diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 71f5232cba21aa0488e46192db024fe01c7e098b..b896438aa1635e22f27a948399dc928299e97c42 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1,6 +1,8 @@ # Owner(s): ["module: dynamo"] import functools -import unittest +import math +import sys +import unittest # noqa: F811 from importlib import import_module import torch @@ -14,32 +16,54 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import CompileCounterWithBackend from torch._higher_order_ops.wrap import tag_activation_checkpoint -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm +from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") +def checkpoint_wrapper(fn): + def inner(*args): + return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) + + return inner + + def count_ops( gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None ): - assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) - if op: + def match_rng_op(node, op): + if isinstance(node.target, torch._ops.HigherOrderOperator): + if node.name == "run_and_save_rng_state": + return node.args[0] == op + elif node.name == "run_with_rng_state": + return node.args[1] == op + return False + + # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) + if op is not None: ops = [op] - if freq: + if freq is not None: freqs = [freq] - if freq_ge: + if freq_ge is not None: freqs_ge = [freq_ge] if freqs: for op, freq in zip(ops, freqs): - actual_count = [node.target for node in gm.graph.nodes].count(op) + actual_count = 0 + for node in gm.graph.nodes: + if match_rng_op(node, op) or node.target == op: + actual_count += 1 assert ( actual_count == freq ), f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}." else: assert freqs_ge is not None for op, freq_ge in zip(ops, freqs_ge): - actual_count = [node.target for node in gm.graph.nodes].count(op) + actual_count = 0 + for node in gm.graph.nodes: + if match_rng_op(node, op) or node.target == op: + actual_count += 1 assert ( actual_count >= freq_ge ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." @@ -89,11 +113,11 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): for arg in args: cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad)) - torch.manual_seed(0) + torch_npu.npu.manual_seed(0) expected = fn(*args) expected.sum().backward() - torch.manual_seed(0) + torch_npu.npu.manual_seed(0) result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args) result.sum().backward() @@ -110,13 +134,62 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): msg="Gradient mismatch between torch.compile and eager versions", ) + def _compare_orig_and_checkpointed_fns( + self, orig_fn, checkpointed_fn, *args, fullgraph=True + ): + # The original version and the checkpointed version of the same function + # should produce the same outputs and the same gradients under torch.compile. + + # Run original version + cloned_args_orig_fn = [] + for arg in args: + cloned_args_orig_fn.append( + arg.clone().detach().requires_grad_(arg.requires_grad) + ) + torch_npu.npu.manual_seed(0) + compiled_orig_fn = torch.compile( + orig_fn, fullgraph=fullgraph, backend="npu" + ) + result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn) + result_orig_fn.sum().backward() + + # Run checkpointed version + cloned_args_checkpointed_fn = [] + for arg in args: + cloned_args_checkpointed_fn.append( + arg.clone().detach().requires_grad_(arg.requires_grad) + ) + torch_npu.npu.manual_seed(0) + compiled_checkpointed_fn = torch.compile( + checkpointed_fn, fullgraph=fullgraph, backend="npu" + ) + result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn) + result_checkpointed_fn.sum().backward() + + # Check that outputs and gradients are equal + self.assertEqual( + result_orig_fn, + result_checkpointed_fn, + msg="Output mismatch between the original version and the checkpointed version of the same function", + ) + for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip( + cloned_args_orig_fn, cloned_args_checkpointed_fn + ): + self.assertEqual( + cloned_arg_orig_fn.grad, + cloned_arg_checkpointed_fn.grad, + msg="Gradient mismatch between the original version and the checkpointed version of the same function", + ) + @requires_npu() def test_tags_function(self): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): - return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y) + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ) x = torch.randn(4, 4, device="npu:0", requires_grad=True) y = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -135,7 +208,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): # This goes through VariableBuilder - return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=True) x = torch.randn(4, 4, device="npu:0", requires_grad=True) y = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -174,9 +247,9 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): x = torch.sin(x) - z = torch.utils.checkpoint.checkpoint(gn, x, y) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) x = torch.sin(z) - z = torch.utils.checkpoint.checkpoint(gn, x, y) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) return z x = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -202,7 +275,9 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): mod = MockModule().npu() def fn(x): - return torch.utils.checkpoint.checkpoint(mod, torch.sin(x)) + return torch.utils.checkpoint.checkpoint( + mod, torch.sin(x), use_reentrant=True + ) x = torch.randn(10, 10, device="npu:0", requires_grad=True) @@ -229,7 +304,9 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): mod = MockModule().npu() def fn(x): - return torch.utils.checkpoint.checkpoint(mod, torch.sin(x)) + return torch.utils.checkpoint.checkpoint( + mod, torch.sin(x), use_reentrant=True + ) x = torch.randn(10, 10, device="npu:0", requires_grad=True) @@ -256,9 +333,9 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): x = torch.sin(x) - x = torch.utils.checkpoint.checkpoint(gn, x, y) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) x = torch.sin(x) - z = torch.utils.checkpoint.checkpoint(gn, x, y) + z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) return z x = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -269,7 +346,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): # count_ops, freq=6, op=torch.ops.aten.mm.default # ) # mm recomputed in the bwd # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) - backend = "inductor" + backend = "npu" self._validate(fn, backend, x, y) @requires_npu() @@ -282,9 +359,9 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): x = torch.sin(x) - x = torch.utils.checkpoint.checkpoint(gn, x, y) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) x = torch.sin(x) - # x = torch.utils.checkpoint.checkpoint(gn, x, y) + # x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) return x x = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -296,7 +373,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): # ) # mm recomputed in the bwd # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) # backend = "aot_eager" - backend = "inductor" + backend = "npu" self._validate(fn, backend, x, y) @requires_npu() @@ -315,10 +392,10 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): mod = MockModule().npu() def fn(x): - return torch.utils.checkpoint.checkpoint(mod, x) + return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) x = torch.randn(10, 10, device="npu:0", requires_grad=True) - backend = "inductor" + backend = "npu" # rand decomps do not have have numerical results as eager self._validate(fn, backend, x, skip_check=True) @@ -345,7 +422,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): self.assertEqual(result, expected) - # One graph for torch.sin on the input, and other for torch.cos. + # One graph for torch.sin on the ipt, and other for torch.cos. self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 2) self.assertEqual(len(cnt.graphs), 2) @@ -390,7 +467,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) def fn(x, y): - return torch.utils.checkpoint.checkpoint(gn, x, y) + return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) backend = "aot_eager" cnt = CompileCounterWithBackend(backend) @@ -414,7 +491,11 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) self.assertEqual(len(wrap_node.args), 3) + @requires_npu() @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) @@ -433,14 +514,14 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, context_fn=selective_checkpointing_context_fn, ) - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + y = torch.randn(4, 4, requires_grad=True, device="npu:0") fw_compiler = functools.partial( count_ops, @@ -461,8 +542,69 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @requires_npu() @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @torch._dynamo.config.patch( + "_experimental_support_context_fn_in_torch_utils_checkpoint", True + ) + def test_compile_selective_checkpoint_tensor_subclass(self): + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.mm.default, + ] + return _pt2_selective_checkpoint_context_fn_gen( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x, y): + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + + def fn(x, y): + return torch.utils.checkpoint.checkpoint( + gn, + x, + y, + use_reentrant=False, + context_fn=selective_checkpointing_context_fn, + ) + + rand_tensor = torch.randn(4, 4, requires_grad=True, device="npu:0") + + # tensor subclasses as inputs + x = TwoTensor(rand_tensor, rand_tensor.clone()) + y = TwoTensor(rand_tensor.clone(), rand_tensor.clone()) + + fw_compiler = functools.partial( + count_ops, + freq=4, + op=torch.ops.aten.mm.default, + ) + bw_compiler = functools.partial( + count_ops, + # We would've expected 12 here + # (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12) + # if we didn't enable selective checkpointing. + freq=8, + op=torch.ops.aten.mm.default, + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=min_cut_rematerialization_partition, + ) + self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_npu() + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) @@ -498,14 +640,14 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, context_fn=selective_checkpointing_context_fn, ) - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + y = torch.randn(4, 4, requires_grad=True, device="npu:0") fw_compiler = functools.partial( count_ops, @@ -527,45 +669,51 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @requires_npu() @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) - def test_compile_selective_checkpoint_outplace_op(self): - def selective_checkpointing_context_fn(): - no_recompute_list = [ - torch.ops.aten.mm.default, - torch.ops.aten.sigmoid.default, - ] + def test_compile_selective_checkpoint_partial_ctx_fn(self): + def selective_checkpointing_context_fn(no_recompute_list): return _pt2_selective_checkpoint_context_fn_gen( - _get_custom_policy(no_recompute_list=no_recompute_list), + _get_custom_policy(no_recompute_list=no_recompute_list) ) def gn(x, y): - return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu() + return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, - context_fn=selective_checkpointing_context_fn, + context_fn=functools.partial( + selective_checkpointing_context_fn, [torch.ops.aten.mm.default] + ), ) - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + y = torch.randn(4, 4, requires_grad=True, device="npu:0") fw_compiler = functools.partial( count_ops, - freqs=[2, 1], - ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + freq=2, + op=torch.ops.aten.mm.default, ) bw_compiler = functools.partial( count_ops, - freqs=[4, 0], - ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], + # We would've expected 6 here + # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) + # if we didn't enable selective checkpointing. + freq=4, + op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, @@ -573,41 +721,40 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @requires_npu() @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @unittest.skip( - "In-place op support in selective checkpointing + torch.compile " - "requires TorchDispatchMode + torch.compile work to complete" + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) - def test_compile_selective_checkpoint_inplace_op(self): + def test_compile_selective_checkpoint_outplace_op(self): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] return _pt2_selective_checkpoint_context_fn_gen( - _get_custom_policy(no_recompute_list=no_recompute_list) + _get_custom_policy(no_recompute_list=no_recompute_list), ) def gn(x, y): - return torch.sigmoid( - torch.selu_(torch.matmul(torch.matmul(x, y), y)) - ).relu_() + return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu() def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, context_fn=selective_checkpointing_context_fn, ) - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + y = torch.randn(4, 4, requires_grad=True, device="npu:0") fw_compiler = functools.partial( count_ops, @@ -625,12 +772,21 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + @requires_npu() @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skip( + "In-place op support in selective checkpointing + torch.compile " + "requires TorchDispatchMode + torch.compile work to complete" + ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) - def test_compile_selective_checkpoint_random_op(self): + def test_compile_selective_checkpoint_inplace_op(self): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -642,24 +798,24 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def gn(x, y): return torch.sigmoid( - torch.matmul(torch.matmul(torch.bernoulli(torch.sigmoid(x)), y), y) - ) + torch.selu_(torch.matmul(torch.matmul(x, y), y)) + ).relu_() def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, context_fn=selective_checkpointing_context_fn, ) - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + y = torch.randn(4, 4, requires_grad=True, device="npu:0") fw_compiler = functools.partial( count_ops, - freqs=[2, 2], + freqs=[2, 1], ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], ) bw_compiler = functools.partial( @@ -673,8 +829,77 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) + self._compare_orig_and_checkpointed_fns(gn, fn, x, y) + + @requires_npu() + @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @torch._dynamo.config.patch( + "_experimental_support_context_fn_in_torch_utils_checkpoint", True + ) + def test_compile_selective_checkpoint_random_op(self): + for preserve_rng_state in [True, False]: + + def selective_checkpointing_context_fn(): + no_recompute_list = [ + torch.ops.aten.sigmoid.default, + ] + return _pt2_selective_checkpoint_context_fn_gen( + _get_custom_policy(no_recompute_list=no_recompute_list) + ) + + def gn(x): + return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True)) + + def fn(x): + return torch.utils.checkpoint.checkpoint( + gn, + x, + use_reentrant=False, + # Regardless of whether `preserve_rng_state` is True or False, + # we will always preserve RNG state when using `torch.compile`. + preserve_rng_state=preserve_rng_state, + context_fn=selective_checkpointing_context_fn, + ) + + x = torch.randn(4, 4, requires_grad=True, device="npu:0") + + fw_compiler = functools.partial( + count_ops, + freqs=[2, 1], + ops=[ + torch.ops.aten.sigmoid.default, + torch.ops.aten.native_dropout.default, + ], + ) + bw_compiler = functools.partial( + count_ops, + # NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1). + freqs=[0, 1], + ops=[ + torch.ops.aten.sigmoid.default, + torch.ops.aten.native_dropout.default, + ], + ) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=min_cut_rematerialization_partition, + ) + + # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, + # because eager version doesn't preserve RNG state while torch.compile still does. + # Hence when `preserve_rng_state` is False, we skip the output and gradient comparison + # between torch.compile and eager. + self._validate(fn, backend, x, skip_check=not preserve_rng_state) + self._compare_orig_and_checkpointed_fns(gn, fn, x) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @torch._dynamo.config.patch( "_experimental_support_context_fn_in_torch_utils_checkpoint", True ) @@ -685,7 +910,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): def fn(x, y): return torch.utils.checkpoint.checkpoint( gn, - torch.sin(x), + x, y, use_reentrant=False, context_fn=_invalid_context_gen, @@ -715,6 +940,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): self._validate(fn, backend, x, y) @requires_npu() + @skipIfRocm def test_autocast_flash_attention(self): def fn(primals_1, primals_2, primals_3): return torch.ops.aten._scaled_dot_product_efficient_attention.default( @@ -722,7 +948,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): )[0] def gn(*args): - return torch.utils.checkpoint.checkpoint(fn, *args) + return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) with torch.npu.amp.autocast(): x = torch.randn(4, 2, 16, 32, device="npu:0", requires_grad=True) @@ -730,11 +956,11 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): z = torch.randn(4, 2, 16, 32, device="npu:0", requires_grad=True) args = (x, y, z) - torch.manual_seed(0) + torch_npu.npu.manual_seed(0) ref = gn(*args) opt_gn = torch.compile(gn) - torch.manual_seed(0) + torch_npu.npu.manual_seed(0) res = opt_gn(*args) self.assertEqual(ref, res) @@ -753,13 +979,12 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): mod = MockModule().npu() def fn(x): - return torch.utils.checkpoint.checkpoint(mod, x) + return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) x = torch.randn(4, 4).npu() opt_fn = torch.compile(fn, fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "while introspecting torch.utils.checkpoint.checkpoint, we were unable to trace function `NNModuleVariable`", + torch._dynamo.exc.Unsupported, "skip function graph_break in file" ): opt_fn(x) @@ -778,7 +1003,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): mod = MockModule().npu() def fn(x, ys): - return torch.utils.checkpoint.checkpoint(mod, x, ys) + return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True) x = torch.randn(4, 4).npu() y = torch.randn(4, 4).npu() @@ -788,6 +1013,72 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): res = opt_fn(x, [y, z]) self.assertEqual(ref, res) + @requires_npu() + def test_pattern_matcher(self): + # Check that the sdpa op is recomputed in the backward graph + # tests percolate_tags + + @checkpoint_wrapper + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(1.0 / math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .matmul(value) + ) + + def fn(query, key, value): + # Checks that sin is not recomputed in the backward graph + return dot_prod_attention(query.sin(), key, value) + + tensor_shape = (4, 2, 16, 32) + dtype = torch.float16 + args1 = [ + torch.randn(tensor_shape, device="npu:0", dtype=dtype, requires_grad=True), + torch.randn(tensor_shape, device="npu:0", dtype=dtype, requires_grad=True), + torch.randn(tensor_shape, device="npu:0", dtype=dtype, requires_grad=True), + ] + + # Save the AOT graphs + aot_graphs = [] + from torch._inductor import compile_fx + + def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): + aot_graphs.append(graph) + return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs) + + backend = functools.partial( + compile_fx.compile_fx, inner_compile=debug_compile_fx_inner + ) + + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + opt_fn(*args1).sum().backward() + + fwd_graph = aot_graphs[0] + self.assertTrue( + count_ops( + fwd_graph, + [], + freq=1, + op=torch.ops.aten._scaled_dot_product_flash_attention.default, + ) + ) + + bwd_graph = aot_graphs[1] + # Check that sin is not recomputed in the backward graph - checks percolate tags + self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default)) + # Check that the sdpa op is recomputed in the backward graph + self.assertTrue( + count_ops( + bwd_graph, + [], + freq=1, + op=torch.ops.aten._scaled_dot_product_flash_attention.default, + ) + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 2cdcafab1328c7a37242d73ed1a90678ff536cfb..de015ae1c5a4cb9c73af004e730066b544db8c43 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -26,7 +26,7 @@ def maybe_dupe_op(x): aten = torch.ops.aten -lib = torch.library.Library("custom", "DEF") +lib = torch.library.Library("custom", "DEF") # noqa: TOR901 lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)") lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU") lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta") @@ -404,7 +404,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F()) # Note: to prevent a recompilation between the two calls, # we need to clone x and y on each use. - # fxy mutates the input's metadata, so otherwise dynamo will end up recompiling. + # fxy mutates the ipt's metadata, so otherwise dynamo will end up recompiling. fxy(x1, y1) fxy(x2, y2) @@ -672,6 +672,43 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): self.assertEqual(cc.frame_count, 2) self.assertIn("""L['c'] is L['d']""", failure_reason) + def test_alias_inputs(self): + def fn(): + a = torch.tensor([1]) + a = a[0:1] + b = a.squeeze() + a[0] = 0 + if a[0] < 1e5: + pass + a[0] = 2 + return b + + ref_output = fn() + aot_fn = torch._dynamo.optimize("aot_eager")(fn) + actual_output = aot_fn() + self.assertEqual(ref_output, actual_output) + + def test_grad_inputs_alias_inputs(self): + class Test(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x) + return y + + @staticmethod + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return x, grad + + def fn(x, y): + return Test.apply(x, y) + + x = torch.ones(1, requires_grad=True) + y = torch.ones(1, requires_grad=True) + compiled_fn = torch.compile(fn, backend="aot_eager") + out = compiled_fn(x, y) + out.sum().backward() + @expectedFailureDynamic # See pytorch/pytorch/issues/103539 @torch._dynamo.config.patch(automatic_dynamic_shapes=False) @patch("torch._functorch.config.debug_assert", True) @@ -792,7 +829,7 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): param_and_buf_len = len(full_args) full_args.extend([x, target]) - # aot_export requires a graph mod input of fwd graph + # aot_export requires a graph mod ipt of fwd graph # returns the full fwd/bwd graph in graph mod format with torch.enable_grad(), fx_traceback.preserve_node_meta(): fx_g, _, _, _ = _aot_export_function( @@ -834,6 +871,7 @@ SeqNr|OrigAten|SrcFn 1|aten._native_batch_norm_legit_functional.default|l__self___bn1 2|aten.relu.default|l__self___relu1 2|aten.detach.default|l__self___relu1 +2|aten.detach.default|l__self___relu1 3|aten.add.Tensor|add 4|aten.view.default|flatten 5|aten.view.default|l__self___fc1 @@ -860,6 +898,7 @@ SeqNr|OrigAten|SrcFn 5|aten.view.default| 4|aten.view.default| 2|aten.detach.default| +2|aten.detach.default| 2|aten.threshold_backward.default| 1|aten.native_batch_norm_backward.default| 0|aten.convolution_backward.default| @@ -868,6 +907,25 @@ SeqNr|OrigAten|SrcFn ), ) + def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self): + from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks + + def fn(result, split_sizes): + rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist()) + return rs + + example_inputs = ( + torch.randn(32, requires_grad=True), + torch.tensor((7, 16, 9)), + ) + outs = fn(*example_inputs) + setup_stacktrace_preservation_hooks([out.grad_fn for out in outs]) + with fx_traceback.preserve_node_meta(): + (outs[0].sum() + outs[1].sum() + outs[2].sum()).backward() + + self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta) + self.assertNotIn("in_grad_fn", fx_traceback.current_meta) + # See pytorch/pytorch/issues/110121 def test_aot_export_joint_simple_repro(self): class Mod(torch.nn.Module): @@ -982,7 +1040,7 @@ SeqNr|OrigAten|SrcFn self.assertIsNotNone(y_ref[1].grad_fn) self.assertIsNotNone(y[1].grad_fn) - # Check that the grad computed for the inputs, given the input, is the same + # Check that the grad computed for the inputs, given the ipt, is the same # The tangent to `y[0]`, which has grad_required=False, is irrelevant self.assertEqual( sum(y_ref[1].grad_fn(torch.tensor([-1.0, 2.0, 0.0]))), diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 0716cc2adcf73ee88cdf54d9a08b53818ae12a7b..96009cf2fa86720902726b4e3baa148ba9e85469 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -1,14 +1,19 @@ # Owner(s): ["module: dynamo"] - +import functools import copy import math +import unittest import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils +from torch.testing._internal.common_utils import skipIfRocm + + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") +requires_triton = functools.partial(unittest.skip, "requires cuda and trion") class CustomFunc1(torch.autograd.Function): @@ -271,14 +276,51 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): opt_model(x) def test_stride_in_bwd(self): + torch._dynamo.utils.counters.clear() + cnt = torch._dynamo.testing.CompileCounter() model = CustomFuncStrideModule() - opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + opt_model = torch.compile(backend=cnt)(model) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Illegal getattr invocation stride in strict mod", - ): - opt_model(x) + ref = model(x) + res = opt_model(x) + + self.assertEqual(ref, res) + self.assertEqual(cnt.frame_count, 1) + # graph break: Illegal getattr invocation stride in strict mod. + self.assertEqual( + list(torch._dynamo.utils.counters["graph_break"].values()), [1] + ) + + def test_enum_arg(self): + from enum import Enum + + class SomeEnum(Enum): + A = 0 + B = 1 + + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, x, e): + if e is SomeEnum.A: + return x.sin() + else: + return x.cos() + + @staticmethod + def backward(ctx, g): + return g + + @torch.compile(backend="eager", fullgraph=True) + def f(x, enum): + output = Foo.apply( + x, + enum, + ) + return output + + x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True) + y = f(x, SomeEnum.A) + self.assertEqual(y, x.sin()) def test_save_for_bwd(self): model = SaveForBwdModule() @@ -595,7 +637,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): @staticmethod def jvp(ctx, x_t): - if jvp_err: + if jvp_err: # noqa: F821 return x_t else: return x_t.mul_(2) @@ -611,7 +653,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): @staticmethod def jvp(ctx, x_t, y_t): - return x_t + y_t, fn(x_t) + return x_t + y_t, fn(x_t) # noqa: F821 class MyFn3(torch.autograd.Function): @staticmethod @@ -697,7 +739,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): ) @staticmethod - def __tensor_unflatten__(tensors, metadatas): + def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride): return FooTensor(tensors["_data"], metadatas[0], metadatas[1]) @classmethod @@ -714,8 +756,6 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): raise NotImplementedError() - __torch_function__ = torch._C._disabled_torch_function_impl - class foo_autograd_fn(torch.autograd.Function): @staticmethod def forward(ctx, x): @@ -780,6 +820,42 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) + def test_repeated_save_for_backward_calls(self): + from torch.autograd import Function + + class Foo(Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x) + ctx.save_for_backward(x, y) + return x * y + + @staticmethod + def backward(ctx, grad_out): + x, y = ctx.saved_tensors + return grad_out * x, grad_out * y + + cnts = torch._dynamo.testing.CompileCounter() + + def foo(x, y): + return Foo.apply(x, y) + + x_ref = torch.randn(2, requires_grad=True) + y_ref = torch.randn(2, requires_grad=True) + x_test = x_ref.clone().detach().requires_grad_() + y_test = y_ref.clone().detach().requires_grad_() + + out_ref = foo(x_ref, y_ref) + out_ref.sum().backward() + + out_test = torch.compile(foo, backend=cnts)(x_test, y_test) + out_test.sum().backward() + + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(out_ref, out_test) + self.assertEqual(x_ref.grad, x_test.grad) + self.assertEqual(y_ref.grad, y_test.grad) + def test_smuggle_tensor_and_complex_structures(self): from torch.autograd import Function @@ -806,6 +882,94 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) + def test_default_values(self): + from torch.autograd import Function + + class Foo(Function): + @staticmethod + def forward(ctx, x, alpha=0.99): + return x + + @staticmethod + def backward(ctx, grad_out): + return grad_out + + @torch.compile + def foo(x): + return Foo.apply(x) + + # Make sure guards for default values do not crash + foo(torch.randn(2)) + foo(torch.randn(2, requires_grad=True)) + + @requires_triton() + @skipIfRocm + def test_triton_kernel_basic(self): + class Add(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + output = torch.zeros_like(x) + n_elements = output.numel() + grid = lambda meta: ( # noqa: E731 + triton.cdiv(n_elements, meta["BLOCK_SIZE"]), + ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) + return output + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + return x * grad_output, y * grad_output + + @torch.compile(fullgraph=True, backend="npu") + def f(x, y): + z = Add.apply(x, y) + return z + + x = torch.randn(10, device="npu:0", requires_grad=True) + y = torch.randn(10, device="npu:0", requires_grad=True) + z = f(x, y) + loss = z.sum() + loss.backward() + self.assertEqual(x + y, z) + + @requires_triton() + @skipIfRocm + def test_triton_kernel_multiple_out(self): + class Add(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + ctx.t1 = x + ctx.t2 = y + output = torch.zeros_like(x) + n_elements = output.numel() + grid = lambda meta: ( # noqa: E731 + triton.cdiv(n_elements, meta["BLOCK_SIZE"]), + ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) + return output, x + + @staticmethod + def backward(ctx, grad_output, old_x): + x, y = ctx.saved_tensors + x1 = ctx.t1 + y1 = ctx.t2 + return old_x * x * x1 * grad_output, y * y1 * grad_output + + @torch.compile(fullgraph=True, backend="npu") + def f(x, y): + z = Add.apply(x, y) + return z + + x = torch.randn(10, device="npu:0", requires_grad=True) + y = torch.randn(10, device="npu:0", requires_grad=True) + z, _ = f(x, y) + loss = z.sum() + loss.backward() + self.assertEqual(x + y, z) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 08add391f4ae4506a3295f99de116242fcedc6d6..2c22a6224fa24b4dc3965021759d236585efbff4 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -1,17 +1,18 @@ # Owner(s): ["module: dynamo"] import functools import unittest - import torch import torch_npu -import torchair import torch._dynamo import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.backends.onnxrt import has_onnxruntime from torch._dynamo.backends.tvm import has_tvm from torch._dynamo.testing import same +from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module +from torch.testing._internal.inductor_utils import HAS_CUDA +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") @@ -106,6 +107,7 @@ class TestOptimizations(torch._dynamo.test_case.TestCase): def test_eager(self): self._check_backend_works("eager") + @_force_skip_lazy_graph_module() def test_torchscript(self): self._check_backend_works("ts") @@ -115,20 +117,13 @@ class TestOptimizations(torch._dynamo.test_case.TestCase): def test_aot_eager_decomp_partition(self): self._check_backend_works("aot_eager_decomp_partition") + @_force_skip_lazy_graph_module() def test_aot_ts(self): self._check_backend_works("aot_ts") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @requires_cuda def test_aot_cudagraphs(self): self._check_backend_works("cudagraphs") - - def test_npu_backend(self): - npu_backend = torchair.get_npu_backend() - model = Seq().eval() - ipt = torch.randn(2, 10) - r1 = model(ipt).npu() - r2 = torch.compile(model, backend=npu_backend)(ipt) - self.assertTrue(same(r1, r2.float(), tol=0.01)) @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") def test_onnxrt(self): @@ -144,6 +139,7 @@ class TestOptimizations(torch._dynamo.test_case.TestCase): self.assertNotIn("eager", torch._dynamo.list_backends()) self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"])) self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[])) + self.assertIn("npu", torch._dynamo.list_backends()) class NormalizeIRTests(torch._dynamo.test_case.TestCase): diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c86aaec634942fbe2196d57bae27adc01d2edb --- /dev/null +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -0,0 +1,248 @@ +# Owner(s): ["module: dynamo"] +# flake8: noqa + +import functools + +import torch +import torch_npu +import torch._dynamo.test_case +import torch._dynamo.testing +import torch._dynamo.utils +from torch import _inductor as inductor +from torch._dynamo import compiled_autograd +from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped +from torch._dynamo.testing import normalize_gm +from torch._dynamo.utils import counters +from torch.fx.experimental.proxy_tensor import make_fx + + +def _multiply(x): + return x * x + + +def _multiply_invoke(grad): + return trace_wrapped(grad, fn=_multiply) + + +class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase): + def test_invoke_in_eager(self): + x = torch.tensor([0.5, 0.5], requires_grad=True) + y = torch.tensor([0.5, 0.5], requires_grad=True) + + def fn(x, y): + x.register_hook(_multiply_invoke) + return x * y + + out = fn(x, y) + grad_out = torch.tensor([2.0, 2.0]) + out.backward(grad_out) + self.assertEqual(x.grad, y * grad_out) + + def test_invoke_in_pt2(self): + for backend in ["eager", "aot_eager", "inductor"]: + torch._dynamo.reset() + x = torch.tensor([0.5, 0.5], requires_grad=True) + y = torch.tensor([0.5, 0.5], requires_grad=True) + + def fn(x, y): + x.register_hook(_multiply_invoke) + return x * y + + fn = torch._dynamo.optimize(backend)(fn) + out = fn(x, y) + grad_out = torch.tensor([2.0, 2.0]) + out.backward(grad_out) + self.assertEqual(x.grad, grad_out * y) + + def test_invoke_make_fx_forward_contrived(self): + x = torch.tensor([0.5, 0.5], requires_grad=True) + out = make_fx(_multiply_invoke)(x) + self.assertEqual(out(x), torch.tensor([0.25, 0.25])) + actual = normalize_gm(out.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class _multiply_invoke(torch.nn.Module): + def forward(self, grad_1: "f32[2]"): + trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None + return trace_wrapped +""", + ) + + def test_invoke_make_bw(self): + x = torch.tensor([0.5, 0.5], requires_grad=True) + + def fwd(x): + z = x * x + return z + z + + res = fwd(x) + res.backward(torch.tensor([1.0, 1.0])) + out = make_fx(_multiply_invoke)(x.grad) + self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0])) + actual = normalize_gm(out.print_readable(False)) + + self.assertExpectedInline( + actual, + """\ +class _multiply_invoke(torch.nn.Module): + def forward(self, grad_1: "f32[2]"): + trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None + return trace_wrapped +""", + ) + + def test_invoke_in_pt2_compiled_autograd(self): + graph = None + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + nonlocal graph + self.assertEqual(graph, None) + graph = gm_ + return inductor.compile(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, fullgraph=True, dynamic=True + ) + + for backend in ["eager", "aot_eager", "inductor"]: + torch._dynamo.reset() + x = torch.tensor([0.5, 0.5], requires_grad=True) + y = torch.tensor([0.5, 0.5], requires_grad=True) + + def fn(x, y): + x.register_hook(_multiply_invoke) + return x + y + + fn = torch._dynamo.optimize(backend)(fn) + out = fn(x, y) + grad_out = torch.tensor([2.0, 2.0]) + with compiled_autograd.enable(compiler_fn): + out.backward(grad_out) + actual = normalize_gm(graph.print_readable(False)) + self.assertEqual(x.grad, grad_out * grad_out) + expected = """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_inputs_0_ : torch.Tensor): + getitem = L_inputs_0_ + + new_grad = torch.clone(getitem) + + call_hook = getitem * getitem; getitem = None + + new_grad_1 = torch.clone(call_hook); call_hook = None + return (new_grad, new_grad_1) +""" + self.assertExpectedInline(actual, expected) + + graph = None + + def test_invoke_in_pt2_compiled_autograd_side_effect(self): + def _side_effect_stateful_fn2(x, obj): + obj.counter = obj.counter + 1 + return _multiply(x) + + def _side_effectful_invoke2(grad, fn): + return trace_wrapped(grad, fn=fn) + + graph = None + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + nonlocal graph + self.assertEqual(graph, None) + graph = gm_ + return inductor.compile(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, fullgraph=True, dynamic=True + ) + + for backend in ["eager", "aot_eager", "inductor"]: + torch._dynamo.reset() + x = torch.tensor([0.5, 0.5], requires_grad=True) + y = torch.tensor([0.5, 0.5], requires_grad=True) + + class MyObj: + def __init__(self): + self.counter = 0 + + obj = MyObj() + inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj) + hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn) + x.register_hook(hook_fn) + + def fn(x, y): + return x + y + + fn = torch._dynamo.optimize(backend, nopython=True)(fn) + out = fn(x, y) + grad_out = torch.tensor([2.0, 2.0]) + with compiled_autograd.enable(compiler_fn): + out.backward(grad_out) + actual = normalize_gm(graph.print_readable(False)) + self.assertEqual(obj.counter, 1) + self.assertEqual(x.grad, grad_out + grad_out) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_inputs_0_ : torch.Tensor): + getitem = L_inputs_0_ + + new_grad = torch.clone(getitem) + + call_hook = getitem * getitem; getitem = None + + new_grad_1 = torch.clone(call_hook); call_hook = None + return (new_grad, new_grad_1) +""", + ) + + out = fn(x, y) + out.backward(grad_out) + self.assertEqual(obj.counter, 2) + + out = fn(x, y) + out.backward(grad_out) + self.assertEqual(obj.counter, 3) + graph = None + + def test_invoke_in_pt2_compiled_autograd_graph_breaks(self): + def _graph_breaking_fn(x): + print("Boo!") + return _multiply(x) + + def _graph_break_invoke(grad): + return trace_wrapped(grad, fn=_graph_breaking_fn) + + def compiler_fn(gm): + return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True) + + for backend in ["eager", "aot_eager", "inductor"]: + torch._dynamo.reset() + x = torch.tensor([0.5, 0.5], requires_grad=True) + y = torch.tensor([0.5, 0.5], requires_grad=True) + + def fn(x, y): + x.register_hook(_graph_break_invoke) + return x + y + + fn = torch._dynamo.optimize(backend, nopython=True)(fn) + out = fn(x, y) + grad_out = torch.tensor([2.0, 2.0]) + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "print", + ): + with compiled_autograd.enable(compiler_fn): + out.backward(grad_out) + + graph = None + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index 19674d04e2a506a21a3110ca3bebee2b421d8f45..cd1936f479438ece020b5770eed578fea9923273 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -1,12 +1,14 @@ # Owner(s): ["module: dynamo"] import inspect +import io import os import tempfile -import unittest +from unittest.mock import patch import torch import torch_npu +from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import CompileCounter @@ -20,7 +22,7 @@ class ToyModel(torch.nn.Module): return self.relu(self.linear(x)) -class InPlaceCompilationTests(unittest.TestCase): +class InPlaceCompilationTests(TestCase): def test_compilation(self): torch._dynamo.reset() model = ToyModel() @@ -72,10 +74,63 @@ class InPlaceCompilationTests(unittest.TestCase): loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt")) loaded_model(torch.randn(1, 10)) + def test_compilation_callback(self): + torch._dynamo.reset() + + @torch._dynamo.on_compile_start + def start_callback(): + print("Compilation started.") + + @torch._dynamo.on_compile_end + def end_callback(): + print("Compilation ended.") + + mod = ToyModel() + x = torch.randn(10, 10) + + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_mod = torch.compile(backend="eager", fullgraph=True)(mod) + opt_mod(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "Compilation started.\nCompilation ended.") + + def test_compilation_callback_with_graph_break(self): + torch._dynamo.reset() + counter = 0 + + @torch._dynamo.on_compile_start + def start_callback(): + nonlocal counter + counter += 1 + print(f"Counter = {counter}") + + @torch._dynamo.on_compile_end + def end_callback(): + nonlocal counter + counter += 1 + print(f"Counter = {counter}") + + @torch.compile(backend="eager") + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + return torch.sin(x) + + x = torch.randn(10, 10) + + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + fn(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual( + printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4" + ) + # The private variants of the below functions are extensively tested # So as long as the signatures match we're good -class PublicTorchCompilerTests(unittest.TestCase): +class PublicTorchCompilerTests(TestCase): def check_signature(self, public_fn_name, private_fn_name, private_namespace): public_fn = getattr(torch.compiler, public_fn_name) private_fn = getattr(private_namespace, private_fn_name) @@ -102,6 +157,5 @@ class PublicTorchCompilerTests(unittest.TestCase): self.check_signature(fn_name, fn_name, torch._dynamo) -if __name__ == '__main__': - unittest.main() - \ No newline at end of file +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 3acc61d8d1c16879df5a8c9fab4428df1dbf329c..433b3fc0683b6427e9cb9f29ca5c4cdd97868bb7 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -233,13 +233,6 @@ y = TensorVariable() 'obj_weakref': None 'guarded_class': None } - global '' CONFIG_HASH_MATCH - { - 'guard_types': None, - 'code': None, - 'obj_weakref': None - 'guarded_class': None - } shape_env '' SHAPE_ENV { 'guard_types': None, diff --git a/test/dynamo/test_config.py b/test/dynamo/test_config.py index 9f856800887b3ff15d2bd5692f2bde08cf066754..b350da76e508bc4fa2fec5435eabd4a433528ab5 100644 --- a/test/dynamo/test_config.py +++ b/test/dynamo/test_config.py @@ -82,7 +82,7 @@ class ConfigTests(torch._dynamo.test_case.TestCase): "debug_dir_root", } for k in dynamo_guarded_config_ignorelist: - assert k in torch._dynamo.config._compile_ignored_keys + assert k in torch._dynamo.config._compile_ignored_keys, k def test_config_hash(self): config = torch._dynamo.config @@ -111,238 +111,6 @@ class ConfigTests(torch._dynamo.test_case.TestCase): assert changed_hash != newest_hash assert newest_hash == starting_hash - @disable_cache_limit() - def test_no_saved_config(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize( - cnt_dynamic, save_config=False - )(fn) - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - for i in range(2, 12): - opt_fn_static_shape( - torch.randn(i), torch.randn(i) - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic.frame_count, 3) - - @disable_cache_limit() - def test_no_saved_config_nested(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(fn) - - # Will trigger recompile as compiled as static - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - opt_fn_try_dynamic = torch._dynamo.optimize( - cnt_dynamic_1, save_config=False - )(opt_fn_static_shape) - - for i in range(2, 6): - opt_fn_try_dynamic(torch.randn(i), torch.randn(i)) - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - # Saved config = False will use whatever config is available - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - for i in range(6, 12): - opt_fn_try_dynamic(torch.randn(i), torch.randn(i)) - self.assertEqual(cnt_dynamic_1.frame_count, 7) - - @disable_cache_limit() - def test_config_changed_from_guarded_config_1(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - res = opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - for i in range(2, 12): - # Only 4-11 will now be recompiled under old config - # 2-3 have been already been compiled under old config - # and hence will hit cache - opt_fn_static_shape(torch.randn(i), torch.randn(i)) - - self.assertEqual(cnt_dynamic.frame_count, 10) - - @disable_cache_limit() - def test_config_changed_from_guarded_config_2(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_dynamic_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - opt_fn_dynamic_shape(torch.randn(2), torch.randn(2)) - opt_fn_dynamic_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 1) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - for i in range(2, 12): - opt_fn_dynamic_shape( - torch.randn(i), torch.randn(i) - ) # will not be recompiled due to automatic dynamic shapes - - self.assertEqual(cnt_dynamic.frame_count, 1) - - @disable_cache_limit() - def test_nested_compile_outer_wins(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1)( - lambda x, y: opt_fn_static_shape(x, y) - ) - for i in range(2, 12): - opt_fn_dynamic( - torch.randn(i), torch.randn(i) - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic.frame_count, 2) - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - @disable_cache_limit() - def test_nested_fn_does_not_inherit_outer_config(self): - def g1(x): - return x + 1 - - def g2(x): - return x * 2 - - def f(x): - x = g1(x) - torch._dynamo.graph_break() - return g2(x) - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(f) - opt_fn_static_shape(torch.randn(2)) - opt_fn_static_shape(torch.randn(3)) - self.assertEqual(cnt_dynamic.frame_count, 4) # 2 compiles * 2 graphs - - opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1, dynamic=True)(g2) - - for i in range(2, 12): - opt_fn_dynamic( - torch.randn(i), - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - @disable_cache_limit() - def test_multiple_compile_recompiles(self): - cnt_dynamic = torch._dynamo.testing.CompileCounter() - - def f(dynamic, compile_count): - @torch._dynamo.optimize(cnt_dynamic, dynamic=dynamic) - def g(x): - return x + 1 - - for i in range(2, 12): - g(torch.randn(i)) # will be recompiled under new config - self.assertEqual(cnt_dynamic.frame_count, compile_count) - cnt_dynamic.clear() - - f(dynamic=True, compile_count=1) # first compile - f(dynamic=False, compile_count=10) # recompile - f(dynamic=True, compile_count=0) # reuse first compile product - - def test_cache_size_limit(self): - cnt = torch._dynamo.testing.CompileCounter() - key = "_ConfigTests___test_cache_size_limit_key" - try: - torch._dynamo.config._allowed_keys.add(key) - torch._dynamo.config._ConfigTests___test_cache_size_limit_key = -1 - with torch._dynamo.config.patch( - {"cache_size_limit": 1, "accumulated_cache_size_limit": 10} - ): - - def g(x): - return x + 1 - - for i in range(12): - with torch._dynamo.config.patch( - {key: i % 6} - ): # same config doesn't recompile - opt_g = torch._dynamo.optimize(cnt)(g) - opt_g(torch.randn(1)) - self.assertEqual(cnt.frame_count, 6) - - for i in range(6, 12): - with torch._dynamo.config.patch({key: i}): - opt_g = torch._dynamo.optimize(cnt)(g) - opt_g(torch.randn(1)) - self.assertEqual( - cnt.frame_count, 10 - ) # only recompile up to cache size limit - finally: - if key in torch._dynamo.config._allowed_keys: - torch._dynamo.config._allowed_keys.remove(key) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 0de9bb849ae24a34d1306ff1d9478633d68ed18a..20f5df9ded2c36efb8ece5092a83590b0ee4ff49 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -10,6 +10,7 @@ from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_utils import TEST_WITH_ROCM class CutomizedCtxManager: @@ -161,8 +162,11 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): s = torch.npu.Stream() x = torch.mul(x, 5) x = torch.add(x, 2) + current_stream = torch.npu.current_stream() + s.wait_stream(current_stream) with torch.npu.stream(s): x = torch.relu(x) + current_stream.wait_stream(s) x = torch.add(x, 1) x = torch.cos(x) return x @@ -174,24 +178,62 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 12) + + @unittest.expectedFailure # pytorch/issues/118204 + @unittest.skipIf(not torch.npu.is_available(), "requires npu") + def test_npu_stream_across_graph_break(self): + def fn(x): + s = torch.npu.Stream() + x = torch.mul(x, 5) + x = torch.add(x, 2) + + print("foo") + + tcs = torch.npu.stream(s) + current_stream = torch.npu.current_stream() + s.wait_stream(current_stream) + + with tcs: + x = torch.relu(x) + + current_stream.wait_stream(s) + x = torch.add(x, 1) + x = torch.cos(x) + return x + + x = torch.randn((2, 2), device="npu:0") + ref = fn(x) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + res = opt_fn(x) + self.assertEqual(ref, res) + self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.op_count, 9) + @unittest.expectedFailure # pytorch/issues/118204 @unittest.skipIf(not torch.npu.is_available(), "requires npu") def test_npu_stream_context_manager2(self): def fn(x, s): x = torch.mul(x, 5) x = torch.add(x, 2) + + current_stream = torch.npu.current_stream() + s.wait_stream(current_stream) + with torch.npu.stream(s): x = torch.relu(x) - s1 = torch.npu.current_stream() - with torch.npu.stream(s1): + current_stream.wait_stream(s) + with torch.npu.stream(current_stream): x = torch.relu(x) s2 = torch.npu.Stream() + s2.wait_stream(current_stream) with torch.npu.stream(s2): x = torch.relu(x) + current_stream.wait_stream(s2) x = torch.add(x, 1) x = torch.cos(x) return x @@ -213,11 +255,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): x = torch.add(x, 2) new_stream = torch.npu.Stream() + cur_stream = torch.npu.current_stream() + new_stream.wait_stream(cur_stream) + with torch.npu.stream(new_stream): x = torch.sin(x) x = torch.add(x, 3) - cur_stream = torch.npu.current_stream() cur_stream.wait_stream(new_stream) x = torch.add(x, 4) @@ -239,11 +283,81 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) res = opt_fn(x) - self.assertTrue(same(ref, res)) + self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 20) + self.assertEqual(cnts.op_count, 21) + + @unittest.skipIf(not torch.npu.is_available(), "requires npu") + def test_npu_stream_compared_with_constant(self): + def fn(x): + x = torch.mul(x, 1) + x = torch.add(x, 2) + + cur_stream = torch.npu.current_stream() + if cur_stream is not None: + return x + 1 + return x - 1 + + def fn2(x): + x = torch.mul(x, 1) + x = torch.add(x, 2) + + cur_stream = torch.npu.current_stream() + if cur_stream != "const_str": + return x + 1 + return x - 1 + + x = torch.randn((2, 2), device="npu:0") + ref = fn(x) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) + res = opt_fn(x) + res2 = opt_fn2(x) + self.assertEqual(ref, res) + self.assertEqual(ref, res2) - @unittest.skipIf(not torch.npu.is_available(), "requires npu:0") + @unittest.skipIf(not torch.npu.is_available(), "requires npu") + def test_npu_stream_compared_with_stream(self): + def fn(x, s0, s1): + if s0 == s1: + return x + 1 + else: + return x - 1 + + s0 = torch.npu.Stream() + s1 = torch.npu.Stream() + x = torch.randn(2, 2) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + + ref0 = fn(x, s0, s1) + res0 = opt_fn(x, s0, s1) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(ref0, res0) + + ref1 = fn(x, s1, s1) + res1 = opt_fn(x, s1, s1) + # We have a re-compilation because of chaning inputs + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(ref1, res1) + + torch._dynamo.reset() + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + + ref1 = fn(x, s1, s1) + res1 = opt_fn(x, s1, s1) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(ref1, res1) + + ref0 = fn(x, s0, s1) + res0 = opt_fn(x, s0, s1) + # We have a re-compilation because of chaning inputs + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(ref0, res0) + + @unittest.skipIf(not torch.npu.is_available(), "requires npu") def test_npu_event_method(self): def fn(x): x = torch.mul(x, 1) @@ -264,8 +378,8 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): new_event = torch.npu.Event() new_event.record(new_stream) - x = torch.add(x, 5) new_event.wait(cur_stream) + x = torch.add(x, 5) # use new event to sync new_event.synchronize() @@ -279,7 +393,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) res = opt_fn(x) - self.assertTrue(same(ref, res)) + self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 19) @@ -314,9 +428,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): class MyModule(torch.nn.Module): def forward(self, x): - a_float32 = torch.rand((8, 8), device="npu") - b_float32 = torch.rand((8, 8), device="npu") - d_float32 = torch.rand((8, 8), device="npu") + a_float32 = torch.rand((8, 8), device="npu:0") + b_float32 = torch.rand((8, 8), device="npu:0") + d_float32 = torch.rand((8, 8), device="npu:0") with torch.autocast(device_type="npu", dtype=torch.bfloat16): e_float16 = torch.mm(a_float32, b_float32) @@ -341,8 +455,8 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def test_npu_amp_autocast(self): class MyModule(torch.nn.Module): def forward(self, x): - a_float32 = torch.rand((8, 8), device="npu") - b_float32 = torch.rand((8, 8), device="npu") + a_float32 = torch.rand((8, 8), device="npu:0") + b_float32 = torch.rand((8, 8), device="npu:0") with torch.npu.amp.autocast(dtype=torch.torch.float64): c_float64 = torch.mm(a_float32, b_float32) @@ -378,7 +492,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(ref, res)) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, + not PLATFORM_SUPPORTS_FLASH_ATTENTION or TEST_WITH_ROCM, "Can't run fused SDPA on this platform", ) def test_autocast_sdpa(self): @@ -396,13 +510,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): seq_len_k = 1 head_dim = 8 query = torch.ones( - 1, 8, seq_len_q, head_dim, device="npu", dtype=dtype, requires_grad=True + 1, 8, seq_len_q, head_dim, device="npu:0", dtype=dtype, requires_grad=True ) key = torch.ones( - 1, 8, seq_len_k, head_dim, device="npu", dtype=dtype, requires_grad=True + 1, 8, seq_len_k, head_dim, device="npu:0", dtype=dtype, requires_grad=True ) value = torch.ones( - 1, 8, seq_len_k, head_dim, device="npu", dtype=dtype, requires_grad=True + 1, 8, seq_len_k, head_dim, device="npu:0", dtype=dtype, requires_grad=True ) module = MyModule() @@ -410,7 +524,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): real_device = real.device real_dtype = real.dtype - opt_mod = torch._dynamo.optimize("inductor")(module) + opt_mod = torch._dynamo.optimize("npu")(module) compiled = opt_mod(query, key, value) self.assertEqual(compiled.device, real_device) @@ -473,7 +587,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): self.assertEqual(res.dtype, torch.bfloat16) def test_autocast_cpu_graph_break_2(self): - # Regression for: See pytorch/pytorch/issues/93890 + # Regression for: pytorch/issues/93890 def fn(x): with torch.autocast(device_type="cpu", dtype=torch.bfloat16): x = torch.mm(x, x) @@ -597,9 +711,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def test_autocast_float64(self): class MyModule(torch.nn.Module): def forward(self, x): - a_float32 = torch.rand((8, 8), device="npu") - b_float32 = torch.rand((8, 8), device="npu") - d_float32 = torch.rand((8, 8), device="npu") + a_float32 = torch.rand((8, 8), device="npu:0") + b_float32 = torch.rand((8, 8), device="npu:0") + d_float32 = torch.rand((8, 8), device="npu:0") with torch.autocast(device_type="npu", dtype=torch.float64): e_float64 = torch.mm(a_float32, b_float32) @@ -623,9 +737,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def test_autocast_device(self): class MyModule(torch.nn.Module): def forward(self, x): - a_float32 = torch.rand((8, 8), device="npu") - b_float32 = torch.rand((8, 8), device="npu") - d_float32 = torch.rand((8, 8), device="npu") + a_float32 = torch.rand((8, 8), device="npu:0") + b_float32 = torch.rand((8, 8), device="npu:0") + d_float32 = torch.rand((8, 8), device="npu:0") with torch.autocast("npu"): e_float64 = torch.mm(a_float32, b_float32) @@ -700,8 +814,8 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def fn(a, b): return mm_float16(a, b), mm_float16_npu(a, b), mm_float16_cpu(a, b) - a_float32 = torch.rand((8, 8), device="npu") - b_float32 = torch.rand((8, 8), device="npu") + a_float32 = torch.rand((8, 8), device="npu:0") + b_float32 = torch.rand((8, 8), device="npu:0") ref = fn(a_float32, b_float32) opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) @@ -719,11 +833,23 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): x = torch.relu(x) return x - 1 + x = torch.rand(2, 3) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn) + with torch.no_grad(): - torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6) + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.op_count, 2) with torch.enable_grad(): - torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6) + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 4) + self.assertEqual(cnts.op_count, 4) def test_nested_generic_context_manager(self): def fn(x): @@ -738,11 +864,23 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): x = torch.relu(x) return x - 1 + x = torch.rand(2, 3) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn) + with torch.no_grad(): - torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9) + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 4) + self.assertEqual(cnts.op_count, 4) with torch.enable_grad(): - torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9) + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 6) + self.assertEqual(cnts.op_count, 6) def test_generic_context_manager_with_graph_break(self): def fn(x): diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index 9f63481d4878d3ab8710814c3d69bf0dd8c74e0d..dc1348c6974966bd29ab3db78f24847b852c7213 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -1,11 +1,20 @@ # Owner(s): ["module: dynamo"] +import unittest +import functools import torch import torch_npu from functorch import make_fx from torch._dynamo import debug_utils +from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.test_case import TestCase +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + +f32 = torch.float32 +i64 = torch.int64 +i32 = torch.int32 + class TestDebugUtils(TestCase): def test_cast_model_to_fp64_dtype_args(self): @@ -50,6 +59,118 @@ def forward(self, x_1): """, # NOQA: B950 ) + @requires_npu() + def test_aot_graph_parser(self): + from torch import device + + def forward( + self, + primals_1: "f32[1001, 6]", + primals_2: "f32[1001]", + primals_3: "f32[1001, 64]", + primals_4: "f32[4190]", + primals_5: "f32[4190]", + primals_6: "f32[1739, 4190]", + primals_48: "f32[6144, 4191]", + ): + _tensor_constant0: "i64[4190]" = self._tensor_constant0 + lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default( + _tensor_constant0 + ) + _tensor_constant0 = None + index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( + primals_48, [None, lift_fresh_copy] + ) + lift_fresh_copy = None + + _tensor_constant1: "i64[6]" = self._tensor_constant1 + lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default( + _tensor_constant1 + ) + _tensor_constant1 = None + index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor( + primals_48, [None, lift_fresh_copy_1] + ) + primals_48 = lift_fresh_copy_1 = None + permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0]) + primals_1 = None + addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default( + primals_2, index_1, permute + ) + primals_2 = permute = None + amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True) + sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax) + exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub) + sub = None + sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) + div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1) + exp = None + + full_default: "i32[6144, 1001]" = torch.ops.aten.full.default( + [6144, 1001], + 1, + dtype=torch.int32, + layout=torch.strided, + device=device(type="npu", index=0), + pin_memory=False, + ) + + iota: "i32[1001]" = torch.ops.prims.iota.default( + 1001, + start=0, + step=1, + dtype=torch.int32, + device=device(type="npu"), + requires_grad=False, + ) + + mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota) + full_default = iota = None + + iota_1: "i32[6144]" = torch.ops.prims.iota.default( + 6144, + start=0, + step=1001, + dtype=torch.int32, + device=device(type="npu", index=0), + requires_grad=False, + ) + view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1]) + mul = None + view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1]) + div = None + _embedding_bag = torch.ops.aten._embedding_bag.default( + primals_3, view, iota_1, False, 0, False, view_1 + ) + + return _embedding_bag + + kwargs = aot_graph_input_parser(forward, device="npu:0") + # runs successfully + forward(**kwargs) + + @requires_npu() + def test_sym_aot_graph_parser(self): + def forward( + self, + primals_1: "f32[1001, 6]", # noqa: F821 + primals_2: "f32[s0]", # noqa: F821 + primals_3: "Sym(s0)", # noqa: F821, + primals_4: "f32[s1]", # noqa: F821, + primals_5: "Sym(s1)", # noqa: F821, + ): + _tensor_constant0: "i64[4190]" = self._tensor_constant0 + + kwargs = aot_graph_input_parser( + forward, device="npu:0", sym_shapes={"s0": 10}, default_sym_shape=5 + ) + + self.assertEqual(list(kwargs["primals_2"].shape), [10]) + self.assertEqual(kwargs["primals_3"], 10) + + self.assertEqual(list(kwargs["primals_4"].shape), [5]) + self.assertEqual(kwargs["primals_5"], 5) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index c09236b92217521cace9b81d5924bc88e82567f2..f5196c5002145bc6148c0917e2d9c8972c2bea6b 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -2,7 +2,6 @@ import os import unittest.mock as mock from unittest.mock import patch -import torch.library import torch import torch_npu @@ -37,7 +36,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnts.op_count, 4) def test_disable_for_custom_op(self): - foo = torch.library.Library("foo", "DEF") + import torch.library + from torch.library import Library + + foo = Library("foo", "DEF") # noqa: TOR901 foo.define("custom(Tensor self) -> Tensor") # Dynamic shape data dependent operator. For static shape compilation, Dynamo @@ -286,6 +288,66 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): def test_mark_static_address_unguarded(self): self._test_mark_static_address(guarded=False) + def test_class_methods(self): + class A: + @classmethod + def my_class_method(cls, arg1): + return cls, arg1 + + @staticmethod + def my_static_method(arg1): + return None, arg1 + + def my_regular_method(self, arg1): + return self, arg1 + + class B(A): + def my_class_method(self, arg1): + return super().my_class_method(arg1) + + def my_static_method(self, arg1): + return super().my_static_method(arg1) + + class C(A): + @classmethod + def my_class_method(cls, arg1): + return super().my_class_method(arg1) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt) + def fn(a, b, c): + # We want a function that does not graph break but + # does generate custom bytecode + v1 = a.my_class_method(1) + v2 = A.my_class_method(2) + v3 = a.my_static_method(3) + v4 = A.my_static_method(4) + v5 = a.my_regular_method(5) + v6 = b.my_class_method(6) + v7 = b.my_static_method(7) + v8 = c.my_class_method(8) + v9 = C.my_class_method(9) + torch.rand(2) + return v1, v2, v3, v4, v5, v6, v7, v8, v9 + + a, b, c = A(), B(), C() + v1, v2, v3, v4, v5, v6, v7, v8, v9 = fn(a, b, c) + + self.assertEqual(v1, (A, 1)) + self.assertEqual(v2, (A, 2)) + self.assertEqual(v3, (None, 3)) + self.assertEqual(v4, (None, 4)) + self.assertEqual(v5, (a, 5)) + # do for later fix me: we do not resolve classmethods properly + # from a regular method + # self.assertEqual(v6, (B, 6)) + self.assertEqual(v7, (None, 7)) + self.assertEqual(v8, (C, 8)) + self.assertEqual(v9, (C, 9)) + + self.assertEqual(cnt.frame_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 827962eae3a3ceef02a8879add50bfb9da64ed2f..cb6eb1c098a5f6ac391c619195fefa5c1f9c2f3c 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -1,12 +1,11 @@ # Owner(s): ["module: dynamo"] import unittest import warnings -import torch -import torch_npu + from torch._dynamo import config from torch._dynamo.testing import make_test_cls_with_patches from torch.fx.experimental import _config as fx_config -from torch.testing._internal.common_utils import TEST_Z3 +from torch.testing._internal.common_utils import slowTest, TEST_Z3 try: from . import ( @@ -18,6 +17,7 @@ try: test_misc, test_modules, test_repros, + test_sdpa, test_subgraphs, ) except ImportError: @@ -29,6 +29,7 @@ except ImportError: import test_misc import test_modules import test_repros + import test_sdpa import test_subgraphs @@ -48,7 +49,7 @@ def make_dynamic_cls(cls): (config, "specialize_int", False), (fx_config, "translation_validation", TEST_Z3), (fx_config, "check_shape_env_recorded_events", True), - (fx_config, "validate_shape_env_verison_key", True), + (fx_config, "validate_shape_env_version_key", True), xfail_prop="_expected_failure_dynamic", ) @@ -70,6 +71,7 @@ tests = [ test_higher_order_ops.HigherOrderOpTests, test_higher_order_ops.FuncTorchHigherOrderOpTests, test_aot_autograd.AotAutogradFallbackTests, + test_sdpa.TestSDPA, ] for test in tests: make_dynamic_cls(test) @@ -79,9 +81,27 @@ if TEST_Z3: # this only fails when z3 is available unittest.expectedFailure( # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: See github sympy issue 25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes + # Ref: https://github.com/sympy/sympy/issues/25146 + DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 + ) + + # TODO model is somehow not being freed when z3 is available + unittest.expectedFailure( + DynamicShapesMiscTests.test_custom_module_free_dynamic_shapes # noqa: F821 ) + unittest.expectedFailure( + DynamicShapesMiscTests.test_sequential_module_free_dynamic_shapes # noqa: F821 + ) + +unittest.expectedFailure( + # Test is only valid without dynamic shapes + DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes # noqa: F821 +) + +# Test takes too long ~700s as of 414a1fd29f04d06e41b7f895368dd1f83a4be29d +DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 + DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests @@ -92,4 +112,4 @@ if __name__ == "__main__": "Testing with translation validation requires Z3." ) - run_tests() + run_tests() \ No newline at end of file diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py new file mode 100644 index 0000000000000000000000000000000000000000..c5dd9288d8695cc538fb8b5b25127ae3def0bfd1 --- /dev/null +++ b/test/dynamo/test_exc.py @@ -0,0 +1,336 @@ +# Owner(s): ["module: dynamo"] + +import logging +import unittest + +import torch +import torch_npu +import torch._dynamo +import torch._dynamo.config +import torch._dynamo.test_case +from torch._dynamo.comptime import comptime +from torch._dynamo.exc import Unsupported +from torch.testing._internal.common_device_type import skipIf +from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3 +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test + + +class ExcTests(LoggingTestCase): + maxDiff = None + + def test_unsupported_real_stack(self): + # exercise Unsupported constructor and augment_exc_message + def fn002(x): + torch._dynamo.graph_break() + + def fn001(x): + x = x + 1 + fn002(x) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn001, backend="eager", fullgraph=True)( + torch.randn(1) + ), + """\ +'skip function graph_break in file _dynamo/decorators.py' + +from user code: + File "test_exc.py", line N, in fn001 + fn002(x) + File "test_exc.py", line N, in fn002 + torch._dynamo.graph_break()""", + ) + + @torch._dynamo.config.patch(verbose=True, suppress_errors=True) + @make_logging_test() + @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode") + def test_internal_error_suppress_errors(self, records): + def fn001(x): + def f(ctx): + raise AssertionError() + + comptime(f) + + torch.compile(fn001, backend="eager")(torch.randn(1)) + + record = self.getRecord(records, "WON'T CONVERT") + + self.assertExpectedInline( + munge_exc(record.getMessage()), + """\ +WON'T CONVERT fn001 test_exc.py line N +========== TorchDynamo Stack Trace ========== +Traceback (most recent call last): + File "test_exc.py", line N, in f + raise AssertionError() +AssertionError: + +from user code: + File "test_exc.py", line N, in fn001 + comptime(f) + + +========== The above exception occurred while processing the following code ========== + + File "test_exc.py", line N, in test_internal_error_suppress_errors + torch.compile(fn001, backend="eager")(torch.randn(1)) + File "test_exc.py", line N, in fn001 + comptime(f) + +==========""", + ) + + @make_logging_test() + def test_not_implemented_error(self, records): + def fn001(x): + def f(ctx): + raise NotImplementedError() + + # Ensure graph break is not possible + for i in range(3): + comptime(f) + + torch.compile(fn001, backend="eager")(torch.randn(1)) + + record = self.getRecord(records, "WON'T CONVERT") + + self.assertExpectedInline( + munge_exc(record.getMessage()), + """\ +WON'T CONVERT fn001 test_exc.py line N +due to: +Traceback (most recent call last): + File "test_exc.py", line N, in f + raise NotImplementedError() +torch._dynamo.exc.InternalTorchDynamoError: + +from user code: + File "test_exc.py", line N, in fn001 + comptime(f)""", + ) + + @unittest.expectedFailure + @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True) + @make_logging_test(dynamo=logging.DEBUG) + def test_unsupported_error(self, records): + def fn001(x): + return {1, 2} + + torch.compile(fn001, backend="eager")(torch.randn(1)) + + # do for later: There is no graph break log! This is because the graph break + # logging is not in a centralized location; unsupported + # instruction bypasses it + self.getRecord(records, "Graph break:") + + @torch._dynamo.config.patch(suppress_errors=False) + def test_internal_error_no_suppress(self): + def fn001(x): + # NB: avoid decorator, as 3.11 changed the line number attributed + # in this situation + def f(ctx): + raise AssertionError() + + comptime(f) + + # NB: OK for user code to be truncated here, because the regular + # exception backtrace has the rest of the crumbs + self.assertExpectedInlineMunged( + AssertionError, + lambda: torch.compile(fn001, backend="eager")(torch.randn(1)), + """\ + + +from user code: + File "test_exc.py", line N, in fn001 + comptime(f)""", + ) + + @make_logging_test(graph_breaks=True) + def test_graph_break_log(self, records): + def fn002(x): + x = x + 1 + torch._dynamo.graph_break() + x = x + 1 + return x + + def fn001(x): + return fn002(x) + + torch.compile(fn001, backend="eager")(torch.randn(1)) + + record = self.getRecord(records, "Graph break:") + + # do for later: This should also report the enclosing frames; need to plumb + # frame object to it + self.assertExpectedInline( + munge_exc(record.getMessage()), + """\ +Graph break: from user code at: + File "test_exc.py", line N, in fn001 + return fn002(x) + File "test_exc.py", line N, in fn002 + torch._dynamo.graph_break() +""", # noqa: B950 + ) + + @torch._dynamo.config.patch(suppress_errors=False) + def test_backend_suppress_line(self): + def fn001(x): + x = torch.relu(x) + return x + 1 + + # Do NOT let this get attributed to x + 1 + self.assertExpectedInlineMunged( + torch._dynamo.exc.BackendCompilerFailed, + lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")( + torch.randn(1) + ), + """\ +backend='relu_compile_error_TESTING_ONLY' raised: +ReluCompileError:""", + ) + + @skipIf(not TEST_Z3, "z3 not installed") + @torch._dynamo.config.patch( + assume_static_by_default=False, + suppress_errors=False, + ) + @torch.fx.experimental._config.patch( + inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, + translation_validation=True, + translation_validation_no_bisect=True, + ) + def test_trigger_on_error(self): + from torch.fx.experimental.validator import ValidationException + + @torch.compile + def fn(x, shape): + return x.split(shape) + + self.assertExpectedInlineMunged( + ValidationException, + lambda: fn(torch.randn(20), (5, 10, 5)), + """\ +translation validation failed. + +Model: + ==> L['shape'][0]: 0 + ==> L['shape'][1]: 0 + ==> L['shape'][2]: 0 + ==> L['x'].size()[0]: 3 + ==> L['x'].storage_offset(): 0 + ==> L['x'].stride()[0]: 1 + ==> s0: 3 + ==> s1: 0 + ==> s2: 0 + ==> s3: 0 + +Assertions: + ==> (== 0 L['x'].storage_offset()) + ==> (== 1 L['x'].stride()[0]) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) + ==> (== L['shape'][2] s3) + ==> (== L['x'].size()[0] s0) + ==> (> s0 1) + ==> (True) + +Target Expressions: + ==> (<= 0 s1) + ==> (<= 0 s2) + ==> (<= 0 s3) + ==> (<= 2 s0) + ==> (== 0 L['shape'][0]) + ==> (== 0 L['shape'][1]) + ==> (== 0 L['shape'][2]) + ==> (== 0 L['x'].storage_offset()) + ==> (== 0 s1) + ==> (== 0 s2) + ==> (== 0 s3) + ==> (== 1 L['x'].stride()[0]) + ==> (== L['x'].size()[0] s0) + ==> (> s0 0) + ==> (>= 0 s1) + ==> (>= 0 s2) + ==> (>= 0 s3) + ==> (>= 9223372036854775806 s0) + +Failed Source Expressions: + ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", + ) + + @skipIf(not TEST_Z3, "z3 not installed") + @torch._dynamo.config.patch( + assume_static_by_default=False, + suppress_errors=False, + ) + @torch.fx.experimental._config.patch( + inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, + translation_validation=True, + ) + def test_trigger_bisect_on_error(self): + from torch.fx.experimental.validator import BisectValidationException + + @torch.compile + def fn(x, shape): + return x.split(shape) + + self.assertExpectedInlineMunged( + BisectValidationException, + lambda: fn(torch.randn(20), (5, 10, 5)), + """\ +translation validation failed when evaluating: Eq(s1 + s2 + s3, s0) + +Failure occurred while running node: + %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) + +Model: + ==> L['shape'][0]: 1 + ==> L['shape'][1]: 1 + ==> L['shape'][2]: 2 + ==> L['x'].size()[0]: 3 + ==> L['x'].storage_offset(): 0 + ==> L['x'].stride()[0]: 1 + ==> s0: 3 + ==> s1: 1 + ==> s2: 1 + ==> s3: 2 + +Assertions: + ==> (== 0 L['x'].storage_offset()) + ==> (== 1 L['x'].stride()[0]) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) + ==> (== L['shape'][2] s3) + ==> (== L['x'].size()[0] s0) + ==> (> s0 1) + +Target Expressions: + ==> (!= (+ s1 s2 s3) s0) + ==> (<= 0 s1) + ==> (<= 0 s2) + ==> (<= 0 s3) + ==> (<= 2 s0) + ==> (== 0 L['x'].storage_offset()) + ==> (== 1 L['x'].stride()[0]) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) + ==> (== L['shape'][2] s3) + ==> (== L['x'].size()[0] s0) + ==> (> s0 0) + ==> (>= 9223372036854775806 s0) + ==> (>= 9223372036854775807 s1) + ==> (>= 9223372036854775807 s2) + ==> (>= 9223372036854775807 s3) + +Failed Source Expressions: + ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index ff09a5c5f86ee510d862a5a632f5ad8802f9f836..27d3b28b16825e4f5c92e931feef12632972bba3 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -7,7 +7,6 @@ import copy import functools import inspect import io -import math import operator import unittest from enum import Enum @@ -16,19 +15,16 @@ from unittest.mock import patch import torch import torch_npu - import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing -from functorch.experimental.control_flow import map as mp from functorch.experimental.control_flow import cond from torch._dynamo import config -from torch._dynamo.output_graph import config as output_graph_config from torch._dynamo.exc import UserError from torch._dynamo.testing import normalize_gm -from torch._export import dynamic_dim from torch._higher_order_ops.out_dtype import out_dtype from torch._subclasses import fake_tensor +from torch.export import dynamic_dim from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, @@ -37,6 +33,7 @@ from torch.fx.experimental.symbolic_shapes import ( StatelessSymbolicContext, ) from torch.testing._internal import common_utils +from torch.testing._internal.common_cuda import TEST_CUDA class ExportTests(torch._dynamo.test_case.TestCase): @@ -124,7 +121,11 @@ class ExportTests(torch._dynamo.test_case.TestCase): for guard in out_guards: if guard.source == GuardSource.SHAPE_ENV: hit = True - self.assertTrue("L['x'].size()[0] <= 10" in guard.code_list) + self.assertExpectedInline( + guard.code_list, + """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950 + ) + break self.assertTrue(hit) @@ -1254,7 +1255,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) - # Tensor input can be almost anything here, and the result will capture what we + # Tensor ipt can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1278,7 +1279,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) - # Tensor input can be almost anything here, and the result will capture what we + # Tensor ipt can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1302,7 +1303,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) - # Tensor input can be almost anything here, and the result will capture what we + # Tensor ipt can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1316,15 +1317,15 @@ class ExportTests(torch._dynamo.test_case.TestCase): def forward(self, x): y = torch.tensor([0.5]) elements = self.helper_fn(x) - y = y * elements.get("x") - y = y * elements.get("x^2") + y = y * elements["x"] + y = y * elements["x^2"] return y module = MyModule() real_result = module(torch.tensor([2.0, 2.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0])) - # Tensor input can be almost anything here, and the result will capture what we + # Tensor ipt can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) @@ -1485,6 +1486,29 @@ class ExportTests(torch._dynamo.test_case.TestCase): resB = graph(torch.tensor([2])) self.assertTrue(torch._dynamo.utils.same(resA, resB)) + def test_export_with_builtin_op_on_assume_constant(self): + @torch._dynamo.assume_constant_result + def get_y(y) -> torch.Tensor: + return y + + class Bob(torch.nn.Module): + def __init__(self, p, val) -> None: + super().__init__() + self.p = p + self.y = torch.nn.Parameter(torch.tensor(val)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # This only looks dynamic but it's actually a constant value + if get_y(self.y) < self.p: + return torch.cat([x, x]) + else: + return x + + model = Bob(0.5, 0.3) + inp = torch.ones(3, 4) + graph, guards = torch._dynamo.export(model)(inp) + self.assertEqual(model(inp), graph(inp)) + def test_export_decomp(self): def f(x): return x.t() + x.t() @@ -1527,6 +1551,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): @config.patch(capture_scalar_outputs=True) def test_export_with_module_layer(self): + from functorch.experimental.control_flow import cond + class Module(torch.nn.Module): def __init__(self): super().__init__() @@ -1563,6 +1589,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): @config.patch(capture_scalar_outputs=True) def test_export_with_cond_branches_calling_methods(self): + from functorch.experimental.control_flow import cond + class Module(torch.nn.Module): # ok def __init__(self): @@ -1594,6 +1622,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): @config.patch(capture_scalar_outputs=True) def test_export_with_cond_closure(self): + from functorch.experimental.control_flow import cond + class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -1669,6 +1699,8 @@ class ExportTests(torch._dynamo.test_case.TestCase): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_cond_dynamic_shape_pred(self): + from functorch.experimental.control_flow import cond + class Module(torch.nn.Module): def forward(self, x): def true_fn(x): @@ -1739,6 +1771,8 @@ def forward(self, l_x_): mod(test_x) def test_export_with_map_cond(self): + from functorch.experimental.control_flow import cond, map + class Module(torch.nn.Module): def inner(self, x, pred): def true_fn(x): @@ -1753,7 +1787,7 @@ def forward(self, l_x_): def body(x, pred): return self.inner(x, pred) - return mp(body, xs, pred) + return map(body, xs, pred) mod = Module() x = torch.randn(3, 2, 1) @@ -1767,12 +1801,14 @@ def forward(self, l_x_): self.assertEqual(real_result, out_graph(pred_y, y)) def test_export_with_map_zero_sized_tensor(self): + from functorch.experimental.control_flow import map + class Module(torch.nn.Module): def forward(self, xs): def body(x): return x + 1 - return mp(body, xs) + return map(body, xs) mod = Module() xs = torch.randn(0, 2) @@ -1944,7 +1980,7 @@ def forward(self, x): for arg in myargs: out *= arg out *= mykw0 - out *= mykwargs.get("input0") * mykwargs.get("input1") + out *= mykwargs["input0"] * mykwargs["input1"] return out mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} @@ -2010,13 +2046,13 @@ def forward(self, x): common_utils.subtest(None, name="None"), common_utils.subtest(42.0, name="float"), common_utils.subtest( - # for fixme: AssertionError: Dynamo input and output is a strict subset of traced input/output + # FIXME: AssertionError: Dynamo ipt and output is a strict subset of traced ipt/output torch.randn(4), name="tensor", decorators=[unittest.expectedFailure], ), common_utils.subtest( - # for fixme: AssertionError: Dynamo input and output is a strict subset of traced input/output + # FIXME: AssertionError: Dynamo ipt and output is a strict subset of traced ipt/output (torch.randn(4),), name="tuple", decorators=[unittest.expectedFailure], @@ -2045,13 +2081,13 @@ def forward(self, x): common_utils.subtest(None, name="None"), common_utils.subtest(42.0, name="float"), common_utils.subtest( - # for fixme: AssertionError: Dynamo input and output is a strict subset of traced input/output + # FIXME: AssertionError: Dynamo ipt and output is a strict subset of traced ipt/output torch.randn(4), name="tensor", decorators=[unittest.expectedFailure], ), common_utils.subtest( - # for fixme: AssertionError: Dynamo input and output is a strict subset of traced input/output + # FIXME: AssertionError: Dynamo ipt and output is a strict subset of traced ipt/output (torch.randn(4),), name="tuple", decorators=[unittest.expectedFailure], @@ -2067,7 +2103,7 @@ def forward(self, x): elif isinstance(kw1_default, tuple): kw1_default = kw1_default[0] out += kw1_default - out += kwargs.get("kw2") + out += kwargs["kw2"] return out pos0 = torch.randn(4) @@ -2223,66 +2259,46 @@ def forward(self, x): return t.x + t.y with self.assertRaisesRegex( - AssertionError, - "graph-captured input #1, of type .*Tensor.*, " - "is not among original inputs of types: .*Tensors", + UserError, + "It looks like one of the inputs with type .*Tensors.* " + "is not supported or pytree-flattenable", ): - torch._dynamo.export( - f, Tensors(x=torch.randn(10), y=torch.randn(10)), aten_graph=False + torch._dynamo.export(f, aten_graph=False)( + Tensors(x=torch.randn(10), y=torch.randn(10)) ) def f(x, y): return Tensors(x=x.sin(), y=y.cos()) with self.assertRaisesRegex( - AssertionError, - "original output #1 is .*Tensors.*, " - "but only the following types are supported", - ): - torch._dynamo.export(f, torch.randn(10), torch.randn(10), aten_graph=False) - - def test_none_out(self): - def f(x, y): - _ = x + y - - with self.assertRaisesRegex( - AssertionError, - "original output #1 is None, but only the following types are supported", + UserError, + "It looks like one of the outputs with type .*Tensors.* " + "is not supported or pytree-flattenable", ): - torch._dynamo.export(f, torch.randn(10), torch.randn(10), aten_graph=False) - - def test_primitive_constant_output(self): - def foo(x): - # return a constant of primitive type - y = 5 - return y * x, y + torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) - with self.assertRaisesRegex( - AssertionError, - "original output #2 is 5, but only the following types are supported", - ): - torch.export.export(foo, (torch.tensor(3),)) + def test_empty(self): + def f(x): + return x - def bar(x, y): - return y * x, y + exported = torch._dynamo.export(f)(torch.randn(3, 3)) + out_graph = exported[0] + inp = torch.randn(3, 3) + self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp))) - # new behavior - with self.assertRaisesRegex( - AssertionError, - "original output #2 is 5, but only the following types are supported", - ): - torch.export.export(bar, (torch.tensor(3), 5)) + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.ones(3, 3) - def qux(x, y): - return y * x, y - 1 + def forward(self): + return self.a - with self.assertRaisesRegex( - AssertionError, - "original output #2 is 4, but only the following types are supported", - ): - torch.export.export(qux, (torch.tensor(3), 5)) + exported = torch._dynamo.export(M())() + out_graph = exported[0] + self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph())) - @unittest.skipIf(not torch.npu.is_available(), "requires npu") + @unittest.skipIf(not TEST_CUDA, "No CUDA available.") def test_export_with_parameters(self): class MyModule(torch.nn.Module): def __init__(self): @@ -2298,7 +2314,7 @@ def forward(self, x): return self.features(x) model = MyModule().eval().npu() - random_inputs = (torch.rand([32, 3, 32, 32]).npu(),) + random_inputs = (torch.rand([32, 3, 32, 32]).to("npu:0"),) dim_x = torch.export.Dim("dim_x", min=1, max=32) exp_program = torch.export.export( model, random_inputs, dynamic_shapes={"x": {0: dim_x}} @@ -2308,9 +2324,31 @@ def forward(self, x): torch.export.save(exp_program, output_buffer) loaded_model = torch.export.load(output_buffer) self.assertTrue( - isinstance(loaded_model.module().features_0_weight, torch.nn.Parameter) + isinstance( + loaded_model.module().get_parameter("features.0.weight"), + torch.nn.Parameter, + ) ) + def test_export_fast_binary_broadcast_check(self): + # This test looks at the case where we erroneously create a guard + # when checking the equality of the operands' shape and the output + # shape during FakeTensor's binary op fast path. + + class MyModel(torch.nn.Module): + def forward(self, a, b): + # final shape is (dim0, 4, 8) + # order matters since a & the output have the same shape + return b + a + + a = torch.randn(100, 4, 8) + b = torch.randn(4, 8) + model = MyModel().eval().npu() + batchsize = torch.export.Dim("dim0", min=3, max=1024) + dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]} + + torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec) + def test_export_meta(self): class MyModule(torch.nn.Module): def __init__(self): @@ -2330,11 +2368,14 @@ def forward(self, x): self.assertEqual(dynamo_result, m(inp)) def test_constraint_violation_error_messages(self): - def foo(x): - if x.shape[0] == x.shape[1] * 2: - return x + 1 - else: - return x + 2 + class Foo(torch.nn.Module): + def forward(self, x): + if x.shape[0] == x.shape[1] * 2: + return x + 1 + else: + return x + 2 + + foo = Foo() t = torch.zeros([8, 4]) dim0 = torch.export.Dim("dim0", min=3, max=10) @@ -2344,15 +2385,19 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Not all values.*valid.*inferred to be equal to(.*\n)*.*" - "must be specialized.*guards generated.*too complex", + "The values of.*must always be related to the values of.*" + "by dim0 = 2\\*dim1", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) - def bar(x): - if x.shape[0] == 5: - return x + 1 - else: - return x + 2 + class Bar(torch.nn.Module): + def forward(self, x): + if x.shape[0] == 5: + return x + 1 + else: + return x + 2 + + bar = Bar() t = torch.zeros([5]) dim0 = torch.export.Dim("dim0", min=3, max=8) @@ -2363,11 +2408,14 @@ def forward(self, x): ): torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes) - def qux(x): - if x.shape[0] > 5 and x.shape[0] < 10: - return x + 1 - else: - return x + 2 + class Qux(torch.nn.Module): + def forward(self, x): + if x.shape[0] > 5 and x.shape[0] < 10: + return x + 1 + else: + return x + 2 + + qux = Qux() t = torch.zeros([7]) dim0 = torch.export.Dim("dim0", min=3, max=8) @@ -2379,27 +2427,23 @@ def forward(self, x): torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes) def test_untracked_inputs_in_constraints(self): - from copy import copy - def foo(x, y): - return y + 1 + class Foo(torch.nn.Module): + def forward(self, x, y): + return y + 1 + + foo = Foo() x = torch.randn(2) y = torch.randn(5, 4) - constraints = [dynamic_dim(x, 0), dynamic_dim(y, 0)] - - example_inputs = (copy(x), y) - ep = torch._export._export(foo, example_inputs, constraints=constraints) - with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"): - ep(torch.randn(3), y) dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y") dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} example_inputs = (copy(x), y) ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes) - ep(torch.randn(3), y) # no specialization error + ep.module()(torch.randn(3), y) # no specialization error def test_export_raise_guard_full_constraint(self): y = torch.randn([3, 3, 3]) @@ -2492,11 +2536,14 @@ def forward(self, x): torch._dynamo.export(my_dyn_fn, constraints=constraints)(x, y, z) def test_remove_redundant_dynamic_dim_in_error_message(self): - def foo(x, y): - if x.shape[0] == y["k"].shape[0]: - return x + 1 - else: - return x - 1 + class Foo(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] == y["k"].shape[0]: + return x + 1 + else: + return x - 1 + + foo = Foo() a = torch.randn(3) b = torch.randn(3) @@ -2509,8 +2556,11 @@ def forward(self, x): ) def test_enforce_equalities(self): - def bar(x, y): - return torch.matmul(x, y) + class Bar(torch.nn.Module): + def forward(self, x, y): + return torch.matmul(x, y) + + bar = Bar() batch, size = torch.export.dims("batch", "size") dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)} @@ -2688,19 +2738,35 @@ def forward(self, x): )(x) def test_trivial_constraint(self): - def foo(x): - # non-trivial divisibility condition - if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0: - return x + 1 - else: - return x - 1 + class Foo(torch.nn.Module): + def forward(self, x): + # complex divisibility condition + if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0: + return x + 1 + else: + return x - 1 - def bar(x): - # trivially true - if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0: - return x + 1 - else: - return x - 1 + foo = Foo() + + class Bar(torch.nn.Module): + def forward(self, x): + # trivially true + if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0: + return x + 1 + else: + return x - 1 + + bar = Bar() + + class Qux(torch.nn.Module): + def forward(self, x): + # simple divisibility condition (not trivially true) + if (3 * x.shape[0]) % 2 == 0: + return x + 1 + else: + return x - 1 + + qux = Qux() x = torch.randn(12) dim0 = torch.export.Dim("dim0", max=100) @@ -2713,6 +2779,12 @@ def forward(self, x): torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes) + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "Not all values.*satisfy the generated guard", + ): + torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes) + def test_list_contains(self): def func(x): assert x.size(-1) in [4, 5, 6], "bad" @@ -2885,11 +2957,11 @@ def forward(self, x): @config.patch(assume_static_by_default=False) def test_export_persist_assert(self): def f(x): - assert x.shape[0] > 4, "Shape must be more than 4" + assert x[0].sum() > 4, "Shape must be more than 4" return x.cos() + x.sin() gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( - torch.randn(5, 4, 6) + torch.ones(5, 4, 6) ) def has_aten_op(gm, op): @@ -2905,7 +2977,7 @@ def forward(self, x): self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): - gm(torch.randn(3, 4, 5)) + gm(torch.zeros(3, 4, 5)) @common_utils.parametrize( "type_fn", @@ -2956,15 +3028,9 @@ def forward(self, x): def f(x): return x[: round(x.shape[0] / 2)] - def f_correct(x): - return x[: math.floor(x.shape[0] / 2)] - - with self.assertRaisesRegex(torch._dynamo.exc.UserError, "Calling round()"): - gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) - - gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4)) + gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) - self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) + self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) def test_cond_supported_pred_types(self): def true_fn(x): @@ -3215,7 +3281,7 @@ def forward(self, x): ) def test_byte_tensor_does_not_crash(self): - # See See pytorch/pytorch/issues/100455 + # See pytorch/pytorch/issues/100455 def func(text): tensor = torch.ByteTensor(list(bytes(text, "utf8"))) return tensor + tensor @@ -3263,6 +3329,7 @@ def forward(self, x): ) def test_capture_symbolic_tracing_simple_within_fake_mode(self): + from torch._dynamo.output_graph import config def f(x): y = torch.randn(3) @@ -3270,8 +3337,8 @@ def forward(self, x): with fake_tensor.FakeTensorMode( shape_env=ShapeEnv( - allow_scalar_outputs=output_graph_config.capture_scalar_outputs, - allow_dynamic_output_shape_ops=output_graph_config.capture_dynamic_output_shape_ops, + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, ), ): x = torch.randn(3) @@ -3434,6 +3501,7 @@ G['macademia'], accessed at: torch._dynamo.export(f)(torch.randn(3)) def test_symbolic_tracing_within_fake_mode_with_constraints(self): + from torch._subclasses import fake_tensor fake_mode = fake_tensor.FakeTensorMode() @@ -3469,6 +3537,8 @@ G['macademia'], accessed at: self.assertEqual(model(*inputs), gm(*inputs)) def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self): + from torch._subclasses import fake_tensor + fake_mode = fake_tensor.FakeTensorMode() # do for later: Seems to choke if you don't make a fresh model and @@ -3496,6 +3566,9 @@ G['macademia'], accessed at: )(*inputs).graph_module def test_capture_symbolic_tracing_within_fake_mode(self): + from torch._dynamo.output_graph import config + from torch._subclasses import fake_tensor + from torch.fx.experimental.symbolic_shapes import ShapeEnv class Model(torch.nn.Module): def __init__(self) -> None: @@ -3513,11 +3586,11 @@ G['macademia'], accessed at: allow_non_fake_inputs=False, allow_fallback_kernels=True, shape_env=ShapeEnv( - allow_scalar_outputs=output_graph_config.capture_scalar_outputs, - allow_dynamic_output_shape_ops=output_graph_config.capture_dynamic_output_shape_ops, + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, ), ) - # Fakefy input+model before exporting it + # Fakefy ipt+model before exporting it with fake_mode: x = torch.rand(5, 2, 2) model = Model() @@ -3611,6 +3684,8 @@ G['macademia'], accessed at: self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) def test_map_cond_param_buffer_lifted(self): + from functorch.experimental.control_flow import cond, map + class A(torch.nn.Module): def __init__(self): super().__init__() @@ -3646,7 +3721,7 @@ G['macademia'], accessed at: def body(x, pred): return self.inner(x, pred) + self.b() - return mp(body, xs, pred) + return map(body, xs, pred) mod = Module() x = torch.randn(3, 2, 1) @@ -3660,6 +3735,8 @@ G['macademia'], accessed at: self.assertEqual(real_result, out_graph(pred_y, y)) def test_cond_free_variables_overlapping(self): + from functorch.experimental.control_flow import cond + class Module(torch.nn.Module): def __init__(self): super().__init__() @@ -3737,7 +3814,7 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch): @unittest.skipIf( common_utils.TEST_WITH_ASAN, - "Times out with ASAN, see See pytorch/pytorch/issues/110416", + "Times out with ASAN, See pytorch/pytorch/issues/110416", ) def test_retracibility(self): class MyLinear(torch.nn.Module): @@ -3895,12 +3972,14 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch): @config.patch(suppress_errors=True) @config.patch(verbose=True) def test_export_with_map_zero_sized_tensor_suppress_errors(self): + from functorch.experimental.control_flow import map + class Module(torch.nn.Module): def forward(self, xs): def body(x): return x + 1 - return mp(body, xs) + return map(body, xs) mod = Module() xs = torch.randn(0, 2) @@ -4288,6 +4367,90 @@ def forward(self, x): self.assertEqual(out.requires_grad, False) out.requires_grad = True + def fn(x): + with torch.inference_mode(): + return x + 1 + + gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2)) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x): + arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + l_x_ = arg0 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) + add = l_x_ + 1; l_x_ = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", + ) + inp = torch.randn(2, 2, requires_grad=True) + out = gm(inp) + self.assertEqual(out.requires_grad, False) + + def test_export_masking_with_no_grad(self): + def fn(x, b, y): + x = x.clone() + x[b] = y + return x + + def fn_no_grad(x, b, y): + with torch.no_grad(): + return fn(x, b, y) + + def fn_inference_mode(x, b, y): + with torch.inference_mode(): + return fn(x, b, y) + + x = torch.randn(4, requires_grad=True) + b = torch.tensor([True, False, True, False]) + y = torch.randn(2, requires_grad=True) + + gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x, b, y): + arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) + l_x_ = arg0 + l_b_ = arg1 + l_y_ = arg2 + _set_grad_enabled = torch._C._set_grad_enabled(False) + x = l_x_.clone(); l_x_ = None + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None + _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + return pytree.tree_unflatten([x], self._out_spec)""", + ) + + gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x, b, y): + arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) + l_x_ = arg0 + l_b_ = arg1 + l_y_ = arg2 + _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) + x = l_x_.clone(); l_x_ = None + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None + return pytree.tree_unflatten([x], self._out_spec)""", + ) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "boolean masking setitem backwards" + ): + gm, _ = torch._dynamo.export(fn)(x, b, y) + + def test_dynamo_list_index(self): + def fn(x, in_list): + return x + in_list.index(2) + + inputs = (torch.ones(2, 2), [1, 2]) + graph, _ = torch._dynamo.export(fn)(*inputs) + out = graph(*inputs) + self.assertEqual(out, torch.ones(2, 2) + 1) + common_utils.instantiate_parametrized_tests(ExportTests) diff --git a/test/dynamo/test_export_mutations.py b/test/dynamo/test_export_mutations.py index 683d02f0667380b291cedb33393c329916160f14..052f7e605533bc5022d503f56ae49f2ab2054f14 100644 --- a/test/dynamo/test_export_mutations.py +++ b/test/dynamo/test_export_mutations.py @@ -3,7 +3,6 @@ import unittest import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing from torch.testing._internal.common_utils import IS_FBCODE @@ -18,20 +17,20 @@ class MutationExportTests(torch._dynamo.test_case.TestCase): real_result = mod(arg) graph, _ = torch._dynamo.export(mod)(arg) result = graph(arg) - self.assertTrue(torch._dynamo.utils.same(result, real_result)) + self.assertEqual(result, real_result) def test_module_attribute_mutation_violation_positive_1(self): # Mutating attribute with a Tensor type class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) def forward(self, x): self.a = self.a.to(torch.float64) return x.sum() + self.a.sum() - self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + self.check_failure_on_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_positive_2(self): # Mutating attribute with a scalar type @@ -44,20 +43,20 @@ class MutationExportTests(torch._dynamo.test_case.TestCase): self.a = self.a * 3 return x.sum() + self.a - self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + self.check_failure_on_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_positive_3(self): # Setting a new attribute inside forward() class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) def forward(self, x): self.b = 2 return x.sum() + self.a.sum() + self.b - self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + self.check_failure_on_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_positive_4(self): # Mutating attribute with an inline function @@ -69,7 +68,7 @@ class MutationExportTests(torch._dynamo.test_case.TestCase): self.a = self.add(1, 2) * self.add(3, 4) return x.sum() + self.a - self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + self.check_failure_on_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_negative_1(self): # Mutating attribute with a Tensor type inside __init__ but @@ -77,39 +76,39 @@ class MutationExportTests(torch._dynamo.test_case.TestCase): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) def forward(self, x): return x.sum() + self.a.to(torch.float64).sum() - self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + self.check_same_with_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_negative_2(self): # Mutating attribute with a Tensor type inside __init__ twice class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) self.a = self.a.to(torch.float64) def forward(self, x): return x.sum() + self.a.sum() - self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + self.check_same_with_export(Foo(), torch.randn(3, 2)) def test_module_attribute_mutation_violation_negative_3(self): # Mutating local variable inside forward() class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) def forward(self, x): b = 1 b = b * 5 return x.sum() + self.a.sum() + b - self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + self.check_same_with_export(Foo(), torch.randn(3, 2)) @unittest.skipIf(IS_FBCODE, "Broken in fbcode") def test_module_attribute_mutation_violation_negative_4(self): @@ -118,17 +117,17 @@ class MutationExportTests(torch._dynamo.test_case.TestCase): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.a = torch.Tensor(3, 2) + self.a = torch.randn(3, 2) def forward(self, x): self.a = self.a.to(torch.float64) return x.sum() + self.a.sum() mod = Foo() - arg = torch.Tensor(3, 2) + arg = torch.randn(3, 2) real_result = mod(arg) opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) - self.assertTrue(torch._dynamo.utils.same(opt_mod(arg), real_result)) + self.assertEqual(opt_mod(arg), real_result) if __name__ == "__main__": diff --git a/test/dynamo/test_frame_init.py b/test/dynamo/test_frame_init.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8ae7fef3a3c7f6008607b778ce77d1e10779a0 --- /dev/null +++ b/test/dynamo/test_frame_init.py @@ -0,0 +1,128 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch_npu +import torch._dynamo.test_case + + +def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): + local = 1 + return { + "local": local, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "keyword_only_arg": keyword_only_arg, + "kwargs": kwargs, + } + + +def varkwargs_code1(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): + # remove a local variable: local = 1 + return { + "local": 1, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "keyword_only_arg": keyword_only_arg, + "kwargs": kwargs, + } + + +def varkwargs_code2(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): + # introduce a local variable + local1 = 0 + local2 = 1 + return { + "local": local1 + local2, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "keyword_only_arg": keyword_only_arg, + "kwargs": kwargs, + } + + +def target_with_varargs(arg1, /, positional_only_arg, *varargs, **kwargs): + local = 1 + return { + "local": local, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "varargs": varargs, + "kwargs": kwargs, + } + + +def varargs_code1(arg1, /, positional_only_arg, *varargs, **kwargs): + # remove a local variable: local = 1 + return { + "local": 1, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "varargs": varargs, + "kwargs": kwargs, + } + + +def varargs_code2(arg1, /, positional_only_arg, *varargs, **kwargs): + # introduce a local variable + local1 = 0 + local2 = 1 + return { + "local": local1 + local2, + "arg1": arg1, + "positional_only_arg": positional_only_arg, + "varargs": varargs, + "kwargs": kwargs, + } + + +class FrameInitTests(torch._dynamo.test_case.TestCase): + def test_frame_init(self): + code_map1 = { + target_with_varargs.__code__: varargs_code1.__code__, + target_with_varkwargs.__code__: varkwargs_code1.__code__, + } + code_map2 = { + target_with_varargs.__code__: varargs_code2.__code__, + target_with_varkwargs.__code__: varkwargs_code2.__code__, + } + + def callback1(frame, cache_entry, frame_state): + if frame.f_code in code_map1: + transformed_code = code_map1[frame.f_code] + return torch._dynamo.types.GuardedCode( + transformed_code, lambda f_locals: True + ) + return None + + def callback2(frame, cache_entry, frame_state): + if frame.f_code in code_map2: + transformed_code = code_map2[frame.f_code] + return torch._dynamo.types.GuardedCode( + transformed_code, lambda f_locals: True + ) + return None + + for callback in [callback1, callback2]: + torch._dynamo.reset() + expected_varargs_output = target_with_varargs( + 1, 2, 3, 4, name1=1, name2=2, name3=3 + ) + expected_kwargs_output = target_with_varkwargs( + 1, 2, keyword_only_arg=1, name2=2, name3=3 + ) + original = torch._dynamo.eval_frame.set_eval_frame(callback1) + real_varargs_output = target_with_varargs( + 1, 2, 3, 4, name1=1, name2=2, name3=3 + ) + real_kwargs_output = target_with_varkwargs( + 1, 2, keyword_only_arg=1, name2=2, name3=3 + ) + self.assertEqual(real_varargs_output, expected_varargs_output) + self.assertEqual(real_kwargs_output, expected_kwargs_output) + torch._dynamo.eval_frame.set_eval_frame(original) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index be22e333e6a80a1051d0bab69fabc02325c9d4bf..9544a82a8739bdd3f364d0aab3ec39eede28539f 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1,11 +1,12 @@ # Owner(s): ["module: dynamo"] -# flake8: noqa +# flake8: noqa: E731, C405, F811, C418, C417 import collections import functools import inspect import itertools import math import operator +import random import sys import unittest from dataclasses import dataclass, field @@ -16,35 +17,28 @@ import numpy as np import torch import torch_npu + import torch._dynamo.test_case import torch._dynamo.testing from torch import sub -from torch._dynamo.testing import expectedFailureDynamic +from torch._dynamo.testing import ( + CompileCounterWithBackend, + EagerAndRecordGraphs, + expectedFailureDynamic, + normalize_gm, +) from torch._dynamo.utils import ifdynstaticdefault, same -from torch._higher_order_ops.triton_kernel_wrap import ( - triton_kernel_wrapper_functional, - triton_kernel_wrapper_mutation, -) -from torch._inductor import metrics from torch.nn import functional as F -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( disable_translation_validation_if_dynamic_shapes, - skipIfRocm, + instantiate_parametrized_tests, + parametrize, ) # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 -if HAS_CUDA: - import triton - from triton import language as tl - -requires_cuda = functools.partial( - unittest.skipIf, not torch.cuda.is_available(), "requires cuda" -) - d = torch.ones(10, 10) e = torch.nn.Linear(10, 10) flag = True @@ -66,11 +60,19 @@ def func_with_default(a, b, some_default_arg=True): return a - b -def make_test(fn): +def make_test(fn=None, expected_frame_count=1): + if fn is None: + return lambda fn: make_test(fn, expected_frame_count=expected_frame_count) + nargs = len(inspect.signature(fn).parameters) def test_fn(self): - return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs) + return torch._dynamo.testing.standard_test( + self, + fn=fn, + nargs=nargs, + expected_frame_count=expected_frame_count, + ) return test_fn @@ -90,6 +92,16 @@ def inline_unused(x): return x + 5.6 +@functools.lru_cache +def inline_lru_cache_fn_with_default_args(x, y, _=None): + return torch.sin(x * y) + + +@torch.jit.script_if_tracing +def inline_script_if_tracing_fn_with_default_args(x, y, _=None): + return torch.cos(x * y) + + class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_inline_jit_annotations(x): @@ -98,6 +110,14 @@ class FunctionTests(torch._dynamo.test_case.TestCase): x = inline_unused(x) return + @make_test + def test_inline_script_if_tracing_fn_with_default_args(a, b): + return inline_script_if_tracing_fn_with_default_args(a, 2, b) + + @make_test + def test_inline_lru_cache_fn_with_default_args(a, b): + return inline_lru_cache_fn_with_default_args(a, 2, b) + @make_test def test_add(a, b): return a + b @@ -141,6 +161,13 @@ class FunctionTests(torch._dynamo.test_case.TestCase): v = v + x return v + @make_test + def test_itertools_chain_from_iterable(a, b): + v = a + for x in itertools.chain.from_iterable([[a, b], [1, 2]]): + v = v + x + return v + @make_test def test_itertools_combinations(a, b): combs = [] @@ -160,14 +187,14 @@ class FunctionTests(torch._dynamo.test_case.TestCase): def test_constant3(a): b = 1 c = 2 - f = 3 - return b + c - f + a + d = 3 + return b + c - d + a @make_test def test_constant4(a, b): c = 2 - f = 3 - if c > f: + d = 3 + if c > d: return a - b return b - a @@ -271,26 +298,26 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_deque(a, b): - f = collections.deque([a, b]) - f.append(a + 1) - f.extend([a, b]) - f.insert(0, "foo") - tmp = f.pop() + d = collections.deque([a, b]) + d.append(a + 1) + d.extend([a, b]) + d.insert(0, "foo") + tmp = d.pop() another_deque = collections.deque([tmp]) - f.extendleft(another_deque) + d.extendleft(another_deque) another_deque.clear() - f.extend(another_deque) + d.extend(another_deque) - f[2] = "setitem" - f = f.copy() - f.append(f.popleft()) + d[2] = "setitem" + d = d.copy() + d.append(d.popleft()) empty = collections.deque() - f.extend(empty) + d.extend(empty) # dynamo same() util doesn't support deque so just return a list - return list(f) + return list(d) @make_test def test_slice1(a): @@ -353,8 +380,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_return_tuple1(a, b): - res = (a - b, b - a, a, b) - return res + return (a - b, b - a, a, b) @make_test def test_globalvar(a, b): @@ -412,6 +438,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase): z = dict({"foo": x + 1}) return z + @make_test + def test_dict_keys(x): + d = {3: x} + keys = d.keys() + d[4] = x + 1 + d2 = {3: 2, 4: "aa"} + return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys + + @make_test + def test_dict_values(x): + d = {3: x} + values = d.values() + d[3] = x + 1 + d[4] = x + 2 + return len(values) + @make_test def test_callable_lambda(x): if callable(lambda x: True): @@ -433,6 +475,40 @@ class FunctionTests(torch._dynamo.test_case.TestCase): else: return x - 1 + def test_callable_class(self): + class CallableClass: + def __call__(): + pass + + class NotCallableClass: + pass + + @torch.compile(backend="eager", fullgraph=True) + def fn1(x, arg): + if callable(arg): + return x + return x + 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn2(x, arg): + if callable(arg): + return x * 2 + return x + 1 + + ipt = torch.randn(4) + + for f in [fn1, fn2]: + self.assertEqual(f(ipt, NotCallableClass()), ipt + 1) + self.assertEqual( + f(ipt, CallableClass()), ipt if f is fn1 else ipt * 2 + ) + + # passing tensor and scalars + self.assertEqual(f(ipt, 1), ipt + 1) + self.assertEqual(f(ipt, 1.1), ipt + 1) + self.assertEqual(f(ipt, True), ipt + 1) + self.assertEqual(f(ipt, ipt), ipt + 1) + @make_test def test_len_constant_misc_iterables(x): a = len((1, 2, 3)) @@ -457,7 +533,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_float(x): - y = float(1.2) + y = float(1.2) # noqa: UP018 y += float("1.2") return torch.add(x, y) @@ -490,6 +566,13 @@ class FunctionTests(torch._dynamo.test_case.TestCase): else: return x - 1 + @make_test + def test_cublas_allow_tf32(x): + if torch.backends.npu.matmul.allow_tf32: + return x.sin() + 1 + + return x.cos() - 1 + @make_test def test_get_calculate_correct_fan(x): fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in") @@ -514,6 +597,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase): if not x.is_npu: return x + 1 + @unittest.skipIf(not torch.npu.is_available(), "requires npu") + @make_test + def test_get_device_properties_tensor_device(a): + x = a.to("npu:0") + prop = torch.npu.get_device_properties(x.device) + if prop.major == 8: + return x + prop.multi_processor_count + return x + prop.max_threads_per_multi_processor + @make_test def test_tensor_type(a, b): m = a.to(torch.float16) @@ -522,7 +614,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @unittest.skipIf(not torch.npu.is_available(), "requires npu") @make_test def test_tensor_type2(a, b): - m = a.to("npu") + m = a.to("npu:0") return m + b.type(m.type()) @make_test @@ -541,6 +633,12 @@ class FunctionTests(torch._dynamo.test_case.TestCase): m = a.type(torch.npu.HalfTensor) return b.type(m.type()) + @make_test + def test_tensor_element_size(a): + if a.element_size() > 1: + return (a + a.element_size(), a - a.element_size()) + return (a - a.element_size(), a + a.element_size()) + @make_test def test_ndim(x): if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: @@ -572,9 +670,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_del(a, b): c = a + 1 - f = c + 2 + d = c + 2 del c, a - return b + f + return b + d @make_test def test_chunks1(x): @@ -585,6 +683,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_import1(x, y): + import torch + from torch import sub + return sub(torch.add(x, y), y) @make_test @@ -693,10 +794,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase): assert tmp.get("zzz") is None v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) tmp.update({"d": 3}) - tmp["c"] = v + tmp.get('d') + tmp["c"] = v + tmp["d"] if "c" in tmp and "missing" not in tmp: return tmp["c"] - tmp["a"] + len(tmp) + @make_test + def test_inline_jit__unwrap_optional(x): + if torch.jit._unwrap_optional(x) is None: + return torch.ones(2, 2) + return x.sin() + def test_dict_param_keys(self): a_param = torch.nn.Parameter(torch.ones([4, 4])) @@ -765,6 +872,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(ref[1]["e"], res[1]["e"])) self.assertTrue(same(ref[1][param], res[1][param])) + def test_dict_tuple_lazy_guard(self): + @torch.compile(backend="eager") + def fn(x, y): + return torch.sin(x) * y[1] + + fn(torch.randn(3), {1: 1, 2: 2}) + # Changing the value of other key should not causing recompilation + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + fn(torch.randn(3), {1: 1, 2: 3}) + + fn(torch.randn(3), (1, 2, 3)) + # Changing the value of index 0, 2 (not 1) should not cause recompilation + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + fn(torch.randn(3), (11, 2, 13)) + @make_test def test_call_dict1(x): d1 = dict() @@ -813,11 +935,11 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_dict_fromkeys(x, y): lst = ["a", "b"] - dd = dict.fromkeys(lst) - d1 = dict.fromkeys(dd, x + 1) + d = dict.fromkeys(lst) + d1 = dict.fromkeys(d, x + 1) d2 = collections.defaultdict.fromkeys(iter(d1), x - 2) d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y) - return d1.get('a') * d2.get('b') + d2.get('a') + d1.get('b') + d3.get('a')+ d3.get('b') + 1 + return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1 @make_test def test_dict_copy(x): @@ -831,20 +953,20 @@ class FunctionTests(torch._dynamo.test_case.TestCase): d3["c"] = x + 20 d4 = d3.copy() d4["c"] = x - 10 - return d1.get('a') * d2.get('b') + d2.get('a') + d3.get('c') + d4.get('c') + 1 + return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1 @make_test def test_dict_update(x, y, z): - dd = {"a": x, "b": y} - dd.update({"a": y - 1}) - dd.update([("b", z + 1), ["c", z]]) - dd.update(zip("ab", [z + 3, y + 2])) + d = {"a": x, "b": y} + d.update({"a": y - 1}) + d.update([("b", z + 1), ["c", z]]) + d.update(zip("ab", [z + 3, y + 2])) od = collections.OrderedDict(a=x * 3, b=y + 2) od.update({"a": y + 5}) od.update([["b", z + 6], ("c", z - 7)]) od.update(zip("ab", [z - 3, x + 2])) - return dd.get("a") * od.get("a") + od.get("c") + dd.get("b") + od.get("b") * dd.get("c") + return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"] @make_test def test_min_max(a, b): @@ -864,12 +986,52 @@ class FunctionTests(torch._dynamo.test_case.TestCase): return x - 1 @make_test - def test_map_sum(a, b, c, f): - return sum(map(lambda x: x + 1, [a, b, c, f])) + def test_map_sum(a, b, c, d): + return sum(map(lambda x: x + 1, [a, b, c, d])) + + @make_test + def test_sum(a, b, c, d): + return sum([a, b, c, d]) + + @make_test + def test_sum_with_start_arg(a, b, c, d): + return sum([b, c, d], a) @make_test - def test_reduce(a, b, c, f): - return functools.reduce(operator.add, [a, b, c, f]) + def test_sum_with_start_kwarg(a, b, c, d): + return sum([b, c, d], start=a) + + @make_test(expected_frame_count=0) + def test_sum_shortcut(): + return sum([0, 1.0, 2, 3.0]) + + @make_test(expected_frame_count=0) + def test_sum_shortcut_with_start_arg(): + return sum([0, 1.0, 2, 3.0], -10) + + @make_test(expected_frame_count=0) + def test_sum_shortcut_with_start_kwarg(): + return sum([0, 1.0, 2, 3.0], start=-10) + + @make_test + def test_reduce(a, b, c, d): + return functools.reduce(operator.add, [a, b, c, d]) + + @make_test + def test_reduce_with_initial(a, b, c, d): + return functools.reduce(operator.add, [b, c, d], a) + + @make_test(expected_frame_count=0) + def test_reduce_with_single(x): + return functools.reduce(lambda a, b: (a, b), [x]) + + @make_test(expected_frame_count=0) + def test_reduce_with_single_with_initial(x, y): + return functools.reduce(lambda a, b: (a, b), [y], x) + + @make_test(expected_frame_count=0) + def test_reduce_with_none_initial(x): + return functools.reduce(lambda a, b: (a, b), [x], None) @make_test def test_tuple_contains(a, b): @@ -910,14 +1072,14 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_unpack_ex2(x): output = (x, x + 1, x + 2, x + 3) - *ab, c, dd = output - return c - dd / ab[0] + *ab, c, d = output + return c - d / ab[0] @make_test def test_unpack_ex3(x): output = (x, x + 1, x + 2, x + 3) - a, *bc, dd = output - return a - dd / bc[0] + a, *bc, d = output + return a - d / bc[0] @make_test def test_const_tuple_add1(x): @@ -956,13 +1118,12 @@ class FunctionTests(torch._dynamo.test_case.TestCase): ("jane", "B", 5), ("dave", "B", 10), ] - res = ( + return ( x + 1, sorted(y), sorted(y, key=lambda student: student[2]), sorted(y, key=lambda student: student[2], reverse=True), ) - return res @make_test def test_tuple_sorted(x): @@ -1214,7 +1375,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_numpy_attributes(x): a = x.numpy() - res = ( + return ( a.itemsize, a.strides, a.shape, @@ -1224,7 +1385,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase): torch.from_numpy(a.real), torch.from_numpy(a.imag), ) - return res @make_test def test_mean_sum_np(x: torch.Tensor): @@ -1321,6 +1481,98 @@ class FunctionTests(torch._dynamo.test_case.TestCase): triple = functools.partial(multiply, y=3) return triple(x) + @unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed") + @make_test + def test_flat_param_same_storage_size(x, y): + import torch.distributed.fsdp._flat_param as flat_param + + if flat_param._same_storage_size(x, 100): + x = x + 1 + else: + x = x - 1 + if flat_param._same_storage_size(y, 123): + y = y + 1 + else: + y = y - 1 + return x, y + + @parametrize( + "attr", + ( + # True + "__subclasshook__", + "__lt__", + "__hash__", + "__ge__", + "__le__", + "__gt__", + "__dict__", + "__getattribute__", + "__setattr__", + "__doc__", + "__repr__", + "__dir__", + "__init__", + "__new__", + "__class__", + "__eq__", + "__delattr__", + "__reduce__", + "__module__", + "__format__", + "__str__", + "__sizeof__", + "__ne__", + "__call__", + "__reduce_ex__", + "__init_subclass__", + "args", + "keywords", + "func", + # False + "__code__", + "__kwdefaults__", + "__defaults__", + "__name__", + "__annotations__", + "__get__", + "__builtins__", + "__qualname__", + "__globals__", + "__closure__", + ), + ) + def test_partials_hasattr(self, attr): + def fn(t): + f = lambda x, y: torch.sin(x) + torch.cos(y) + p = functools.partial(f, y=t) + if hasattr(p, attr): + return p(t) + else: + return torch.zeros_like(t) + + t = torch.randn(3, 4) + counter = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) + self.assertEqual(opt_fn(t), fn(t)) + self.assertGreater(counter.frame_count, 0) + + @unittest.expectedFailure + def test_partials_hasattr_set_attr(self): + def fn(t): + f = lambda x, y: torch.sin(x) + torch.cos(y) + p = functools.partial(f, y=t) + p.__name__ = "test" + if hasattr(p, "__name__"): + return p(t) + else: + return torch.zeros_like(t) + + t = torch.randn(3, 4) + counter = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) + self.assertEqual(opt_fn(t), fn(t)) + def test_pow_int(self): def fn(a, b): return torch.pow(a, b) @@ -1387,6 +1639,201 @@ class FunctionTests(torch._dynamo.test_case.TestCase): eager_result = fn(lambda0, lambda1, x) self.assertEqual(eager_result, dynamo_result) + def test_partials_graph_break_reconstruct(self): + def fn(udf_mul_0, udf_mul_1, x): + lambda0 = functools.partial(udf_mul_0, y=x) + lambda1 = functools.partial(udf_mul_1, y=x) + + print("break") + return torch.mul(lambda0(x), lambda1(x)) + + backend = EagerAndRecordGraphs() + cnts = CompileCounterWithBackend(backend) + x = torch.randn(2, 2) + dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x) + + eager_result = fn(udf_mul, udf_mul, x) + gm = backend.graphs[0] + self.assertEqual(eager_result, dynamo_result) + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + mul_1 = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_2 = torch.mul(mul, mul_1); mul = mul_1 = None + return (mul_2,) +""", + ) + else: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + mul_1 = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_2 = torch.mul(mul, mul_1); mul = mul_1 = None + return (mul_2,) +""", + ) + + def test_partials_graph_break_reconstruct_mix(self): + def fn(udf_mul_0, udf_add_1, x): + lambda0 = functools.partial(udf_mul_0, y=x) + lambda1 = functools.partial(udf_add_1, x) + + print("break") + return torch.mul(lambda0(x), lambda1(x)) + + backend = EagerAndRecordGraphs() + cnts = CompileCounterWithBackend(backend) + x = torch.randn(2, 2) + dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x) + + eager_result = fn(udf_mul, udf_add, x) + gm = backend.graphs[0] + self.assertEqual(eager_result, dynamo_result) + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + + add = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_1 = torch.mul(mul, add); mul = add = None + return (mul_1,) +""", + ) + else: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + + add = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_1 = torch.mul(mul, add); mul = add = None + return (mul_1,) +""", + ) + + def test_partials_graph_break_reconstruct_mix_no_source(self): + def fn(udf_mul_0, x): + udf_add_1 = lambda x, y: x + y + + lambda0 = functools.partial(udf_mul_0, y=x) + lambda1 = functools.partial(udf_add_1, x) + + print("break") + return torch.mul(lambda0(x), lambda1(x)) + + backend = EagerAndRecordGraphs() + cnts = CompileCounterWithBackend(backend) + x = torch.randn(2, 2) + dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x) + + eager_result = fn(udf_mul, x) + gm = backend.graphs[0] + self.assertEqual(eager_result, dynamo_result) + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + + add = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_1 = torch.mul(mul, add); mul = add = None + return (mul_1,) +""", + ) + else: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_lambda0_keywords_y_ : torch.Tensor): + l_lambda0_keywords_y_ = L_lambda0_keywords_y_ + + mul = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + + add = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + + mul_1 = torch.mul(mul, add); mul = add = None + return (mul_1,) +""", + ) + + def test_partials_graph_break_reconstruct_args_and_kwargs(self): + def fn(udf_mul_0, x): + lambda0 = functools.partial(udf_mul_0, x, 4, z=x) + lambda1 = functools.partial(udf_mul_0, 4, z=x) + + return torch.mul(lambda0(), lambda1(5)) + + backend = EagerAndRecordGraphs() + cnts = CompileCounterWithBackend(backend) + x = torch.randn(2, 2) + dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x) + + eager_result = fn(udf_mul2, x) + gm = backend.graphs[0] + self.assertEqual(eager_result, dynamo_result) + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + + mul = l_x_ * 4 + mul_1 = mul * l_x_; mul = None + mul_2 = 20 * l_x_; l_x_ = None + + mul_3 = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None + return (mul_3,) +""", + ) + else: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): + l_x_ = L_x_ + + mul = l_x_ * 4 + mul_1 = mul * l_x_; mul = None + mul_2 = 20 * l_x_; l_x_ = None + + mul_3 = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None + return (mul_3,) +""", + ) + def test_partials_recompilation(self): def fn(f0, f1, x): return f0(x) * f1(x) @@ -1395,6 +1842,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2)) cnts = torch._dynamo.testing.CompileCounter() + x = torch.randn(2, 2) fn = torch._dynamo.optimize(cnts, nopython=True)(fn) dynamo_result = fn(lambda0, lambda1, x) @@ -1442,7 +1890,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): fn2(lambda6, lambda6, [x]) self.assertEqual( cnts.frame_count, 4 - ) # Recompile! input is no longer a functools partial + ) # Recompile! ipt is no longer a functools partial def test_manual_seed(self): @torch.compile @@ -1542,11 +1990,170 @@ class FunctionTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(program(input1, input2), input1 + input1)) + def test_compare_constant_and_tensor(self): + for op in [ + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.ne, + operator.eq, + operator.is_, + operator.is_not, + ]: + with self.subTest(op=op): + + def fn(x): + return op(-10, x) + + opt_fn = torch.compile(fullgraph=True)(fn) + + x = torch.randn(10) + self.assertEqual(opt_fn(x), fn(x)) + + def test_pos(self): + def fn(x, y): + return operator.pos(x) * +y + + opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn) + + def test(x, y): + self.assertEqual(opt_fn(x, y), fn(x, y)) + + test(torch.ones(4), 1) + test(1, torch.ones(4)) + test(-1, -1) + test(-1.1, 1.1) + test(True, False) + test(torch.ones(4, dtype=torch.float32), 1.1) + + def test_truth(self): + def fn(x, y): + return operator.truth(x) and bool(y) + + opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn) + + def test(x, y): + self.assertEqual(opt_fn(x, y), fn(x, y)) + + test(1, 100) + test(-1.1, True) + test(-1.1, 1.1) + test(True, False) + test(torch.ones(1), 1) + test(torch.zeros(1), 1) + test(torch.ones(1), torch.ones(1)) + + def test_unary_fold_op(self): + for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth): + with self.subTest(op=op): + + def fn(): + a = range(-10, 10) + return list(map(op, a)) + + opt_fn = torch._dynamo.optimize(nopython=True)(fn) + self.assertEqual(opt_fn(), fn()) + + def test_unary_fold_op_seq(self): + for op in (operator.length_hint,): + with self.subTest(op=op): + + def fn(): + a = [tuple(range(-10, i)) for i in range(10)] + return tuple(map(op, a)) + + opt_fn = torch._dynamo.optimize(nopython=True)(fn) + self.assertEqual(opt_fn(), fn()) + + def test_rand_inlined(self): + @torch.compile(backend="eager", dynamic=True) + def fn(): + idx_size = [10] + idx_size[random.randint(0, 0)] = random.randint(1, 8) + t = tuple(idx_size) + src_size = [random.randint(1, 5) + s for s in idx_size] + idx = torch.empty(t) + + fn() + + def test_rand_tensor_partial(self): + from collections import namedtuple + from functools import partial + + SdpaShape = namedtuple( + "Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"] + ) + + @torch.compile(backend="eager") + def func(): + make_tensor = partial( + torch.rand, device="cpu", dtype=torch.float16, requires_grad=True + ) + + bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16) + make_q_tensor = partial( + make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim) + ) + make_kv_tensor = partial( + make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim) + ) + t1 = make_q_tensor() + t2 = make_kv_tensor() + t3 = t1 + t2 + + func() + + def test_to(self): + @torch.compile(backend="eager") + def fn(): + t = torch.ones(2) + y = t.to("meta") + + fn() + + def test_elipsis(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(a, ind, val): + a[ind] = val + return a + + arr = np.zeros(4) + self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4)) + + arr = np.array([[1, 1], [2, 2]]) + self.assertEqual( + fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]]) + ) + + arr = np.array([[1, 1], [2, 2]]) + self.assertEqual( + fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]]) + ) + + arr = np.array([[1, 1], [2, 2]]) + self.assertEqual( + fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]]) + ) + + arr = np.array([[1, 1], [2, 2]]) + self.assertEqual( + fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) + ) + def udf_mul(x, y): return x * y +def udf_mul2(x, y, z): + return x * y * z + + +def udf_add(x, y): + return x + y + + class SmallNN(torch.nn.Module): def forward(self, x, y): combined = torch.cat((x, y), dim=1) @@ -1583,12 +2190,6 @@ class WrapperModule(torch.nn.Module): return self.m() -# Define shared triton constants here. -CONSTANT_C = 4 -STRING_CONSTANT_C = "CONSTANT_C" -BOOL_CONSTANT_C = True - - class DefaultsTests(torch._dynamo.test_case.TestCase): def test_func_default_tensor_args(self): """ @@ -1696,704 +2297,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 1) - @requires_cuda() - def test_triton_kernel_with_kernel_param(self): - @triton.jit - def pass_kernel(kernel): - pass - - @torch.compile(backend="eager") - def f(x): - grid = (x.numel(),) - pass_kernel[grid](kernel=x) - - t1 = torch.rand(5, device="cuda") - f(t1) - # No need to assert anything, the goal is to make sure dynamo does - # not crash - - @requires_cuda() - def test_triton_kernel_higher_order_func(self): - from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table - - add_kernel_id = kernel_side_table.add_kernel(add_kernel) - - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") - - torch_add = t1 + t2 - - # Test higher order function with mutation - output = torch.zeros_like(t1) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - triton_kernel_wrapper_mutation( - kernel_idx=add_kernel_id, - grid=[grid], - kwargs={ - "in_ptr0": t1, - "in_ptr1": t2, - "out_ptr": output, - "n_elements": n_elements, - "BLOCK_SIZE": 16, - }, - ) - self.assertEqual(output, torch_add) - # Make sure it is modified - self.assertNotEqual(output, torch.zeros_like(t1)) - - # Test higher order function without mutation - output = torch.zeros_like(t1) - out_dict = triton_kernel_wrapper_functional( - kernel_idx=add_kernel_id, - grid=[grid], - kwargs={ - "in_ptr0": t1, - "in_ptr1": t2, - "out_ptr": output, - "n_elements": n_elements, - "BLOCK_SIZE": 16, - }, - tensors_to_clone=["in_ptr0", "in_ptr1", "out_ptr"], - ) - self.assertEqual(out_dict["out_ptr"], torch_add) - # Make sure it is NOT modified - self.assertEqual(output, torch.zeros_like(t1)) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_functionalize(self): - import functorch - from functorch import make_fx - from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table - from torch._subclasses.functional_tensor import ( - CppFunctionalizeAPI, - FunctorchFunctionalizeAPI, - PythonFunctionalizeAPI, - ) - - kernel_side_table.reset_table() - - def f(x, output): - out = triton_kernel_wrapper_functional( - kernel_idx=kernel_side_table.add_kernel(mul2_kernel), - grid=[(x.numel(),)], - kwargs={ - "in_ptr0": x, - "out_ptr": output, - "n_elements": output.numel(), - "BLOCK_SIZE": 16, - }, - tensors_to_clone=["in_ptr0", "out_ptr"], - ) - return out["out_ptr"] - - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") - - gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2) - # Make sure t2 was not modified - self.assertNotEqual(gm(t1, t2), t2) - - gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2) - # Make sure t2 was not modified - self.assertNotEqual(gm(t1, t2), t2) - - gm = make_fx(torch.func.functionalize(f))(t1, t2) - # Make sure t2 was not modified - self.assertNotEqual(gm(t1, t2), t2) - - gm = make_fx(f, tracing_mode="fake")(t1, t2) - self.assertExpectedInline( - gm.code.strip(), - """\ -def forward(self, x_1, output_1): - triton_kernel_wrapper_functional_proxy = torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional(kernel_idx = 0, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None - getitem = triton_kernel_wrapper_functional_proxy['in_ptr0'] - getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr'] - getitem_2 = triton_kernel_wrapper_functional_proxy['n_elements'] - getitem_3 = triton_kernel_wrapper_functional_proxy['BLOCK_SIZE']; triton_kernel_wrapper_functional_proxy = None - return getitem_1""", - ) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_mutation_type(self): - from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table - from torch._subclasses.fake_tensor import FakeTensorMode - from torch._subclasses.functional_tensor import ( - FunctionalTensor, - FunctionalTensorMode, - ) - - def prep(): - x = torch.ones(4, device="cuda", requires_grad=True) - x_func = FunctionalTensor.to_functional(x) - self.assertTrue(torch._is_functional_tensor(x_func.elem)) - return x_func - - # normal mutation only - with FakeTensorMode(): - x_func = prep() - - with FunctionalTensorMode(): - x_func.mul_(2) - - self.assertFalse( - torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) - ) - - # triton kernel mutation only - with FakeTensorMode(): - x_func = prep() - - with FunctionalTensorMode(): - triton_kernel_wrapper_mutation( - kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), - grid=[(x_func.numel(),)], - kwargs={ - "ptr": x_func, - "n_elements": x_func.numel(), - "BLOCK_SIZE": 16, - }, - ) - - self.assertTrue( - torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) - ) - - # normal mutation + triton kernel mutation - with FakeTensorMode(): - x_func = prep() - - with FunctionalTensorMode(): - x_func.mul_(2) - triton_kernel_wrapper_mutation( - kernel_idx=kernel_side_table.add_kernel(mul2_inplace_kernel), - grid=[(x_func.numel(),)], - kwargs={ - "ptr": x_func, - "n_elements": x_func.numel(), - "BLOCK_SIZE": 16, - }, - ) - - self.assertFalse( - torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) - ) - - @requires_cuda() - @common_utils.parametrize("dynamic", [False, True]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - def test_triton_kernel_with_views(self, dynamic, backend): - def call_triton_take_view(x: torch.Tensor): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - return output - - def call_triton_return_view(x: torch.Tensor): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - return output.view(4, 4) - - t = torch.rand(4, 4, device="cuda") - t_view = t.view(16) - - compiled_func = torch.compile( - call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic - ) - self.assertEqual(2 * t_view, compiled_func(t_view)) - self.assertEqual(2 * t, compiled_func(t_view).view(4, 4)) - - compiled_func = torch.compile( - call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic - ) - self.assertEqual(2 * t_view, compiled_func(t).view(16)) - self.assertEqual(2 * t, compiled_func(t)) - - @requires_cuda() - @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - def test_triton_kernel_with_grad_option(self, grad_fn, backend): - def call_triton(x: torch.Tensor): - with grad_fn(): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - return output - - t = torch.rand(5, device="cuda") - compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) - self.assertEqual(2 * t, compiled_func(t)) - - @requires_cuda() - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - def test_triton_kernel_inner_triton_function(self, backend): - def f(x: torch.Tensor): - @triton.jit - def pow2_kernel( - in_ptr0, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - output = x * x - tl.store(out_ptr + offsets, output, mask=mask) - - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - return output - - t = torch.rand(5, device="cuda") - - compiled_func = torch.compile(f, backend=backend, fullgraph=True) - # do for later(oulgen): NYI - Support this - # self.assertEqual(t * t, compiled_func(t)) - - @requires_cuda() - @common_utils.parametrize("grad", [False, True]) - @common_utils.parametrize("dynamic", [False, True]) - @patch.object(torch._inductor.config, "implicit_fallbacks", False) - def test_triton_kernel_no_clones(self, grad, dynamic): - from torch._inductor.utils import run_and_get_code - - def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): - n_elements = output.numel() - - tmp = torch.add(x, 1) - grid = (x.numel(),) - add_kernel.run(x, y, output, n_elements, grid=grid, BLOCK_SIZE=16) - - return output, tmp - - t1 = torch.rand(5, device="cuda", requires_grad=grad) - t2 = torch.rand(5, device="cuda", requires_grad=grad) - o1 = torch.zeros_like(t1, requires_grad=grad) - - torch_add = call_triton(t1, t2, o1) - metrics.reset() - o2 = torch.zeros_like(t1, requires_grad=grad) - test, codes = run_and_get_code( - torch.compile(call_triton, dynamic=dynamic), t1, t2, o2 - ) - if not grad: - self.assertEqual(metrics.generated_kernel_count, 1) - self.assertEqual(torch_add, test) - # These two asserts are not optimal since it requires original aten - # to be in the metadata, so there might be false negatives - self.assertTrue("aten.copy" not in codes[0]) - self.assertTrue("aten.clone" not in codes[0]) - # The following checks that there are only the tensor output is in - # the compiled graph - if dynamic and grad: - self.assertTrue("return (buf0, s0, )" in codes[0]) - else: - self.assertTrue("return (buf0, )" in codes[0]) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_caching(self): - from torch._inductor.utils import run_and_get_code - - def add_in_loop( - x: torch.Tensor, - y: torch.Tensor, - ): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_autotuned[grid](x, y, output, n_elements) - return output - - def call_triton_add( - x: torch.Tensor, - y: torch.Tensor, - ): - for i in range(4): - x = add_in_loop(x, y) - return x - - t1 = torch.ones(5, device="cuda") - t2 = torch.ones(5, device="cuda") - - test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2) - self.assertEqual(test, 5 * torch.ones(5, device="cuda")) - self.assertTrue("add_kernel_autotuned_1.run" not in code) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_caching_duplicate(self): - from torch._inductor.utils import run_and_get_code - - class C: - @triton.jit - def pass_kernel( - in_ptr0, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pass - - class D: - @triton.jit - def pass_kernel( - in_ptr0, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pass - - def call_triton(x: torch.Tensor): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = (n_elements,) - C.pass_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - D.pass_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) - - t = torch.ones(5, device="cuda") - test, (code,) = run_and_get_code(torch.compile(call_triton), t) - # Make sure we emitted two kernels here - self.assertTrue("pass_kernel_0.run" in code) - self.assertTrue("pass_kernel_1.run" in code) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_various_args(self): - @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE": 128})], - key=[], - ) - @triton.jit - def pass_kernel( - out_ptr, - n_elements, - dummy_None, - dummy_empty, - dummy_float, - BLOCK_SIZE: "tl.constexpr", - RANDOM_SIZE: "tl.constexpr", - ): - pass - - @torch.compile - def call_triton(output): - n_elements = output.numel() - grid = (n_elements,) - pass_kernel[grid]( - output, - n_elements, - None, - torch.empty_like(output), - 3.1415926, - RANDOM_SIZE=0, - ) - return output - - output = torch.randn(5, device="cuda") - # Make sure this does not crash - call_triton(output) - - @requires_cuda() - @skipIfRocm - def test_triton_kernel_dependancies(self): - def call_triton( - x: torch.Tensor, - y: torch.Tensor, - ): - output = torch.zeros_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_autotuned[grid](x, y, output, n_elements) - output2 = torch.zeros_like(output) - add_kernel_autotuned[grid](output, y, output2, n_elements) - output3 = torch.add(output2, 1) - return output3 - - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") - torch_result = call_triton(t1, t2) - compiled_result = torch.compile(call_triton)(t1, t2) - self.assertEqual(torch_result, compiled_result) - - @requires_cuda() - @common_utils.parametrize("grad", [False, True]) - def test_triton_kernel_multi_kernel(self, grad): - @triton.jit - def mul2_and_add_and_zero_negatives_kernel( - in_ptr0, - in_ptr1, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ACTIVATION: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - indirection_kernel( - in_ptr0, - in_ptr0, - n_elements, - BLOCK_SIZE=BLOCK_SIZE, - ACTIVATION="mul2_inplace_kernel", - ) - indirection_kernel( - in_ptr1, - in_ptr1, - n_elements, - BLOCK_SIZE=BLOCK_SIZE, - ACTIVATION="mul2_inplace_kernel", - ) - x = tl.load(in_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y - if ACTIVATION == "zero_negs": - output = zero_negs(output) - tl.store(out_ptr + offsets, output, mask=mask) - - @torch.compile - def call_triton( - x: torch.Tensor, - y: torch.Tensor, - xi: torch.Tensor, - yi: torch.Tensor, - output: torch.Tensor, - outputi: torch.Tensor, - ): - n_elements = output.numel() - - grid = (x.numel(),) - mul2_and_add_and_zero_negatives_kernel[grid]( - x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION="zero_negs" - ) - mul2_and_add_and_zero_negatives_kernel[grid]( - xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None - ) - - return (output, outputi) - - t1 = torch.tensor( - [-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad - ) - t2 = torch.tensor( - [-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad - ) - float_result = 2 * t1 + 2 * t2 - float_result = float_result.where(float_result >= 0, 0.0) - - t1i = torch.randint(-2, 2, (5,), device="cuda") - t2i = torch.randint(-2, 2, (5,), device="cuda") - o_tensor = torch.zeros_like(t1, requires_grad=grad) - oi = torch.zeros_like(t1i) - int_result = 2 * t1i + 2 * t2i - - (result, resulti) = call_triton(t1, t2, t1i, t2i, o_tensor, oi) - self.assertEqual(float_result, result) - self.assertEqual(int_result, resulti) - - @requires_cuda() - def test_triton_kernel_constants(self): - @triton.jit - def mulC_kernel( - in_ptr0, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - CONSTANT_NAME: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - if CONSTANT_NAME.value == STRING_CONSTANT_C: - output = CONSTANT_C * x - if BOOL_CONSTANT_C: - output *= CONSTANT_C - tl.store(out_ptr + offsets, output, mask=mask) - - def call_triton( - x: torch.Tensor, - ): - output = torch.zeros_like(x) - n_elements = output.numel() - - grid = (x.numel(),) - mulC_kernel[grid]( - x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C" - ) - return output - - # Triton kernels capture global constants by their parse time value - # not runtime value - global CONSTANT_C - prev_c = CONSTANT_C - # If the behavior of triton kernels change, this test will fail - CONSTANT_C = 10 - assert CONSTANT_C != prev_c - - t = torch.randn(5, device="cuda") - torch_result = call_triton(t) - compiled_result = torch.compile(call_triton)(t) - - self.assertEqual(torch_result, compiled_result) - - # reset back - CONSTANT_C = prev_c - - @requires_cuda() - @skipIfRocm - @common_utils.parametrize("grad", [False, True]) - @common_utils.parametrize("dynamic", [False, True]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - @common_utils.parametrize("grid_type", [1, 2, 3]) - def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type): - def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): - n_elements = output.numel() - - def grid_fn(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - if grid_type == 1: - grid = (n_elements,) - elif grid_type == 2: - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - elif grid_type == 3: - grid = grid_fn - - add_kernel_autotuned[grid](x, y, output, n_elements) - return output - - t1 = torch.rand(256, device="cuda", requires_grad=grad) - t2 = torch.rand(256, device="cuda", requires_grad=grad) - output = torch.zeros_like(t1, requires_grad=grad) - - torch_add = call_triton(t1, t2, output) - compiled_func = torch.compile( - call_triton, backend=backend, fullgraph=True, dynamic=dynamic - ) - - output2 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, output2), torch_add) - - @requires_cuda() - @skipIfRocm - @common_utils.parametrize("grad", [False, True]) - @common_utils.parametrize("dynamic", [False, True]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - @common_utils.parametrize("grid_type", [1, 2, 3]) - def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type): - def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): - x_elements = output.size()[0] - y_elements = output.size()[1] - - def grid_fn(meta): - return ( - triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), - triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), - ) - - if grid_type == 1: - grid = (x_elements, y_elements) - elif grid_type == 2: - grid = lambda meta: ( - triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]), - triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]), - ) - elif grid_type == 3: - grid = grid_fn - - add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements) - return output - - t1 = torch.rand((512, 256), device="cuda", requires_grad=grad) - t2 = torch.rand((512, 256), device="cuda", requires_grad=grad) - output = torch.zeros_like(t1, requires_grad=grad) - - torch_result = call_triton(t1, t2, output) - compiled_func = torch.compile( - call_triton, backend=backend, fullgraph=True, dynamic=dynamic - ) - output2 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, output2), torch_result) - - @requires_cuda() - @common_utils.parametrize("grad", [False, True]) - @common_utils.parametrize("dynamic", [False, True]) - @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - @patch.object(torch._inductor.config, "implicit_fallbacks", False) - def test_triton_kernel_native(self, grad, dynamic, backend): - def call_triton_add( - x: torch.Tensor, - y: torch.Tensor, - output: torch.Tensor, - grid_type: int, - num=1, - positional=False, - ): - n_elements = output.numel() - - def grid_fn(meta): - return (triton.cdiv(num, meta["BLOCK_SIZE"]),) - - if grid_type == 0: - grid = (x.numel(),) - elif grid_type == 1: - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - else: - grid = grid_fn - - if positional: - add_kernel[grid](x, y, output, n_elements, 16) - else: - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) - - return output - - t1 = torch.rand(5, device="cuda", requires_grad=grad) - t2 = torch.rand(5, device="cuda", requires_grad=grad) - o1 = torch.zeros_like(t1, requires_grad=grad) - - torch_add = t1 + t2 - - # No Dynamo -- Make sure triton kernel works - self.assertEqual(call_triton_add(t1, t2, o1, 1), torch_add) - # No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE) - o2 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(call_triton_add(t1, t2, o2, 1, True), torch_add) - - # With Dynamo - compiled_func = torch.compile( - call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic - ) - # With simple kernel - o3 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, o3, 0), torch_add) - # With lambda kernel - o4 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, o4, 1), torch_add) - # With lambda kernel (with positional BLOCK_SIZE) - o5 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, o5, 1, 1, True), torch_add) - # With user defined function kernel - o6 = torch.zeros_like(t1, requires_grad=grad) - self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add) - def test_dataclass_factory(self): @dataclass class Output: @@ -2451,8 +2354,8 @@ def forward(self, x_1, output_1): inner_a: Any = field(default_factory=list) def fn(x): - l_derived = Derived(1, 2) - return l_derived.outer_a * x + l_ = Derived(1, 2) + return l_.outer_a * x opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) @@ -2532,6 +2435,21 @@ def forward(self, x_1, output_1): self.assertEqual(fn(x, y), fn_opt(x, y)) self.assertEqual(fn(x, x), fn_opt(x, x)) + def test_is_not_tensor_tensor(self): + def fn(x, y): + if x is not y: + return x * 2 + else: + return x + y + + fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) + + x = torch.zeros(2) + y = torch.ones(2) + + self.assertEqual(fn(x, y), fn_opt(x, y)) + self.assertEqual(fn(x, x), fn_opt(x, x)) + def test_is_mutated_tensor_tensor(self): def fn(x): y = x.add_(1) @@ -2583,6 +2501,7 @@ def forward(self, x_1, output_1): self.assertEqual(fn(z), fn_opt(z)) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self): def fn(z): x = z.clone() @@ -2596,6 +2515,7 @@ def forward(self, x_1, output_1): self.assertEqual(fn(z), fn_opt(z)) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_vmapped_mutated_tensor_tensor(self): def fn(x): y = torch.vmap(torch.Tensor.acos_)(x) @@ -2607,6 +2527,7 @@ def forward(self, x_1, output_1): self.assertEqual(fn(z), fn_opt(z)) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self): def fn(y, z): a = y.clone() @@ -2615,8 +2536,8 @@ def forward(self, x_1, output_1): def g(a, b): return a.acos_(), b.acos_() - c, dd = torch.vmap(g)(a, b) - return a is c is b is dd + c, d = torch.vmap(g)(a, b) + return a is c is b is d fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) @@ -2669,6 +2590,26 @@ def forward(self, x_1, output_1): self.assertEqual(opt_fn(param, param), fn(param, param)) self.assertEqual(cnts.frame_count, 2) # Recompiles + def test_reconstructed_name(self): + lst = [] + + @torch._dynamo.disable + def disallowed(g): + lst.append(g.__name__) + + def f(): + def g(): + return () + + disallowed(g) + + f_opt = torch._dynamo + opt_f = torch._dynamo.optimize(backend="eager")(f) + opt_f() + f() + self.assertEqual(len(lst), 2) + self.assertEqual(lst[0], lst[1]) + @unittest.skipIf( sys.version_info < (3, 10), "zip strict kwargs not implemented for Python < 3.10", @@ -2697,26 +2638,8 @@ def forward(self, x_1, output_1): with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) - def test_compare_constant_and_tensor(self): - for op in [ - operator.lt, - operator.le, - operator.gt, - operator.ge, - operator.ne, - operator.eq, - ]: - - def fn(x): - return op(-10, x) - - opt_fn = torch.compile(fullgraph=True)(fn) - - x = torch.randn(10) - self.assertEqual(opt_fn(x), fn(x)) - -common_utils.instantiate_parametrized_tests(DefaultsTests) +instantiate_parametrized_tests(FunctionTests) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 410da7f76f457244423a58ce0f8a938f84612967..2e612c4f38fa52114edbf7b394c92c148debd567 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same @@ -125,7 +124,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase): def test_store_global_dict(self): def fn(x): global g_dict - val = x + g_dict.get("b") + val = x + g_dict["b"] """ Strictly speaking, we are not testing STORE_GLOBAL here, since STORE_SUBSCR is actually used to store. diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5d20f6f2d71a3be157f2deffdb1d874787f029e1 --- /dev/null +++ b/test/dynamo/test_guard_manager.py @@ -0,0 +1,667 @@ +# Owner(s): ["module: dynamo"] +import functools +import weakref + +import torch +import torch_npu +import torch._dynamo +import torch._dynamo.test_case +from torch._C._dynamo import guards +from torch.testing._internal.common_utils import set_default_dtype + +RootGuardManager = guards.RootGuardManager +DictGuardManager = guards.DictGuardManager +GetAttrGuardAccessor = guards.GetAttrGuardAccessor +GetItemGuardAccessor = guards.GetItemGuardAccessor +TypeGuardAccessor = guards.TypeGuardAccessor +TENSOR_ALIASING = guards.TENSOR_ALIASING +install_tensor_aliasing_guard = guards.install_tensor_aliasing_guard +NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING +install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard + + +x = torch.tensor(4) +weakref_x = weakref.ref(x) + + +class Pair: + def __init__(self, x, y): + self.x = x + self.y = y + + +global_pair = Pair(torch.randn(4), 1) + + +def id_type(x): + return id(type(x)) + + +def equals_match(x, expected): + return x == expected + + +def equals_match_verbose_code_parts(expected): + return [f"x == {expected}"] + + +def ge_match(x, expected): + return x >= expected + + +def ge_match_verbose_code_parts(expected): + return f"expected >= {expected}" + + +def less_match(x, expected): + return x < expected + + +def less_match_verbose_code_parts(expected): + return [f"expected < {expected}"] + + +class GuardManagerTests(torch._dynamo.test_case.TestCase): + def test_global_state_guard(self): + guard = guards.GLOBAL_STATE(["global_state_check"]) + self.assertTrue(guard(None)) + with set_default_dtype(torch.double): + self.assertFalse(guard(None)) + self.assertTrue(guard(None)) + _orig = torch.are_deterministic_algorithms_enabled() + try: + torch.use_deterministic_algorithms(not _orig) + self.assertFalse(guard(None)) + finally: + torch.use_deterministic_algorithms(_orig) + self.assertTrue(guard(None)) + + def test_python_lambda_leaf_guard(self): + const_guard = guards.LAMBDA_GUARD( + functools.partial(equals_match, expected=5), + equals_match_verbose_code_parts(5), + ) + self.assertTrue(const_guard(5)) + self.assertFalse(const_guard(4)) + self.assertFalse(const_guard("foo")) + + def test_type_guard(self): + foo = 4 + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) + + self.assertTrue(guard(5)) + self.assertTrue(guard(4)) + self.assertFalse(guard("foo")) + + foo = {"a": 1} + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) + self.assertTrue(guard(foo)) + self.assertTrue(guard({})) + self.assertFalse(guard(5)) + self.assertFalse(guard("foo")) + + class Foo: + def __init__(self, x, y): + self.x = x + self.y = y + + foo = Foo(1, 2) + + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) + self.assertTrue(guard(foo)) + self.assertFalse(guard({})) + self.assertFalse(guard(5)) + self.assertFalse(guard("foo")) + + def test_id_guard(self): + foo = 4 + guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) + + self.assertTrue(guard(foo)) + self.assertFalse(guard(5)) + self.assertFalse(guard("foo")) + + foo = {"a": 1} + guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) + self.assertTrue(guard(foo)) + self.assertFalse(guard({"a": 1})) + self.assertFalse(guard({})) + self.assertFalse(guard(5)) + + def test_equals_guard(self): + foo = 4 + guard = guards.EQUALS_MATCH(foo, ["x == 4"]) + + self.assertTrue(guard(4)) + self.assertFalse(guard(5)) + self.assertFalse(guard("foo")) + + # tuple + foo = (1, 2, 3) + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + self.assertTrue(guard(foo)) + self.assertTrue(guard((1, 2, 3))) + self.assertFalse(guard((1, 2, 3, 4))) + self.assertFalse(guard({})) + + # list + foo = [1, 2, 3] + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + self.assertTrue(guard(foo)) + self.assertTrue(guard([1, 2, 3])) + self.assertFalse(guard([1, 2, 3, 4])) + + # type + foo = int + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + self.assertTrue(guard(foo)) + self.assertTrue(guard(int)) + self.assertFalse(guard(float)) + + def test_default_device_guard(self): + foo = 1 + guard = guards.DEFAULT_DEVICE(["cpu device"]) + self.assertTrue(guard(foo)) + + try: + torch.set_default_device("npu:0") + self.assertFalse(guard(foo)) + finally: + torch.set_default_device(None) + + def test_data_ptr_match_guard(self): + foo = torch.tensor([1, 2, 3]) + guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"]) + self.assertTrue(guard(foo)) + self.assertFalse(guard(torch.tensor([1, 2, 3]))) + + def test_length_check_guard(self): + foo = [1, 2, 3] + guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"]) + self.assertTrue(guard(foo)) + self.assertFalse(guard([])) + + def test_no_hasattr_guard(self): + class Bar: + def __init__(self): + self.bar = 2 + + bar = Bar() + + class Foo: + def __init__(self): + self.foo = 2 + + foo = Foo() + + guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"]) + self.assertTrue(guard(bar)) + self.assertFalse(guard(foo)) + + def test_tensor_aliasing_guard(self): + guard_manager = RootGuardManager() + + a = torch.randn(3, 4) + + class Foo: + def __init__(self, x, y): + self.x = x + self.y = y + + f_locals = Foo(a, a) + + x_guard_mgr = guard_manager.getattr_manager("x", "", a) + y_guard_mgr = guard_manager.getattr_manager("y", "", a) + install_tensor_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"]) + + # Check structure + x_guards = x_guard_mgr.get_leaf_guards() + y_guards = y_guard_mgr.get_leaf_guards() + self.assertEqual(len(x_guards), 1) + self.assertEqual(len(y_guards), 1) + self.assertTrue(isinstance(x_guards[0], TENSOR_ALIASING)) + self.assertTrue(isinstance(y_guards[0], TENSOR_ALIASING)) + # Check that the two guards are the same object + self.assertTrue(x_guards[0] is y_guards[0]) + + f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4)) + self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1) + self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1) + self.assertTrue(guard_manager.check(f_locals)) + + self.assertFalse(guard_manager.check(f_locals_unaliased)) + + def test_dict_version_guard(self): + foo = {"a": 1, "b": 2} + guard = guards.DICT_VERSION(foo, ["x.version == foo.version"]) + + self.assertTrue(guard(foo)) + self.assertFalse(guard(dict(foo))) + foo["a"] = 2 + self.assertFalse(guard(foo)) + self.assertFalse(guard({"a": 1, "b": 2})) + self.assertFalse(guard({})) + + def test_dynamic_indices_guard(self): + guard1 = guards.DYNAMIC_INDICES(False, set(), ["x.size(0) == y.size(0)"]) + guard2 = guards.DYNAMIC_INDICES(True, set({0, 1}), ["x.size(0) == y.size(0)"]) + + x = torch.randn(4) + self.assertTrue(guard1(x)) + self.assertTrue(guard2(x)) + + x._dynamo_dynamic_indices = set({0}) + self.assertFalse(guard1(x)) + self.assertTrue(guard2(x)) + + x._dynamo_dynamic_indices = set({2}) + self.assertFalse(guard1(x)) + self.assertFalse(guard2(x)) + + def test_tensor_match_guard(self): + guard_manager = RootGuardManager() + x = torch.randn(4, 4) + size = list(x.size()) + stride = list(x.stride()) + guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"]) + self.assertTrue(guard_manager.check(x)) + self.assertTrue(guard_manager.check_verbose(x).result) + self.assertTrue(guard_manager.check(torch.randn(4, 4))) + self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result) + self.assertFalse(guard_manager.check(x.t_())) + + x = torch.randn(4, 4) + x.t_() + debug_info = guard_manager.check_verbose(x) + print(debug_info.verbose_code_parts[0]) + self.assertTrue( + "tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0] + ) + + def test_no_tensor_aliasing_guard(self): + guard_manager = RootGuardManager() + + a = torch.randn(3, 4) + + class Foo: + def __init__(self, x, y, z): + self.x = x + self.y = y + self.z = z + + f_locals = Foo(a, a, a) + + x_guard_mgr = guard_manager.getattr_manager("x", "", a) + y_guard_mgr = guard_manager.getattr_manager("y", "", a) + z_guard_mgr = guard_manager.getattr_manager("z", "", a) + install_no_tensor_aliasing_guard( + [x_guard_mgr, y_guard_mgr, z_guard_mgr], + ["x", "y", "z"], + ["no_aliasing(x, y, z)"], + ) + + # Check structure + x_guards = x_guard_mgr.get_leaf_guards() + y_guards = y_guard_mgr.get_leaf_guards() + z_guards = z_guard_mgr.get_leaf_guards() + self.assertEqual(len(x_guards), 1) + self.assertEqual(len(y_guards), 1) + self.assertEqual(len(z_guards), 1) + self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING)) + self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING)) + self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING)) + # Check that the two guards are the same object + self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0]) + self.assertFalse(guard_manager.check(f_locals)) + self.assertFalse(guard_manager.check_verbose(f_locals).result) + + f_locals_unaliased = Foo( + torch.randn(3, 4), + torch.randn(3, 4), + torch.randn(3, 4), + ) + self.assertTrue(guard_manager.check(f_locals_unaliased)) + self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result) + # Check that hash map is cleared. + self.assertTrue(guard_manager.check(f_locals_unaliased)) + + f_locals_unaliased = Foo( + a, + torch.randn(3, 4), + a, + ) + self.assertFalse(guard_manager.check(f_locals_unaliased)) + self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result) + + def test_weakref_alive_guard(self): + x = torch.rand(3, 4) + weakref_x = weakref.ref(x) + + guard = guards.WEAKREF_ALIVE(["weakref_x is not None"]) + self.assertTrue(guard(weakref_x())) + del x + self.assertFalse(guard(weakref_x())) + + def test_guard_manager_leaf_guard(self): + guard_manager = RootGuardManager() + guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) + guard_manager.add_lambda_guard( + functools.partial(ge_match, expected=5), + ge_match_verbose_code_parts(expected=5), + ) + guard_manager.add_lambda_guard( + functools.partial(less_match, expected=10), + less_match_verbose_code_parts(expected=10), + ) + self.assertEqual(len(guard_manager.get_leaf_guards()), 3) + self.assertEqual(len(guard_manager.get_accessors()), 0) + self.assertTrue(guard_manager.check(6)) + self.assertFalse(guard_manager.check(4)) + self.assertFalse(guard_manager.check("foo")) + + def test_attr_guard_manager(self): + class Foo: + def __init__(self, x, y): + self.x = x + self.y = y + + foo = Foo(1, 2) + guard_manager = RootGuardManager() + guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) + guard_manager.getattr_manager("x", "x", 1).add_lambda_guard( + functools.partial(equals_match, expected=foo.x), + equals_match_verbose_code_parts(foo.x), + ) + guard_manager.getattr_manager("y", "y", 2).add_lambda_guard( + functools.partial(equals_match, expected=foo.y), + equals_match_verbose_code_parts(foo.y), + ) + self.assertEqual(len(guard_manager.get_leaf_guards()), 1) + # 2 child managers, one for x and one for y + self.assertEqual(len(guard_manager.get_accessors()), 2) + self.assertTrue( + isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor) + ) + self.assertTrue( + isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor) + ) + # Check leaf guards on child managers + self.assertEqual( + len( + guard_manager.getattr_manager( + attr="x", source="x", example_value=None + ).get_leaf_guards() + ), + 1, + ) + self.assertEqual( + len(guard_manager.getattr_manager("y", "y", None).get_leaf_guards()), 1 + ) + + self.assertTrue(guard_manager.check(foo)) + self.assertFalse(guard_manager.check(Foo(3, 4))) + self.assertFalse(guard_manager.check("foo")) + + def test_item_guard_manager(self): + foo = [1, 2] + guard_manager = RootGuardManager() + guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) + guard_manager.getitem_manager(0, "", 1).add_lambda_guard( + functools.partial(equals_match, expected=foo[0]), + equals_match_verbose_code_parts(foo[0]), + ) + guard_manager.getitem_manager(1, "", 2).add_lambda_guard( + functools.partial(equals_match, expected=foo[1]), + equals_match_verbose_code_parts(foo[1]), + ) + self.assertEqual(len(guard_manager.get_leaf_guards()), 1) + # 2 child managers, one for x and one for y + self.assertEqual(len(guard_manager.get_accessors()), 2) + self.assertTrue( + isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor) + ) + self.assertTrue( + isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor) + ) + # Check leaf guards on child managers + self.assertEqual( + len(guard_manager.getitem_manager(0, "", None).get_leaf_guards()), 1 + ) + self.assertEqual( + len(guard_manager.getitem_manager(1, "", None).get_leaf_guards()), 1 + ) + + self.assertTrue(guard_manager.check(foo)) + self.assertFalse(guard_manager.check([3, 4])) + self.assertFalse(guard_manager.check("foo")) + + def test_dict_getitem_accessor(self): + foo = { + "a": 1, + "b": 2, + } + + guards_manager = RootGuardManager() + guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) + guards_manager.dict_getitem_manager("a", "", 1).add_equals_match_guard( + 1, ["a == 1"] + ) + guards_manager.dict_getitem_manager("b", "", 2).add_equals_match_guard( + 2, ["b == 2"] + ) + + self.assertTrue(guards_manager.check(foo)) + self.assertFalse(guards_manager.check({"a": 1, "b": 3})) + + def test_globals(self): + global global_pair, Pair + guard_manager = RootGuardManager() + gpair_mgr = guard_manager.globals_dict_manager( + globals(), "", None + ).getitem_manager("global_pair", "", global_pair) + + gpair_mgr.add_lambda_guard( + lambda x: isinstance(x, Pair) + and isinstance(x.x, torch.Tensor) + and isinstance(x.y, int), + "global guard fail", + ) + + self.assertTrue(guard_manager.check(global_pair)) + global_pair.y = "foo" + self.assertFalse(guard_manager.check(global_pair)) + + def test_type_manager(self): + guard_manager = RootGuardManager() + + class A: + a = 4 + + class B(A): + def mul(self, x): + super().mul(x) + + foo = B() + f_locals = {"foo": foo} + + # len(type(foo).__mro__) == 2 + foo_mgr = guard_manager.getitem_manager("foo", "", foo) + type_manager = foo_mgr.type_manager("", type(foo)) + self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor)) + mro_manager = type_manager.getattr_manager("__mro__", "", type(foo).__mro__) + self.assertTrue( + isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor) + ) + mro_manager.add_length_check_guard( + 3, + "Expected len(type(foo).__mro__) == 3", + ) + + # type(foo).__mro__[0].a = 4 + item_manager = mro_manager.getitem_manager(1, "", type(foo).__mro__[1]) + self.assertTrue( + isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor) + ) + attr_manager = item_manager.getattr_manager("a", "", type(foo).__mro__[0].a) + self.assertTrue( + isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor) + ) + attr_manager.add_lambda_guard( + lambda x: x == 4, + "Expected value 4", + ) + + self.assertTrue(guard_manager.check(f_locals)) + + def test_tuple_iterator_getitem(self): + a = (1, 2, 3, 4, 5, 6) + foo = iter(a) + next(foo) # foo points at index=1 + + guard_manager = RootGuardManager() + # Check a[3] which is tuple_iterator_getitem(foo, 2) + guard_manager.add_tuple_iterator_length_guard( + 5, id_type(iter(tuple())), ["len == 5"] + ) + guard_manager.tuple_iterator_getitem_manager(2, "", foo).add_equals_match_guard( + a[3], ["x==4"] + ) + + # Check that type match works + self.assertFalse(guard_manager.check(False)) + + self.assertTrue(guard_manager.check(foo)) + + # Check that index error fails gracefully + b = (1, 2) + b_foo = iter(b) + self.assertFalse(guard_manager.check(b_foo)) + + def test_global_weakref(self): + guard_manager = RootGuardManager() + globals_manager = guard_manager.globals_dict_manager(globals(), "", None) + weakref_manager = globals_manager.global_weakref_manager("weakref_x", "", None) + + weakref_manager.add_lambda_guard( + lambda x: isinstance(x, torch.Tensor), + "global weakref fail", + ) + + self.assertTrue(guard_manager.check(None)) + global x + del x + self.assertFalse(guard_manager.check(None)) + + def test_lambda_manager(self): + a = (1, 1, 3, 4, 5, 6) + + guard_manager = RootGuardManager() + + # Check that we can use the same accessor + foo_mgr = guard_manager.lambda_manager(lambda x: x[2], "", None) + foo_mgr.add_lambda_guard( + lambda x: x == 3, + "Expected value 3", + ) + self.assertTrue(guard_manager.check(a)) + + # test that exception works + guard_manager = RootGuardManager() + + def fn(x): + raise AssertionError("Test") + return x + + foo_mgr = guard_manager.lambda_manager(fn, "", None) + + self.assertFalse(guard_manager.check(None)) + debug_info = guard_manager.check_verbose(None) + self.assertFalse(debug_info.result) + self.assertTrue("Test" in debug_info.verbose_code_parts[0]) + + def test_dict_contains_guard(self): + foo = {"a": 1, "b": 2} + guard = guards.DICT_CONTAINS(True, "a", ["has a"]) + + self.assertTrue(guard(foo)) + self.assertTrue(guard({"a": 1, "b": 2})) + self.assertFalse(guard({"b": 2, "c": 3})) + self.assertFalse(guard({})) + + guard = guards.DICT_CONTAINS(False, "c", ["not has c"]) + self.assertTrue(guard(foo)) + self.assertTrue(guard({"a": 1, "b": 2})) + self.assertFalse(guard({"b": 2, "c": 3})) + self.assertTrue(guard({})) + + def test_dict_guard_manager(self): + root = RootGuardManager() + + def nothing(): + pass + + f_locals = { + "d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)}, + } + + # its a getitem_manager just for f_locals. But the child guard manager + # should be a DictGuardManager. + dict_mgr = root.getitem_manager("d", "", f_locals["d"]) + self.assertTrue(isinstance(dict_mgr, DictGuardManager)) + + self.assertTrue(root.check(f_locals)) + + # Check that no one can add a leaf guard + with self.assertRaises(RuntimeError): + dict_mgr.add_id_match_guard(id_type(f_locals), "id match") + + # Check that no one can add an arbitrary accessor + with self.assertRaises(RuntimeError): + dict_mgr.getitem_manager("a", "", f_locals["d"]["a"]) + + # Check that it fails with different length dict + f_locals_prime = { + "d": {"a": 1, "b": 2}, + } + self.assertFalse(root.check(f_locals_prime)) + + # Add key-value manager ("a" : 1) + self.assertTrue(root.check(f_locals)) + dict_mgr.get_key_manager(0, "", "a").add_equals_match_guard( + "a", ["dict.keys()[0] == a"] + ) + self.assertTrue(root.check(f_locals)) + dict_mgr.get_value_manager(0, "", 1).add_equals_match_guard(1, ["d[0] == 1"]) + self.assertTrue(root.check(f_locals)) + + # Add key-value manager (nothing : {"z" : 3}) + self.assertTrue(root.check(f_locals)) + dict_mgr.get_key_manager(1, "", nothing).add_lambda_guard( + lambda x: x is nothing, ["x is nothing"] + ) + self.assertTrue(root.check(f_locals)) + value_mgr = dict_mgr.get_value_manager(1, "", f_locals["d"][nothing]) + self.assertTrue(isinstance(value_mgr, DictGuardManager)) + self.assertTrue(root.check(f_locals)) + + # Check structure + # Check that we are only guarding on two keys. This is common in + # LazyVariableTracker. + self.assertEqual(len(dict_mgr.get_key_value_managers()), 2) + + f_locals["d"]["a"] = 2 + self.assertFalse(root.check(f_locals)) + self.assertFalse(root.check_verbose(f_locals).result) + + f_locals["d"]["a"] = 1 + self.assertTrue(root.check(f_locals)) + + f_locals["d"].pop(100) + # fails because of len check + self.assertFalse(root.check(f_locals)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 8696b535809da41ee49d046191a42b78c6743f81..5579cf3db7098df51b54fdf3a1ebc2bb296799c3 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1,14 +1,15 @@ # Owner(s): ["module: dynamo"] +import enum import functools import pprint import re import unittest +import warnings import functorch.experimental.control_flow as control_flow import torch import torch_npu -import torchair import torch._dynamo.config as config import torch._dynamo.test_case @@ -25,6 +26,12 @@ from torch._dynamo.testing import ( ) from torch._dynamo.utils import counters, ifdynstaticdefault from torch._higher_order_ops.wrap import wrap +from torch.testing._internal.common_utils import ( + munge_exc, + TEST_WITH_TORCHDYNAMO, + xfailIfTorchDynamo, +) +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") @@ -115,6 +122,8 @@ def default_args_generator(seed_value): new_val = val + 1 * i elif isinstance(val, float): new_val = val + 0.1 * i + elif isinstance(val, enum.Enum): + new_val = val else: raise AssertionError("unexpected arg type") @@ -195,8 +204,8 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase): x = torch.randn(3) with self.assertRaisesRegex( - RuntimeError, - "while introspecting wrap, we were unable to trace function `inner`", + torch._dynamo.exc.Unsupported, + r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)", ): f(x) @@ -207,6 +216,22 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase): x = torch.randn(3) self._test_wrap_simple(f, default_args_generator((x,)), 2) + def test_enum_arg(self): + class SomeEnum(enum.Enum): + A = 0 + B = 1 + + def g(x, val): + if val == SomeEnum.A: + return torch.sin(x) + return torch.cos(x) + + def f(x, val): + return wrap(g, x, val) + + x = torch.randn(3) + self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2) + def test_return_captured_var(self): freevar = torch.randn(3) @@ -219,7 +244,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase): x = torch.randn(3) # Since, `x` is unused, we don't lift it to - # be the input. + # be the ipt. self._test_wrap_simple(fn, default_args_generator((x,)), 2) def test_return_captured_vars(self): @@ -235,7 +260,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase): x = torch.randn(3) # Since, `x` is unused, we don't lift it to - # be the input. + # be the ipt. self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4) def test_return_captured_var_used_multiple_times(self): @@ -408,7 +433,7 @@ class GraphModule(torch.nn.Module): x = torch.tensor(1.2) y = MyClass(torch.tensor(3.4)) - self._assert_wrap_fallback(f, (x, y)) + self._test_wrap_simple(f, [(x, y)], 3) def test_capture_constants(self): x = torch.randn(3, 3) @@ -514,7 +539,7 @@ class GraphModule(torch.nn.Module): x = torch.randn(3) # Since, `x` is unused, we don't lift it to - # be the input. + # be the ipt. self._test_wrap_simple(f, default_args_generator((x,)), 2, 3) def test_capture_value_created_in_subgraph(self): @@ -1084,7 +1109,6 @@ class GraphModule(torch.nn.Module): xs = torch.randn(2, 3, 3) y = torch.randn(3) - @torch.compile(backend=cnt, fullgraph=True) def map_f(xs, y): def inner(x, y): def inner2(x, y): @@ -1094,26 +1118,99 @@ class GraphModule(torch.nn.Module): return control_flow.map(inner, xs, y) - result = map_f(xs, y) - self.assertEqual(result, xs + y) - - map_gm = backend.graphs[0] - name_set = set() - for name, _ in map_gm.named_modules(): - name_set.add(name) - self.assertEqual(name_set, {"", "map_body_1.map_body_0", "map_body_1"}) + graphs = self._check_map_graph_and_extract(map_f, (xs, y)) + if graphs: + graph, body_graph = graphs + self.assertExpectedInline( + graph, + """\ +def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor): + l_xs_ = L_xs_ + l_y_ = L_y_ + map_body_1 = self.map_body_1 + map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None + getitem_1 = map_impl[0]; map_impl = None + return (getitem_1,)""", + ) + self.assertExpectedInline( + body_graph, + """\ +def forward(self, getitem, l_y_): + getitem_1 = getitem[0] + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, [getitem], [l_y_]); map_body_0 = getitem = l_y_ = None + getitem_2 = map_impl[0]; map_impl = None + return (getitem_2,)""", + ) def test_map_multi_return(self): cnt = CompileCounter() - @torch.compile(backend=cnt) def f(x): return control_flow.map(lambda x: (x.sin(), x.sin()), x) x = torch.randn(3) - result = f(x) - self.assertEqual(result, (x.sin(), x.sin())) - self.assertEqual(cnt.frame_count, 0) + graphs = self._check_map_graph_and_extract(f, (x,)) + if graphs: + graph, body_graph = graphs + self.assertExpectedInline( + graph, + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None + getitem_1 = map_impl[0] + getitem_2 = map_impl[1]; map_impl = None + return (getitem_1, getitem_2)""", + ) + self.assertExpectedInline( + body_graph, + """\ +def forward(self, getitem): + sin = getitem.sin() + sin_1 = getitem.sin(); getitem = None + return (sin, sin_1)""", + ) + + def test_map_pytree_return(self): + cnt = CompileCounter() + + def _construct_pytree(a): + return (a, [[[a]]], a, (a, (a,), a), {"a": a}) + + def f(x): + def inner_f(xs): + return _construct_pytree(xs) + + return control_flow.map(inner_f, x) + + x = torch.randn(3) + graphs = self._check_map_graph_and_extract(f, (x,)) + if graphs: + graph, body_graph = graphs + self.assertExpectedInline( + graph, + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None + getitem_1 = map_impl[0] + getitem_2 = map_impl[1] + getitem_3 = map_impl[2] + getitem_4 = map_impl[3] + getitem_5 = map_impl[4] + getitem_6 = map_impl[5] + getitem_7 = map_impl[6]; map_impl = None + return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""", + ) + self.assertExpectedInline( + body_graph, + """\ +def forward(self, getitem): + return (getitem, getitem, getitem, getitem, getitem, getitem, getitem)""", + ) def test_map_kwargs(self): cnt = CompileCounter() @@ -1138,14 +1235,61 @@ class GraphModule(torch.nn.Module): x = torch.randn(3, 1) y = torch.randn(3, 1) - compiled_fn = torch.compile(fn, backend=cnt, fullgraph=True) + graphs = self._check_map_graph_and_extract(fn, (x, y)) + if graphs: + graph, body_graph = graphs + self.assertExpectedInline( + graph, + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None + getitem_1 = map_impl[0]; map_impl = None + return (getitem_1,)""", + ) + self.assertExpectedInline( + body_graph, + """\ +def forward(self, getitem, const): + add = getitem + 3; getitem = None + sin = torch.sin(add); add = None + return (sin,)""", + ) - ref = fn(x, y) - res = compiled_fn(x, y) + def test_map_lowers_to_graph(self): + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) - self.assertEqual(ref, res) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdynstaticdefault(2, 3)) + def fn(x, y): + def inner(x, y): + return torch.sin(x + y) + + return control_flow.map(inner, x, y.size(0)) + + x = torch.randn(3, 1) + y = torch.randn(3, 1) + graphs = self._check_map_graph_and_extract(fn, (x, y)) + if graphs: + graph, body_graph = graphs + self.assertExpectedInline( + graph, + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None + getitem_1 = map_impl[0]; map_impl = None + return (getitem_1,)""", + ) + self.assertExpectedInline( + body_graph, + """\ +def forward(self, getitem, const): + add = getitem + 3; getitem = None + sin = torch.sin(add); add = None + return (sin,)""", + ) def test_cond_subgraph_name_is_valid(self): backend = EagerAndRecordGraphs() @@ -1297,6 +1441,25 @@ class GraphModule(torch.nn.Module): false_graph = gm.cond_false_0.code.strip() return (graph, true_graph, false_graph) + def _check_map_graph_and_extract(self, fn, args): + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + out = torch.compile(fn, backend=cnt, fullgraph=True)(*args) + self.assertEqual(out, fn(*args)) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(len(backend.graphs), 1) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + gm = backend.graphs[0] + graph = gm.code.strip() + subgraphs = [] + for module_name in gm._modules.keys(): + subgraphs.append(getattr(gm, module_name).code.strip()) + return (graph, *subgraphs) + def test_cond_branches_no_arguments(self): def fn(x): def true_fn(): @@ -1702,7 +1865,7 @@ class GraphModule(torch.nn.Module): self.assertEqual(len(backend.graphs), 1) wrap_node = find_first_node(backend.graphs[0], wrap) - # 3 args - 1 for input, and other 2 for the weight and bias + # 3 args - 1 for ipt, and other 2 for the weight and bias self.assertTrue(len(wrap_node.args), 3) # Check that the linear bias and weight are getattr in the outer graph @@ -2079,6 +2242,7 @@ class GraphModule(torch.nn.Module): """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""", ) + @config.patch(capture_func_transforms=True) def test_grad_source_fn_stack(self): backend = EagerAndRecordGraphs() @@ -2096,11 +2260,10 @@ class GraphModule(torch.nn.Module): actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"}) self.assertExpectedInline( pprint.pformat(actual_stack), - """\ -{'sin': ['grad_impl', 'grad_impl', 'sin'], - 'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""", + """{'sin': ['sin']}""", ) + @config.patch(capture_func_transforms=True) def test_vmap_source_fn_stack(self): backend = EagerAndRecordGraphs() @@ -2114,13 +2277,13 @@ class GraphModule(torch.nn.Module): x = torch.randn(3, 3, 3, 3) fn(x) gm = backend.graphs[0] - actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "add"}) + actual_stack = self._get_source_fn_stack( + gm, {"sum_1", "sum_2", "batched_output"} + ) self.assertExpectedInline( pprint.pformat(actual_stack), """\ -{'add': ['vmap_impl', 'vmap_impl', 'add'], - 'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'], - 'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""", +{'batched_output': ['add'], 'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""", ) def test_cond_pytree_operands(self): @@ -2132,8 +2295,7 @@ class GraphModule(torch.nn.Module): e = torch.randn(3, 3) f = torch.randn(3, 3) g = torch.randn(3, 3) - res = (a, [[[b]]], c, (d, (e,), f), {"g": g}) - return res + return (a, [[[b]]], c, (d, (e,), f), {"g": g}) pred = torch.tensor(True) inp = _construct_pytree() @@ -2213,13 +2375,212 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre torch.compile(fn, backend="eager")(pred, pytree_in) +class HigherOrderOpVmapGuardTests(LoggingTestCase): + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_grad_guard_ok(self, records): + vmap = torch.vmap + grad = torch.func.grad + + def g(x): + return vmap(grad(torch.sin))(x) + + @torch.compile(backend="eager") + def fn(x): + return vmap(g)(x) + + x = torch.randn(4, 5) + y = fn(x) + # sanity check + self.assertEqual(len(records), 0) + self.assertEqual(x.cos(), y) + + # Calling the same function again won't have any effect on guards + fn(x) + self.assertEqual(len(records), 0) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_grad_guard_fail(self, records): + grad = torch.func.grad + + @torch.compile(backend="eager") + def fn(x): + return grad(torch.sin)(x.sum()) + + x = torch.randn([]) + fn(x) + self.assertEqual(len(records), 0) + + # calling again should not invalidate the graph + fn(x) + self.assertEqual(len(records), 0) + + # call grad should retrigger compilation + x = torch.randn(3) + grad(fn)(x) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "pyfunctorch") + self.assertIn( + """\ + triggered by the following guard failure(s): + - torch._functorch.pyfunctorch.compare_functorch_state([])""", + munge_exc(record.getMessage()), + ) + + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_guard_ok(self, records): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.randn(3, 3, 4, 5) + y = fn(x) + # sanity check + self.assertEqual(len(records), 0) + self.assertEqual(x.sin(), y) + + # Calling the same function again won't have any effect on guards + z = fn(x) + self.assertEqual(len(records), 0) + self.assertEqual(x.sin(), z) + + # calling with a different object will also not affect guards + w = fn(z) + self.assertEqual(len(records), 0) + self.assertEqual(z.sin(), w) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_guard_fail_different_state(self, records): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 4) + y = torch.vmap(fn, randomness="same")(x) + self.assertEqual(x.sin(), y) + self.assertEqual(len(records), 0) + + # call vmap(vmap(fn))(x) should retrigger compilation + y = torch.vmap(fn, randomness="different")(x) + self.assertEqual(x.sin(), y) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "pyfunctorch") + self.assertIn( + """\ + triggered by the following guard failure(s): + - torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""", + record.getMessage(), + ) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_guard_fail(self, records): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + y = torch.vmap(fn)(x) + self.assertEqual(x.sin(), y) + self.assertEqual(len(records), 0) + + # call vmap(vmap(fn))(x) should retrigger compilation as + # _functorch.current_level() is not the same + x = torch.zeros(3, 3, 3, 4, 5) + y = torch.vmap(torch.vmap(fn))(x) + self.assertEqual(x.sin(), y) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "pyfunctorch") + self.assertIn( + """\ + triggered by the following guard failure(s): + - torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""", + record.getMessage(), + ) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_grad_vmap_guard_fail(self, records): + vmap = torch.vmap + grad = torch.func.grad + + def g(x): + y = vmap(torch.sin, randomness="same")(x) + return y.sum(0) + + @torch.compile(backend="eager") + def fn(x): + return grad(g)(x) + + x = torch.randn(3, 3) + y = vmap(fn, randomness="error")(x) + self.assertEqual(x.cos(), y) + + # previous FX graph should be invalidated + x = torch.randn(3, 3, 4) + y = vmap(vmap(fn, randomness="different"))(x) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "pyfunctorch") + self.assertIn( + """\ + triggered by the following guard failure(s): + - torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""", + munge_exc(record.getMessage()), + ) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @make_logging_test(recompiles=True) + def test_vmap_recompile_different_states(self, records): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + y = torch.vmap(fn, randomness="same")(x) + self.assertEqual(len(records), 0) # sanity check + + y = torch.vmap(fn, randomness="different")(x) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "pyfunctorch") + self.assertIn( + """\ + triggered by the following guard failure(s): + - torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""", + munge_exc(record.getMessage()), + ) + + class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): - def run(self, result=None): - # capture_func_transform will be set to False (for 2.1) till we - # support all transforms, so manually patch it to `True`` for - # testing on release branch. - with config.patch(capture_func_transforms=True): - super().run(result) + def tearDown(self): + # Ensure that in the case of a test failure, the next test won't fail + # because of a previous call to _vmap_increment_nesting that wasn't undone + # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1 + # and the call to increment nesting is not undone + if not TEST_WITH_TORCHDYNAMO: + return + + warn = False + while ci := torch._C._functorch.peek_interpreter_stack(): + if ci.key() == torch._C._functorch.TransformType.Vmap: + warn = True + torch._C._functorch._vmap_decrement_nesting() + else: + break + + if warn: + msg = ( + "Interpreter stack is not empty. Test should have called " + "'torch._C._functorch._vmap_decrement_nesting()'" + ) + warnings.warn(msg) def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): backend = EagerAndRecordGraphs() @@ -2231,18 +2592,15 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm - def test_grad(self): + @config.patch(capture_func_transforms=True) + def test_jacrev(self): counters.clear() - def fn(x): - return x.sin().sum() - def wrapper_fn(x): - return torch.func.grad(fn)(x) + return torch.func.jacrev(torch.sin)(x) - x = torch.randn(3, 3, 3) + x = torch.randn(4, 3) wrapped_gm = self._compile_check(wrapper_fn, (x,)) - # Dynamic shapes produce a slightly different graph. if check_dynamic_shape_capture(): return @@ -2253,92 +2611,77 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + child_3 = L_x_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (contiguous,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - sin = l_x_.sin(); l_x_ = None - sum_1 = sin.sum(); sin = None - return sum_1 -""", - ) + diff_primals = torch._C._functorch._wrap_for_grad(child_3, 1); child_3 = None - def test_grad_freevar_tensor(self): - counters.clear() - y = torch.randn(3, 3) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) - def fn(x): - return (x.sin() + y).sum() + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) - def wrapper_fn(x): - return torch.func.grad(fn)(x) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) - x = torch.randn(3, 3, 3) - expected = wrapper_fn(x) - actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) - self.assertEqual(actual, expected) + primal_out = torch.sin(diff_primals) - def test_grad_freevar_python_scalar(self): - counters.clear() - y = 3 + out_1 = torch._C._functorch._unwrap_for_grad(primal_out, 1) - def fn(x): - return (x.sin() + y).sum() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() - def wrapper_fn(x): - return torch.func.grad(fn)(x) + tensor = torch.tensor((12,)) + cumsum = tensor.cumsum(dim = 0); tensor = None + getitem = cumsum[slice(None, -1, None)]; cumsum = None + neg = getitem.neg(); getitem = None + unbind = neg.unbind(); neg = None - x = torch.randn(3, 3, 3) - wrapped_gm = self._compile_check(wrapper_fn, (x,)) + chunk = out_1.new_zeros(12, 12); out_1 = None - # Dynamic shapes produce a slightly different graph. - if check_dynamic_shape_capture(): - return + diagonal = chunk.diagonal(0) + fill_ = diagonal.fill_(1); diagonal = None - actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) - self.assertExpectedInline( - actual, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + arg_4 = chunk.view(12, 4, 3); chunk = None - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (contiguous,) + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - sin = l_x_.sin(); l_x_ = None - add = sin + 3; sin = None - sum_1 = add.sum(); add = None - return sum_1 -""", - ) + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') - def test_grad_capture_tensor(self): - counters.clear() + _add_batch_dim = torch._C._functorch._add_batch_dim(arg_4, 0, 1); arg_4 = None - def wrapper_fn(x): - y = torch.randn(3) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(primal_out, _add_batch_dim) - def fn(x): - return (x.sin() + y).sum() + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primal_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + batched_output = _autograd_grad[0]; _autograd_grad = None - return torch.func.grad(fn)(x) + result = torch._C._functorch._remove_batch_dim(batched_output, 1, 12, 0); batched_output = None - x = torch.randn(3, 3, 3) + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable_1 = torch._C._autograd._saved_tensors_hooks_enable() - wrapped_gm = self._compile_check(wrapper_fn, (x,)) + split = result.split((12,), dim = 0); result = None + split_1 = split[0]; split = None + output_input = split_1.view((4, 3, 4, 3)); split_1 = None + return (output_input, diff_primals, primal_out) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_jacrev_two_tensors_argnums(self): + counters.clear() + + def fn(x, y): + return y.sin() + + def wrapper_fn(x, y): + return torch.func.jacrev(fn, argnums=1)(x, y) + + x = torch.randn(4, 3) + y = torch.randn(3, 4) + wrapped_gm = self._compile_check(wrapper_fn, (x, y)) # Dynamic shapes produce a slightly different graph. if check_dynamic_shape_capture(): return @@ -2348,79 +2691,612 @@ class GraphModule(torch.nn.Module): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): + child_2 = L_x_ + child_5 = L_y_ - y = torch.randn(3) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_, y); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (y, contiguous) + _wrap_for_grad = torch._C._functorch._wrap_for_grad(child_2, 1); child_2 = None + diff_primals = torch._C._functorch._wrap_for_grad(child_5, 1); child_5 = None - class GraphModule(torch.nn.Module): - def forward(self, l_x_, y): - sin = l_x_.sin(); l_x_ = None - add = sin + y; sin = y = None - sum_1 = add.sum(); add = None - return sum_1 -""", - ) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) - def test_grad_closure_scalar(self): - counters.clear() + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) - def wrapper_fn(x): - y = 3.14 + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) - def fn(x): - return (x.sin() + y).sum() + primal_out = diff_primals.sin() - return torch.func.grad(fn)(x) + out_1 = torch._C._functorch._unwrap_for_grad(primal_out, 1) - x = torch.randn(3, 3, 3) + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() - # Graph break because dynamo is unable to get source `fn` and - # functools.wraps in `grad` leads to graph-break - wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) + tensor = torch.tensor((12,)) + cumsum = tensor.cumsum(dim = 0); tensor = None + getitem = cumsum[slice(None, -1, None)]; cumsum = None + neg = getitem.neg(); getitem = None + unbind = neg.unbind(); neg = None - # Dynamic shapes produce a slightly different graph. - if check_dynamic_shape_capture(): - return + chunk = out_1.new_zeros(12, 12); out_1 = None - actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) - self.assertExpectedInline( - actual, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + diagonal = chunk.diagonal(0) + fill_ = diagonal.fill_(1); diagonal = None - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (contiguous,) + arg_5 = chunk.view(12, 3, 4); chunk = None - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - sin = l_x_.sin(); l_x_ = None - add = sin + 3.14; sin = None - sum_1 = add.sum(); add = None - return sum_1 -""", - ) + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - def test_grad_has_aux(self): - counters.clear() + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') - y = 3.14 + _add_batch_dim = torch._C._functorch._add_batch_dim(arg_5, 0, 1); arg_5 = None - def fn(x): - return ((x.sin() + y).sum(), x.cos()) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(primal_out, _add_batch_dim) - def wrapper_fn(x): + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primal_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + batched_output = _autograd_grad[0]; _autograd_grad = None + + result = torch._C._functorch._remove_batch_dim(batched_output, 1, 12, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable_1 = torch._C._autograd._saved_tensors_hooks_enable() + + split = result.split((12,), dim = 0); result = None + split_1 = split[0]; split = None + + output_input = split_1.view((3, 4, 3, 4)); split_1 = None + return (output_input, diff_primals, primal_out) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_jacrev_has_aux(self): + counters.clear() + + def fn(x, y): + return y.sin(), x + + def wrapper_fn(x, y): + return torch.func.jacrev(fn, argnums=1, has_aux=True)(x, y) + + x = torch.randn(4, 3) + y = torch.randn(3, 4) + wrapped_gm = self._compile_check(wrapper_fn, (x, y)) + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): + child_2 = L_x_ + child_5 = L_y_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + aux = torch._C._functorch._wrap_for_grad(child_2, 1); child_2 = None + diff_primals = torch._C._functorch._wrap_for_grad(child_5, 1); child_5 = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + primal_out = diff_primals.sin() + + aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + out_1 = torch._C._functorch._unwrap_for_grad(primal_out, 1) + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + + tensor = torch.tensor((12,)) + cumsum = tensor.cumsum(dim = 0); tensor = None + getitem = cumsum[slice(None, -1, None)]; cumsum = None + neg = getitem.neg(); getitem = None + unbind = neg.unbind(); neg = None + + chunk = out_1.new_zeros(12, 12); out_1 = None + + diagonal = chunk.diagonal(0) + fill_ = diagonal.fill_(1); diagonal = None + + arg_5 = chunk.view(12, 3, 4); chunk = None + + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + + _add_batch_dim = torch._C._functorch._add_batch_dim(arg_5, 0, 1); arg_5 = None + + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(primal_out, _add_batch_dim) + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primal_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None + batched_output = _autograd_grad[0]; _autograd_grad = None + + result = torch._C._functorch._remove_batch_dim(batched_output, 1, 12, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable_1 = torch._C._autograd._saved_tensors_hooks_enable() + + split = result.split((12,), dim = 0); result = None + split_1 = split[0]; split = None + + output_input = split_1.view((3, 4, 3, 4)); split_1 = None + return (output_input, aux_2, diff_primals, primal_out) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_vjp(self): + counters.clear() + + def fn(x): + return x.sin().sum() + + def wrapper_fn(x, v): + (out, vjpfunc) = torch.func.vjp(fn, x) + return out + + x = torch.randn([5]) + v = torch.randn(5) + wrapped_gm = self._compile_check(wrapper_fn, (x, v)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + child_2 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = child_1.sin(); child_1 = None + primal_out = sin.sum(); sin = None + + out = torch._C._functorch._unwrap_for_grad(primal_out, 1); primal_out = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (out,) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_vjp_multiple_outputs(self): + counters.clear() + + def wrapper_fn(x, v): + fn = lambda x: (x.sin(), x.cos()) # noqa: E731 + (out, vjpfunc) = torch.func.vjp(fn, x) + vjps = vjpfunc((v, v)) + return out, vjps + + x = torch.randn([5]) + v = torch.randn(5) + wrapped_gm = self._compile_check(wrapper_fn, (x, v)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor, L_v_ : torch.Tensor): + child = L_x_ + child_8 = L_v_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + child_4 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + primal_out = child_1.sin() + primal_out_1 = child_1.cos(); child_1 = None + + _unwrap_for_grad = torch._C._functorch._unwrap_for_grad(primal_out, 1) + _unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(primal_out_1, 1) + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((primal_out, primal_out_1), (child_8, child_8)) + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primal_out, primal_out_1], [child_4], [child_8, child_8], retain_graph = True, create_graph = True); primal_out = primal_out_1 = child_4 = child_8 = None + getitem = _autograd_grad[0]; _autograd_grad = None + return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_vjp_multiple_outputs_python_struct(self): + counters.clear() + + def wrapper_fn(x, v): + fn = lambda x: {"first": x.sin(), "second": x.cos()} # noqa: E731 + (out, vjpfunc) = torch.func.vjp(fn, x) + vjps = vjpfunc({"first": v, "second": v.sin()}) + return out, vjps + + x = torch.randn([5]) + v = torch.randn(5) + wrapped_gm = self._compile_check(wrapper_fn, (x, v)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor, L_v_ : torch.Tensor): + child = L_x_ + child_7 = L_v_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + child_4 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + primal_out = child_1.sin() + primal_out_1 = child_1.cos(); child_1 = None + + _unwrap_for_grad = torch._C._functorch._unwrap_for_grad(primal_out, 1) + _unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(primal_out_1, 1) + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + + child_8 = child_7.sin() + + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': primal_out, 'second': primal_out_1}, {'first': child_7, 'second': child_8}) + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primal_out, primal_out_1], [child_4], [child_7, child_8], retain_graph = True, create_graph = True); primal_out = primal_out_1 = child_4 = child_7 = child_8 = None + getitem = _autograd_grad[0]; _autograd_grad = None + return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_vjp_has_aux(self): + counters.clear() + + def fn(x): + return x.sin().sum(), x + + def wrapper_fn(x, v): + (out, vjpfunc, _) = torch.func.vjp(fn, x, has_aux=True) + return out + + x = torch.randn([5]) + v = torch.randn(5) + wrapped_gm = self._compile_check(wrapper_fn, (x, v)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + aux = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + child_2 = torch._functorch.eager_transforms._set_tensor_requires_grad(aux) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = aux.sin() + primal_out = sin.sum(); sin = None + + _ = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + out = torch._C._functorch._unwrap_for_grad(primal_out, 1); primal_out = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (out,) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_grad(self): + counters.clear() + + def fn(x): + return x.sin().sum() + + def wrapper_fn(x): + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3, 3) + wrapped_gm = self._compile_check(wrapper_fn, (x,)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + output = sin.sum(); sin = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad,) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_grad_freevar_tensor(self): + counters.clear() + y = torch.randn(3, 3) + + def fn(x): + return (x.sin() + y).sum() + + def wrapper_fn(x): + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3, 3) + expected = wrapper_fn(x) + actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(actual, expected) + + @config.patch(capture_func_transforms=True) + def test_grad_freevar_python_scalar(self): + counters.clear() + y = 3 + + def fn(x): + return (x.sin() + y).sum() + + def wrapper_fn(x): + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3, 3) + wrapped_gm = self._compile_check(wrapper_fn, (x,)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + add = sin + 3; sin = None + output = add.sum(); add = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad,) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_grad_capture_tensor(self): + counters.clear() + + def wrapper_fn(x): + y = torch.randn(3) + + def fn(x): + return (x.sin() + y).sum() + + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3, 3) + + wrapped_gm = self._compile_check(wrapper_fn, (x,)) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + y = torch.randn(3) + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + add = sin + y; sin = None + output = add.sum(); add = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad, y) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_grad_closure_scalar(self): + counters.clear() + + def wrapper_fn(x): + y = 3.14 + + def fn(x): + return (x.sin() + y).sum() + + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3, 3) + + # Graph break because dynamo is unable to get source `fn` and + # functools.wraps in `grad` leads to graph-break + wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + add = sin + 3.14; sin = None + output = add.sum(); add = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad,) +""", + ) + + @config.patch(capture_func_transforms=True) + def test_grad_has_aux(self): + counters.clear() + + y = 3.14 + + def fn(x): + return ((x.sin() + y).sum(), x.cos()) + + def wrapper_fn(x): return torch.func.grad(fn, has_aux=True)(x) x = torch.randn(3, 3, 3) @@ -2436,26 +3312,40 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + child = L_x_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - getitem = call[0] - getitem_1 = call[1]; call = None - contiguous = getitem.contiguous(); getitem = None - return (contiguous, getitem_1) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - sin = l_x_.sin() - add = sin + 3.14; sin = None - sum_1 = add.sum(); add = None - cos = l_x_.cos(); l_x_ = None - return (sum_1, cos) + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + add = sin + 3.14; sin = None + output = add.sum(); add = None + aux = diff_args.cos() + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad, aux_2) """, ) + @config.patch(capture_func_transforms=True) def test_grad_two_tensor_has_aux(self): counters.clear() @@ -2479,27 +3369,42 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - l_x_ = L_x_ - l_y_ = L_y_ + child = L_x_ + child_1 = L_y_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None - call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None - getitem = call[0] - getitem_1 = call[1]; call = None - contiguous = getitem.contiguous(); getitem = None - return (contiguous, getitem_1) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_, l_y_): - sin = l_x_.sin() - add = sin + l_y_; sin = l_y_ = None - sum_1 = add.sum(); add = None - cos = l_x_.cos(); l_x_ = None - return (sum_1, cos) + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + add = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None + output = add.sum(); add = None + aux = diff_args.cos() + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad, aux_2) """, ) + @config.patch(capture_func_transforms=True) def test_grad_two_tensor_all_grad_has_aux(self): counters.clear() @@ -2534,27 +3439,45 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - l_x_ = L_x_ - l_y_ = L_y_ + child = L_x_ + child_1 = L_y_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None - call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None - getitem = call[0] - getitem_1 = getitem[0] - getitem_2 = getitem[1]; getitem = None - getitem_3 = call[1]; call = None - contiguous = getitem_1.contiguous(); getitem_1 = None - contiguous_1 = getitem_2.contiguous(); getitem_2 = None - return (contiguous, contiguous_1, getitem_3) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_, l_y_): - sin = l_x_.sin() - add = sin + l_y_; sin = l_y_ = None - sum_1 = add.sum(); add = None - cos = l_x_.cos(); l_x_ = None - return (sum_1, cos) + child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None + child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_5) + + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = child_4.sin() + add = sin + child_5; sin = None + output = add.sum(); add = None + aux = child_4.cos() + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None + child_6 = _autograd_grad[0] + child_7 = _autograd_grad[1]; _autograd_grad = None + + _unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None + _unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2) """, ) self.assertExpectedInline( @@ -2562,75 +3485,120 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - l_x_ = L_x_ - l_y_ = L_y_ + child = L_x_ + child_1 = L_y_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None - call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None - getitem = call[0] - getitem_1 = getitem[0] - getitem_2 = getitem[1]; getitem = None - getitem_3 = call[1]; call = None - contiguous = getitem_1.contiguous(); getitem_1 = None - contiguous_1 = getitem_2.contiguous(); getitem_2 = None - return (contiguous, contiguous_1, getitem_3) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_, l_y_): - sin = l_x_.sin() - add = sin + l_y_; sin = l_y_ = None - sum_1 = add.sum(); add = None - cos = l_x_.cos(); l_x_ = None - return (sum_1, cos) + child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None + child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_5) + + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = child_4.sin() + add = sin + child_5; sin = None + output = add.sum(); add = None + aux = child_4.cos() + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None + child_6 = _autograd_grad[0] + child_7 = _autograd_grad[1]; _autograd_grad = None + + _unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None + _unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2) """, ) - def test_grad_over_grad(self): - counters.clear() + @config.patch(capture_func_transforms=True) + def test_grad_over_grad(self): + counters.clear() + + def fn(x): + return x.sin().sum() + + def wrapper_fn(x): + return torch.func.grad(torch.func.grad(fn))(x) + + x = torch.randn(()) + wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) + + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + child = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + + child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting() + + diff_args_1 = torch._C._functorch._wrap_for_grad(child_1, 2) + + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1) + + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args_1.sin() + output = sin.sum(); sin = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None + grad_input = _autograd_grad[0]; _autograd_grad = None - def fn(x): - return x.sin().sum() + output_2 = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None - def wrapper_fn(x): - return torch.func.grad(torch.func.grad(fn))(x) + _ = torch._C._functorch._unwrap_for_grad(output, 2); output = None - x = torch.randn(()) - wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False) + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") - if check_dynamic_shape_capture(): - return + _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((output_2,), [child_1], create_graph = True); child_1 = None + grad_input_2 = _autograd_grad_1[0]; _autograd_grad_1 = None - actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) - self.assertExpectedInline( - actual, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + grad_1 = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None - grad_body_1 = self.grad_body_1 - grad_proxy = torch.func.grad(grad_body_1, 0, False); grad_body_1 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (contiguous,) + __1 = torch._C._functorch._unwrap_for_grad(output_2, 1); output_2 = None - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return contiguous - - class GraphModule(torch.nn.Module): - def forward(self, l_x_): - sin = l_x_.sin(); l_x_ = None - sum_1 = sin.sum(); sin = None - return sum_1 + _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad_1,) """, ) + @config.patch(capture_func_transforms=True) def test_grad_with_graph_break(self): counters.clear() @@ -2647,6 +3615,7 @@ class GraphModule(torch.nn.Module): self.assertEqual(len(counters["graph_break"]), 1) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_with_side_effect(self): counters.clear() @@ -2662,16 +3631,10 @@ class GraphModule(torch.nn.Module): x = torch.randn(3, 3, 3) actual = wrapper_fn(x) expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(len(counters["graph_break"]), 1) - assert_dict_matches_regex( - self, - dict(counters["graph_break"]), - { - r".*HigherOrderOperator: Mutating a variable not in the current scope \(replace_all\)": 2 - }, - ) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_pytree(self): counters.clear() @@ -2688,14 +3651,10 @@ class GraphModule(torch.nn.Module): expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( (x1, x2) ) - self.assertEqual(len(counters["graph_break"]), 1) - assert_dict_matches_regex( - self, - dict(counters["graph_break"]), - {".*HigherOrderOperator with body that accepts non-Tensors as input": 2}, - ) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_non_tensor_input(self): counters.clear() @@ -2719,23 +3678,37 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - l_x_ = L_x_ + child = L_x_ - grad_body_0 = self.grad_body_0 - grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None - call = grad_proxy.__call__(l_x_, 3.0); grad_proxy = l_x_ = None - contiguous = call.contiguous(); call = None - return (contiguous,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() - class GraphModule(torch.nn.Module): - def forward(self, l_x_, const): - sin = l_x_.sin(); l_x_ = None - sum_1 = sin.sum(); sin = None - add = sum_1 + 3.0; sum_1 = None - return add + diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + + sin = diff_args.sin() + sum_1 = sin.sum(); sin = None + output = sum_1 + 3.0; sum_1 = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input = _autograd_grad[0]; _autograd_grad = None + + grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + + _ = torch._C._functorch._unwrap_for_grad(output, 1); output = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (grad,) """, ) + @config.patch(capture_func_transforms=True) def test_grad_disable_capture(self): counters.clear() @@ -2763,6 +3736,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_grad_fn_with_kwargs(self): def fn(x, y): return (x + y).sum() @@ -2774,13 +3748,177 @@ class GraphModule(torch.nn.Module): y = torch.randn(3, 3) actual = wrapper_fn(x, y) expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) - self.assertEqual(len(counters["graph_break"]), 1) - self.assertEqual( - dict(counters["graph_break"]), - {"torch.func.grad: kwargs arguments are currently unsupported.": 2}, - ) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) + @config.patch(error_on_recompile=True) + def test_vmap_recompile(self): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + y = torch.vmap(fn)(x) + # should not recompile on second call. See Pytorch issue #118493 + y = torch.vmap(fn)(x) + + @xfailIfTorchDynamo + @config.patch(capture_func_transforms=True) + @config.patch(error_on_recompile=True) + def test_vmap_recompile_different_config(self): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + y = torch.vmap(fn)(x) + with self.assertRaises(torch._dynamo.exc.RecompileError): + fn(x) + + @config.patch(capture_func_transforms=True) + @config.patch(error_on_recompile=True) + def test_vmap_recompile_same_config(self): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + torch.vmap(torch.vmap(fn, randomness="same"), randomness="same")(x) + with self.assertRaises(torch._dynamo.exc.RecompileError): + torch.vmap(torch.vmap(fn, randomness="same"), randomness="error")(x) + + @config.patch(capture_func_transforms=True) + @config.patch(error_on_recompile=True) + def test_vmap_recompile_with_randomness(self): + @torch.compile(backend="eager") + def fn(x): + return torch.vmap(lambda x: x.sin())(x) + + x = torch.zeros(3, 3, 4, 5) + torch.vmap(fn, randomness="same")(x) + with self.assertRaises(torch._dynamo.exc.RecompileError): + torch.vmap(fn, randomness="different")(x) + + @config.patch(capture_func_transforms=True) + @config.patch(error_on_recompile=True) + def test_grad_recompile(self): + @torch.compile(backend="eager") + def fn(x): + return torch.func.grad(torch.sin)(x) + + x = torch.randn([]) + torch.func.grad(fn)(x) + # should not recompile on second call + torch.func.grad(fn)(x) + + @config.patch(capture_func_transforms=True) + def test_vmap_get_wrapped(self): + counters.clear() + + def g(x): + return x.sin() + + @torch.compile(backend="aot_eager", fullgraph=True) + def fn(): + return torch.vmap(g) + + x = torch.randn(3, 4) + expected = torch.vmap(g)(x) + wrapper = fn() + got = wrapper(x) + self.assertEqual(expected, got) + + @config.patch(capture_func_transforms=True) + def test_vmap_with_conditional_graph_break(self): + def g(x): + if len(x.shape) < 2: + torch._dynamo.graph_break() + return x.sin() + else: + return x.cos() + + @torch.compile(backend="aot_eager") + def fn(x): + return torch.vmap(g)(x) + + counters.clear() + x = torch.randn(2, 3) + expected = x.sin() + got = fn(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 1) + + counters.clear() + y = torch.randn(2, 3, 4) + expected = y.cos() + got = fn(y) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 0) + + @config.patch(capture_func_transforms=True) + def test_vmap_with_graph_break(self): + counters.clear() + + def g(x): + y = x.cos() + print("hi") + return y.sin() + + def fn(x): + return torch.vmap(g)(x) + + x = torch.randn(3, 4) + opt = torch.compile(fn, backend="aot_eager", fullgraph=False) + expected = fn(x) + got = opt(x) + self.assertEqual(len(counters["graph_break"]), 1) + self.assertEqual(expected, got) + + @config.patch(capture_func_transforms=True) + def test_vmap_with_graph_break_2(self): + counters.clear() + + def cos(x): + print("cos") + return x.cos() + + def sin(x): + print("sin") + return x.sin() + + def g(x): + y = cos(x) + return sin(y) + + def fn(x): + return torch.vmap(g, randomness="same")(x) + + x = torch.randn(3, 4) + opt = torch.compile(fn, backend="aot_eager", fullgraph=False) + expected = fn(x) + got = opt(x) + self.assertEqual(len(counters["graph_break"]), 1) + self.assertEqual(expected, got) + + def test_vmap_with_graph_break_lambda(self): + counters.clear() + + def sin(x): + print("sin") + return x.sin() + + def fn(x): + return torch.vmap(lambda x: sin(x))(x) + + x = torch.randn(3, 4) + opt = torch.compile(fn, backend="aot_eager", fullgraph=False) + expected = fn(x) + got = opt(x) + self.assertEqual(len(counters["graph_break"]), 1) + self.assertEqual(expected, got) + + @config.patch(capture_func_transforms=True) def test_vmap(self): def fn(x): return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) @@ -2798,25 +3936,28 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - child = L_x_ + arg = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child); vmap_proxy = child = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - add = sum_1 + sum_2; sum_1 = sum_2 = None - return add + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + sum_1 = _add_batch_dim.sum(0) + sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None + batched_output = sum_1 + sum_2; sum_1 = sum_2 = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_free_const(self): y = 3 @@ -2836,26 +3977,29 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - child = L_x_ + arg = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child); vmap_proxy = child = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - add = sum_1 + sum_2; sum_1 = sum_2 = None - add_1 = add + 3; add = None - return add_1 + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + sum_1 = _add_batch_dim.sum(0) + sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None + add = sum_1 + sum_2; sum_1 = sum_2 = None + batched_output = add + 3; add = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_free_tensor(self): y = torch.randn(3, 3) @@ -2875,27 +4019,30 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - child = L_x_ + arg = L_x_ l_y_ = L_y_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child, l_y_); vmap_proxy = child = l_y_ = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select, l_y_): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - add = sum_1 + sum_2; sum_1 = sum_2 = None - add_1 = add + l_y_; add = l_y_ = None - return add_1 + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + sum_1 = _add_batch_dim.sum(0) + sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None + add = sum_1 + sum_2; sum_1 = sum_2 = None + batched_output = add + l_y_; add = l_y_ = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_two_inputs(self): def fn(x, y): return torch.func.vmap( @@ -2916,28 +4063,31 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - child = L_x_ - child_1 = L_y_ + arg = L_x_ + arg_3 = L_y_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - select_1 = child_1.select(1, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select, select_1): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - add = sum_1 + sum_2; sum_1 = sum_2 = None - add_1 = add + select_1; add = select_1 = None - return add_1 + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + _add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None + + sum_1 = _add_batch_dim.sum(0) + sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None + add = sum_1 + sum_2; sum_1 = sum_2 = None + batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_two_inputs_tuple_in_dims(self): in_dims = (0, 1) @@ -2960,28 +4110,31 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - child = L_x_ - child_1 = L_y_ + arg = L_x_ + arg_3 = L_y_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - select_1 = child_1.select(1, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select, select_1): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - add = sum_1 + sum_2; sum_1 = sum_2 = None - add_1 = add + select_1; add = select_1 = None - return add_1 + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + _add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None + + sum_1 = _add_batch_dim.sum(0) + sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None + add = sum_1 + sum_2; sum_1 = sum_2 = None + batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_over_vmap_two_inputs(self): def fn(x, y): return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y) @@ -3000,35 +4153,41 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): - child = L_x_ - child_1 = L_y_ + arg = L_x_ + arg_3 = L_y_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') - _check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - select_1 = child_1.select(0, 0) - vmap_body_1 = self.vmap_body_1 - vmap_proxy = torch.func.vmap(vmap_body_1, (0, 0), 0, 'error'); vmap_body_1 = None - call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select, select_1): - select_2 = select.select(1, 0) - select_3 = select_1.select(1, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (1, 1), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(select, select_1); vmap_proxy = select = select_1 = None - return call - - class GraphModule(torch.nn.Module): - def forward(self, select_2, select_3): - add = select_2 + select_3; select_2 = select_3 = None - return add + arg_8 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + arg_9 = torch._C._functorch._add_batch_dim(arg_3, 0, 1); arg_3 = None + + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error') + + _add_batch_dim_2 = torch._C._functorch._add_batch_dim(arg_8, 1, 2); arg_8 = None + _add_batch_dim_3 = torch._C._functorch._add_batch_dim(arg_9, 1, 2); arg_9 = None + + batched_output = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None + + batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + + _remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 3, 0); batched_output_1 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim_1,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_over_vmap_captured(self): x = torch.ones(2, 3) y = torch.ones(5, 3) @@ -3048,33 +4207,39 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor): - child = L_y_ + arg = L_y_ l_x_ = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') - _check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_1 = self.vmap_body_1 - vmap_proxy = torch.func.vmap(vmap_body_1, (0, None), 0, 'error'); vmap_body_1 = None - call = vmap_proxy.__call__(child, l_x_); vmap_proxy = child = l_x_ = None - return (call,) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select, l_x_): - select_1 = select.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(select, l_x_); vmap_proxy = select = l_x_ = None - return call - - class GraphModule(torch.nn.Module): - def forward(self, select_1, l_x_): - mul = l_x_ * select_1; l_x_ = select_1 = None - return mul + arg_3 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error') + + _add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 0, 2); arg_3 = None + + batched_output = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None + + batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + + _remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 5, 0); batched_output_1 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim_1,) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs(self): x = torch.ones(2, 4, 3) @@ -3093,26 +4258,28 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - child = L_x_ + arg = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child); vmap_proxy = child = None - getitem = call[0] - getitem_1 = call[1]; call = None - return (getitem, getitem_1) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - return (sum_1, sum_2) + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + batched_output = _add_batch_dim.sum(0) + batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 0); batched_output = None + _remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim, _remove_batch_dim_1) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs_diff_dims(self): x = torch.ones(2, 4, 3) @@ -3131,26 +4298,28 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - child = L_x_ + arg = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child); vmap_proxy = child = None - getitem = call[0] - getitem_1 = call[1]; call = None - return (getitem, getitem_1) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - return (sum_1, sum_2) + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + batched_output = _add_batch_dim.sum(0) + batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None + _remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim, _remove_batch_dim_1) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_multiple_outputs_out_dims_tuple(self): x = torch.ones(2, 4, 3) out_dims = (1, 0) @@ -3170,26 +4339,28 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): - child = L_x_ + arg = L_x_ - _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() - select = child.select(0, 0) - vmap_body_0 = self.vmap_body_0 - vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None - call = vmap_proxy.__call__(child); vmap_proxy = child = None - getitem = call[0] - getitem_1 = call[1]; call = None - return (getitem, getitem_1) + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') - class GraphModule(torch.nn.Module): - def forward(self, select): - sum_1 = select.sum(0) - sum_2 = select.sum(1); select = None - return (sum_1, sum_2) + _add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None + + batched_output = _add_batch_dim.sum(0) + batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None + + _remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None + _remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + return (_remove_batch_dim, _remove_batch_dim_1) """, ) + @config.patch(capture_func_transforms=True) def test_vmap_kwargs(self): counters.clear() x = torch.ones(2, 3) @@ -3200,13 +4371,10 @@ class GraphModule(torch.nn.Module): actual = fn(x, y) expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y) - self.assertEqual(len(counters["graph_break"]), 1) - self.assertEqual( - dict(counters["graph_break"]), - {"NYI - torch.func.vmap: kwargs arguments are currently unsupported.": 2}, - ) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_vmap_pytree_inputs(self): counters.clear() x = torch.ones(2, 3) @@ -3222,17 +4390,10 @@ class GraphModule(torch.nn.Module): actual = fn(x, y) expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y) - self.assertEqual(len(counters["graph_break"]), 2) - assert_dict_matches_regex( - self, - dict(counters["graph_break"]), - { - ".*HigherOrderOperator with body that accepts non-Tensors as input": 2, - "Unsupported: meta converter nyi with fake tensor propagation.": 1, - }, - ) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_vmap_side_effects(self): counters.clear() x = torch.ones(2, 3) @@ -3249,16 +4410,51 @@ class GraphModule(torch.nn.Module): actual = wrapper_fn(x, y) expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) - self.assertEqual(len(counters["graph_break"]), 1) - assert_dict_matches_regex( - self, - dict(counters["graph_break"]), - { - r".*HigherOrderOperator: Mutating a variable not in the current scope \(replace_all\)": 2 - }, - ) + self.assertEqual(len(counters["graph_break"]), 0) + self.assertEqual(actual, expected) + self.assertEqual(some_list, [1, 1]) + + @unittest.expectedFailure + @config.patch(capture_func_transforms=True) + def test_vmap_side_effects_append_input(self): + counters.clear() + x = torch.ones(2, 3) + y = torch.randn(2, 3) + + some_list = [] + + def f(x, y): + some_list.append(x) + return x + y + + def wrapper_fn(x, y): + return torch.func.vmap(f)(x, y) + + actual = wrapper_fn(x, y) + expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y) + self.assertEqual(len(counters["graph_break"]), 0) + self.assertEqual(actual, expected) + + @config.patch(capture_func_transforms=True) + def test_vmap_previous_illegal_op_no_graph_break(self): + counters.clear() + + # calling .stride() would previously graph break + def bad_fn(x): + y = x.view((4, 3)) + y.stride() + return y + + def wrapper_fn(x): + return torch.func.vmap(bad_fn)(x) + + x = torch.randn(2, 3, 4) + actual = wrapper_fn(x) + expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + @config.patch(capture_func_transforms=True) def test_vmap_disable_capture(self): counters.clear() @@ -3283,27 +4479,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(actual, expected) - def test_vmap_illegal_op_graph_break(self): - counters.clear() - - def bad_fn(x): - x.stride() - return x - - def wrapper_fn(x): - return torch.func.vmap(bad_fn)(x) - - x = torch.randn(3, 3, 3) - actual = wrapper_fn(x) - expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(len(counters["graph_break"]), 1) - assert_dict_matches_regex( - self, - dict(counters["graph_break"]), - {".*Illegal getattr invocation stride in strict mode": 2}, - ) - self.assertEqual(actual, expected) - + @config.patch(capture_func_transforms=True) def test_vmap_multiple_invocation_in_dims(self): counters.clear() @@ -3318,8 +4494,9 @@ class GraphModule(torch.nn.Module): actual = opt(x, 0), opt(x, 1), opt(x, 2) self.assertEqual(expected, actual) self.assertEqual(cnt.frame_count, 3) - self.assertEqual(cnt.op_count, 9) + self.assertEqual(cnt.op_count, 33) + @config.patch(capture_func_transforms=True) def test_vmap_multiple_invocation_out_dims(self): counters.clear() @@ -3334,8 +4511,9 @@ class GraphModule(torch.nn.Module): actual = opt(x, 0), opt(x, 1), opt(x, 2) self.assertEqual(expected, actual) self.assertEqual(cnt.frame_count, 3) - self.assertEqual(cnt.op_count, 9) + self.assertEqual(cnt.op_count, 30) + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_in_body(self): def fn(x): return x + torch.ones(3) @@ -3346,11 +4524,12 @@ class GraphModule(torch.nn.Module): x = torch.randn( 3, ) - opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True) + opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) expected = wrapper_fn(x) actual = opt(x) self.assertEqual(expected, actual) + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_unused_in_body(self): def fn(x): return torch.tensor(0.5) @@ -3359,17 +4538,18 @@ class GraphModule(torch.nn.Module): return torch.func.vmap(fn)(x) x = torch.randn(3) - opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True) + opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) expected = wrapper_fn(x) actual = opt(x) self.assertEqual(expected, actual) + @config.patch(capture_func_transforms=True) def test_vmap_new_tensor_implicit_via_op(self): def wrapper_fn(x): return torch.func.vmap(lambda t: torch.add(t, 0.5))(x) x = torch.randn(3) - opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True) + opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True) expected = wrapper_fn(x) actual = opt(x) self.assertEqual(expected, actual) @@ -3402,7 +4582,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): - return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y) + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ) x = torch.randn(4, 4, requires_grad=True) y = torch.randn(4, 4, requires_grad=True) @@ -3422,7 +4604,11 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + gn, + torch.sin(x), + y, + use_reentrant=True, + preserve_rng_state=False, ) x = torch.randn(4, 4, requires_grad=True) @@ -3442,7 +4628,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2) def fn(x, y): - return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y) + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ) x = torch.randn(4, 4, device="npu:0", requires_grad=True) y = torch.randn(4, 4, device="npu:0", requires_grad=True) @@ -3460,20 +4648,21 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): @requires_npu() @torch._functorch.config.patch(functionalize_rng_ops=True) - def test_dropout_npubackend(self): - # official case is test_dropout_inductor + def test_dropout_inductor(self): def gn(x, y): return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2) def fn(x, y): - return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y) + return torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ) x = torch.randn(4, 4, device="npu:0", requires_grad=True) y = torch.randn(4, 4, device="npu:0", requires_grad=True) - - npu_backend = torchair.get_npu_backend() + + backend = "npu" self._validate( - fn, npu_backend, x, y, skip_check=True + fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager @requires_npu() @@ -3484,7 +4673,11 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): - return torch.cos(torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)) + return torch.cos( + torch.utils.checkpoint.checkpoint( + gn, torch.sin(x), y, use_reentrant=True + ), + ) x = torch.randn(4, 4, requires_grad=True) y = torch.randn(4, 4, requires_grad=True) @@ -3498,7 +4691,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): self.assertEqual(result, expected) - # One graph for torch.sin on the input, and other for torch.cos. + # One graph for torch.sin on the ipt, and other for torch.cos. self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 2) self.assertEqual(len(backend.graphs), 2) @@ -3517,7 +4710,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): mod = MockModule() def fn(x): - return torch.utils.checkpoint.checkpoint(mod, torch.sin(x)) + return torch.utils.checkpoint.checkpoint( + mod, torch.sin(x), use_reentrant=True + ) x = torch.randn(10, 10, requires_grad=True) diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 520ed34976ce1e92613f82cd63c55a98556d3e0b..7f7dbc2fad0b259a1ded53b6a78e7ab8536329ce 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -2,6 +2,7 @@ import contextlib import functools +import unittest import torch import torch_npu @@ -11,6 +12,7 @@ import torch._dynamo.testing from functorch.compile import nop from torch._dynamo import compiled_autograd from torch._functorch.aot_autograd import aot_module_simplified +from torch.utils.hooks import RemovableHandle def compiler_fn(gm): @@ -32,6 +34,11 @@ def global_hook_2(grad): h0 = None +class ClassWithVal: + def __init__(self, val): + self.val = val + + class HooksTests(torch._dynamo.test_case.TestCase): def test_tensor_only_register_hook_in_graph_lambda(self): def fn(x): @@ -73,7 +80,7 @@ class HooksTests(torch._dynamo.test_case.TestCase): v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) - self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.frame_count, 1) def test_tensor_register_hook_multi_handle_return(self): def fn(x, y, z): @@ -107,9 +114,27 @@ class HooksTests(torch._dynamo.test_case.TestCase): v.backward(torch.tensor([1.0, 2.0, 3.0])) self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) self.assertEqual(cnts.frame_count, 1) - self.assertNotEqual(h, None) - self.assertNotEqual(h2, None) - self.assertEqual(h2, h) + self.assertIsInstance(h, RemovableHandle) + self.assertIs(h2, h) + + def test_removed_handle_return(self): + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y, z): + handle = x.register_hook(lambda grad: grad * 2) + z = z * z + handle.remove() + handle.remove() + return x, y * y, z, handle, handle + + v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) + v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) + v.backward(torch.tensor([1.0, 2.0, 3.0])) + self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0])) + self.assertEqual(cnt.frame_count, 1) + self.assertIsInstance(h, RemovableHandle) + self.assertIs(h2, h) def test_tensor_register_hook_repeated_handle_not_local(self): def fn(x, y, z, mod): @@ -118,7 +143,7 @@ class HooksTests(torch._dynamo.test_case.TestCase): return x, y * y, z cnts = torch._dynamo.testing.CompileCounter() - fn = torch._dynamo.optimize(cnts)(fn) + fn = torch._dynamo.optimize(cnts, nopython=True)(fn) v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) mod = torch.nn.Module() @@ -234,10 +259,10 @@ class HooksTests(torch._dynamo.test_case.TestCase): def test_tensor_register_multiple_hooks_handles_in_list(self): def fn(x): - h_0 = x.register_hook(global_hook_0) # * 4 - h_1 = x.register_hook(global_hook_1) # / 2 - h_2 = x.register_hook(global_hook_2) # * 3 - return x, x * x, h_0, h_1, h_2 + h0 = x.register_hook(global_hook_0) # * 4 + h1 = x.register_hook(global_hook_1) # / 2 + h2 = x.register_hook(global_hook_2) # * 3 + return x, x * x, h0, h1, h2 cnts = torch._dynamo.testing.CompileCounter() fn = torch._dynamo.optimize(cnts)(fn) @@ -495,6 +520,127 @@ class HooksTests(torch._dynamo.test_case.TestCase): self.assertEqual(obj.count, 2) + def test_register_hook_partial_guarding( + self, + ): + def some_hook(grad, *, obj): + return grad + obj.val + + class MyMod(torch.nn.Module): + def forward(self, x, obj): + y = x.mul(2) + hook1 = functools.partial(some_hook, obj=obj) + y.register_hook(hook1) + z = y.mul(3) + return (z,) + + mod = MyMod() + obj1 = ClassWithVal(torch.tensor(88)) + obj2 = ClassWithVal(torch.tensor(99)) + obj3 = ClassWithVal(11) + cnt = torch._dynamo.testing.CompileCounter() + + x0 = torch.ones(4, requires_grad=True) + x1 = torch.ones(4, requires_grad=True) + + with compiled_autograd.enable(compiler_fn): + torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1) + torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1) + torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2) + torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3) + self.assertEqual(cnt.frame_count, 1) + + def test_hook_with_closure(self): + def fn(x, obj): + y = x.sin() + x.register_hook(lambda grad: grad + obj.val) + z = y.sin() + return z + + cnt_fw = torch._dynamo.testing.CompileCounter() + cnt_bw = torch._dynamo.testing.CompileCounter() + opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) + + obj1 = ClassWithVal(torch.tensor(88)) + obj2 = ClassWithVal(torch.tensor(99)) + x0 = torch.ones(4, requires_grad=True) + x1 = torch.ones(4, requires_grad=True) + x2 = torch.ones(4, requires_grad=True) + x3 = torch.ones(4, requires_grad=True) + fn(x0, obj1).sum().backward() + fn(x1, obj2).sum().backward() + + with compiled_autograd.enable( + functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) + ): + opt(x2, obj1).sum().backward() + opt(x3, obj2).sum().backward() + self.assertEqual(cnt_fw.frame_count, 1) + self.assertEqual(cnt_bw.frame_count, 1) + + self.assertEqual(x0.grad, x2.grad) + self.assertEqual(x1.grad, x3.grad) + + def test_intermediate_hook_with_closure_eager(self): + def fn(x, obj): + y = x.sin() + y.register_hook(lambda grad: grad + obj.val) + z = y.sin() + return z + + cnt_fw = torch._dynamo.testing.CompileCounter() + cnt_bw = torch._dynamo.testing.CompileCounter() + opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) + + obj1 = ClassWithVal(torch.tensor(88)) + obj2 = ClassWithVal(torch.tensor(99)) + x0 = torch.ones(4, requires_grad=True) + x1 = torch.ones(4, requires_grad=True) + x2 = torch.ones(4, requires_grad=True) + x3 = torch.ones(4, requires_grad=True) + fn(x0, obj1).sum().backward() + fn(x1, obj2).sum().backward() + + with compiled_autograd.enable( + functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) + ): + opt(x2, obj1).sum().backward() + opt(x3, obj2).sum().backward() + self.assertEqual(cnt_fw.frame_count, 1) + self.assertEqual(cnt_bw.frame_count, 1) + + self.assertEqual(x0.grad, x2.grad) + self.assertEqual(x1.grad, x3.grad) + + def test_intermediate_hook_with_closure_aot(self): + def fn(x, obj): + y = x.sin() + y.register_hook(lambda grad: grad + obj.val) + z = y.sin() + return z + + cnt_bw = torch._dynamo.testing.CompileCounter() + opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + + obj1 = ClassWithVal(torch.tensor(88)) + obj2 = ClassWithVal(torch.tensor(99)) + x0 = torch.ones(4, requires_grad=True) + x1 = torch.ones(4, requires_grad=True) + x2 = torch.ones(4, requires_grad=True) + x3 = torch.ones(4, requires_grad=True) + fn(x0, obj1).sum().backward() + fn(x1, obj2).sum().backward() + + with compiled_autograd.enable( + functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) + ): + opt(x2, obj1).sum().backward() + opt(x3, obj2).sum().backward() + self.assertEqual(cnt_bw.frame_count, 1) + + self.assertEqual(x0.grad, x2.grad) + self.assertEqual(x1.grad, x3.grad) + def test_no_recompile_on_hook_identity_change(self): def my_hook(grad, k=0): return grad + k @@ -531,7 +677,7 @@ class HooksTests(torch._dynamo.test_case.TestCase): comp_out = comp_mod(x1) - self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.frame_count, 1) comp_out[0].backward(torch.ones(4)) self.assertEqual(x0.grad, x1.grad) @@ -605,6 +751,26 @@ class HooksTests(torch._dynamo.test_case.TestCase): with compiled_bwd_ctx: test_fn(compiled_fn) + def test_recompile(self): + def hook(param): + param.grad *= 2 + + x = torch.ones(10) + x.requires_grad = True + + def run(ipt): + return x * ipt + + x.register_post_accumulate_grad_hook(hook) + with compiled_autograd.enable(compiler_fn): + for i in range(5): + with unittest.mock.patch( + "torch._dynamo.config.error_on_recompile", True + ): + # Mimic optimizer.zero_grad() to clear the gradient + x.grad = None + run(i).sum().backward() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_input_attr_tracking.py b/test/dynamo/test_input_attr_tracking.py index 092de04693c2fc8113e283045cc3ba089e89d934..d7ef3ea4d982baf15af4f2d7c22bdc0e0c6d43a4 100644 --- a/test/dynamo/test_input_attr_tracking.py +++ b/test/dynamo/test_input_attr_tracking.py @@ -2,7 +2,6 @@ # flake8: noqa import torch import torch_npu - import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing @@ -70,7 +69,7 @@ class TestInputAttrTracking(torch._dynamo.test_case.TestCase): if node.op == "placeholder": placeholder_cnt += 1 - # y is already an input + # y is already an ipt self.assertEqual(placeholder_cnt, 2) def test_const_property_on_tensor(self): diff --git a/test/dynamo/test_interop.py b/test/dynamo/test_interop.py index 48cd8ba4bdaa6ad0021f03f02f98d907cfbd1371..d14ea5d7be45dc27e1330a29027bb0c4d852a436 100644 --- a/test/dynamo/test_interop.py +++ b/test/dynamo/test_interop.py @@ -31,6 +31,7 @@ class InteropTests(torch._dynamo.test_case.TestCase): trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)]) self._common(lambda a, b: trace_fn(a, b) + 1) + @torch._dynamo.config.patch(capture_func_transforms=True) def test_vmap_in_graph(self): from functools import wraps @@ -48,13 +49,13 @@ class InteropTests(torch._dynamo.test_case.TestCase): cnts = torch._dynamo.testing.CompileCounter() x = torch.randn(3, 5, 3) - def fn1(x): + def fn(x): return torch.vmap(torch.Tensor.t)(x) - fn_opt = torch.compile(fn1, backend=cnts, fullgraph=True) - fn_opt_traceable = torch.compile(traceable(fn1), backend=cnts, fullgraph=True) + fn_opt = torch.compile(fn, backend=cnts, fullgraph=True) + fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True) - self.assertEqual(fn1(x), fn_opt(x)) + self.assertEqual(fn(x), fn_opt(x)) self.assertEqual(cnts.frame_count, 1) self.assertEqual(fn_opt(x), fn_opt_traceable(x)) self.assertEqual(cnts.frame_count, 2) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 4e67a38408c607075b4c2587d04ce3a73b7c0998..ae2db0e3a644ae608c01f6e4a78c981160033d1b 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import atexit import contextlib import functools import logging @@ -9,16 +8,19 @@ import unittest.mock import torch import torch_npu -import torchair import torch._dynamo.test_case import torch._dynamo.testing -import torch._inductor.lowering import torch.distributed as dist from torch._dynamo.testing import skipIfNotPy311 from torch.nn.parallel import DistributedDataParallel as DDP -from torch.testing._internal.common_utils import find_free_port, munge_exc +from torch.testing._internal.common_utils import ( + find_free_port, + munge_exc, + skipIfTorchDynamo, +) +from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.logging_utils import ( LoggingTestCase, make_logging_test, @@ -29,7 +31,6 @@ requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) -npu_backend = torchair.get_npu_backend() def example_fn(a): @@ -49,8 +50,7 @@ def inductor_error_fn(a): return output -def npubackend_schedule_fn(a): - # official func is inductor_schedule_fn +def inductor_schedule_fn(a): output = a.add(torch.ones(1000, 1000, device="npu:0")) return output @@ -91,7 +91,7 @@ class LoggingTests(LoggingTestCase): @requires_npu() @make_logging_test(schedule=True) def test_schedule(self, records): - fn_opt = torch._dynamo.optimize(npu_backend)(npubackend_schedule_fn) + fn_opt = torch._dynamo.optimize("npu")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device="npu:0")) self.assertGreater(len(records), 0) self.assertLess(len(records), 5) @@ -99,7 +99,15 @@ class LoggingTests(LoggingTestCase): @requires_npu() @make_logging_test(fusion=True) def test_fusion(self, records): - fn_opt = torch._dynamo.optimize(npu_backend)(npubackend_schedule_fn) + fn_opt = torch._dynamo.optimize("npu")(inductor_schedule_fn) + fn_opt(torch.ones(1000, 1000, device="npu:0")) + self.assertGreater(len(records), 0) + self.assertLess(len(records), 8) + + @requires_npu() + @make_logging_test(cudagraphs=True) + def test_cudagraphs(self, records): + fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device="npu:0")) self.assertGreater(len(records), 0) self.assertLess(len(records), 8) @@ -117,6 +125,7 @@ class LoggingTests(LoggingTestCase): test_dynamo_debug = within_range_record_test(30, 90, dynamo=logging.DEBUG) test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO) + @skipIfTorchDynamo("too slow") @make_logging_test(dynamo=logging.DEBUG) def test_dynamo_debug_default_off_artifacts(self, records): fn_opt = torch._dynamo.optimize("inductor")(example_fn) @@ -147,12 +156,13 @@ from user code: ) test_aot = within_range_record_test(2, 6, aot=logging.INFO) - test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG) + test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG) test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) @make_logging_test() def test_inductor_error(self, records): exitstack = contextlib.ExitStack() + import torch._inductor.lowering def throw(x): raise AssertionError() @@ -210,7 +220,7 @@ LoweringException: AssertionError: os.environ["MASTER_PORT"] = str(find_free_port()) dist.init_process_group("hccl", rank=0, world_size=1) - ddp_model = torch._dynamo.optimize(npu_backend)( + ddp_model = torch._dynamo.optimize("npu")( DDP(ToyModel().to("npu:0"), device_ids=[0], bucket_cap_mb=4) ) @@ -280,8 +290,10 @@ LoweringException: AssertionError: def test_dump_compile_times(self, records): fn_opt = torch._dynamo.optimize("inductor")(example_fn) fn_opt(torch.ones(1000, 1000)) - # explicitly invoke the atexit registered functions - atexit._run_exitfuncs() + # This function runs during exit via atexit.register. + # We're not actually going to run atexit._run_exit_funcs() here, + # because it'll destroy state necessary for other tests. + torch._dynamo.utils.dump_compile_times() self.assertEqual( len( [r for r in records if "TorchDynamo compilation metrics" in str(r.msg)] @@ -324,7 +336,7 @@ LoweringException: AssertionError: if torch._logging._internal._is_torch_handler(handler): break self.assertIsNotNone(handler) - self.assertIn("[INFO]", handler.format(records[0])) + self.assertIn("I", handler.format(records[0])) self.assertEqual("custom format", handler.format(records[1])) @make_logging_test(dynamo=logging.INFO) @@ -342,8 +354,8 @@ LoweringException: AssertionError: self.assertIsNotNone(handler) for record in records: r = handler.format(record) - for l_str in r.splitlines(): - self.assertIn("[INFO]", l_str) + for l_ in r.splitlines(): + self.assertIn("I", l_) test_trace_source_simple = within_range_record_test(1, 100, trace_source=True) @@ -632,6 +644,7 @@ L['zs'][0] == 3.0 # for y, z in zip( record_str, ) + @skipIfTorchDynamo("too slow") @make_logging_test(**torch._logging.DEFAULT_LOGGING) def test_default_logging(self, records): def fn(a): @@ -654,10 +667,33 @@ L['zs'][0] == 3.0 # for y, z in zip( len([r for r in records if "return a + 1" in r.getMessage()]), 0 ) + def test_logs_out(self): + import tempfile + + with tempfile.NamedTemporaryFile() as tmp: + env = dict(os.environ) + env["TORCH_LOGS"] = "dynamo" + env["TORCH_LOGS_OUT"] = tmp.name + stdout, stderr = self.run_process_no_exception( + """\ +import torch +@torch.compile(backend="eager") +def fn(a): + return a.sum() + +fn(torch.randn(5)) + """, + env=env, + ) + with open(tmp.name) as fd: + lines = fd.read() + self.assertEqual(lines, stderr.decode("utf-8")) + # single record tests exclusions = { "bytecode", + "cudagraphs", "output_code", "schedule", "fusion", @@ -678,6 +714,8 @@ exclusions = { "onnx_diagnostics", "guards", "verbose_guards", + "sym_node", + "export", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 189116b4bf0be49931ccd4ab2f64902ba19a501d..fcc6958c39bc120adaf2f6dd4d2fb26bfaf92021 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,17 +1,11 @@ # Owner(s): ["module: dynamo"] -import functools import unittest import torch import torch_npu import torch._dynamo from torch._dynamo.test_minifier_common import MinifierTestBase -requires_cuda = functools.partial( - unittest.skipIf, not torch.cuda.is_available(), "requires cuda" -) -requires_npu = functools.partial( - unittest.skipIf, not torch.npu.is_available(), "requires npu" -) +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") class MinifierTests(MinifierTestBase): @@ -46,38 +40,20 @@ inner(torch.randn(20, 20).to("{device}")) "cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" ) - @requires_cuda() - def test_after_dynamo_cuda_compile_error(self): - self._test_after_dynamo( - "cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError" - ) - @requires_npu() - def test_after_dynamo_npu_compile_error(self): + def test_after_dynamo_cuda_compile_error(self): self._test_after_dynamo( "npu", "relu_compile_error_TESTING_ONLY", "ReluCompileError" ) - @requires_cuda() - def test_after_dynamo_cuda_runtime_error(self): - self._test_after_dynamo( - "cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" - ) - @requires_npu() - def test_after_dynamo_npu_runtime_error(self): + def test_after_dynamo_cuda_runtime_error(self): self._test_after_dynamo( "npu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" ) - @requires_cuda() - def test_after_dynamo_cuda_accuracy_error(self): - self._test_after_dynamo( - "cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" - ) - @requires_npu() - def test_after_dynamo_npu_accuracy_error(self): + def test_after_dynamo_cuda_accuracy_error(self): self._test_after_dynamo( "npu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" ) @@ -117,117 +93,30 @@ inner(torch.randn(20, 20, requires_grad=True) + 1) "cpu", "relu_accuracy_error_TESTING_ONLY" ) - @requires_cuda() - def test_after_dynamo_cuda_compile_backend_passes(self): - self._test_after_dynamo_backend_passes( - "cuda", "relu_compile_error_TESTING_ONLY" - ) - @requires_npu() - def test_after_dynamo_npu_compile_backend_passes(self): + def test_after_dynamo_cuda_compile_backend_passes(self): self._test_after_dynamo_backend_passes( "npu", "relu_compile_error_TESTING_ONLY" ) - @requires_cuda() - def test_after_dynamo_cuda_runtime_backend_passes(self): - self._test_after_dynamo_backend_passes( - "cuda", "relu_runtime_error_TESTING_ONLY" - ) - @requires_npu() - def test_after_dynamo_npu_runtime_backend_passes(self): + def test_after_dynamo_cuda_runtime_backend_passes(self): self._test_after_dynamo_backend_passes( "npu", "relu_runtime_error_TESTING_ONLY" ) - @requires_cuda() - def test_after_dynamo_cuda_accuracy_backend_passes(self): - self._test_after_dynamo_backend_passes( - "cuda", "relu_accuracy_error_TESTING_ONLY" - ) - @requires_npu() - def test_after_dynamo_npu_accuracy_backend_passes(self): + def test_after_dynamo_cuda_accuracy_backend_passes(self): self._test_after_dynamo_backend_passes( "npu", "relu_accuracy_error_TESTING_ONLY" ) - # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd - @requires_cuda() + # Test that a module with mixed cpu/npu parts with an error after dynamo can be repro'd + @requires_npu() def test_cpu_cuda_module_after_dynamo(self): backend_name = "relu_compile_error_TESTING_ONLY" run_code = f"""\ class CpuCudaModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.m_x = torch.nn.Linear(20, 20).cuda() - self.m_y = torch.nn.Linear(20, 20) - self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) - self.p_y = torch.nn.Parameter(torch.randn(20, 20)) - self.register_buffer("b_x", torch.ones(20, 20).cuda()) - self.register_buffer("b_y", torch.ones(20, 20)) - - def forward(self, x, y): - return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y - -mod = CpuCudaModule() - -@torch._dynamo.optimize({backend_name!r}) -def inner(x1, y1): - x2 = torch.randn(20, 20).cuda() - y2 = torch.randn(20, 20) - x3, y3 = mod(x1 + x2, y1 + y2) - return torch.relu(x3.cpu() + y3) - -inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) -""" - - res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False) - - self.assertExpectedInline( - res.minifier_module(), - """\ -class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda() - self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True) - self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda()) - self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32)) - self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda")) - self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)) - - def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor): - l_x1_ = L_x1_ - l_y1_ = L_y1_ - randn = torch.randn(20, 20) - x2 = randn.cuda(); randn = None - y2 = torch.randn(20, 20) - add = l_x1_ + x2; l_x1_ = x2 = None - add_1 = l_y1_ + y2; l_y1_ = y2 = None - g__mod___m_x = self.G__mod___m_x(add); add = None - g__mod___p_x = self.G__mod___p_x - add_2 = g__mod___m_x + g__mod___p_x; g__mod___m_x = g__mod___p_x = None - g__mod___b_x = self.G__mod___b_x - x3 = add_2 + g__mod___b_x; add_2 = g__mod___b_x = None - g__mod___m_y = self.G__mod___m_y(add_1); add_1 = None - g__mod___p_y = self.G__mod___p_y - add_4 = g__mod___m_y + g__mod___p_y; g__mod___m_y = g__mod___p_y = None - g__mod___b_y = self.G__mod___b_y - y3 = add_4 + g__mod___b_y; add_4 = g__mod___b_y = None - cpu = x3.cpu(); x3 = None - add_6 = cpu + y3; cpu = y3 = None - relu = torch.relu(add_6); add_6 = None - return (relu,)""", - ) - - # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd - @requires_npu() - def test_cpu_npu_module_after_dynamo(self): - backend_name = "relu_compile_error_TESTING_ONLY" - run_code = f"""\ -class CpuNpuModule(torch.nn.Module): def __init__(self): super().__init__() self.m_x = torch.nn.Linear(20, 20).npu() @@ -240,7 +129,7 @@ class CpuNpuModule(torch.nn.Module): def forward(self, x, y): return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y -mod = CpuNpuModule() +mod = CpuCudaModule() @torch._dynamo.optimize({backend_name!r}) def inner(x1, y1): @@ -264,7 +153,7 @@ class Repro(torch.nn.Module): self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True) self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).npu()) self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32)) - self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="npu")) + self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="npu:0")) self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)) def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ccc979231224221f359482c12b79ab7e1bc482d9..366948a1ab5b491cdc13059005c01f6925f2cce9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -6,12 +6,15 @@ import dataclasses import dis import enum import functools +import gc +import io import itertools import logging import math import operator import os import random +import re import sys import tempfile import threading @@ -45,6 +48,7 @@ from torch._dynamo.testing import ( same, skipIfNotPy311, unsupported, + xfailIfPy311, ) from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code @@ -59,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import ( constrain_unify, ConstraintViolationError, expect_true, + guard_size_oblivious, ShapeEnv, ) from torch.nn import functional as F @@ -66,7 +71,6 @@ from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, SM80OrLater, - TEST_MULTIGPU, ) from torch.testing._internal.common_methods_invocations import ( sample_inputs_take_along_dim, @@ -75,10 +79,14 @@ from torch.testing._internal.common_utils import ( freeze_rng_state, IS_FBCODE, set_default_dtype, + wrapDeterministicFlagAPITest, ) from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.logging_utils import logs_to_string +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") TEST_NPU = torch.npu.is_available() +MULTINPU = TEST_NPU and torch.npu.device_count() >= 2 mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -96,6 +104,16 @@ def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable: return wrapper +def cleanup_op(opname): + ns, name = opname.split("::") + if not hasattr(torch.ops, ns): + return + actual_ns = getattr(torch.ops, ns) + if not hasattr(actual_ns, name): + return + delattr(actual_ns, name) + + class MyPickledModule(torch.nn.Module): def __init__(self, z): super().__init__() @@ -119,6 +137,13 @@ uniform_qconfig_8bit = QConfig( qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]} +def closure_adder(val): + def inner(x): + return torch.sin(x + val) + + return inner + + class MiscTests(torch._dynamo.test_case.TestCase): def test_get_cache_entry(self): def f(x): @@ -139,6 +164,17 @@ class MiscTests(torch._dynamo.test_case.TestCase): except TypeError as e: self.assertIn("expected a code object!", str(e)) + # test get cache entry on skipped code object + def h(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 1 + + torch.compile(h)(torch.randn(3, 3)) + + entries = _debug_get_cache_entry_list(torch._dynamo.graph_break) + self.assertEqual(len(entries), 0) + def test_boolarg(self): def boolarg(aa, bb, flag): if flag: @@ -163,6 +199,17 @@ class MiscTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(val4, correct1)) self.assertEqual(counter.frame_count, 3) + def test_invalid_args_builtin(self): + @torch.compile(backend="eager") + def fn(x): + x = x.sin() + if isinstance(x, torch.Tensor, invalid=True): + x = x.sin() + return x + + with self.assertRaises(TypeError): + fn(torch.randn(16)) + def test_callpacked(self): def call_packed(args): a, b, c = args @@ -320,7 +367,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) _ = optimized_g(x) finally: - del torch.ops.mylib.bar + cleanup_op("mylib::bar") del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) @@ -356,7 +403,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) y = optimized_g(x) finally: - del torch.ops.mylib.bar2 + cleanup_op("mylib::bar2") del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) @@ -412,7 +459,84 @@ class MiscTests(torch._dynamo.test_case.TestCase): y = optimized_h(x) finally: - del torch.ops.mylib.bar3 + cleanup_op("mylib::bar3") + del lib + + def test_auto_functionalize_can_with_default(self): + lib = torch.library.Library("mylib", "FRAGMENT") + torch.library.define( + "mylib::foo", + "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo_impl(a, b, c=None, d=None, e=-1): + a + b + return + + def f(a, mode): + return torch.ops.mylib.foo( + a, + 0, + ) + + a = torch.tensor([10, 10, 10], dtype=torch.int64) + + torch.compile(f)(a, 0) + + cleanup_op("mylib::foo") + del lib + + def test_closure_recompiles(self): + cnt = CompileCounter() + + def fn(x, other_fn): + return other_fn(x + 1) - 1 + + opt = torch.compile(fn, backend=cnt, fullgraph=True) + + x = torch.randn(8) + for f in ( + closure_adder(5), + closure_adder(5), + closure_adder(torch.randn(8)), + closure_adder(torch.randn(8)), + ): + self.assertEqual(opt(x, f), fn(x, f)) + + self.assertEqual(cnt.frame_count, 2) + + def test_generate_trivial_abstract_impl(self): + try: + lib = torch.library.Library("mylib", "FRAGMENT") + torch.library.define( + "mylib::foo", + "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w): + x + y[0] + w + return + + def f(x, y, z, w): + return torch.ops.mylib.foo(x, y, z, 2) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + w = torch.randn(3) + args = (x, y, z, w) + + output = torch.compile(f, backend="eager", fullgraph=True)(*args) + self.assertEqual(output, None) + finally: + cleanup_op("mylib::foo") del lib def test_can_auto_functionalize(self): @@ -422,13 +546,18 @@ class MiscTests(torch._dynamo.test_case.TestCase): "(Tensor(a!) x) -> ()", "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", ] expected_false = [ "(Tensor x) -> ()", "(Tensor(a) x) -> Tensor(a)", "(Tensor(a!) x) -> Tensor(a!)", "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", ] for schema in expected_true: try: @@ -439,7 +568,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): ) self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) finally: - del torch.ops.mylib.a + cleanup_op("mylib::a") del lib for schema in expected_false: try: @@ -450,7 +579,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): ) self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) finally: - del torch.ops.mylib.a + cleanup_op("mylib::a") del lib def test_auto_functionalize(self): @@ -479,15 +608,95 @@ class MiscTests(torch._dynamo.test_case.TestCase): orig_args = (x, y, z, n) compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - torch.compile(f, backend="aot_eager_decomp_partition", fullgraph=True)( - *compiled_args + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # Check the graph under static shapes + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None + return ()""", + ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) f(*eager_args) self.assertEqual(compiled_args, eager_args) finally: - del torch.ops.mylib.foo + cleanup_op("mylib::foo") + del lib + + def test_auto_functionalize_with_returns(self): + try: + lib = torch.library.Library("mylib", "FRAGMENT") + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + return y[0] + w, y[1] + n + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_abstract(x, y, z, w, n): + return y[0] + w, y[1] + n + + def f(x, y, z, n): + return torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( + *compiled_args + ) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None + getitem_4: "f32[3]" = foo_default[0] + getitem_5: "f32[3]" = foo_default[1]; foo_default = None + return (getitem_4, getitem_5)""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + eager_out = f(*eager_args) + self.assertEqual(compiled_args, eager_args) + self.assertEqual(compiled_out, eager_out) + finally: + cleanup_op("mylib::foo") del lib def test_auto_functionalize_on_view(self): @@ -522,8 +731,8 @@ class MiscTests(torch._dynamo.test_case.TestCase): y = f(x) self.assertEqual(y, x.sin()) finally: + cleanup_op("mylib::foo") del lib - del torch.ops.mylib.foo def test_auto_functionalize_optional(self): try: @@ -553,15 +762,30 @@ class MiscTests(torch._dynamo.test_case.TestCase): orig_args = (x, y, z, n) compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - torch.compile(f, backend="aot_eager_decomp_partition", fullgraph=True)( - *compiled_args + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(None, [arg0_1, arg3_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg1_1 = arg2_1 = None + return ()""", + ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) f(*eager_args) self.assertEqual(compiled_args, eager_args) finally: - del torch.ops.mylib.foo + cleanup_op("mylib::foo") del lib def test_shape_int_inplace_binops(self): @@ -736,6 +960,42 @@ class MiscTests(torch._dynamo.test_case.TestCase): else: self.assertExpectedInline(counts.op_count, """4""") + def test_user_defined_iter(self): + class Mod: + def __init__(self): + self.a = [torch.randn(2, 2), torch.randn(2, 2)] + + def __iter__(self): + return iter(self.a) + + def f(mod): + ret = [] + for x in mod: + ret.append(x + 1) + return ret + + mod = Mod() + counts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(counts, nopython=True)(f) + ref = f(mod) + res = opt_fn(mod) + res = opt_fn(mod) + res = opt_fn(mod) + res = opt_fn(mod) + self.assertTrue(same(ref, res)) + self.assertEqual(counts.frame_count, 1) + + mod.a.append(torch.randn(2, 2)) + # `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a + # Mutating length of mod.a causes a re-compilation. + ref2 = f(mod) + res2 = opt_fn(mod) + res2 = opt_fn(mod) + res2 = opt_fn(mod) + res2 = opt_fn(mod) + self.assertTrue(same(ref2, res2)) + self.assertEqual(counts.frame_count, 2) + def test_compare_shapes_eq(self): def compare_shapes(a, b, to_list): x = list(a.unsqueeze(-1).shape) if to_list else a.shape @@ -847,6 +1107,24 @@ class MiscTests(torch._dynamo.test_case.TestCase): torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) + @unittest.skipIf(sys.version_info[:2] <= (3, 8), "Requires astunparse") + def test_cse_dict_guards(self): + def fn(x): + ret = torch.zeros(3) + for v in x.values(): + ret = ret + v + return ret + + from torch._dynamo.guards import build_guard_function, CLOSURE_VARS + + x = {3: torch.randn(3), 2: torch.randn(3), 4: torch.randn(3)} + _, guards = torch._dynamo.export(fn, x) + + code_lists = [c for g in guards for c in g.code_list or []] + _, pycode = build_guard_function(code_lists, []) + # Make sure we just call "list(dict.keys())" once + self.assertEqual(pycode.count("keys"), 1) + def test_sys_modules(self): def fn(x, y): mod_a = sys.modules.get("aaaaaaaa") @@ -871,9 +1149,7 @@ class MiscTests(torch._dynamo.test_case.TestCase): # Filter out id-matches that won't reproduce run to run guard_code = filter( - lambda line: not any( - banned in line for banned in ["id", "lookup_backend", "config_hash"] - ), + lambda line: "id" not in line and "lookup_backend" not in line, sorted(guard_code), ) guard_code_str = "\n".join(guard_code) @@ -908,6 +1184,18 @@ utils_device.CURRENT_DEVICE == None""".split( torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) + def test_getattr_dict(self): + def fn(x): + from torch.masked.maskedtensor._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE + + return x * len(_MASKEDTENSOR_FUNCTION_TABLE) + + i = torch.randn(5) + r1 = fn(i) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + r2 = opt_fn(i) + self.assertEqual(r1, r2) + def test_shape_unpack(self): def fn(x): a, b = x.size() @@ -919,6 +1207,16 @@ utils_device.CURRENT_DEVICE == None""".split( r2 = opt_fn(i) self.assertTrue(same(r1, r2)) + def test_typing_dict(self): + def fn(d): + return d[T] + + d = {T: torch.randn(3)} + r1 = fn(d) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + r2 = opt_fn(d) + self.assertEqual(r1, r2) + def test_tensor_iter(self): def fn(x): for y in x: @@ -1009,6 +1307,8 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.frame_count, 1) @torch._dynamo.config.patch(capture_scalar_outputs=True) + # Translation validation changes the exception type, don't run with it + @torch.fx.experimental._config.patch(translation_validation=False) def test_torch_check_is_size(self): cnts = torch._dynamo.testing.CompileCounter() @@ -1016,15 +1316,13 @@ utils_device.CURRENT_DEVICE == None""".split( def f(x): y = x.item() torch._check_is_size(y) - # unsound 0/1 specialization! + # Cannot conditional on unbacked SymInt if y == 0: assert False else: return torch.arange(0, y) - f(torch.tensor([3])) - f(torch.tensor([4])) - self.assertEqual(cnts.frame_count, 1) + self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) def test_config_obj(self): class Cfg: @@ -1090,7 +1388,7 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(opt_fn(v, v.size())[0, 0], -10) self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) - # One recompile per differing input type + # One recompile per differing ipt type self.assertEqual(cnts.frame_count, 3) def test_cell_output1(self): @@ -1271,7 +1569,7 @@ utils_device.CURRENT_DEVICE == None""".split( else: return torch.ones([2, 3]) - x1 = {"input": torch.rand(2, 3)} + x1 = {"ipt": torch.rand(2, 3)} x2 = torch.rand(2, 3) ref1 = fn(x1) ref2 = fn(x2) @@ -1516,6 +1814,25 @@ utils_device.CURRENT_DEVICE == None""".split( self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18) ) + def test_list_iadd_side_effect(self): + def fn(a, b): + a += [b] + torch._dynamo.graph_break() + return a + + a = [1, 2, 3] + b = torch.ones(2, 2) + + opt_fn = torch._dynamo.optimize("eager")(fn) + + exp = fn(a, b) + + a = [1, 2, 3] + b = torch.ones(2, 2) + act = opt_fn(a, b) + + self.assertEqual(exp, act) + def test_user_getattr1(self): class MyConfig(dict): def __getattr__(self, name): @@ -1779,6 +2096,17 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + def test_numpy_subdtype(self): + def fn(x, n): + return np.issubdtype(type(n), np.integer) + x + + args = [torch.randn(10), 4096] + correct = fn(*args) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + self.assertEqual(opt_fn(*args), correct) + self.assertEqual(cnts.frame_count, 1) + def test_numpy_take_along_axis(self): def fn(x, i, a): return np.take_along_axis(x, i, a) @@ -2290,7 +2618,7 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(fn().shape, (13,)) def test_inplace_view_on_graph_input(self): - # graph break when calling methods with inplace_view tag on graph input + # graph break when calling methods with inplace_view tag on graph ipt func_args_map = { lambda x: x.resize_(6).mul_(2): torch.ones(4), lambda x: x.t_().mul_(2): torch.rand(2, 3), @@ -2324,6 +2652,29 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertTrue(same(fn(x, y), opt_fn(x.clone(), y.clone()))) self.assertEqual(cnts.frame_count, 1) + def test_out_variants_with_resizing_on_graph_inputs_with_dynamic(self): + # See pytorch/pytorch/issues/120482 + class CustomModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + return torch.outer(**inputs) + + compile_fn = torch.compile(CustomModel(), fullgraph=True) + + shapes = [(2, 1), (6, 1), (4, 1)] + for shape in shapes: + vec1, vec2 = shape + input_tensor1 = torch.randn(vec1) + input_tensor2 = torch.randn(vec2) + out_tensor = torch.empty(shape) + args = {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor} + res = compile_fn(args) + opt_res = res.clone() # cuz this is out and we mutate it + res = CustomModel()(args) + self.assertEqual(res, opt_res) + def test_dict_mutation_side_effect(self): def fn(d): d["c"] = d["a"] + d.pop("b") @@ -2339,6 +2690,95 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 1) + def test_dict_order_keys(self): + def fn(d): + return d["a"] - d["b"] + + args1 = {} + args1["a"] = torch.rand(10) + args1["b"] = torch.rand(10) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + self.assertEqual(fn(args1), opt_fn(args1)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 1) + # A different order of keys recompiles + args2 = {} + args2["b"] = args1["b"] + args2["a"] = args1["a"] + self.assertEqual(fn(args2), opt_fn(args2)) + self.assertEqual(cnts.frame_count, 2) + # Extra calls don't recompile + self.assertEqual(cnts.frame_count, 2) + + def test_dict_namedtuple(self): + def fn(d): + return d[3] * 2 + + args1 = {collections.namedtuple: None, 3: torch.randn(3)} + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + self.assertEqual(fn(args1), opt_fn(args1)) + self.assertEqual(cnts.frame_count, 1) + # Test a failing namedtuple guard + args2 = {2: None, 3: torch.randn(3)} + self.assertEqual(fn(args2), opt_fn(args2)) + self.assertEqual(cnts.frame_count, 2) + + def test_dict_order_keys_tensors(self): + def fn(d, x): + return d[x] + 3 + + args1 = {} + x = torch.randn(10) + y = torch.randn(10) + z = torch.randn(10) + args1[x] = y + args1[3] = z + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + self.assertEqual(fn(args1, x), opt_fn(args1, x)) + self.assertEqual(cnts.frame_count, 1) + + # Calling again doesn't recompile (same id and key order) + opt_fn(args1, x) + self.assertEqual(cnts.frame_count, 1) + args2 = {} + args2[3] = z + args2[x] = y + + # Different order recompiles + self.assertEqual(fn(args2, x), opt_fn(args2, x)) + self.assertEqual(cnts.frame_count, 2) + + def test_dict_order_keys_modules(self): + def fn(d, x): + return d[x](torch.ones(2, 2)) + + args1 = {} + x = torch.nn.Linear(2, 2) + y = torch.nn.Linear(2, 2) + z = torch.nn.Linear(2, 2) + args1[x] = y + args1[3] = z + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + self.assertEqual(fn(args1, x), opt_fn(args1, x)) + self.assertEqual(cnts.frame_count, 1) + + # Calling again doesn't recompile (same id and key order) + opt_fn(args1, x) + self.assertEqual(cnts.frame_count, 1) + args2 = {} + args2[3] = z + args2[x] = y + + # Different order recompiles + self.assertEqual(fn(args2, x), opt_fn(args2, x)) + self.assertEqual(cnts.frame_count, 2) + def test_dunder_new_function_inlining(self): # See pytorch/pytorch/issues/107460 @@ -2837,6 +3277,27 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.op_count, 3) cnts.clear() + def test_closure_with_mutation_and_graph_break(self): + def fn(): + x = torch.zeros(1) + + def subfunc(): + x[0] = backup + + if x[0] >= -1e5: + pass + + backup = 1 + subfunc() + return x + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + expected = fn() + actual = opt_fn() + self.assertTrue(same(expected, actual)) + self.assertEqual(cnts.frame_count, 2) + def test_closure_out_of_scope_cell_with_cond(self): # Test closure with out-of-scope cell variable, used in a cond # where the two branches read different closure variables @@ -3195,7 +3656,7 @@ utils_device.CURRENT_DEVICE == None""".split( # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) - @unittest.skipIf(not TEST_NPU, "requires npu") + @requires_npu() def test_rand(self): cnts = torch._dynamo.testing.CompileCounter() device = "npu:0" @@ -3265,6 +3726,77 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertTrue(same(ref, res)) + def test_source_non_input_grad_access(self): + # This test creates a model, and accesses the grads + # from its parameter. This means that within dynamo, + # the tensor we are reading the grad from HAS a source, + # but is not known to graphargs. + cnts = torch._dynamo.testing.CompileCounter() + + class TrivialModel(torch.nn.Module): + def __init__(self): + super(TrivialModel, self).__init__() + self.linear = torch.nn.Linear(2, 1) + + def forward(self, x): + return self.linear(x) + + def fn(a, b): + outs = [] + for param in model.parameters(): + outs.append(torch.ones(param.grad.size())) + return outs, param.grad + 1 + + model = TrivialModel() + # Eager + a = torch.ones([2, 2], requires_grad=True) + b = torch.ones([2, 2]) + out = model(a) + out_sum = out.sum() + out_sum.backward() + ref = fn(a, b) + + # Compiled + model = TrivialModel() + a = torch.ones([2, 2], requires_grad=True) + b = torch.ones([2, 2]) + out = model(a) + out_sum = out.sum() + out_sum.backward() + + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + res = opt_fn(a, b) + + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 3) + + def test_intermediary_tensor_grad_access(self): + # This test creates a model, and accesses the grads + # from its parameters and an entirely intermediary tensor. + cnts = torch._dynamo.testing.CompileCounter() + + def fn(a, b): + intermediary = torch.ones(2, 2) + c = a + intermediary + outs = [] + outs.append(intermediary.grad) + return outs + + # Eager + a = torch.ones([2, 2], requires_grad=True) + b = torch.ones([2, 2]) + ref = fn(a, b) + + # Compiled + a = torch.ones([2, 2], requires_grad=True) + b = torch.ones([2, 2]) + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + res = opt_fn(a, b) + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 2) + @skipIfNotPy311 def test_linetable_311_writer1(self): def fn(): @@ -3297,13 +3829,13 @@ utils_device.CURRENT_DEVICE == None""".split( x0 = 1 x1 = 1 ... - l = [x0, x1, ...] + l_ = [x0, x1, ...] """ fn_str = f"""\ def fn(): foo.bar(1, 2, 3) {str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} - l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] + l_ = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] """ locals = {} exec(fn_str, {}, locals) @@ -3358,10 +3890,10 @@ def fn(): def test_tensor_is_contiguous(self): def fn(x): - input = torch.randn((1, 16, 1, 1)) + ipt = torch.randn((1, 16, 1, 1)) weight = torch.randn((8, 16, 3, 3)) weight = weight.to(memory_format=x) - output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) + output = torch.conv2d(ipt, weight, None, (2, 1), (1, 1), (1, 1), 1) return output.is_contiguous(memory_format=x) opt_fn = torch._dynamo.optimize("eager")(fn) @@ -3369,15 +3901,15 @@ def fn(): self.assertEqual(fn(x), opt_fn(x)) def test_python_slice(self): - def f1(input): + def f1(ipt): y = 0 - for i, x in enumerate(input[2:], 1): + for i, x in enumerate(ipt[2:], 1): y = y + x return y - def f2(input): + def f2(ipt): y = 0 - for i, x in enumerate(input.shape[2:], 1): + for i, x in enumerate(ipt.shape[2:], 1): y = y + x return y @@ -3450,11 +3982,19 @@ def fn(): def test_const_dict_variable_python_type(self): from torch._dynamo.variables import ConstantVariable, ConstDictVariable - d1 = {"a": ConstantVariable.create(10), "b": ConstantVariable.create(20)} + make_key = ConstantVariable.create + + d1 = { + make_key("a"): ConstantVariable.create(10), + make_key("b"): ConstantVariable.create(20), + } d2 = collections.OrderedDict( - [("x", ConstantVariable.create(12)), ("y", ConstantVariable.create(22))] + [ + (make_key("x"), ConstantVariable.create(12)), + (make_key("y"), ConstantVariable.create(22)), + ] ) - self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict) + self.assertEqual(ConstDictVariable(d1).python_type(), dict) self.assertEqual( ConstDictVariable(d2, collections.OrderedDict).python_type(), collections.OrderedDict, @@ -3631,14 +4171,14 @@ def fn(): self.assertExpectedInline(cnts.op_count, """2""") def test_inline_func_jump_on_tensor_condition(self): - def f1(input): - if input == 0: - return input + 1 + def f1(ipt): + if ipt == 0: + return ipt + 1 else: - return input + 2 + return ipt + 2 - def f2(input): - return f1(input) + def f2(ipt): + return f1(ipt) cnts = torch._dynamo.testing.CompileCounter() opt_f2 = torch._dynamo.optimize(cnts)(f2) @@ -3701,6 +4241,49 @@ def fn(): res2 = opt_f2() self.assertTrue(same(res1, res2)) + def test_inline_local_dict_clear(self): + def f(d): + d.clear() + return d + + inp = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} + out = torch.compile(f, backend="eager", fullgraph=True)(inp) + self.assertEqual(len(out), 0) + self.assertEqual(len(inp), 0) + + def test_inline_module_attr_dict_clear(self): + class MyMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} + + def forward(self): + self.a.clear() + return self.a + + m = MyMod() + out = torch.compile(m, backend="eager", fullgraph=True)() + self.assertEqual(len(out), 0) + self.assertEqual(len(m.a), 0) + + def test_inline_user_defined_dict_attr_clear(self): + class MyMod: + def __init__(self): + self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)} + + def f(obj, inp): + ret = len(obj.a) + inp + obj.a.clear() + return obj.a, ret + + m = MyMod() + before_len = len(m.a) + t_inp = torch.ones(1) + d, ret = torch.compile(f, backend="eager", fullgraph=True)(m, t_inp) + self.assertEqual(len(m.a), 0) + self.assertEqual(len(d), 0) + self.assertEqual(ret, t_inp + before_len) + def test_recursive_inline_list_mutation(self): def f1(x, y): x.append(torch.tensor([1.1])) @@ -3791,6 +4374,22 @@ def fn(): self.assertIsNone(mod_ref(), None) self.assertIsNone(mod_weight_ref(), None) + def test_release_scope_memory(self): + def inner(y): + y + + inner = torch._dynamo.optimize("eager")(inner) + + p_ref = None + + x = torch.randn((10, 10)) + inner(x) + + p_ref = weakref.ref(x) + self.assertTrue(p_ref() is not None) + del x + self.assertTrue(p_ref() is None) + def test_update_locals_and_stack_uses_shared_cache(self): def fn(x): perm = [0, 3, 5] @@ -3954,14 +4553,14 @@ def fn(): weight=rand_5, reduce=False, label_smoothing=0.5 ) opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) - input = rand_3_5 - dynamo_output = opt_loss(input, target) + ipt = rand_3_5 + dynamo_output = opt_loss(ipt, target) loss = torch.nn.CrossEntropyLoss( weight=rand_5, reduce=False, label_smoothing=0.5 ) - input = rand_3_5 - output = loss(input, target) + ipt = rand_3_5 + output = loss(ipt, target) self.assertTrue(torch.allclose(dynamo_output, output)) @@ -3971,12 +4570,12 @@ def fn(): loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) - input = rand_3_5 - dynamo_output = opt_loss(input, target) + ipt = rand_3_5 + dynamo_output = opt_loss(ipt, target) loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) - input = rand_3_5 - output = loss(input, target) + ipt = rand_3_5 + output = loss(ipt, target) self.assertTrue(torch.allclose(dynamo_output, output)) @@ -3987,12 +4586,12 @@ def fn(): loss = torch.nn.CrossEntropyLoss() opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) - input = rand_3_5 - dynamo_output = opt_loss(input, target) + ipt = rand_3_5 + dynamo_output = opt_loss(ipt, target) loss = torch.nn.CrossEntropyLoss() - input = rand_3_5 - output = loss(input, target) + ipt = rand_3_5 + output = loss(ipt, target) self.assertTrue(torch.allclose(dynamo_output, output)) @@ -4230,7 +4829,7 @@ def fn(): self.assertEqual(exp_out, opt_out) self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual( - len(torch._dynamo.eval_frame.guarded_backend_cache.cached_backends), + len(torch._dynamo.eval_frame.cached_backends), exp_n_cached_backend, ) @@ -4279,10 +4878,11 @@ def fn(): def foo(x): return x.sin() + x.cos() - def compile_then_check_exp(*args): - self._optimize_then_check_exp(*args) - self._optimize_then_check_exp(*args) - self._optimize_then_check_exp(*args) + def compile_then_check_exp(foo, args, cnt, eager_result, exp_frame_count): + for i in range(3): + opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args) + self.assertEqual(opt_out, eager_result) + self.assertEqual(cnt.frame_count, exp_frame_count) thread_success[threading.current_thread()] = True eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs() @@ -4291,7 +4891,6 @@ def fn(): # Test dynamo recompiles but only caches a single backend for each thread eager_result = foo(x) # cnt and None - exp_n_cached_backend = 2 exp_frame_count = 1 threads = [] thread_success = {} @@ -4305,7 +4904,6 @@ def fn(): cnt, eager_result, exp_frame_count, - exp_n_cached_backend, ), ) threads.append(thread) @@ -4315,6 +4913,12 @@ def fn(): for thread in threads: thread.join() + # Threads are sharing the backend cache. We see two cnt backends and one None backend + self.assertEqual( + len(torch._dynamo.eval_frame.cached_backends), + 3, + ) + self.assertEqual(len(thread_success), len(threads)) def test_dynamo_min_operator_with_shape(self): @@ -4439,7 +5043,7 @@ def fn(): ) ) - # 1D input + # 1D ipt input_tensor = torch.tensor([0, 8]) static_size = 1 self.assertTrue( @@ -4458,7 +5062,7 @@ def fn(): ) ) - # 2D input + # 2D ipt input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) static_size = 5 fill_value = -100 @@ -4490,7 +5094,7 @@ def fn(): ) ) - # 3D input + # 3D ipt input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]]) static_size = 4 fill_value = -999 @@ -4831,7 +5435,7 @@ def fn(): out = F.scaled_dot_product_attention(query, key, value, None, scale=8) return out - device = "npu:0" + device = "cuda" dtype = torch.float16 seq_len_q = 1 seq_len_k = 1 @@ -5029,34 +5633,34 @@ def fn(): self.assertTrue(same(real, dynamo_result)) def test_error_on_nested_fx_trace(self): - input = torch.rand(2, 3) + ipt = torch.rand(2, 3) def f(x): x + x - real = f(input) + real = f(ipt) optimized = torch._dynamo.optimize("eager")(f) - self.assertTrue(same(optimized(input), real)) + self.assertTrue(same(optimized(ipt), real)) with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): gm = torch.fx.symbolic_trace(optimized) @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) def test_no_error_on_nested_fx_trace(self): - input = torch.rand(2, 3) + ipt = torch.rand(2, 3) def f(x): x + x - real = f(input) + real = f(ipt) optimized = torch._dynamo.optimize("eager")(f) - self.assertTrue(same(optimized(input), real)) + self.assertTrue(same(optimized(ipt), real)) # should not error gm = torch.fx.symbolic_trace(optimized) - self.assertTrue(same(gm(input), real)) + self.assertTrue(same(gm(ipt), real)) def test_not_dynamic_scope(self): def f(y): @@ -5068,10 +5672,10 @@ def fn(): return y + g()() - input = torch.zeros(1) - real = f(input) + ipt = torch.zeros(1) + real = f(ipt) optimized = torch._dynamo.optimize("eager")(f) - opt = optimized(input) + opt = optimized(ipt) self.assertTrue(same(opt, real)) def test_inference_mode(self): @@ -5304,7 +5908,7 @@ def fn(): res = opt_fn(x) self.assertTrue(same(ref, res)) - @unittest.skipIf(not TEST_NPU, "requires npu") + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") def test_torch_cudnn_is_acceptable(self): def fn(x): @@ -5312,13 +5916,13 @@ def fn(): return x + 1 return x - x = torch.rand(4).npu() + x = torch.rand(4).cuda() ref = fn(x) opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) res = opt_fn(x) self.assertTrue(same(ref, res)) - @unittest.skipIf(not TEST_NPU, "requires npu") + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") def test_torch_cudnn_is_acceptable_bad_inputs(self): def fn1(x): @@ -5332,20 +5936,20 @@ def fn(): return x with self.assertRaisesRegex( - AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" + AssertionError, "Expect ipt to cudnn.is_acceptable to be a tensor" ): - x1 = torch.rand(4).npu() + x1 = torch.rand(4).cuda() opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1) res1 = opt_fn1(x1) with self.assertRaisesRegex( - AssertionError, "Expect 1 input to cudnn.is_acceptable" + AssertionError, "Expect 1 ipt to cudnn.is_acceptable" ): - x2 = torch.rand(4).npu() + x2 = torch.rand(4).cuda() opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2) res = opt_fn2(x2) - @unittest.skipIf(not TEST_NPU, "requires npu") + @requires_npu() def test_get_device(self): def fn(x, y): x = x + 1 @@ -5373,16 +5977,35 @@ def fn(): self.assertEqual(cnt.frame_count, 0) def test_is_compiling(self): - def f(): + def f1(): if torch._dynamo.is_compiling(): return torch.ones(2, 2) else: return torch.zeros(2, 2) - opt_f = torch._dynamo.optimize("eager")(f) + def f2(): + if torch._utils.is_compiling(): + return torch.ones(2, 2) + else: + return torch.zeros(2, 2) + + def f3(): + if torch.compiler.is_compiling(): + return torch.ones(2, 2) + else: + return torch.zeros(2, 2) + + def f4(): + if torch.compiler.is_dynamo_compiling(): + return torch.ones(2, 2) + else: + return torch.zeros(2, 2) + + for f in [f1, f2, f3, f4]: + opt_f = torch._dynamo.optimize("eager")(f) - self.assertEqual(f(), torch.zeros(2, 2)) - self.assertEqual(opt_f(), torch.ones(2, 2)) + self.assertEqual(f(), torch.zeros(2, 2)) + self.assertEqual(opt_f(), torch.ones(2, 2)) def test_torch_generator_set_state(self): def fn(): @@ -5479,7 +6102,7 @@ def fn(): first_guard_failure, ) else: - self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure) + self.assertIn("""2 <= L['x'].size()[0] <= 2""", first_guard_failure) def test_guard_failure_fn2(self): def fn(x, y): @@ -5667,6 +6290,18 @@ def fn(): fn(inputs, iter(tuple(inputs))) + def fn(params): + y = tuple(params) + return inner_fn(*y) + + opt_fn = torch._dynamo.optimize("eager")(fn) + inputs = [torch.randn(10, 10) for _ in range(3)] + self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) + + # Force recompilation + inputs = [torch.randn(10, 10) for _ in range(4)] + self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs))))) + def test_torch_package_working_with_trace(self): # from torch._dynamo.test_case import run_tests @@ -5729,7 +6364,7 @@ def fn(): import builtins # Cache the original builtin function ids - torch._dynamo.allowed_functions._builtin_function_ids() + torch._dynamo.trace_rules._builtin_function_ids() class MyClass: pass @@ -5892,7 +6527,7 @@ def fn(): self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"]) def test_tagging_tensors_mix_used_unused_structure(self): - def pre_attention_state_ops(input, mems, state): + def pre_attention_state_ops(ipt, mems, state): lc_key = state[0] lc_val = state[1] bar = [] @@ -6489,6 +7124,20 @@ def fn(): with self.assertRaises(ConstraintViolationError): torch._dynamo.optimize("eager")(my_dyn_fn)(y) + # Translation validation changes the exception type, don't run with it + @torch.fx.experimental._config.patch(translation_validation=False) + def test_mark_dynamic_with_ranges(self): + y = torch.randn([8, 3, 3]) + + def my_dyn_fn(x): + if x.shape[0] == 3: + return x.sin() + return x.cos() + + torch._dynamo.mark_dynamic(y, 0, min=2, max=5) + with self.assertRaises(ConstraintViolationError): + torch._dynamo.optimize("eager")(my_dyn_fn)(y) + def test_mark_static(self): counter = CompileCounter() @@ -6882,8 +7531,8 @@ def fn(): @torch._dynamo.config.patch(automatic_dynamic_shapes=False) def test_compile_profiler(self): class Model(torch.nn.Module): - def forward(self, input): - return input + input + def forward(self, ipt): + return ipt + ipt model = Model() prof = CompileProfiler() @@ -6895,8 +7544,8 @@ def fn(): .check("No graph breaks detected.") .check("Recompilation") ) - input = torch.rand((2, 3, 4)) - _ = compiled(input) + ipt = torch.rand((2, 3, 4)) + _ = compiled(ipt) base_checker().check("No recompilation detected.").run(prof.report()) new_shape_input = torch.rand((3, 3, 4)) @@ -6914,9 +7563,9 @@ def fn(): _ = compiled(new_shape_input) base_checker().check("Recompile Reasons").check("'forward'").check( - "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" + "tensor 'L['ipt']' size mismatch at index 0. expected 2, actual 3" ).check( - "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" + "tensor 'L['ipt']' size mismatch at index 0. expected 3, actual 4" ).run( prof.report() ) @@ -6997,8 +7646,8 @@ def fn(): self.assertTrue(isinstance(compile_out, torch.Size)) self.assertEqual(eager_out, compile_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def TEST_NPU_set_device(self): + @unittest.skipIf(not MULTINPU, "need multiple npu") + def test_npu_set_device(self): def fn(): a = torch.ones(2, device="npu") torch.npu.set_device(1) @@ -7263,15 +7912,15 @@ def ___make_guard_fn(): super().__init__() self.jitter_val = jitter_val - def roll_tensor(self, input): + def roll_tensor(self, ipt): h_shift = self.jitter_val - 1 w_shift = self.jitter_val + 1 return torch.roll( - torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3 + torch.roll(ipt, shifts=h_shift, dims=2), shifts=w_shift, dims=3 ) - def forward(self, input): - return self.roll_tensor(input) + def forward(self, ipt): + return self.roll_tensor(ipt) x = torch.rand([4, 4, 4, 4]) m = Jitter(jitter_val=4) @@ -7282,9 +7931,9 @@ def ___make_guard_fn(): def test_scalar_tensor_is_equivalent_to_int_list_argument(self): class MyModel(torch.nn.Module): - def forward(self, input): + def forward(self, ipt): permute = torch.tensor([0, 2, 1]) - x = input.permute(*permute) + x = ipt.permute(*permute) return x x = torch.randn(2, 3, 4) @@ -7333,6 +7982,61 @@ def ___make_guard_fn(): compiled_out = compiled_fn() self.assertTrue(same(fn_out, compiled_out)) + def test_tuple_hasattr(self): + def fn(x): + if hasattr(x, "foo"): + return x[0] + 1 + return x[1] - 1 + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + x = (torch.randn(3), torch.randn(3)) + fn_out = fn(x) + compiled_out = compiled_fn(x) + self.assertTrue(same(fn_out, compiled_out)) + + def test_fn_hasattr__name__1(self): + def fn(): + foo = lambda x: x + 1 + return hasattr(foo, "__name__") + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + fn_out = fn() + compiled_out = compiled_fn() + self.assertEqual(fn_out, compiled_out) + self.assertTrue(fn_out) + + def test_fn_hasattr__name__2(self): + def bar(x): + return torch.sin(x) + + def fn(): + return hasattr(bar, "__name__") + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + fn_out = fn() + compiled_out = compiled_fn() + self.assertEqual(fn_out, compiled_out) + self.assertTrue(fn_out) + + def test_fn_hasattr__name__3(self): + def bar(x, y): + return torch.sin(x) + torch.cos(y) + + baz = functools.partial(bar, y=4) + + def fn(): + return hasattr(baz, "__name__") + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + fn_out = fn() + compiled_out = compiled_fn() + self.assertEqual(fn_out, compiled_out) + self.assertFalse(fn_out) + def test_torch_objects_as_keys(self): remap = {torch.float16: torch.float32} @@ -7413,6 +8117,22 @@ def ___make_guard_fn(): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) + def test_dynamic_one_hot(self): + def fn(x): + x = x + 1 + # graph break from data-dependent output shape + x = torch.nn.functional.one_hot(x) + x = x + 1 + return x + + inp = torch.arange(20) % 4 + counter = CompileCounter() + real_out = fn(inp) + comp_out = torch.compile(fn, backend=counter)(inp) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 2) + self.assertEqual(counter.op_count, 2) + def test_tracing_nested_py_tree_mixed_all(self): import torch.utils._pytree as pytree @@ -7475,7 +8195,7 @@ def ___make_guard_fn(): # use checkpoint to trigger a "sourceless" tensor subclass def checkpoint_fn(xs): - return checkpoint(fn, xs) + return checkpoint(fn, xs, use_reentrant=True) xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) @@ -7519,6 +8239,31 @@ def ___make_guard_fn(): f(torch.tensor([2, 3, 4]), torch.randn(9)) + # See pytorch/pytorch/issues/119689 + @unittest.expectedFailure + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_runtime_assert_replacement(self): + @torch.compile(backend="aot_eager") + def fn(x, y): + z = y.item() + torch._check(z == 3) + return x + z + + fn(torch.randn(4), torch.tensor([3])) + self.assertRaises(RuntimeError, lambda: fn(torch.randn(4), torch.tensor([4]))) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_cat_unbacked(self): + @torch.compile(backend="eager") + def fn(x, y): + z = y.item() + return torch.cat([x, torch.ones(z)]) + + fn(torch.randn(2, 3), torch.tensor([0])) + self.assertRaises( + RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1])) + ) + def test_simple_set_usage(self): def foo(x, y): setty = {x, y} @@ -7580,8 +8325,8 @@ def ___make_guard_fn(): foo = torch._dynamo.optimize(counter, nopython=True)(foo) # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. - # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) - # and so the guard story for the objects passed into input just isn't there atm. + # Specifically, getting a set as ipt won't ever work with how GetItemSource works (Can't arbitrary access set contents) + # and so the guard story for the objects passed into ipt just isn't there atm. with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, "^call_method UserDefinedObjectVariable\\(set\\).*", @@ -7664,6 +8409,85 @@ def ___make_guard_fn(): # assert no recompile self.assertEqual(counter.frame_count, 6) + def test_str_format_return1(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(img): + x = torch.sin(img) + y = f"shape {img.shape[-2:]} batch size {img.shape[0]}" + return img + x, y + + img1 = torch.randn(1, 1, 8, 8) + res, msg = fn(img1) + self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1") + self.assertEqual(res, img1 + torch.sin(img1)) + + def test_str_format_return2(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(img): + x = torch.sin(img) + y = "shape {} batch size {y:.2f}".format(img.shape[-2:], y=img.shape[0]) + return img + x, y + + img1 = torch.randn(1, 1, 8, 8) + res, msg = fn(img1) + self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00") + self.assertEqual(res, img1 + torch.sin(img1)) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_validate_outputs_unbacked(self): + class SillyCat(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, x1, i): + ctx.save_for_backward(i) + return torch.cat([x0, x1]) + + @staticmethod + def backward(ctx, grad_out): + (i,) = ctx.saved_tensors + i0, i1 = i.tolist() + g_x0, g_x1 = grad_out.split([i0, i1]) + return g_x0, g_x1, None + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(x, i): + i0, i1 = i.tolist() + x0, x1 = x.split([i0, i1]) + return SillyCat.apply(x0, x1, i) + + f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) + + def test_str_format_assert1(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(img): + x = torch.sin(img) + val = x.shape[-2:] + torch._assert(len(val) == 2, f"shape {img.shape}") + return img + x + + img1 = torch.randn(1, 1, 8, 8) + res = fn(img1) + self.assertEqual(res, img1 + torch.sin(img1)) + + def test_str_format_assert2(self): + cnt = CompileCounter() + + @torch.compile(backend=cnt) + def fn(img): + x = torch.sin(img) + torch._assert( + img.shape[-2] == 8 and img.shape[-1] == 16, f"shape {img.shape}" + ) + return img + x + + img1 = torch.randn(1, 3, 8, 16) + res = fn(img1) + self.assertEqual(res, img1 + torch.sin(img1)) + self.assertEqual(cnt.frame_count, 1) + + # trigger a recompile and graph break + img2 = torch.randn(1, 3, 8, 15) + self.assertRaises(AssertionError, lambda: fn(img2)) + def test_tolist_scalar(self): def fn(x): new_list = [] @@ -7831,14 +8655,13 @@ def ___make_guard_fn(): def test_yield_from(self): def yield_from_fn(t_list, k): - def yield_from_gen(l): - l2 = [t * k for t in l] + def yield_from_gen(l_): + l2 = [t * k for t in l_] yield from l2 return [t * k for t in yield_from_gen(t_list)] t_list = [torch.randn([2, 3])] * 3 - multiplier = torch.tensor([10]) eager = yield_from_fn(t_list, 2) counter = CompileCounter() compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2) @@ -8121,10 +8944,10 @@ def ___make_guard_fn(): counters.clear() def fn(a, b, c, d, x): - l = [a, b, c, d, x] - for i, t in enumerate(l): - l[i] = t * x - return itertools.accumulate(l) + l_ = [a, b, c, d, x] + for i, t in enumerate(l_): + l_[i] = t * x + return itertools.accumulate(l_) t_list = [torch.tensor([i + 1]) for i in range(4)] x = torch.tensor([[1, 2], [3, 4]]) @@ -8141,10 +8964,10 @@ def ___make_guard_fn(): counters.clear() def fn(a, b, c, d, x): - l = [a, b, c, d, x] - for i, t in enumerate(l): - l[i] = t * x - return itertools.accumulate(l, builtin_op) + l_ = [a, b, c, d, x] + for i, t in enumerate(l_): + l_[i] = t * x + return itertools.accumulate(l_, builtin_op) t_list = [torch.tensor([i + 1]) for i in range(4)] x = torch.tensor([[1, 2], [3, 4]]) @@ -8167,10 +8990,10 @@ def ___make_guard_fn(): counters.clear() def fn(a, b, c, d, x): - l = [a, b, c, d, x] - for i, t in enumerate(l): - l[i] = t * x - return itertools.accumulate(l, **kwargs) + l_ = [a, b, c, d, x] + for i, t in enumerate(l_): + l_[i] = t * x + return itertools.accumulate(l_, **kwargs) t_list = [torch.tensor([i + 1]) for i in range(4)] x = torch.tensor([[1, 2], [3, 4]]) @@ -8203,10 +9026,10 @@ def ___make_guard_fn(): torch._dynamo.reset() def fn(a, b, c, d, x): - l = [a, b, c, d, x] - for i, t in enumerate(l): - l[i] = t * x - return itertools.accumulate(l, udo_fn) + l_ = [a, b, c, d, x] + for i, t in enumerate(l_): + l_[i] = t * x + return itertools.accumulate(l_, udo_fn) t_list = [torch.tensor([i]) for i in range(4)] x = torch.tensor([[1, 2], [3, 4]]) @@ -8249,14 +9072,14 @@ def ___make_guard_fn(): def test_itertools_groupby_pure_python_default_identify_func(self): counters.clear() - def fn(l): - return [(k, list(g)) for k, g in itertools.groupby(l)] + def fn(l_): + return [(k, list(g)) for k, g in itertools.groupby(l_)] - l = [1, 2, 2, 3, 4, 4, 4, 1, 2] - eager = fn(l) + l_ = [1, 2, 2, 3, 4, 4, 4, 1, 2] + eager = fn(l_) compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) - compiled = compiled_fn(l) + compiled = compiled_fn(l_) self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) @@ -8264,18 +9087,83 @@ def ___make_guard_fn(): def test_itertools_groupby_pure_python_key_func(self): counters.clear() - def fn(l): - return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)] + def fn(l_): + return [(k, list(g)) for k, g in itertools.groupby(l_, key=operator.neg)] - l = [1, 2, -2, 3, 4, 4, -4, 0, -2] - eager = fn(l) + l_ = [1, 2, -2, 3, 4, 4, -4, 0, -2] + eager = fn(l_) compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) - compiled = compiled_fn(l) + compiled = compiled_fn(l_) self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) + def test_list_iterator_contains(self): + def fn(x): + it = iter(["my_weight", "not_my_weight"]) + next(it) + if "my_weight" in it: + return x + 2 + return x + 1 + + x = torch.zeros(3) + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + + self.assertEqual(fn(x), compiled_fn(x)) + + def test_storage_return(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + y = torch.sin(x + 1) + storage = x.untyped_storage() + storage.resize_(0) + y = torch.cos(y) + return y, storage + + x = torch.randn(10) + expected = torch.cos(torch.sin(x + 1)) + y, s = fn(x) + self.assertEqual(y, expected) + self.assertEqual(x.untyped_storage().size(), 0) + self.assertIs(s, x.untyped_storage()) + + def test_flat_name_to_original_fqn(self): + class FooBarModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4))) + self.register_buffer("test_buf", torch.randn(3, 4)) + self.register_parameter( + "test_param", torch.nn.Parameter(torch.randn(3, 4)) + ) + + def forward(self, x): + return ((x + self.test_buf) * getattr(self, "0")) / self.test_param + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.foo_bar = FooBarModule() + self.register_parameter( + "test_param", torch.nn.Parameter(torch.randn(3, 4)) + ) + self.register_buffer("test_buf", torch.randn(3, 4)) + + def forward(self, x): + return (self.foo_bar(x) + self.test_param) * self.test_buf + + gm, _ = torch._dynamo.export(TestModule(), torch.randn(3, 4)) + self.assertIn("dynamo_flat_name_to_original_fqn", gm.meta) + expected_fqn = { + "L__self___test_param": "test_param", + "L__self___test_buf": "test_buf", + "getattr_L__self___foo_bar___0__": "foo_bar.0", + "L__self___foo_bar_test_param": "foo_bar.test_param", + "L__self___foo_bar_test_buf": "foo_bar.test_buf", + } + self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"]) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -8349,9 +9237,6 @@ ShapeEnv not equal: field values don't match: ==> name_to_node: values don't match. > Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {} -==> runtime_var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {} ==> source_to_symbol: values don't match. > Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]} > Right: {} @@ -8384,7 +9269,7 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> name_to_node: values don't match. - > Left: {f0, i0, i1} + > Left: {f0, u0, u1} > Right: {} ==> unbacked_symfloat_counter: values don't match. > Left: 1 @@ -8393,7 +9278,7 @@ ShapeEnv not equal: field values don't match: > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {f0: ValueRanges(lower=-oo, upper=oo, is_bool=False), i0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), i1: ValueRanges(lower=0, upper=1, is_bool=False)} + > Left: {f0: ValueRanges(lower=-oo, upper=oo, is_bool=False), u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False)} > Right: {} """, ) @@ -8469,6 +9354,9 @@ ShapeEnv not equal: field values don't match: ==> replacements: values don't match. > Left: {s0: 3} > Right: {} +==> var_to_range: values don't match. + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} """, ) self._replay_and_check(main) @@ -8504,9 +9392,6 @@ ShapeEnv not equal: field values don't match: ==> name_to_node: values don't match. > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} -==> var_to_guards: values don't match. - > Left: {s0: (s0 >= 3, None)} - > Right: {} ==> var_to_range: values don't match. > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} @@ -8534,14 +9419,14 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {i0: [Eq(Mod(i0, 3), 0)]} + > Left: {u0: [Eq(Mod(u0, 3), 0)]} > Right: {} ==> divisible: values don't match. - > Left: {Mod(i0, 3)} + > Left: {Mod(u0, 3)} > Right: {} ==> name_to_node: values don't match. - > Left: {_assert, eq, i0, mod} - > Right: {i0} + > Left: {_assert, eq, mod, u0} + > Right: {u0} ==> num_deferred_runtime_asserts: values don't match. > Left: 1 > Right: 0 @@ -8575,6 +9460,30 @@ ShapeEnv not equal: field values don't match: with set_default_dtype(torch.double): foo() + def test_numpy_ufunc_out(self): + @torch.compile(backend="eager") + def foo(): + x = np.arange(5) + out = np.empty((x.shape[0], x.shape[0])) + res_out = np.sin(x, out=out) + assert res_out is out + + foo() + + # Unfortunately, we don't currently preserve the ids of + # res_out and out correctly across the graph break + @unittest.expectedFailure + def test_numpy_ufunc_out_graph_break(self): + @torch.compile(backend="eager") + def foo(): + x = np.arange(5) + out = np.empty((x.shape[0], x.shape[0])) + res_out = np.sin(x, out=out) + torch._dynamo.graph_break() + assert res_out is out + + foo() + def test_dict_subclass_cannot_be_initialized_in_graph(self): for super_class in ( collections.OrderedDict, @@ -8597,6 +9506,31 @@ ShapeEnv not equal: field values don't match: ): print(fn_opt(torch.zeros(1))) + @wrapDeterministicFlagAPITest + def test_backward_deterministic_mode_mismatch_warning(self): + @torch.compile + def func(a, b): + return a + b + + for forward_deterministic, backward_deterministic in itertools.product( + [True, False], [True, False] + ): + torch.use_deterministic_algorithms(forward_deterministic) + a = torch.randn(10, requires_grad=True) + res = func(a, 1) + grad = torch.ones_like(res) + torch.use_deterministic_algorithms(backward_deterministic) + + if not forward_deterministic and backward_deterministic: + with self.assertRaisesRegex( + RuntimeError, + "^This compiled backward function is being run with torch\.use_deterministic_algorithms", + ): + res.backward(grad) + + else: + res.backward(grad) + def test_torch_dynamo_codegen_pow(self): def pow(x): return x**2 @@ -8614,6 +9548,51 @@ ShapeEnv not equal: field values don't match: msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code", ) + def test_compilation_metrics_size_limit(self): + def fn1(x): + return x.relu() + + def fn2(x): + return x.cos() + + def fn3(x): + return x.sin() + + def fn4(x): + return x.exp() + + import contextlib + + @contextlib.contextmanager + def metrics_limit_ctx(): + try: + torch._dynamo.utils.set_compilation_metrics_limit(3) + yield + finally: + torch._dynamo.utils.set_compilation_metrics_limit( + torch._dynamo.utils.DEFAULT_COMPILATION_METRICS_LIMIT + ) + + x = torch.rand((4, 4)) + torch._dynamo.reset() + torch.compile(fn1, backend="eager")(x) + torch.compile(fn2, backend="eager")(x) + torch.compile(fn3, backend="eager")(x) + torch.compile(fn4, backend="eager")(x) + + with metrics_limit_ctx(): + torch._dynamo.utils.clear_compilation_metrics() + torch._dynamo.reset() + self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics())) + torch.compile(fn1, backend="eager")(x) + self.assertEqual(1, len(torch._dynamo.utils.get_compilation_metrics())) + torch.compile(fn2, backend="eager")(x) + self.assertEqual(2, len(torch._dynamo.utils.get_compilation_metrics())) + torch.compile(fn3, backend="eager")(x) + self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) + torch.compile(fn4, backend="eager")(x) + self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics())) + def test_funcname_cache(self): src = """\ import torch @@ -8678,6 +9657,262 @@ fn """, ) + def test_return_dict_with_graph_break_and_update(self): + def create(): + torch._dynamo.graph_break() + return {0: torch.tensor(3)} + + def fn(): + return {**create()} + + opt_fn = torch.compile(backend="eager")(fn) + result = opt_fn() + self.assertIn(0, result) + self.assertTrue(same(result[0], torch.tensor(3))) + + def test_dynamo_reset_clears_cache(self): + """Test that dynamo bytecode cache is freed + when dynamo reset is called + """ + + def fn(x): + return torch.sin(x) + + opt_fn = torch.compile(backend="eager")(fn) + opt_fn(torch.randn(3, 3)) + + c1 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c1), 1) + + torch._dynamo.reset() + c2 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c2), 0) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_guard_size_oblivious(self): + # This code, in fact, does NOT work in eager + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + y = torch.zeros(x.item()) + if guard_size_oblivious(y.size(0) == 0): + assert False + return y + + self.assertEqual(fn(torch.tensor([0])), torch.zeros(0)) + + def _test_compile_model_free(self, model_inp_ctr, weakref_watch): + """ + Args: + model_inp_ctr + - constructor that returns a new model and inputs to that model + weakref_watch + - function that returns a layer of the model for weakref to + finalize on, so we can check that the layer is freed after + the model goes out of scope + """ + cleared = False + + def finalize(): + nonlocal cleared + cleared = True + + def run(): + mod, inp = model_inp_ctr() + weakref.finalize(weakref_watch(mod), finalize) + torch.compile(mod, backend="eager")(inp) + + run() + gc.collect() + self.assertTrue(cleared) + + def test_custom_module_free(self): + """Test that a model is freed when it goes out of scope""" + + class Mod(torch.nn.Module): + def __init__(self): + super(Mod, self).__init__() + self.fc = torch.nn.Linear(100, 100) + + def forward(self, out): + return self.fc(out) + + self._test_compile_model_free( + lambda: (Mod(), torch.randn(100, 100)), + lambda mod: mod.fc, + ) + + @xfailIfPy311 + def test_sequential_module_free(self): + self._test_compile_model_free( + lambda: ( + torch.nn.Sequential( + torch.nn.Linear(100, 100), + torch.nn.ReLU(), + ), + torch.randn(100, 100), + ), + lambda mod: mod[0], + ) + + @unittest.expectedFailure + def test_linear_module_free(self): + self._test_compile_model_free( + lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)), + lambda mod: mod, + ) + + def test_dynamo_cache_move_to_front(self): + class Mod(torch.nn.Module): + def __init__(self): + super(Mod, self).__init__() + self.fc = torch.nn.Linear(3, 3) + + def forward(self, out): + return self.fc(out) + + def fn(x, mod): + return mod(x) + + opt_fn = torch.compile(fn, backend="eager") + + m1 = Mod() + m2 = Mod() + m3 = Mod() + inp = torch.randn(3, 3) + + # NOTE: assumes that each cache entry is guarded + # on unique Mod instance + opt_fn(inp, m1) + opt_fn(inp, m2) + opt_fn(inp, m3) + + c1 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c1), 3) + + # move cache entry to front + opt_fn(inp, m2) + c2 = _debug_get_cache_entry_list(fn.__code__) + self.assertIs(c1[1], c2[0]) + + def test_dynamo_cache_invalidate(self): + class Mod(torch.nn.Module): + def __init__(self): + super(Mod, self).__init__() + self.fc = torch.nn.Linear(3, 3) + + def forward(self, out): + return self.fc(out) + + def fn(x, mod): + return mod(x) + + opt_fn = torch.compile(fn, backend="eager") + + m1 = Mod() + m2 = Mod() + m3 = Mod() + inp = torch.randn(3, 3) + + # NOTE: assumes that each cache entry is guarded + # on unique Mod instance + opt_fn(inp, m1) + opt_fn(inp, m2) + opt_fn(inp, m3) + + c1 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c1), 3) + + # move cache entry to front + opt_fn(inp, m2) + c2 = _debug_get_cache_entry_list(fn.__code__) + self.assertIs(c1[1], c2[0]) + + # delete center of cache + del m3 + c3 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c3), 2) + self.assertIs(c3[0], c2[0]) + self.assertIs(c3[1], c2[2]) + + # delete end of cache + del m1 + c4 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c4), 1) + self.assertIs(c4[0], c3[0]) + + del m2 + c5 = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(c5), 0) + + def test_grad_none(self): + def fn(x, y): + x.grad = torch.abs(y) + x.grad.add_(y) + return torch.abs(y) + + y = torch.arange(4).reshape(2, 2).to(torch.float) + x = torch.randn(2, 2) + x.grad = None + + z = fn(x, y) + ref_y = torch.clone(z).detach() + ref_x_grad = torch.clone(x.grad).detach() + + y = torch.arange(4).reshape(2, 2).to(torch.float) + x = torch.randn(2, 2) + x.grad = None + + opt_fn = torch.compile(fn, backend="eager") + z = opt_fn(x, y) + self.assertEqual(z, ref_y) + self.assertEqual(x.grad, ref_x_grad) + + def test_grad_non_none(self): + def fn(x, y): + x.grad.add_(y) + return torch.abs(y) + + y = torch.ones(2, 2) + x = torch.randn(2, 2) + x.grad = torch.arange(4).reshape(2, 2).to(torch.float) + + z = fn(x, y) + ref_y = torch.clone(z).detach() + ref_x_grad = torch.clone(x.grad).detach() + + y = torch.ones(2, 2) + x = torch.randn(2, 2) + x.grad = torch.arange(4).reshape(2, 2).to(torch.float) + + cnt = torch._dynamo.testing.CompileCounterWithBackend("eager") + opt_fn = torch.compile(fn, backend=cnt) + z = opt_fn(x, y) + + # Ensure that the generated graph returns only one output. We want the + # add_ on the grad to be part of the graph itself, so that inductor can + # theoretically move the add_ and resutling copy_ nodes at the right + # place to free memory. + self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1) + self.assertEqual(z, ref_y) + self.assertEqual(x.grad, ref_x_grad) + + def test_new_with_int_list(self): + # Make sure torch.Tensor.new(int argument list) behaves the same on dynamo. + def fn(x): + return x.new(*x.size()) + 5 + + optfn = torch.compile(backend="eager")(fn) + + x = torch.arange(10).view(2, 5) + + expected = fn(x) + actual = optfn(x) + + self.assertEqual(expected.dtype, actual.dtype) + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected.stride(), actual.stride()) + self.assertEqual(expected.storage_offset(), actual.storage_offset()) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index b5447d1360add0bf9bc1da67507601b883764fb4..1eb38f4193f80defef69d0b5d7ead1d84f81c128 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -4,7 +4,6 @@ import unittest.mock import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same @@ -171,7 +170,7 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): class BertPooler(torch.nn.Module): def __init__(self): super().__init__() - self.dense = torch.nn.Linear(768, 768).npu() + self.dense = torch.nn.Linear(768, 768).to("npu:0") self.activation = torch.nn.Tanh() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -229,7 +228,7 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): result["pooler_output"] = pooled_output return result - sequence_output = torch.rand(1, 12, 768).npu() + sequence_output = torch.rand(1, 12, 768).to("npu:0") model = BertModel() orig_result = model(sequence_output) compiled_model = torch.compile(model, backend="eager") diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 33aab22f76e7826ad496e73c49a91c63b9d1aafd..e838f28bac159156a164537456b4a6acd7118b30 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3,11 +3,11 @@ import collections import itertools import traceback -import types as tps +import types import unittest from copy import deepcopy from functools import partial -from typing import Tuple +from typing import Dict, NamedTuple, Tuple from unittest.mock import patch import torch @@ -300,7 +300,7 @@ class ModuleList(torch.nn.Module): ) def forward(self, x): - for i, _ in enumerate(self.layers): + for i in range(len(self.layers)): x = self.layers[i](x) for layer in self.layers: @@ -336,8 +336,11 @@ class CustomGetItemModuleList(torch.nn.Module): def __getitem__(self, idx: int): return self.layers[idx] + def __len__(self) -> int: + return len(self.layers) + def forward(self, x): - for i, _ in enumerate(self.layers): + for i in range(len(self)): x = self[i](x) return x @@ -537,9 +540,8 @@ class DenseNetBlocks(torch.nn.Module): class MaterializedModule(torch.nn.Module): - """Once the below lazy module is initialized with its first input, - it is transformed into this module. - """ + """Once the below lazy module is initialized with its first ipt, + it is transformed into this module.""" param: Parameter @@ -579,6 +581,38 @@ class LazyMLP(torch.nn.Module): return y +class MyInput(NamedTuple): + x: Dict[str, Dict[str, torch.Tensor]] + y: torch.Tensor + + +class LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module): + def __init__(self): + super().__init__() + + def initialize_parameters(self, ipt): + with torch.no_grad(): + self._param = torch.nn.Parameter( + torch.empty(ipt.x["a"][0].shape).fill_(0.5) + ) + + def forward(self, ipt): + ipt = ipt.x["a"] + x = 0 + for i in range(len(ipt)): + x = x + ipt[i] + return x + + +class LazyModuleWithNamedTupleInput(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = LazyLayerWithNamedTupleInput() + + def forward(self, ipt): + return self.layer(ipt) + + class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module): def __init__(self): super().__init__() @@ -589,7 +623,7 @@ class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module): def forward(self, ipt): x = 0 - for i, _ in enumerate(ipt): + for i in range(len(ipt)): x = x + ipt[i] return x @@ -615,6 +649,37 @@ class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module): return self.layer(x) +class LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module): + def __init__(self): + super().__init__() + + def initialize_parameters(self, x, y): + with torch.no_grad(): + self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5)) + self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5)) + + def forward(self, x, y): + res_x = 0 + for i in range(len(x)): + res_x = res_x + x[i] + res_y = 0 + for i in range(len(y)): + res_y = res_y + y[i] + return res_x + res_y + + +class LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module): + def __init__(self): + super().__init__() + + def initialize_parameters(self, *args, **kwargs): + with torch.no_grad(): + self.layer = LazyLayerWithInputs() + + def forward(self, x, y): + return self.layer(x, y=y) + + class LazyParentModule(LazyModuleMixin, torch.nn.Module): def __init__(self): super().__init__() @@ -804,7 +869,9 @@ class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d): def forward(self, x): if x.numel() > 0: return super().forward(x) - zip_x = zip( + output_shape = [ + ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op) + for i, p, di, k, d, op in zip( x.shape[-2:], self.padding, self.dilation, @@ -812,12 +879,9 @@ class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d): self.stride, self.output_padding, ) - output_shape = [ - ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op) - for i, p, di, k, d, op in zip_x ] output_shape = [x.shape[0], self.bias.shape[0]] + output_shape - return _NewEmptyTensorOp.apply(x, output_shape) + return _NewEmptyTensorOp.apply(x, output_shape) # noqa: F821 class ModuleNameString(torch.nn.Module): @@ -931,7 +995,7 @@ class ModuleGuardNameIsValid(torch.nn.ModuleDict): def __init__(self): super().__init__() for i in range(2): - self.add_module("l@yer-%d" % (i + 1), BasicModule()) + self.add_module("l_@yer-%d" % (i + 1), BasicModule()) def forward(self, x): for layer in self.values(): @@ -1394,7 +1458,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): res = opt_m(x) ref = m(x) self.assertTrue(torch.allclose(ref, res)) - # input shape changed and second iteration + # ipt shape changed and second iteration x = torch.rand([20, 20]) try: opt_m(x) @@ -1404,7 +1468,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module5(self): - # Test lazy module works well with list/tuple input + # Test lazy module works well with list/tuple ipt m = LazyModuleWithListInput() x = [torch.rand([5, 5])] * 3 + [None] opt_m = torch._dynamo.optimize("eager", nopython=True)(m) @@ -1423,6 +1487,20 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): ref = m(x) self.assertTrue(torch.allclose(ref, res)) + # RuntimeError: SymIntArrayRef expected to contain only concrete integers + @expectedFailureDynamic + def test_lazy_module7(self): + # Test lazy module works well with namedtuple/dict ipt + m = LazyModuleWithNamedTupleInput() + x = MyInput( + x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])}, + y=torch.rand([5, 5]), + ) + opt_m = torch.compile(backend="eager", fullgraph=True)(m) + res = opt_m(x) + ref = m(x) + self.assertTrue(torch.allclose(ref, res)) + def test_lazy_module_no_cls_to_become(self): # make sure super() works in the case where cls_to_become is None m = LazyChildModuleNoClsToBecome() @@ -1432,6 +1510,14 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): ref = m(x) self.assertTrue(torch.allclose(ref, res)) + def test_lazy_module_kwargs(self): + m = LazyModuleKwArgs() + x = [torch.rand([5, 5])] * 3 + y = [torch.rand([5, 5])] * 2 + opt_m = torch.compile(backend="eager", fullgraph=True)(m) + exp_res = m(x, y) + self.assertTrue(torch.allclose(exp_res, opt_m(x, y))) + def test_call_fn_with_non_const_inputs_safe(self): class ModuleSpecialFwd(torch.nn.Module): def __init__(self): @@ -1535,7 +1621,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 3) def test_attr(self): - class MockModule_attr(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) @@ -1544,7 +1630,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def forward(self, x): return self.r(torch.sin(x)) + self.buf0 - mod = MockModule_attr() + mod = MockModule() opt_mod = torch._dynamo.optimize("eager")(mod) # Check parameters and buffers @@ -1562,7 +1648,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): self.assertEqual(out_dtype, torch.float32) def test_dir(self): - class MockModule_dir(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) @@ -1574,7 +1660,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def forward(self, x): return self.r(torch.sin(x)) + self.buf0 - mod = MockModule_dir() + mod = MockModule() mod_keys = dir(mod) opt_mod = torch._dynamo.optimize("eager")(mod) opt_mod_keys = dir(opt_mod) @@ -1602,7 +1688,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): a = torch.sin(torch.cos(x)) return self.linear(a) - class MockModule_toy(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.mods = [SubModule() for _ in range(num_submodules)] @@ -1613,7 +1699,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): x = mod(x) return x - mod = MockModule_toy() + mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. with unittest.mock.patch( @@ -1640,7 +1726,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): a = torch.sin(torch.cos(x)) return self.relu(a) - class MockModule_1(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.mods = [SubModule() for _ in range(num_submodules)] @@ -1651,7 +1737,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): x = mod(x) return x - mod = MockModule_1() + mod = MockModule() # For the third iteration, we would reach the cache size limit, and # therefore the total number of expected frame count is 2 * # num_submodules. @@ -1738,7 +1824,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def test_module_patch(self): mod = ModulePatch1() - mod.forward = tps.MethodType(ModulePatch2.forward, mod) + mod.forward = types.MethodType(ModulePatch2.forward, mod) def fn(x): return mod(x) @@ -1842,7 +1928,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): handle.remove() self.assertEqual(compiled_func(inp), outer_func(inp)) self.assertEqual(compiled_func(inp).item(), 7) - self.assertTrue("forward_hooks.keys" in failure_reason) + self.assertTrue("forward_hooks" in failure_reason) self.assertEqual(cc.frame_count, 1 + 1) self.assertEqual(cc.op_count, 6 + 4) @@ -1862,9 +1948,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): m._forward_hooks[handle.id] = new_forward_hook self.assertEqual(compiled_func(inp), outer_func(inp)) self.assertEqual(compiled_func(inp).item(), 16) - self.assertRegex( - failure_reason, r"^___check_obj_id\(.*\(L\['m'\]\._forward_hooks" - ) + self.assertRegex(failure_reason, r"^___check_obj_id\(L\['m'\]._forward_hooks") @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) def test_hooks_skip_guards(self): @@ -2225,7 +2309,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): self.buffer = torch.rand([4]) def forward(self, x): - # should be a no-op, but causes dynamo to lose the static input + # should be a no-op, but causes dynamo to lose the static ipt x = x + 1 self.buffer = self.buffer.to(x) return self.buffer + x @@ -2267,7 +2351,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def test_no_guard_on_torch_nn_modules(self): # See pytorch/pytorch/issues/110048 - class MockModule2(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) @@ -2275,7 +2359,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): def forward(self, x): return self.linear(x) - mod = MockModule2() + mod = MockModule() cnt = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index 800ad8bbe927b6ad56e4014980e32e3a4ef369a7..c145919341645104401c3e18bfa715df1187422b 100644 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo import eval_frame diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index b67b335b1a7de6c297978294a845f41629513e7d..f85e73ac4c88b234295c44409d1773fc40347292 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -6,7 +6,6 @@ import functools # Owner(s): ["module: dynamo"] -import inspect import torch import torch_npu @@ -15,94 +14,6 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch.nn import Parameter -input1 = torch.ones([10, 10]) -model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)]) -model(input1).sum().backward() - - -def get_optimizer_step(opt_arg, closure=None): - # run the patcher so that step has the expected structure - torch._dynamo.eval_frame.TorchPatcher.patch() - - # unwrap step to avoid a deliberate graph break due to - # a limitation of functionalization/no_grad detection - # see the [Note on graph break] in optimizer.py - # This ignores the outer _use_grad_if_differentiable wrapper, which is fine for now - # as dynamo does not support differentiable optimizers anyway - step_fn = opt_arg.step.__wrapped__ - if closure is not None: - - def fn(): - step_fn(opt_arg, closure) - - else: - - def fn(): - step_fn(opt_arg) - - return fn - - -def make_test(optim_cls, closure=None, **kwargs): - opt = optim_cls(model.parameters(), **kwargs) - - def test_fn(self): - nonlocal opt - - fn = get_optimizer_step(opt, closure=closure) - - with torch.set_grad_enabled(False): - torch.compile(fn, backend="eager", fullgraph=True)() - - return test_fn - - -class OptimizerTests(torch._dynamo.test_case.TestCase): - test_sgd = make_test(torch.optim.SGD, lr=0.01) - # lgbfs has data-dependent control and internally iterates - # calling the closure - # do for later mlazos: re-enable once we have latest pytorch with FakeTensor fix #497 - # test_lbfgs = make_test( - # torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum() - # ) - - # Has data dependent control for rectification (needs symint) - # RAdam has data-dependent control which breaks the graph; - # furthermore, the break is inside a for loop, so we bail on the frame - # entirely. This is basically an xfail; if the frame count goes up - # you done good - # test_radam = unittest.skipIf(IS_FBCODE, "TypeError: _use_grad() missing")( - # make_test(torch.optim.RAdam, exp_graph_count=0) - # ) - - -# exclude SparseAdam because other areas of the stack don't support it yet -# the others are handled specially above -exclude = { - "SGD", # Handled above - "Optimizer", - "SparseAdam", # Unsupported - "LBFGS", # Unsupported - "RAdam", # Has data dependent control for rectification (needs symint) -} - - -def check_opt(opt_ipt): - if inspect.isclass(opt_ipt) and issubclass(opt_ipt, torch.optim.Optimizer) and opt_ipt.__name__ not in exclude: - return True - return False - - -optimizers = [ - opt - for opt in torch.optim.__dict__.values() - if check_opt(opt) -] - - -for opt in optimizers: - setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt)) - class MyOptimizer(torch.optim.Optimizer): def __init__(self, params): @@ -126,7 +37,7 @@ class MyOptimizer(torch.optim.Optimizer): class End2EndTests(torch._dynamo.test_case.TestCase): - # see torchdynamo issues 1604 + # See pytorch/torchdynamo/issues/1604 def test_optimizing_over_tensor_with_requires_grad(self): class Net(torch.nn.Module): def forward(self, x, y): @@ -144,13 +55,13 @@ class End2EndTests(torch._dynamo.test_case.TestCase): return loss net = Net() - input_1 = torch.randn(2, 1, 4) - input_2 = torch.randn(2, 4, 8, requires_grad=True) - optimizer = torch.optim.Adam([input_2], lr=0.1) + input1 = torch.randn(2, 1, 4) + input2 = torch.randn(2, 4, 8, requires_grad=True) + optimizer = torch.optim.Adam([input2], lr=0.1) cnts = torch._dynamo.testing.CompileCounter() opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn) - batch = {"x": input_1, "y": input_2} + batch = {"x": input1, "y": input2} for _ in range(2): opt_training_iter_fn(batch, net, optimizer) self.assertEqual(cnts.frame_count, 2) @@ -181,7 +92,6 @@ class End2EndTests(torch._dynamo.test_case.TestCase): tensor = torch.randn(5, 5, dtype=dtype) params = Parameter(tensor.detach().clone(), requires_grad=False) opt_params = Parameter(tensor.detach().clone(), requires_grad=False) - print(params, opt_params) optim = MyOptimizer([params]) optim.step() @@ -189,7 +99,6 @@ class End2EndTests(torch._dynamo.test_case.TestCase): opt_optim = MyOptimizer([opt_params]) opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step) opt_step() - print(params, opt_params) self.assertEqual(params, opt_params) diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 02e7d289f74c673f91dcd82c528f6e163658df0a..ac9c7b936c111f5424c875294d57ae2c8a679edb 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -3,7 +3,6 @@ from unittest.mock import patch import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index 4365713cac718bf6bb6253ebb3c541e964778446..9ae24545ddd33958959ba5ea3515b56ffef18807 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -3,7 +3,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional import torch import torch_npu - import torch._dynamo from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import CompileCounter, same @@ -96,15 +95,15 @@ def grad(L, desired_results: List[Variable]) -> List[Variable]: # perform chain rule propagation specific to each compute dL_dinputs = entry.propagate(dL_doutputs) - # Accumulate the gradient produced for each input. + # Accumulate the gradient produced for each ipt. # Each use of a variable produces some gradient dL_dinput for that # use. The multivariate chain rule tells us it is safe to sum # all the contributions together. - for input1, dL_dinput in zip(entry.inputs, dL_dinputs): - if input1 not in dL_d: - dL_d[input1] = dL_dinput + for ipt, dL_dinput in zip(entry.inputs, dL_dinputs): + if ipt not in dL_d: + dL_d[ipt] = dL_dinput else: - dL_d[input1].value += dL_dinput.value + dL_d[ipt].value += dL_dinput.value # print some information to understand the values of each intermediate # for name, value in dL_d.items(): diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py new file mode 100644 index 0000000000000000000000000000000000000000..5cab7a07958c4a014ee3ad09b5869efc2fee7bca --- /dev/null +++ b/test/dynamo/test_recompile_ux.py @@ -0,0 +1,288 @@ +# Owner(s): ["module: dynamo"] +import unittest +import weakref + +import torch +import torch_npu +import torch._dynamo +import torch._dynamo.config +import torch._dynamo.test_case +import torch._dynamo.testing + +import torch._logging +from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings + + +class RecompileUxTests(torch._dynamo.test_case.TestCase): + # do for later(whc) dynamo actually recompiles one more time than the cache limit + cache_limit = 1 + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + torch._dynamo.config.patch("cache_size_limit", cls.cache_limit) + ) + + def test_drop_cache_on_skip(self): + def model(x, i): + return x + i + + attached = False + triggered = False + + def trigger(): + nonlocal triggered + triggered = True + + def compiler(gm, ipt): + nonlocal attached + f = gm.forward + assert not attached + # NB: making this a weakref.ref causes the cycle to no + # longer be promptly GC'ed + weakref.finalize(f, trigger) + attached = True + return f + + x = torch.randn(2) + for i in range(2): + opt_model = torch._dynamo.optimize(compiler)(model) + opt_model(x, i) + + self.assertTrue(triggered) + + def test_loop_torture(self): + def loop_torture(ipt, iters): + out = ipt + # randint itself causes one graph break + for _ in range(iters): + out += ipt + return out + + compile_counter = torch._dynamo.testing.CompileCounter() + for _ in range(10): + x = torch.randn(3) + iters = torch.randint(low=0, high=1000, size=()) + opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture) + opt_loop_torture(x, iters) + + # Currently, we recompile each time, + # We'd probably like to bail out quickly and warn + # do for later(whc) these checks fail on py37. Why? + # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit) + # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit) + + # compile_counter only sees frames that were fed to the backend compiler, + # which is a subset of counters["frames"]["ok"] -- probably because + # counters["frames"]["ok"] includes frames not containing torch ops? + self.assertEqual(compile_counter.frame_count, self.cache_limit) + + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + def test_dynamic_input(self): + def model(ipt): + return ipt + ipt + + expected_recompiles = 2 + compile_counter = torch._dynamo.testing.CompileCounter() + with torch._dynamo.config.patch("cache_size_limit", expected_recompiles): + with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: + for _ in range(10): + bsz = torch.randint(low=0, high=1000, size=()) + x = torch.randn((bsz, 3, 4)) + opt_model = torch._dynamo.optimize(compile_counter)(model) + opt_model(x) + + self.assertEqual(compile_counter.frame_count, expected_recompiles) + self.assertEqual(len(logs.records), 1) + print(logs.records[0]) + self.assertTrue( + logs.records[0] + .getMessage() + .startswith("torch._dynamo hit config.cache_size_limit") + ) + + @unittest.skipIf(not torch.npu.is_available(), "requires npu") + def test_nvfuser_guards(self): + # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards + # such that we ensure dynamo is in charge of all the recompilations at the top level, + # and we could thus simplify the underlying torchscript executor + def func(a, b, c): + return a + b * c + + a = torch.rand(3, 4, 5, device="npu:0") + b = torch.rand(3, 4, 5, device="npu:0") + b_v = torch.rand(3, 5, 4, device="npu:0").view(3, 4, 5) + b_p = torch.rand(3, 5, 4, device="npu:0").permute(0, 2, 1) + c = torch.rand(3, 4, 5, device="npu:0") + compile_counter = torch._dynamo.testing.CompileCounter() + + with torch._dynamo.config.patch("cache_size_limit", 2): + opt_func = torch._dynamo.optimize(compile_counter)(func) + opt_func(a, b, c) # warmup + self.assertEqual(compile_counter.frame_count, 1) + + opt_func(a, b, c) # no guard fail or recompile + self.assertEqual(compile_counter.frame_count, 1) + + opt_func(a, b_v, c) # a view should not cause nvfuser recompile + self.assertEqual(compile_counter.frame_count, 1) + + opt_func(a, b_p, c) # a permutation should cause recompile + self.assertEqual(compile_counter.frame_count, 2) + + def assert_single_log_contains(self, logs, contains_str): + self.assertEqual(len(logs.records), 1) + self.assertTrue( + logs.records[0].getMessage().find(contains_str) > 0, + msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"', + ) + + def test_verbose_tensor_check(self): + def func(a): + # Warning: choose a function here whose meta implementation lives + # entirely in C++. If you do a Python one, Dynamo will dive into + # torch._refs which is OK but it will muddy up the warnings + return torch.add(a, 4) + + def cache_fail_test(cached_input, missed_input, expected_failure): + # do for later(whc) maybe its hacky to have a 'test within a test' but this seemed convenient + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + opt_func = torch._dynamo.optimize("eager")(func) + # warmup + opt_func(cached_input) + + with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: + opt_func = torch._dynamo.optimize("eager")(func) + opt_func(missed_input) + self.assert_single_log_contains(logs, expected_failure) + + a = torch.rand(3, 4, 5) + cache_fail_test( + a, + a[0:2, :, :], + "tensor 'L['a']' size mismatch at index 0. expected 3, actual 2", + ) + cache_fail_test( + a, + a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)), + "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1", + ) + cache_fail_test( + a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2" + ) + cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.") + cache_fail_test( + a, + a.to(torch.float16), + "tensor 'L['a']' dtype mismatch. expected Float, actual Half", + ) + a_grad = a.clone() + a_grad.requires_grad = True + cache_fail_test( + a, + a_grad, + "tensor 'L['a']' requires_grad mismatch. expected requires_grad=0", + ) + + def test_mismatched_type(self): + a = torch.rand(3, 4, 5) + b = torch.rand(3, 4, 5) + + def func(a, b): + return a + b + + opt_func = torch._dynamo.optimize("eager")(func) + # warmup + opt_func(a, b) + + with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: + opt_func = torch._dynamo.optimize("eager")(func) + opt_func(a, 1) + self.assert_single_log_contains( + logs, + "expected type of 'L['b']' to be a tensor type, ' but found ", + ) + + @torch._dynamo.config.patch("cache_size_limit", 32) + def test_multiple_guard_fails(self): + failure_reasons = [] + + def guard_fail_fn(failure): + failure_reasons.append(failure[0]) + + def f(x): + return torch.relu(x) + + opt_f = torch._dynamo.optimize( + backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False + )(f) + + for i in range(5): + failure_reasons.clear() + opt_f(torch.randn(8 + i)) + + failure_str = "\n".join(failure_reasons) + for line in """\ +tensor 'L['x']' size mismatch at index 0. expected 11, actual 12 +tensor 'L['x']' size mismatch at index 0. expected 10, actual 12 +tensor 'L['x']' size mismatch at index 0. expected 9, actual 12 +tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split( + "\n" + ): + self.assertIn( + line, + failure_str, + ) + + @torch._dynamo.config.patch("cache_size_limit", 32) + def test_multiple_guard_fails_report_all(self): + with log_settings(kwargs_to_settings(recompiles_verbose=True)): + failure_reasons = [] + + def guard_fail_fn(failure): + failure_reasons.append(failure[0]) + + def f(x): + return torch.ones(len(x), x[-1]) + + opt_f = torch._dynamo.optimize( + backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False + )(f) + + opt_f([4, 5, 6]) + + def filter_reasons(): + return "\n".join( + [ + line + for line in "\n".join(failure_reasons).splitlines() + if not line.startswith("___check_type_id") + ] + ) + + failure_reasons.clear() + opt_f([7, 8]) + + for line in """\ +len(L['x']) == 3""".split( + "\n" + ): + self.assertIn(line, filter_reasons()) + + failure_reasons.clear() + opt_f([9]) + + for line in """\ +len(L['x']) == 2 +len(L['x']) == 3""".split( + "\n" + ): + self.assertIn(line, filter_reasons()) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index 87b2d238915f02c59f11a82db77e081726e8fc07..aee230a3e170be65f9d7e486156d0f50f3adf93d 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -3,7 +3,6 @@ from unittest.mock import patch import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing @@ -38,23 +37,15 @@ class RecompileTests(torch._dynamo.test_case.TestCase): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -108,23 +99,15 @@ class RecompileTests(torch._dynamo.test_case.TestCase): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -163,23 +146,15 @@ class RecompileTests(torch._dynamo.test_case.TestCase): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles_swap_types() + return run_foo_6_times_and_count_recompiles_swap_types() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles_swap_types() + return run_foo_6_times_and_count_recompiles_swap_types() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -276,45 +251,29 @@ class RecompileTests(torch._dynamo.test_case.TestCase): return cnt + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_default_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": True, - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_default_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": True, - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_dynamic_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": False, - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_dynamic_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": False, - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() torch._dynamo.reset() static_comp_default_param = run_static_comp_default_param() @@ -357,81 +316,6 @@ class RecompileTests(torch._dynamo.test_case.TestCase): model(x) self.assertEqual(counter.frame_count, 2) - def test_forbid_nopython_has_graph_break_cache_hit(self): - from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn - - for create_functions in [ - lambda f, cnt: ( - torch.compile(f, backend=cnt), - torch.compile(f, backend=cnt, fullgraph=True), - ), - lambda f, cnt: ( - torch._dynamo.optimize(backend=cnt)(f), - torch._dynamo.optimize(backend=cnt, nopython=True)(f), - ), - ]: - torch._dynamo.reset() - - def fn(x): - if len(x.size()) == 1: - x = x + 2 - torch._dynamo.graph_break() - return x + 1 - else: - return x + 1 - - cnt = torch._dynamo.testing.CompileCounter() - - opt_fn, nopython_fn = create_functions(fn, cnt) - - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"): - nopython_fn(torch.zeros(1)) - self.assertEqual(cnt.frame_count, 0) - - opt_fn(torch.zeros(1)) - self.assertEqual(cnt.frame_count, 2) - - cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn)) - self.assertEqual(len(cache_entries), 1) - # guarded code with graph break has `___needs_nopython` guard - self.assertTrue( - any( - "___needs_nopython" in part - for part in cache_entries[0].check_fn.code_parts - ) - ) - - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "graph_break"): - nopython_fn(torch.zeros(1)) - self.assertEqual(cnt.frame_count, 2) - - opt_fn(torch.zeros(1)) - self.assertEqual(cnt.frame_count, 2) - - nopython_fn(torch.zeros((1, 2))) - self.assertEqual(cnt.frame_count, 3) - - cache_entries = _debug_get_cache_entry_list(innermost_fn(opt_fn)) - self.assertEqual(len(cache_entries), 2) - # nopython function with no graph break does not have `___needs_nopython` guard - self.assertFalse( - any( - "___needs_nopython" in part - for part in cache_entries[0].check_fn.code_parts - ) - ) - # previous guarded code with graph break still has `___needs_nopython` guard - self.assertTrue( - any( - "___needs_nopython" in part - for part in cache_entries[1].check_fn.code_parts - ) - ) - - # nopython does not recompile - manages to hit cache entry with no graph breaks - nopython_fn(torch.zeros((1, 2))) - self.assertEqual(cnt.frame_count, 3) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py new file mode 100644 index 0000000000000000000000000000000000000000..3c893c79392c30d36c7595389df9e5d3140bf863 --- /dev/null +++ b/test/dynamo/test_reorder_logs.py @@ -0,0 +1,153 @@ +# Owner(s): ["module: dynamo"] +import io +import warnings +from unittest.mock import patch + +import torch +import torch_npu +import torch._dynamo +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.testing import same +from torch._dynamo.utils import counters + + +class ReorderLogsTests(torch._dynamo.test_case.TestCase): + def test_dont_reorder_print(self): + def f(x): + x = x + x + print("moo") + x = x * x + return x + + counters.clear() + x = torch.randn(3, 3) + opt_f = torch.compile(backend="eager")(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(x) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(x) + + self.assertTrue(same(orig_out, opt_out)) + self.assertEqual(printed_output, "moo") + self.assertEqual(len(counters["graph_break"]), 1) + + @torch._dynamo.config.patch(reorderable_logging_functions={print}) + def test_reorder_print(self): + def f(x): + print("moo") + x1 = x + x + print(x1) + x2 = x1 * x1 + print(1, 2, 3) + x3 = x2 + x2 + return (x1, x3) + + x = torch.ones(3, 3) + opt_f = torch.compile(backend="eager", fullgraph=True)(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(x) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(x) + + self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3") + self.assertTrue(same(orig_out, opt_out)) + + @torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn}) + def test_reorder_warnings(self): + import warnings + + def f(x): + x1 = x + x + warnings.warn("moo") + x2 = x1 * x1 + warnings.warn(f"{x2}") + x3 = x2 + x2 + return x3 + + x = torch.ones(3, 3) + opt_f = torch.compile(backend="eager", fullgraph=True)(f) + with warnings.catch_warnings(record=True) as w: + opt_out = opt_f(x) + warning_messages = [str(i.message) for i in w] + orig_out = f(x) + + self.assertTrue(same(orig_out, opt_out)) + self.assertIn("moo", warning_messages) + + @torch._dynamo.config.patch(reorderable_logging_functions={print}) + def test_reorder_print_graph_break(self): + def f(x): + x1 = x + x + print(f"res: {x1}") + x2 = x1 * x1 + torch._dynamo.graph_break() + x3 = x2 + x2 + print(1, 2, 3) + return x3 + + x = torch.ones(3, 3) + opt_f = torch.compile(backend="eager")(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(x) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(x) + + self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3") + self.assertTrue(same(orig_out, opt_out)) + + def test_reorder_custom_log_fn(self): + custom_logs = [] + + def custom_log(s: str): + torch._dynamo.graph_break() + custom_logs.append(s) + + def f(x): + custom_log("moo") + x1 = x + x + custom_log(f"{x1}") + return x + x + + x = torch.ones(3, 3) + counters.clear() + with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}): + opt_f = torch.compile(backend="eager")(f) + opt_out = opt_f(x) + + self.assertEqual(sum(counters["graph_break"].values()), 1) + self.assertEqual(custom_logs[0], "moo") + self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}") + + @torch._dynamo.config.patch(reorderable_logging_functions={print}) + def test_constant_mutation(self): + def f(x): + alist = [x] + alist.append(x + 1) + print(alist[-1]) + alist[0].sum().item() # graph break + res = alist.pop() + print(alist[-1]) + res.sum().item() # graph break + return res + + inputs = (torch.tensor([1]),) + counters.clear() + opt_f = torch.compile(backend="eager")(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(*inputs) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(*inputs) + + self.assertEqual(printed_output, "tensor([2])\ntensor([1])") + self.assertTrue(same(orig_out, opt_out)) + + graph_break_key = counters["graph_break"].keys() + self.assertEqual(len(graph_break_key), 1) + self.assertEqual(next(iter(graph_break_key)), "Tensor.item") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3746fd96b095b34a0af3adff2c101294df7a59 --- /dev/null +++ b/test/dynamo/test_replay_record.py @@ -0,0 +1,196 @@ +# Owner(s): ["module: dynamo"] +import logging +import re +import shutil +import unittest + +import torch +import torch_npu +import torch._dynamo.test_case +import torch._dynamo.testing +from torch.testing._internal.common_utils import skipIfNoDill + + +class ReplayRecordTests(torch._dynamo.test_case.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + unittest.mock.patch.object( + torch._dynamo.config, "replay_record_enabled", True + ) + ) + torch._logging.set_logs(graph_breaks=True, dynamo=logging.ERROR) + # These tests require dynamo exceptions to be propagated up to the caller + cls._exit_stack.enter_context( + unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False) + ) + cls._exit_stack.enter_context( + unittest.mock.patch.object( + torch._dynamo.config, + "debug_dir_root", + "/tmp/_torchdynamo_debug_/", + ) + ) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(torch._dynamo.config.debug_dir_root, ignore_errors=True) + torch._logging.set_logs() + cls._exit_stack.close() + + def check_replay(self, fn, *args, exp_exc_name=None): + fn_opt = torch._dynamo.optimize("eager")(fn) + try: + fn_opt(*args) + except Exception as e: + if exp_exc_name is not None: + self.assertIn(exp_exc_name, str(e)) + expected_error = str(e) + else: + self.fail("opt_fn didn't raise an exception") + + file_name_match = re.search(r"torch._dynamo\.replay\('(.*)'\)", expected_error) + + # Remove replay message from expected error + expected_error = expected_error.split("\n") + for i, line in enumerate(expected_error): + if "torch._dynamo.replay" in line: + del expected_error[i + 1] # Empty line + del expected_error[i] # Replay message + break + expected_error = "\n".join(expected_error) + + self.maxDiff = None + self.assertTrue( + file_name_match is not None, + "No record file name found in generated logs.", + ) + try: + torch._dynamo.replay(file_name_match.groups()[0]) + except Exception as e: + actual_error = str(e) + if actual_error != expected_error: + raise e + else: + self.fail("Replayed frame didn't raise an exception") + + @skipIfNoDill + def test_unsuccessful_inline(self): + def level2(): + a = {10} + z = a["z"] # RuntimeError, Illegal to getitem on a set + return z * torch.ones(1) + + def level1(): + y = torch.ones(1, 1) + return level2() + y + + def level0(): + x = torch.ones(1, 1) + return level1() + x + + self.check_replay(level0, exp_exc_name="RuntimeError") + + @skipIfNoDill + def test_successful_inline(self): + def test_fn(): + x = torch.ones(2, 2) + + def level1(a): + return a + torch.ones(2, 2) + + y = level1(x) + + return y + torch.ones(3, 3) # dimension mismatch + + self.check_replay(test_fn, exp_exc_name="RuntimeError") + + @skipIfNoDill + def test_nonlocal_fn_call(self): + def nonlocal_fn(x): + return x + torch.ones(2, 2) + + def test_fn(): + z = torch.ones(2, 2) + x = nonlocal_fn(z) + return x + torch.ones(3, 3) + + self.check_replay(test_fn, exp_exc_name="RuntimeError") + + @skipIfNoDill + def test_nonlocal_module_fn_call(self): + # replay when we use a module + # not defined in the replay env + try: + from . import mock_modules + except ImportError: + import mock_modules + + def test_fn(): + z = mock_modules.mock_module2.method1([], 2) + z = torch.ones(2, 2) + z[0] + return z + torch.zeros(3, 3) + + self.check_replay(test_fn, exp_exc_name="RuntimeError") + + @skipIfNoDill + def test_nonlocal_module_class(self): + try: + from .mock_modules import mock_module2 + except ImportError: + from mock_modules import mock_module2 + + def test_fn(): + z = mock_module2.Class1(1, 2) + y = z.method2(torch.ones(3, 3)) + return y + torch.zeros(3, 5) + + self.check_replay(test_fn, exp_exc_name="RuntimeError") + + @skipIfNoDill + def test_local_module(self): + try: + from .mock_modules import mock_module3 as _ # noqa: F401 + + def test_fn(x): + from .mock_modules import mock_module3 + + z = mock_module3.method1([], torch.ones(5, 1)) + return torch.ones(2, 2) + x + z[0] + + except ImportError: + + def test_fn(x): + from mock_modules import mock_module3 + + z = mock_module3.method1([], torch.ones(5, 1)) + return torch.ones(2, 2) + x + z[0] + + self.check_replay(test_fn, torch.ones(1, 1), exp_exc_name="RuntimeError") + + # Verify that we replay when we have tensor arguments to the frame being replayed + @skipIfNoDill + def test_fn_call_args(self): + def test_fn(x, y): + return x + y + torch.zeros(2, 2) + + self.check_replay( + test_fn, torch.ones(3, 3), torch.ones(2, 2), exp_exc_name="RuntimeError" + ) + + # Verify that accessing torch.nn works when frame replaying is enabled + @skipIfNoDill + def test_torch_nn(self): + def fn(x): + y = torch.nn.functional.pad(x, (10, 10, 10, 10)) + return y + torch.ones(3, 3) # dimension mismatch + + x = torch.ones(4, 4, 4, 4) + self.check_replay(fn, x, exp_exc_name="RuntimeError") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 22e4892122ca4063c91400821c75f0a5a8a82108..215fb8b040e3fe437c6c99e504cb750dc64d2da5 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -15,40 +15,41 @@ import weakref from abc import ABC from collections import namedtuple from copy import deepcopy +from enum import Enum from functools import wraps from typing import List import numpy as np import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils import torch._functorch.config import torch.library -import torch.fx from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same from torch.nn import functional as F + +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( disable_translation_validation_if_dynamic_shapes, + TEST_WITH_ROCM, ) +from torch.testing._internal.two_tensor import TwoTensor _orig_module_call = torch.nn.Module.__call__ # Custom operator that only supports CPU and Meta -lib = torch.library.Library("test_sample", "DEF") +lib = torch.library.Library("test_sample", "DEF") # noqa: TOR901 lib.define("foo(Tensor self) -> Tensor") lib.impl("foo", torch.sin, "CPU") -requires_npu = functools.partial( - unittest.skipIf, not torch.npu.is_available(), "requires npu" -) +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") _GLOBAL_CPU_TENSOR = torch.randn(3) @@ -289,7 +290,7 @@ class _ReversibleFunction(torch.autograd.Function): # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args - res = ( + return ( grad_hidden_states, None, None, @@ -303,7 +304,6 @@ class _ReversibleFunction(torch.autograd.Function): None, None, ) - return res class ReformerEncoder(torch.nn.Module): @@ -317,7 +317,7 @@ class ReformerEncoder(torch.nn.Module): self, hidden_states, attention_mask=None, - head_mask=None, + head_mask=[None] * 6, num_hashes=None, use_cache=False, orig_sequence_length=64, @@ -328,7 +328,6 @@ class ReformerEncoder(torch.nn.Module): all_hidden_states = [] all_attentions = [] past_buckets_states = [((None), (None)) for i in range(len(self.layers))] - head_mask = [None] * 6 if head_mask is None else head_mask # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) @@ -863,6 +862,16 @@ class MockModule(torch.nn.Module): return self.inner_fn(tensor.shape, (1, 2, 3)) +class IncByOne: + def __init__(self, x): + self.x = x + 1 + + +class IncByTwo: + def __init__(self, x): + self.x = x + 2 + + class ReproTests(torch._dynamo.test_case.TestCase): def test_do_paste_mask(self): torch._dynamo.utils.counters.clear() @@ -1018,7 +1027,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): out_test.sum().backward() self.assertEqual(leaf.grad, leaf_test.grad) - # See See pytorch/pytorch/issues/97745 + # See pytorch/pytorch/issues/97745 def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self): def f(a, b): c = torch.ones(2, 2) @@ -1029,7 +1038,6 @@ class ReproTests(torch._dynamo.test_case.TestCase): fake_d_pred = torch.matmul(b, e.detach()) d_loss = fake_d_pred.mean() d_loss.backward() - return g_loss, d_loss a_ref = torch.randn(2, 2, requires_grad=True) b_ref = torch.randn(2, 2, requires_grad=True) @@ -1043,6 +1051,27 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(a_ref.grad, a_test.grad) self.assertEqual(b_ref.grad, b_test.grad) + # See pytorch/pytorch/issues/111603 + def test_tuple_enum_as_key_dict(self): + class MyEnum(Enum): + A = "a" + + class SomeModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x) -> torch.Tensor: + return self.linear(x[MyEnum.A]) + + x = {MyEnum.A: torch.rand(8, 1)} + model_pytorch = SomeModel() + model = torch.compile(model_pytorch) + # Executing twice works + model(x) + y = model(x) + self.assertEqual(y, model_pytorch(x)) + def test_embedding_backward_broadcasting_decomp(self): def f(grad_output, indices): num_weights = 10 @@ -1285,7 +1314,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.op_count, 2) # rand, rand try: graph, _ = torch._dynamo.export(fn)() - # See See pytorch/pytorch/pull/87490 + # See pytorch/pytorch/pull/87490 self.fail("unexpected export success") except torch._dynamo.exc.Unsupported: pass @@ -1316,7 +1345,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): torch.nn.Linear(10, 10), torch.nn.ReLU(), ) - # this one is tricky because it mutates the list provided as an input + # this one is tricky because it mutates the list provided as an ipt l1 = [x] l2 = [x] correct, _ = model(x, l1) @@ -1387,7 +1416,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) opt_test_fn() - # See See pytorch/pytorch/issues/100067 + # See pytorch/pytorch/issues/100067 def test_copy_weird_strides(self): # This test requires inductor's copy() decomp to preserve strides properly. def test_fn(a): @@ -1491,6 +1520,35 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 4) + def test_issue114171(self): + device = torch.device("cpu") + + def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU): + layers = [ + torch.nn.Linear(in_dim, hidden_dim, device=device), + activation(), + torch.nn.Linear(hidden_dim, out_dim, device=device), + ] + return torch.nn.Sequential(*layers) + + class testmodel(torch.nn.Module): + def __init__(self): + super().__init__() + self.interaction_networks = torch.nn.ModuleList( + [fcnn(262, 1174, 400) for _ in range(4)] + ) + + def interact(self, x, cycle): + return self.interaction_networks[cycle](x) + + model = testmodel() + forward_aot = torch.compile( + model.interact, fullgraph=True, dynamic=True, backend="eager" + ) + + x = torch.rand([111, 262], device=device) + y2 = forward_aot(x, 2) # previously failed + def test_issue175(self): n_heads = 2 d_model = 64 @@ -2065,8 +2123,10 @@ class ReproTests(torch._dynamo.test_case.TestCase): x = torch.rand([1]) self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x)) - @unittest.skipIf(not has_detectron2(), "requires detectron2") def test_multi_import(self): + if not has_detectron2(): + raise unittest.SkipTest("requires detectron2") + @torch._dynamo.optimize("eager", nopython=True) def to_bitmasks(boxes): from detectron2.layers.mask_ops import ( @@ -2087,6 +2147,8 @@ class ReproTests(torch._dynamo.test_case.TestCase): return torch.sin(x) def fn(x): + import torch.fx + _ = torch.fx.symbolic_trace(fn1) return x * 2 @@ -2201,7 +2263,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(mod(*args), opt_mod(*args))) def test_reinplacing(self): - class MockModule_toy(torch.nn.Module): + class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.self_layoutlm_embeddings_x_position_embeddings = ( @@ -2222,7 +2284,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings return (add_2,) - mod = MockModule_toy() + mod = MockModule() opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod) args = [ @@ -2237,7 +2299,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same_two_models(mod, opt_mod, args)) def test_optimized_deepcopy(self): - # See See pytorch/pytorch/pull/88629 + # See pytorch/pytorch/pull/88629 class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -2609,7 +2671,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): def test_exception_in_dynamo_handling(self): hit_handler = False - # See See pytorch/pytorch/pull/96488 + # See pytorch/pytorch/pull/96488 @contextlib.contextmanager def ctx(): try: @@ -2631,7 +2693,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(hit_handler) def test_generator_dealloc(self): - # See See pytorch/pytorch/pull/96488 + # See pytorch/pytorch/pull/96488 # # NB: yes, [(...)] is intentional, this is a list containing a # generator @@ -2717,13 +2779,13 @@ class ReproTests(torch._dynamo.test_case.TestCase): @torch.compile(backend=cnt, fullgraph=True) def fn(a, b): - l_list = CustomList2() - l_list.extend([True]) - l_list.append(a) - l_list.extend([b]) - l_list.pop(0) - l_list.append(l_list.length_times_10()) - return sum(l_list) + l_ = CustomList2() + l_.extend([True]) + l_.append(a) + l_.extend([b]) + l_.pop(0) + l_.append(l_.length_times_10()) + return sum(l_) x = torch.randn(10) y = torch.randn(10) @@ -2760,16 +2822,16 @@ class ReproTests(torch._dynamo.test_case.TestCase): x = torch.randn(10) y = torch.randn(10) - l_list = CustomList2([x, y]) - self.assertIs(fn(l_list, l_list), l_list) - self.assertEqual(len(l_list), 7) - self.assertIs(l_list[0], x) - self.assertIs(l_list[1], y) - self.assertIs(l_list[2], x) - self.assertIs(l_list[3], y) - self.assertEqual(l_list[4], x + 1) - self.assertIs(l_list[5], l_list[4]) - self.assertEqual(l_list[6], y + 2) + l_ = CustomList2([x, y]) + self.assertIs(fn(l_, l_), l_) + self.assertEqual(len(l_), 7) + self.assertIs(l_[0], x) + self.assertIs(l_[1], y) + self.assertIs(l_[2], x) + self.assertIs(l_[3], y) + self.assertEqual(l_[4], x + 1) + self.assertIs(l_[5], l_[4]) + self.assertEqual(l_[6], y + 2) def test_rewrite_assert_with_msg(self): def f(x): @@ -2797,11 +2859,11 @@ class ReproTests(torch._dynamo.test_case.TestCase): return a x = torch.randn(10) - l_list = [x] - self.assertIs(fn(l_list), l_list) - self.assertEqual(len(l_list), 2) - self.assertIs(l_list[0], x) - self.assertEqual(l_list[1], torch.sin(x)) + l_ = [x] + self.assertIs(fn(l_), l_) + self.assertEqual(len(l_), 2) + self.assertIs(l_[0], x) + self.assertEqual(l_[1], torch.sin(x)) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) @@ -2809,12 +2871,12 @@ class ReproTests(torch._dynamo.test_case.TestCase): def f(x): b = x.sin() if not x.sum() <= 3: - raise ValueError("input sum needs to be 3") + raise ValueError("ipt sum needs to be 3") return x.cos() + b args = (torch.Tensor([3, 4, 5]),) opt_fn = torch._dynamo.optimize("eager")(f) - with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): + with self.assertRaisesRegex(ValueError, "ipt sum needs to be 3"): opt_fn(*args) def test_rewrite_assert_dont_change_bytecode(self): @@ -3425,7 +3487,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): compiled_opt_step() compiled_model_step(x) - param_grad_ref = weakref.ref(list(model.parameters())[0].grad) + param_grad_ref = weakref.ref(next(iter(model.parameters())).grad) optimizer.zero_grad(True) self.assertIsNone(param_grad_ref()) @@ -3468,9 +3530,9 @@ class ReproTests(torch._dynamo.test_case.TestCase): def test_odict_get_item_index_name(self): d = {float: torch.float32, np.float16: torch.float16} - @torch.compile + @torch.compile(backend="eager") def f(x, y1, y2): - return torch.zeros(5, dtype=d.get(y1)), torch.zeros(5, dtype=d.get(y2)) + return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2]) f(torch.zeros(4), float, np.float16) @@ -3484,7 +3546,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): @requires_npu() def test_guard_default_device(self): try: - torch.set_default_device("npu") + torch.set_default_device("npu:0") counter = torch._dynamo.testing.CompileCounter() @@ -3527,10 +3589,10 @@ class ReproTests(torch._dynamo.test_case.TestCase): b = torch_bmm_nd(inp3, inp4, 4) b.unsqueeze_(2) - l_res = a + b + l_ = a + b out = torch.cat([a, b, c], dim=2) - return out, l_res + return out, l_ inp1 = torch.rand(1, 64, 448) inp2 = torch.rand(1, 448, 64) @@ -3659,6 +3721,16 @@ class ReproTests(torch._dynamo.test_case.TestCase): make_fn(None)() + def test_call_finally_opcode_python_3_8(self): + def fn(): + try: + return torch.zeros(4) + finally: + return torch.ones(4) # noqa: SIM107, B012 + + result = torch.compile(fn, backend="aot_eager")() + self.assertEqual(result, torch.ones(4)) + def test_string_format(self): s = "temp{i}" @@ -3674,8 +3746,8 @@ class ReproTests(torch._dynamo.test_case.TestCase): # Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards' # due to bad empty list handling def test_empty_list_contains_with_jump(self): - def fn(x, l_arg): - if x in l_arg: + def fn(x, l_): + if x in l_: return x.cos() return x.sin() @@ -3723,7 +3795,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_deferred_runtime_asserts(self): - @torch.compile(backend="aot_eager", fullgraph=True) + @torch.compile(fullgraph=True) def f(x): y = x.item() torch._check_is_size(y) @@ -3814,6 +3886,51 @@ class ReproTests(torch._dynamo.test_case.TestCase): fn_opt = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1))) + @torch._dynamo.config.patch(log_compilation_metrics=True) + def test_many_views_with_mutation(self): + # When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap + # began adding shape guards - a quadratic amount relative to the number of inputs. + # Test this configuration, and test that a reasonable number of guards are added. + # Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards. + def fn(x): + x[0].relu_() + return torch.cat(x).sum() + + AMT = 32 + src = torch.rand(16 * (AMT + 1)) + + x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)] + + torch._dynamo.reset() + torch._dynamo.utils.clear_compilation_metrics() + + res = torch.compile(fn, backend="aot_eager")(x) + + all_metrics = torch._dynamo.utils.get_compilation_metrics() + + total_guards = sum(metric.guard_count for metric in all_metrics) + self.assertLess(total_guards, AMT * 8) + + total_shape_env_guards = sum( + metric.shape_env_guard_count for metric in all_metrics + ) + self.assertLess(total_shape_env_guards, AMT * 8) + + # See pytorch/pytorch/issues/118799 + def test_subclass_graph_output_repro(self): + @torch._dynamo.allow_in_graph + def to_subclass(x): + return TwoTensor(x.clone(), x.clone()) + + def f(x): + tmp_subclass = to_subclass(x) + return tmp_subclass.view(-1) + + x = torch.ones(2) + out_ref = f(x) + out_test = torch.compile(f, backend="aot_eager")(x) + self.assertEqual(out_ref, out_test) + def test_numpy_tobytes_no_error(self): def fn(x): x += 1 @@ -3828,6 +3945,8 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 2) def test_numpy_not_ndarray_recompiles(self): + import torch + def fn(x=None): if x is None: x = np.ones(3) @@ -3867,7 +3986,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): z = x x.data = y y.data = torch.zeros([0]) - return x is z + return torch.tensor(x is z) for backend in ["eager", "aot_eager", "inductor"]: for func in [func1, func2, func3]: @@ -3893,7 +4012,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): out_compiled = compiled_fn(compiled_a, compiled_b) self.assertEqual(eager_a, compiled_a) self.assertEqual(eager_b, compiled_b) - self.assertEqual(out_eager, out_compiled) + self.assertTrue(torch.equal(out_eager, out_compiled)) # func1 hits a leaf Variable that requires grad is being used in an in-place operation if requires_grad: @@ -3907,6 +4026,285 @@ class ReproTests(torch._dynamo.test_case.TestCase): # frame_count should stay at 1. self.assertEqual(cnt.frame_count, 1) + @unittest.skipIf( + TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "flash attention not supported", + ) + def test_flash_attn_backward_mixed_strides(self): + # in this repro, "grad_out" and "value" are transposed tensors, + # but "key" and "value" are contiguous + def gen_inputs(device): + return ( + torch.randn( + 2, 513, 16, 64, dtype=torch.float16, device=device + ).transpose(1, 2), + torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), + torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), + torch.randn( + 2, 513, 16, 64, dtype=torch.float16, device=device + ).transpose(1, 2), + torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), + torch.randn(2, 16, 513, device=device), + None, + None, + 513, + 513, + 0.0, + False, + torch.tensor(1, dtype=torch.int64), + torch.tensor(1, dtype=torch.int64), + ) + + inps_cuda = gen_inputs("npu:0") + inps_meta = gen_inputs("meta") + ( + out1_ref, + out2_ref, + out3_ref, + ) = torch.ops.aten._scaled_dot_product_flash_attention_backward( + *inps_cuda, scale=0.125 + ) + from torch._meta_registrations import meta__scaled_dot_product_flash_backward + + out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward( + *inps_meta, scale=0.125 + ) + + self.assertEqual(out1_ref.shape, out1_test.shape) + self.assertEqual(out1_ref.stride(), out1_test.stride()) + self.assertEqual(out2_ref.shape, out2_test.shape) + self.assertEqual(out2_ref.stride(), out2_test.stride()) + self.assertEqual(out3_ref.shape, out3_test.shape) + self.assertEqual(out3_ref.stride(), out3_test.stride()) + + def test_user_ctor_ctx_manager(self): + class UserCtxManager: + def __enter__(self): + return 1 + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def fn(x, y): + ucm = UserCtxManager() + return x * x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) + x = torch.rand([2, 2]) + opt_fn(x, x) + self.assertEqual(cnt.frame_count, 1) + + def test_user_ctor_ctx_manager_custom_init(self): + class UserCtxManager: + def __init__(self, x): + x[0] = 10 + + def __enter__(self): + return 1 + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def fn(x, y): + ucm = UserCtxManager(y) + return x * y[0] + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) + x = torch.rand([2, 2]) + self.assertEqual(opt_fn(x, [5]), fn(x, [5])) + self.assertEqual(cnt.frame_count, 1) + + def test_user_ctor_ctx_manager_custom_init_graph_break(self): + counter = [0] + + class UserCtxManager: + def __init__(self, k): + k[0] += 1 + + def __enter__(self): + return 1 + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def fn(x, counter): + x = x * x + ucm = UserCtxManager(counter) + return x * x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.rand([2, 2]) + self.assertEqual(opt_fn(x, counter), fn(x, counter)) + self.assertEqual(counter[0], 2) + for i in range(0, 10): + opt_fn(x, counter) + self.assertEqual(counter[0], 12) + self.assertEqual(cnt.frame_count, torch._dynamo.utils.ifdynstaticdefault(3, 2)) + + @unittest.expectedFailure + def test_many_overlapping_inputs_does_not_explode_guards(self): + from torch._dynamo.backends.common import aot_autograd + + # Before, this was (9702, 0) + num_shape_guards = None + num_aot_guards = None + num_compiles = 0 + + def guard_count_backend(gm, *args): + nonlocal num_shape_guards + nonlocal num_aot_guards + nonlocal num_compiles + num_shape_guards = len( + torch._guards.TracingContext.try_get().fake_mode.shape_env.guards + ) + num_aot_guards = len( + torch._guards.TracingContext.try_get().guards_context.aotautograd_guards + ) + num_compiles += 1 + return gm + + aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend) + + @torch.compile(backend=aot_guard_counter, dynamic=True) + def f(*args): + for a in args: + a.add_(1) + + x = torch.ones(1000, requires_grad=True) + args = x.split(10) + + with torch.no_grad(): + f(*args) + # In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2), + # because every pair of aliased inputs needs a guard. + self.assertTrue(num_aot_guards < 5000) + # But there are no dynamic shape guards. + self.assertEqual(num_shape_guards, 0) + # don't recompile + with torch.no_grad(): + f(*args) + self.assertEqual(num_compiles, 1) + + def test_invalid_seq_unpack(self): + def myfn(arg): + (a, b) = arg + + def fn(): + return myfn((1, 2, 3)) + + try: + torch.compile(fn)() + except ValueError: + pass + else: + self.fail("expected exception") + + def test_udf_classes_reconstruction(self): + def fn(x): + o = T(5) + return o.x + x + + opt_fn = torch.compile(fn, backend="eager") + T = IncByOne + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + + # This should recompile + T = IncByTwo + self.assertEqual(fn(x), opt_fn(x)) + + def test_dont_aggressively_write_assert(self): + record_graph = torch._dynamo.testing.EagerAndRecordGraphs() + + @torch.compile(dynamic=True, backend=record_graph) + def f(x): + assert x.shape[0] > 3 + assert x[0].sum() > 0 + assert 1 % (x.shape[0] // 2) != 0 + assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0 + return x.cos() + + f(torch.ones(6, 4)) + graph = record_graph.graphs[0] + # It is bit annoying that we generate useless statements for + # shape guards, but DCE should be able to remove them since t + # there is no backed assert on them. The reason this is ok is + # because dynamo will only skip the assert statement, but not + # the instructions before it. + self.assertExpectedInline( + str(graph.code).strip(), + """\ +def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): + l_x_ = L_x_ + size = l_x_.size() + getitem = size[0]; size = None + gt = getitem > 3; getitem = None + getitem_2 = l_x_[0] + sum_1 = getitem_2.sum(); getitem_2 = None + gt_1 = sum_1 > 0; sum_1 = None + _assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = None + size_1 = l_x_.size() + getitem_3 = size_1[0]; size_1 = None + floordiv = getitem_3 // 2; getitem_3 = None + mod = 1 % floordiv; floordiv = None + ne = mod != 0; mod = None + size_2 = l_x_.size() + getitem_5 = size_2[0]; size_2 = None + floordiv_1 = getitem_5 // 2; getitem_5 = None + pow_1 = floordiv_1 ** 2; floordiv_1 = None + mul = 32 * pow_1; pow_1 = None + size_3 = l_x_.size() + getitem_7 = size_3[0]; size_3 = None + floordiv_2 = getitem_7 // 2; getitem_7 = None + mul_1 = 16 * floordiv_2; floordiv_2 = None + sub = mul - mul_1; mul = mul_1 = None + ne_1 = sub != 0; sub = None + cos = l_x_.cos(); l_x_ = None + return (cos,)""", + ) + for node in graph.graph.nodes: + if "example_value" in node.meta and isinstance( + node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor + ): + shape_env = node.meta["example_value"].fake_mode.shape_env + lower_ranges = [val.lower for val in shape_env.var_to_range.values()] + self.assertTrue(lower_ranges == [4, 2]) + + @torch.compile(dynamic=True, backend=record_graph) + def f_fail(x): + assert x.shape[0] < 3 + + # We graph-break here, so the failure should be eager + with self.assertRaisesRegex(AssertionError, ""): + f_fail(torch.ones(6, 4)) + + def test_super_in_staticmethod(self): + class A: + @staticmethod + def foo(): + return super().__init__() + + def fn(obj): + return obj.foo() + + obj = A() + + try: + fn(obj) + except Exception as e: + orig_str = str(e) + self.assertIn("no arguments", orig_str) + + try: + torch.compile(backend="eager")(fn)(obj) + except Exception as e: + compiled_str = str(e) + self.assertEqual(orig_str, compiled_str) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_sdpa.py b/test/dynamo/test_sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..9572b6ff06503fff679abdf87e2b90014026d3ec --- /dev/null +++ b/test/dynamo/test_sdpa.py @@ -0,0 +1,107 @@ +# Owner(s): ["module: dynamo"] +import contextlib +import torch +import torch_npu +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.testing import CompileCounter +from torch.backends.cuda import SDPAParams + + +@contextlib.contextmanager +def allow_in_graph_sdpa_params(): + global SDPAParams + try: + old = SDPAParams + SDPAParams = torch._dynamo.allow_in_graph(SDPAParams) + yield + finally: + SDPAParams = old + + +class TestSDPA(torch._dynamo.test_case.TestCase): + def assert_ref_equals_params(self, actual, expected): + self.assertIs(actual.query, expected.query) + self.assertIs(actual.key, expected.key) + self.assertIs(actual.value, expected.value) + self.assertIs(actual.attn_mask, expected.attn_mask) + + def test_returns_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(fullgraph=True, backend=counter) + def fn(q, k, v, m): + return SDPAParams(q, k, v, m, 0.1, True) + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + o = fn(q, k, v, m) + self.assertTrue(isinstance(o, SDPAParams)) + self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) + self.assertEqual(counter.frame_count, 1) + + def test_graph_break_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(backend=counter) + def fn(q, k, v, m): + z = SDPAParams(q, k, v, m, 0.1, True) + torch._dynamo.graph_break() + return z, q + 1 + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + o, _ = fn(q, k, v, m) + self.assertTrue(isinstance(o, SDPAParams)) + self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) + self.assertEqual(counter.frame_count, 2) + + def test_input_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(backend=counter) + def fn(sdpap, q): + torch._dynamo.graph_break() + return sdpap, sdpap.query + q + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + s = SDPAParams(q, k, v, m, 0.1, True) + o, _ = fn(s, q) + self.assertIs(o, s) + self.assertEqual(counter.frame_count, 1) + + def test_intermediate_attr_access_SDPAParams(self): + with allow_in_graph_sdpa_params(): + counter = CompileCounter() + + @torch.compile(fullgraph=True, backend=counter) + def fn(q, k, v, m): + q += 1 + z = SDPAParams(q, k, v, m, 0.1, True) + a = z.query + return a + 1, z, q + + q = torch.randn(10) + k = torch.randn(10) + v = torch.randn(10) + m = torch.randn(10) + _, o, _ = fn(q, k, v, m) + expected = SDPAParams(q, k, v, m, 0.1, True) + self.assert_ref_equals_params(o, expected) + self.assertEqual(counter.frame_count, 1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index 584113e4444ca813136401cca22162efc1cb4114..e8e787188503626d2c44bfc9bb6fb223b8f44f57 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -161,7 +161,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase): assert counter.op_count == 0 def test_do_not_skip_side_effects(self): - # see pytorch issue 110765 + # See pytorch/pytorch/issues/110765 # By invoking torch._utils.is_compiling(), # there may be side-effects inconsistent with eager when diff --git a/test/dynamo/test_sources.py b/test/dynamo/test_sources.py index eda4e340bacc079a88d39d6268a4196739011054..d6a197094849469e2d67b761bbf3a379b5944477 100644 --- a/test/dynamo/test_sources.py +++ b/test/dynamo/test_sources.py @@ -1,15 +1,10 @@ # Owner(s): ["module: dynamo"] -""" -PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes -with test_export_persist_assert) -""" import torch import torch_npu - import torch._dynamo import torch._dynamo.test_case -import torch._dynamo.testing +import torch.nn as nn from torch._dynamo.source import ( AttrSource, GlobalSource, @@ -18,6 +13,10 @@ from torch._dynamo.source import ( ) +class CausalLMOutputWithPast: + value = 5 + + class SourceTests(torch._dynamo.test_case.TestCase): def test_is_local(self): x_src = LocalSource("x") @@ -28,3 +27,55 @@ class SourceTests(torch._dynamo.test_case.TestCase): self.assertTrue(is_from_local_source(attr_x_a)) self.assertEqual(is_from_local_source(attr_y_b), False) + + def test_property_closure(self): + def external_property(): + closed_value = 7 + + def internal_function(self): + return closed_value + + return internal_function + + class Elements: + myprop = property(external_property()) + + def func(elements): + if not elements.myprop: + return torch.tensor([1, 2, 3]) + else: + return torch.tensor([4, 5, 6]) + + e = Elements() + a = func(e) + b = torch.compile(func, backend="eager", fullgraph=True)(e) + self.assertEqual(a, b) + + def test_supported_nodes(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.x = torch.randn(10, 10) + + def forward(self): + if ( + torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type + == int + ): + x = torch.sin(self.x) + else: + x = torch.cos(self.x) + return x + + torch.utils._pytree.register_pytree_node( + CausalLMOutputWithPast, + lambda x: ((), None), + lambda x, _: CausalLMOutputWithPast(), + ) + + # breakpoint() + torch.export.export(Model(), ()) + + +if __name__ == "__main__": + torch._dynamo.test_case.run_tests() diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6417543a4ca5257cc88964ba5a1819d7b04d10 --- /dev/null +++ b/test/dynamo/test_structured_trace.py @@ -0,0 +1,393 @@ +# Owner(s): ["module: dynamo"] +import copy +import functools +import io +import json +import logging +import os +import shutil +import subprocess +import tempfile +import unittest.mock + +import torch +import torch_npu +import torch._dynamo.test_case +import torch._dynamo.testing +import torch._logging.structured +import torch.distributed as dist + +from torch._logging._internal import TorchLogsFormatter +from torch.nn.parallel import DistributedDataParallel as DDP + +from torch.testing._internal.common_utils import find_free_port, TestCase + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") +requires_distributed = functools.partial( + unittest.skipIf, not dist.is_available(), "requires distributed" +) + + +def example_fn(a): + output = a.mul(torch.ones(1000, 1000)) + output = output.add(torch.ones(1000, 1000)) + return output + + +def dynamo_error_fn(a): + output = a.mul(torch.ones(1000, 1000)) + output = output.add(torch.ones(10, 10)) + return output + + +def inductor_error_fn(a): + output = torch.round(a) + return output + + +def inductor_schedule_fn(a): + output = a.add(torch.ones(1000, 1000, device="npu:0")) + return output + + +ARGS = (torch.ones(1000, 1000, requires_grad=True),) + + +class StructuredTraceTestingFilter(logging.Filter): + def filter(self, record): + if "str" in record.metadata: + return False + return True + + +class StructuredTraceTestingFormatter(logging.Formatter): + def format(self, record): + metadata = copy.deepcopy(record.metadata) + + # Stub out values that are not stable across runs + # do for later: Check that these match schema + if "has_payload" in metadata: + metadata["has_payload"] = "HASH" + if "dynamo_start" in metadata: + metadata["dynamo_start"]["stack"] = "STACK" + if "inductor_output_code" in metadata: + metadata["inductor_output_code"]["filename"] = "FILENAME" + + return json.dumps(metadata) + + +trace_log = logging.getLogger("torch.__trace") + + +class StructuredTraceTest(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + torch._logging.structured.INTERN_TABLE.clear() + self.buffer = io.StringIO() + self.old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + + self.handler = logging.StreamHandler(self.buffer) + self.handler.setFormatter(StructuredTraceTestingFormatter()) + self.handler.addFilter(StructuredTraceTestingFilter()) + trace_log.addHandler(self.handler) + + self.raw_file = tempfile.NamedTemporaryFile( + mode="w", delete=True + ) # set this to False to keep temporary files + self.raw_handler = logging.StreamHandler(self.raw_file) + self.raw_handler.setFormatter(TorchLogsFormatter(trace=True)) + trace_log.addHandler(self.raw_handler) + + def tearDown(self): + trace_log.removeHandler(self.handler) + trace_log.removeHandler(self.raw_handler) + self.raw_file.close() + trace_log.setLevel(self.old_level) + + def assertParses(self): + out = tempfile.mkdtemp() + try: + subprocess.check_call( + [ + "tlparse", + "-o", + out, + "--overwrite", + "--no-browser", + "--strict", + self.raw_file.name, + ] + ) + finally: + shutil.rmtree(out, ignore_errors=True) + + @requires_npu() + def test_schedule(self): + fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) + fn_opt(torch.ones(1000, 1000, device="npu:0")) + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + @requires_npu() + def test_cudagraphs(self): + fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) + fn_opt(torch.ones(1000, 1000, device="npu:0")) + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + def test_recompiles(self): + def fn(x, y): + return torch.add(x, y) + + fn_opt = torch._dynamo.optimize("inductor")(fn) + fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000)) + fn_opt(torch.ones(1000, 1000), 1) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + def test_example_fn(self): + fn_opt = torch._dynamo.optimize("inductor")(example_fn) + fn_opt(torch.ones(1000, 1000)) + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + def test_dynamo_error(self): + try: + fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) + fn_opt(*ARGS) + except Exception: + pass + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +""", # noqa: B950 + ) + + self.assertParses() + + def test_inductor_error(self): + import torch._inductor.lowering + + def throw(x): + raise AssertionError() + + # inject an error in the lowerings + dict_entries = {} + for x in list(torch._inductor.lowering.lowerings.keys()): + if "round" in x.__name__: + dict_entries[x] = throw + + with unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries): + try: + fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn) + fn_opt(*ARGS) + except Exception: + pass + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + @requires_distributed() + @requires_npu() + def test_ddp_graphs(self): + class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(1024, 1024), + torch.nn.Linear(1024, 1024), + ) + + def forward(self, x): + return self.layers(x) + + # do for later: this isn't safely bracketed, will leak + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(find_free_port()) + dist.init_process_group("hccl", rank=0, world_size=1) + + ddp_model = torch._dynamo.optimize("npu")( + DDP(ToyModel().to("npu:0"), device_ids=[0], bucket_cap_mb=4) + ) + + ddp_model(torch.randn(1024, 1024, device="npu:0")) + + dist.destroy_process_group() + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_guards": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1024, 1024], "l__self___layers_0": [1024, 1024], "l__self___layers_1": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + def test_graph_breaks(self): + @torch._dynamo.optimize("inductor") + def fn(x): + torch._dynamo.graph_break() + return x + 1 + + fn(torch.ones(1)) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + # do for later: bring in the trace_source tests once we start emitting bytecode + + def test_graph_sizes_dynamic(self): + def fn(a, b): + return a @ b + + fn_opt = torch._dynamo.optimize("eager", dynamic=False)(fn) + fn_opt(torch.randn(10, 20), torch.randn(20, 30)) + + fn_opt2 = torch._dynamo.optimize("eager", dynamic=True)(fn) + fn_opt2(torch.randn(5, 10), torch.randn(10, 15)) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + def test_guards_recompiles(self): + def fn(x, ys, zs): + return inner(x, ys, zs) + + def inner(x, ys, zs): + for y, z in zip(ys, zs): + x += y * z + return x + + ys = [1.0, 2.0] + zs = [3.0] + x = torch.tensor([1.0]) + + fn_opt = torch._dynamo.optimize("eager")(fn) + fn_opt(x, ys, zs) + fn_opt(x, ys[:1], zs) + + self.assertExpectedInline( + self.buffer.getvalue(), + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +""", # noqa: B950 + ) + + self.assertParses() + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index e9edb9436566aff376bcd1186faca88bddfdd8a4..10b9c7d0d5572484d57b82f73ae688dbc7f60622 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -5,8 +5,6 @@ import unittest import torch import torch_npu -import torchair - import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config @@ -14,7 +12,7 @@ import torch.utils._pytree as pytree import torch.utils.checkpoint from torch._dynamo.testing import normalize_gm from torch._higher_order_ops.wrap import wrap -import torch.fx._symbolic_trace + from torch.fx.experimental.symbolic_shapes import ( DimDynamic, ShapeEnv, @@ -23,13 +21,29 @@ from torch.fx.experimental.symbolic_shapes import ( from torch.nested._internal.nested_tensor import ( jagged_from_list, jagged_from_tensor_and_lengths, - ViewBufferFromNested, + nested_view_from_values_offsets, + NestedTensor, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + subtest, ) +from torch.testing._internal.two_tensor import TwoTensor +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") compile_full_eager = torch.compile(backend="eager", fullgraph=True) +class BaseTorchFunction(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + class MockSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -50,6 +64,21 @@ class DummyNDim(torch.Tensor): return super().__torch_function__(func, types, args, kwargs) +class WrapperSubclass: + def __init__(self, tensor): + self.tensor = tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) + kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) + + return func(*args, **kwargs) + + class SigmoidToExpSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -62,6 +91,59 @@ class SigmoidToExpSubclass(torch.Tensor): return super().__torch_function__(func, types, args, kwargs) +# Wrapper subclass with two inner tensors: data and scale +# data has same shape as outer, and scale has single dim size +class ScaledTensor(torch.Tensor): + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + *, + constant: int = 0, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): + self._data = data + self._scale = scale + self._constant = constant + + def __tensor_flatten__(self): + ctx = {"_constant": self._constant} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + constant=metadata["_constant"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + scaled_tensor = args[0] + out = func(scaled_tensor._data, *args[1:], **kwargs) + return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" + + +def func(a): + return a.sin() + + class EagerRecordGraphAndInputs: def __init__(self): self.graphs = [] @@ -73,7 +155,27 @@ class EagerRecordGraphAndInputs: return gm -GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass} +GLOBAL_TEST_SUBCLASSES = { + MockSubclass, + DummyNDim, + SigmoidToExpSubclass, + BaseTorchFunction, +} + + +# Returns True if the function recompiles between inputs1 and inputs2 with the +# specified dynamic setting. +def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): + compile_count = [0] + + def counter(gm, example_inputs): + compile_count[0] += 1 + return gm + + compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) + compiled_f(*inputs1) + compiled_f(*inputs2) + return compile_count[0] > 1 class SubclassTests(torch._dynamo.test_case.TestCase): @@ -90,6 +192,40 @@ class SubclassTests(torch._dynamo.test_case.TestCase): def tearDownClass(cls): cls._exit_stack.close() + def test_no_call_to_new(self): + class BadNewTorchFunction(torch.Tensor): + def __new__(cls, *args, **kwargs): + raise RuntimeError("Oops!") + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {BadNewTorchFunction} + ): + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return torch.add(x, 1) + + ipt = torch.ones(2, 2).as_subclass(BadNewTorchFunction) + + res = fn(ipt) + self.assertIsInstance(res, BadNewTorchFunction) + + def test_base_torch_function_tracing(self): + def fn(x): + return torch.add(x, 1) + + ipt = torch.ones(2, 2).as_subclass(BaseTorchFunction) + out = fn(ipt) + out_opt = compile_full_eager(fn)(ipt) + self.assertIsInstance(out, BaseTorchFunction) + self.assertEqual(out, out_opt) + def test_torch_function_state_graph_break(self): @torch.compile(backend="eager") def fn(x): @@ -97,8 +233,8 @@ class SubclassTests(torch._dynamo.test_case.TestCase): torch._dynamo.graph_break() return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) - input1 = torch.ones(2, 2) - res, _ = fn(input1) + ipt = torch.ones(2, 2) + res, _ = fn(ipt) self.assertFalse(res) def test_torch_function_state_nested(self): @@ -150,6 +286,16 @@ class SubclassTests(torch._dynamo.test_case.TestCase): res = fn(ipt) self.assertIsInstance(res, MockSubclass) + def test_return_as_subclass(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return torch.add(x, 1.0).as_subclass(MockSubclass) + + ipt = torch.ones(2, 2) + + res = fn(ipt) + self.assertIsInstance(res, MockSubclass) + def test_return_local_subclass(self): class LocalSubclass(torch.Tensor): @classmethod @@ -169,6 +315,35 @@ class SubclassTests(torch._dynamo.test_case.TestCase): res = fn(ipt) self.assertIsInstance(res, LocalSubclass) + @parametrize( + "comparison", + [ + subtest(isinstance, "isinstance"), + subtest(lambda instance, type_: type(instance) == type_, "equality"), + subtest(lambda instance, type_: type(instance) is type_, "identity"), + ], + ) + @parametrize( + "input_type", + [ + subtest(torch.Tensor, "tensor"), + subtest(DummyNDim, "subclass"), + ], + ) + def test_type_check(self, comparison, input_type): + with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): + + def fn(x): + if comparison(x, DummyNDim): + return torch.ones(1, 1) + else: + return torch.zeros(2, 2) + + ipt = torch.ones(2, 2).as_subclass(input_type) + exp_res = fn(ipt) + act_res = torch.compile(backend="eager", fullgraph=True)(fn)(ipt) + self.assertEqual(exp_res, act_res) + def test_torch_function_call_on_method(self): x = torch.ones(2, 2) y = torch.ones(2, 2) @@ -199,17 +374,17 @@ class SubclassTests(torch._dynamo.test_case.TestCase): def sigmoid(self): return None - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - x.sigmoid() + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + x.sigmoid() msg = ( "Accessing overridden method/attribute sigmoid on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) @@ -223,17 +398,17 @@ class SubclassTests(torch._dynamo.test_case.TestCase): ndim = 10 - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - return x.ndim + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) @@ -256,17 +431,17 @@ class SubclassTests(torch._dynamo.test_case.TestCase): def ndim(self, value): self._ndim = value - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - return x.ndim + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) @@ -315,6 +490,32 @@ class SubclassTests(torch._dynamo.test_case.TestCase): self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, torch.ones(2) + 10) + def test_torch_function_wrapper_class(self): + x = torch.ones(2, 2) + wrapped = WrapperSubclass(x) + + def fn(w): + return torch.add(w, 1.0) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + self.assertEqual(res_exp, res_act) + + def test_torch_function_wrapper_class_with_kwargs(self): + x = torch.ones(2, 2) + wrapped = WrapperSubclass(x) + + def fn(w): + return torch.add(w, 1.0, alpha=2.0) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(wrapped) + res_act = fn_opt(wrapped) + self.assertEqual(res_exp, res_act) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) @@ -607,7 +808,7 @@ class GraphModule(torch.nn.Module): return ["inner_elem"], None @staticmethod - def __tensor_unflatten__(inner_tensors, _): + def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) def __repr__(self): @@ -629,24 +830,19 @@ class GraphModule(torch.nn.Module): return DoubleSizeMaybeAddGeThreeTensor(out_inner) - lower_bound_str = None - upper_bound_str = None curr_var_to_val = None curr_var_to_sources = None + guards = None def backend(gm, args): - print(gm.code) context = torch._guards.TracingContext.get() - val_to_guards = list(context.fake_mode.shape_env.var_to_guards.values()) # Grab info on sources and guards from the shapeenv - nonlocal lower_bound_str - nonlocal upper_bound_str nonlocal curr_var_to_val nonlocal curr_var_to_sources + nonlocal guards - lower_bound_str = str(val_to_guards[0][0].expr) - upper_bound_str = str(val_to_guards[0][1].expr) + guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] curr_var_to_val = { str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() } @@ -678,14 +874,70 @@ class GraphModule(torch.nn.Module): "s0": "L['x'].size()[0]", "s1": "L['x'].inner_elem.size()[0]", } - # lower bound comes from code underneath torch_dispatch (operating on the inner tensor size) - expected_lower_bound = "s1 > 3" - # upper bound comes from user code (operating on the wrapper size) - expected_upper_bound = "2*s1 < 10" self.assertEqual(curr_var_to_val, expected_var_to_val) self.assertEqual(curr_var_to_sources, expected_var_to_sources) - self.assertEqual(lower_bound_str, expected_lower_bound) - self.assertEqual(upper_bound_str, expected_upper_bound) + self.assertExpectedInline( + "\n".join(guards), + """\ +Eq(2*s1, s0) +2*s1 < 10 +s1 > 3""", + ) + + def test_wrapper_subclass_with_same_sized_inner_tensor(self): + # shouldn't recompile for different sizes when dynamic=True + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) + self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) + + # should recompile for different data size when dynamic=False + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + # avoid recompile using manual mark_dynamic() for different data size + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) + # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size + torch._dynamo.mark_dynamic(sub1, 0) + torch._dynamo.mark_dynamic(sub1, 1) + sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) + self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + def test_wrapper_subclass_with_differently_sized_inner_tensor(self): + # should recompile for different scale size when dynamic=False + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) + sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + # still recompiles using manual mark_dynamic() on outer for different scale size + sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) + # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size + torch._dynamo.mark_dynamic(sub1, 0) + torch._dynamo.mark_dynamic(sub1, 1) + sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) + self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) + + def test_torch_dispatch_subclass_guard_recompile(self): + x = torch.ones(2, 2) + x_two = TwoTensor(x.clone(), x.clone()) + + def fn(w): + return torch.add(w, 1.0) + + fn_opt = torch.compile(backend="eager")(fn) + + ref = fn(x_two) + res = fn_opt(x_two) + self.assertEqual(ref, res) + + # ensure no recompilation on same input type + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + fn_opt(TwoTensor(x + 1, x + 2)) + + # recompile! + ref = fn(x) + res = fn_opt(x) + self.assertEqual(ref, res) def test_recompile_with_symbool_inputs(self): def f(pred: bool): @@ -780,9 +1032,21 @@ class GraphModule(torch.nn.Module): ], ) + def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): + def f(x_subclass): + tmp_subclass = torch.add(x, 1) + return torch.mul(tmp_subclass._scale, tmp_subclass._constant) + + x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) + out_ref = f(x) + out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(out_ref, out_test) + def test_support_bases(self): import abc + import torch.fx._symbolic_trace + class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): def __new__(cls, name, bases, dct): x = super().__new__(cls, name, bases, dct) @@ -803,6 +1067,34 @@ class GraphModule(torch.nn.Module): self.assertEqual(f(torch.randn(1)), (Multistreamable,)) + @parametrize("dynamic", [False, True]) + def test_subclass_views(self, dynamic): + def _get_views(t): + # Note that any closed-over SymInts will be symbolicized during fake-ification. + yield t.narrow(dim=-1, start=3, length=8) + yield t.split(5, -1) + yield t.split_with_sizes([9, 6], -1) + yield t.unsqueeze(-1).expand(4, 15, 10) + yield t.select(-1, 6) + yield t[2:3, 5:9] + + def f(x): + return x * 2 + + compiled_f = torch.compile( + f, backend="aot_eager", fullgraph=True, dynamic=dynamic + ) + + # Take a view of a subclass to pass as input. + t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) + for view in _get_views(t): + out_ref = f(view) + out_test = compiled_f(view) + self.assertEqual(out_ref, out_test) + + +instantiate_parametrized_tests(SubclassTests) + class TestNestedTensor(torch._dynamo.test_case.TestCase): def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): @@ -829,18 +1121,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): ) return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) - def _check_recompiles(self, fn, inputs1, inputs2, recompiles): - compile_count = [0] - - def counter(gm, example_inputs): - compile_count[0] += 1 - return gm - - compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=True) - out = compiled_f(*inputs1) - self.assertEqual(compile_count[0], 1) - out = compiled_f(*inputs2) - self.assertEqual(compile_count[0], 2 if recompiles else 1) + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): + actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) + self.assertEqual(actual_recompiles, expected_recompiles) def test_unary_does_not_recompile(self): nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) @@ -854,9 +1137,11 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): else: return nt1.sin() - # Basic binary - nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None) - nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets) + # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). + # This causes a recompile later on when it realizes the batch and last dim + # should not always be equal. To avoid that, we use (3, j0, 5) here. + nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) @@ -869,9 +1154,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): return nt1.sin() # Binary recompiles because singleton ints no longer match - nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None) - nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets) - nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) + nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) + nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) # do for later: cannot parametrize this test class with device for some reason @@ -879,19 +1164,20 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) - nt, offsets = jagged_from_list([a, b, c], None) - nt2, _ = jagged_from_list([a, b, c], offsets) + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + # do for later: Switch to public API when it exists + nt2, _ = jagged_from_list([a, b, c], nt.offsets()) def fn1(nt1, nt2): return (nt1 + nt2).sin().cos() compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) out = compiled_f(nt, nt2) - out_buffer = ViewBufferFromNested.apply(out) + out_buffer = out.values() ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) out_ref = fn1(nt, nt2) - out_buffer_ref = ViewBufferFromNested.apply(out_ref) + out_buffer_ref = out_ref.values() ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) self.assertTrue(torch.allclose(ga, ga_ref)) @@ -901,17 +1187,48 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): def test_basic_autograd(self): self._test_autograd("aot_eager") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @requires_npu() def test_basic_autograd_inductor(self): self._test_autograd("inductor") - - @unittest.skipIf(not torch.npu.is_available(), "requires npu") - def test_basic_autograd_npu_backend(self): - npu_backend = torchair.get_npu_backend() - self._test_autograd(npu_backend) + + def test_subclass_with_mutation_in_graph(self): + # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed + # to remain in the graph. Normally this is allowed, but it's not allowed if + # the graph handles subclasses at all. + # Whether the mutation is allowed or not allowed in the graph alters the number + # of outputs from the forward graph. Previously, a bug in this handling meant + # that sometimes the expected number and actual number of outputs from the + # joint graph did not match, causing assertion failures. + def fn(x, y): + z = x.sin() + y.sin_() + return z.cos(), y.cos() + + fn_c = torch.compile(fn, backend="inductor") + + values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] + values_copy = [x.detach().clone().requires_grad_(True) for x in values] + + nt, offsets = jagged_from_list(values, None) + nt_copy, offsets = jagged_from_list(values_copy, offsets) + y = torch.rand((4, 8)) + y_copy = y.clone() + + ret = fn_c(nt, y)[0] + ref = fn(nt_copy, y_copy)[0] + + self.assertEqual(ret.values(), ref.values()) + + ret.values().sum().backward() + ref.values().sum().backward() + for ref_v, res_v in zip(values_copy, values): + self.assertEqual(ref_v.grad, res_v.grad) def test_unbind(self): - nt, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) + # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). + # This causes a recompile later on when it realizes the batch and last dim + # should not always be equal. To avoid that, we use (3, j0, 5) here. + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) @@ -936,71 +1253,112 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): self._check_recompiles(fn, (nt,), (nt3,), True) def _get_views(self): - # There are three cases to consider here based on the logic in - # meta_utils.py - # - # (1) basic case: - # view is not a leaf and has the same requires grad as its basic case - x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) - self.assertEqual(x.is_leaf, False) - yield x.unsqueeze(-1) - - # (2) leaf view case: - # the view has to be a leaf (w/ requires_grad True or requires_grad False) - # base w/ requires_grad True or requires_grad False - for requires_grad_1, requires_grad_2 in itertools.product( - [True, False], repeat=2 - ): - x, _ = self._get_jagged_tensor( - ((2, 3, 4), 3), None, requires_grad=requires_grad_1 - ) + # Test all cases with both an NT base and a dense base + # Subclass -> Subclass + # Dense -> Subclass + for base_is_nt in [False, True]: + # There are three cases to consider here based on the logic in + # meta_utils.py + # + # (1) basic case: + # view is not a leaf and has the same requires grad as its basic case + x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) + x = x.clone() if base_is_nt else x + self.assertEqual(x.is_leaf, False) + yield x.unsqueeze(-1) + + # (2) leaf view case: + # the view has to be a leaf (w/ requires_grad True or requires_grad False) + # base w/ requires_grad True or requires_grad False + for requires_grad_1, requires_grad_2 in itertools.product( + [True, False], repeat=2 + ): + x, _ = self._get_jagged_tensor( + ((2, 3, 4), 3), None, requires_grad=requires_grad_1 + ) + x = x.clone() if base_is_nt else x + with torch.no_grad(): + x_view = x.unsqueeze(-1) + # The issue is this doesn't quite work + x_view.requires_grad_(requires_grad_2) + yield x_view + + # (3) obscure case: + # view is not a leaf (implies requires_grad True) + # base w/ requires_grad False) + x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) + x = x.clone() if base_is_nt else x + # intermediate leaf view with torch.no_grad(): x_view = x.unsqueeze(-1) - # The issue is this doesn't quite work - x_view.requires_grad_(requires_grad_2) - yield x_view - - # (3) obscure case: - # view is not a leaf (implies requires_grad True) - # base w/ requires_grad False) - x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) - # intermediate leaf view - with torch.no_grad(): - x_view = x.unsqueeze(-1) - x_view.requires_grad_(True) - x_view_view = x_view.unsqueeze(-1) - yield x_view_view - - def test_inputs_to_compiled_fn_are_views(self): - for nt_view in self._get_views(): + x_view.requires_grad_(True) + x_view_view = x_view.unsqueeze(-1) + yield x_view_view + + # Subclass -> Dense + x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() + yield x.values() + + # Dense -> Subclass -> Dense -> Subclass + values = torch.randn(10, 5) + offsets = torch.tensor([0, 3, 6, 10]) + offsets2 = offsets.clone().detach() + yield nested_view_from_values_offsets( + nested_view_from_values_offsets(values, offsets).values(), offsets + ) - def fn(x): - return x.sin() + def _input_view_test(self, nt_view): + def fn(x): + return x.sin() - out_ref = fn(nt_view) - torch._dynamo.reset() - compile_fn = torch.compile( - fn, fullgraph=True, backend="aot_eager", dynamic=True - ) - out = compile_fn(nt_view) + out_ref = fn(nt_view) + torch._dynamo.reset() + compile_fn = torch.compile( + fn, fullgraph=True, backend="aot_eager", dynamic=True + ) + out = compile_fn(nt_view) - # Check metadata and values are correct - self.assertTrue(out.size() == out_ref.size()) - self.assertTrue(out.stride() == out_ref.stride()) + # Check metadata and values are correct + self.assertTrue(out.size() == out_ref.size()) + self.assertTrue(out.stride() == out_ref.stride()) + if out.is_nested: self.assertTrue(torch.allclose(out.values(), out_ref.values())) + else: + self.assertTrue(torch.allclose(out, out_ref)) - # Check that no guards are incurred - def backend(gm, args): - context = torch._guards.TracingContext.get() - val_to_guards = context.fake_mode.shape_env.var_to_guards.values() - self.assertEqual(len(val_to_guards), 0) - return gm + # Check that no upper/lower bound guards are incurred + def backend(gm, args): + context = torch._guards.TracingContext.get() + guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] - torch._dynamo.reset() - compile_fn = torch.compile( - fn, fullgraph=True, backend=backend, dynamic=True - ) - out = compile_fn(nt_view) + # varies based on the type of view + guard_str = "\n".join(guards) + if isinstance(nt_view._base, NestedTensor): + self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") + else: + self.assertExpectedInline(guard_str, """""") + return gm + + torch._dynamo.reset() + compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) + out = compile_fn(nt_view) + + def test_inputs_to_compiled_fn_are_views(self): + for nt_view in self._get_views(): + self._input_view_test(nt_view) + + # NJT1 -> Dense -> NJT2 -> Dense view + # During view replay, the Dense -> NJT2 part will construct an intermediate, + # symbolically-sized NJT that is immediately deconstructed to return the final dense + # view. To construct this intermediate properly, we need the associated nested int + # to be symbolic. This view is expected to fail compilation until symbolic nested ints + # are cached onto fake offsets to solve this problem. + @unittest.expectedFailure + def test_subclass_dense_subclass_dense_view(self): + x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() + offsets2 = x.offsets().clone().detach() + nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() + self._input_view_test(nt_view) if __name__ == "__main__": diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 4ea3d8ea85db31273e2c00e43a29cf98c6a11d3c..188509014685ad5586f55181896390b11e0ac148 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest from unittest.mock import patch import torch @@ -565,7 +564,6 @@ class SubGraphTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 6) - @unittest.skip("not working yet") def test_tuple_iterator_mutate(self): def fn(x, it): x = x + next(it) diff --git a/test/dynamo/test_torchrec.py b/test/dynamo/test_torchrec.py new file mode 100644 index 0000000000000000000000000000000000000000..713560c7827331a4e7226f1d8efc760cad7f5b0e --- /dev/null +++ b/test/dynamo/test_torchrec.py @@ -0,0 +1,206 @@ +# Owner(s): ["module: dynamo"] +import sys +import unittest +from typing import Dict, List + +import torch +import torch_npu +import torch._dynamo.config +import torch._dynamo.test_case +from torch import nn +from torch._dynamo.test_case import TestCase +from torch._dynamo.testing import CompileCounter +from torch.testing._internal.common_utils import NoTest + +try: + from torchrec.datasets.random import RandomRecDataset + from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + HAS_TORCHREC = True +except ImportError: + HAS_TORCHREC = False + + +@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True) +class BucketizeMod(torch.nn.Module): + def __init__(self, feature_boundaries: Dict[str, List[float]]): + super().__init__() + self.bucket_w = torch.nn.ParameterDict() + self.boundaries_dict = {} + for key, boundaries in feature_boundaries.items(): + self.bucket_w[key] = torch.nn.Parameter( + torch.empty([len(boundaries) + 1]).fill_(1.0), + requires_grad=True, + ) + buf = torch.tensor(boundaries, requires_grad=False) + self.register_buffer( + f"{key}_boundaries", + buf, + persistent=False, + ) + self.boundaries_dict[key] = buf + + def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": + weights_list = [] + for key, boundaries in self.boundaries_dict.items(): + jt = features[key] + bucketized = torch.bucketize(jt.weights(), boundaries) + # doesn't super matter I guess + # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries)) + hashed = bucketized + weights = torch.gather(self.bucket_w[key], dim=0, index=hashed) + weights_list.append(weights) + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=torch.cat(weights_list), + lengths=features.lengths(), + offsets=features.offsets(), + stride=features.stride(), + length_per_key=features.length_per_key(), + ) + + +if not HAS_TORCHREC: + print("torchrec not available, skipping tests", file=sys.stderr) + TestCase = NoTest # noqa: F811 + + +@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") +class TorchRecTests(TestCase): + def test_pooled(self): + tables = [ + (nn.EmbeddingBag(2000, 8), ["a0", "b0"]), + (nn.EmbeddingBag(2000, 8), ["a1", "b1"]), + (nn.EmbeddingBag(2000, 8), ["b2"]), + ] + + embedding_groups = { + "a": ["a0", "a1"], + "b": ["b0", "b1", "b2"], + } + + counter = CompileCounter() + + @torch.compile(backend=counter, fullgraph=True, dynamic=True) + def f(id_list_features: KeyedJaggedTensor): + id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() + pooled_embeddings = {} + # do for later: run feature processor + for emb_module, feature_names in tables: + features_dict = id_list_jt_dict + for feature_name in feature_names: + f = features_dict[feature_name] + pooled_embeddings[feature_name] = emb_module( + f.values(), f.offsets() + ) + + pooled_embeddings_by_group = {} + for group_name, group_embedding_names in embedding_groups.items(): + group_embeddings = [ + pooled_embeddings[name] for name in group_embedding_names + ] + pooled_embeddings_by_group[group_name] = torch.cat( + group_embeddings, dim=1 + ) + + return pooled_embeddings_by_group + + dataset = RandomRecDataset( + keys=["a0", "a1", "b0", "b1", "b2"], + batch_size=4, + hash_size=2000, + ids_per_feature=3, + num_dense=0, + ) + di = iter(dataset) + + # unsync should work + + d1 = next(di).sparse_features.unsync() + d2 = next(di).sparse_features.unsync() + d3 = next(di).sparse_features.unsync() + + r1 = f(d1) + r2 = f(d2) + r3 = f(d3) + + self.assertEqual(counter.frame_count, 1) + counter.frame_count = 0 + + # sync should work too + + d1 = next(di).sparse_features.sync() + d2 = next(di).sparse_features.sync() + d3 = next(di).sparse_features.sync() + + r1 = f(d1) + r2 = f(d2) + r3 = f(d3) + + self.assertEqual(counter.frame_count, 1) + + # export only works with unsync + + gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module + gm.print_readable() + + self.assertEqual(gm(d1), r1) + self.assertEqual(gm(d2), r2) + self.assertEqual(gm(d3), r3) + + def test_bucketize(self): + mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]}) + features = KeyedJaggedTensor.from_lengths_sync( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([2, 0, 1, 1, 1, 3]), + weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), + ).unsync() + + def f(x): + # This is a trick to populate the computed cache and instruct + # ShapeEnv that they're all sizey + x.to_dict() + return mod(x) + + torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable() + + @unittest.expectedFailure + def test_simple(self): + jag_tensor1 = KeyedJaggedTensor( + values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + keys=["index_0", "index_1"], + lengths=torch.tensor([0, 0, 1, 1, 1, 3]), + ).sync() + + # ordinarily, this would trigger one specialization + self.assertEqual(jag_tensor1.length_per_key(), [1, 5]) + + counter = CompileCounter() + + @torch._dynamo.optimize(counter, nopython=True) + def f(jag_tensor): + # The indexing here requires more symbolic reasoning + # and doesn't work right now + return jag_tensor["index_0"].values().sum() + + f(jag_tensor1) + + self.assertEqual(counter.frame_count, 1) + + jag_tensor2 = KeyedJaggedTensor( + values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + keys=["index_0", "index_1"], + lengths=torch.tensor([2, 0, 1, 1, 1, 3]), + ).sync() + + f(jag_tensor2) + + self.assertEqual(counter.frame_count, 1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index c3894752032e371043bcbca475e557a3db17ecf5..72f07d09342f2eb1c54525d5454126b588474535 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -9,6 +9,7 @@ import warnings from typing import Any, Dict, Set import torch +import torch_npu import torch._dynamo.config as config import torch._dynamo.test_case import torch._functorch.deprecated as deprecated_func @@ -314,6 +315,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.", ) + @unittest.skip( + "This test keeps getting broken and our disable infra is not handling well. see #120627" + ) def test_torch_name_rule_map_updated(self): # Generate the allowed objects based on heuristic defined in `allowed_functions.py`, objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 3ccfd548f988fd5eac10939436243e437f2f9456..b693a7c292b7e8a5354e8cbf72094a268c9fbfb5 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -4,10 +4,8 @@ import random import unittest import numpy as np - import torch import torch_npu - import torch._dynamo.test_case import torch._dynamo.testing import torch.nn.functional as F @@ -215,8 +213,26 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) + def test_no_recompiles_prod_backward(self): + # See pytorch/pytorch/issues/120608 + cnt = CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True, dynamic=True) + def fn(t): + return torch.prod(t, 3, keepdim=True) + + input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)] + for s in input_shapes: + t1 = torch.randn(s, requires_grad=True) + h_result = fn(t1) + grad = torch.ones_like(h_result) + h_result.backward(grad) + + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + @unittest.skipIf(not torch.npu.is_available(), "requires npu") - def test_builtin_functions_on_npu(self): + def test_builtin_functions_on_cuda(self): def fn(x, scaler): m = torch.nn.ReLU() y = m(x) * scaler @@ -270,7 +286,18 @@ class UnspecTests(torch._dynamo.test_case.TestCase): res = opt_fn(x, y) self.assertTrue(same(ref, res)) - def test_shape_graph_break(self): + def test_mark_static_inside(self): + def fn(x): + torch._dynamo.mark_static(x, 0) + comptime.assert_static(x.size(0)) + return x + 1 + + opt_fn = torch.compile(fn, dynamic=True, fullgraph=True) + opt_fn(torch.randn(12, 23)) + + def test_shape_graph_break(self): + from torch._dynamo.comptime import comptime + def fn(x): x_shape = x.size() comptime.graph_break() @@ -355,6 +382,9 @@ class UnspecTests(torch._dynamo.test_case.TestCase): def f3(v): return torch.tensor(v.item()) + def f4(v): + return torch.tensor((v.item(),)) + optimize = torch.compile(backend="aot_eager", fullgraph=True) r = torch.randn(1) @@ -362,6 +392,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertEqual(f1(r), optimize(f1)(r)) self.assertEqual(f2(r), optimize(f2)(r)) self.assertEqual(f3(r), optimize(f3)(r)) + self.assertEqual(f4(r), optimize(f4)(r)) def test_sym_int_conversion(self): def f(x): @@ -381,6 +412,17 @@ class UnspecTests(torch._dynamo.test_case.TestCase): compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_item_max(self): + def fn(x): + return torch.ones(max(x.item(), 1024)) + + x = torch.tensor([1000]) + y = torch.tensor([2000]) + compl_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compl_fn(x)) + self.assertEqual(fn(y), compl_fn(y)) + # See pytorch/pytorch/issues/104812 def test_argmin_coerces_symint_to_intlist_spec(self): def fn(x, dim): @@ -402,13 +444,55 @@ class UnspecTests(torch._dynamo.test_case.TestCase): compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict)) + def test_symbol_guard_limit_before_specialize(self): + cnts = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.optimize(cnts, dynamic=True) + def fn(x): + torch._check(x.size(0) != 3) + torch._check(x.size(0) != 4) + torch._check(x.size(0) != 5) + torch._check(x.size(0) != 6) + return x + 2 + + # Control test + fn(torch.randn(12)) + fn(torch.randn(13)) + fn(torch.randn(14)) + + self.assertExpectedInline(cnts.frame_count, """1""") + cnts.frame_count = 0 + + torch._dynamo.reset() + + with torch.fx.experimental._config.patch( + symbol_guard_limit_before_specialize=3 + ): + fn(torch.randn(12)) + fn(torch.randn(13)) + fn(torch.randn(14)) + + self.assertExpectedInline(cnts.frame_count, """3""") + + def test_defaults(self): + def g(x, i=8): + comptime.assert_static(i) + return x * i + + def fn(x): + return g(x) + + inputs = torch.randn(2, 3, 4) + compl_fn = torch.compile(fn, dynamic=True, backend="eager") + self.assertEqual(compl_fn(inputs), fn(inputs)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_data_dependent_evaluate_expr_graph_break(self): cnts = torch._dynamo.testing.CompileCounter() # To ensure that the continuation frame is compiled, # have to write the test function in this funny way. - # See See pytorch/pytorch/issues/111918 + # See pytorch/pytorch/issues/111918 def test(y): if y > 2: return True @@ -430,6 +514,15 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertExpectedInline(cnts.frame_count, """2""") self.assertExpectedInline(cnts.op_count, """3""") + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_split_aot_autograd(self): + @torch.compile(backend="aot_eager", fullgraph=True) + def f(x, i): + y, z = i.tolist() + return torch.split(x, [y, z]) + + print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3]))) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index 80080a2b534becc01b66e65aba30d6bed6976ae9..d132fca884a94b5607f13392619e90112843f789 100644 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -3,11 +3,11 @@ import operator import torch import torch_npu - import torch._dynamo import torch._dynamo.config as config import torch._dynamo.test_case from torch._dynamo.testing import same +from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module class Seq(torch.nn.Module): @@ -36,10 +36,7 @@ class Conv_Bn_Relu(torch.nn.Module): def toy_example(a, b): - try: - x = a / (torch.abs(a) + 1) - except ZeroDivisionError: - return "ZeroDivisionError: Division by zero is not allowed" + x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b @@ -67,10 +64,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase): def test_example_inputs(self): def fn(a, bc, d): b, c = bc - try: - return a / d - b / c - except ZeroDivisionError: - return "ZeroDivisionError: Division by zero is not allowed" + return a / d - b / c def compiler_fn(graph, example_inputs): nonlocal r1 @@ -93,6 +87,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase): self.assertEqual(r1.device, r2.device) self.assertEqual(r1.device, r3.device) + @_force_skip_lazy_graph_module() def test_torchscript(self): s = Seq() i = torch.randn(10) diff --git a/test/unsupported_test_cases/.pytorch-disabled-tests.json b/test/unsupported_test_cases/.pytorch-disabled-tests.json index 34dbc322456c3fec69c2d4d2f3c3792b6f6a1d52..a73b6d08ffdb927a5a36ccfc45f4038d08de374c 100644 --- a/test/unsupported_test_cases/.pytorch-disabled-tests.json +++ b/test/unsupported_test_cases/.pytorch-disabled-tests.json @@ -1,4 +1,5 @@ { + "test_tvm (__main__.TestOptimizations)": ["", [""]], "test_grad_fn_with_kwargs_dynamic_shapes (__main__.DynamicShapesFuncTorchHigherOrderOpTests)": ["", [""]], "test_grad_pytree_dynamic_shapes (__main__.DynamicShapesFuncTorchHigherOrderOpTests)": ["", [""]], "test_mem_eff_backwards_throws_determinism_warning_warn_only_True_npu (__main__.TestSDPANpuOnlyPRIVATEUSE1)": ["", [""]], diff --git a/test/unsupported_test_cases/disabled_tests_type.json b/test/unsupported_test_cases/disabled_tests_type.json new file mode 100644 index 0000000000000000000000000000000000000000..595617c74094821c0ef5499c0b19180361e1fccb --- /dev/null +++ b/test/unsupported_test_cases/disabled_tests_type.json @@ -0,0 +1,13 @@ +{ + "test_autocast_float64 (__main__.CtxManagerTests)": {"DTYPE": "unsupport float64"}, + "test_autocast_float64_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": {"DTYPE": "unsupport float64"}, + "test_npu_amp_autocast (__main__.CtxManagerTests)": {"DTYPE": "unsupport float64"}, + "test_npu_amp_autocast_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": {"DTYPE": "unsupport float64"}, + "test_autocast_flash_attention (__main__.ActivationCheckpointingViaTagsTests)": {"NOT_SUPPORT": "unsupport aten SDPA for graph module"}, + "test_compile_selective_checkpoint_outplace_op (__main__.ActivationCheckpointingViaTagsTests)": {"NOT_SUPPORT": "failed on both GPU and NPU"}, + "test_tvm (__main__.TestOptimizations)": {"OTHER": "skipped in original pytorch too"}, + "test_tags_decomps (__main__.ActivationCheckpointingViaTagsTests)": {"OTHER": "there is a precision problem on NPU"}, + "test_after_dynamo_cuda_accuracy_error (__main__.MinifierTests)": {"NOT_SUPPORT": "CUDA related function"}, + "test_after_dynamo_cuda_compile_error (__main__.MinifierTests)": {"NOT_SUPPORT": "CUDA related function"}, + "test_cpu_cuda_module_after_dynamo (__main__.MinifierTests)": {"NOT_SUPPORT": "CUDA related function"} +}