From 4ebb8c6c99279ac486f599849f0b32a74ae473a5 Mon Sep 17 00:00:00 2001 From: x00430237 Date: Fri, 15 Dec 2023 15:08:08 +0800 Subject: [PATCH 1/2] fake # Conflicts: # torch_npu/__init__.py --- torch_npu/__init__.py | 2 + torch_npu/csrc/core/npu/NPUMocker.cpp | 429 ++++++++++++++++++ torch_npu/csrc/core/npu/NPUMocker.h | 11 + torch_npu/csrc/core/npu/NPUQueue.cpp | 5 + .../csrc/core/npu/register/FunctionLoader.cpp | 9 + torch_npu/csrc/toolkit/CMakeLists.txt | 1 + torch_npu/csrc/toolkit/mocker/CMakeLists.txt | 23 + .../csrc/toolkit/mocker/inc/mocker_defines.h | 30 ++ .../csrc/toolkit/mocker/src/acl_mocker.cpp | 216 +++++++++ .../csrc/toolkit/mocker/src/hccl_mocker.cpp | 48 ++ .../toolkit/mocker/src/metrics_config.cpp | 42 ++ torch_npu/mocker/mocker.py | 45 ++ 12 files changed, 861 insertions(+) create mode 100644 torch_npu/csrc/core/npu/NPUMocker.cpp create mode 100644 torch_npu/csrc/core/npu/NPUMocker.h create mode 100644 torch_npu/csrc/toolkit/mocker/CMakeLists.txt create mode 100644 torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h create mode 100644 torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp create mode 100644 torch_npu/csrc/toolkit/mocker/src/hccl_mocker.cpp create mode 100644 torch_npu/csrc/toolkit/mocker/src/metrics_config.cpp create mode 100644 torch_npu/mocker/mocker.py diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index ec083ecfc3..25aaf29216 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -26,6 +26,8 @@ from typing import Set, Type from functools import wraps import torch + +from torch_npu.mocker import mocker import torch_npu try: diff --git a/torch_npu/csrc/core/npu/NPUMocker.cpp b/torch_npu/csrc/core/npu/NPUMocker.cpp new file mode 100644 index 0000000000..c3a1f011a2 --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUMocker.cpp @@ -0,0 +1,429 @@ +#include "torch_npu/csrc/core/npu/NPUEventManager.h" +#include "torch_npu/csrc/core/npu/NPUQueue.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/NPUDefine.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h" + +#include "ATen/record_function.h" + +#include +#include +#include +#include + +#include "acl/acl_op_compiler.h" +namespace c10_npu { +namespace { +class Task { +public: + Task(queue::QueueParamType type, aclrtStream stream) + : type(type), stream(stream) {} + Task(queue::QueueParamType type, aclrtStream stream, + std::vector &&data) + : Task(type, stream) { + this->data = data; + } + Task(queue::QueueParamType type, aclrtStream stream, void *raw_data, + size_t size) + : Task(type, stream) { + data.resize(size); + memcpy(data.data(), raw_data, size); + } + template [[nodiscard]] const T *Data() const { + return reinterpret_cast(data.data()); + } + queue::QueueParamType type; + aclrtStream stream; + std::vector data; + std::string label; +}; + +class TorchQueueTaskHandle { +public: + explicit TorchQueueTaskHandle(aclrtStream stream) : stream_(stream) {} + virtual ~TorchQueueTaskHandle() = default; + int ProcTask(const Task &task) { + int before, inner, after = 0; + GetCost(task.type, &before, &inner, &after); + if (before > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(before)); + } + int result = Proc(task, inner); + if (after > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(after)); + } + return result; + } + virtual int Proc(const Task &task, int cost) = 0; + Task PackTask(queue::QueueParamType type, void *data, size_t size) { + Task task(type, stream_); + task.label = MakeLabel(type, data, size); + Pack(task, type, data, size); + return task; + } + static void GetCost(queue::QueueParamType type, int *before, int *inner, + int *after); + static std::string MakeLabel(queue::QueueParamType type, void *data, + size_t size) { + static_cast(size); + const static std::vector kQueueType = { + "Unknown", "Execute", "MemcpyAsync", "RecordEvent", + "WaitEvent", "LazyDestroyEvent", "ResetEvent"}; + if (type == queue::QueueParamType::COMPILE_AND_EXECUTE) { + auto param = static_cast(data); + std::stringstream label; + label << kQueueType[static_cast(type)] << "(" << param->opType + << ")"; + return label.str(); + } + if (type == queue::QueueParamType::ASYNC_MEMCPY) { + auto param = static_cast(data); + std::stringstream label; + label << kQueueType[static_cast(type)]; + const static std::vector kMemcpyKind = {"H2H", "H2D", "D2H", + "D2D"}; + label << (param->kind >= kMemcpyKind.size() || param->kind < 0 + ? "Kind" + std::to_string(param->kind) + : kMemcpyKind[param->kind]); + label << "(" << param->dst << ", " << param->src << ")"; + return label.str(); + } + if (type == queue::QueueParamType::RECORD_EVENT || + type == queue::QueueParamType::WAIT_EVENT || + type == queue::QueueParamType::LAZY_DESTROY_EVENT || + type == queue::QueueParamType::RESET_EVENT) { + auto param = static_cast(data); + std::stringstream label; + label << kQueueType[static_cast(type)] << "(" << param->event << ")"; + return label.str(); + } + return kQueueType.front() + std::to_string(type); + } + +protected: + virtual void Pack(Task &task, queue::QueueParamType type, void *data, + size_t size) {} + aclrtStream stream_; +}; + +#define TASK_COST_TIME() \ + do { \ + if (cost > 0) { \ + std::this_thread::sleep_for(std::chrono::microseconds(cost)); \ + } \ + } while (false) + +class CompileAndExecute : public TorchQueueTaskHandle { +public: + explicit CompileAndExecute(aclrtStream stream) + : TorchQueueTaskHandle(stream) {} + int Proc(const Task &task, int cost) override { + TASK_COST_TIME(); + return 0; + } +}; + +class AsyncMemcpy : public TorchQueueTaskHandle { +public: + explicit AsyncMemcpy(aclrtStream stream) : TorchQueueTaskHandle(stream) {} + int Proc(const Task &task, int cost) override { + TASK_COST_TIME(); + return 0; + } +}; + +class Event : public TorchQueueTaskHandle { +public: + explicit Event(aclrtStream stream) : TorchQueueTaskHandle(stream) {} + int Proc(const Task &task, int cost) override { + TASK_COST_TIME(); + auto event = *task.Data(); + switch (task.type) { + case queue::QueueParamType::RECORD_EVENT: { + aclrtRecordEvent(event, stream_); + c10_npu::NPUEventManager::GetInstance().DecreaseUnrecordedCount(event); + aclrtSynchronizeStream(stream_); + break; + } + case queue::QueueParamType::WAIT_EVENT: { + aclrtStreamWaitEvent(stream_, event); + aclrtSynchronizeStream(stream_); + break; + } + case queue::QueueParamType::LAZY_DESTROY_EVENT: { + c10_npu::NPUEventManager::GetInstance().LazyDestroy(event); + break; + } + case queue::QueueParamType::RESET_EVENT: { + aclrtResetEvent(event, stream_); + aclrtSynchronizeStream(stream_); + break; + } + default: + break; + } + return 0; + } + void Pack(Task &task, queue::QueueParamType type, void *data, + size_t size) override { + auto param = static_cast(data); + task.data.resize(sizeof(aclrtEvent), 0U); + memcpy(task.data.data(), ¶m->event, sizeof(aclrtEvent)); + } +}; + +class Default : public TorchQueueTaskHandle { +public: + explicit Default(aclrtStream stream) : TorchQueueTaskHandle(stream) {} + int Proc(const Task &task, int cost) override { + TASK_COST_TIME(); + return 0; + } +}; + +class Stream { +public: + explicit Stream(aclrtStream stream) : stream_(stream) { + auto event_handle = std::make_shared(stream); + handles_ = { + std::make_shared(stream), // OTHERS + std::make_shared(stream), // COMPILE_AND_EXECUTE + std::make_shared(stream), // ASYNC_MEMCPY + event_handle, // RECORD_EVENT + event_handle, // WAIT_EVENT + event_handle, // LAZY_DESTROY_EVENT + event_handle, // RESET_EVENT + }; + + int32_t device = 0; + aclrtGetDevice(&device); + worker_ = std::thread([this, device]() { + aclrtSetDevice(device); + while (true) { + std::unique_lock lock(mutex_); + if (tasks_.empty() && !cancel_) { + cond_.wait(lock); + } + if (cancel_) { + while (!tasks_.empty()) { + auto task = tasks_.front(); + tasks_.pop(); + Proc(task); + pendingTaskNum_--; + } + return; + } + auto task = tasks_.front(); + tasks_.pop(); + lock.unlock(); + Proc(task); + pendingTaskNum_--; + } + }); + } + + ~Stream() { + { + std::unique_lock lock(mutex_); + cancel_ = true; + cond_.notify_one(); + } + worker_.join(); + } + + std::shared_ptr GetHandle(queue::QueueParamType type) { + if (type > queue::QueueParamType::RESET_EVENT || + type < queue::QueueParamType::COMPILE_AND_EXECUTE) { + return handles_.front(); + } + return handles_[type]; + } + + int Launch(queue::QueueParamType type, void *data, size_t size) { + std::shared_lock sharedLock(launch_mutex_); + auto task = GetHandle(type)->PackTask(type, data, size); + M_DLOG() << "Stream " << stream_ << " launch " << task.label; + RECORD_FUNCTION("Launch@" + task.label, std::vector{}); + std::unique_lock lock(mutex_); + if (cancel_) { + M_DLOG() << "Stream " << stream_ << " skip launch " << task.label + << " as canceled"; + return 0; + } + tasks_.emplace(std::move(task)); + pendingTaskNum_++; + cond_.notify_one(); + return 0; + } + + void Sync() { + std::unique_lock sharedLock(launch_mutex_); + while (pendingTaskNum_ > 0 && !cancel_) { + } + } + + aclError LaunchSync(const std::string &func, const std::string &msg = "") { + RECORD_FUNCTION("CpuSyncExe@" + func, std::vector{}); + Sync(); + M_DLOG() << "Stream " << stream_ << " proc sync op " << func + << ", extra:" << msg; + return ACL_ERROR_NONE; + } + + void Proc(const Task &task) { + M_DLOG() << "Stream " << stream_ << " proc " << task.label; + RECORD_FUNCTION("StreamExe@" + task.label, std::vector{}); + (void)GetHandle(task.type)->ProcTask(task); + } + +private: + aclrtStream stream_; + std::vector> handles_; + + std::thread worker_; + std::queue tasks_; + std::atomic_bool cancel_{}; + std::atomic_int64_t pendingTaskNum_{0}; + + std::mutex mutex_; + std::condition_variable cond_; + + std::shared_mutex launch_mutex_; +}; + +class StreamManager { +public: + static StreamManager &Instance() { + static StreamManager manager; + return manager; + } + int Schedule(queue::QueueParas *task) { + Stream *stream = GetOrCreate(task->paramStream); + queue::QueueParamType type = task->paramType; + void *data = task->paramVal; + size_t size = task->paramLen; + auto ret = stream->Launch(type, data, size); + if (type == queue::QueueParamType::RECORD_EVENT) { + stream->Sync(); + } + return ret; + } + + int ScheduleSync(aclrtStream rtStream, const std::string &func, + const std::string &msg = "") { + Stream *stream = GetOrCreate(rtStream); + return stream->LaunchSync(func, msg); + } + +private: + Stream *GetOrCreate(aclrtStream stream) { + std::unique_lock lock(mutex_); + auto it = streams_.find(stream); + if (it != streams_.end()) { + return it->second.get(); + } + streams_[stream] = std::make_unique(stream); + return streams_[stream].get(); + } + std::mutex mutex_; + std::map> streams_; +}; + +namespace stubs { +aclError aclopCompileAndExecuteV2( + const char *opType, int numInputs, aclTensorDesc *inputDesc[], + aclDataBuffer *inputs[], int numOutputs, aclTensorDesc *outputDesc[], + aclDataBuffer *outputs[], aclopAttr *attr, aclopEngineType engineType, + aclopCompileType compileFlag, const char *opPath, aclrtStream stream) { + return StreamManager::Instance().ScheduleSync( + stream, "aclopCompileAndExecuteV2", opType); +} + +aclError aclrtSynchronizeStream(aclrtStream stream) { + return StreamManager::Instance().ScheduleSync(stream, + "aclrtSynchronizeStream"); +} +aclError aclrtSynchronizeStreamWithTimeout(aclrtStream stream, + int32_t timeout) { + return StreamManager::Instance().ScheduleSync( + stream, "aclrtSynchronizeStreamWithTimeout", + "timeout:" + std::to_string(timeout)); +} + +#define MOCK_FUNC(v) \ + { #v, reinterpret_cast(v) } + +std::map stubFuncs = { + MOCK_FUNC(aclopCompileAndExecuteV2), MOCK_FUNC(aclrtSynchronizeStream), + MOCK_FUNC(aclrtSynchronizeStreamWithTimeout)}; +} // namespace stubs + +class MockerLoader { +public: + static MockerLoader &Instance() { + static MockerLoader loader("libnpu_mocker.so"); + return loader; + } + + ~MockerLoader() { + if (handle_ != nullptr) { + dlclose(handle_); + } + } + + void *Get(const std::string &soName, const std::string &funcName) { + auto iter = stubs::stubFuncs.find(funcName); + if (iter != stubs::stubFuncs.end()) { + return iter->second; + } + if (handle_ == nullptr) { + return nullptr; + } + static_cast(soName); + void *func = dlsym(handle_, funcName.c_str()); + return func; + } + +private: + explicit MockerLoader(std::string mocker) : mocker_(std::move(mocker)) { + if (mocker_.empty()) { + return; + } + handle_ = dlopen(mocker_.c_str(), RTLD_LAZY); + if (handle_ == nullptr) { + M_DLOG() << "Dlopen " << mocker_ << " failed: " << dlerror(); + } + } + std::string mocker_; + void *handle_ = nullptr; +}; + +typedef int (*GetMockerCost)(int type, int *before, int *inner, int *after); +void TorchQueueTaskHandle::GetCost(queue::QueueParamType type, int *before, + int *inner, int *after) { + static GetMockerCost handle = reinterpret_cast( + MockerLoader::Instance().Get("libnpu_mocker.so", "GetMockerCost")); + if (handle == nullptr) { + return; + } + handle(static_cast(type), before, inner, after); +} +} // namespace + +namespace mocker { +bool Enabled() { + static bool kIsMocking = IsEnvEnabled("TORCH_NPU_FAKE_MODE"); + return kIsMocking; +} + +int Launch(void *raw_data) { + return StreamManager::Instance().Schedule( + static_cast(raw_data)); +} + +void *Get(const std::string &soName, const std::string &funcName) { + return MockerLoader::Instance().Get(soName, funcName); +} +} // namespace mocker +} // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NPUMocker.h b/torch_npu/csrc/core/npu/NPUMocker.h new file mode 100644 index 0000000000..291ab99798 --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUMocker.h @@ -0,0 +1,11 @@ +#pragma once +#include +namespace c10_npu { +namespace mocker { +int Launch(void *raw_data); + +void *Get(const std::string &soName, const std::string &funcName); + +bool Enabled(); +} // namespace mocker +} // namespace c10_npu \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUQueue.cpp b/torch_npu/csrc/core/npu/NPUQueue.cpp index 96a3a1df93..d1fd8d52fd 100644 --- a/torch_npu/csrc/core/npu/NPUQueue.cpp +++ b/torch_npu/csrc/core/npu/NPUQueue.cpp @@ -1,3 +1,4 @@ +#include "torch_npu/csrc/core/npu/NPUMocker.h" #include "torch_npu/csrc/core/npu/NPUQueue.h" #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/npu_log.h" @@ -21,6 +22,10 @@ public: CallBackManager() {} ~CallBackManager() {} void SetExec(const ACL_EXEC_FUNC& func) { + if (mocker::Enabled()) { + this->execFunc = mocker::Launch; + return; + } this->execFunc = func; } diff --git a/torch_npu/csrc/core/npu/register/FunctionLoader.cpp b/torch_npu/csrc/core/npu/register/FunctionLoader.cpp index 1595a05bcc..477b11c2e1 100644 --- a/torch_npu/csrc/core/npu/register/FunctionLoader.cpp +++ b/torch_npu/csrc/core/npu/register/FunctionLoader.cpp @@ -18,6 +18,9 @@ #include "torch_npu/csrc/core/npu/register/FunctionLoader.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/core/npu/NPUMocker.h" + namespace c10_npu { namespace option { @@ -83,6 +86,12 @@ namespace register_function { } void* FunctionRegister::Get(const std::string& soName, const std::string& funcName) { + if (ASCEND_UNLIKELY(mocker::Enabled())) { + auto mock_handle = mocker::Get(soName, funcName); + if (mock_handle != nullptr) { + return mock_handle; + } + } auto itr = registry.find(soName); if (itr != registry.end()) { return itr->second->Get(funcName); diff --git a/torch_npu/csrc/toolkit/CMakeLists.txt b/torch_npu/csrc/toolkit/CMakeLists.txt index 6a8bc1f9a8..d88f3485c6 100644 --- a/torch_npu/csrc/toolkit/CMakeLists.txt +++ b/torch_npu/csrc/toolkit/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(profiler) +add_subdirectory(mocker) \ No newline at end of file diff --git a/torch_npu/csrc/toolkit/mocker/CMakeLists.txt b/torch_npu/csrc/toolkit/mocker/CMakeLists.txt new file mode 100644 index 0000000000..b2255ec7c7 --- /dev/null +++ b/torch_npu/csrc/toolkit/mocker/CMakeLists.txt @@ -0,0 +1,23 @@ +set(MOCKER_NAME npu_mocker) + +FILE(GLOB NPU_MOCKER_SRCS + src/*.cpp +) + +set(NPU_MOCKER_INC + ${CMAKE_CURRENT_SOURCE_DIR}/inc +) + +add_library(${MOCKER_NAME} SHARED + ${NPU_MOCKER_SRCS} +) + +target_include_directories(${MOCKER_NAME} PRIVATE + ${NPU_MOCKER_INC} +) + +target_compile_options(${MOCKER_NAME} PRIVATE + ${TORCH_CXX_FLAGS} +) + +target_link_libraries(${MOCKER_NAME} PRIVATE stdc++) diff --git a/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h b/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h new file mode 100644 index 0000000000..f2caf81a66 --- /dev/null +++ b/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include + +namespace mocker { +class Logger : public std::basic_ostringstream { +public: + Logger() { *this << "PID:" << getpid() << " TID:" << gettid() << " "; } + ~Logger() override { std::cerr << str() << std::endl; } +}; +} // namespace mocker + +static bool IsEnvEnabled(const char *envVar) { + auto env = std::getenv(envVar); + return (env != nullptr) && (std::string(env) == "1"); +} + +const static bool kIsDebug = IsEnvEnabled("TORCH_NPU_FAKE_MODE_DEBUG"); + +#define M_DLOG() \ + if (kIsDebug) mocker::Logger() + +#define RECORD_MOCKER_LOADED(V) \ + std::cerr << "PID:" << getpid() << " loaded mocker " << V << std::endl + +#define RECORD_MOCKER_UNLOADED(V) \ + std::cerr << "PID:" << getpid() << " unloaded mocker " << V << std::endl + +#define RECORD_MOCK() M_DLOG() << "Mocking " << __FUNCTION__ \ No newline at end of file diff --git a/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp b/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp new file mode 100644 index 0000000000..217751be31 --- /dev/null +++ b/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp @@ -0,0 +1,216 @@ +#include "dlfcn.h" +#include "stdio.h" +#include "stdlib.h" +#include +#include + +#include "acl/acl_rt.h" + +#include "mocker_defines.h" + +static bool initialized = false; + +static void *realAclHandle = nullptr; + +#define DEF_MOCK_ORIGIN(F, ...) \ + typedef aclError (*F##_t)(__VA_ARGS__); \ + static F##_t real_##F = nullptr + +#define LOAD_MOCK_ORIGIN(F) real_##F = (F##_t)dlsym(realAclHandle, #F) +#define CALL_ORIGIN(F, ...) real_##F(__VA_ARGS__) + +DEF_MOCK_ORIGIN(aclrtSetDevice, int32_t deviceId); +DEF_MOCK_ORIGIN(aclrtGetDevice, int32_t *deviceId); +DEF_MOCK_ORIGIN(aclrtResetDevice, int32_t deviceId); +DEF_MOCK_ORIGIN(aclrtGetDeviceCount, uint32_t *count); +DEF_MOCK_ORIGIN(aclrtQueryEvent, aclrtEvent event, aclrtEventStatus *status); +DEF_MOCK_ORIGIN(aclrtQueryEventStatus, aclrtEvent event, + aclrtEventRecordedStatus *status); +DEF_MOCK_ORIGIN(aclrtCreateEvent, aclrtEvent *event); +DEF_MOCK_ORIGIN(aclrtCreateEventWithFlag, aclrtEvent *event, uint32_t flag); +DEF_MOCK_ORIGIN(aclrtResetEvent, aclrtEvent event, aclrtStream stream); +DEF_MOCK_ORIGIN(aclrtDestroyEvent, aclrtEvent event); +DEF_MOCK_ORIGIN(aclrtRecordEvent, aclrtEvent event, aclrtStream stream); +DEF_MOCK_ORIGIN(aclrtStreamWaitEvent, aclrtStream stream, aclrtEvent event); +DEF_MOCK_ORIGIN(aclrtSynchronizeEvent, aclrtEvent event); +DEF_MOCK_ORIGIN(aclrtEventElapsedTime, float *ms, aclrtEvent start, + aclrtEvent end); +DEF_MOCK_ORIGIN(aclrtMemcpy, void *dst, size_t destMax, const void *src, + size_t count, aclrtMemcpyKind kind); + +__attribute__((constructor)) void InitMockAcl() { + if (initialized) { + return; + } + realAclHandle = dlopen("libascendcl.so", RTLD_LAZY); + if (realAclHandle == nullptr) { + printf("Failed to open libascendcl.so %s\n", dlerror()); + exit(1); + } + LOAD_MOCK_ORIGIN(aclrtSetDevice); + LOAD_MOCK_ORIGIN(aclrtGetDevice); + LOAD_MOCK_ORIGIN(aclrtResetDevice); + LOAD_MOCK_ORIGIN(aclrtGetDeviceCount); + LOAD_MOCK_ORIGIN(aclrtQueryEvent); + LOAD_MOCK_ORIGIN(aclrtQueryEventStatus); + LOAD_MOCK_ORIGIN(aclrtCreateEvent); + LOAD_MOCK_ORIGIN(aclrtCreateEventWithFlag); + LOAD_MOCK_ORIGIN(aclrtResetEvent); + LOAD_MOCK_ORIGIN(aclrtDestroyEvent); + LOAD_MOCK_ORIGIN(aclrtRecordEvent); + LOAD_MOCK_ORIGIN(aclrtStreamWaitEvent); + LOAD_MOCK_ORIGIN(aclrtSynchronizeEvent); + LOAD_MOCK_ORIGIN(aclrtEventElapsedTime); + LOAD_MOCK_ORIGIN(aclrtMemcpy); + + initialized = true; + RECORD_MOCKER_LOADED("acl"); +} + +__attribute__((destructor)) void DeInitMockAcl() { + if (!initialized) { + return; + } + if (realAclHandle != nullptr) { + dlclose(realAclHandle); + } + RECORD_MOCKER_UNLOADED("acl"); + initialized = false; +} + +thread_local int32_t currentDeviceId = -1; +uint32_t RealDeviceCount() { + static uint32_t deviceCount = []() { + if (IsEnvEnabled("TORCH_NPU_FAKE_MODE_FORCE_SINGLE_DEVICE")) { + M_DLOG() << "Mocker force device count to 1"; + return 1U; + } + uint32_t deviceCount = 0; + CALL_ORIGIN(aclrtGetDeviceCount, &deviceCount); + M_DLOG() << "Mocker detect real device count " << deviceCount; + return deviceCount; + }(); + return deviceCount; +} + +#ifdef __cplusplus +extern "C" { +#endif + +aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { + RECORD_MOCK(); + *devPtr = malloc(size); + M_DLOG() << "Mocking malloc " << size << " bytes to " << *devPtr; + memset(*devPtr, 0, size); + return ACL_SUCCESS; +} + +aclError aclrtMallocAlign32(void **devPtr, size_t size, + aclrtMemMallocPolicy policy) { + RECORD_MOCK(); + *devPtr = malloc((size + 31U) / 32U * 32U); + M_DLOG() << "Mocking malloc " << size << " bytes to " << *devPtr + << ", align32"; + memset(*devPtr, 0, size); + return ACL_SUCCESS; +} + +aclError aclrtFree(void *devPtr) { + RECORD_MOCK(); + free(devPtr); + return ACL_SUCCESS; +} + +aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, + aclrtMemcpyKind kind) { + RECORD_MOCK(); + if (kind == ACL_MEMCPY_DEVICE_TO_HOST) { + M_DLOG() << "Mocking copy from device " << src << " to host " << dst + << " size " << count; + memset(dst, 0, count); + } else { + CALL_ORIGIN(aclrtMemcpy, dst, destMax, src, count, kind); + } + return ACL_SUCCESS; +} + +aclError aclrtGetDeviceCount(uint32_t *count) { + RECORD_MOCK(); + *count = 8U; + return ACL_SUCCESS; +} + +aclError aclrtSetDevice(int32_t deviceId) { + RECORD_MOCK(); + currentDeviceId = deviceId; + if (deviceId >= RealDeviceCount()) { + M_DLOG() << "Mocking aclrtSetDevice from device " << deviceId << " to 0"; + return CALL_ORIGIN(aclrtSetDevice, 0); + } + return CALL_ORIGIN(aclrtSetDevice, deviceId); +} + +aclError aclrtResetDevice(int32_t deviceId) { + RECORD_MOCK(); + if (deviceId >= RealDeviceCount()) { + M_DLOG() << "Mocking aclrtResetDevice from device " << deviceId << " to 0"; + return CALL_ORIGIN(aclrtResetDevice, 0); + } + return CALL_ORIGIN(aclrtResetDevice, deviceId); +} + +aclError aclrtGetDevice(int32_t *deviceId) { + RECORD_MOCK(); + if (currentDeviceId == -1) { + return ACL_ERROR_UNINITIALIZE; + } + CALL_ORIGIN(aclrtGetDevice, deviceId); + if (*deviceId != currentDeviceId) { + M_DLOG() << "Mocking aclrtGetDevice from real " << *deviceId << " to " + << currentDeviceId; + *deviceId = currentDeviceId; + } + return ACL_SUCCESS; +} + +aclError aclrtQueryEvent(aclrtEvent event, aclrtEventStatus *status) { + return CALL_ORIGIN(aclrtQueryEvent, event, status); +} +aclError aclrtQueryEventStatus(aclrtEvent event, + aclrtEventRecordedStatus *status) { + return CALL_ORIGIN(aclrtQueryEventStatus, event, status); +} +aclError aclrtCreateEvent(aclrtEvent *event) { + auto ret = CALL_ORIGIN(aclrtCreateEvent, event); + M_DLOG() << "Create default event " << *event; + return ret; +} + +aclError aclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag) { + auto ret = CALL_ORIGIN(aclrtCreateEventWithFlag, event, flag); + M_DLOG() << "Create event " << *event << " with flag " << flag; + return ret; +} +aclError aclrtResetEvent(aclrtEvent event, aclrtStream stream) { + return CALL_ORIGIN(aclrtResetEvent, event, stream); +} + +aclError aclrtDestroyEvent(aclrtEvent event) { + return CALL_ORIGIN(aclrtDestroyEvent, event); +} +aclError aclrtRecordEvent(aclrtEvent event, aclrtStream stream) { + return CALL_ORIGIN(aclrtRecordEvent, event, stream); +} +aclError aclrtStreamWaitEvent(aclrtStream stream, aclrtEvent event) { + return CALL_ORIGIN(aclrtStreamWaitEvent, stream, event); +} +aclError aclrtSynchronizeEvent(aclrtEvent event) { + return CALL_ORIGIN(aclrtSynchronizeEvent, event); +} +aclError aclrtEventElapsedTime(float *ms, aclrtEvent start, aclrtEvent end) { + return CALL_ORIGIN(aclrtEventElapsedTime, ms, start, end); +} + +#ifdef __cplusplus +} +#endif diff --git a/torch_npu/csrc/toolkit/mocker/src/hccl_mocker.cpp b/torch_npu/csrc/toolkit/mocker/src/hccl_mocker.cpp new file mode 100644 index 0000000000..4bd5a51e10 --- /dev/null +++ b/torch_npu/csrc/toolkit/mocker/src/hccl_mocker.cpp @@ -0,0 +1,48 @@ +#include "stdio.h" + +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +#include "mocker_defines.h" + +static bool initialized = false; +__attribute__((constructor)) void InitMockHccl() { + if (initialized) { + return; + } + RECORD_MOCKER_LOADED("hccl"); + initialized = true; +} + +__attribute__((destructor)) void DeInitMockHccl() { + if (!initialized) { + return; + } + RECORD_MOCKER_UNLOADED("hccl"); + initialized = false; +} + +#ifdef __cplusplus +extern "C" { +#endif + +HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) { + RECORD_MOCK(); + return HCCL_SUCCESS; +} + +HcclResult HcclCommInitClusterInfo(const char *clusterInfo, uint32_t rank, + HcclComm *comm) { + RECORD_MOCK(); + return HCCL_SUCCESS; +} + +HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, + uint32_t rank, HcclComm *comm) { + RECORD_MOCK(); + return HCCL_SUCCESS; +} + +#ifdef __cplusplus +} +#endif diff --git a/torch_npu/csrc/toolkit/mocker/src/metrics_config.cpp b/torch_npu/csrc/toolkit/mocker/src/metrics_config.cpp new file mode 100644 index 0000000000..8afbeea20f --- /dev/null +++ b/torch_npu/csrc/toolkit/mocker/src/metrics_config.cpp @@ -0,0 +1,42 @@ +#define INTERVAL_NUM 3 +#define TYPE_NUM 7 + +#ifdef __cplusplus +extern "C" { +#endif + +static int kCosts[INTERVAL_NUM][TYPE_NUM] = { + {0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0}}; +void GetMockerCost(int type, int *before, int *inner, int *after) { + if ((type >= TYPE_NUM) || (type < 0)) { + return; + } + if (before != nullptr) { + *before = kCosts[0][type] > 0 ? kCosts[0][type] : 0; + } + if (inner != nullptr) { + *inner = kCosts[1][type] > 0 ? kCosts[1][type] : 0; + } + if (after != nullptr) { + *after = kCosts[2][type] > 0 ? kCosts[2][type] : 0; + } +} + +#define MAKE_BEAN(V, i) \ + void SetCostOf##V(int before, int inner, int after) { \ + kCosts[0][i] = before; \ + kCosts[1][i] = inner; \ + kCosts[2][i] = after; \ + } + +MAKE_BEAN(Default, 0) +MAKE_BEAN(CompileAndExecute, 1) +MAKE_BEAN(AsyncMemcpy, 2) +MAKE_BEAN(RecordEvent, 3) +MAKE_BEAN(WaitEvent, 4) +MAKE_BEAN(LazyDestroyEvent, 5) +MAKE_BEAN(ResetEvent, 6) + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/torch_npu/mocker/mocker.py b/torch_npu/mocker/mocker.py new file mode 100644 index 0000000000..32e251b4ed --- /dev/null +++ b/torch_npu/mocker/mocker.py @@ -0,0 +1,45 @@ +import os + + +def _load_mocker(): + import ctypes + package_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + mocker_path = str(os.path.join(package_path, "lib", "libnpu_mocker.so")) + mocker = ctypes.CDLL(mocker_path, ctypes.RTLD_GLOBAL) + mocker.SetCostOfCompileAndExecute(0, 0, 0) + mocker.SetCostOfAsyncMemcpy(0, 0, 0) + mocker.SetCostOfRecordEvent(0, 0, 0) + mocker.SetCostOfWaitEvent(0, 0, 0) + mocker.SetCostOfLazyDestroyEvent(0, 0, 0) + mocker.SetCostOfResetEvent(0, 0, 0) + + +def _mock_pg(): + import inspect + import torch.distributed as dist + _origin = dist.init_process_group + + def _mock_init_process_group(*args, **kwargs): + kwargs = inspect.signature(_origin).bind(*args, **kwargs) + kwargs.apply_defaults() + kwargs = dict(kwargs.arguments) + kwargs.pop("init_method") + + store = dist.FileStore("dummy_pg.store", kwargs["world_size"]) + store.set("0//npu//0", "x" * 4108) + store.set("0//npu//version_key", "git@HEAD") + kwargs["store"] = store + + print(f"{'*' * 5} You are running in fake mode and pg will init locally with dummy store {'*' * 5}", flush=True) + assert kwargs["rank"] >= 0, f"Expect rank {kwargs['rank']} >=0 in single process pg mode" + assert kwargs["world_size"] > 0, f"Expect world_size {kwargs['world_size']} > 0 in single process pg mode" + + return _origin(**kwargs) + + dist.init_process_group = _mock_init_process_group + + +if os.getenv("TORCH_NPU_FAKE_MODE", None) == "1": + _load_mocker() + if os.getenv("TORCH_NPU_FAKE_MODE_SINGLE_PROCESS_PG", None) == "1": + _mock_pg() -- Gitee From 440926dbcdc1b875ea6a201bc75e5774ef0132eb Mon Sep 17 00:00:00 2001 From: x00430237 Date: Fri, 15 Dec 2023 15:29:24 +0800 Subject: [PATCH 2/2] fakev1.11.0 --- torch_npu/csrc/core/npu/NPUMocker.cpp | 18 +++++++++++----- .../csrc/toolkit/mocker/inc/mocker_defines.h | 4 +++- .../csrc/toolkit/mocker/src/acl_mocker.cpp | 21 ++++++++++++------- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/torch_npu/csrc/core/npu/NPUMocker.cpp b/torch_npu/csrc/core/npu/NPUMocker.cpp index c3a1f011a2..8156a1217e 100644 --- a/torch_npu/csrc/core/npu/NPUMocker.cpp +++ b/torch_npu/csrc/core/npu/NPUMocker.cpp @@ -4,12 +4,14 @@ #include "torch_npu/csrc/framework/NPUDefine.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h" #include "ATen/record_function.h" #include #include -#include +#include #include #include "acl/acl_op_compiler.h" @@ -37,6 +39,7 @@ public: aclrtStream stream; std::vector data; std::string label; + queue::EventAllocatorType allocatorType = queue::RESERVED; }; class TorchQueueTaskHandle { @@ -142,7 +145,11 @@ public: switch (task.type) { case queue::QueueParamType::RECORD_EVENT: { aclrtRecordEvent(event, stream_); - c10_npu::NPUEventManager::GetInstance().DecreaseUnrecordedCount(event); + if (task.allocatorType == c10_npu::queue::HOST_ALLOCATOR_EVENT) { + THNPUCachingHostAllocator_insertCompleteEvent(event); + } else if (task.allocatorType == c10_npu::queue::NPU_ALLOCATOR_EVENT) { + c10_npu::NPUCachingAllocator::NpuAllocatorInsertRecordedEvent(event); + } aclrtSynchronizeStream(stream_); break; } @@ -168,6 +175,7 @@ public: void Pack(Task &task, queue::QueueParamType type, void *data, size_t size) override { auto param = static_cast(data); + task.allocatorType = param->eventAllocatorType; task.data.resize(sizeof(aclrtEvent), 0U); memcpy(task.data.data(), ¶m->event, sizeof(aclrtEvent)); } @@ -241,7 +249,7 @@ public: } int Launch(queue::QueueParamType type, void *data, size_t size) { - std::shared_lock sharedLock(launch_mutex_); + std::unique_lock sharedLock(launch_mutex_); auto task = GetHandle(type)->PackTask(type, data, size); M_DLOG() << "Stream " << stream_ << " launch " << task.label; RECORD_FUNCTION("Launch@" + task.label, std::vector{}); @@ -258,7 +266,7 @@ public: } void Sync() { - std::unique_lock sharedLock(launch_mutex_); + std::unique_lock sharedLock(launch_mutex_); while (pendingTaskNum_ > 0 && !cancel_) { } } @@ -289,7 +297,7 @@ private: std::mutex mutex_; std::condition_variable cond_; - std::shared_mutex launch_mutex_; + std::mutex launch_mutex_; }; class StreamManager { diff --git a/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h b/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h index f2caf81a66..2f26b923b2 100644 --- a/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h +++ b/torch_npu/csrc/toolkit/mocker/inc/mocker_defines.h @@ -1,12 +1,14 @@ #pragma once #include #include +#include +#include #include namespace mocker { class Logger : public std::basic_ostringstream { public: - Logger() { *this << "PID:" << getpid() << " TID:" << gettid() << " "; } + Logger() { *this << "PID:" << getpid() << " TID:" << syscall(SYS_gettid) << " "; } ~Logger() override { std::cerr << str() << std::endl; } }; } // namespace mocker diff --git a/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp b/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp index 217751be31..c7e4567898 100644 --- a/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp +++ b/torch_npu/csrc/toolkit/mocker/src/acl_mocker.cpp @@ -37,6 +37,7 @@ DEF_MOCK_ORIGIN(aclrtEventElapsedTime, float *ms, aclrtEvent start, aclrtEvent end); DEF_MOCK_ORIGIN(aclrtMemcpy, void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); +DEF_MOCK_ORIGIN(aclrtQueryDeviceStatus, int32_t deviceId, aclrtDeviceStatus *deviceStatus); __attribute__((constructor)) void InitMockAcl() { if (initialized) { @@ -62,6 +63,7 @@ __attribute__((constructor)) void InitMockAcl() { LOAD_MOCK_ORIGIN(aclrtSynchronizeEvent); LOAD_MOCK_ORIGIN(aclrtEventElapsedTime); LOAD_MOCK_ORIGIN(aclrtMemcpy); + LOAD_MOCK_ORIGIN(aclrtQueryDeviceStatus); initialized = true; RECORD_MOCKER_LOADED("acl"); @@ -99,19 +101,19 @@ extern "C" { aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { RECORD_MOCK(); - *devPtr = malloc(size); + *devPtr = malloc(32U); M_DLOG() << "Mocking malloc " << size << " bytes to " << *devPtr; - memset(*devPtr, 0, size); + memset(*devPtr, 0, 32U); return ACL_SUCCESS; } aclError aclrtMallocAlign32(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { RECORD_MOCK(); - *devPtr = malloc((size + 31U) / 32U * 32U); + *devPtr = malloc(32U); M_DLOG() << "Mocking malloc " << size << " bytes to " << *devPtr << ", align32"; - memset(*devPtr, 0, size); + memset(*devPtr, 0, 32U); return ACL_SUCCESS; } @@ -128,8 +130,6 @@ aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, M_DLOG() << "Mocking copy from device " << src << " to host " << dst << " size " << count; memset(dst, 0, count); - } else { - CALL_ORIGIN(aclrtMemcpy, dst, destMax, src, count, kind); } return ACL_SUCCESS; } @@ -191,6 +191,13 @@ aclError aclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag) { M_DLOG() << "Create event " << *event << " with flag " << flag; return ret; } + +aclError aclrtQueryDeviceStatus(int32_t deviceId, aclrtDeviceStatus *deviceStatus) { + M_DLOG() << "Query device " << deviceId << " status"; + *deviceStatus = ACL_RT_DEVICE_STATUS_NORMAL; + return ACL_SUCCESS; +} + aclError aclrtResetEvent(aclrtEvent event, aclrtStream stream) { return CALL_ORIGIN(aclrtResetEvent, event, stream); } @@ -202,7 +209,7 @@ aclError aclrtRecordEvent(aclrtEvent event, aclrtStream stream) { return CALL_ORIGIN(aclrtRecordEvent, event, stream); } aclError aclrtStreamWaitEvent(aclrtStream stream, aclrtEvent event) { - return CALL_ORIGIN(aclrtStreamWaitEvent, stream, event); + return ACL_SUCCESS; } aclError aclrtSynchronizeEvent(aclrtEvent event) { return CALL_ORIGIN(aclrtSynchronizeEvent, event); -- Gitee