From 763b4b3951570d0d20f008c33737e707e330b48a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E7=AB=8B?= Date: Tue, 8 Jul 2025 06:19:58 +0000 Subject: [PATCH] fix(inductor): regist _npu_dtype_cast into lowering list --- torch_npu/_inductor/lowering.py | 6 +++++- torch_npu/_inductor/lowering_fx.py | 4 ++++ torch_npu/_inductor/lowering_op_list.py | 3 ++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 29ac8924a1..2b47e091af 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -33,7 +33,7 @@ from torch._inductor.lowering import ( add_layout_constraint ) import torch_npu -from torch_npu import npu_dtype_cast +from torch_npu import npu_dtype_cast, _npu_dtype_cast from .lowering_op_list import GENERATE_LIST, GENERATE_LIST2, FALLBACK_LIST, LOWERING_OVERLOAD_OP @@ -198,6 +198,10 @@ def _register_npu_inductor_fallbacks(): def _convert_npu_type(x: TensorBox, dtype: torch.dtype): return to_dtype(x, dtype, copy=True) + @register_lowering(_npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + def var_mean_sum_(x, axis, correction, keepdim, return_mean): if correction is None: correction = 1 diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py index 5084c29534..e9d3e3a0e9 100644 --- a/torch_npu/_inductor/lowering_fx.py +++ b/torch_npu/_inductor/lowering_fx.py @@ -2223,6 +2223,10 @@ def _register_npu_inductor_fallbacks(): def _convert_npu_type(x: TensorBox, dtype: torch.dtype): return to_dtype(x, dtype, copy=True) + @register_lowering(npu._npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + def var_mean_sum_(x, axis, correction, keepdim, return_mean): if correction is None: correction = 1 diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index 0e8bb3a9a5..e750953242 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -1,5 +1,5 @@ import torch -from torch_npu import npu_dtype_cast +from torch_npu import npu_dtype_cast, _npu_dtype_cast aten = torch.ops.aten tr_c10d = torch.ops.tr_c10d @@ -56,6 +56,7 @@ GENERATE_LIST = [ aten.clamp_max, aten.mean, npu_dtype_cast, + _npu_dtype_cast, aten.select_scatter, aten.slice_scatter, prims.broadcast_in_dim, -- Gitee