From cb3e9711644f542690384833872c3d9fcfa56085 Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 21 Jul 2025 15:13:18 +0800 Subject: [PATCH 1/3] hccl lazy init --- CMakeLists.txt | 2 - setup.py | 2 +- torch_npu/csrc/distributed/HCCLUtils.hpp | 5 + torch_npu/csrc/distributed/HcclCompile.h | 147 +++++++++++++++++- .../csrc/distributed/ProcessGroupHCCL.cpp | 30 ++-- torch_npu/csrc/npu/DataParallelComm.cpp | 6 +- 6 files changed, 170 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d506ad2dc..10dcc23b12 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -297,8 +297,6 @@ endif() link_directories(${PYTORCH_INSTALL_DIR}/lib) link_directories(${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs) -target_link_libraries(${PLUGIN_NAME} PUBLIC ${TORCHNPU_THIRD_PARTY_ROOT}/acl/libs/libhccl.so) - if (NOT DEFINED BUILD_LIBTORCH) target_link_libraries(${PLUGIN_NAME} PUBLIC ${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so) endif() diff --git a/setup.py b/setup.py index 5b18fa0664..d19af49c80 100644 --- a/setup.py +++ b/setup.py @@ -236,6 +236,7 @@ def patchelf_dynamic_library(): library_files = [str(i) for i in library_dir.rglob('*.so')] for library_file in library_files: subprocess.check_call(["patchelf", "--remove-needed", "libgomp.so.1", library_file], cwd=BASE_DIR) # Compliant + subprocess.check_call(["patchelf", "--remove-needed", "libhccl.so", library_file], cwd=BASE_DIR) # Compliant def CppExtension(name, sources, *args, **kwargs): @@ -258,7 +259,6 @@ def CppExtension(name, sources, *args, **kwargs): libraries.append('torch') libraries.append('torch_cpu') libraries.append('torch_python') - libraries.append('hccl') kwargs['libraries'] = libraries kwargs['language'] = 'c++' return Extension(name, sources, *args, **kwargs) diff --git a/torch_npu/csrc/distributed/HCCLUtils.hpp b/torch_npu/csrc/distributed/HCCLUtils.hpp index b4662c1e49..d6910c590a 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.hpp +++ b/torch_npu/csrc/distributed/HCCLUtils.hpp @@ -58,12 +58,17 @@ } while (0) namespace c10d_npu { +extern HcclResult hcclCommDestroy(HcclComm comm); +extern HcclResult hcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, + aclrtStream stream); extern HcclResult hcclGetCommAsyncError(HcclComm comm, HcclResult* asyncError); +extern HcclResult hcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); extern HcclResult hcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclCommConfig* config, HcclComm *comm); extern HcclResult hcclCommInitClusterInfoConfig(const char *clusterInfo, uint32_t rank, HcclCommConfig *config, HcclComm *comm); extern HcclResult hcclCreateSubCommConfig(HcclComm *comm, uint32_t rankNum, uint32_t *rankIds, uint64_t subCommId, uint32_t subCommRankId, HcclCommConfig* config, HcclComm *subComm); extern HcclResult hcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); +extern HcclResult hcclCommInitAll(uint32_t ndev, int32_t *devices, HcclComm *comms); // Provides additional detail into HCCL error codes based on when these are // thrown in the HCCL codebase. diff --git a/torch_npu/csrc/distributed/HcclCompile.h b/torch_npu/csrc/distributed/HcclCompile.h index a63ad73696..07762464b5 100644 --- a/torch_npu/csrc/distributed/HcclCompile.h +++ b/torch_npu/csrc/distributed/HcclCompile.h @@ -14,14 +14,24 @@ namespace c10d_npu { GET_FUNCTION(libhccl, funcName) REGISTER_LIBRARY(libhccl) +LOAD_FUNCTION(HcclGetRootInfo) +LOAD_FUNCTION(HcclCommDestroy) +LOAD_FUNCTION(HcclSend) +LOAD_FUNCTION(HcclRecv) +LOAD_FUNCTION(HcclAllReduce) LOAD_FUNCTION(HcclAlltoAllV) +LOAD_FUNCTION(HcclBroadcast) LOAD_FUNCTION(HcclAllGatherV) +LOAD_FUNCTION(HcclAllGather) +LOAD_FUNCTION(HcclReduceScatter) LOAD_FUNCTION(HcclReduceScatterV) LOAD_FUNCTION(HcclReduce) LOAD_FUNCTION(HcclGetCommAsyncError) LOAD_FUNCTION(HcclScatter) LOAD_FUNCTION(HcclBatchSendRecv) LOAD_FUNCTION(HcclAlltoAll) +LOAD_FUNCTION(HcclCommInitAll) +LOAD_FUNCTION(HcclCommInitRootInfo) LOAD_FUNCTION(HcclCommInitRootInfoConfig) LOAD_FUNCTION(HcclGetCommConfigCapability) LOAD_FUNCTION(HcclCommInitClusterInfoConfig) @@ -29,9 +39,103 @@ LOAD_FUNCTION(HcclCreateSubCommConfig) LOAD_FUNCTION(HcclCommWorkingDevNicSet) +extern HcclResult hcclGetRootInfo(HcclRootInfo *rootInfo) +{ + using HcclGetRootInfoFunc = HcclResult(*)(HcclRootInfo *); + static HcclGetRootInfoFunc func = nullptr; + if (func == nullptr) { + func = (HcclGetRootInfoFunc)GET_FUNC(HcclGetRootInfo) + } + TORCH_CHECK(func, "Failed to find function ", "HcclGetRootInfo", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(rootInfo); + return ret; +} + +extern HcclResult hcclCommDestroy(HcclComm comm) +{ + using HcclCommDestroyFunc = HcclResult(*)(HcclComm); + static HcclCommDestroyFunc func = nullptr; + if (func == nullptr) { + func = (HcclCommDestroyFunc)GET_FUNC(HcclCommDestroy) + } + TORCH_CHECK(func, "Failed to find function ", "HcclCommDestroy", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(comm); + return ret; +} + +extern HcclResult hcclSend(void *sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, + HcclComm comm, aclrtStream stream) +{ + using HcclSendFunc = HcclResult(*)( + void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); + static HcclSendFunc func = nullptr; + if (func == nullptr) { + func = (HcclSendFunc)GET_FUNC(HcclSend) + } + TORCH_CHECK(func, "Failed to find function ", "HcclSend", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(sendBuf, count, dataType, destRank, comm, stream); + return ret; +} + +extern HcclResult hcclRecv(void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, + HcclComm comm, aclrtStream stream) +{ + using HcclRecvFunc = HcclResult(*)( + void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); + static HcclRecvFunc func = nullptr; + if (func == nullptr) { + func = (HcclRecvFunc)GET_FUNC(HcclRecv) + } + TORCH_CHECK(func, "Failed to find function ", "HcclRecv", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(recvBuf, count, dataType, srcRank, comm, stream); + return ret; +} + +extern HcclResult hcclCommInitAll(uint32_t ndev, int32_t *devices, HcclComm *comms) +{ + using HcclCommInitAllFunc = HcclResult(*)( + uint32_t, int32_t *, HcclComm *); + static HcclCommInitAllFunc func = nullptr; + if (func == nullptr) { + func = (HcclCommInitAllFunc)GET_FUNC(HcclCommInitAll) + } + TORCH_CHECK(func, "Failed to find function ", "HcclCommInitAll", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(ndev, devices, comms); + return ret; +} + +extern HcclResult hcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + HcclComm comm, aclrtStream stream) +{ + using HcclAllGatherFunc = HcclResult(*)( + void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream); + static HcclAllGatherFunc func = nullptr; + if (func == nullptr) { + func = (HcclAllGatherFunc)GET_FUNC(HcclAllGather) + } + TORCH_CHECK(func, "Failed to find function ", "HcclAllGather", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(sendBuf, recvBuf, sendCount, dataType, comm, stream); + return ret; +} + +extern HcclResult hcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + HcclReduceOp op, HcclComm comm, aclrtStream stream) +{ + using HcclAllReduceFunc = HcclResult(*)( + void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream); + static HcclAllReduceFunc func = nullptr; + if (func == nullptr) { + func = (HcclAllReduceFunc)GET_FUNC(HcclAllReduce) + } + TORCH_CHECK(func, "Failed to find function ", "HcclAllReduce", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(sendBuf, recvBuf, count, dataType, op, comm, stream); + return ret; +} + extern HcclResult hcclAlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, - HcclDataType recvType, HcclComm comm, aclrtStream stream) { + HcclDataType recvType, HcclComm comm, aclrtStream stream) +{ using HcclAlltoAllVFunc = HcclResult(*)( const void *, const void *, const void *, HcclDataType, const void *, const void *, const void *, HcclDataType, @@ -63,6 +167,47 @@ extern HcclResult hcclAllGatherV(const void *sendBuf, uint64_t sendCount, return ret; } +extern HcclResult hcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, + aclrtStream stream) +{ + using HcclBroadcastFunc = HcclResult(*)( + void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); + static HcclBroadcastFunc func = nullptr; + if (func == nullptr) { + func = (HcclBroadcastFunc)GET_FUNC(HcclBroadcast) + } + TORCH_CHECK(func, "Failed to find function ", "HcclBroadcast", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(buf, count, dataType, root, comm, stream); + return ret; +} + +extern HcclResult hcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) +{ + using HcclCommInitRootInfoFunc = HcclResult(*)( + uint32_t, const HcclRootInfo *, uint32_t, HcclComm *); + static HcclCommInitRootInfoFunc func = nullptr; + if (func == nullptr) { + func = (HcclCommInitRootInfoFunc)GET_FUNC(HcclCommInitRootInfo) + } + TORCH_CHECK(func, "Failed to find function ", "HcclCommInitRootInfo", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(nRanks, rootInfo, rank, comm); + return ret; +} + +extern HcclResult hcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType, + HcclReduceOp op, HcclComm comm, aclrtStream stream) +{ + using HcclReduceScatterFunc = HcclResult(*)( + void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream); + static HcclReduceScatterFunc func = nullptr; + if (func == nullptr) { + func = (HcclReduceScatterFunc)GET_FUNC(HcclReduceScatter); + } + TORCH_CHECK(func, "Failed to find function ", "HcclReduceScatter", DIST_ERROR(ErrCode::NOT_FOUND)); + auto ret = func(sendBuf, recvBuf, recvCount, dataType, op, comm, stream); + return ret; +} + extern HcclResult hcclReduceScatterV(const void *sendBuf, const void *sendCounts, const void *sdispls, const void *recvBuf, uint64_t recvCount, HcclDataType dataType, HcclReduceOp op, HcclComm comm, aclrtStream stream) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 7ec4974387..a4d01f44b9 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2213,7 +2213,7 @@ void ProcessGroupHCCL::createHCCLComm( HcclRootInfo hcclID; bool isSingleP2POp = commType == HcclCommType::P2P ? true : false; if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { - HCCL_CHECK_ERROR(HcclGetRootInfo(&hcclID)); + HCCL_CHECK_ERROR(hcclGetRootInfo(&hcclID)); } broadcastMasterID(&hcclID, isSingleP2POp, devicesKey, p2pRank); @@ -2383,7 +2383,7 @@ void ProcessGroupHCCL::createHCCLCommForZeroCopy( HcclRootInfo hcclID; if (envMap["local_rank"] == localRootRank) { - HCCL_CHECK_ERROR(HcclGetRootInfo(&hcclID)); + HCCL_CHECK_ERROR(hcclGetRootInfo(&hcclID)); } HcclRootInfo* hcclID_ = &hcclID; @@ -3785,7 +3785,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allreduce( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllreduce", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllReduce( + auto hccl_result = hcclAllReduce( inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; @@ -3910,7 +3910,7 @@ c10::intrusive_ptr ProcessGroupHCCL::broadcast( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclBroadcast", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream(false)); + auto hccl_result = hcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -3959,7 +3959,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allreduce_coalesced( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllreduce", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllReduce( + auto hccl_result = hcclAllReduce( inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; @@ -4367,7 +4367,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); + auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -4524,7 +4524,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclBroadcast", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream()); + auto hccl_result = hcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream()); *is_dispatched = true; return hccl_result; }, @@ -4564,7 +4564,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather_into_tensor_coalesced torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); + auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -4610,7 +4610,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather_togather( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); + auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -4661,7 +4661,7 @@ c10::intrusive_ptr ProcessGroupHCCL::_allgather_base( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); + auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -4708,7 +4708,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclReduceScatter( + auto hccl_result = hcclReduceScatter( inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; @@ -4904,7 +4904,7 @@ c10::intrusive_ptr ProcessGroupHCCL::_reduce_scatter_base( torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclReduceScatter( + auto hccl_result = hcclReduceScatter( inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; @@ -4950,7 +4950,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter_tensor_coalesced torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, streamId, -1, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclReduceScatter( + auto hccl_result = hcclReduceScatter( inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false)); *is_dispatched = true; return hccl_result; @@ -5133,7 +5133,7 @@ c10::intrusive_ptr ProcessGroupHCCL::send(std::vector& t torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclSend", numel, hcclType, comm, streamId, -1, dst_rank), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclSend(inputDataPtr, numel, hcclType, static_cast(dst_rank), comm, stream.stream(false)); + auto hccl_result = hcclSend(inputDataPtr, numel, hcclType, static_cast(dst_rank), comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; @@ -5168,7 +5168,7 @@ c10::intrusive_ptr ProcessGroupHCCL::recv(std::vector& t torch_npu::profiler::MstxRange range( getMstxHcclMsg("HcclRecv", numel, hcclType, comm, streamId, src_rank, -1), stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION); - auto hccl_result = HcclRecv(outputDataPtr, numel, hcclType, static_cast(src_rank), comm, stream.stream(false)); + auto hccl_result = hcclRecv(outputDataPtr, numel, hcclType, static_cast(src_rank), comm, stream.stream(false)); *is_dispatched = true; return hccl_result; }; diff --git a/torch_npu/csrc/npu/DataParallelComm.cpp b/torch_npu/csrc/npu/DataParallelComm.cpp index c744e1e1ba..643f5d0cbb 100644 --- a/torch_npu/csrc/npu/DataParallelComm.cpp +++ b/torch_npu/csrc/npu/DataParallelComm.cpp @@ -44,7 +44,7 @@ struct HcclCommList { int ndevices; explicit HcclCommList(const std::vector& devices): comms(new HcclComm[devices.size()]), ndevices(devices.size()) { - HCCL_CHECK_ERROR(HcclCommInitAll(devices.size(), const_cast(devices.data()), comms.get())); + HCCL_CHECK_ERROR(c10d_npu::hcclCommInitAll(devices.size(), const_cast(devices.data()), comms.get())); } HcclCommList(HcclCommList&& foo) = default; HcclCommList& operator=(HcclCommList&& foo) = default; @@ -61,7 +61,7 @@ struct HcclCommList { In these cases, skip hcclCommDestroy */ return; } - HcclCommDestroy(comms[i]); + c10d_npu::hcclCommDestroy(comms[i]); } } } @@ -228,7 +228,7 @@ void broadcast(TensorList tensors, const stream_list& streams = {}, const comm_l count_max, ")" + PTA_ERROR(ErrCode::VALUE)); HcclComm comm = comms[i]; - HCCL_CHECK_ERROR(HcclBroadcast(tensors[i].data_ptr(), numel, data_type, 0, comm, stream)); + HCCL_CHECK_ERROR(c10d_npu::hcclBroadcast(tensors[i].data_ptr(), numel, data_type, 0, comm, stream)); }); } for (auto &t : threads) { -- Gitee From 7e506d02ddae011e5813fa381c4cfa5eb7df454b Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 21 Jul 2025 20:06:58 +0800 Subject: [PATCH 2/3] 1 --- torch_npu/csrc/distributed/HCCLUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/HCCLUtils.cpp b/torch_npu/csrc/distributed/HCCLUtils.cpp index 74c2334ade..227092c466 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.cpp +++ b/torch_npu/csrc/distributed/HCCLUtils.cpp @@ -130,7 +130,7 @@ std::shared_ptr HCCLComm::create( HcclRootInfo& rootInfo) { auto comm = std::make_shared(); - HCCL_CHECK_ERROR(HcclCommInitRootInfo(numRanks, &rootInfo, rank, &(comm->hcclComm_))); + HCCL_CHECK_ERROR(hcclCommInitRootInfo(numRanks, &rootInfo, rank, &(comm->hcclComm_))); c10_npu::NpuSysCtrl::GetInstance().RegisterReleaseFn([=]() ->void {comm->destroyHcclComm();}, c10_npu::ReleasePriority::PriorityMiddle); return comm; -- Gitee From 0e71b79c3852c0b6a33d4d3518c9a3d44f077ce1 Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 22 Jul 2025 09:17:55 +0800 Subject: [PATCH 3/3] 2 --- torch_npu/csrc/distributed/HCCLUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/HCCLUtils.cpp b/torch_npu/csrc/distributed/HCCLUtils.cpp index 227092c466..e34e7166a9 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.cpp +++ b/torch_npu/csrc/distributed/HCCLUtils.cpp @@ -200,7 +200,7 @@ void HCCLComm::destroyHcclComm() { std::unique_lock lock(mutex_); if (hcclComm_) { - HcclCommDestroy(hcclComm_); + hcclCommDestroy(hcclComm_); hcclComm_ = nullptr; } } -- Gitee