diff --git a/test/_inductor/test_cat.py b/test/_inductor/test_cat.py index 26d89caaa8dd679975d020621fdeafbb79f78c0a..c75e0fd7f6885cfb70276ca1e5b633b5d820a1e7 100644 --- a/test/_inductor/test_cat.py +++ b/test/_inductor/test_cat.py @@ -20,6 +20,48 @@ class TestCat(TestUtils): inductor_cat = compiled_op_calc(input_element, dim) self.assertEqual(std_cat, inductor_cat, atol=1e-1, rtol=1e-1, equal_nan=True) + def op_calc_non_contiguous(self, input_element, dim): + return torch.cat([input_element, input_element], dim) + + @parametrize('shape', [(8, 16, 32, 64)]) + @parametrize('dim', [1]) + @parametrize('dtype', ['bfloat16']) + def test_cat_non_contiguous(self, shape, dim, dtype): + input_element = self._generate_tensor(shape, dtype) + input_element = input_element.transpose(-1, -2) + std_cat = self.op_calc_non_contiguous(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc_non_contiguous, backend="inductor") + inductor_cat = compiled_op_calc(input_element, dim) + self.assertEqual(std_cat, inductor_cat, atol=1e-4, rtol=1e-4, equal_nan=True) + + class PatternModel(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, *xs): + slices = [x[..., :sz] for x, sz in zip(xs, (128, 32, 48, 48, 48, 48, 48))] + output_tensor = torch.cat(slices, self.dim) + + return output_tensor + + @parametrize('shape', [(128, 50, 128)]) + @parametrize('dim', [2]) + @parametrize('dtype', ['float32', 'bfloat16']) + def test_model_input_is_concat(self, shape, dim, dtype): + inputs = [self._generate_tensor(shape, dtype) for _ in range(7)] + + model = self.PatternModel(dim).to(dtype=getattr(torch, dtype)) + model.eval() + with torch.no_grad(): + eager_out = model(*inputs) + + compiled_model = torch.compile(model, backend="inductor") + with torch.no_grad(): + inductor_out = compiled_model(*inputs) + + self.assertEqual(eager_out, inductor_out, + atol=1e-4, rtol=1e-4, equal_nan=True) instantiate_parametrized_tests(TestCat) diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index a6779c7c602bfbfd9854d06e9bc8e211a9082846..2c65ab65727026e38f0395feb22615075ca57e6e 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -20,7 +20,7 @@ from .lowering import make_reduction, npu_make_fallback from .npu_choices import should_use_persistent_reduction from .npu_device import NewNPUDeviceOpOverrides from .runtime import _load_cached_autotuning -from .utils import get_current_raw_stream, patch_is_gpu, patch_has_triton +from .utils import get_current_raw_stream, patch_is_gpu, patch_has_triton, disable_foreach from .codecache import patch_aot_code_compiler_compile, patch_cache_base_get_system set_compile_threads() @@ -106,6 +106,5 @@ register_fa_pass() patch_cache_base_get_system() patch_is_gpu() patch_has_triton() - - +disable_foreach() diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 2b47e091af8f49d3662f4c613a97b505f3e9266b..2e98a9b55bb89d9573be9877edc2350d223ccf32 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -5,12 +5,13 @@ from torch._inductor import lowering from torch._inductor.decomposition import decompositions, pw_cast_for_opmath from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper from torch._inductor.ir import Reduction -from torch._inductor.lowering import sum_ +from torch._inductor.lowering import sum_, clone from torch._inductor.utils import sympy_product from torch._prims_common import ( is_boolean_dtype, is_integer_dtype, get_computation_dtype, + ELEMENTWISE_TYPE_PROMOTION_KIND, ) from torch._inductor.lowering import ( lowerings, @@ -30,7 +31,10 @@ from torch._inductor.lowering import ( _make_reduction_inner, _validate_reduction_axis, add_needs_realized_inputs, - add_layout_constraint + add_layout_constraint, + require_channels_last, + _validate_dim, + get_promoted_dtype, ) import torch_npu from torch_npu import npu_dtype_cast, _npu_dtype_cast @@ -262,7 +266,16 @@ def _register_npu_inductor_fallbacks(): @register_lowering(aten.cat) def cat(inputs, dim=0): - return fallback_handler(aten.cat.default)(inputs, dim) + if len(inputs) == 1: + return clone(inputs[0]) + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + return TensorBox(ir.ConcatKernel.create(inputs, dim)) make_fallback(aten._log_softmax) make_fallback(aten.gather) diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py index 78e20064570ccf7c1978ab23092f119002c0929e..a272144e14143b44b25b5bc8c2f07f2ade62f202 100644 --- a/torch_npu/_inductor/lowering_fx.py +++ b/torch_npu/_inductor/lowering_fx.py @@ -2298,7 +2298,16 @@ def _register_npu_inductor_fallbacks(): @register_lowering(aten.cat) def cat(inputs, dim=0): - return lowering.fallback_handler(aten.cat.default)(inputs, dim) + if len(inputs) == 1: + return clone(inputs[0]) + dim = _validate_dim(inputs[0], dim, 0) + dtype = lowering.get_promoted_dtype( + *inputs, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + return TensorBox(ir.ConcatKernel.create(inputs, dim)) lowering.make_fallback(aten._log_softmax) lowering.make_fallback(aten.gather) diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py index 095f1f69cf2bff023c2ee492c726940d801c75c5..2caf6e3e5f57a480e4580c116805a39db8c04770 100644 --- a/torch_npu/_inductor/utils.py +++ b/torch_npu/_inductor/utils.py @@ -76,3 +76,10 @@ def patch_has_triton(): torch._inductor.scheduler.has_triton = has_triton +def disable_foreach(): + from torch._inductor.scheduler import Scheduler + + def create_foreach_nodes(self): + return + + Scheduler.create_foreach_nodes = create_foreach_nodes