From c70f30e7f21920fec30802e3a919a4ea6cfe368b Mon Sep 17 00:00:00 2001 From: zhangqiongwen Date: Mon, 8 Sep 2025 20:50:02 +0800 Subject: [PATCH 1/5] [bugfix]add fsdp patch for supporting host inputs --- torch_npu/distributed/fsdp/_add_fsdp_patch.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index 55322057b4..0289fbcc29 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -1,6 +1,8 @@ -from typing import Tuple, Union, cast, List +from typing import Tuple, Union, cast, List, Dict, Any +import logging import torch +import torch.nn as nn from torch import distributed as dist from torch._dynamo import tensor_version_op from torch._prims import _make_prim, RETURN_TYPE @@ -15,6 +17,9 @@ import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + def _patched_finalize_backward(self): self._wait_for_post_backward() for fsdp_param in self.fsdp_params: @@ -258,6 +263,34 @@ def _get_param_all_gather_inputs( return param_all_gather_inputs +def _root_pre_forward( + self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + self._lazy_init() + if self._state_ctx.iter_forward_root is not None: + return args, kwargs + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_pre_forward") + self._state_ctx.iter_forward_root = self + with torch.profiler.record_function("FSDP::root_pre_forward"): + # Wait for optimizer before implicitly prefetched all-gathers + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = self._device_handle.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if self._device.type in ["cuda", "hpu"]: + with torch.profiler.record_function("FSDP::inputs_to_device"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, self._device, False + ) # same as DDP + args, kwargs = args_tuple[0], kwargs_tuple[0] + return args, kwargs + + def _apply_fsdp_patch(): FSDPParamGroup.finalize_backward = _patched_finalize_backward FSDPParamGroup.wait_for_unshard = patched_wait_for_unshard @@ -266,3 +299,4 @@ def _apply_fsdp_patch(): _unsafe_preserve_version_counter.__init__ = _patched_unsafe_preserve_version_counter.__init__ _unsafe_preserve_version_counter.__exit__ = _patched_unsafe_preserve_version_counter.__exit__ torch.distributed.fsdp._fully_shard._fsdp_collectives._get_param_all_gather_inputs = _get_param_all_gather_inputs + torch.distributed.fsdp._fully_shard._fsdp_state.FSDPState._root_pre_forward = _patched_root_pre_forward -- Gitee From bb8be6839151927d69cfe86a16a3737deced16b7 Mon Sep 17 00:00:00 2001 From: zhangqiongwen Date: Mon, 8 Sep 2025 20:52:44 +0800 Subject: [PATCH 2/5] [bugfix]add fsdp patch for supporting host inputs --- torch_npu/distributed/fsdp/_add_fsdp_patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index 0289fbcc29..1d6166956d 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -274,7 +274,8 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - if (event := self._state_ctx.post_optim_event) is not None: + event = self._state_ctx.post_optim_event + if event is not None: self._comm_ctx.all_gather_copy_in_stream.wait_event(event) self._comm_ctx.all_gather_stream.wait_event(event) self._state_ctx.post_optim_event = None -- Gitee From 40dca491498d261624f16c9020a4bb6d72557356 Mon Sep 17 00:00:00 2001 From: zhangqiongwen Date: Tue, 9 Sep 2025 09:10:09 +0800 Subject: [PATCH 3/5] [bugfix]add fsdp patch for supporting host inputs --- torch_npu/distributed/fsdp/_add_fsdp_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index 1d6166956d..06f1f5b0d1 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast, List, Dict, Any +from typing import Any, Tuple, Union, cast, List, Dict import logging import torch -- Gitee From c6cb68e342ac17b2d6aab67305d6b1d3b23df203 Mon Sep 17 00:00:00 2001 From: zhangqiongwen Date: Tue, 9 Sep 2025 15:26:39 +0800 Subject: [PATCH 4/5] [bugfix]add fsdp patch for supporting host inputs --- torch_npu/distributed/fsdp/_add_fsdp_patch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index 06f1f5b0d1..b536b3fd4b 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -263,7 +263,7 @@ def _get_param_all_gather_inputs( return param_all_gather_inputs -def _root_pre_forward( +def _patched_root_pre_forward( self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: self._lazy_init() @@ -283,7 +283,8 @@ def _root_pre_forward( current_stream = self._device_handle.current_stream() self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) self._comm_ctx.all_gather_stream.wait_stream(current_stream) - if self._device.type in ["cuda", "hpu"]: + # add patch for supporting self._device.type="npu" + if self._device.type in ["cuda", "hpu", "npu"]: with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( args, kwargs, self._device, False -- Gitee From cddc29bfe3852f20901c92339a5403f7d285c130 Mon Sep 17 00:00:00 2001 From: zhangqiongwen Date: Thu, 11 Sep 2025 10:32:57 +0800 Subject: [PATCH 5/5] [bugfix]add fsdp patch for supporting host inputs --- torch_npu/distributed/fsdp/_add_fsdp_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/distributed/fsdp/_add_fsdp_patch.py b/torch_npu/distributed/fsdp/_add_fsdp_patch.py index b536b3fd4b..59b12070c5 100644 --- a/torch_npu/distributed/fsdp/_add_fsdp_patch.py +++ b/torch_npu/distributed/fsdp/_add_fsdp_patch.py @@ -12,6 +12,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_collectives import AllGatherResul from torch.distributed.fsdp._fully_shard._fsdp_common import compiled_autograd_enabled, TrainingState from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam, ShardedState from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup, AllGatherState +from torch.distributed.utils import _to_kwargs from torch.profiler import record_function import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error -- Gitee