diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aee5852fdf189ff3cbc1572fbe85808d9e09798e --- /dev/null +++ b/torch_npu/_inductor/__init__.py @@ -0,0 +1,72 @@ +import torch +from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides +from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device +from torch._inductor import lowering as inductor_lowering + +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from torch_npu.utils._dynamo_device import NpuInterface +from torch_npu.npu.utils import device_count + +from . import config as npu_config +from . import codegen +from . import npu_fusion_attention_graph +from . import embedding_backward_patch +from .lowering import _register_npu_inductor_fallbacks, make_reduction +from .decomposition import _register_npu_inductor_decompositons +from .utils import get_current_raw_stream +from .config import log as npulog +from .config import aggresive_autotune + +npulog.info("perform torch_npu._inductor patch") + + +def _inductor_register_backend_for_device(): + from .codegen.schduling import NPUTritonScheduling + from .codegen.wrapper import NPUWrapperCodeGen + register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen) + +_inductor_register_backend_for_device() + +## Override original inductor device overrides in torch_npu + + +class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch_npu._inductor import get_current_raw_stream as {name}" + + +def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + +_inductor_register_device_op_overrides() + + +## Override original dynamo device interface in torch_npu +class NewNpuInterface(NpuInterface): + + @staticmethod + def is_available() -> bool: + return device_count() > 0 + + @staticmethod + def get_compute_capability(device_cur=None): + # npu has no concept of cc. triton-npu compiler depends on subarch instead + return torch.npu.get_device_name(device_cur) + +register_interface_for_device("npu", NewNpuInterface) +device = get_interface_for_device("npu") + +inductor_lowering.make_reduction = make_reduction +_register_npu_inductor_fallbacks() +_register_npu_inductor_decompositons() + + +def _replace_benchmark_all_configs(): + from torch._inductor.triton_heuristics import CachingAutotuner + from .npu_triton_heuristics import benchmark_all_configs + CachingAutotuner.benchmark_all_configs = benchmark_all_configs + +if (aggresive_autotune): + _replace_benchmark_all_configs() + import os + os.environ["TRITON_BENCH_METHOD"] = "npu" \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/__init__.py b/torch_npu/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8a7d6726378f6daf6ce0c4fa5dbc00ae04aeaf --- /dev/null +++ b/torch_npu/_inductor/codegen/__init__.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + +from torch._inductor.ir import Reduction, LoopBody +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor import sizevars +from torch._inductor.codegen.triton import TritonKernel + +from torch_npu._inductor.codegen._sizevars import simplify +from torch_npu._inductor.codegen.ir import (num_splits, loopbody__call__, transform_dims_in_indexing, + substituted_dims_in_indexing) +from torch_npu._inductor.codegen.triton import is_compatible +from torch_npu._inductor.codegen.triton import group_fn, select_index_dtype, select_tiling +from ..config import log as npulog + +npulog.info("perform npu_indexing patch") + +Reduction.num_splits = num_splits +setattr(LoopBody, 'transform_dims_in_indexing', transform_dims_in_indexing) +setattr(LoopBody, 'substituted_dims_in_indexing', substituted_dims_in_indexing) + +LoopBody.__call__ = loopbody__call__ +# need to enable this to speedup attn_cp_test +# ComputedBuffer.simplify_and_reorder = simplify_and_reorder +# triton scheduling +TritonScheduling.group_fn = group_fn +TritonScheduling.select_index_dtype = select_index_dtype +TritonScheduling.select_tiling = select_tiling +# triton kernel +setattr(TritonKernel, 'is_compatible', is_compatible) + +# util +sizevars.SizeVarAllocator.simplify = simplify \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/_sizevars.py b/torch_npu/_inductor/codegen/_sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..84206554041b15e3930fead7d0759bb3b9c8ab8e --- /dev/null +++ b/torch_npu/_inductor/codegen/_sizevars.py @@ -0,0 +1,10 @@ +import sympy +from sympy import Expr +from torch._inductor.utils import sympy_subs + + +def simplify(self, expr: Expr): + if isinstance(expr, (tuple, list)): + return [sympy.expand(s).xreplace(self.replacements) for s in expr] + return sympy.expand(expr).xreplace(self.replacements) + diff --git a/torch_npu/_inductor/codegen/ir.py b/torch_npu/_inductor/codegen/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..54f9240925bb257f9b1184087e59c43667c3b85e --- /dev/null +++ b/torch_npu/_inductor/codegen/ir.py @@ -0,0 +1,194 @@ +from typing import List, Tuple, Dict, Any, Optional +import itertools +from torch._inductor.virtualized import V +from torch._inductor.ir import (ReductionHint, IRNode, ModularIndexing, FloorDiv) +from torch._inductor.utils import sympy_subs, sympy_index_symbol +import sympy +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel +from ..config import log + + +# NPU doesn't need to support ReductionHint.OUTER, and persistent reduction +def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + return ReductionHint.DEFAULT, 1 + + +def detect_flattened_dims(kernel, index): + new_vars = {} + if not isinstance(index, (sympy.core.add.Add, ModularIndexing, FloorDiv)): + return new_vars + + def detect_flattened_axis(expr): + def init_new_vars(var, length): + if var not in new_vars: + new_vars[var] = {length: [None, None]} + if length not in new_vars[var]: + new_vars[var][length] = [None, None] + if isinstance(expr, ModularIndexing): + var, divisor, length = expr.args + init_new_vars(var, length) + new_vars[var][length][1] = (expr, divisor, length) + elif isinstance(expr, FloorDiv): + var, divisor = expr.args + init_new_vars(var, divisor) + # over than 1 node_schedule, var may be deleted in kernel.range_tree_nodes + # it shoule be find in range_tree_nodes_removed dict + if (var in kernel.range_tree_nodes): + numel = kernel.range_tree_nodes[var].length + else: + numel = kernel.range_tree_nodes_removed[var].length + + length = expr.eval(numel, divisor) + new_vars[var][divisor][0] = (expr, divisor, length) + + else: + for x in expr.args: + detect_flattened_axis(x) + + # add + if isinstance(index, sympy.core.add.Add): + for x in index.args: + detect_flattened_axis(x) + elif isinstance(index, (ModularIndexing, FloorDiv)): + detect_flattened_axis(index) + else: + pass + + # make sure FloorDiv, MouldarIndexing must be in-pair + for var, divisors in new_vars.items(): + if var in kernel.range_tree_nodes: + parent_axis = kernel.range_tree_nodes[var] + else: + parent_axis = kernel.range_tree_nodes_removed[var] + for divisor, pair in divisors.items(): + if not pair[0] and not pair[1]: + pass + #FloorDiv not inplace + elif not pair[0]: + _, _, length = pair[1] + expr = FloorDiv(var, length) + new_vars[var][divisor][0] = (expr, length, parent_axis.length // length) + #ModularIndexing not inplace + elif not pair[1]: + expr = ModularIndexing(var, 1, divisor) + new_vars[var][divisor][1] = (expr, 1, divisor) + else: + pass + + return new_vars + + +def rebuild_flattened_dims(indexing): + def rebuild_flattened_dim(key, index, old_node, flatten_dim): + for _, pair in flatten_dim.items(): + new_var_expr = sympy.Integer(0) + origin_axis_length = 0 + for axis in pair: + expr, divisor, length = axis + # 3. try to rebuild the axis in kernel + new_node = old_node.parent.lookup(divisor, length) + # 4. substitute div/mod expression in indexing + index = index.subs(expr, new_node.symbol()) + indexing[key] = index + if isinstance(expr, FloorDiv): + new_var_expr = new_var_expr + new_node.symbol() * divisor + origin_axis_length = divisor * length + elif isinstance(expr, ModularIndexing): + new_var_expr = new_var_expr + new_node.symbol() + V.kernel.expr_substituted[expr] = new_node.symbol() + + if var not in V.kernel.range_tree_nodes_substituted: + V.kernel.range_tree_nodes_substituted[var] = [] + V.kernel.range_tree_nodes_substituted[var].append((origin_axis_length, new_var_expr)) + + def find_index_in_substitute(index, kernel): + return any([index.find(key) for key in kernel.expr_substituted.keys()]) + + kernel = V.kernel + for key, index in indexing.items(): + if find_index_in_substitute(index, kernel): + new_index = sympy_subs(index, kernel.expr_substituted) + indexing[key] = new_index + + # 1. try to find out flattened axis from indexing + flatten_dims = detect_flattened_dims(kernel, index) + #2. try to rebuild these flattened dims + for var, flatten_dim in flatten_dims.items(): + if (var in kernel.range_tree_nodes): + old_node = kernel.range_tree_nodes[var] + else: + old_node = kernel.range_tree_nodes_removed[var] + + rebuild_flattened_dim(key, index, old_node, flatten_dim) + + +def substituted_dims_in_indexing(self, indexing, kernel, range_tree_nodes_substituted): + substituted = False + for var, candidates in range_tree_nodes_substituted.items(): + if len(candidates) == 0: + raise ValueError(f"No candidates found for variable {var}: {candidates}") + exprs = sorted(candidates, reverse=True, key=lambda x: x[0]) + # the best candidate is with the longest numel + numel = exprs[0][0] + expr = exprs[0][1] + node = kernel.range_tree_nodes[var] + if node.length != numel: + log.debug("sub nodes (expr%s, numel:%d) can not substitute parent node(%s:%d)", + expr, numel, node.symbol(), node.length) + continue + for key, index in indexing.items(): + if var in index.free_symbols: + index = index.subs(var, expr) + indexing[key] = index + substituted = True + + return substituted + + +def transform_dims_in_indexing(self, indices): + if self.indexing is None: + index = list(itertools.chain.from_iterable(indices)) + if len(index) != len(self.var_ranges): + raise ValueError(f"Length mismatch: index len {len(index)} does not matches var_ranges len {len(self.var_ranges)}") + if any(v in self.var_ranges for v in index): + raise ValueError(f"v in self.var_ranges for v in index") + replacements = dict(zip(self.var_ranges.keys(), index)) + indexing_map = dict(zip(index, self.var_ranges.keys())) + setattr(self, 'indexing_map', indexing_map) + self.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + if V.kernel is not None and isinstance(V.kernel, NPUIndexTritonKernel): + rebuild_flattened_dims(self.indexing) + + +# select tiling axis, recover missing dimensions, +def loopbody__call__(self, *indices): + if self.indexing is None: + index = list(itertools.chain.from_iterable(indices)) + if len(index) != len(self.var_ranges): + raise ValueError(f"Length mismatch: index len {len(index)} does not matches var_ranges len {len(self.var_ranges)}") + if any(v in self.var_ranges for v in index): + raise ValueError(f"v in self.var_ranges for v in index") + replacements = dict(zip(self.var_ranges.keys(), index)) + self.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + result = self.root_block() + self.indexing = None + return result + + diff --git a/torch_npu/_inductor/codegen/schduling.py b/torch_npu/_inductor/codegen/schduling.py new file mode 100644 index 0000000000000000000000000000000000000000..824bad0b7936164a5141c07ff074da3315f2898d --- /dev/null +++ b/torch_npu/_inductor/codegen/schduling.py @@ -0,0 +1,163 @@ +import itertools +import contextlib +from torch._inductor.codegen.triton import (TritonScheduling, log, config, EnableReduction, DisableReduction, + indexing_dtype_strength_reduction) +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.virtualized import ( + V, +) +from torch._inductor.codecache import code_hash +from torch._dynamo.utils import counters +from torch._inductor.utils import sympy_index_symbol, ModularIndexing, FloorDiv +import sympy +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel +from .split_tiling import SplitTiling + + +class NPUTritonScheduling(TritonScheduling): + def __init__(self, scheduler): + super().__init__(scheduler) + self.kernel_type = NPUIndexTritonKernel + + # create NPUTritonKernel or NPUIndexTritonKernel + # set final_kernel to V after kernel context exits + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): + from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel + tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, reduction_numel + ) + + is_split_scan = any( + isinstance(node, BaseSchedulerNode) and node.is_split_scan() + for node in node_schedule + ) + # Note: backported patch + kernel_type = TritonSplitScanKernel if is_split_scan else self.kernel_type + kernel_args = tiled_groups + kernel_kwargs = { + "reduction_hint": reduction_hint_val, + "mutations": mutations, + "index_dtype": index_dtype, + } + kernel = kernel_type( + *kernel_args, + **kernel_kwargs, + ) + kernel.buf_accesses = buf_accesses + setattr(kernel, "node_schedule", node_schedule) + # generate code for the kernel + self.decide_codegen_dims_in_kernel(node_schedule, kernel) + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + kernel_name = self.define_kernel(src_code, node_schedule) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + #NPU don't need persistent reduction + final_kernel = kernel # type: ignore[assignment] + with V.set_kernel_handler(final_kernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + setattr(V, "final_kernel", final_kernel) + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernel.args.live_output_buffers() + for node in node_schedule: + if not isinstance(node, BaseSchedulerNode): + continue + name = node.get_name() + if name not in live_outs: + continue + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def decide_codegen_dims_in_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + # 1. transform dims: create new dims to substitute floor_divide and modular expression + stack = contextlib.ExitStack() + for i, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) + else: + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node._body.transform_dims_in_indexing(index_vars) + + # 2.collection additional node to be substituted + self.additional_nodes_to_be_subs(kernel, kernel.range_tree_nodes_substituted) + # 3.do the substitution on all indexing + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + indexing = node._body.indexing + node._body.substituted_dims_in_indexing(indexing, kernel, kernel.range_tree_nodes_substituted) + # 4.remove the substituted dims from kernel + for var, _ in kernel.range_tree_nodes_substituted.items(): + if (var in kernel.range_tree_nodes): + root = kernel.range_tree_nodes[var].parent + root.remove_entry(var) + + # select split and tiling axis + split_tiling = SplitTiling(kernel) + split_tiling.select_tiling_axis() + + def additional_nodes_to_be_subs(self, kernel, node_to_be_substituted): + for node in kernel.range_tree_nodes.values(): + if node.expr != sympy_index_symbol(f"{node.parent.prefix}index") \ + or len(node.parent.var_ranges) == 1 \ + or node.symbol() in node_to_be_substituted: + continue + numel = sympy.Integer(1) + new_var_expr = sympy.Integer(0) + for k, s in node.parent.var_ranges.items(): + if k == node.symbol(): + continue + numel = numel * s + sub_node = kernel.range_tree_nodes[k] + if isinstance(sub_node.expr, FloorDiv): + new_var_expr = new_var_expr + sub_node.symbol() * sub_node.divisor + elif isinstance(sub_node.expr, ModularIndexing): + new_var_expr = new_var_expr + sub_node.symbol() + + if numel == node.length: + node_to_be_substituted[node.symbol()] = [(node.length, new_var_expr)] + else: + log.warning("sub nodes (expr%s, numel:%d) can not make up parent node(%s:%d)", + new_var_expr, numel, node.symbol(), node.length) + + + + + diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..c36674de7483494014168c38a6d432480374ef92 --- /dev/null +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -0,0 +1,294 @@ +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.utils import ModularIndexing, sympy_subs +import sympy as sympy +from torch._inductor.virtualized import V +from torch._inductor.codegen.triton import (EnableReduction, DisableReduction) +from torch._inductor.utils import next_power_of_2 +from ..config import num_vector_core, log +from .triton_utils import get_aligned_numel + + +# split and tiling axis selector +class SplitTiling: + def __init__(self, kernel: TritonKernel): + self.kernel = kernel + self.indexing = [] + + def key(x): + # to be higher than x and y + if x.name[0] == 'w' or x.name[0] == 'v' or x.name[0] == 'p' or x.name[0] == 't': + return "z" + x.name + # to be lower than floor_dir + elif isinstance(x.expr, ModularIndexing): + return x.name[0] + "0" + x.name[1:] + else: + return x.name + + kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] + kernel.sorted_axis.sort(reverse=True, key=key) + for i, dim in enumerate(kernel.sorted_axis): + dim.sorted_order = i + + self.find_lowest_dimension() + self.should_outer_reduce = False + + # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低,split和tiling轴选择策略的复杂性 。 + # Split 原则2: 切分的数量要和AIcore的数量对齐(相同或是倍数)。每个核要分配的split的量一致。每个split形状要一致(包括维度和尺寸)。 + # Split 原则3: 对于规约类融合算子, 从非规约选择切分轴。对于非规约类融合算子, 从所有轴中选切分轴。 + # 为了tiling时刻的低维tilesize最大化,切分轴最好不是低维轴且长度大于aicore的数量 。 + # Split 原则4: 如果高维规约类融合算子,而且高维尺寸非常大( >= 64KB),低维度尺寸比较小( <= 32B), 可以选择对规约轴切分,然后在核间用atomic + # 原语做规约。 + # Split 原则5 :根据算子逻辑,优先选择一维发射。 + def select_split_axis(self): + def is_reduction(x): + return x.prefix == 'r' + + def select_longest_dim(can_be_low_dim=True): + longest = -1 + longest_dim = None + for x in candidates: + if SplitTiling.great_than(x.length, longest) and (can_be_low_dim or not self.is_lowest_dimension(x)): + longest_dim = x + longest = x.length + return longest_dim + # point-wise: all dims , reduction: outer_reduction dim or non-reduction dims + candidates = [x for x in self.kernel.sorted_axis if not is_reduction(x) or self.should_outer_reduce_me(x)] + if self.should_outer_reduce: + return self.kernel.split_axis + + # 0307 patch 5lines + if len(candidates) > 0 and SplitTiling.ge_than(candidates[0].length, num_vector_core): + longest_dim = candidates[0] + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + return longest_dim + + #longest and not low dims + longest_dim = select_longest_dim(can_be_low_dim=False) + + # longest and can be low dims + if longest_dim is None or SplitTiling.less_than(longest_dim.length, int(num_vector_core * 0.8)): + longest_dim = select_longest_dim(can_be_low_dim=True) + if longest_dim is not None: + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + elif len(self.kernel.sorted_axis) > 0: + longest_dim = self.kernel.sorted_axis[0] + self.kernel.split_axis = longest_dim + self.kernel.split_axis.is_split_axis = True + + return longest_dim + + # Tiling 原则1:切分要照顾所有load / store 中索引表达式的中的低维轴 :所有的低维轴都被切分 从而成为tiling 轴。写代码的时候对所有的tiling + # 轴通过make_range产生连续索引,从而保证load / store的连续性。 + # Tiling 原则2 :规约的tile必须要二维。 对于低维规约算子,规约轴和至少一个非规约轴要选择为tiling轴。对于高维规约,规约轴和低维轴要选择为tiling轴 + # 对于是多维规约, 所有的规约轴都要选择为tiling 轴 。 + # Tiling 原则3: 如果tiling轴是低维,在该轴上的切分的尺寸要与SIMD的BlockSize 对齐(32bytes) + # Tiling 原则4: 低维轴的tile size 越大,性能越好。这个其实autotune 的原则,放在这里只是为了更好解释用例中使用的数值 。 + + # fixme, two tiling axis might be insufficient when there're 3 or more low-dims in indexing + def select_tiling_axis(self): + # True:self.kernel.axis2 is Not None and all reduction axis selected, False: other cases + def axis2_selection_done(axis): + if self.kernel.total_numels <= 1: + return True + elif self.kernel.axis2 is not None: + is_reduction = axis.prefix == "r" + if not is_reduction: + return True + reduction_axis = self.kernel.numof_reduction_axis() + return True if reduction_axis <= 1 else len(self.kernel.axis2_list) == reduction_axis + else: + return False + + if self.kernel.axis2 is not None or self.kernel.axis1 is not None: + return + # two or more reduction axises, need to flatten reduction dims to one to do 1 dim reduction . + if self.kernel.numof_reduction_axis() > 1: + self.kernel.persistent_reduction = True + biggest = -1 + dims = self.kernel.sorted_axis + if self.kernel.split_axis is None: + self.select_split_axis() + + if self.kernel.split_axis is None: + return + # select tiling_axis2 then tiling_axis1, for reduction, all reduction axis will be selected as tiling_axis2 + for i in range(len(dims) - 1, -1, -1): + axis = dims[i] + numel = axis.length + if isinstance(numel, (sympy.Symbol, sympy.Expr)) and not isinstance(numel, sympy.Integer): + numel = numel.subs(V.graph.sizevars.var_to_val) + if axis.is_split_axis: + dtype = self.kernel.get_axis_dtype(axis) + + min_aligned_numel = get_aligned_numel(dtype) + _, numel = SplitTiling.decide_nblocks_xblock(numel, len(self.kernel.sorted_axis) <= 1, min_aligned_numel) + + # choose reduction axis or low-dim as axis2 + if not axis2_selection_done(axis): + axis.is_tiling_axis2 = True if SplitTiling.great_than(numel, 1) else False + # axis2 must be the reduction axis in case inside_reduction + if axis.prefix == "r": + axis.is_tiling_axis2 = True + if axis.is_tiling_axis2 and self.kernel.axis2 is None: + self.kernel.axis2 = axis.symbol() + if self.kernel.numof_reduction_axis() > 1: + self.kernel.axis2_list.append(axis.symbol()) + self.kernel.axis2 = axis.symbol() if isinstance(axis.expr, ModularIndexing) else self.kernel.axis2 + else: + # for _higher_order_reduction, axis1 must be the lowest dimension + if self.kernel.inside_reduction and self.kernel.is_higher_order_reduction(): + self.kernel.axis1 = axis.symbol() + break + + # low-dim should be selected as another tiling axis + if self.is_lowest_dimension(axis): + self.kernel.axis1 = axis.symbol() + break + # select the longest in other cases + if numel > biggest: + self.kernel.axis1 = axis.symbol() + biggest = numel + + if self.kernel.axis1 is not None: + axis = self.kernel.range_tree_nodes[self.kernel.axis1] + axis.is_tiling_axis1 = True + + log.debug(f"split_tiling numels:{self.kernel.numels} split_axis: {self.kernel.split_axis.symbol()} " + f"axis1:{self.kernel.axis1} axis2:{self.kernel.axis2} low_dims:{self.kernel.low_dims}, " + f"indexing: {self.indexing}") + + # the below logic doesn't work when there're two reduction axis, but only one need outer reduction + def should_outer_reduce_me(self, x): + should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, 32768) and x.is_loop + if should_outer: + self.should_outer_reduce = True + self.kernel.split_axis = x + self.kernel.split_axis.is_split_axis = True + return should_outer + + @staticmethod + def decide_nblocks_xblock(numel, no_axis2, min_aligned_numel, xblock=None): + #no_axis2 mean there's only on dims + min_xblock = min_aligned_numel if no_axis2 else 1 + + # need to keep linearity for low_dims + if xblock is None: + xblock = (numel + num_vector_core - 1) // num_vector_core if numel > num_vector_core else min_xblock + + # fixme, aligning is wasting cores . + #if (not no_axis2 and is_low_dim) or same_axis1: + xblock = next_power_of_2(xblock) + + nblocks = (numel + xblock - 1) // xblock + return nblocks, xblock + + @staticmethod + def get_nblocks_before_launch(numel, xblock): + nblocks = (numel + xblock - 1) // xblock + return nblocks, xblock + + @staticmethod + def get_nblocks_xblock_list(numel): + ret = [] + XBLOCK = numel + NBLOCKS = 1 + ret.append((NBLOCKS, XBLOCK)) + while NBLOCKS <= num_vector_core and XBLOCK > 1: + XBLOCK -= 1 + NBLOCKS = (numel + XBLOCK - 1) // XBLOCK + XBLOCK = (numel + NBLOCKS - 1) // NBLOCKS + ret.append((NBLOCKS, XBLOCK)) + + return ret + + # return True when x is the low-dim in indexing + def is_lowest_dimension(self, x): + return x.sorted_order in self.kernel.low_dims + + def find_lowest_dimension(self): + def construct_low_dim(): + for index in self.indexing: + coefficients_dict = index.as_coefficients_dict() + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.kernel.range_tree_nodes: + continue + + if value == sympy.Integer(1): + axis = self.kernel.range_tree_nodes[key] + self.kernel.low_dims.add(axis.sorted_order) + + # all read index should be considered + buf_names = [node.node.name for node in self.kernel.node_schedule if node not in (EnableReduction, DisableReduction)] + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + + for read in node._body.reads: + name = node._body.indexing_exprs_name[read] + read_is_inptr = True + for arg, expr in node._body.reads_name2expr.items(): + # read inner buf should be excluded (tmp will cse replace load) + if read == expr and (arg[:3] != 'arg' and arg in buf_names): + read_is_inptr = False + if read_is_inptr: + names.append(name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + if self.kernel.inside_reduction: + construct_low_dim() + return + # for non-reduction, write index should be considered + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + + for write in node._body.writes: + name = node._body.indexing_exprs_name[write] + names.append(name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + construct_low_dim() + + @staticmethod + def convert(x, y): + xnumel = x + ynumel = y + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + xnumel = xnumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + ynumel = ynumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): + ynumel = sympy.Integer(ynumel) + + if isinstance(ynumel, sympy.Integer) and isinstance(xnumel, int): + xnumel = sympy.Integer(xnumel) + + return (xnumel, ynumel) + + + @staticmethod + def less_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel < ynumel + + @staticmethod + def great_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel > ynumel + + @staticmethod + def ge_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel >= ynumel diff --git a/torch_npu/_inductor/codegen/tile_generator.py b/torch_npu/_inductor/codegen/tile_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6aef7699a3ddc0f692badeb5e8d43879388b9b4f --- /dev/null +++ b/torch_npu/_inductor/codegen/tile_generator.py @@ -0,0 +1,135 @@ +import copy +import math + +from torch._inductor.triton_heuristics import Config +from torch._inductor.utils import next_power_of_2 +from .triton_utils import get_aligned_numel, byte_per_numel + + +# generate tiling configs +class TileGenerator: + + @staticmethod + def aligned_numel(numel): + aligned = next_power_of_2(numel) + return aligned + + @staticmethod + def get_byte_per_numel(dtype): + if dtype is None: + return 1 + return byte_per_numel[dtype] + + @staticmethod + def valid_config(config, align_numel, rnumel=1): + + bytes_align_numel = align_numel + max_numel = 16384 * 4 // bytes_align_numel + + rblock = config["RBLOCK"] if "RBLOCK" in config else rnumel + xblock_sub = config["XBLOCK_SUB"] + if rblock * xblock_sub <= max_numel: + return True + + return False + + # when rblock is low dim, need to maximize rblock + @staticmethod + def descend_xblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + + bytes_align_numel = align_numel + start_numel = 2048 // bytes_align_numel if aggresive else 1024 // bytes_align_numel + # include rblock is too big, need to decend rblock first + rblock = rnumel if rnumel > 0 else 1 + while (rblock > start_numel): + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + cfg["RBLOCK"] = rblock + xblock_sub = TileGenerator.aligned_numel(xblock) + + while True: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + if TileGenerator.valid_config(newcfg, align_numel, rnumel=rblock): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + xblock_sub = xblock_sub // 2 + if xblock_sub * rblock <= start_numel: + break + + @staticmethod + def descend_rblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + bytes_align_numel = align_numel + start_numel = 4096 // bytes_align_numel if aggresive else 1024 // bytes_align_numel + + xblock_sub = start_numel if xblock > start_numel else xblock + cfg["XBLOCK_SUB"] = xblock_sub + rblock = rnumel + while True: + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + if xblock_sub * rblock <= start_numel: + break + + @staticmethod + def descend_xblock_rblock(rnumel, xblock, configs, cfg, align_numel, aggresive=True): + bytes_align_numel = align_numel + start_numel = 4096 // bytes_align_numel if aggresive else 1024 // bytes_align_numel + + # Depending on the number of bytes available to the hardware UB, + # 4096 bytes is an appropriate empirical value for an intra-core split. + # Rule: xblock_sub * rblock <= start_numel + end_numel = math.floor(math.sqrt(start_numel)) + + xblock = next_power_of_2(xblock) + rnumel = next_power_of_2(rnumel) + + xblock_sub = xblock if xblock > start_numel else xblock + rblock = start_numel if rnumel > start_numel else rnumel + + rblock_is_biggerr = rblock > xblock_sub + + if xblock_sub * rblock <= start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + + if rblock_is_biggerr: + while rblock > xblock_sub and xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["RBLOCK"] = rblock + xblock_sub = xblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + rblock = rblock // 2 + else: + while rblock < xblock_sub and xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + xblock_sub = xblock_sub // 2 + + while xblock_sub * rblock > start_numel: + newcfg = copy.deepcopy(cfg) + newcfg["XBLOCK_SUB"] = xblock_sub + newcfg["RBLOCK"] = rblock + if TileGenerator.valid_config(newcfg, align_numel): + configs.append(Config(newcfg, num_warps=1, num_stages=1)) + if xblock_sub >= end_numel: + xblock_sub = xblock_sub // 2 + if rblock >= end_numel: + rblock = rblock // 2 + + @staticmethod + def nearest_power_of_2(n): + big = next_power_of_2(n) + small = big // 2 + return big if (big - n) < (n - small) else small diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..791593a6da522d2d09bd11fefa5d573fcebf2c43 --- /dev/null +++ b/torch_npu/_inductor/codegen/triton.py @@ -0,0 +1,1858 @@ +import os +import itertools +import operator +from typing import List, Set, Iterable, Callable, Dict +import functools +from enum import Enum +from typing import ( + Optional, + Union, + Tuple, + Any, +) +import re +import sympy + +import torch +from torch._inductor.utils import sympy_subs +from torch._inductor.scheduler import SchedulerNode + +from torch._inductor.codegen.triton import ( + IndexingOptions, + sympy_dot, + CantSplit, + triton_reshape, + TritonCSEVariable, + free_symbol_startswith, + OpsHandler, DisableReduction, EnableReduction, +) + +from torch._inductor.codegen.triton import ( + + TritonKernel, + TritonKernelOverrides, + IterationRangesRoot, + IterationRangesEntry, + CSEVariable, + gen_common_triton_imports, + ReductionHint, + BlockPtrOptions, + triton_acc_type, + triton_constant, + is_welford_reduction, + triton_compute_type, + cast, + ModularIndexing, FloorDiv, sympy_index_symbol, + log +) + +from torch.utils import _pytree as pytree +from torch.utils._sympy.value_ranges import ValueRanges + + +from torch._inductor import config, ir +from torch._inductor.virtualized import ( + V, + StoreMode, + ReductionType, + _ops as ops, +) + +from torch._inductor.utils import ( + Placeholder, + next_power_of_2, +) + + +from torch._inductor.codegen.common import ( + IndentedBuffer, + SizeArg, + DeferredLine, +) +from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta + + + +def flatten(nums): + res = [] + for i in nums: + if isinstance(i, list): + res.extend(flatten(i)) + else: + res.append(i) + return res + + +class AxisDirection(Enum): + Flat = 0, + Vertical = 1, + Horizontal = 2 + + +def reverse_direction(direction): + if direction == AxisDirection.Vertical: + return AxisDirection.Horizontal + elif direction == AxisDirection.Horizontal: + return AxisDirection.Vertical + else: + return AxisDirection.Flat + + +class NPUTritonKernelOverrides(TritonKernelOverrides): + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def sqrt(x): + return f"tl_math.sqrt({x})" + + @staticmethod + def tanh(x): + return f"tl_math.tanh({x})" + + @staticmethod + def rsqrt(x): + return f"tl.rsqrt({x})" + + @staticmethod + def floor(x): + return f"tl_math.floor({x})" + + @staticmethod + def erf(x): + return f"tl_math.erf({x})" + + @staticmethod + def ceil(x): + return f"tl_math.ceil({x})" + + +class NumelList(Tuple): + def numels(self): + numel = functools.reduce(lambda a, b: a * b, self) + return numel + + def __eq__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel == numel2 + + def __mod__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel % numel2 + + def __truediv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel / numel2 + + def __floordiv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel // numel2 + + def __mul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __rmul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __add__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __radd__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __hash__(self): + return super(NumelList, self).__hash__() + + +def group_fn(self, sizes): + groups = list() + for s in sizes: + if not s: + groups.append(1) + elif isinstance(s, list): + group = flatten(s) + groups.append(NumelList(tuple(group)) if isinstance(group, list) else group) + else: + groups.append(s) + return tuple(groups) + + +@staticmethod +def select_index_dtype(node_schedule, numel, reduction_numel): + return "tl.int32" + + +@classmethod +def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + return (numel, reduction_numel) + + +class IterationRangesEntryNPUIndex(IterationRangesEntry): + def __init__( + self, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_tiling_axis1 = False + self.is_tiling_axis2 = False + self.is_split_axis = False + self.indexing_code = IndentedBuffer() + self.sorted_order = None + self.low_dims = set() + + def _codegen_mask(self): + if self.is_tiling_axis1 or self.is_tiling_axis2: + upper = f"{self.name}_numel" + line = f"{self.name}_mask = {self.name} < {upper}" + self.writeline(line) + line = f"{self.name}_prime_mask = {self.name}_prime < {upper}" + self.writeline(line) + else: + pass + + def _codegen(self): + index = None + vertical = self.is_tiling_axis1 if V.kernel.numof_reduction_axis() <= 1 else not isinstance(self.expr, ModularIndexing) + direction = V.kernel.get_axis_direction(vertical) + # for multiple reduce dims, don't need this + if self.is_tiling_axis1 and V.kernel.numof_reduction_axis() <= 1: + index = f"{self.name} = {self.codegen_index(direction)}" + #to be fixed, only permute need to this . + self.writeline(f"{self.name}_prime = {self.codegen_index(reverse_direction(direction))}") + + elif self.is_tiling_axis2: + index = f"{self.name} = {self.codegen_index(direction)}" + #to be fixed, only permute need to this . + self.writeline(f"{self.name}_prime = {self.codegen_index(reverse_direction(direction))}") + if V.kernel.inside_reduction and V.kernel.current_node \ + and isinstance(V.kernel.current_node, SchedulerNode) \ + and V.kernel.current_node.node \ + and V.kernel.current_node.node.data \ + and isinstance(V.kernel.current_node.node.data, ir.Reduction): + reduction_type = V.kernel.current_node.node.data.reduction_type + if reduction_type in {"argmax", "argmin"}: + self.writeline(f"{self.parent.prefix}index = " + f"{self.codegen_index(reverse_direction(AxisDirection.Flat))}") + if index: + self.writeline(index) + self._codegen_mask() + return self.name + + def writeline(self, line): + self.indexing_code.writeline(line) + + def codegen_index(self, direction): + index = "" + if self.is_tiling_axis1 and V.kernel.axis2 is None and V.kernel.persistent_reduction: + index = f"tl.arange(0, RBLOCK)" + + elif self.is_tiling_axis1: + if self.is_split_axis: + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop1 * XBLOCK_SUB) + base1" + else: + index = f"(loop1 * XBLOCK_SUB) + base1" + if V.kernel.axis2 is not None and direction != AxisDirection.Flat: + index += ("[None,:]" if direction == AxisDirection.Horizontal else "[:, None]") + + elif self.is_tiling_axis2: + if V.kernel.persistent_reduction: + index = f"tl.arange(0, RBLOCK_{self.symbol()})" if V.kernel.numof_reduction_axis() > 1 else "base2" + elif self.is_split_axis: + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop2 * RBLOCK) + base2" + else: + index = "loop2 * RBLOCK + base2" + + if direction != AxisDirection.Flat: + index += ("[:, None]" if direction == AxisDirection.Vertical else "[None,:]") + + return index + + def codegen_header(self, code): + # generate offset index loop + lines = [] + + if self.is_split_axis and not (V.kernel.axis2 is None and V.kernel.persistent_reduction): + lines.append(f"{self.symbol()}_offset = tl.program_id(0) * XBLOCK") + + if self.is_tiling_axis1 and not (V.kernel.axis2 is None and V.kernel.persistent_reduction): + # don't create loops for multi-reductions + if V.kernel.numof_reduction_axis() <= 1: + lines.append("base1 = tl.arange(0, XBLOCK_SUB)") + xblock = f"XBLOCK" if self.is_split_axis else f"{self.symbol()}_numel" + lines.append(f"loops1 = ({xblock} + XBLOCK_SUB - 1) // XBLOCK_SUB") + + elif self.is_tiling_axis2 and len(V.kernel.axis2_list) <= 1: + lines.append("base2 = tl.arange(0, RBLOCK)") + if self.is_split_axis: + lines.append(f"loops2 = (XBLOCK + RBLOCK - 1) // RBLOCK") + else: + lines.append(f"loops2 = ({self.name}_numel + RBLOCK - 1) // RBLOCK") + else: + pass + + code.writelines(lines) + + +class IterationRangesRootNPUIndex(IterationRangesRoot): + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: TritonKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + ): + super().__init__(name, numel, prefix, index, kernel, pid_cache, is_loop=is_loop, tensor_dim=tensor_dim, + grid_dim=grid_dim) + + def __repr__(self): + return f"IterationRangesRootNPUIndex({self.name!r}, {self.numel}, ...)" + + def remove_entry(self, name): + if name in self.var_ranges: + del self.var_ranges[name] + if name in self.var_list: + del self.var_list[self.var_list.index(name)] + if name in V.kernel.range_tree_nodes: + V.kernel.range_tree_nodes_removed[name] = V.kernel.range_tree_nodes[name] + del V.kernel.range_tree_nodes[name] + if name in self.nodes: + del self.nodes[name] + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntryNPUIndex( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + + + return self.nodes[expr] + + +def is_compatible(groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]]): + try: + groups = flatten(groups) + NPUIndexTritonKernel._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + +class NPUIndexTritonKernel(TritonKernel): + overrides = NPUTritonKernelOverrides + + def __init__(self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + min_elem_per_thread=0, + disable_persistent_reduction=False,): + + super().__init__(*groups, index_dtype=index_dtype, mutations=mutations, pid_cache=pid_cache, + reduction_hint=reduction_hint, min_elem_per_thread=min_elem_per_thread, + disable_persistent_reduction=disable_persistent_reduction) + self.first_node = True + self.inside_high_order_reduction = False + # split axis + self.split_axis = None + # tiling axis + self.axis1 = None + self.axis2 = None + # incase two reduction axis + self.axis2_list = [] + self.low_dims = set() + + self.range_tree_nodes_removed: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.range_tree_nodes_substituted = {} + self.expr_substituted = {} + self.sorted_axis = [] + self.prefix: IndentedBuffer = IndentedBuffer() + + def gen_triton_ext_imports(self): + imports = IndentedBuffer() + imports.splice( + """ + from torch._inductor import triton_helpers + from torch_npu._inductor import npu_triton_heuristics + from torch_npu._inductor import npu_triton_helpers + from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math + import torch + """ + ) + return imports.getvalue() + + def patch_triton_hash(self): + # remove this method once the original invocation is fixed + import hashlib + from triton.compiler.compiler import triton_key, make_backend + from triton.runtime.driver import driver + backend = make_backend(driver.active.get_current_target()) + key = f"{triton_key()}-{backend.hash()}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # persistent_reduction means reduction without loop2 + # for big reduction numel (> 1024), should use outer reduction or loop2 inner reduction + def should_use_persistent_reduction(self) -> bool: + if not (self.inside_reduction and config.triton.persistent_reductions): + return False + threshold = { + ReductionHint.INNER: 1024, + ReductionHint.DEFAULT: 1024 + }.get(self.reduction_hint, 64) + + if config.triton.multi_kernel: + threshold *= 16 + last_numel = self.numels[-1] + if isinstance(last_numel, (list, NumelList)): + last_numel = NumelList(last_numel).numels() + self.numels[-1] = last_numel + + if not isinstance(last_numel, (int, sympy.Integer)): + # Not static + return False + hint = V.graph.sizevars.size_hint(last_numel) + if hint > threshold: + return False + # will need to recompile if we cross a larger power of 2 boundary + V.graph.sizevars.guard_leq(last_numel, next_power_of_2(hint)) # type: ignore[arg-type] + return True + + def numof_reduction_axis(self): + root = self.range_trees[-1] + if root is None: + return 0 + + return len(root.var_list) + + def numof_tiling_axis(self): + return (1 if self.axis1 is not None else 0) + (1 if self.axis2 is not None else 0) + + def initialize_range_tree(self, pid_cache): + self.numels = flatten(self.numels) + self.total_numels = 0 + for x in self.numels: + if not isinstance(x, sympy.Integer): + x = x.subs(V.graph.sizevars.var_to_val) + if x > 1: + self.total_numels += 1 + no_r_dim = not self.inside_reduction or self.numels[-1] == 1 + prefixes = "wvtpyxr" + active_prefixes = prefixes[-len(self.numels):] + #prefix can not be 's', 'u', 'ps' , 'i', 'z', 'q' + grid_dims = "xyptvw" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyptvw" + else: + tensor_dims = "xyptvwr" + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix == "r" + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRootNPUIndex( + f"{prefix}index", + self.numels[i], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim + ) + ) + + # numels sent to autotune configs + def get_size_hints(self): + size_hints = [] + + if (len(self.range_tree_nodes.values()) == 0): + return self.numels + for _, node in enumerate(self.sorted_axis): + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + numel_expr = V.graph.sizevars.symbolic_hint(numel_expr) + + size_hints.append(numel_expr) + return size_hints + + def add_numel_to_call_args_and_grid(self, name, call_args, grid): + for node in self.sorted_axis: + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + if isinstance(numel_expr, (sympy.Integer, sympy.Symbol)): + expr = numel_expr + else: + expr = V.graph.wrapper_code.generate_node_numel_expr(name, node, numel_expr) + call_args.append(expr) + if node.parent.grid_dim is not None: + grid.append(expr) + + def gen_numel_args(self, signature, triton_meta_signature, argdefs): + for node in self.sorted_axis: + if not os.environ.get('INDUCTOR_STATIC_MODE'): + sizearg = SizeArg(f"{node.name}_numel", node.length) + signature.append(sizearg) + triton_meta_signature[len(argdefs)] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(f"{node.name}_numel") + else: + argdefs.append(f"{node.name}_numel: tl.constexpr") + self.triton_meta["constants"][f"{node.name}_numel"] = node.length + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + size_hints = self.get_size_hints() + heuristics = self._get_heuristic() + if name is None: + code.splice(gen_common_triton_imports()) + # Note: add extra imports for extensions + code.splice(self.gen_triton_ext_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature = self.args.python_argdefs() + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype + ) + triton_meta = { + "signature": triton_meta_signature, + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + # special config for NPU, specify compile target + "mix_mode": "aiv", + } + + inductor_meta = self.create_inductor_meta() + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + self.triton_meta = triton_meta + self.gen_numel_args(signature, triton_meta_signature, argdefs) + + #add in tiling args + self.add_autotune_args(argdefs) + #for scalar codegen + if len(self.range_tree_nodes) == 0: + self.write_scalar() + else: + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + # Note: override original triton_heuristics + if self.inside_reduction: + reduction_hint = self.reduction_hint + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + def codegen_static_numels(self, code): + no_x_axis = self.numof_reduction_axis() > 1 + symbols = [] + if self.axis2 is not None: + symbols = list(self.axis2_list) if no_x_axis else list([self.axis2]) + elif self.persistent_reduction and self.axis1 is not None: + symbols = list([self.axis1]) + + nodes = [self.range_tree_nodes[symbol] for symbol in symbols if symbol is not None] + for node in nodes: + if node.prefix == "r" and self.persistent_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(node.length) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + val = next_power_of_2(val) + if no_x_axis: + code.writeline(f"RBLOCK_{node.symbol()}: tl.constexpr = {val}") + else: + code.writeline(f"RBLOCK: tl.constexpr = {val}") + + def axis2_variable(self): + if self.axis2 is not None: + return self.range_tree_nodes[self.axis2] + return None + + def is_isolated_symbol(self, input_str, symbol): + # 使用正则表达式查找独立的符号, 防止out_ptr0 匹配上r0 r0_prime + pattern1 = r'\b' + re.escape(symbol) + r'\b' + pattern2 = r'\b' + re.escape(symbol + '_prime') + r'\b' + + return bool(re.search(pattern1, input_str)) or bool(re.search(pattern2, input_str)) + + def find_axis2_in_load_store(self): + var = self.axis2_variable() + if not var: + return False + for line in self.loads._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.compute._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.suffix._lines: + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, var.name): + return True + for line in self.stores._lines: + if isinstance(line, DeferredLine): + line = line.line + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, var.name): + return True + return False + + def find_axis2_in_indexing(self): + var = self.axis2_variable() + if not var: + return False + if self.current_node is None: + return False + for index in self.current_node._body.indexing.values(): + if var.symbol() in index.free_symbols: + return True + return False + + def write_scalar(self): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.suffix.clear() + self.prefix.clear() + + def codegen_body(self): + if not ( + self.loads + or self.stores + or self.compute + or self.suffix + ): + return + + def write_pointwise(): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + def is_1d_reduction(): + return self.numels[-1] > 1 and self.axis2 is None + + def codegen_range(index): + def loop_body(index, indexing_code, is_last_axis, do_indent=True): + if do_indent: + self.body.do_indent() + if indexing_code: + self.body.splice(indexing_code) + + if is_last_axis: + write_pointwise() + else: + codegen_range(index + 1) + + if do_indent: + self.body.do_unindent() + + if index < 0 or index >= len(self.range_tree_nodes): + return + nodes = self.sorted_axis + node_range = nodes[index] + is_tilling_asix1 = getattr(node_range, "is_tiling_axis1") + is_tilling_asix2 = getattr(node_range, "is_tiling_axis2") + is_last_axis = index == len(nodes) - 1 + indexing_code = getattr(node_range, "indexing_code") + numof_axis2 = self.numof_reduction_axis() + if is_tilling_asix1: + do_indent = True + reduction_1d = is_1d_reduction() + if reduction_1d: + self.body.splice(self.prefix) + self.prefix.clear() + + # multi-dim reduction, i.e. var_mean[1,2] + if numof_axis2 > 1: + if node_range.is_split_axis: + offset = f"{node_range.name}node_range" + self.body.writeline(f"for {node_range.name} in range({offset}, " + f"min({offset} + XBLOCK), {node_range.name}_numel)):") + else: + self.body.writeline(f"for {node_range.name} in range({node_range.name}_numel):") + # 1D persistent_reduction or 1d reduction non-first-node + elif self.axis2 is None and (self.persistent_reduction or len(self.loads._lines) == 0): + do_indent = False + if len(self.loads._lines) == 0: + indexing_code = None + else: + self.body.writeline(f"for loop1 in range(loops1):") + + if not reduction_1d and self.persistent_reduction: + self.body.do_indent() + self.body.splice(self.prefix) + self.prefix.clear() + self.body.do_unindent() + + loop_body(index, indexing_code, is_last_axis, do_indent=do_indent) + + # for 1D reduction, need to add in suffix for persist_reduction or second node of 1d reduction + if is_1d_reduction() or self.persistent_reduction: + self.body.splice(self.suffix) + self.suffix.clear() + + elif is_tilling_asix2: + do_indent = False + need_axis2_loop = self.find_axis2_in_load_store() + if not need_axis2_loop: + indexing_code = None + if (not self.inside_reduction or not self.persistent_reduction) \ + and need_axis2_loop: + self.body.splice(self.prefix) + self.body.writeline(f"for loop2 in range(loops2):") + do_indent = True + loop_body(index, indexing_code, is_last_axis, do_indent) + self.body.splice(self.suffix) + self.suffix.clear() + + # pointwise, last axis = 1 + elif is_last_axis and node_range.numel == 1: + write_pointwise() + else: + if node_range.is_split_axis: + offset = f"{node_range.symbol()}_offset" + self.body.writeline(f"for {node_range.symbol()} in range({offset}, min({offset} + XBLOCK, {node_range.name}_numel)):") + else: + self.body.writeline(f"for {node_range.symbol()} in range({node_range.name}_numel):") + loop_body(index, indexing_code, is_last_axis) + + if self.first_node: + for node in self.sorted_axis: + node.codegen_header(self.body) + + if self.first_node: + codegen_range(0) + else: + if self.axis2 is None: + codegen_range(0) + else: + axis2_order = self.range_tree_nodes[self.axis2].sorted_order + if self.persistent_reduction and self.numof_reduction_axis() > 1: + axis2_order = axis2_order - self.numof_reduction_axis() + 1 + for _ in node_range(axis2_order): + self.body.do_indent() + codegen_range(axis2_order) + for _ in node_range(axis2_order): + self.body.do_unindent() + + self.cse.invalidate(self.outside_loop_vars) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.suffix.clear() + self.prefix.clear() + self.first_node = False + + # for creat constant tensor, if have two axis, constant=tl.full([1,1]) else tl.full([1]) + def triton_tensor_ndim(self): + if self.numof_reduction_axis() > 1: + return 1 + if self.axis1 is not None and self.axis2 is not None: + ndim = 2 + else: + ndim = 1 + return ndim + + # fixme, indexing.mask_str is None , see varmean_test.py + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + if not self.inside_reduction: + raise ValueError(f"self.inside_reduction = False") + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + if isinstance(indexing, BlockPtrOptions): + self.suffix.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + if not isinstance(indexing, IndexingOptions): + raise TypeError("indexing is not IndexingOptions") + line = f"tl.store({var} + ({indexing.index_str} ), {value}, {indexing.mask_str})" + if self.numof_reduction_axis() > 1: + line = f"tl.store({var} + ({indexing.index_str} + tl.arange(0,1) ), {value}, {indexing.mask_str})" + self.suffix.writeline( + DeferredLine(name, line) + ) + + def apply_var_prime(self, index, line, mask): + # axis should only be replaced once + axis_list = [] + for key in index.as_coefficients_dict().keys(): + if not key.free_symbols: + continue + symbol = list(key.free_symbols)[0] + if symbol not in self.range_tree_nodes: + continue + range_tree_node = self.range_tree_nodes[symbol] + if (range_tree_node.is_tiling_axis1 or range_tree_node.is_tiling_axis2) and (symbol not in axis_list): + line = line.replace(f"{range_tree_node.name}", f"{range_tree_node.name}_prime") + mask = mask.replace(f"{range_tree_node.name}", f"{range_tree_node.name}_prime") + axis_list.append(symbol) + return line, mask + + # apply xxx_prime var in case dim are permuted + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + index_str = indexing.index_str + value_str = f"{value}" + + # need to reshape when value's dimensions > 2, e.g. (XBLOCK,1,RBLOCK) + is_permuted = self.need_permuted(index) + + mask_str = indexing.mask_str + if is_permuted: + index_str, mask_str = self.apply_var_prime(index, index_str, indexing.mask_str) + value_str = value_str.replace(f"{value}", f"{value}.permute(1,0)") + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({index_str}), {value_str}, {mask_str})" + if len(self.axis2_list) > 1: + line = f"tl.store({var} + ({index_str} + tl.arange(0,1) ), {value_str}, {indexing.mask_str})" + + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index_str}), {value_str}, {indexing.mask_str})" + else: + raise NotImplementedError(f"store mode={mode}") + + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + + @staticmethod + def _get_next_scheduler_node(node_schedule, current_node): + found_current = False if current_node else True + for node in node_schedule: + if isinstance(node, SchedulerNode): + if not found_current and node.get_name() == current_node.get_name(): + found_current = True + continue + if found_current: + return node + return None + + #fixme, this seems not reliable, need to refactor . + def get_next_scheduler_node(self, node): + return self._get_next_scheduler_node(self.node_schedule, node) + + def get_prev_scheduler_node(self, node): + return self._get_next_scheduler_node(reversed(self.node_schedule), node) + + def check_all_index_is_1d_for_dual_reduction(self): + if self.numof_reduction_axis() <= 1: + return False + + all_index_is_1d = True + for _, index in self.current_node._body.indexing.items(): + count = 0 + for symbol in index.free_symbols: + if symbol in self.axis2_list: + count = count + 1 + if count > 1: + all_index_is_1d = False + + if not all_index_is_1d: + break + return all_index_is_1d + + # to generate the shape of the accumulator of RBLOCK loop + def dense_size_list(self, is_permute) -> List[str]: + sizes = [] + if self.numof_reduction_axis() > 1: + sizes = [] if self.check_all_index_is_1d_for_dual_reduction() else [f"RBLOCK_{axis}" for axis in self.axis2_list] + return sizes + if self.persistent_reduction and self.axis2 is None: + sizes = ["RBLOCK"] + return sizes + # current computedbuffer is reduction + cb_is_reduction = self.inside_reduction if not self.current_node else isinstance(self.current_node.node.data, ir.Reduction) + + for tree in self.sorted_axis: + if tree.is_tiling_axis1: + sizes.append("XBLOCK_SUB") + elif tree.is_tiling_axis2: + sizes.append("RBLOCK") + + if cb_is_reduction and self.inside_reduction and self.is_higher_order_reduction() or is_permute: + sizes = sizes[::-1] + + return sizes + + def dense_size_str(self, is_permute=False): + sizes = self.dense_size_list(is_permute) + if self.numof_reduction_axis() > 1: + return f"[{'* '.join(sizes)}]" + return f"[{', '.join(sizes)}]" + + def filter_masks(self, mask_vars): + for node in self.sorted_axis: + if not (node.is_tiling_axis1 or node.is_tiling_axis2): + mask_vars.discard(f"{node.name}_mask") + if len(self.axis2_list) > 1 and not node.is_tiling_axis2: + mask_vars.discard(f"{node.name}_mask") + + # and add to shape to value + def reduction_resize(self, value): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + is_higher_order_reduction = self.is_higher_order_reduction() + + expand_str = "1," if is_higher_order_reduction else ",1" + if is_higher_order_reduction: + return f"{value}.reshape({expand_str}XBLOCK_SUB)" + else: + return f"{value}.reshape(XBLOCK_SUB{expand_str})" + + def get_axis_direction(self, is_axis1, is_reversed=False): + if self.check_all_index_is_1d_for_dual_reduction(): + result = AxisDirection.Flat + elif not self.inside_reduction: + if self.numof_tiling_axis() > 1: + result = AxisDirection.Vertical if is_axis1 else AxisDirection.Horizontal + else: + result = AxisDirection.Flat + else: + if is_axis1: + result = AxisDirection.Horizontal if V.kernel.is_higher_order_reduction() else AxisDirection.Vertical + else: + result = AxisDirection.Vertical if V.kernel.is_higher_order_reduction() else AxisDirection.Horizontal + + result = reverse_direction(result) if is_reversed else result + return result + + def is_higher_order_reduction(self, check_prev_node=False): + if self.numof_reduction_axis() > 1: + return False + if not self.inside_reduction: + raise ValueError(f"self.inside_reduction = False") + if self.inside_high_order_reduction: + return self.inside_high_order_reduction + + node = self.current_node if self.current_node is not None else self.get_prev_scheduler_node(None) + if node is None or not isinstance(node, SchedulerNode): + return False + + reduction = node.node.data + while check_prev_node and reduction is not None and not isinstance(reduction, ir.Reduction): + node = self.get_prev_scheduler_node(node) + if node is None: + reduction = None + else: + reduction = node.node.data + + if reduction is None or not isinstance(reduction, ir.Reduction): + return False + if not hasattr(reduction, "reduced_idx"): + return False + + reduced_order = reduction.reduced_idx[0] + is_last_axis = all(_ < reduced_order for _ in reduction.kept_idx) + self.inside_high_order_reduction = not is_last_axis + return self.inside_high_order_reduction + + def get_axis_dtype(self, axis): + dtype = None + if axis is None: + return None + for node in self.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + if axis.symbol() in node._body.indexing_map: + dtype = V.graph.get_dtype(node.node.name) + break + if dtype is None: + should_break_all = False + for node in self.node_schedule: + if should_break_all: + break + if node in (EnableReduction, DisableReduction): + continue + for key, _ in node._body.indexing_map.items(): + if key in self.range_tree_nodes: + dim = self.range_tree_nodes[key] + else: + dim = self.range_tree_nodes_removed[key] + + if dim.parent == axis.parent: + dtype = V.graph.get_dtype(node.node.name) + should_break_all = True + break + return dtype + + def create_inductor_meta(self): + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + axis1_order = self.range_tree_nodes[self.axis1].sorted_order if self.axis1 is not None else None + axis2_order = self.range_tree_nodes[self.axis2].sorted_order if self.axis2 is not None else None + split_axis_dtype = self.get_axis_dtype(self.split_axis) + inductor_meta = { + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "no_x_dim": self.no_x_dim, + # Due to breaking change of triton 3.0, the original invocation is broken + "backend_hash": self.patch_triton_hash(), # torch.utils._triton.triton_hash_with_backend(), + "split_axis_order": self.split_axis.sorted_order if self.split_axis is not None else None, + "axis1_order": axis1_order, + "axis2_order": axis2_order, + "low_dims": self.low_dims, + "numof_reduction_axis": self.numof_reduction_axis(), + "split_axis_dtype": split_axis_dtype + } + return inductor_meta + + def reduction_dim(self): + if not self.inside_reduction: + raise ValueError(f"self.inside_reduction = False") + if self.numof_reduction_axis() > 1: + return 0 + return 0 if self.is_higher_order_reduction() or len(self.sorted_axis) == 1 else 1 + + def reduction_var(self): + var = self.axis2 + return var + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + if not self.inside_reduction: + raise ValueError(f"self.inside_reduction = False") + masks = {f"{node.symbol()}_mask" for node in self.sorted_axis} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + + dense_size_str = self.dense_size_str(False) + + if len(dense_size_str) > 2: + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.reshape({v}, {dense_size_str})" + ), + value, + ) + + dim: int + root_op: str + + def final_reduction(value): + # use tl + module = "tl" + if reduction_type in {"max", "min"}: + # use tl.max + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp')} + """ + ) + + def get_reduction_axis(): + return list(self.range_tree_nodes.values())[-1] + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.reduction_dim() + acc_type = triton_acc_type(src_dtype) + result_var: Any = self.cse.newvar() + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton_constant, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default)) + # masked_value doesn't work dual reduction + if self.numof_reduction_axis() == 1: + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + else: + masked_value = value + + if reduction_type in {"argmax", "argmin", "max", "min"}: + reduce_axis = get_reduction_axis() + broadcast_string: str + if self.is_higher_order_reduction(): + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reduction_range_prefix.upper()}BLOCK,1), {masked_value}.shape)" + else: + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape(1,{reduction_range_prefix.upper()}BLOCK), {masked_value}.shape)" + accumulator_index = str( + self.cse.generate( + self.compute, + broadcast_string + ) + ) + if reduction_type == "argmax" or reduction_type == "argmin": + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "max" or reduction_type == "min": + result_var = self.cse.generate( + self.compute, final_reduction(masked_value) + ) + elif reduction_type == "welford_reduce": + raise ValueError(f"welford_reduction is not supported now..") + elif reduction_type == "welford_combine": + raise ValueError(f"welford_combine is not supported now..") + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value) + ) + else: + accumulator = f"_{result_var}" + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton_constant, default) + if not isinstance(default, tuple): + self.prefix.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.prefix.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.suffix, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + raise ValueError(f"welford_reduction is not supported now..") + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + #XBLICK:split size, XBLOCK_SUB: tile1 size, RBLOCK:tile2 size + def add_autotune_args(self, argdefs): + # no tiling in this case + if self.persistent_reduction and self.axis2 is None: + return + argdefs.append(f"XBLOCK: tl.constexpr") + if self.numof_reduction_axis() <= 1: + argdefs.append(f"XBLOCK_SUB: tl.constexpr") + if self.axis2 is not None and not self.persistent_reduction: + argdefs.append(f"RBLOCK: tl.constexpr") + + def _get_heuristic(self): + if self.persistent_reduction: + if not self.inside_reduction: + raise ValueError(f"self.inside_reduction = False") + return "persistent_reduction_npu_index" + elif self.inside_reduction: + return "reduction_npu_index" + return "pointwise_npu_index" + + def need_broadcast(self, index: sympy.Expr): + tiling_axis = [False, False] + for axis in index.free_symbols: + if axis not in self.range_tree_nodes: + continue + if self.range_tree_nodes[axis].is_tiling_axis1: + tiling_axis[0] = True + elif self.range_tree_nodes[axis].is_tiling_axis2: + tiling_axis[1] = True + #implict broadcast + result = (self.numof_tiling_axis() > 1 and not self.persistent_reduction) and (tiling_axis[1] ^ tiling_axis[0]) + result = result and self.find_axis2_in_indexing() + return result, tiling_axis + + def current_node_has_permute(self): + if not self.current_node: + return False + for index in self.current_node._body.indexing.values(): + if self.need_permuted(index): + return True + return False + + def need_permuted(self, index: sympy.Expr): + if self.numof_tiling_axis() <= 1: + return False + + need_permute = False + tmp_list = [] + coefficients_dict = index.as_coefficients_dict() + need_permute_axis1 = False + need_permute_axis2 = False + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.range_tree_nodes: + continue + axis = self.range_tree_nodes[key] + # normally, axis2 is lowest dimension, except for higher_order_reduction + if (self.inside_reduction and self.is_higher_order_reduction(True)): + if axis.is_tiling_axis1 and value > sympy.Integer(1): + need_permute_axis1 = True + elif axis.is_tiling_axis2 and value > sympy.Integer(1): + need_permute_axis2 = True if self.numof_reduction_axis() <= 1 else isinstance(axis.expr, ModularIndexing) + tmp_list.append(True if value > sympy.Integer(1) else False) + + # If all axes have coefficients greater than 1, + # then the stride is not 1, and in this case, return false, + # indicating that the transpose is not required. + if all(tmp_list): + return False + return need_permute_axis1 or need_permute_axis2 + + def get_reshape_dense_str(self, tiling_axis): + # there must be one tiling asis missing + if not tiling_axis[1] and not tiling_axis[0]: + raise ValueError(f"Both tiling_axis[1] and tiling_axis[0] should not be 0") + sizes = ["XBLOCK_SUB", "1"] + if not tiling_axis[0]: + sizes = ["1", "RBLOCK"] + + if self.inside_reduction and self.is_higher_order_reduction(): + sizes = reversed(sizes) + return f"[{', '.join(sizes)}]" + + def get_reshape_str(self, tiling_axis, check_prev_node=True): + # there must be one tiling asis missing + if not tiling_axis[1] and not tiling_axis[0]: + raise ValueError(f"Both tiling_axis[1] and tiling_axis[0] should not be 0") + sizes = ["XBLOCK_SUB", "RBLOCK"] + if not tiling_axis[0]: + sizes[0] = "1" + elif not tiling_axis[1]: + sizes[1] = "1" + if self.inside_reduction and self.is_higher_order_reduction(check_prev_node): + sizes = reversed(sizes) + + return f"[{', '.join(sizes)}]" + + def get_broadcast_dense_str(self, tiling_axis, check_prev_node=True): + # there must be one tiling asis missing + if not tiling_axis[1] and not tiling_axis[0]: + raise ValueError(f"Both tiling_axis[1] and tiling_axis[0] should not be 0") + sizes = ["XBLOCK_SUB", "RBLOCK"] + if self.inside_reduction and self.is_higher_order_reduction(check_prev_node): + sizes = reversed(sizes) + return f"[{', '.join(sizes)}]" + + #broadcast, permute handling + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + original_index = index + is_permuted = self.need_permuted(index) + store_cache = self.cse.store_cache + if name in store_cache: + broadcasted, tiling_axis = self.need_broadcast(original_index) + result_var = store_cache[name] + if broadcasted: + line = f"{result_var}.broadcast_to({self.get_broadcast_dense_str(tiling_axis, True)})" + buffer = self.compute if self.persistent_reduction else self.loads + result_var = self.cse.generate(buffer, line) + elif is_permuted: + line = f"{result_var}.permute(1,0)" + buffer = self.compute if self.persistent_reduction else self.loads + result_var = self.cse.generate(self.loads, line) + return result_var + + need_broadcast, tiling_axis = self.need_broadcast(index) + indirect_indexing = self.is_indirect_indexing(index) + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + ep = "" + if ( + (has_tmpmask or has_rindex) + and V.graph.get_dtype(name) != torch.bool + and indexing.has_mask() + ): + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + line = triton_reshape( + line, indexing.block_shape, indexing.reshape_suffix + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + num_size = len(self.dense_size_list(is_permuted)) + append_broadcast = "[1, 1]" if (num_size > 1) else "[1]" + else: + index_str = indexing.index_str + mask_str = indexing.mask_str + if is_permuted: + index_str, mask_str = self.apply_var_prime(index, index_str, mask_str) + line = f"tl.load({var} + ({index_str}), {mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + line += ".to(tl.int1)" + if has_tmpmask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indirect_indexing + and not has_rindex + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.prefix + + else: + load_buffer = self.loads + result_var = self.cse.generate(load_buffer, line) + if not isinstance(result_var, TritonCSEVariable): + raise TypeError("result_var is not TritonCSEVariable") + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast and append_broadcast != '[]': + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line) + elif need_broadcast and not indirect_indexing: + line = f"{result_var}.broadcast_to({self.get_broadcast_dense_str(tiling_axis)})" + result_var = self.cse.generate(load_buffer, line) + elif is_permuted: + line = f"{result_var}.permute(1,0)" + result_var = self.cse.generate(self.loads, line) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + #1. only remove the line which asserts index var should be in "xyr" + #2. don't do simplify_indexing, which combine continuous dims + #3. removed block_ptr, removed dense mask/broadcast support + # fixme, dense_mask_vars should be generated from sorted_axis + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ) -> Union[IndexingOptions, BlockPtrOptions]: + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + s.name.startswith("s") or s.name.startswith("ps") for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + if not isinstance(var, sympy.Symbol): + raise TypeError("var is not sympy.Symbol") + has_rindex = has_rindex or var.name.startswith("r") + if override_mask: + pass + elif var.name.startswith("tmp"): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif var.name.startswith(("s", "ps", "i")): + pass + else: + # var is one of xN, yN or rN + mask_vars.add(f"{var.name}_mask") + + expand_str = None + index_str = self.index_to_str(index) + is_permute = self.need_permuted(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str(is_permute) + if (index != 0): + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + else: + index_str = f"tl.arange(0,1)" + return IndexingOptions(index_str, set(), "None", expand_str, has_rindex) + + if override_mask: + mask_vars = {override_mask} + if self._load_mask: + mask_vars.add(self._load_mask) + self.filter_masks(mask_vars) + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type] + + #support split multiple ranges (instead of double) from one flatten range, triple-ranges are needed in mamba model + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + for i, group in enumerate(remaining): + if isinstance(group, (list, tuple)): + remaining[i] = NumelList(group).numels() + + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit() + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(strides, index_list): + def getter(flat_vars): + expr = sympy.Integer(0) + for stride, index in zip(strides, index_list): + expr = stride * flat_vars[index] + expr + return expr + + return getter + + def size_hints(group): + if isinstance(group, (list, tuple)): + return sv.size_hint(NumelList(group).numels()) + return sv.size_hint(group) + + def add_multiple_range(size, return_getters): + # need to break size in multiple + index_list = [] + stride_list = [] + group = current_group + remained_size = size + # Two checks: + # 1. remaining sizes to be merged + # 2. remained_size is already divided to 1 + while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + group_size = remaining[group] + # size should be divisible by group_size + if not sv.statically_known_multiple_of(remained_size, group_size): + raise CantSplit() + index_list.append(add_range(group, group_size)) + remained_size = FloorDiv(remained_size, group_size) + stride_list.append(remained_size) + group = group + 1 + if remained_size != 1: + raise CantSplit() + return_getters.append(make_combined(stride_list, index_list)) + + return_getters_groups = [] + current_group = 0 + + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while ( + current_group < len(remaining) + and size_hints(remaining[current_group]) == 1 + ): + # scroll to next group with remaining elements + current_group += 1 + + if sv.size_hint(size) > size_hints(remaining[current_group]): + # add multiple ranges (two or more) to the list, as well as the getter funcs + add_multiple_range(size, return_getters) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + if any(V.graph.sizevars.size_hint(s) != 1 for s in remaining): + raise ValueError(f"failed to set ranges {remaining} {lengths}") + + return new_ranges, return_getters_groups + + # just to override load method of CSEProxy, however, CSEProxy is an inner which can not be monkey patched, + # we need to override the whole inner class + def __enter__(self): + class CSEProxy: + self.name = "CSEProxy" + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + # TritonTemplateKernel has no current_node + buf_bounds = ValueRanges.unknown() + if hasattr(V.interpreter, "current_node"): + fx_node = V.interpreter.current_node + if not isinstance(self.node_to_bounds, dict): + raise TypeError("self.node_to_bounds is not dict") + buf_bounds = self.node_to_bounds.get( + fx_node, ValueRanges.unknown() + ) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + + def do_cse(v): + csevar = self.cse.generate(self.compute, v, bounds=buf_bounds) + csevar.update_on_args(name, args, kwargs) + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def indirect_indexing( + var: CSEVariable, size: sympy.Expr, check: bool = True + ): + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg = var.bounds & ValueRanges(-sympy.oo, -1) + new_bounds = ValueRanges(neg.lower + size, neg.upper + size) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, sympy.oo) + new_bounds = new_bounds | pos + + stm = ops.add(var, self.rename_indexing(size)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, "0") + stm = ops.where(lt, stm, var) + new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + new_var.update_on_args("index_wrap", (var,), {}) + var = new_var + + if self.generate_assert(check): + mask = self.load_mask(var) + + # An assertion line may have been written already, if so just + # update the max size. + map_key = (var, mask) + existing_size, _ = self.indirect_max_sizes.get( + map_key, (None, None) + ) + if existing_size is not None: + size = sympy.Min(size, existing_size) + else: + pass + self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) + return sympy_index_symbol(str(var)) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + V.kernel.must_keep_buffers.add(name) + if free_symbol_startswith(index, "tmp"): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return self.load(name, index) + return self.load(name, index) + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + else: + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + return None + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + return self.scan(dtype, combine_fn, value, init) + + @staticmethod + def bucketize( + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + return self.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + # Use sympy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + if not self.overrides: + raise ValueError("self.overrides = False") + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd971ba027980ccede319c0a8870b1419ae37a --- /dev/null +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -0,0 +1,29 @@ + +import torch + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16 # torch.complex128 +} + + +def get_aligned_numel(dtype): + if dtype in byte_per_numel: + return 32 // byte_per_numel[dtype] + else: + return 1 + + diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..b1469d6dffc0e30025ff6673c4dd0d511d4cddb3 --- /dev/null +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -0,0 +1,48 @@ +from torch._inductor.codegen.wrapper import WrapperCodeGen, SymbolicCallArg +from torch._inductor.virtualized import V + + +class NPUWrapperCodeGen(WrapperCodeGen): + def __init__(self): + super().__init__() + + def write_triton_header_once(self) -> None: + self.header.splice( + """ + import triton + import triton.language as tl + from torch._inductor.triton_heuristics import split_scan_grid, start_graph, end_graph + from torch_npu._inductor.npu_triton_heuristics import grid + {} + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + #generate numel expr for range_tree_node + def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): + expr = f"{kernel_name}_{node.name}_numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, numel_expr) + + # don't free anything + def make_buffer_free(self, buffer): + return "" + + # don't assert + def codegen_input_size_asserts(self) -> None: + pass diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fe356796b989f32bd9be0e902208e5d7e982f20f --- /dev/null +++ b/torch_npu/_inductor/config.py @@ -0,0 +1,44 @@ +import os # noqa: C101 +import sys +import logging +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from triton.runtime.driver import driver +from torch._inductor import config +enable_npu_indexing = True + +config.triton.unique_kernel_names = True +# avoid test_opensora_cases_model_16_forward reinterpre_tensor issue +config.allow_buffer_reuse = False +#inductor debug switch +config.trace.enabled = True + +# npu hardware params from trion +target = driver.active.get_current_target() +device = driver.active.get_current_device() +prop = driver.active.utils.get_device_properties(device) + +num_cube_core = prop["num_aicore"] +num_vector_core = prop["num_aicore"] + +# unit byte +npu_block = 32 + +if ("Ascend910B" in target.arch): + num_vector_core = num_cube_core * 2 + +log_level_env = os.getenv('INDUCTOR_ASCEND_LOG_LEVEL', 'INFO').upper() +log_level_mapping = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} +log_level = log_level_mapping.get(log_level_env.upper(), logging.INFO) +logging.basicConfig( + level=log_level, + format='%(asctime)s - %(levelname)s - %(message)s' +) +log = logging.getLogger(__name__) + +aggresive_autotune = os.getenv("INDUCTOR_ASCEND_AGGRESSIVE_AUTOTUNE", '0').lower() in ('1', 'true') \ No newline at end of file diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..a31c27ff382a4d4e36274d998d16256d0d0266a4 --- /dev/null +++ b/torch_npu/_inductor/decomposition.py @@ -0,0 +1,48 @@ +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.decomposition import register_decomposition +import torch._ops +from .lowering import _init_set + +aten = torch.ops.aten + +DECOMPOSITION_OVERLOAD_OP = [ + aten._log_softmax, + aten.nll_loss_forward, + # aten.gelu_backward, + # aten.gelu, + aten.nll_loss_backward, + aten._log_softmax_backward_data, + aten.embedding_dense_backward +] + + +def _register_npu_inductor_decompositons(): + overload_op_set = set() + _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) + + for op in overload_op_set: + if (op in decompositions): + del decompositions[op] + + @register_decomposition([aten.scatter.src]) + @pw_cast_for_opmath + def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): + if self.device.type != "npu" or dim != 1: + raise ValueError("Device type must be 'npu' and dim must be 1") + (XNUMEL, YS) = input_tensor.shape + index_rblock = torch.arange(YS).npu().reshape((1, YS)).repeat((XNUMEL, 1)) + + index_tensor_brd = index_tensor.to(torch.int32).broadcast_to(XNUMEL, YS) + source_tensor_brd = source_tensor.broadcast_to(XNUMEL, YS).to(torch.float32) + scatter1 = torch.where(index_rblock == index_tensor_brd, 1.0, 0.0) * source_tensor_brd + return scatter1 + + @register_decomposition([aten.expm1]) + def expm1(x): + tensor = torch.exp(x) - torch.ones_like(x) + return tensor + + @register_decomposition([aten.erfc]) + def erfc(x): + tensor = torch.ones_like(x) - torch.exp(x) + return tensor \ No newline at end of file diff --git a/torch_npu/_inductor/embedding_backward_patch.py b/torch_npu/_inductor/embedding_backward_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..a0bd902bac98d294e3b9842998287fad88e20093 --- /dev/null +++ b/torch_npu/_inductor/embedding_backward_patch.py @@ -0,0 +1,11 @@ +import torch +from torch.library import Library, impl + +python_dispatcher_lib = Library("aten", "IMPL", "PythonDispatcher") + + +@impl(python_dispatcher_lib, "embedding_backward") +def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): + if sparse: + raise RuntimeError("the current NPU does not yet support sparse tensor, when sparse is set to True") + return torch.ops.aten.embedding_dense_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq) \ No newline at end of file diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..44b69f2e666832812e4e0a25352ad098f7112d68 --- /dev/null +++ b/torch_npu/_inductor/lowering.py @@ -0,0 +1,313 @@ +import sympy +from torch._inductor.ir import Reduction +from torch._inductor.utils import sympy_product +from torch._inductor import ir +from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper +from torch._inductor.lowering import sum_ +from torch._inductor import lowering +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +import torch._ops + +from torch._inductor.lowering import ( + lowerings, + make_fallback, + register_lowering, + to_dtype, + # make_reduction, + # reduce_amax, + # reduce_amin, + fallback_cumsum, + _validate_reduction_axis, + div, + squeeze, + square, + sub, + fallback_handler, + is_boolean_type, + logical_and, + make_pointwise, + _make_reduction_inner, + _validate_reduction_axis, +) + +import torch_npu +from torch_npu import npu_dtype_cast + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + # Only realize if reduction isn't unrolled + if isinstance(result.data.data, Reduction): + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + setattr(result.data.data, "kept_idx", kept_idx) + setattr(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + +lowering.make_reduction = make_reduction +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +GENERATE_LIST = [ + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.maximum, + aten.sum, + aten.select, + aten.unsqueeze, + aten.repeat, + #aten.clone, + aten.reshape, + aten.where, + aten.lt, + aten.minimum, + aten.gt, + aten.le, + aten.ceil, + aten.floor, + aten.rsqrt, + aten.abs, + aten.log, + aten.bitwise_xor, + aten.amax, + # backward + prims.convert_element_type, + aten.min, + aten.max, + aten.erf, + aten.argmax, + aten.argmin, + aten.clamp_min, + aten.slice, + aten.neg, + aten.cat, + aten.arange, + aten.expand, + aten.eq, + aten.where, + aten.scalar_tensor, + aten.ge, + aten.permute, + aten.sqrt, + aten.relu, + aten.clamp, + aten.clamp_max, + aten.mean, + # npu.npu_dtype_cast + npu_dtype_cast, + aten.select_scatter, + aten.slice_scatter, + prims.broadcast_in_dim, + prims.maximum, + aten.ne, + aten.sigmoid, + aten.sign, + aten.logical_and, + aten.logical_or, + aten.logical_not, + aten.pow, + aten.gelu, + aten.tanh, + aten.isnan, + aten.bitwise_and, + aten.squeeze, + aten.copy, + aten.reciprocal +] + +GENERATE_LIST2 = [ + "foreach" +] + +FALLBACK_LIST = [] + +# 先删除从lowering已经注册的op,再更新,不然会lowering的时候找到在torch注册的op +LOWERING_OVERLOAD_OP = [ + aten.cumsum, + aten.mean, + # aten.max, + # aten.min, + # aten.mul, + aten.var_mean, + aten.var, + + # work round for electraModel + aten.embedding, + aten.split, + aten.split_with_sizes, + aten.nll_loss_forward, + aten.gather, + aten.cat, + aten.clone +] + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + make_fallback(op) + FALLBACK_LIST.append(op) + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowerings: + del lowerings[op] + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + size = x.get_size() + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.bfloat16,): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int32 # torch.int64->torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must in [0, -1]") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return fallback_cumsum(x, dim=axis, dtype=dtype) + + @register_lowering(npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + # The welford reduction branch is annotated + # if use_two_step_variance(x,axis=axis,keepdim=keepdim) + # else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + # work round for electraModel backward + return fallback_handler(aten.cat.default)(inputs, dim) + + make_fallback(aten._log_softmax) + make_fallback(aten.gather) + make_fallback(aten.nll_loss_forward) diff --git a/torch_npu/_inductor/npu_fusion_attention_graph.py b/torch_npu/_inductor/npu_fusion_attention_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..323bbfc5eefc64f8c1a6a811654f0631cc40c4e5 --- /dev/null +++ b/torch_npu/_inductor/npu_fusion_attention_graph.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import functools +import sympy + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.library import Library, impl +import torch_npu + + +npu_def = Library("npu_graph", "DEF") +npu_lib = Library("npu_graph", "IMPL", "PrivateUse1") +meta_lib = Library("npu_graph", "IMPL", "Meta") + +npu_def.define("npu_fa(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)") +npu_def.define("npu_fa_backward(Tensor query, Tensor key, Tensor value, Tensor dy, int head_num, str input_layout, *, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? softmax_max=None, Tensor? softmax_sum=None, Tensor? softmax_in=None, Tensor? attention_in=None, float scale_value=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, Tensor? seed=None, Tensor? offset=None, Tensor? numels=None, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor)") + + +@impl(npu_lib, "npu_fa") +def npu_fa(*args, **kwargs): + if len(args) > 8: + args = list(args) + # for scale + try: + args[8] = 1.0 / args[8] + except IndexError: + args[8] = 1.0 / (args[8] + 1e-6) + print("args[8]: zero can not be divided") + r1, r2, r3, r4, seed, offset, numel = torch_npu.npu_fusion_attention(*args, **kwargs) + r2.requires_grad = False + r3.requires_grad = False + r4.requires_grad = False + return r1, r2, r3, r4, torch.tensor([seed], requires_grad=False), torch.tensor([offset], requires_grad=False), torch.tensor([numel], requires_grad=False) + + +@impl(npu_lib, "npu_fa_backward") +def npu_fa_backward(*args, **kwargs): + if 'scale_value' in kwargs: + kwargs['scale_value'] = 1.0 / kwargs['scale_value'] + return torch_npu.npu_fusion_attention_grad(*args, **kwargs) + + +@impl(meta_lib, "npu_fa") +def npu_fa(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + attention_score = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_out = torch.empty([0], dtype=query.dtype, device='meta') + return (torch.empty_like(attention_score), + torch.empty_like(softmax_max), + torch.empty_like(softmax_sum), + torch.empty_like(softmax_out), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False)) + + +@impl(meta_lib, "npu_fa_backward") +def npu_fa_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + + dq = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + dk = torch.empty_like(key, dtype=query.dtype, device='meta').contiguous() + dv = torch.empty_like(value, dtype=query.dtype, device='meta').contiguous() + dpse = torch.empty([0], dtype=query.dtype, device='meta').contiguous() + return (torch.empty_like(dq), torch.empty_like(dk), torch.empty_like(dv), torch.empty_like(dpse) if pse else None) + + +class NpuGraphAttentionFunction(Function): + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + # 前向传播逻辑 + # 这里假设有一个实现前向传播的函数 `npu_fusion_attention_forward` + result0, result1, result2, result3, result4, result5, result6 = torch.ops.npu_graph.npu_fa( + query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync + ) + # 保存中间结果,以便在反向传播中使用 + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6) + ctx.head_num = head_num + ctx.input_layout = input_layout + ctx.scale = scale + ctx.keep_prob = keep_prob + ctx.pre_tockens = pre_tockens + ctx.next_tockens = next_tockens + ctx.inner_precise = inner_precise + ctx.prefix = prefix + ctx.actual_seq_qlen = actual_seq_qlen + ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + + return result0, result1, result2, result3, result4, result5, result6 + + @staticmethod + def backward(ctx, grad_result0, grad_result1, grad_result2, grad_result3, grad_result4, grad_result5, grad_result6): + # 获取保存的中间结果 + query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6 = ctx.saved_tensors + # 反向传播逻辑 + # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` + grad_query, grad_key, grad_value, grad_pse = torch.ops.npu_graph.npu_fa_backward( + query, key, value, grad_result0, ctx.head_num, ctx.input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, softmax_max=result1, softmax_sum=result2, softmax_in=result3, attention_in=result0, scale_value=ctx.scale, keep_prob=ctx.keep_prob, pre_tockens=ctx.pre_tockens, next_tockens=ctx.next_tockens, inner_precise=ctx.inner_precise, seed=result4, offset=result5, numels=result6, prefix=ctx.prefix, actual_seq_qlen=ctx.actual_seq_qlen, actual_seq_kvlen=ctx.actual_seq_kvlen, sparse_mode=ctx.sparse_mode, gen_mask_parallel=ctx.gen_mask_parallel, sync=ctx.sync + ) + return (grad_query, grad_key, grad_value, None, None, grad_pse, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +def npu_fusion_attention_graph(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + return NpuGraphAttentionFunction.apply(query, key, value, head_num, input_layout, pse, padding_mask, + atten_mask, scale, keep_prob, pre_tockens, next_tockens, + inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, gen_mask_parallel, sync) +torch_npu.npu_fusion_attention_graph = npu_fusion_attention_graph + + +def register_fx_pass(): + TOKEN_MAX = 2147483647 + from torch._inductor.pattern_matcher import register_replacement, fwd_only, joint_fwd_bwd + from torch._inductor.fx_passes.joint_graph import patterns + from torch._dynamo.utils import counters + from torch._inductor.fx_passes.fuse_attention import partialize_and_update_signature + + def _npu_fusion_attention_graph_pattern_1(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + + def _npu_fusion_attention_graph_replacement_1(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + head_num = query.size(2) + input_layout = "BNSD" + return torch_npu.npu_fusion_attention_graph( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_num, + input_layout, + None, + atten_mask=None, + scale=inv_scale_factor, + keep_prob=1.0 - dropout_p, + )[0] + + def _get_sfdp_patterns(): + device = 'npu' + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + c_inp = functools.partial(torch.tensor, 2.0, device=device) + d = {"dropout_p": 0.113377} + candidates = [] + for dtype in [torch.float]: + g = functools.partial(g_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + candidates.append(( + _npu_fusion_attention_graph_pattern_1, + _npu_fusion_attention_graph_replacement_1, + [g(), g(), g(), c()], + d, + )) + + for pattern, replacement, args, workaround in candidates: + # when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + if not isinstance(workaround, dict): + raise TypeError("workaround is not dict") + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + if workaround: + if len(workaround) != 1 or "dropout_p" not in workaround: + raise ValueError("len(workaround) must be 1 and dropout_p must in workaround") + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + for _, register_replacement_kwargs in _get_sfdp_patterns(): + register_replacement( + **register_replacement_kwargs, + ) + +register_fx_pass() + + + diff --git a/torch_npu/_inductor/npu_triton_helpers.py b/torch_npu/_inductor/npu_triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..a18217298aaca73fb0e3ccb11617dff49838c6e5 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_helpers.py @@ -0,0 +1,21 @@ +import triton +import triton.language as tl +import triton.language.extra.ascend.libdevice as libdevice + +from torch._inductor import triton_helpers + +libdevice = tl.extra.ascend.libdevice +math = tl.math + + +@triton.jit +def maximum(a, b): + return tl.maximum(a, b) + + +@triton.jit +def minimum(a, b): + return tl.minimum(a, b) + +triton_helpers.maximum = maximum +triton_helpers.minimum = minimum diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfcdb5cddcd81762569829648e2a5065d30dd53 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -0,0 +1,834 @@ +# This file is based on triton_heuristics with heuristics designed for NPU +import os +import copy +import hashlib +import functools +import logging +import re +from typing import Any, Callable, List, Optional +import getpass +import tempfile +import shutil + +from torch._inductor.triton_heuristics import ( + CachingAutotuner, + HeuristicType, + unique_configs, + hash_configs, + load_cached_autotuning, + Config, + ASTSource, + _find_names, + get_first_attr, + collected_calls, + json, +) + +from torch._inductor.utils import ( + create_bandwidth_info_str, + get_num_bytes, +) + +from torch._inductor import config +from torch._dynamo.utils import dynamo_timed +import triton +import torch +import torch_npu + +from .codegen.tile_generator import TileGenerator +from .codegen.triton_utils import get_aligned_numel + +from .codegen.split_tiling import SplitTiling + +from .config import log +from .config import aggresive_autotune + +try: + from triton.backends.compiler import GPUTarget + from triton.runtime.autotuner import OutOfResources + import torch.autograd.profiler as autograd_profiler +except ImportError: + GPUTarget = None + OutOfResources = None + autograd_profiler = None + +from .utils import get_current_raw_stream + + +class NPUCachingAutotuner(CachingAutotuner): + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + ): + super().__init__(fn, triton_meta, configs, save_cache_hook, mutated_arg_names, heuristic_type, size_hints, inductor_meta, custom_kernel) + self.gpu_device.get_raw_stream = get_current_raw_stream + self.exceptions = [] + + # don't print exceptions when UB exception thrown by underlying compiler + def precompile(self, warm_cache_only_with_cc=None): + # xpu_graph changed TORCHINDUCTOR_CACHE_DIR. + # When TORCHINDUCTOR_COMPILE_THREADS > 1, multiprocessing's fork method + # does not propagate TORCHINDUCTOR_CACHE_DIR into the child threads. + # However, after all the child threads finished, the main thread reaches + # here and inherits xpu_graph's TORCHINDUCTOR_CACHE_DIR. Then the main + # thread finds the cache dir does not have any compiled kernel. It will + # compile all kernels one by one. + # So we directly replace TORCHINDUCTOR_CACHE_DIR with the standard cache dir. + if ("xpu_graph" in os.getenv("TORCHINDUCTOR_CACHE_DIR", "")): + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + cache_dir = os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir + os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_dir, "triton", "0") + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + save_configs = [] + latest_config = None + for c in self.configs: + try: + latest_config = c + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only_with_cc + ) + if (compiled_binary is None): + continue + except Exception as e: + log.debug(f"[thread {os.getpid()}][InductorNPU.precompile] Exception = {e}, kernel = {self.fn.__name__} config = {c}") + # Skip the config if the compilation fails + continue + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + # remove compile failure tiling case + self.configs = save_configs + if len(self.launchers) == 0: + kernel_name = self.inductor_meta.get("kernel_name", "triton_") + log.exception( + "Triton compilation failed: %s with metadata: %s", + kernel_name, + latest_config.kwargs, + ) + error_messages = [] + for e in self.exceptions: + error_messages.append(e.message) + line_delim = "\n" + log.exception("Compile %s report %d times exceptions:\n%s", + kernel_name, + len(self.exceptions), + line_delim.join(error_messages), + ) + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + self.configs = None + + # to add the line options["mix_mode"] = "aiv" + # to filter out some options on cfg used for gird, but not for constants + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + for k, v in cfg.kwargs.items(): + if k not in self.fn.arg_names: + continue + index = self.fn.arg_names.index(k) + compile_meta["constants"][self.fn.arg_names[index]] = v + # for higher version triton + kwargs_list = [k for k, v in cfg.kwargs.items()] + for i, arg in enumerate(self.fn.arg_names): + if arg in kwargs_list: + continue + if os.environ.get('INDUCTOR_STATIC_MODE') and arg.endswith('_numel'): + continue + name = self.fn.arg_names[i] + value = compile_meta["signature"][i] + del compile_meta["signature"][i] + compile_meta["signature"][name] = value + + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = ( + os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1') and + config.assert_indirect_indexing and torch.version.hip is None + ) + + # Setting device_type="hip" required on ROCm to pass down to triton + compile_meta["device_type"] = ( + self.device_type if torch.version.hip is None else "hip" + ) + if warm_cache_only_with_cc: + cc = warm_cache_only_with_cc + else: + # Use device_type 'cuda' for both cuda and hip devices to retrieve + # the compute capability. + device_type = self.device_type if torch.version.hip is None else "cuda" + device_id = compile_meta["device"] + device = torch.device(device_type, device_id) + cc = self.gpu_device.get_compute_capability(device) + + compile_meta["cc"] = cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + None, + ), + ) + + if GPUTarget: + target = GPUTarget(compile_meta["device_type"], cc, 0) + else: + target = (compile_meta["device_type"], cc) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + # Note: currently force to generate vector kernels only + if self.device_type == "npu": + options["mix_mode"] = "aiv" + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + + if warm_cache_only_with_cc: + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + # compile failed don't need raise error for npu + return None, None + binary._init_handles() + return binary, None + + # load binary to the correct device + with self.gpu_device.device(compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + self.gpu_device.synchronize(self.gpu_device.current_device()) + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception as e: + self.exceptions.append(e) + # compile failed don't need raise error for npu + return None, None + binary._init_handles() + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "metadata": binary.metadata, + "torch": torch, + "set_device": self.gpu_device.set_device, + "current_device": self.gpu_device.current_device, + } + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + scope["function"] = get_first_attr(binary, "function", "cu_function") + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + scope["shared"] = binary_shared + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + bin[grid_0, grid_1, grid_2]( + {', '.join(call_args)}, + stream=stream) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = config.triton.store_cubin + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def bench(self, launcher, *args, grid_cur, **kwargs): + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + + def kernel_call(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid_cur, + stream=stream, + ) + from torch._inductor.utils import do_bench + # remove fast_flush=True for high level triton + try: + ret = do_bench(kernel_call, rep=40) + print(f"do bench ret = {ret}", flush=True) + except Exception as e: + print(f"do bench error on launcher.config : {launcher.config}", flush=True) + print(f"[ERROR MESSAGE] : {e}", flush=True) + ret = float("inf") + return ret + + +class NPUDebugAutotuner(NPUCachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, grid_cur, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=grid_cur, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, grid=grid_cur) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + if len(configs) != 1 and filename is None: + raise ValueError("Either len(configs) = 1 or filename is provided") + save_cache_hook: Optional[Callable[[Any, Any], Any]] + inductor_meta = {} if inductor_meta is None else inductor_meta + + # on disk caching logic and/or remote caching + if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning): + configs_hash = hash_configs(configs) + + cache_filename = None + remote_cache = None + remote_cache_key = None + if config.use_autotune_local_cache: + cache_filename = os.path.splitext(filename)[0] + ".best_config" + if config.use_autotune_remote_cache or ( + config.is_fbcode() + and torch._utils_internal.justknobs_check( + "pytorch/autotune_remote_cache:enable" + ) + ): + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is not None: + key = backend_hash + configs_hash + "autotune-best-config" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if config.is_fbcode(): + remote_cache = ( + triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend( + key, is_autotune=True + ) + ) + else: + remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + except Exception: + remote_cache = None + log.error("Unable to create a remote cache", exc_info=True) + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + else: + log.warning( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + + best_config = None + if cache_filename is not None and os.path.exists(cache_filename): + with open(cache_filename) as fd: + best_config = json.loads(fd.read()) + elif remote_cache is not None and remote_cache_key is not None: + cache_outs = remote_cache.get([remote_cache_key]) + cache_out = cache_outs.get(remote_cache_key, None) + best_config = json.loads(cache_out) if cache_out else None + + best_config = load_cached_autotuning(best_config, configs_hash, configs) + if best_config: + configs = [best_config] + + def save_cache_hook(cfg, found_by_coordesc=False): + data = json.dumps( + { + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + "configs_hash": configs_hash, + "found_by_coordesc": found_by_coordesc, + } + ) + if cache_filename is not None: + with open(cache_filename, "w") as fd: + fd.write(data) + if remote_cache is not None and remote_cache_key is not None: + remote_cache.put(remote_cache_key, data) + else: + save_cache_hook = None + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + if tconfig.kwargs["XBLOCK"] != 1: + raise ValueError('tconfig.kwargs["XBLOCK"] != 1') + tconfig.kwargs.pop("XBLOCK") + + if config.profile_bandwidth: + return NPUDebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=config.profile_bandwidth_regex, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + return NPUCachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + + return decorator + + +###################################################### +## Main entry points for triton kernel invocation ## +## adapts original heuristics for NPU arch, and ## +## redirect to NPUCaching autotuner ## +###################################################### + +def grid(*numels): + def grid_fn(meta): + split_axis_order = meta["split_axis_order"] + + if split_axis_order is not None and split_axis_order < len(numels): + numel = numels[split_axis_order] if split_axis_order is not None else 1 + xblock = meta["XBLOCK"] + NBLOCKS, _ = SplitTiling.get_nblocks_before_launch(numel, xblock) + else: + NBLOCKS = 1 + + log.debug("launch grid(%s), NBLOCKS:%d, meta:%s", numels, NBLOCKS, meta) + return ( + NBLOCKS, + 1, + 1, + ) + + return grid_fn + + +# split:sizeof split, xblock:axis1 length, rblock:axis2 length +def triton_config_npu_index( + size_hints, + inductor_meta, + triton_meta=None, + reduction=False, + persistent_reduction=False, + +) -> List[Config]: + num_warps = 1 + num_stages = 1 + configs = [] + log.info("[InductorNPU] processing kernel %s", inductor_meta['kernel_name']) + split_axis_order = inductor_meta["split_axis_order"] + axis1_order = inductor_meta["axis1_order"] + axis2_order = inductor_meta["axis2_order"] + low_dims = inductor_meta["low_dims"] + split_axis_dtype = inductor_meta["split_axis_dtype"] + split_numel = size_hints[split_axis_order] if split_axis_order is not None else 1 + is_low_dim = True if split_axis_order is not None and split_axis_order in low_dims else False + + min_aligned_numel = get_aligned_numel(split_axis_dtype) + grid_list = [] + if (aggresive_autotune): + grid_list = SplitTiling.get_nblocks_xblock_list(split_numel) + else: + nblocks, split = SplitTiling.decide_nblocks_xblock(split_numel, axis2_order is None, min_aligned_numel) + grid_list.append((nblocks, split)) + + for nblocks, split in grid_list: + log.debug("generating tiling : size_hints:%s split_axis_order:%s, axis1_order:%s, axis2_order:%s, " + "low_dims:%s nblocks %s, split:%s persistent_reduction:%s split_axis_dtype:%s", size_hints, + split_axis_order, axis1_order, axis2_order, low_dims, nblocks, split, + persistent_reduction, split_axis_dtype) + # xblock is a range, don't auto_tune + xnumel = split if split_axis_order == axis1_order else size_hints[axis1_order] + rblock = 1 + if axis2_order is not None: + rblock = split if split_axis_order == axis2_order else size_hints[axis2_order] + + xblock_sub = xnumel + cfg = {"NBLOCKS": nblocks, "XBLOCK": split, "XBLOCK_SUB": xblock_sub} + # forward to grid() + cfg["split_axis_order"] = split_axis_order + cfg["axis2_order"] = axis2_order if not(axis2_order is None) else -1 + cfg["is_low_dim"] = is_low_dim + cfg["min_aligned_numel"] = min_aligned_numel + is_1d_reduction = reduction and axis2_order is None + if persistent_reduction: + numof_reduction_axis = inductor_meta["numof_reduction_axis"] + if numof_reduction_axis > 1: + del cfg["XBLOCK_SUB"] + configs.append(Config(cfg, num_warps=1, num_stages=1)) + elif axis2_order is None: + del cfg["XBLOCK"] + del cfg["XBLOCK_SUB"] + cfg["NBLOCKS"] = 1 + configs.append(Config(cfg, num_warps=1, num_stages=1)) + else: + TileGenerator.descend_xblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif is_1d_reduction: + cfg["NBLOCKS"] = 1 + cfg["XBLOCK"] = split_numel + cfg["XBLOCK_SUB"] = split_numel + TileGenerator.descend_xblock(rnumel=rblock, xblock=split_numel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + # both of the two axis are low dims + elif axis1_order in low_dims and axis2_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_xblock_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif axis2_order is None and axis1_order is not None: + TileGenerator.descend_xblock(rnumel=0, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + # need to maximize xblock_sub + elif axis1_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif axis2_order in low_dims: + cfg["RBLOCK"] = rblock + TileGenerator.descend_xblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel) + elif len(low_dims) == 0: + cfg["RBLOCK"] = rblock + if (axis1_order is not None) and (axis2_order is not None): + TileGenerator.descend_xblock_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + elif axis1_order is not None: + TileGenerator.descend_xblock(rnumel=0, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + else: + TileGenerator.descend_rblock(rnumel=rblock, xblock=xnumel, configs=configs, cfg=cfg, align_numel=min_aligned_numel, aggresive=False) + else: + cfg["RBLOCK"] = rblock + tmp = Config(cfg, num_warps=num_warps, num_stages=num_stages) + configs.append(tmp) + + for cfg in configs: + log.debug("generated tiling configs %s", cfg.kwargs) + + return configs + + +def pointwise_npu_index( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + + inductor_meta = {} if inductor_meta is None else inductor_meta + triton_config_with_settings = functools.partial( + triton_config_npu_index + ) + return cached_autotune( + size_hints, + triton_config_with_settings(size_hints, inductor_meta=inductor_meta), + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if triton_meta is None: + raise ValueError('triton_meta is None') + contiguous_config = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True) + return cached_autotune( + size_hints, + [ + *contiguous_config, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.REDUCTION, + ) + + +def persistent_reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + configs = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True, persistent_reduction=True) + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dynamo_timed +def benchmark_all_configs(self, *args, grid_cur, **kwargs): + print(f"candidate launcher count = {len(self.launchers)}") + + tilling_kernel_list = [] + + def kernel_call(launcher): + def call_kernel(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid_cur, + stream=stream, + ) + return call_kernel + + for launcher in self.launchers: + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + tilling_kernel_list.append(kernel_call(launcher)) + + def do_batch_benchmark(tilling_kernel_list): + + def delete_file(base_path): + if os.path.exists(base_path): + shutil.rmtree(base_path) + + from datetime import datetime + + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + + md5_hash = hashlib.md5() + md5_hash = hashlib.md5(datetime.now(tz=timezone.utc).strftime('%Y-%m-%d').encode('utf-8')).hexdigest() + + torch_path = os.path.join("./profile_result/", md5_hash) + rep = 1 + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=rep, repeat=1, skip_first=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(rep + 3): + for fn in tilling_kernel_list: + fn() + prof.step() + stream.synchronize() + + import pandas as pd + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['Name'].str.startswith('triton', na=False)] + ret = triton_rows['Duration(us)'].astype(float).tolist() + delete_file(torch_path) + return ret + + delete_file(torch_path) + return [] + + try: + timinglist = do_batch_benchmark(tilling_kernel_list) + if len(timinglist) != len(self.launchers): + raise ValueError('len(timinglist) != len(self.launchers)') + timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} + except Exception as e: + print("some cases in batch benchmark has error! Logging Exception as:") + print(e) + print("switched to single bench...") + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + print(f"final valid tillings count = {len(timings)}") + return timings \ No newline at end of file diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..697059f7a885b8b2ec3912954728b6a3a84a234c --- /dev/null +++ b/torch_npu/_inductor/utils.py @@ -0,0 +1,7 @@ +import torch +import torch_npu + + +# Not good implementation, but no other way +def get_current_raw_stream(device): + return torch.npu.current_stream(device).npu_stream \ No newline at end of file