diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index 6bc07049be300275c4b89ce858e349f1f3ba246f..dd53703f2add9b91322a40970023d491d39d0ead 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -1,14 +1,20 @@ -from typing import cast +from typing import Any, cast +import logging import torch +import torch.nn as nn from torch import distributed as dist from torch.distributed.fsdp._fully_shard._fsdp_common import compiled_autograd_enabled, TrainingState from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam, ShardedState from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup, AllGatherState +from torch.distributed.utils import _to_kwargs import torch_npu +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + def _patched_finalize_backward(self): self._wait_for_post_backward() for fsdp_param in self.fsdp_params: @@ -83,6 +89,37 @@ def _get_param_all_gather_inputs( return param_all_gather_inputs +def _patched_root_pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[tuple[Any, ...], dict[str, Any]]: + self._lazy_init() + if self._state_ctx.iter_forward_root is not None: + return args, kwargs + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_pre_forward") + self._state_ctx.iter_forward_root = self + with torch.profiler.record_function("FSDP::root_pre_forward"): + # Wait for optimizer before implicitly prefetched all-gathers + event = self._state_ctx.post_optim_event + if event is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = self._device_handle.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) + # add patch for supporting self._device.type="npu" + if self._device.type in ["cuda", "hpu", "xpu", "mtia", "npu"]: + with torch.profiler.record_function("FSDP::inputs_to_device"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, self._device, False + ) # same as DDP + args, kwargs = args_tuple[0], kwargs_tuple[0] + return args, kwargs + + def _apply_fsdp_patch(): FSDPParamGroup.finalize_backward = _patched_finalize_backward torch.distributed.fsdp._fully_shard._fsdp_collectives._get_param_all_gather_inputs = _get_param_all_gather_inputs + torch.distributed.fsdp._fully_shard._fsdp_state.FSDPState._root_pre_forward = _patched_root_pre_forward