From 6700d6b56c6f498794beb8360ab6529a62474f0d Mon Sep 17 00:00:00 2001 From: wangjiacheng Date: Thu, 29 May 2025 13:21:21 +0800 Subject: [PATCH] add new version pattern to GetCANNInfo --- torch_npu/csrc/core/npu/GetCANNInfo.cpp | 60 ++++++++++++++++--------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/torch_npu/csrc/core/npu/GetCANNInfo.cpp b/torch_npu/csrc/core/npu/GetCANNInfo.cpp index c009465e4e..8916a70fc9 100644 --- a/torch_npu/csrc/core/npu/GetCANNInfo.cpp +++ b/torch_npu/csrc/core/npu/GetCANNInfo.cpp @@ -61,54 +61,70 @@ int64_t VersionToNum(std::string versionStr) return num; } -double DriverVersionToNum(std::string versionStr) +int64_t DriverVersionToNum(std::string versionStr) { std::smatch results; - int major = -1; - int minor = -1; - int release = -1; - int TVersion = -1; - int RCVersion = -51; - int bVersion = 0; + int64_t major = -1; + int64_t minor = -1; + int64_t release = -1; + int64_t TVersion = -1; + int64_t RCVersion = -51; + int64_t patch = 0; + int64_t bVersion = 1; + int64_t alphaVersion = 0; // driver version check only supports pattern listed here: - // 24.1.0,24.1.RC1,24.1.rc1,24.1.RC1.B10,24.1.rc1.b10,24.1.T1 - if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).RC([0-9]+)"))) { + // pattern is major.minor.release.patch. release:num or RC+num or T+num, patch: num or alpha+num or beta+num. + std::regex re_rc("([0-9]+).([0-9]+).RC([0-9]+)", std::regex::icase); + std::regex re_num("([0-9]+).([0-9]+).([0-9]+)"); + std::regex re_rc_num("([0-9]+).([0-9]+).RC([0-9]+).([0-9]+)", std::regex::icase); + std::regex re_num_num("([0-9]+).([0-9]+).([0-9]+).([0-9]+)"); + std::regex re_t("([0-9]+).([0-9]+).T([0-9]+)", std::regex::icase); + std::regex re_rc_beta("([0-9]+).([0-9]+).RC([0-9]+).beta([0-9]+)", std::regex::icase); + std::regex re_rc_alpha("([0-9]+).([0-9]+).RC([0-9]+).alpha([0-9]+)", std::regex::icase); + if (std::regex_match(versionStr, results, re_rc)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); RCVersion = stoi(results[kVersionIndex3]); - } else if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).rc([0-9]+)"))) { + } else if (std::regex_match(versionStr, results, re_rc_num)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); RCVersion = stoi(results[kVersionIndex3]); - } else if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).([0-9]+)"))) { + patch = stoi(results[kVersionIndex4]); + } else if (std::regex_match(versionStr, results, re_num)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); release = stoi(results[kVersionIndex3]); - } else if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).T([0-9]+)"))) { + } else if (std::regex_match(versionStr, results, re_num_num)) { + major = stoi(results[kVersionIndex1]); + minor = stoi(results[kVersionIndex2]); + release = stoi(results[kVersionIndex3]); + patch = stoi(results[kVersionIndex4]); + } else if (std::regex_match(versionStr, results, re_t)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); TVersion = stoi(results[kVersionIndex3]); - } else if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).RC([0-9]+).B([0-9]+)"))) { + } else if (std::regex_match(versionStr, results, re_rc_beta)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); RCVersion = stoi(results[kVersionIndex3]); bVersion = stoi(results[kVersionIndex4]); - } else if (std::regex_match(versionStr, results, std::regex("([0-9]+).([0-9]+).rc([0-9]+).b([0-9]+)"))) { + } else if (std::regex_match(versionStr, results, re_rc_alpha)) { major = stoi(results[kVersionIndex1]); minor = stoi(results[kVersionIndex2]); RCVersion = stoi(results[kVersionIndex3]); - bVersion = stoi(results[kVersionIndex4]); + alphaVersion = stoi(results[kVersionIndex4]); } else { TORCH_NPU_WARN_ONCE("Driver Version: " + versionStr + " is invalid or not supported yet."); - return 0.0; + return 0; } - double num = ((static_cast(major) + 1.0) * 100000000) + - ((static_cast(minor) + 1.0) * 1000000) + - ((static_cast(release) + 1.0) * 10000) + - ((static_cast(RCVersion) + 1.0) * 100 + 5000) + - ((static_cast(TVersion) + 1.0) * 100) + - static_cast(bVersion); + int64_t num = ((major + 1) * 100000000) + + ((minor + 1) * 1000000) + + ((release + 1) * 10000) + + ((RCVersion + 1) * 100 + 5000) + + ((TVersion + 1) * 100) - + (alphaVersion ? 1 : 0) * (100 - alphaVersion) + + (bVersion - 1) + patch; return num; } -- Gitee