diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index 7ea242593d7ab5bd0bbe1498fe54ea25d8e8502f..8f45fb4e9fee6de36e3839dd6d521541dc7b76c2 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -77,6 +77,38 @@ class NPUCachingAutotuner(CachingAutotuner): self.exceptions = [] + @staticmethod + def api_accuracy_checker(expected, actual, kernel_name, dump_path): + from msprobe.core.common.const import CompareConst + from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import BENCHMARK_COMPARE_SUPPORT_LIST + from .tools.api_accuracy_checker.precision_compare import precision_compare + from .tools.api_accuracy_checker.precision_standard.triton_standard_register import exist_in_precision_standard + from .tools.api_accuracy_checker.get_compare_result import get_compare_result + from .tools.api_accuracy_checker.common.compare_utils import convert_compare_column_to_row, print_check_details + + dtype = actual.dtype + + # only float use precision standard + if exist_in_precision_standard(kernel_name): + if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: + compare_column = precision_compare(kernel_name, expected, actual, dtype) # calc metrics + compare_row = convert_compare_column_to_row(compare_column, kernel_name) + status = get_compare_result(compare_row, kernel_name) # get compare results + if status == CompareConst.ERROR: + print('CHECK ACCURACY FAILED!', flush=True) + print(f'kernel: {kernel_name}, Dump Path: {dump_path}') + print_check_details(compare_column, kernel_name) + actual.copy_(expected) + check_flag = 1 + else: + print(f'The data type {dtype} is not supported for new precision standard. ' + f'Check accuracy by tolerance method.') + check_flag = 2 + else: + print(f'kernel_name {kernel_name} does not in new precision standard. Check accuracy by tolerance method.') + check_flag = 2 + return check_flag + def precompile(self, warm_cache_only=False): # xpu_graph changed TORCHINDUCTOR_CACHE_DIR. # When TORCHINDUCTOR_COMPILE_THREADS > 1, multiprocessing's fork method @@ -563,26 +595,36 @@ class NPUCachingAutotuner(CachingAutotuner): grid=grid, stream=stream, ) + try: + import msprobe + check_flag = 1 # check_flag with 1 meaning new precision standard while 2 meaning default tol + except ImportError: + check_flag = 2 + print("msprobe import failed, please check. It may be due to missing dependencies or other factors. " + "Check accuracy by tolerance method.") for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[fx_module.num_inputs:]): if actual.dtype != expected.dtype: expected = expected.to(actual.dtype) - acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) - rtol = acc_comp_tol['rtol'] - atol = acc_comp_tol['atol'] - - matches = torch.isclose( - actual, expected, rtol=rtol, atol=atol, equal_nan=False - ) - if not matches.all(): - abs_diff = torch.abs(actual - expected) - rel_diff = abs_diff / torch.abs(expected) - rel_diff.masked_fill_(matches, 0) - print( - f"CHECK ACCURACY FAILED! Greatest Relative Difference: {rel_diff.max().item()}, " f"Kernel Name: {kernel_name}", - flush=True) - print(f"kernel {kernel_name} Dump Path: {dump_path}") - actual.copy_(expected) - del matches + if check_flag == 1: + check_flag = self.api_accuracy_checker(expected, actual, kernel_name, dump_path) + if check_flag == 2: + acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=False + ) + if not matches.all(): + abs_diff = torch.abs(actual - expected) + rel_diff = abs_diff / torch.abs(expected) + rel_diff.masked_fill_(matches, 0) + print( + f"CHECK ACCURACY FAILED! Greatest Relative Difference: {rel_diff.max().item()}, " f"Kernel Name: {kernel_name}", + flush=True) + print(f"kernel {kernel_name} Dump Path: {dump_path}") + actual.copy_(expected) + del matches for arg in fx_args: del arg return True diff --git a/torch_npu/_inductor/tools/__init__.py b/torch_npu/_inductor/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/__init__.py b/torch_npu/_inductor/tools/api_accuracy_checker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/common/__init__.py b/torch_npu/_inductor/tools/api_accuracy_checker/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_input.py b/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_input.py new file mode 100644 index 0000000000000000000000000000000000000000..b52ea5bb5ebb6c0d1372423035b55964c232d66e --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_input.py @@ -0,0 +1,5 @@ +class PrecisionCompareInput: + def __init__(self, compare_row, dtype, compare_column): + self.row_npu = compare_row + self.dtype = dtype + self.compare_column = compare_column diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_utils.py b/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3d3332fe02339ce32eccfc0ec8ad03d0ea2e9b --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/common/compare_utils.py @@ -0,0 +1,45 @@ +import pandas as pd + +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS +from msprobe.core.common.const import CompareConst +from ..precision_standard.triton_standard_register import absolute_standard_api_list, \ + binary_standard_api_list, ulp_standard_api_list, thousandth_standard_api_list +from .const import accumulative_error_eb_threshold, ulp_err_threshold + + +def convert_compare_column_to_row(compare_column, api_name): + compare_column_list = compare_column.to_column_value("pass", " ") + compare_column_list.insert(0, api_name) + compare_row = pd.Series(compare_column_list, DETAIL_TEST_ROWS[0]) + return compare_row + + +def print_check_details(compare_column, api_name): + if api_name in absolute_standard_api_list: + standard = CompareConst.ABSOLUTE_THRESHOLD + metrics = ['inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio'] + values = [compare_column.inf_nan_error_ratio, compare_column.rel_err_ratio, compare_column.abs_err_ratio] + thresholds = [0, 0, 0] + elif api_name in binary_standard_api_list: + standard = CompareConst.BINARY_CONSISTENCY + metrics = ['error_rate'] + values = [compare_column.error_rate] + thresholds = [0] + elif api_name in ulp_standard_api_list: + standard = CompareConst.ULP_COMPARE + metrics = ['mean_ulp_error', 'ulp_error_proportion'] + values = [compare_column.mean_ulp_error, compare_column.ulp_error_proportion] + thresholds = [ulp_err_threshold] + elif api_name in thousandth_standard_api_list: + standard = CompareConst.THOUSANDTH_STANDARD + metrics = ['rel_err_thousandth'] + values = [compare_column.rel_err_thousandth] + thresholds = [CompareConst.THOUSANDTH_PASS_VALUE] + else: + standard = CompareConst.ACCUMULATIVE_ERROR_COMPARE + metrics = ['inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio', 'eb'] + values = [compare_column.inf_nan_error_ratio, compare_column.rel_err_ratio, compare_column.abs_err_ratio, + compare_column.eb] + thresholds = [0, 0, 0, accumulative_error_eb_threshold] + + print(f"Checked by precision standard:{standard}, metrics:{metrics}, values:{values}, thresholds:{thresholds}") diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/common/const.py b/torch_npu/_inductor/tools/api_accuracy_checker/common/const.py new file mode 100644 index 0000000000000000000000000000000000000000..aecc4489f5c764555a31e6b2d3e5d67deaa57fd1 --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/common/const.py @@ -0,0 +1,16 @@ +accumulative_error_eb_threshold = { + 'torch.float16': 2 ** -20, + 'torch.bfloat16': 2 ** -7, + 'torch.float32': 2 ** -14, + 'default': 2 ** -14 +} + +ulp_err_threshold = { + 'torch.float32': { + 'mean_ulp_error': 64, + 'ulp_err_proportion': 0.05 + }, + 'torch.float16': { + 'ulp_err_proportion': 0.001 + } +} diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/get_compare_result.py b/torch_npu/_inductor/tools/api_accuracy_checker/get_compare_result.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ca775f0e1ac79be72fe6f34a836c38829c1e69 --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/get_compare_result.py @@ -0,0 +1,43 @@ +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_UNSUPPORT_LIST, \ + ApiPrecisionCompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import record_absolute_threshold_result, \ + record_binary_consistency_result, record_thousandth_threshold_result, record_accumulative_error_compare_result +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn +from .precision_standard.ulp_compare import record_ulp_compare_result +from .precision_standard.triton_standard_register import TritonStandardRegister +from .common.compare_input import PrecisionCompareInput + + +def register_compare_func(): + registry = TritonStandardRegister() + registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result) + registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result) + registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result) + registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result) + return registry + + +def get_api_status(compare_row, api_name, compare_column, registry): + # compare_row is CompareColumn by run_ut + # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 + if (compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or + compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or + compare_row[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE): + compare_column.compare_result = CompareConst.SKIP + new_status = CompareConst.SKIP + else: + compare_column.api_name = api_name + dtype = compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE] + input_data = PrecisionCompareInput(compare_row, dtype, compare_column) + comparison_func = registry.get_comparison_function(api_name, dtype) + new_status = comparison_func(input_data) + return new_status + + +def get_compare_result(run_ut_column, api_name): + compare_column = ApiPrecisionOutputColumn() + registry = register_compare_func() + status = get_api_status(run_ut_column, api_name, compare_column, registry) + return status diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/precision_compare.py b/torch_npu/_inductor/tools/api_accuracy_checker/precision_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2a0ebd58ceb27123b0b7c909dc92444e55916f --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/precision_compare.py @@ -0,0 +1,87 @@ +import torch + +from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ + get_rel_err_origin +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput +from msprobe.core.common.const import CompareConst +from .precision_standard.triton_standard_register import TritonStandardRegister + + +class Comparator: + def __init__(self): + self.registry = self._register_compare_func() + + @staticmethod + def _absolute_standard_compare(input_data): + absolute_compare = AbsolutethdCompare(input_data) + absolute_compare.compare() + + @staticmethod + def _binary_standard_compare(input_data): + binary_compare = BinaryCompare(input_data) + binary_compare.compare() + + @staticmethod + def _ulp_compare(input_data): + ulp_compare = UlpCompare(input_data) + ulp_compare.compare() + + @staticmethod + def _thousandth_standard_compare(input_data): + thousandth_compare = ThousandthStdCompare(input_data) + thousandth_compare.compare() + + @staticmethod + def _benchmark_compare(input_data): + benchmark_compare = BenchmarkCompare(input_data) + benchmark_compare.compare() + + @staticmethod + def _accumulative_error_compare(input_data): + accumulative_error_compare = AccumulativeErrorCompare(input_data) + accumulative_error_compare.compare() + + def _register_compare_func(self): + registry = TritonStandardRegister() + registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare) + registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare) + registry.register(CompareConst.ULP_COMPARE, self._ulp_compare) + registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare) + registry.register(CompareConst.BENCHMARK, self._benchmark_compare) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare) + return registry + + def perform_comparison(self, api_name, input_data): + comparison_func = self.registry.get_comparison_function(api_name, None) + comparison_func(input_data) + + +def precision_compare(api_name, expected, actual, dtype): + compare_column = CompareColumn() + compare_column.bench_type = str(expected.dtype) + compare_column.npu_type = str(actual.dtype) + compare_column.shape = tuple(actual.shape) + + # to float32 for numpy without bfloat16 + if dtype == torch.bfloat16: + expected = expected.to(torch.float32) + actual = actual.to(torch.float32) + + fx_output = expected.cpu().numpy() # fx_output and triton_output need to be numpy data + triton_output = actual.cpu().numpy() + + _, abs_bench_with_eps = get_abs_bench_with_eps(fx_output, dtype) + abs_err = get_abs_err(fx_output, triton_output) + rel_err_origin = get_rel_err_origin(abs_err, abs_bench_with_eps) + + input_data = CompareInput(fx_output, triton_output, compare_column, dtype, rel_err_origin) + comparator = Comparator() + comparator.perform_comparison(api_name, input_data) + return compare_column diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/__init__.py b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_op_precision_standard.yaml b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_op_precision_standard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..999a79051272b7d82cf2477f5a7dfcb89d734f68 --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_op_precision_standard.yaml @@ -0,0 +1,14 @@ +AbsoluteThreshStandard: + - triton_unk_fused_add_native_layer_norm_9_2048_1024 + +BinaryCompareStandard: + - test_triton + +ULPStandard: + - test_triton + +ThousandthStandard: + - test_triton + +AccumulativeErrorStandard: + - test_triton diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_standard_register.py b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_standard_register.py new file mode 100644 index 0000000000000000000000000000000000000000..4023e0f5e97d21b77f3d7733bc98c31216def004 --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/triton_standard_register.py @@ -0,0 +1,32 @@ +import os + +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry + + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +standard_yaml_path = os.path.join(cur_dir, 'triton_op_precision_standard.yaml') +apis = load_yaml(standard_yaml_path) +absolute_standard_api_list = apis.get('AbsoluteThreshStandard') +binary_standard_api_list = apis.get('BinaryCompareStandard') +ulp_standard_api_list = apis.get('ULPStandard') +thousandth_standard_api_list = apis.get('ThousandthStandard') +accumulative_error_standard_api_list = apis.get('AccumulativeErrorStandard') + + +class TritonStandardRegister(StandardRegistry): + def __init__(self): + super().__init__() + self.api_standard_function_map[CompareConst.ABSOLUTE_THRESHOLD] = absolute_standard_api_list + self.api_standard_function_map[CompareConst.BINARY_CONSISTENCY] = binary_standard_api_list + self.api_standard_function_map[CompareConst.ULP_COMPARE] = ulp_standard_api_list + self.api_standard_function_map[CompareConst.THOUSANDTH_STANDARD] = thousandth_standard_api_list + self.api_standard_function_map[CompareConst.ACCUMULATIVE_ERROR_COMPARE] = accumulative_error_standard_api_list + + +def exist_in_precision_standard(kernel_name): + for api_list in apis.values(): + if kernel_name in api_list: + return True + return False diff --git a/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/ulp_compare.py b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/ulp_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..23040b0b0badb567a091b2ea9cbda33c587111a3 --- /dev/null +++ b/torch_npu/_inductor/tools/api_accuracy_checker/precision_standard/ulp_compare.py @@ -0,0 +1,103 @@ +from collections import namedtuple + +import torch + +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig +from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BasePrecisionCompare +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, convert_str_to_float +from msprobe.core.common.const import Const, CompareConst + + +UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', + ['mean_ulp_err_inf_nan_consistency', + 'ulp_err_proportion_ratio_inf_nan_consistency']) + + +class UlpPrecisionCompare(BasePrecisionCompare): + def __init__(self, input_data): + input_data.row_gpu = None # 由于服用了msprobe中的BasePrecisionCompare,需要补上row_gpu属性,无作用 + super().__init__(input_data) + self.compare_algorithm = CompareConst.ULP_COMPARE_ALGORITHM_NAME + + @staticmethod + def _get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion): + mean_ulp_err_threshold, ulp_err_proportion_threshold = StandardConfig.get_ulp_threshold(torch.float32) + if mean_ulp_err < mean_ulp_err_threshold: + return CompareConst.PASS, "" + elif ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + compare_message = "ERROR: ULP误差不满足标准\n" + return CompareConst.ERROR, compare_message + + @staticmethod + def _get_fp16_ulp_err_status(ulp_err_proportion): + _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) + if ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + compare_message = "ERROR: ULP误差不满足标准\n" + return CompareConst.ERROR, compare_message + + def _compute_mean_ulp_err(self): + column_name = ApiPrecisionCompareColumn.MEAN_ULP_ERR + npu_value = self._get_and_convert_values(column_name) + return npu_value, "" + + def _compute_ulp_err_proportion(self): + column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION + npu_value = self._get_and_convert_values(column_name) + return npu_value + + def _get_status(self, metrics, inf_nan_consistency): + ulp_inf_nan_consistency = inf_nan_consistency.mean_ulp_err_inf_nan_consistency and \ + inf_nan_consistency.ulp_err_proportion_ratio_inf_nan_consistency + + if not ulp_inf_nan_consistency: + compare_result = CompareConst.ERROR + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \ + "ERROR: ULP误差不满足标准\n" + metrics.update({CompareConst.COMPARE_RESULT: compare_result}) + return metrics + + dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE) + mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR) + ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION) + + if dtype == Const.TORCH_FLOAT32: + status, final_message = self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion) + else: + status, final_message = self._get_fp16_ulp_err_status(ulp_err_proportion) + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message + + status_dict = { + CompareConst.ULP_ERR_STATUS: status + } + compare_result = status + metrics.update(status_dict) + metrics.update({CompareConst.COMPARE_RESULT: compare_result}) + return metrics + + def _compute_ratio(self): + compare_message = "" + mean_ulp_err, mean_ulp_err_message = self._compute_mean_ulp_err() + compare_message += mean_ulp_err_message + npu_ulp_err_proportion = self._compute_ulp_err_proportion() + + metrics = { + CompareConst.MEAN_ULP_ERR: mean_ulp_err, + CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion, + CompareConst.COMPARE_MESSAGE: compare_message + } + return metrics, UlpInfNanConsistency(True, True) + + def _get_and_convert_values(self, column_name): + npu_value = self.row_npu.get(column_name) + if npu_value is None: + raise ValueError(f"value for column '{column_name}' is None.") + npu_value = convert_str_to_float(npu_value) + return npu_value + + +def record_ulp_compare_result(input_data): + us = UlpPrecisionCompare(input_data) + compare_result = us.compare() + return compare_result