From 77c9707617c21cab7f18838a6b97cbdee40c386d Mon Sep 17 00:00:00 2001 From: wangchao Date: Fri, 1 Aug 2025 17:19:56 +0800 Subject: [PATCH] fix p2p bug with ranktable --- .../csrc/distributed/ProcessGroupHCCL.cpp | 38 ++++++++++++++----- .../csrc/distributed/ProcessGroupHCCL.hpp | 5 ++- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 3403c954be..7cdc9b94ab 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2154,7 +2154,7 @@ std::vector>& ProcessGroupHCCL::getHCCLComm( return createHCCLComm(devicesKey, devices, commType, commConfig, p2pRank); } -void ProcessGroupHCCL::createHCCLComm( +void ProcessGroupHCCL::createHCCLCommOrigin( const std::string& devicesKey, const std::vector& devices, HcclCommType commType, @@ -2214,6 +2214,7 @@ void ProcessGroupHCCL::createHCCLComm( } bool ProcessGroupHCCL::createHCCLCommEx( + const std::string& devicesKey, const std::vector& devices, HcclCommType commType, HcclCommConfig* commConfig, @@ -2301,10 +2302,27 @@ bool ProcessGroupHCCL::createHCCLCommEx( } commConfig = &config; } - auto subComm = HCCLComm::createSubHcclComm(globalHcclComm, numRanks, options_->global_ranks_in_group.data(), hcclid, rank, commConfig); + std::shared_ptr subComm = nullptr; + if (commType == HcclCommType::P2P && options_->global_ranks_in_group.empty()) { + uint32_t peer = static_cast(getP2pPeer()); + uint32_t lowRank = rank_ < peer ? rank_ : peer; + uint32_t highRank = rank_ < peer ? peer : rank_; + std::vector p2pRanks = {lowRank, highRank}; + hcclid = (std::hash{}(devicesKey)); + std::string p2pName = "p2p_" + std::to_string(lowRank) + "_" + std::to_string(highRank); + if (strlen(commConfig->hcclCommName) > 0) { + torch_npu::toolkit::profiler::Utils::safe_strcpy_s(commConfig->hcclCommName, p2pName.c_str(), COMM_NAME_MAX_LENGTH); + } + if (strlen(commConfig->hcclUdi) > 0) { + torch_npu::toolkit::profiler::Utils::safe_strcpy_s(commConfig->hcclUdi, p2pName.c_str(), UDI_MAX_LENGTH); + } + subComm = HCCLComm::createSubHcclComm(globalHcclComm, numRanks, p2pRanks.data(), hcclid, rank, commConfig); + } else { + subComm = HCCLComm::createSubHcclComm(globalHcclComm, numRanks, options_->global_ranks_in_group.data(), hcclid, rank, commConfig); + } if (subComm == nullptr) { - ASCEND_LOGI("Create sub hccl comm by hcclCreateSubCommConfig failed, group id is %s, subCommId is %llu.", - options_->group_id.c_str(), hcclid); + ASCEND_LOGI("Create sub hccl comm by hcclCreateSubCommConfig failed, group id is %s, subCommId is %llu, devicesKey is %s.", + options_->group_id.c_str(), hcclid, devicesKey.c_str()); return false; } hcclComms[i] = subComm; @@ -2316,10 +2334,10 @@ bool ProcessGroupHCCL::createHCCLCommEx( } auto subEndTime = std::chrono::steady_clock::now(); auto subTimeElapsed = std::chrono::duration_cast(subEndTime - subStartTime); - ASCEND_LOGI("Create sub hccl comm by hcclCreateSubCommConfig success, group id is %s, subCommId is %llu, use %d ms.", - options_->group_id.c_str(), hcclid, subTimeElapsed.count()); - logger->info("Create sub hccl comm by hcclCreateSubCommConfig success, group id is %s, subCommId is %llu, use %d ms.", - options_->group_id.c_str(), hcclid, subTimeElapsed.count()); + ASCEND_LOGI("Create sub hccl comm by hcclCreateSubCommConfig success, group id is %s, subCommId is %llu, devicesKey is %s, use %d ms.", + options_->group_id.c_str(), hcclid, devicesKey.c_str(), subTimeElapsed.count()); + logger->info("Create sub hccl comm by hcclCreateSubCommConfig success, group id is %s, subCommId is %llu, devicesKey is %s, use %d ms.", + options_->group_id.c_str(), hcclid, devicesKey.c_str(), subTimeElapsed.count()); return true; } @@ -2391,8 +2409,8 @@ std::vector>& ProcessGroupHCCL::createHCCLComm( std::vector streamVal; streamVal.reserve(devices.size()); - if (!createHCCLCommEx(devices, commType, commConfig, hcclComms, streamVal, p2pRank)) { - createHCCLComm(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank); + if (!createHCCLCommEx(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank)) { + createHCCLCommOrigin(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank); } hcclStreams_.emplace(devicesKey, std::move(streamVal)); diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 2dd834306f..0240ea8531 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -923,6 +923,8 @@ protected: int peer_; + std::vector global_ranks_in_group; + std::exception_ptr watchDogException_ = nullptr; std::shared_ptr pgStatus_ = std::make_shared(); @@ -989,7 +991,7 @@ private: HcclCommConfig* commConfig = nullptr, int p2pRank = 0); - void createHCCLComm( + void createHCCLCommOrigin( const std::string& devicesKey, const std::vector& devices, HcclCommType commType, @@ -999,6 +1001,7 @@ private: int p2pRank); bool createHCCLCommEx( + const std::string& devicesKey, const std::vector& devices, HcclCommType commType, HcclCommConfig* commConfig, -- Gitee