diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py index 19357a41307c3803363b4d174cc98a8a77b6daf1..1689e1f55da4ea30af241d6ff9ff4eb002ef1493 100644 --- a/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/__init__.py @@ -1,4 +1,4 @@ from .pipelines import CogVideoXPipeline -from .models import CogVideoXTransformer3DModel +from .models import CogVideoXTransformer3DModel, AdaStep from .utils import get_world_size, get_rank, all_gather, set_parallel from .utils import get_sp_world_size, get_sp_rank, get_dp_rank, get_dp_world_size, get_sp_group, get_dp_group \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py index a267e101cd0c03bcc4f076ed254a02309fb22712..bd6e63e9be3d073d886967829dfa8e924e19f844 100644 --- a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/__init__.py @@ -1 +1,2 @@ from .transformers import CogVideoXTransformer3DModel +from .sampling_optm import AdaStep diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/linear.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c871bb867f6a9f1195f7f6672f5bdf8b24b0c7a5 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/linear.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None, + device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + + self.cross_attention_dim = cross_attention_dim + self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if cross_attention_dim is None: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter( + torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs)) + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, q_seqlen, _ = hidden_states.shape + q = q.view(batch, q_seqlen, -1) + + batch, kv_seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, kv_seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + return q, k, v diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/sampling_optm.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/sampling_optm.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0bf553b86664794f7a5591f7456bbc821caa53 --- /dev/null +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/sampling_optm.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from functools import reduce +import torch + + +class AdaStep: + """ + The Adastep class is designed to optimize the sampling process in a diffusion model, + """ + def __init__(self, skip_thr=0.015, max_skip_steps=4, decay_ratio=0.99, + device="npu", forward_value=None, step_value=None): + """ + Args: + skip_thr (float): The threshold for determining whether to skip a step based on the change in latent variables. + Recommended values are between 0.01 and 0.015. Default is 0.015. + max_skip_steps (int): The maximum number of consecutive steps that can be skipped. + Recommended values are between 3 and 4. Default is 4. + decay_ratio (float): The decay ratio for the skip threshold, which is used to dynamically adjust + the threshold over time. Recommended values are between 0.95 and 0.99. Default is 0.99. + device (str): The device on which the computations will be performed. Default is "npu". + """ + + # recommand 0.01(skip around 35 steps) ~ 0.015(skip around 50 steps) + self.skip_thr = skip_thr + # recommand 3 ~ 4 + self.max_skip_steps = max_skip_steps + # recommand 0.95 ~ 0.99 + self.decay_ratio = decay_ratio + self.device = device + self.if_skip = self.max_skip_steps > 0 + self.reset_status() + + self.forwardretype = forward_value + self.stepretype = step_value + + def __call__(self, transformer, *model_args, **model_kwargs): + """ + Args: + transformer (Module): The Module that works as the DiT and returns the noise prediction. + model_args (tuple): The arguments to be passed to the transformer Module forword. + model_kwargs (dict): The keyword arguments to be passed to the transformer Module forword. + Returns: + The noise prediction from the transformer function. + """ + if self.if_skip and torch.all(self.skip_vote): + return self._return_output(self.skip_noise_pred, self.forwardretype) + + noise_pred = transformer(*model_args, **model_kwargs) + if not self.forwardretype: + if isinstance(noise_pred, tuple): + self.forwardretype = tuple + elif isinstance(noise_pred, torch.Tensor): + self.forwardretype = torch.Tensor + else: + raise (ValueError, "Transformer needs return a tuple whose first element is the result, " + "or return a tensor. In other cases, please enter `forward_value`.") + self.skip_noise_pred = self._get_input(noise_pred, self.forwardretype) + return noise_pred + + @staticmethod + def _get_input(input_value, inp_type): + if isinstance(inp_type, type): + if inp_type is tuple: + return input_value[0] + else: + return input_value + else: + return input_value[inp_type] + + @staticmethod + def _return_output(output, outptype): + if isinstance(outptype, type): + if outptype is tuple: + return (output,) + else: + return output + elif isinstance(outptype, str): + return {outptype: output} + else: + return (0,) * outptype + (output,) + + def set_param(self, skip_thr=None, max_skip_steps=None, decay_ratio=None, device=None): + """ + To set the parameters of the AdaStep class. + """ + self.skip_thr = skip_thr or self.skip_thr + self.max_skip_steps = max_skip_steps or self.max_skip_steps + self.decay_ratio = decay_ratio or self.decay_ratio + if device: + self.device = device + self.skip_vote.to(self.device) + self.if_skip = self.max_skip_steps > 0 + + def reset_status(self): + """ + Reset the status of the AdaStep class. + """ + self.skip_mask = [False] + self.skip_latents_diff = [] + self.skip_noise_pred = None + self.skip_prev_latents = 0 + self.skip_vote = torch.tensor([False], dtype=torch.bool, device=self.device) + + def update_strategy(self, latents, sequence_parallel=False, sp_group=None): + """ + Update the strategy for skipping steps based on the change in latents. + """ + if not self.stepretype: + if isinstance(latents, tuple): + self.stepretype = tuple + elif isinstance(latents, torch.Tensor): + self.stepretype = torch.Tensor + else: + raise (ValueError, "step needs return a tuple whose first element is the result, " + "or return a tensor. In other cases, please enter `step_value`.") + if self.if_skip: + latents = self._get_input(latents, self.stepretype) + diff = latents - self.skip_prev_latents + self.skip_latents_diff.append(diff.abs().mean()) + if len(self.skip_latents_diff) >= 3: + self.skip_mask.append(self._estimate()) + + self.skip_prev_latents = latents + + mask_value = self.skip_mask[-1] + mask_value = torch.tensor([mask_value], dtype=torch.bool, device=self.device) + if sequence_parallel: + skip_vote = torch.zeros(torch.distributed.get_world_size(sp_group), + dtype=torch.bool, device=self.device) + torch.distributed.all_gather_into_tensor(skip_vote, mask_value, group=sp_group) + else: + skip_vote = mask_value + self.skip_vote = skip_vote + + def _estimate(self): + # `self.skip_latents_diff[-1]` refers to the most recent difference in latent variables. + cur_diff = self.skip_latents_diff[-1] + # `self.skip_latents_diff[-2]` refers to the second most recent difference in latent variables. + prev_diff = self.skip_latents_diff[-2] + # `self.skip_latents_diff[-3]` refers to the third most recent difference in latent variables. + prev_prev_diff = self.skip_latents_diff[-3] + + self.skip_thr = self.skip_thr * self.decay_ratio + + if len(self.skip_mask) >= self.max_skip_steps and \ + all(self.skip_mask[-self.max_skip_steps:]): + return False + + if abs((cur_diff + prev_prev_diff) / 2 - prev_diff) <= prev_diff * self.skip_thr: + return True + return False + + +class SamplingOptm: + def __init__(self, pipe, dit_forward="transformer.forward", scheduler_step="scheduler.step", + forward_value=None, step_value=None, parallel=False, group=None, config=None): + self.parallel = parallel + self.group = group + self.skip = False + if config and config["method"] == "AdaStep": + self.skip = True + ditforward_lst = dit_forward.split(".") + schedulerstep_lst = scheduler_step.split(".") + self.pipe = pipe + + self.ori_forward = reduce(getattr, ditforward_lst, self.pipe) # getattr(self.pipe, )dit_forward.split(".") + self.forward = ditforward_lst.pop() + self.ori_dit = reduce(getattr, ditforward_lst, self.pipe) + + self.ori_step = reduce(getattr, schedulerstep_lst, self.pipe) + self.step = schedulerstep_lst.pop() + self.ori_scheduler = reduce(getattr, schedulerstep_lst, self.pipe) + + shik_thr = config.get("skip_thr", 0.015) + max_skip_steps = config.get("max_skip_steps", 4) + decay_ratio = config.get("decay_ratio", 0.99) + self.skip_strategy = AdaStep(skip_thr=shik_thr, max_skip_steps=max_skip_steps, decay_ratio=decay_ratio) + + def __enter__(self): + if self.skip: + self._sub_forward() + self._sub_step() + + def __exit__(self, t, v, trace): + if self.skip: + self._revert_forward() + self._revert_step() + + def _sub_forward(self): + @functools.wraps(self.ori_forward) + def _optm_forward(*args, **kwargs): + noise_pred = self.skip_strategy(self.ori_forward, *args, **kwargs) + return noise_pred + setattr(self.ori_dit, self.forward, _optm_forward) + + def _sub_step(self): + @functools.wraps(self.ori_step) + def _optm_step(*args, **kwargs): + latents = self.ori_step(*args, **kwargs) + self.skip_strategy.update_strategy(latents, self.parallel, self.group) + return latents + setattr(self.ori_scheduler, self.step, _optm_step) + + def _revert_forward(self): + setattr(self.ori_dit, self.forward, self.ori_forward) + + def _revert_step(self): + setattr(self.ori_scheduler, self.step, self.ori_step) \ No newline at end of file diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py index 3c343d6f76088a8bf170d9648e709ee1fccc108e..472e379382f34fe32638e14fca0341eb2580ed66 100644 --- a/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py @@ -24,7 +24,7 @@ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_l from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from mindiesd.layers.linear import QKVLinear +from ..linear import QKVLinear from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps diff --git a/MindIE/MultiModal/CogVideoX/inference.py b/MindIE/MultiModal/CogVideoX/inference.py index 9c473492af8c1b937a8efcf7a87df3c274bd1f39..7c9334ff2837a4a16838f13d3dd27b6e02361421 100644 --- a/MindIE/MultiModal/CogVideoX/inference.py +++ b/MindIE/MultiModal/CogVideoX/inference.py @@ -14,8 +14,15 @@ from torch_npu.contrib import transfer_to_npu from diffusers import CogVideoXDPMScheduler from diffusers.utils import export_to_video -from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather, set_parallel -from mindiesd.pipeline.sampling_optm import AdaStep +from cogvideox_5b import ( + CogVideoXPipeline, + CogVideoXTransformer3DModel, + get_rank, + get_world_size, + all_gather, + set_parallel, + AdaStep +) from mindiesd import CacheAgent, CacheConfig