diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index 7ea22768e69e200fa2163e63f1a6d63811d1f555..7b9442cfc8b63d77b42b42b65b289c485c07b379 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -454,7 +454,8 @@ PyObject* c10d_npu_init(PyObject* _unused, PyObject* noargs) auto processGroupLCCL = intrusive_ptr_no_gil_destructor_class_<::c10d_npu::ProcessGroupLCCL>( module, "ProcessGroupLCCL", dist.attr("Backend")) .def(py::init&, int, int>(), - py::call_guard()); + py::call_guard()) + .def("get_lccl_comm_args", &::c10d_npu::ProcessGroupLCCL::getLCCLCommArgs); auto cDist = py::module_::import("torch._C._distributed_c10d"); auto parallelStore = intrusive_ptr_no_gil_destructor_class_<::c10d::ParallelTcpStore>( diff --git a/torch_npu/csrc/distributed/ProcessGroupLCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupLCCL.cpp index e2d50c6dbc3fb1f5b85532be3b7f1e37f9be2c30..cc41f9b34e56cf3eff9801144fd83c99806cec3a 100644 --- a/torch_npu/csrc/distributed/ProcessGroupLCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupLCCL.cpp @@ -208,6 +208,58 @@ std::vector &ProcessGroupLCCL::getLCCLComm( return devLCCLCommMap_[devicesKey]; } + +uintptr_t ProcessGroupLCCL::getLCCLCommArgs( + const std::string &devicesKey, + const std::vector &devices) +{ + // Sanity check + if (devicesKey.empty()) { + throw std::runtime_error("Not able to create/get the lccll Communicator since " + "the NPU devices are not known" + + DIST_ERROR(ErrCode::PARAM)); + } + + { + std::lock_guard lock(mutex_); + if (devLCCLCommMap_.find(devicesKey) != devLCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + Lcal::LcalComm* object = reinterpret_cast(devLCCLCommMap_[devicesKey][0]); + auto commargs = object->GetCommArgsPtr(); + return reinterpret_cast(devLCCLCommMap_[devicesKey][0]); + } + } + + std::vector lcclComms; + lcclComms.resize(devices.size()); + + c10_npu::OptionalNPUGuard npuGuard; + std::vector streamVal; + streamVal.reserve(devices.size()); + + for (size_t i = 0; i < devices.size(); ++i) { + npuGuard.set_index(devices[i].index()); + auto ret = at_npu::lccl::LcclCommInitRankLocal(size_, rank_, &lcclComms[i]); + TORCH_CHECK(ret == 0, "init lccl comm failed, error code:", ret, PTA_ERROR(ErrCode::INTERNAL)); + + // Creates the LCCL streams + streamVal.push_back(c10_npu::getNPUStreamFromPool(devices[i].index())); + } + + lcclStreams_.emplace(devicesKey, std::move(streamVal)); + + // Note: these events are created with the (default) cudaEventDisableTiming + // flag This flag provides the best performance when used with + // StreamWaitEvent() and EventQuery(). Since we here don't measure the + // performance using npuEvent, this should be set. + lcclEvents_.emplace(std::piecewise_construct, std::make_tuple(devicesKey), std::make_tuple(devices.size())); + + // Hold the lock before modifying the cache. + std::lock_guard lock(mutex_); + devLCCLCommMap_.emplace(devicesKey, std::move(lcclComms)); + return reinterpret_cast(devLCCLCommMap_[devicesKey][0]); +} + template c10::intrusive_ptr ProcessGroupLCCL::collective(std::vector &inputs, std::vector &outputs, Fn fn, PreProcess pre, diff --git a/torch_npu/csrc/distributed/ProcessGroupLCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupLCCL.hpp index a26eb8f9f9082e2469b5f5854bda44fec81608b7..a027f3b187efde3a34c7c767900d90f2ff56a66e 100644 --- a/torch_npu/csrc/distributed/ProcessGroupLCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupLCCL.hpp @@ -130,6 +130,10 @@ public: static const int64_t kProcessGroupLCCLOpTimeoutMillis; + uintptr_t ProcessGroupLCCL::getLCCLCommArgs( + const std::string &devicesKey, + const std::vector &devices); + protected: // Helper that either looks up the cached LCCL communicators or creates // a new set of LCCL communicators as a cache entry