diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 29ac8924a116041f77a02e44c046dad52ce4884d..2b47e091af8f49d3662f4c613a97b505f3e9266b 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -33,7 +33,7 @@ from torch._inductor.lowering import ( add_layout_constraint ) import torch_npu -from torch_npu import npu_dtype_cast +from torch_npu import npu_dtype_cast, _npu_dtype_cast from .lowering_op_list import GENERATE_LIST, GENERATE_LIST2, FALLBACK_LIST, LOWERING_OVERLOAD_OP @@ -198,6 +198,10 @@ def _register_npu_inductor_fallbacks(): def _convert_npu_type(x: TensorBox, dtype: torch.dtype): return to_dtype(x, dtype, copy=True) + @register_lowering(_npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + def var_mean_sum_(x, axis, correction, keepdim, return_mean): if correction is None: correction = 1 diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py index 5084c29534afdc2760c8baa6a987b59a215bd821..e9d3e3a0e97f376526d2713ea38ec345fe17721f 100644 --- a/torch_npu/_inductor/lowering_fx.py +++ b/torch_npu/_inductor/lowering_fx.py @@ -2223,6 +2223,10 @@ def _register_npu_inductor_fallbacks(): def _convert_npu_type(x: TensorBox, dtype: torch.dtype): return to_dtype(x, dtype, copy=True) + @register_lowering(npu._npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + def var_mean_sum_(x, axis, correction, keepdim, return_mean): if correction is None: correction = 1 diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index 0e8bb3a9a53234e6d5e5ac228eeaf8b101fe66c9..e750953242662e80d8b38c07366c2d5cf8d3f15d 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -1,5 +1,5 @@ import torch -from torch_npu import npu_dtype_cast +from torch_npu import npu_dtype_cast, _npu_dtype_cast aten = torch.ops.aten tr_c10d = torch.ops.tr_c10d @@ -56,6 +56,7 @@ GENERATE_LIST = [ aten.clamp_max, aten.mean, npu_dtype_cast, + _npu_dtype_cast, aten.select_scatter, aten.slice_scatter, prims.broadcast_in_dim,