diff --git a/test/_inductor/test_permute_layernorm.py b/test/_inductor/test_permute_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5b69338fdb29e492112f5758f706729351380e1c --- /dev/null +++ b/test/_inductor/test_permute_layernorm.py @@ -0,0 +1,47 @@ +import torch +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from testutils import OperatorType, TestUtils +from torch._dynamo.testing import rand_strided +import torch_npu + + +class TestSqrt(TestUtils): + def forward(self, arg2_1, arg3_1, arg4_1, arg5_1): + unsqueeze = torch.ops.aten.unsqueeze.default(arg2_1, 1); + npu_dtype_cast_2 = torch.ops.npu.npu_dtype_cast.default(arg3_1, torch.float32) + npu_dtype_cast_3 = torch.ops.npu.npu_dtype_cast.default(arg4_1, torch.float32) + npu_dtype_cast_4 = torch.ops.npu.npu_dtype_cast.default(arg5_1, torch.float32) + clone = torch.ops.aten.clone.default(npu_dtype_cast_2, memory_format=torch.contiguous_format) + var_mean = torch.ops.aten.var_mean.correction(clone, [2], correction=0, keepdim=True) + getitem = var_mean[0] + getitem_1 = var_mean[1] + add = torch.ops.aten.add.Tensor(getitem, 1e-06) + rsqrt = torch.ops.aten.rsqrt.default(add) + sub = torch.ops.aten.sub.Tensor(clone, getitem_1) + mul_1 = torch.ops.aten.mul.Tensor(sub, rsqrt) + mul_2 = torch.ops.aten.mul.Tensor(mul_1, npu_dtype_cast_3) + add_1 = torch.ops.aten.add.Tensor(mul_2, npu_dtype_cast_4) + npu_dtype_cast_5 = torch.ops.npu.npu_dtype_cast.default(add_1, torch.float16) + add_2 = torch.ops.aten.add.Tensor(npu_dtype_cast_5, unsqueeze) + return add_2 + + def test_permute_layernorm_cases(self): + arg2 = rand_strided((2, 1408), (1408, 1), device='npu', dtype=torch.float32) + arg3 = rand_strided((2, 3840, 1408), (5406720, 1, 3840), device='npu', dtype=torch.float16) + arg4 = rand_strided((1408,), (1,), device='npu', dtype=torch.float16) + arg5 = rand_strided((1408,), (1,), device='npu', dtype=torch.float16) + + std_result = self.forward(arg2, arg3, arg4, arg5) + compiled_op_calc = torch.compile(self.forward, backend="inductor") + inductor_result = compiled_op_calc(arg2, arg3, arg4, arg5) + + rtol = 1e-2 + atol = 1e-2 + torch.testing.assert_close(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + + +instantiate_parametrized_tests(TestSqrt) + +if __name__ == "__main__": + test = TestSqrt() + test.test_permute_layernorm_cases() diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index e39948eb657c63dc6a490555182af9c578cc20e7..ecf9bc7dc0bffe6041e109d181cde1f3ed9ca131 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -438,6 +438,7 @@ class NPUIndexTritonKernel(TritonKernel): self.golden_var_list = None self.reduce_analysis = None self.load_store_indexing = None + self.store_mask = {} def gen_triton_ext_imports(self): imports = IndentedBuffer() @@ -867,6 +868,18 @@ class NPUIndexTritonKernel(TritonKernel): is_last_axis = index == len(self.sorted_axis) - 1 indexing_code = getattr(range_val, "indexing_code") + if not self.first_node: + for mask in self.store_mask.keys(): + idx = self.store_mask[mask] + if idx == index: + continue + if mask in str(self.body): + continue + # add mask + other_axis_indexing_code = self.sorted_axis[idx].indexing_code + indexing_code.splice(other_axis_indexing_code) + + reduction_1d = is_1d_reduction() do_indent = False # do nothing except for writing porintwise @@ -1002,6 +1015,18 @@ class NPUIndexTritonKernel(TritonKernel): value_str = f"{value}" mask_str = indexing.mask_str + if index_analyze.var_replacements: + for tmp_var in index.free_symbols: + if tmp_var not in index_analyze.var_replacements.keys(): + continue + + for mask in indexing.mask_vars: + str_var = str(index_analyze.var_replacements[tmp_var]) + if str_var in mask: + axis = self.range_tree_nodes[tmp_var] + idx = self.sorted_axis.index(axis) + self.store_mask[mask] = idx + if index_analyze.need_permute: value_str = value_str.replace(f"{value}", f"{value}{index_analyze.generate_statement()}")