diff --git a/test/contrib/test_transfer_to_npu.py b/test/contrib/test_transfer_to_npu.py index bd950a79ce80e3c407fed4f4ec64bb9ba2da0f00..6584ebfe8847b1893abc64e705bd8f8c00756086 100644 --- a/test/contrib/test_transfer_to_npu.py +++ b/test/contrib/test_transfer_to_npu.py @@ -11,6 +11,27 @@ from torch_npu.contrib import transfer_to_npu class TestTransferToNpu(TestCase): + def test_generator(self): + g0 = torch.Generator() + self.assertTrue(isinstance(g0, torch.Generator)) + self.assertEqual(g0.device.type, 'cpu') + + g1 = torch.Generator('cuda') + self.assertTrue(isinstance(g1, torch.Generator)) + self.assertEqual(g1.device.type, 'npu') + + g2 = torch.Generator(torch.device('cuda')) + self.assertTrue(isinstance(g2, torch.Generator)) + self.assertEqual(g2.device.type, 'npu') + + g3 = torch.Generator(device='cuda') + self.assertTrue(isinstance(g3, torch.Generator)) + self.assertEqual(g3.device.type, 'npu') + + g4 = torch.Generator(device=torch.device('cuda')) + self.assertTrue(isinstance(g4, torch.Generator)) + self.assertEqual(g4.device.type, 'npu') + def test_wrap_isinstance(self): # check builtins isinstance grammar self.assertTrue(isinstance(1, int)) diff --git a/test/cpp_extensions/extension.cpp b/test/cpp_extensions/extension.cpp index 636982882df2fcb123bb323f9b8266c5a5452609..8d3a62f1ac2fa0c476550bb337ab2861843f1d4f 100644 --- a/test/cpp_extensions/extension.cpp +++ b/test/cpp_extensions/extension.cpp @@ -48,6 +48,17 @@ bool check_from_blob() return dtype_same && num_same && pos1_same && pos2_same && pos3_same && sub_same; } +bool check_from_blob_delete() +{ + int isgone = 0; + { + auto data = torch::tensor({1.0, 2.0, 3.0}, torch::kFloat).to(at::Device("npu:0")); + auto res = at_npu::native::from_blob(data.data_ptr(), data.sizes(), [&](void*) { isgone++; }); + } + bool is_deleted = (isgone == 1); + return is_deleted; +} + bool check_from_blob_strides() { auto data = torch::tensor({1, 2, 3, 4, 5, 6, 7, 8, 9}, torch::kInt32).to(at::Device("npu:0")); @@ -131,6 +142,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("check_storage_sizes", &check_storage_sizes, "check_storage_sizes"); m.def("check_from_blob", &check_from_blob, "check_from_blob"); m.def("check_from_blob_strides", &check_from_blob_strides, "check_from_blob_strides"); + m.def("check_from_blob_delete", &check_from_blob_delete, "check_from_blob_delete"); m.def("blocking_ops", &blocking_ops, "blocking_ops"); m.def("register_op_hook", ®ister_op_hook, "register_op_hook"); m.def("get_op_hook_call_count", &get_op_hook_call_count, "get_op_hook_call_count"); diff --git a/test/cpp_extensions/test/test_cpp_extensions_aot.py b/test/cpp_extensions/test/test_cpp_extensions_aot.py index 0adbfd4127d6fc0a0992866a4ec782bb250c931e..409bce348bc94dce54773b4eb2b7504bf44e0ee1 100644 --- a/test/cpp_extensions/test/test_cpp_extensions_aot.py +++ b/test/cpp_extensions/test/test_cpp_extensions_aot.py @@ -53,6 +53,7 @@ class TestCppExtensionAOT(TestCase): def test_from_blob(self): self.assertTrue(npu_extension.check_from_blob()) self.assertTrue(npu_extension.check_from_blob_strides()) + self.assertTrue(npu_extension.check_from_blob_delete()) def test_dispatch_allreduce(self): flags = os.O_WRONLY | os.O_RDONLY | os.O_CREAT diff --git a/test/npu/test_npu_format.py b/test/npu/test_npu_format.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc1c067ff4496896e816493c36529074bbfb2a8 --- /dev/null +++ b/test/npu/test_npu_format.py @@ -0,0 +1,49 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestNPUFormat(TestCase): + + def test_enum_values(self): + """test the enumeration value""" + self.assertEqual(torch_npu.Format.NCHW.value, 0) + self.assertEqual(torch_npu.Format.NHWC.value, 1) + + def test_npu_format_cast(self): + """test npu_format_cast""" + tensor = torch.ones(2, 2).npu() + + out1 = torch_npu.npu_format_cast(tensor, 0) + fmt1 = torch_npu.get_npu_format(out1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + + out2 = torch_npu.npu_format_cast(tensor, torch_npu.Format.NHWC) + fmt2 = torch_npu.get_npu_format(out2) + self.assertEqual(fmt2, torch_npu.Format.NHWC) + + def test_npu_format_cast_(self): + """test npu_format_cast_""" + x1 = torch.ones(2, 2).npu() + x2 = torch.ones(2, 2).npu() + + torch_npu.npu_format_cast_(x1, 0) + fmt1 = torch_npu.get_npu_format(x1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + + torch_npu.npu_format_cast_(x2, torch_npu.Format.NHWC) + fmt2 = torch_npu.get_npu_format(x2) + self.assertEqual(fmt2, torch_npu.Format.NHWC) + + def test_get_npu_format(self): + """test get_npu_format""" + x1 = torch.ones(2, 2).npu() + torch_npu.npu_format_cast_(x1, 0) + + fmt1 = torch_npu.get_npu_format(x1) + self.assertEqual(fmt1, torch_npu.Format.NCHW) + self.assertEqual(fmt1, 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/npu/test_torch_npu.py b/test/npu/test_torch_npu.py index ca5c77b21e8ca0d2ba9b9715d4b3bd020dc870a2..541dd3c59e7fa569d00862d10cb5723431270985 100644 --- a/test/npu/test_torch_npu.py +++ b/test/npu/test_torch_npu.py @@ -78,6 +78,12 @@ class TorchNPUDeviceTestCase(TestCase): torch_npu.npu.synchronize() after_free_memory, after_total_memory = torch_npu.npu.mem_get_info(0) self.assertEqual(before_total_memory, after_total_memory) + + @unittest.skip("CANN doesn't support now.") + def test_set_device_res_limit(self): + ans_dict = {'cube_num': 12, 'vector_num': 24} + torch.npu.set_device_res_limit(torch.npu.current_device(), 12, 24) + self.assertEqual(ans_dict, torch.npu.get_device_res_limit(torch.npu.current_device())) class TorchNPUMemoryApiTestCase(TestCase): def test_npu_memory_stats(self): diff --git a/third_party/acl/inc/acl/acl_rt.h b/third_party/acl/inc/acl/acl_rt.h index 98b520ba4ac73a4b5072d98fd436edde37b51655..ecc36f38128bd746bc9f9cb5064e6f47f9bc5b6a 100755 --- a/third_party/acl/inc/acl/acl_rt.h +++ b/third_party/acl/inc/acl/acl_rt.h @@ -181,6 +181,11 @@ typedef enum aclrtLastErrLevel { ACL_RT_THREAD_LEVEL = 0, } aclrtLastErrLevel; +typedef enum { + ACL_RT_DEV_RES_CUBE_CORE = 0, + ACL_RT_DEV_RES_VECTOR_CORE, +} aclrtDevResModelType; + typedef void* aclrtDrvMemHandle; typedef void (*aclrtCallback)(void *userData); @@ -1541,6 +1546,37 @@ ACL_FUNC_VISIBILITY aclError aclrtPeekAtLastError(aclrtLastErrLevel level); */ ACL_FUNC_VISIBILITY aclError aclrtGetLastError(aclrtLastErrLevel level); +/** + * @ingroup AscendCL + * @brief Get the value of the current device's limited resources + * @param [in] deviceId the device id + * @param [in] type resources type + * @param [out] value resources limit value + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclrtGetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t* value); + +/** + * @ingroup AscendCL + * @brief Set the value of the current device's limited resources + * @param [in] deviceId the device id + * @param [in] type resource type + * @param [in] value resource limit value + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclrtSetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t value); + +/** + * @ingroup AscendCL + * @brief Reset the value of the current device's limited resources + * @param [in] deviceId the device id + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclrtResetDeviceResLimit(int32_t deviceId); + #ifdef __cplusplus } #endif diff --git a/third_party/acl/libs/acl.cpp b/third_party/acl/libs/acl.cpp index 4f24e6bf043cba7c53c7015e597f5c6e82164bd6..9bb32581dd7ea6ca7d1b5fe01c7896dfb7d84764 100644 --- a/third_party/acl/libs/acl.cpp +++ b/third_party/acl/libs/acl.cpp @@ -18,6 +18,9 @@ aclError aclmdlSetDump(const char *configPath){return 0;} aclError aclmdlInitDump(){return 0;} aclError aclmdlFinalizeDump(){return 0;} aclError aclrtDeviceTaskAbort(int32_t deviceId, uint32_t timeout){return 0;} +aclError aclrtGetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t* value){return 0;} +aclError aclrtSetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t value){return 0;} +aclError aclrtResetDeviceResLimit(int32_t deviceId){return 0;} // Stream aclError aclrtCreateStream(aclrtStream *stream) { return 0; } diff --git a/third_party/hccl/inc/hccl/hccl.h b/third_party/hccl/inc/hccl/hccl.h index 023914a348285ad17c459b077cdd03c4593637ea..216ef7a83847e424ee1b0679b351d188452a2981 100644 --- a/third_party/hccl/inc/hccl/hccl.h +++ b/third_party/hccl/inc/hccl/hccl.h @@ -212,6 +212,8 @@ inline void HcclCommConfigInit(HcclCommConfig *config) config->hcclRdmaTrafficClass = HCCL_COMM_TRAFFIC_CLASS_CONFIG_NOT_SET; config->hcclRdmaServiceLevel = HCCL_COMM_SERVICE_LEVEL_CONFIG_NOT_SET; config->hcclOpExpansionMode = HCCL_COMM_DEFAULT_OP_EXPANSION_MODE; + config->hcclWorldRankID = 0; + config->hcclJobID = 0; } /** diff --git a/third_party/hccl/inc/hccl/hccl_types.h b/third_party/hccl/inc/hccl/hccl_types.h index 40631676c1bdc9bb44256b083e647e99e8f6fc8f..9a02c61c0414a96af23bf2468ab96482512240fa 100644 --- a/third_party/hccl/inc/hccl/hccl_types.h +++ b/third_party/hccl/inc/hccl/hccl_types.h @@ -15,7 +15,7 @@ extern "C" { const uint32_t HCCL_COMM_CONFIG_INFO_BYTES = 24; const uint32_t HCCL_COMM_CONFIG_MAGIC_WORD = 0xf0f0f0f0; -const uint32_t HCCL_COMM_CONFIG_VERSION = 5; +const uint32_t HCCL_COMM_CONFIG_VERSION = 6; const uint32_t HCCL_COMM_DEFAULT_BUFFSIZE = 200; // 200MB buffer size const uint32_t HCCL_COMM_DEFAULT_DETERMINISTIC = 0; // Disable deterministic calculations const uint32_t COMM_NAME_MAX_LENGTH = 128; @@ -132,6 +132,8 @@ typedef struct HcclCommConfigDef { uint32_t hcclOpExpansionMode; uint32_t hcclRdmaTrafficClass; uint32_t hcclRdmaServiceLevel; + uint32_t hcclWorldRankID; + uint64_t hcclJobID; } HcclCommConfig; typedef enum { diff --git a/third_party/op-plugin b/third_party/op-plugin index 161f835137eaa0ca36e62202c141dfbde80babfe..f8fab40561b64047e20d2a98c7eac6f100cc71b6 160000 --- a/third_party/op-plugin +++ b/third_party/op-plugin @@ -1 +1 @@ -Subproject commit 161f835137eaa0ca36e62202c141dfbde80babfe +Subproject commit f8fab40561b64047e20d2a98c7eac6f100cc71b6 diff --git a/third_party/torchair/torchair b/third_party/torchair/torchair index edf95b3a70ccd0fcb90a935cfa9836879df9453d..ec5747ba5477a4508131ca4401088e7383908266 160000 --- a/third_party/torchair/torchair +++ b/third_party/torchair/torchair @@ -1 +1 @@ -Subproject commit edf95b3a70ccd0fcb90a935cfa9836879df9453d +Subproject commit ec5747ba5477a4508131ca4401088e7383908266 diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index d897abb1095e1aecbe2d86f5ef349bf6640e9a8a..755309772ab24cc7188ce734e3daf150c017e2c9 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -60,6 +60,7 @@ from torch_npu.utils import _register_ops_under_dtensor_rules from torch_npu.utils.exposed_api import public_npu_functions from torch_npu.distributed.checkpoint.checkpoint import _apply_dcp_patch from torch_npu.npu._stream_check import apply_sanitizer_patch +from torch_npu.npu._format import _apply_npu_format_patch from torch_npu.multiprocessing.reductions import _add_reductions_methods from torch_npu.npu.utils import _erase_stream as erase_stream from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler @@ -155,6 +156,7 @@ def _apply_class_patches(): _apply_distributed_methods_patch() _apply_mstx_patch() _add_reductions_methods() + _apply_npu_format_patch() def _apply_distributed_methods_patch(): diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index ea9a08b51361f8de4d1598891e0aebbef6421fca..b899c0ecec371271bdeb25d03f03cac95916bf01 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -29,7 +29,7 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'eye', '_sparse_csr_tensor_unsafe', 'empty', '_sparse_coo_tensor_unsafe', 'blackman_window', 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', - 'empty_quantized', '_pin_memory', 'autocast', 'load', "Generator", 'set_default_device'] + 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] @@ -46,6 +46,13 @@ cur_path = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(cur_path, 'apis_config.json') +class _GeneratorProxy(torch.Generator): + + def __new__(cls, device='cpu'): + device = _replace_cuda_to_npu_in_list([device], None)[0] + instance = super().__new__(cls, device) + return instance + def _get_function_from_string(attribute_string): try: @@ -329,6 +336,7 @@ def _init(): # torch.* _device_wrapper(torch, torch_fn_white_list) torch.UntypedStorage.__new__ = _wrapper_cuda(torch.UntypedStorage.__new__) + torch.Generator = _GeneratorProxy # torch.Tensor.* _device_wrapper(torch.Tensor, torch_tensor_fn_white_list) diff --git a/torch_npu/csrc/aten/common/from_blob.cpp b/torch_npu/csrc/aten/common/from_blob.cpp index fdd44e3f3a587a2d31fdf62941f2af7cd6d462a0..43e72a83aea30b8ca3152a6378ce7bfcbc97166e 100644 --- a/torch_npu/csrc/aten/common/from_blob.cpp +++ b/torch_npu/csrc/aten/common/from_blob.cpp @@ -36,7 +36,12 @@ at::Tensor TensorMaker::make_tensor() std::size_t size_bytes = computeStorageSize(); - c10::DataPtr data_ptr{data_, *device_}; + c10::DataPtr data_ptr{}; + if (deleter_) { + data_ptr = c10::InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_); + } else { + data_ptr = c10::DataPtr(data_, *device_); + } c10::intrusive_ptr storage_impl = torch_npu::make_npu_storage_impl_inner( c10::StorageImpl::use_byte_size_t(), @@ -86,6 +91,54 @@ std::size_t TensorMaker::computeStorageSize() const noexcept return storage_size; } +at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + std::function deleter, + const at::TensorOptions& options, + const c10::optional target_device) +{ + return for_blob(data, sizes) + .deleter(std::move(deleter)) + .options(options) + .target_device(target_device) + .make_tensor(); +} + +at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + int64_t storage_offset, + const std::function& deleter, + const at::TensorOptions& options, + const c10::optional target_device) +{ + return for_blob(data, sizes) + .strides(strides) + .storage_offset(storage_offset) + .deleter(deleter) + .options(options) + .target_device(target_device) + .make_tensor(); +} + +at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const std::function& deleter, + const at::TensorOptions& options, + const c10::optional target_device) +{ + return for_blob(data, sizes) + .strides(strides) + .deleter(deleter) + .options(options) + .target_device(target_device) + .make_tensor(); +} + at::Tensor from_blob( void* data, at::IntArrayRef sizes, diff --git a/torch_npu/csrc/aten/common/from_blob.h b/torch_npu/csrc/aten/common/from_blob.h index f0d6bbd12700ec295d322762febe80070286bb43..0669d2fdca08965e9797918b35d83b185ef1272e 100644 --- a/torch_npu/csrc/aten/common/from_blob.h +++ b/torch_npu/csrc/aten/common/from_blob.h @@ -41,6 +41,12 @@ public: return *this; } + TensorMaker& deleter(std::function value) noexcept + { + deleter_ = std::move(value); + + return *this; + } at::Tensor make_tensor(); private: @@ -58,6 +64,7 @@ private: c10::optional device_{}; at::TensorOptions opts_{}; c10::Allocator* allocator_{}; + std::function deleter_{}; }; inline TensorMaker for_blob(void* data, at::IntArrayRef sizes) noexcept @@ -65,6 +72,30 @@ inline TensorMaker for_blob(void* data, at::IntArrayRef sizes) noexcept return TensorMaker{data, sizes}; } +TORCH_NPU_API at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + std::function deleter, + const at::TensorOptions& options = {}, + const c10::optional target_device = c10::nullopt); + +TORCH_NPU_API at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + int64_t storage_offset, + const std::function& deleter, + const at::TensorOptions& options = {}, + const c10::optional target_device = c10::nullopt); + +TORCH_NPU_API at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const std::function& deleter, + const at::TensorOptions& options = {}, + const c10::optional target_device = c10::nullopt); + TORCH_NPU_API at::Tensor from_blob( void* data, at::IntArrayRef sizes, diff --git a/torch_npu/csrc/core/npu/NPUAffinityController.cpp b/torch_npu/csrc/core/npu/NPUAffinityController.cpp index 5567c3e6e22292f53e4a5907f46f56287df666af..6c2d35fd951aaae50b0293c0437db458a3896874 100644 --- a/torch_npu/csrc/core/npu/NPUAffinityController.cpp +++ b/torch_npu/csrc/core/npu/NPUAffinityController.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace c10_npu { @@ -16,6 +17,7 @@ static thread_local ThreadType local_thread = ThreadType::MAIN_THREAD; static pthread_t main_thread; static bool start_main_thread_bind = false; +static std::mutex core_map_mutex; using ThreadCoreMap = std::unordered_map; @@ -264,6 +266,7 @@ CoreIdRange getCoreRange(c10::DeviceIndex device_id, ThreadType type) if (cpu_affinity_mode == 0 || cpu_affinity_mode == 1) { core_range = device_ranges[device_id]; } else { + std::lock_guard lock(core_map_mutex); if (device_thread_core_maps.find(device_id) == device_thread_core_maps.end()) { device_thread_core_maps.emplace(device_id, getCpuAffinityMap(device_id, device_ranges)); } diff --git a/torch_npu/csrc/core/npu/NPUException.cpp b/torch_npu/csrc/core/npu/NPUException.cpp index 3d1ca96bf4eea944b15f42e927b703e246204178..1f1b4ea47830468850aaf59d8924a478e26590fe 100644 --- a/torch_npu/csrc/core/npu/NPUException.cpp +++ b/torch_npu/csrc/core/npu/NPUException.cpp @@ -47,11 +47,14 @@ void warn_(const ::c10::Warning& warning) std::string formatErrorCode(SubModule submodule, ErrCode errorCode) { + if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) { + return " "; + } std::ostringstream oss; int deviceIndex = -1; c10_npu::GetDevice(&deviceIndex); auto rank_id = c10_npu::option::OptionsManager::GetRankId(); - if (!(c10_npu::option::OptionsManager::ShouldPrintLessError())) { + if (!(c10_npu::option::OptionsManager::IsCompactErrorOutput())) { oss << "\n[ERROR] " << getCurrentTimestamp() << " (PID:" << getpid() << ", Device:" << deviceIndex << ", RankID:" << rank_id << ") "; } oss << "ERR" << std::setw(2) << std::setfill('0') << static_cast(submodule); @@ -149,10 +152,10 @@ const std::string c10_npu_check_error_message(std::string& errmsg) const char *c10_npu_get_error_message() { auto errmsg = c10_npu::acl::AclGetErrMsg(); - if (c10_npu::option::OptionsManager::ShouldPrintLessError()) { + if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) { std::string log(errmsg); std::string errmsg_ = c10_npu::c10_npu_check_error_message(log); - thread_local std::string processedErrMsg = errmsg_; + thread_local std::string processedErrMsg = "CANN error: " + errmsg_; c10_npu::setRepoErrMsg(processedErrMsg.c_str()); return processedErrMsg.c_str(); } else { diff --git a/torch_npu/csrc/core/npu/NPUException.h b/torch_npu/csrc/core/npu/NPUException.h index 203b6529b786f743d6e58bb6bbaf6fd2d6c48ce1..1d34ae2050ded401c5508ccaa57a89ad481e2a74 100644 --- a/torch_npu/csrc/core/npu/NPUException.h +++ b/torch_npu/csrc/core/npu/NPUException.h @@ -151,7 +151,7 @@ inline const char* getErrorFunction(const char* /* msg */, const char* args) " that driver and firmware packages do not match."); \ return true; \ }(); \ - } else if (c10_npu::option::OptionsManager::ShouldPrintLessError()) { \ + } else if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) { \ std::ostringstream oss; \ oss << " NPU function error: " \ << (device_error_msg.empty() ? getErrorFunction(#err_code, ##__VA_ARGS__) : device_error_msg) \ @@ -207,7 +207,7 @@ inline const char* getErrorFunction(const char* /* msg */, const char* args) static c10_npu::acl::AclErrorCode err_map; \ if ((Error) != ACL_ERROR_NONE) { \ CHECK_AND_THROW_ERROR_WITH_SPECIFIC_MESSAGE(Error); \ - if (c10_npu::option::OptionsManager::ShouldPrintLessError()) \ + if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) \ { \ std::ostringstream oss; \ oss << " OPS function error: " << getErrorFunction(#err_code, ##__VA_ARGS__) \ diff --git a/torch_npu/csrc/core/npu/NPUFunctions.cpp b/torch_npu/csrc/core/npu/NPUFunctions.cpp index 3d146783f0a0532712e08a10727acf12b8e01228..40e865f10989e495766b81146bd10af4e251e1b5 100644 --- a/torch_npu/csrc/core/npu/NPUFunctions.cpp +++ b/torch_npu/csrc/core/npu/NPUFunctions.cpp @@ -5,6 +5,7 @@ #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" +#include "third_party/acl/inc/acl/acl_rt.h" #ifndef BUILD_LIBTORCH #include "torch_npu/csrc/sanitizer/NPUTrace.h" #endif @@ -46,7 +47,6 @@ aclError GetDevice(int32_t *device) { if (targetDeviceIndex >= 0) { *device = targetDeviceIndex; - NPU_CHECK_ERROR_WITHOUT_UCE(SetDevice(targetDeviceIndex)); return ACL_ERROR_NONE; } @@ -60,13 +60,8 @@ aclError GetDevice(int32_t *device) } if (err == ACL_ERROR_NONE) { local_device = *device; - } else if (err == ACL_ERROR_RT_CONTEXT_NULL && aclrtSetDevice(0) == ACL_ERROR_NONE) { + } else if (err == ACL_ERROR_RT_CONTEXT_NULL) { *device = 0; - local_device = 0; - std::lock_guard lock(mtx); - if (used_devices.find(local_device) == used_devices.end()) { - NPU_CHECK_ERROR_WITHOUT_UCE(aclrtGetCurrentContext(&used_devices[local_device])); - } return ACL_ERROR_NONE; } return err; @@ -284,4 +279,42 @@ void stream_synchronize(aclrtStream stream) NPU_CHECK_ERROR(aclrtSynchronizeStream(stream)); } +aclError SetDeviceResLimit(int32_t device, int32_t type, uint32_t value) +{ + std::lock_guard lock(mtx); + if (used_devices.find(device) == used_devices.end()) { + TORCH_CHECK(false, "NPU device ", device, " has not been initialized! Can not get device resource limit"); + } + TORCH_CHECK(device >= 0, "device id must be positive!", PTA_ERROR(ErrCode::VALUE)); + c10_npu::acl::aclrtDevResModelType restype = static_cast(type); + aclError err = c10_npu::acl::AclrtSetDeviceResLimit(device, restype, value); + NPU_CHECK_ERROR_WITHOUT_UCE(err); + return err; +} + +uint32_t GetDeviceResLimit(int32_t device, int32_t type) +{ + std::lock_guard lock(mtx); + if (used_devices.find(device) == used_devices.end()) { + TORCH_CHECK(false, "NPU device ", device, " has not been initialized! Can not get device resource limit"); + } + TORCH_CHECK(device >= 0, "device id must be positive!", PTA_ERROR(ErrCode::VALUE)); + c10_npu::acl::aclrtDevResModelType restype = static_cast(type); + uint32_t value; + NPU_CHECK_ERROR_WITHOUT_UCE(c10_npu::acl::AclrtGetDeviceResLimit(device, restype, &value)); + return value; +} + +aclError ResetDeviceResLimit(int32_t device) +{ + std::lock_guard lock(mtx); + if (used_devices.find(device) == used_devices.end()) { + TORCH_CHECK(false, "NPU device ", device, " has not been initialized! Can not reset device resource limit"); + } + TORCH_CHECK(device >= 0, "device id must be positive!", PTA_ERROR(ErrCode::VALUE)); + aclError err = c10_npu::acl::AclrtResetDeviceResLimit(device); + NPU_CHECK_ERROR_WITHOUT_UCE(err); + return err; +} + } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NPUFunctions.h b/torch_npu/csrc/core/npu/NPUFunctions.h index 9bb715bdb85fc5007e04026865794e9f3a5cc1cd..3e8220a09f17406c45a2b5346c2f24e371186d1d 100644 --- a/torch_npu/csrc/core/npu/NPUFunctions.h +++ b/torch_npu/csrc/core/npu/NPUFunctions.h @@ -77,6 +77,12 @@ void SetTargetDevice(); int GetLocalDevice(); +aclError SetDeviceResLimit(int32_t device, int32_t type, uint32_t value); + +C10_NPU_API uint32_t GetDeviceResLimit(int32_t deviceId, int32_t type); + +aclError ResetDeviceResLimit(int32_t deviceId); + enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; // it's used to store npu synchronization state diff --git a/torch_npu/csrc/core/npu/NPUStream.cpp b/torch_npu/csrc/core/npu/NPUStream.cpp index 714aa57f97744ee29fcd66cd97423053e5b2c4ab..4ffc125d162347017bda4ed44acf2c11a37582f7 100644 --- a/torch_npu/csrc/core/npu/NPUStream.cpp +++ b/torch_npu/csrc/core/npu/NPUStream.cpp @@ -229,6 +229,8 @@ static void initNPUStreamsOnce() { // Inits default and secondary streams (once, globally) c10::DeviceIndex device_index = current_device(); + // makesure on real devcie + SetTargetDevice(); if (!initialize_flag[device_index]) { std::lock_guard lock(mtx[device_index]); if (!initialize_flag[device_index]) { diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index bfcb8aa87d49b77ba3a8505a03028730e95a300d..7d38fbde3cf0c24fd3722d8fe22cc2c2c074bbcb 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -89,6 +89,9 @@ LOAD_FUNCTION(aclrtIpcMemClose) LOAD_FUNCTION(aclrtMemExportToShareableHandle) LOAD_FUNCTION(aclrtMemSetPidToShareableHandle) LOAD_FUNCTION(aclrtMemImportFromShareableHandle) +LOAD_FUNCTION(aclrtGetDeviceResLimit) +LOAD_FUNCTION(aclrtSetDeviceResLimit) +LOAD_FUNCTION(aclrtResetDeviceResLimit) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -1020,5 +1023,41 @@ aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t dev return func(shareableHandle, deviceId, handle); } +aclError AclrtGetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t* value) +{ + typedef aclError (*AclrtGetDeviceResLimit)(int32_t, aclrtDevResModelType, uint32_t*); + static AclrtGetDeviceResLimit func = nullptr; + if (func == nullptr) { + func = (AclrtGetDeviceResLimit) GET_FUNC(aclrtGetDeviceResLimit); + } + + TORCH_CHECK(func, "Failed to find function aclrtGetDeviceResLimit", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(deviceId, type, value); +} + +aclError AclrtSetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t value) +{ + typedef aclError (*AclrtSetDeviceResLimit)(int32_t, aclrtDevResModelType, uint32_t); + static AclrtSetDeviceResLimit func = nullptr; + if (func == nullptr) { + func = (AclrtSetDeviceResLimit) GET_FUNC(aclrtSetDeviceResLimit); + } + + TORCH_CHECK(func, "Failed to find function aclrtSetDeviceResLimit", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(deviceId, type, value); +} + +aclError AclrtResetDeviceResLimit(int32_t deviceId) +{ + typedef aclError (*AclrtResetDeviceResLimit)(int32_t); + static AclrtResetDeviceResLimit func = nullptr; + if (func == nullptr) { + func = (AclrtResetDeviceResLimit) GET_FUNC(aclrtResetDeviceResLimit); + } + + TORCH_CHECK(func, "Failed to find function aclrtResetDeviceResLimit", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(deviceId); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index 6c4b3ff82f698ec8376d4f155f1969e9d94390f9..47035c68074fa41367ef4301a58c043dc184e0d4 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -32,6 +32,12 @@ enum aclrtStreamStatus { }; using aclrtStreamStatus = enum aclrtStreamStatus; +enum aclrtDevResModelType { + ACL_RT_DEV_RES_CUBE_CORE = 0, + ACL_RT_DEV_RES_VECTOR_CORE = 1, +}; +using aclrtDevResModelType = enum aclrtDevResModelType; + /** aclprofStepInfo is provide by acl, it used to be store dispatch op info. */ @@ -243,5 +249,9 @@ aclError AclrtMemSetPidToShareableHandle(uint64_t shareableHandle, int32_t *pid, aclError AclrtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t deviceId, aclrtDrvMemHandle *handle); +aclError AclrtGetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t* value); +aclError AclrtSetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uint32_t value); +aclError AclrtResetDeviceResLimit(int32_t deviceId); + } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.cpp b/torch_npu/csrc/core/npu/register/OptionsManager.cpp index 0f9ffb565bed3ebddc6d6850676fbcc6bd451211..24e924e518207797cede3c552da8f4dd0acaf488 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.cpp +++ b/torch_npu/csrc/core/npu/register/OptionsManager.cpp @@ -482,11 +482,11 @@ uint32_t OptionsManager::GetAclOpInitMode() const static uint32_t acl_op_init_mode = []() -> uint32_t { char* buf_val = std::getenv("ACL_OP_INIT_MODE"); // Default 0 - int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 1; + int64_t acl_op_init_mode = (buf_val != nullptr) ? strtol(buf_val, nullptr, 10) : 0; std::unordered_map aclOpInitMode = getAclOpInitMode(); if (aclOpInitMode.find(acl_op_init_mode) == aclOpInitMode.end()) { - acl_op_init_mode = 1; - TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 1."); + acl_op_init_mode = 0; + TORCH_NPU_WARN_ONCE("Get env ACL_OP_INIT_MODE not in [0, 1, 2], so reset it to the default value 0."); } return static_cast(acl_op_init_mode); }(); @@ -622,7 +622,7 @@ bool OptionsManager::IsOomSnapshotEnable() return (envFlag != 0); } -bool OptionsManager::ShouldPrintLessError() +bool OptionsManager::IsCompactErrorOutput() { static bool should_print = []() -> bool { int32_t disabled_error = OptionsManager::GetBoolTypeOption("TORCH_NPU_COMPACT_ERROR_OUTPUT"); diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.h b/torch_npu/csrc/core/npu/register/OptionsManager.h index 5be33e06daae47716164f4ad7299afabd8c3426c..73f5dbcb81f9fc268d8ef9122407e66b976dad08 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.h +++ b/torch_npu/csrc/core/npu/register/OptionsManager.h @@ -133,7 +133,7 @@ public: static std::string GetOomSnapshotDumpPath(); static bool IsOomSnapshotEnable(); static bool ShouldPrintWarning(); - static bool ShouldPrintLessError(); + static bool IsCompactErrorOutput(); private: static int GetBoolTypeOption(const char* env_str, int defaultVal = 0); diff --git a/torch_npu/csrc/distributed/HCCLUtils.hpp b/torch_npu/csrc/distributed/HCCLUtils.hpp index ffd645a7c50575b8dfe9f3d2652e316cf7723a39..b4662c1e49b85cce5b0f80d55f85706b503c5916 100644 --- a/torch_npu/csrc/distributed/HCCLUtils.hpp +++ b/torch_npu/csrc/distributed/HCCLUtils.hpp @@ -17,7 +17,7 @@ auto Error = err_code; \ if ((Error) != HCCL_SUCCESS) { \ CHECK_AND_THROW_ERROR_WITH_SPECIFIC_MESSAGE(Error); \ - if (c10_npu::option::OptionsManager::ShouldPrintLessError()) { \ + if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) { \ std::ostringstream oss; \ oss << " HCCL function error: " << getErrorFunction(#err_code, ##__VA_ARGS__) \ << ", error code is " << Error << " " \ diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 0ddf9445c56703f4d0ef561b336f0dc1d17c9812..7ec49743870720ea41c86335f892f684ff221379 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -19,6 +19,10 @@ #include #include #include +#include +#include + +#include #include "op_plugin/OpInterface.h" #include "third_party/acl/inc/acl/acl.h" @@ -63,6 +67,7 @@ constexpr const char* P2P_DEVICE_KEY = "_p2p"; using hcclUs = std::chrono::steady_clock::time_point; constexpr int32_t MAX_GROUP_NAME_LEN = 128; +constexpr int32_t NSLB_JOBID_OFFSET = 32; // HCCL ReduceOp mapping std::map hcclOp = { @@ -949,6 +954,24 @@ ProcessGroupHCCL::ProcessGroupHCCL( PrefixStore *prefixStore = dynamic_cast(store_.get()); globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; + c10::intrusive_ptr getTcpStore = store_; + while (getTcpStore) { + c10d::PrefixStore *asPrefixStore = dynamic_cast(getTcpStore.get()); + c10d::TCPStore *tcpStore = dynamic_cast(getTcpStore.get()); + if (tcpStore) { + if (!(tcpStore->getHost().empty())) { + tcpMasterAddr = tcpStore->getHost(); + tcpMasterPort = tcpStore->getPort(); + break; + } + } + if (asPrefixStore) { + getTcpStore = asPrefixStore->getUnderlyingStore(); + } else { + break; + } + } + const char* blockingWait = getenv(HCCL_BLOCKING_WAIT); try { if (blockingWait != nullptr) { @@ -1181,6 +1204,7 @@ void ProcessGroupHCCL::abortAndClearHcclComm(c10::optional abortRea abortCommsFromMap(devHCCLCommMap_, rank_, abortReason); devHCCLCommMap_.clear(); devHCCLCommNameMap_.clear(); + p2pSendRecvKeys_.clear(); hcclCommCounter_ = 0; return; } @@ -1223,6 +1247,7 @@ ProcessGroupHCCL::~ProcessGroupHCCL() } } devHCCLCommMap_.clear(); + p2pSendRecvKeys_.clear(); } ASCEND_LOGI("process group destroyed, group id is %s.", options_->group_id.c_str()); logger->info("process group destroyed, group id is %s.", options_->group_id.c_str()); @@ -2152,6 +2177,30 @@ std::vector>& ProcessGroupHCCL::getHCCLComm( return createHCCLComm(devicesKey, devices, commType, commConfig, p2pRank); } +void ProcessGroupHCCL::setNSLBCommConfig(HcclCommConfig** commConfig) +{ + const char* envPtr = std::getenv("RANK"); + if (envPtr == nullptr) { + ASCEND_LOGI("Failed to get env info for NSLB-DP."); + return; + } + uint32_t worldRankID = std::stoi(std::string(envPtr)); + options_->hccl_config["hccl_world_rank_id"] = worldRankID; + uint32_t masterPort = tcpMasterPort; + struct sockaddr_in sa; + std::string master_addr = tcpMasterAddr; + inet_pton(AF_INET, std::string(master_addr).c_str(), &(sa.sin_addr)); + uint32_t masterIp = ntohl(sa.sin_addr.s_addr); + uint64_t jobID = masterPort; + jobID = (jobID << NSLB_JOBID_OFFSET); + jobID += masterIp; + options_->hccl_config["hccl_job_id"] = jobID; + if ((*commConfig) != nullptr) { + (*commConfig)->hcclWorldRankID = worldRankID; + (*commConfig)->hcclJobID = jobID; + } +} + void ProcessGroupHCCL::createHCCLComm( const std::string& devicesKey, const std::vector& devices, @@ -2176,6 +2225,10 @@ void ProcessGroupHCCL::createHCCLComm( HcclCommConfig config; + if (options_->global_ranks_in_group.empty()) { + setNSLBCommConfig(&commConfig); + } + npuGuard.set_index(devices[i].index()); switch (commType) { case HcclCommType::DEFAULT: @@ -2306,6 +2359,9 @@ bool ProcessGroupHCCL::createHCCLCommEx( return false; } hcclComms[i] = subComm; + if (commType == HcclCommType::P2P) { + hcclComms[i]->p2pPeer = getP2pPeer(); + } // Creates the HCCL streams streamVal.push_back(getNPUStreamByCurrentType(devices[i].index())); } @@ -2409,6 +2465,14 @@ std::vector>& ProcessGroupHCCL::createHCCLComm( // Move the HCCL resource to cache devHCCLCommMap_.emplace(devicesKey, std::move(hcclComms)); + if (commType == HcclCommType::P2P) { + auto iter = p2pSendRecvKeys_.find(rank_); + if (iter == p2pSendRecvKeys_.end()) { + p2pSendRecvKeys_.emplace(rank_, std::vector{devicesKey}); + } else { + iter->second.push_back(devicesKey); + } + } return devHCCLCommMap_[devicesKey]; } @@ -2419,7 +2483,13 @@ int64_t ProcessGroupHCCL::getStreamId(bool p2p, int peer) std::vector devices = {at::Device(c10::DeviceType::PrivateUse1, device)}; auto key = getKeyFromDevices(devices); if (p2p && hcclCommInitRootInfoConfigExist() && c10_npu::option::OptionsManager::GetP2PBufferSize() != 0) { - TORCH_CHECK(peer >= 0, "In p2p scenarios, the passed 'dst rank id' is error.", DIST_ERROR(ErrCode::PARAM)); + TORCH_CHECK( + peer >= 0, + "In p2p scenarios, the passed 'dst rank id' : ", + peer, + " is error, ", + "expected value >= 0.", + DIST_ERROR(ErrCode::PARAM)); key = getKeySendRecv(rank_, peer); } if ((hcclStreams_.count(key) == 0) || hcclStreams_[key].empty()) { @@ -2746,7 +2816,7 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) { at::Device device = at::Device(c10::DeviceType::PrivateUse1, device_id); std::vector devices = {device}; - const auto key = getKeyFromDevices(devices); + auto key = getKeyFromDevices(devices); { std::lock_guard lock(mutex_); @@ -2758,6 +2828,19 @@ void ProcessGroupHCCL::resumeHcclComm(int device_id) HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); } } + if (p2pSendRecvKeys_.find(rank_) != p2pSendRecvKeys_.end()) { + auto p2pKeys = p2pSendRecvKeys_[rank_]; + for (const auto& p2pKey : p2pKeys) { + if (devHCCLCommMap_.find(p2pKey) != devHCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + auto& hcclComms = devHCCLCommMap_[p2pKey]; + for (const auto& hcclComm : hcclComms) { + auto comm = hcclComm->getHcclComm(); + HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm)); + } + } + } + } } ASCEND_LOGI("resumeHcclComm success, group id is %s.", options_->group_id.c_str()); } @@ -3097,6 +3180,22 @@ HcclCommConfig ProcessGroupHCCL::createHcclCommConfigWithOptions() } } + if (options_->hccl_config.find("hccl_world_rank_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_world_rank_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_world_rank_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_world_rank_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + + if (options_->hccl_config.find("hccl_job_id") != options_->hccl_config.end()) { + if (std::holds_alternative(options_->hccl_config["hccl_job_id"])) { + config.hcclOpExpansionMode = std::get(options_->hccl_config["hccl_job_id"]); + } else { + TORCH_CHECK(false, "Value type of hccl_job_id should be int.", DIST_ERROR(ErrCode::TYPE)); + } + } + return config; } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp index 4021373b52b42290db011dc93094df4784e99842..940075105b6f53e062208290104951f9650dcb8f 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.hpp @@ -384,7 +384,7 @@ public: return c10::make_intrusive(_is_high_priority_stream); } - std::unordered_map> hccl_config; + std::unordered_map> hccl_config; std::chrono::milliseconds opTimeout; // Schedule HCCL operations on high priority CUDA streams @@ -571,6 +571,8 @@ public: void resumeHcclComm(int device_id); + void setNSLBCommConfig(HcclCommConfig** commConfig); + bool setCommWorkingDevNic( const HcclComm& comm, int nranks, @@ -746,6 +748,8 @@ protected: // // Note that the order of the device for the tensor list matters. std::unordered_map>> devHCCLCommMap_; + + std::unordered_map> p2pSendRecvKeys_; std::unordered_map devHCCLCommNameMap_; @@ -953,6 +957,10 @@ protected: static std::string exceptionMessage_; + std::string tcpMasterAddr; + + uint32_t tcpMasterPort; + private: // Helper that encapsulates work shared across all collective communication // primitives. diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.h b/torch_npu/csrc/framework/utils/CalcuOpUtil.h index 9a4a8024435cf77db1a1aba49e22cf73b580062f..c249a332ba24c19cc4879b08e8b7f053ddd23eee 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.h +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.h @@ -36,14 +36,23 @@ using std::vector; #define ASCEND_ALWAYS_INLINE inline #endif -#define ACL_REQUIRE_OK_OP(expr, opstr) \ - do { \ - if (ASCEND_UNLIKELY((expr) != 0)) { \ - std::cout << (opstr) << std::endl; \ - TORCH_CHECK((expr) == 0, __func__, ":", __FILE__, ":", __LINE__, \ - " NPU error,NPU error code is:", expr, "\n", \ - c10_npu::acl::AclGetErrMsg(), OPS_ERROR(ErrCode::INTERNAL)); \ - } \ +#define ACL_REQUIRE_OK_OP(expr, opstr) \ + do { \ + if (ASCEND_UNLIKELY((expr) != 0)) { \ + std::cout << (opstr) << std::endl; \ + if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) { \ + std::ostringstream oss; \ + oss << " NPU error,NPU error code is:" << (expr) << "\n" \ + << OPS_ERROR(ErrCode::INTERNAL); \ + std::string err_msg=oss.str(); \ + ASCEND_LOGE("%s", err_msg.c_str()); \ + TORCH_CHECK((expr) == 0, c10_npu::c10_npu_get_error_message()); \ + } else { \ + TORCH_CHECK((expr) == 0, __func__, ":", __FILE__, ":", __LINE__, \ + " NPU error,NPU error code is:", expr, "\n", \ + c10_npu::acl::AclGetErrMsg(), OPS_ERROR(ErrCode::INTERNAL)); \ + } \ + } \ } while (0) using StorageAndOffsetMemSizePair = std::pair; diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 123e14f028f02bd9183d65ab3f7672ab3f53a363..234f0f37e7afdc5846ae727c9f1456feaf7b13f4 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -1032,6 +1032,7 @@ PyObject* THNPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) py::str requested_size_s = "requested_size"; py::str stream_s = "stream"; py::str segment_type_s = "segment_type"; + py::str segment_pool_id = "segment_pool_id"; py::str large_s = "large"; py::str small_s = "small"; py::str size_s = "size"; @@ -1071,6 +1072,7 @@ PyObject* THNPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) // represent the stream rather than a torch.cuda.stream object segmentDict[stream_s] = int64_t(segmentInfo.stream); segmentDict[segment_type_s] = (segmentInfo.is_large ? large_s : small_s); + segmentDict[segment_pool_id] = segmentInfo.owner_private_pool_id; segmentDict[is_expandable_s] = segmentInfo.is_expandable; add_frame_key(segmentDict, segmentInfo.context_when_allocated); @@ -1694,6 +1696,50 @@ static PyObject* THNPModule_add_p2p_access(PyObject* self, PyObject *args) END_HANDLE_TH_ERRORS } +static PyObject* THNPModule_set_device_res_limit(PyObject* self, PyObject *args) +{ + HANDLE_TH_ERRORS + PyObject* device = nullptr; + PyObject* type = nullptr; + PyObject* value = nullptr; + + if (!PyArg_ParseTuple(args, "OOO", &device, &type, &value)) { + throw torch::TypeError("Pybind failed to parse parameters." + + PTA_ERROR(ErrCode::TYPE)); + } + int32_t device_ = THPUtils_unpackLong(device); + int32_t type_ = THPUtils_unpackLong(type); + uint32_t value_ = static_cast(THPUtils_unpackUInt32(value)); + c10_npu::SetDeviceResLimit(device_, type_, value_); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THNPModule_get_device_res_limit(PyObject* self, PyObject *args) +{ + HANDLE_TH_ERRORS + PyObject* device = nullptr; + PyObject* type = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &device, &type)) { + throw torch::TypeError("Pybind failed to parse parameters." + + PTA_ERROR(ErrCode::TYPE)); + } + int32_t device_ = THPUtils_unpackLong(device); + int32_t type_ = THPUtils_unpackLong(type); + uint32_t value = c10_npu::GetDeviceResLimit(device_, type_); + return PyLong_FromUnsignedLong(value); + END_HANDLE_TH_ERRORS +} + +static PyObject* THNPModule_reset_device_res_limit(PyObject* self, PyObject *args) +{ + HANDLE_TH_ERRORS + int32_t device = THPUtils_unpackLong(args); + c10_npu::ResetDeviceResLimit(device); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, @@ -1758,6 +1804,9 @@ static struct PyMethodDef THNPModule_methods[] = { {"_is_gte_cann_version", (PyCFunction)THNPModule_is_gte_cann_version, METH_VARARGS, nullptr}, {"_add_ipc_pid", (PyCFunction)THNPModule_add_ipc_pid, METH_VARARGS, nullptr}, {"_add_p2p_access", (PyCFunction)THNPModule_add_p2p_access, METH_VARARGS, nullptr}, + {"_npu_get_device_res_limit", (PyCFunction)THNPModule_get_device_res_limit, METH_VARARGS, nullptr}, + {"_npu_set_device_res_limit", (PyCFunction)THNPModule_set_device_res_limit, METH_VARARGS, nullptr}, + {"_npu_reset_device_res_limit", (PyCFunction)THNPModule_reset_device_res_limit, METH_O, nullptr}, {nullptr}}; TORCH_NPU_API PyMethodDef* THNPModule_get_methods() diff --git a/torch_npu/csrc/npu/memory_snapshot.cpp b/torch_npu/csrc/npu/memory_snapshot.cpp index 9f0aadbcd7628db5d6bb94df9ad8aa7c6eecbb72..9e6bf3c74b10a16d91de05cd7e9ceb01c62345d8 100644 --- a/torch_npu/csrc/npu/memory_snapshot.cpp +++ b/torch_npu/csrc/npu/memory_snapshot.cpp @@ -159,6 +159,7 @@ std::string _memory_snapshot_pickled() c10::IValue requested_size_s = "requested_size"; c10::IValue stream_s = "stream"; c10::IValue segment_type_s = "segment_type"; + c10::IValue segment_pool_id = "segment_pool_id"; c10::IValue large_s = "large"; c10::IValue small_s = "small"; c10::IValue size_s = "size"; @@ -200,6 +201,8 @@ std::string _memory_snapshot_pickled() segmentDict.insert(stream_s, int64_t(segmentInfo.stream)); segmentDict.insert(segment_type_s, (segmentInfo.is_large ? large_s : small_s)); + segmentDict.insert(segment_pool_id, + std::tuple(segmentInfo.owner_private_pool_id)); segmentDict.insert(is_expandable_s, segmentInfo.is_expandable); add_frame_key(segmentDict, segmentInfo.context_when_allocated); diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index bd94890b27e32637643e927d41bc6f2a85eafe30..4e8e94fbfc515e2f66c58ad3ef0a96c47d529644 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -115,7 +115,9 @@ __all__ = [ "graph_task_group_begin", "graph_task_group_end", "graph_task_update_begin", - "graph_task_update_end" + "graph_task_update_end", + "set_device_res_limit", + "get_device_res_limit" ] from typing import Tuple, Union diff --git a/torch_npu/npu/_format.py b/torch_npu/npu/_format.py new file mode 100644 index 0000000000000000000000000000000000000000..beb65e076f74f5537daf4bcc76a58eaae4fdedbd --- /dev/null +++ b/torch_npu/npu/_format.py @@ -0,0 +1,38 @@ +from enum import IntEnum + +import torch +import torch_npu + + +class Format(IntEnum): + """NPU storage format enumeration class""" + UNDEFINED = -1 + NCHW = 0 + NHWC = 1 + ND = 2 + NC1HWC0 = 3 + FRACTAL_Z = 4 + NC1HWC0_C04 = 12 + HWCN = 16 + NDHWC = 27 + FRACTAL_NZ = 29 + NCDHW = 30 + NDC1HWC0 = 32 + FRACTAL_Z_3D = 33 + NC = 35 + NCL = 47 + + def __str__(self): + return self.name + + +def _apply_npu_format_patch(): + orig_get_format = torch_npu.get_npu_format + + def patched_get_format(tensor): + """get the Format type of tensor""" + format_int = orig_get_format(tensor) + return Format(format_int) + + torch_npu.get_npu_format = patched_get_format + torch_npu.Format = Format diff --git a/torch_npu/npu/npu_config.py b/torch_npu/npu/npu_config.py index f2a5104920fcb342df1971f56e64bcb922a0f4c1..f10f5c8c7a674785479b3a5f7b975bcd2294068b 100644 --- a/torch_npu/npu/npu_config.py +++ b/torch_npu/npu/npu_config.py @@ -5,12 +5,14 @@ import warnings import torch_npu._C from torch_npu.utils._path_manager import PathManager from torch_npu.utils._error_code import ErrCode, pta_error, prof_error +from .utils import _get_device_index # this file is used to enhance the npu frontend API by set_option or other. __all__ = ["set_option", "set_aoe", "set_compile_mode", "set_mm_bmm_format_nd", "get_mm_bmm_format_nd", - "is_jit_compile_false", "finalize_dump", "init_dump", "set_dump"] + "is_jit_compile_false", "finalize_dump", "init_dump", "set_dump", + "set_device_res_limit", "get_device_res_limit", "reset_device_res_limit"] _option_map = {"ACL_PRECISION_MODE": ["allow_fp32_to_fp16", "must_keep_origin_dtype"], "ACL_OP_SELECT_IMPL_MODE": ["high_performance", "high_precision"], @@ -169,3 +171,42 @@ class _allowHF32Conv: hf32_value = torch_npu._C._npu_getOption("ALLOW_CONV_HF32") return (hf32_value is None) or (hf32_value.decode() == "") or (hf32_value.decode() == "enable") return None + + +class _call_once_class: + def __init__(self, func): + self.func = func + self.called = False + self.result = None + + def __call__(self, *args, **kwargs): + if self.called: + raise RuntimeError(f"Function '{self.func.__name__}' has already been called, \ + You can only set this interface once.") + + self.called = True + self.result = self.func(*args, **kwargs) + return self.result + + +@_call_once_class +def set_device_res_limit(device, cube_num=-1, vector_num=-1): + from torch_npu.npu import device_count + device_id = _get_device_index(device, optional=True) + if device_id < 0 or device_id >= device_count(): + raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE)) + torch_npu.npu._lazy_init() + if cube_num != -1: + torch_npu._C._npu_set_device_res_limit(device_id, 0, cube_num) + if vector_num != -1: + torch_npu._C._npu_set_device_res_limit(device_id, 1, vector_num) + + +def get_device_res_limit(device): + from torch_npu.npu import device_count + device_id = _get_device_index(device, optional=True) + if device_id < 0 or device_id >= device_count(): + raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE)) + torch_npu.npu._lazy_init() + return {"cube_num": torch_npu._C._npu_get_device_res_limit(device_id, 0), \ + "vector_num": torch_npu._C._npu_get_device_res_limit(device_id, 1)} \ No newline at end of file diff --git a/torch_npu/profiler/analysis/prof_common_func/_log.py b/torch_npu/profiler/analysis/prof_common_func/_log.py index eba5db1af7f74910d1afd3a1fcf47bfb2a928098..0fecde48c41b465cf04eff26282a02911655c032 100644 --- a/torch_npu/profiler/analysis/prof_common_func/_log.py +++ b/torch_npu/profiler/analysis/prof_common_func/_log.py @@ -57,14 +57,15 @@ class ProfilerLogger: if cls._instance is not None: if cls._pid == os.getpid(): return - cls.destroy() # Create logs directory log_dir = os.path.join(output_dir, cls.DEFAULT_LOG_DIR) PathManager.make_dir_safety(log_dir) # Create logger - logger = logging.getLogger(cls.DEFAULT_LOGGER_NAME) + logger = logging.getLogger( + f"{cls.DEFAULT_LOGGER_NAME}_{custom_name}" if custom_name else cls.DEFAULT_LOGGER_NAME + ) logger.setLevel(cls.DEFAULT_LOG_LEVEL) logger.propagate = False @@ -112,19 +113,11 @@ class ProfilerLogger: def destroy(cls) -> None: """ Close and cleanup the logger. - To avoid the deadlock problem caused by directly calling close on handler in multi-process scenarios, close the - file descriptor manually. + To avoid the deadlock problem caused by directly calling close on handler in multi-process scenarios, + when child process updates instance, the parent process instance obtained by fork does not call this method. """ if cls._instance: for handler in cls._instance.handlers[:]: cls._instance.removeHandler(handler) - if cls._pid == os.getpid(): - handler.close() - else: - try: - if hasattr(handler.stream, 'fileno'): - fileno = handler.stream.fileno() - os.close(fileno) - except (OSError, AttributeError, ValueError): - logging.warning("Close profiler logger handler stream failed.") + handler.close() cls._instance = None diff --git a/torch_npu/profiler/analysis/prof_view/_communication_parser.py b/torch_npu/profiler/analysis/prof_view/_communication_parser.py index fff6d265d6ceb5198681e78956b6268efc732cb9..e07f68b785b31eb509602a99a12760fad476a5f3 100644 --- a/torch_npu/profiler/analysis/prof_view/_communication_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_communication_parser.py @@ -46,8 +46,6 @@ class CommunicationParser(BaseParser): self._root_node = TorchOpNode() self._kernel_dict = {} self.step_list = [] - ProfilerLogger.init(self._profiler_path, "CommunicationParser") - self.logger = ProfilerLogger.get_instance() @staticmethod def combine_size_distribution(op_dict: dict, total_dict: dict): @@ -63,6 +61,8 @@ class CommunicationParser(BaseParser): return round(dividend / divisor, 4) def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "CommunicationParser") + self.logger = ProfilerLogger.get_instance() try: self._init_step_list(deps_data) self.generate_view() diff --git a/torch_npu/profiler/analysis/prof_view/_integrate_parser.py b/torch_npu/profiler/analysis/prof_view/_integrate_parser.py index b6c545420c3bb961640c7ef25dc54e8050fad6ae..28472a241177ed4f8f13c7b090e02a98db1113c2 100644 --- a/torch_npu/profiler/analysis/prof_view/_integrate_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_integrate_parser.py @@ -26,10 +26,10 @@ class IntegrateParser(BaseParser): def __init__(self, name: str, param_dict: dict): super().__init__(name, param_dict) - ProfilerLogger.init(self._profiler_path, "IntegrateParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "IntegrateParser") + self.logger = ProfilerLogger.get_instance() try: ProfilerConfig().load_info(self._profiler_path) self.generate_view() diff --git a/torch_npu/profiler/analysis/prof_view/_kernel_view_parser.py b/torch_npu/profiler/analysis/prof_view/_kernel_view_parser.py index 30ffd8be8ba46e0b8cc5ac1300c4eba389211eaa..ded9a612c6cfd98a7076fb749457e0c3da9aa44c 100644 --- a/torch_npu/profiler/analysis/prof_view/_kernel_view_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_kernel_view_parser.py @@ -17,8 +17,6 @@ class KernelViewParser(BaseParser): def __init__(self, name: str, param_dict: dict): super().__init__(name, param_dict) self.step_range = [] - ProfilerLogger.init(self._profiler_path, "KernelViewParser") - self.logger = ProfilerLogger.get_instance() @classmethod def _project_map_for_headers(cls, input_headers: list): @@ -35,6 +33,8 @@ class KernelViewParser(BaseParser): return output_headers def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "KernelViewParser") + self.logger = ProfilerLogger.get_instance() try: ProfilerConfig().load_info(self._profiler_path) self._init_step_range(deps_data) diff --git a/torch_npu/profiler/analysis/prof_view/_memory_view_parser.py b/torch_npu/profiler/analysis/prof_view/_memory_view_parser.py index a82c3dc3c8f08ebe6875f0b7a5e59730c6cf4e6e..47255efd09dbdca635e4888fd575f311fbcff5ef 100644 --- a/torch_npu/profiler/analysis/prof_view/_memory_view_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_memory_view_parser.py @@ -34,8 +34,6 @@ class MemoryViewParser(BaseParser): self.ge_record_list = [] self.memory_data = [] self.component_list = [] - ProfilerLogger.init(self._profiler_path, "MemoryViewParser") - self.logger = ProfilerLogger.get_instance() @staticmethod def _get_data_from_file(file_set: set, file_type_bean: any, bean_list: bool = False) -> list: @@ -73,6 +71,8 @@ class MemoryViewParser(BaseParser): return [cur_record_list, pta_ge_record_list] def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "MemoryViewParser") + self.logger = ProfilerLogger.get_instance() try: self.memory_data = deps_data.get(Constant.MEMORY_PREPARE, {}).get("memory_data", {}).get(Constant.Text, []) self.pta_record_list = deps_data.get(Constant.MEMORY_PREPARE, {}).get("pta_record_list", []) diff --git a/torch_npu/profiler/analysis/prof_view/_operator_view_parser.py b/torch_npu/profiler/analysis/prof_view/_operator_view_parser.py index f87e8dc8b85e7f35097afd2666194f7cd0311b68..7c10e9d4bf45c2881fb8bd04ae3c2b1124f578c5 100644 --- a/torch_npu/profiler/analysis/prof_view/_operator_view_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_operator_view_parser.py @@ -22,10 +22,10 @@ class OperatorViewParser(BaseParser): self._torch_op_node = [] self._root_node = None self._kernel_dict = {} - ProfilerLogger.init(self._profiler_path, "OperatorViewParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "OperatorViewParser") + self.logger = ProfilerLogger.get_instance() try: self._torch_op_node = deps_data.get(Constant.TREE_BUILD_PARSER, []) self._kernel_dict = deps_data.get(Constant.RELATION_PARSER, {}) diff --git a/torch_npu/profiler/analysis/prof_view/_stack_view_parser.py b/torch_npu/profiler/analysis/prof_view/_stack_view_parser.py index 2f793a8af8b611559613799a004531224c366590..b4a85271d99034e55936d682e9b4748f6251cf11 100644 --- a/torch_npu/profiler/analysis/prof_view/_stack_view_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_stack_view_parser.py @@ -23,10 +23,10 @@ class StackViewParser(BaseParser): self._root_node = None self._kernel_dict = {} self._metric = param_dict.get("metric") - ProfilerLogger.init(self._profiler_path, "StackViewParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "StackViewParser") + self.logger = ProfilerLogger.get_instance() try: self._torch_op_node = deps_data.get(Constant.TREE_BUILD_PARSER, []) self.generate_view() diff --git a/torch_npu/profiler/analysis/prof_view/_trace_step_time_parser.py b/torch_npu/profiler/analysis/prof_view/_trace_step_time_parser.py index b5e0502ee410027dea1bc9d0f2b324c969bf26c3..bcdb7d2c6eb3092cfee64b681534cab1357ba89c 100644 --- a/torch_npu/profiler/analysis/prof_view/_trace_step_time_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_trace_step_time_parser.py @@ -51,8 +51,6 @@ class TraceStepTimeParser(BaseParser): def __init__(self, name: str, param_dict: dict): super().__init__(name, param_dict) self.step_range = [] - ProfilerLogger.init(self._profiler_path, "TraceStepTimeParser") - self.logger = ProfilerLogger.get_instance() @classmethod def is_float_num(cls, num): @@ -165,6 +163,8 @@ class TraceStepTimeParser(BaseParser): FileManager.create_csv_file(output_path, print_time, file_name, self.title) def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "TraceStepTimeParser") + self.logger = ProfilerLogger.get_instance() try: self._init_step_range(deps_data) self.generate_view() diff --git a/torch_npu/profiler/analysis/prof_view/_trace_view_parser.py b/torch_npu/profiler/analysis/prof_view/_trace_view_parser.py index f90100e869fd4c4ea92661dd2183b8fd20808412..c5e572e1bcfeba5ecaa4c4e6db93b47c896392eb 100644 --- a/torch_npu/profiler/analysis/prof_view/_trace_view_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_trace_view_parser.py @@ -27,8 +27,6 @@ class TraceViewParser(BaseParser): self._trace_data = [] self._torch_op_node = [] self._root_node = None - ProfilerLogger.init(self._profiler_path, "TraceViewParser") - self.logger = ProfilerLogger.get_instance() @staticmethod def _prune_trace_by_level(json_data: list) -> list: @@ -47,6 +45,8 @@ class TraceViewParser(BaseParser): return result def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "TraceViewParser") + self.logger = ProfilerLogger.get_instance() try: ProfilerConfig().load_info(self._profiler_path) torch_op_node = deps_data.get(Constant.TREE_BUILD_PARSER, []) diff --git a/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_analyze.py b/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_analyze.py index d5b577eaee1a53c1a3cf8b4f42293d712fce6083..9c1916753f7845a005918d56b0ffceb17ebdcc00 100644 --- a/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_analyze.py +++ b/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_analyze.py @@ -34,10 +34,10 @@ class CANNAnalyzeParser(BaseParser): super().__init__(name, param_dict) self._cann_path = ProfilerPathManager.get_cann_path(self._profiler_path) self.msprof_path = shutil.which("msprof") - ProfilerLogger.init(self._profiler_path, "CANNAnalyzeParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "CANNAnalyzeParser") + self.logger = ProfilerLogger.get_instance() try: ProfilerConfig().load_info(self._profiler_path) if not os.path.isdir(self._cann_path): diff --git a/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_export.py b/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_export.py index 49d4e7eb8f6ac9b5d08f6de0177274ce148bd9b7..2b13bc25e7976b3ef1815f276cc067a6576120f5 100644 --- a/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_export.py +++ b/torch_npu/profiler/analysis/prof_view/cann_parse/_cann_export.py @@ -41,10 +41,10 @@ class CANNExportParser(BaseParser): super().__init__(name, param_dict) self._cann_path = ProfilerPathManager.get_cann_path(self._profiler_path) self.msprof_path = shutil.which("msprof") - ProfilerLogger.init(self._profiler_path, "CANNExportParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "CANNExportParser") + self.logger = ProfilerLogger.get_instance() try: ProfilerConfig().load_info(self._profiler_path) if not os.path.isdir(self._cann_path): diff --git a/torch_npu/profiler/analysis/prof_view/prepare_parse/_fwk_pre_parser.py b/torch_npu/profiler/analysis/prof_view/prepare_parse/_fwk_pre_parser.py index 490488d5e15703e576dd29ad776bd179a71f1add..a54ec86d4063512977acd4d6314b34f7f1f3616e 100644 --- a/torch_npu/profiler/analysis/prof_view/prepare_parse/_fwk_pre_parser.py +++ b/torch_npu/profiler/analysis/prof_view/prepare_parse/_fwk_pre_parser.py @@ -28,10 +28,10 @@ class TracePreParser(BaseParser): def __init__(self, name: str, param_dict: dict): super().__init__(name, param_dict) - ProfilerLogger.init(self._profiler_path, "TracePreParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "TracePreParser") + self.logger = ProfilerLogger.get_instance() try: fwk_trace_data = FwkFileParser(self._profiler_path).get_fwk_trace_data() trace_file_path = os.path.join(self._output_path, Constant.TRACE_VIEW_TEMP) if os.path.isdir( diff --git a/torch_npu/profiler/analysis/prof_view/prepare_parse/_relation_parser.py b/torch_npu/profiler/analysis/prof_view/prepare_parse/_relation_parser.py index 86e8c1e9ea27463adda9b8f1abb805e4395dab5a..27437eaa654bf55529ec7f6c7e7577d4c237d440 100644 --- a/torch_npu/profiler/analysis/prof_view/prepare_parse/_relation_parser.py +++ b/torch_npu/profiler/analysis/prof_view/prepare_parse/_relation_parser.py @@ -23,10 +23,10 @@ __all__ = [] class RelationParser(BaseParser): def __init__(self, name: str, param_dict: dict): super().__init__(name, param_dict) - ProfilerLogger.init(self._profiler_path, "RelationParser") - self.logger = ProfilerLogger.get_instance() def run(self, deps_data: dict): + ProfilerLogger.init(self._profiler_path, "RelationParser") + self.logger = ProfilerLogger.get_instance() try: kernel_dict = FwkCANNRelationParser(self._profiler_path).get_kernel_dict() except Exception as e: