From 83bd81cd1044978036dced3fbe7668ff7b7a2047 Mon Sep 17 00:00:00 2001 From: 15591922239 Date: Thu, 26 Oct 2023 19:57:58 +0800 Subject: [PATCH 1/2] add MaskedSoftmaxWithRelPosBias onnx --- torch_npu/onnx/wrapper_onnx_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 91d660d3d5..b80c24103a 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -636,6 +636,24 @@ class NPUFlashAttentionOP(torch.autograd.Function): head_num_i=head_num, input_layout_s=input_layout, scale_f=scale, keep_prob_f=keep_prob, pre_tockens_i=pre_tockens, next_tockens_i=next_tockens, gen_mask_parallel_i=gen_mask_parallel, sync_i=sync) +# npu_masked_softmax_with_rel_pos_bias(Tensor x, Tensor relative_pos_bias, Tensor? atten_mask=None, float scale_value=1.0, int inner_precision_mode=0) +class NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_masked_softmax_with_rel_pos_bias(*args, **kwargs) + + @staticmethod + def symbolic(g, x: Tensor, relative_pos_bias: Tensor, atten_mask: Tensor, scale_value: float = 1.0, + inner_precision_mode: int = 0): + if atten_mask is None: + atten_mask = g.op("Constant", value_t=torch.tensor([]).to(torch.float)) + return g.op("npu::NPUMaskedSoftmaxWithRelPosBias", x, relative_pos_bias, atten_mask, scale_value_f = scale_value, + inner_precision_mode_i = inner_precision_mode) + + +def wrapper_npu_masked_softmax_with_rel_pos_bias(x, relative_pos_bias, atten_mask=None, scale_value=1.0, inner_precision_mode=0): + return NPUFlashAttentionOP.apply(x, relative_pos_bias, atten_mask, scale_value, inner_precision_mode) def wrapper_npu_flash_attention(query, key, value, head_num, @@ -916,3 +934,4 @@ def add_onnx_ops(): torch_npu.npu_mish = wrapper_npu_mish torch_npu.npu_rotary_mul = wrapper_npu_rotary_mul torch_npu.npu_flash_attention = wrapper_npu_flash_attention + torch_npu.npu_flash_attention = wrapper_npu_masked_softmax_with_rel_pos_bias -- Gitee From 073be81c465dee1391fd4f11c54b401656c98549 Mon Sep 17 00:00:00 2001 From: 15591922239 Date: Thu, 26 Oct 2023 20:03:18 +0800 Subject: [PATCH 2/2] add MaskedSoftmaxWithRelPosBias onnx --- torch_npu/onnx/wrapper_onnx_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index b80c24103a..8a166a4af0 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -653,8 +653,7 @@ class NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): def wrapper_npu_masked_softmax_with_rel_pos_bias(x, relative_pos_bias, atten_mask=None, scale_value=1.0, inner_precision_mode=0): - return NPUFlashAttentionOP.apply(x, relative_pos_bias, atten_mask, scale_value, inner_precision_mode) - + return NPUMaskedSoftmaxWithRelPosBiasOP.apply(x, relative_pos_bias, atten_mask, scale_value, inner_precision_mode) def wrapper_npu_flash_attention(query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, -- Gitee