From b70ee72c9b3c9d868ddc3c7d4d8e28d7d8d27c01 Mon Sep 17 00:00:00 2001 From: zhangtao Date: Fri, 18 Jul 2025 09:22:43 +0000 Subject: [PATCH] HunyuanVideo online quantization matmul demo --- .../hyvideo/modules/mlp_layers.py | 11 +++++-- .../HunyuanVideo/hyvideo/utils/helpers.py | 33 ++++++++++++++++++- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py index 10dd9a85e2..31f325efda 100644 --- a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from .modulate_layers import modulate -from ..utils.helpers import to_2tuple +from ..utils.helpers import to_2tuple, dynamic_quant_matmul_func class MLP(nn.Module): @@ -49,7 +49,12 @@ class MLP(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - x = self.fc1(x) + # == Float Matmul == + # x = self.fc1(x) + + # == Dynamic Quant Matmul == + x = dynamic_quant_matmul_func(x, self.fc1.weight, self.fc1.bias) + x = self.act(x) x = self.drop1(x) x = self.norm(x) @@ -113,4 +118,4 @@ class FinalLayer(nn.Module): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift=shift, scale=scale) x = self.linear(x) - return x \ No newline at end of file + return x diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/utils/helpers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/utils/helpers.py index 5e1aadf118..7e8b02a1f8 100644 --- a/MindIE/MultiModal/HunyuanVideo/hyvideo/utils/helpers.py +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/utils/helpers.py @@ -2,6 +2,8 @@ import collections.abc from itertools import repeat +import torch +import torch_npu def _ntuple(n): def parse(x): @@ -36,4 +38,33 @@ def as_list_of_2tuple(x): lst = [] for i in range(0, len(x), 2): lst.append((x[i], x[i + 1])) - return lst \ No newline at end of file + return lst + +def dynamic_quant_matmul_func(x, weight, bias=None): + # The input should be either float16 or bfloat16 + if x.dtype != torch.bfloat16 and x.dtype != torch.float16: + x = x.to(torch.bfloat16) + + # get batch size and sequence length + b, s = x.shape[0], x.shape[1] + + # Convert to int8 + # Official Document: https://www.hiascend.com/document/detail/zh/Pytorch/700/apiref/apilist/ptaoplist_001228.html + x_int8, x_scale = torch_npu.npu_dynamic_quant(x.reshape(b * s, -1)) + w_int8, w_scale = torch_npu.npu_dynamic_quant(weight) + + # Transpose weight for quant matmul + w_int8 = w_int8.T + + # quant matmul, input: int, output: float/bfloat + # Official Document: https://www.hiascend.com/document/detail/zh/Pytorch/700/apiref/apilist/ptaoplist_000172.html + x = torch_npu.npu_quant_matmul( + x_int8, w_int8, w_scale, + pertoken_scale=x_scale, + output_dtype=x.dtype, + bias=bias + ) + + # reshape input back to [batchsize, sequnce length, hidden state] + x = x.reshape(b, s, -1) + return x -- Gitee