From 9f8866f7e8f46d045cd7e752dad2c892cac25544 Mon Sep 17 00:00:00 2001 From: zhengying Date: Wed, 3 Jul 2024 21:58:57 +0800 Subject: [PATCH] =?UTF-8?q?antiquantv2=E9=80=82=E9=85=8Dint4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/custom_ops/test_npu_anti_quant.py | 4 ++ test/onnx/test_pytorch_onnx_no_runtime.py | 78 +---------------------- test/onnx/test_utility_funs.py | 25 ++------ test/onnx/test_wrapper_onnx_ops.py | 21 ++++++ torch_npu/meta/_meta_registrations.py | 14 ++++ torch_npu/onnx/wrapper_onnx_ops.py | 7 +- 6 files changed, 51 insertions(+), 98 deletions(-) diff --git a/test/custom_ops/test_npu_anti_quant.py b/test/custom_ops/test_npu_anti_quant.py index 8405b6eb44..a5cdf9b574 100644 --- a/test/custom_ops/test_npu_anti_quant.py +++ b/test/custom_ops/test_npu_anti_quant.py @@ -35,6 +35,10 @@ class TestAntiQuant(TestCase): [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], [np.float32, -1, [100]], torch.bfloat16, None], [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], [np.float32, -1, [100]], torch.float16, torch.int8], [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], [np.float32, -1, [100]], torch.bfloat16, torch.int8], + # [[np.int32, -1, [10, 25]], [np.float32, -1, [200]], [np.float32, -1, [200]], torch.float16, None], + # [[np.int32, -1, [10, 25]], [np.float32, -1, [200]], [np.float32, -1, [200]], torch.bfloat16, None], + # [[np.int32, -1, [10, 25]], [np.float32, -1, [200]], [np.float32, -1, [200]], torch.float16, torch.quint4x2], + # [[np.int32, -1, [10, 25]], [np.float32, -1, [200]], [np.float32, -1, [200]], torch.bfloat16, torch.quint4x2], ] for item in shape_format: diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index e49e058907..6c3c837120 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -20,7 +20,7 @@ import pytorch_test_common import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import OperatorExportTypes, symbolic_helper, utils +from torch.onnx import symbolic_helper, utils from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils import torch_npu @@ -396,7 +396,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): for node in graph.nodes(): self.assertTrue(node.sourceRange()) - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_due_exception(self): def bad_clamp(g, self, min, max): return symbolic_helper._onnx_unsupported("Bad boy!") @@ -413,7 +412,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): ) self.assertAtenOp(onnx_model, "clamp", "Tensor") - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_explicit_request(self): class MyClip(torch.nn.Module): def forward(self, x): @@ -963,58 +961,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) - @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback_must_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN, - OperatorExportTypes.ONNX_ATEN_FALLBACK, - ): - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten_must_not_fallback(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN_FALLBACK, - OperatorExportTypes.ONNX_ATEN, - ): - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" - - @common_utils.skipIfCaffe2 def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): def forward(self, x, y): @@ -1037,7 +983,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "linalg_qr") - @common_utils.skipIfCaffe2 def test_onnx_aten(self): class ModelWithAtenFmod(torch.nn.Module): def forward(self, x, y): @@ -1056,7 +1001,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "fmod", "Tensor") - @common_utils.skipIfCaffe2 def test_onnx_aten_fallback_must_not_fallback(self): # For BUILD_CAFFE2=0, aten fallback only when not exportable class ONNXExportable(torch.nn.Module): @@ -1232,26 +1176,6 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase): _export_to_onnx(model, data, input_names) - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_linear(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Linear(5, 10, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 2, 5).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_conv2d(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Conv2d(3, 5, 2, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - @common_quantization.skipIfNoFBGEMM @unittest.skip( "onnx opset9 does not support quantize_per_tensor and caffe2 \ diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index ea024e4d5d..3076ab86b1 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -25,7 +25,7 @@ from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack +from torch.testing._internal.common_utils import skipIfNoLapack from verify import verify from url import get_url import torch_npu @@ -1362,6 +1362,8 @@ class TestUtilityFuns(_BaseTestCase): iter_ = graph.nodes() self.assertEqual(next(iter_).kind(), "custom_namespace::custom_op") + # gelu is exported as onnx::Gelu for opset >= 20 + @skipIfUnsupportedMaxOpsetVersion(19) def test_custom_opsets_gelu(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9) @@ -1386,6 +1388,8 @@ class TestUtilityFuns(_BaseTestCase): self.assertEqual(graph.opset_import[1].domain, "com.microsoft") self.assertEqual(graph.opset_import[1].version, 1) + # gelu is exported as onnx::Gelu for opset >= 20 + @skipIfUnsupportedMaxOpsetVersion(19) def test_register_aten_custom_op_symbolic(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9) @@ -1626,25 +1630,6 @@ class TestUtilityFuns(_BaseTestCase): "Graph parameter names does not match model parameters.", ) - @skipIfNoCaffe2 - def test_modifying_params(self): - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.tensor([2.0])) - - def forward(self, x): - y = x * x - self.param.data.add_(1.0) - return y - - x = torch.tensor([1, 2]) - # Move import to local as caffe2 backend requires additional build flag, - # and is only used in this test case. - import caffe2.python.onnx.backend as backend - - verify(MyModel(), x, backend, do_constant_folding=False) - def test_fuse_conv_bn(self): class Fuse(torch.nn.Module): def __init__(self): diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index ced26ab6aa..3ad78ad3ad 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1380,6 +1380,27 @@ class TestOnnxOps(TestCase): export_onnx(onnx_model_name) assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + # @SupportedDevices(['Ascend910B']) + # def test_wrapper_npu_anti_quant_s42bf16(self): + # class Model(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, scale, offset=None, dst_dtype=torch.bfloat16, src_dtype=torch.quint4x2): + # return torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=dst_dtype, src_dtype=src_dtype) + + # def export_onnx(onnx_model_name): + # x = torch.randint(low=-128, high=127, size=(10, 1), dtype=torch.int32).npu() + # scale = torch.randn((1,), dtype=torch.bfloat16).npu() + # offset = torch.randn((1,), dtype=torch.bfloat16).npu() + # model = Model().to("npu") + # model(x, scale, offset, None, None) + # self.onnx_export(model, (x, scale, offset, None, None), onnx_model_name) + + # onnx_model_name = "mode_npu_anti_quant_s42bf16.onnx" + # export_onnx(onnx_model_name) + # assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + def test_wrapper_npu_quantize(self): class Model(torch.nn.Module): def __init__(self): diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index ca4390ef8a..cef4152112 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -583,6 +583,20 @@ def npu_anti_quant_meta(x, scale, *, offset=None, dst_dtype=None, src_dtype=None return torch.empty_like(x, dtype=torch.float16) return torch.empty_like(x, dtype=dst_dtype) + # if dst_dtype is None: + # dst_dtype = torch.float16 + + # if x.dtype == torch.qint32: + # x_shape = x.size() + # if len(x_shape) == 0: + # y_shape = (8,) + # else: + # y_shape = (*(x_shape[:-1]), x_shape[-1] * 8) + # y = x.new_empty(y_shape, dtype=dst_dtype) + # return torch.empty_like(y) + # else: + # return torch.empty_like(x, dtype=dst_dtype) + @impl(m, "npu_apply_rotary_pos_emb") def npu_apply_rotary_pos_emb_meta(query, key, cos, sin, layout=1): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 3d173fac91..60ec07ee09 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -785,11 +785,16 @@ class NPUAntiQuantOP(torch.autograd.Function): else: raise TypeError("The argument 'dst_dtype' must be torch.float16 or torch.bfloat16." + pta_error(ErrCode.TYPE)) - + if src_dtype is None or src_dtype == torch.int8: src_dtype = 2 else: raise TypeError("The argument 'src_dtype' must be torch.int8." + pta_error(ErrCode.TYPE)) + # elif src_dtype == torch.quint4x2: + # src_dtype = 29 + # else: + # raise TypeError("The argument 'src_dtype' must be torch.int8 or torch.quint4x2. " + + # pta_error(ErrCode.TYPE)) return g.op("npu::NPUAntiQuant", x, scale, offset, dst_dtype_i=dst_dtype, src_dtype_i=src_dtype) -- Gitee