From f3fb0d4fcbe30a3cffdcb0cd38bf07a48a14f9b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B0=E9=87=8E?= Date: Fri, 19 Jul 2024 11:06:00 +0800 Subject: [PATCH] fixed 2faeb78 from https://gitee.com/tianye0806/pytorch/pulls/13039 npu_quantize support int4 --- torch_npu/csrc/framework/utils/CalcuOpUtil.cpp | 2 +- torch_npu/meta/_meta_registrations.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp index 99ee462357..6af98cdec2 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp @@ -48,7 +48,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(ENUM_PAIR_FUNC) _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ _(at::ScalarType::BFloat16, ACL_BF16) \ - _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt4x2, ACL_INT4) \ _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ _(at::ScalarType::Bits1x8, ACL_DT_UNDEFINED) \ _(at::ScalarType::Bits2x4, ACL_DT_UNDEFINED) \ diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index fbae9ae3b7..58ec076231 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -688,6 +688,15 @@ def npu_quantize_meta(self, scales, zero_points, dtype, axis=1, div_mode=True): return torch.empty_like(self, dtype=torch.int8) elif dtype == torch.qint32: return torch.empty_like(self, dtype=torch.int32) + elif dtype == torch.quint4x2: + dim_num = self.dim() + if self.size(dim_num - 1) % 8: + raise RuntimeError("If dtype is quint4x2, last dim must be divided by 8.") + output_shape = [] + for dim in range(dim_num - 1): + output_shape.append(self.size(dim)) + output_shape.append(self.size(dim_num - 1) // 8) + return torch.empty_like(self, dtype=torch.int32) return torch.empty_like(self, dtype=torch.int8) -- Gitee