From 3e18b3deefa50f2884db7d83b699965bed4dd9c5 Mon Sep 17 00:00:00 2001 From: luqingyun Date: Wed, 29 Nov 2023 15:51:00 +0800 Subject: [PATCH] Modify test cases Modify test cases --- test/distributed/_tensor/test_api.py | 8 +- .../_tensor/test_basic_strategy.py | 15 +- test/distributed/_tensor/test_common_rules.py | 13 +- test/distributed/_tensor/test_device_mesh.py | 557 ++++++++++++++++++ test/distributed/_tensor/test_dtensor.py | 27 +- .../_tensor/test_dtensor_compile.py | 12 +- .../_tensor/test_dtensor_custom_ops.py | 8 +- .../distributed/_tensor/test_embedding_ops.py | 37 +- test/distributed/_tensor/test_init.py | 11 +- test/distributed/_tensor/test_math_ops.py | 4 +- test/distributed/_tensor/test_matrix_ops.py | 2 +- .../distributed/_tensor/test_pointwise_ops.py | 13 +- test/distributed/_tensor/test_random_ops.py | 2 +- test/distributed/_tensor/test_redistribute.py | 10 +- test/distributed/_tensor/test_tensor_ops.py | 28 +- test/distributed/_tensor/test_utils.py | 12 +- test/distributed/_tensor/test_view_ops.py | 2 +- 17 files changed, 693 insertions(+), 68 deletions(-) create mode 100644 test/distributed/_tensor/test_device_mesh.py diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py index a882766b9d..1384f308b7 100644 --- a/test/distributed/_tensor/test_api.py +++ b/test/distributed/_tensor/test_api.py @@ -8,12 +8,12 @@ from torch.distributed._tensor import ( Replicate, Shard, ) -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class MyModel(nn.Module): @@ -31,7 +31,6 @@ class MyModel(nn.Module): m.reset_parameters() -@skipIfUnsupportMultiNPU(4) class DTensorAPITest(DTensorTestBase): @property def world_size(self) -> int: @@ -39,6 +38,7 @@ class DTensorAPITest(DTensorTestBase): # at least with 2d mesh return 4 + @skipIfUnsupportMultiNPU(4) @with_comms def test_distribute_tensor(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -56,6 +56,7 @@ class DTensorAPITest(DTensorTestBase): self.assertTrue(dist_tensor.requires_grad) self.assertTrue(dist_tensor.is_leaf) + @skipIfUnsupportMultiNPU(4) @with_comms def test_distribute_tensor_uneven_sharding(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -78,6 +79,7 @@ class DTensorAPITest(DTensorTestBase): local_tensor = dist_tensor.to_local() self.assertEqual(local_tensor, splitted_tensor_list[self.rank]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_distribute_module(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -141,6 +143,7 @@ class DTensorAPITest(DTensorTestBase): else: self.assertEqual(param.placements, replica_spec) + @skipIfUnsupportMultiNPU(4) @with_comms def test_distribute_module_input_fn_output_fn(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -185,6 +188,7 @@ class DTensorAPITest(DTensorTestBase): self.assertTrue(isinstance(param_grad, DTensor)) self.assertTrue(isinstance(param_grad.placements[0], Replicate)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_distribute_module_meta(self): # If the model is too big, the user may first the create entire model on the meta device and then initialize diff --git a/test/distributed/_tensor/test_basic_strategy.py b/test/distributed/_tensor/test_basic_strategy.py index 766a406b7c..f25081ef45 100644 --- a/test/distributed/_tensor/test_basic_strategy.py +++ b/test/distributed/_tensor/test_basic_strategy.py @@ -5,15 +5,16 @@ from torch.distributed._tensor.ops.basic_strategy import ( gen_einsum_strategies, ) -from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase +from torch.testing._internal.common_utils import TestCase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class TestEinsumDims(TestCase): + @skipIfUnsupportMultiNPU(4) def test_batch_dims(self): equation = "abc,abc->abc" input_dims, output_dim = EinsumDims.parse_equation(equation) @@ -24,6 +25,7 @@ class TestEinsumDims(TestCase): self.assertEqual(edims.lhs_out_only_dims, []) self.assertEqual(edims.rhs_out_only_dims, []) + @skipIfUnsupportMultiNPU(4) def test_mm_dims(self): equation = "mk,kn->mn" input_dims, output_dim = EinsumDims.parse_equation(equation) @@ -34,6 +36,7 @@ class TestEinsumDims(TestCase): self.assertEqual(edims.lhs_out_only_dims, ["m"]) self.assertEqual(edims.rhs_out_only_dims, ["n"]) + @skipIfUnsupportMultiNPU(4) def test_bmm_dims(self): equation = "bmk,bkn->bmn" input_dims, output_dim = EinsumDims.parse_equation(equation) @@ -53,6 +56,7 @@ class TestEinsumDims(TestCase): self.assertEqual(edims.lhs_out_only_dims, ["m"]) self.assertEqual(edims.rhs_out_only_dims, ["n"]) + @skipIfUnsupportMultiNPU(4) def test_free_dims(self): equation = "abc,ab->abc" input_dims, output_dim = EinsumDims.parse_equation(equation) @@ -73,12 +77,12 @@ class TestEinsumDims(TestCase): self.assertEqual(edims.rhs_out_only_dims, ["f"]) -@skipIfUnsupportMultiNPU(4) class TestEinsumStrategies(DTensorTestBase): @property def world_size(self) -> int: return 4 + @skipIfUnsupportMultiNPU(4) @with_comms def test_mm_1d_mesh(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -86,6 +90,7 @@ class TestEinsumStrategies(DTensorTestBase): all_strats = gen_einsum_strategies("mk,kn->mn", mesh) self.assertEqual(len(all_strats.strategies), 4) + @skipIfUnsupportMultiNPU(4) @with_comms def test_mm_2d_mesh(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) @@ -93,6 +98,7 @@ class TestEinsumStrategies(DTensorTestBase): all_strats = gen_einsum_strategies("mk,kn->mn", mesh) self.assertEqual(len(all_strats.strategies), 16) + @skipIfUnsupportMultiNPU(4) @with_comms def test_bmm_1d_mesh(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -100,6 +106,7 @@ class TestEinsumStrategies(DTensorTestBase): all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh) self.assertEqual(len(all_strats.strategies), 5) + @skipIfUnsupportMultiNPU(4) @with_comms def test_bmm_2d_mesh(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) @@ -107,6 +114,7 @@ class TestEinsumStrategies(DTensorTestBase): all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh) self.assertEqual(len(all_strats.strategies), 25) + @skipIfUnsupportMultiNPU(4) @with_comms def test_pointwise_1d_mesh(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -117,6 +125,7 @@ class TestEinsumStrategies(DTensorTestBase): broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh) self.assertEqual(len(broadcast_strats.strategies), 5) + @skipIfUnsupportMultiNPU(4) @with_comms def test_linearity_1d_mesh(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py index cb0bb878f9..6acc1b89bc 100644 --- a/test/distributed/_tensor/test_common_rules.py +++ b/test/distributed/_tensor/test_common_rules.py @@ -9,14 +9,13 @@ from torch.distributed._tensor.ops.common_rules import ( ) from torch.distributed._tensor.placement_types import DTensorSpec from torch.fx.passes.shape_prop import _extract_tensor_metadata -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class CommonRulesTest(DTensorTestBase): @property def world_size(self) -> int: @@ -28,6 +27,7 @@ class CommonRulesTest(DTensorTestBase): empty_tensor = torch.empty(shape) return _extract_tensor_metadata(empty_tensor) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_basic_propagation(self): # plain einsum, mm @@ -82,6 +82,7 @@ class CommonRulesTest(DTensorTestBase): self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_pointwise_propagation(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -136,6 +137,7 @@ class CommonRulesTest(DTensorTestBase): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_merge_sharding(self): # 2d mesh einop merge sharding @@ -162,6 +164,7 @@ class CommonRulesTest(DTensorTestBase): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( @@ -232,6 +235,7 @@ class CommonRulesTest(DTensorTestBase): # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim @@ -261,6 +265,7 @@ class CommonRulesTest(DTensorTestBase): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( @@ -284,6 +289,7 @@ class CommonRulesTest(DTensorTestBase): with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {})) + @skipIfUnsupportMultiNPU(4) @with_comms def test_pointwise_rules_broadcasting(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -312,6 +318,7 @@ class CommonRulesTest(DTensorTestBase): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_pointwise_rules_suggestion(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -342,6 +349,7 @@ class CommonRulesTest(DTensorTestBase): self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) + @skipIfUnsupportMultiNPU(4) @with_comms def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding @@ -394,6 +402,7 @@ class CommonRulesTest(DTensorTestBase): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) + @skipIfUnsupportMultiNPU(4) @with_comms def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py new file mode 100644 index 0000000000..82bb2a5226 --- /dev/null +++ b/test/distributed/_tensor/test_device_mesh.py @@ -0,0 +1,557 @@ +import os + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor._collective_utils import ( + mesh_all_to_all, + mesh_broadcast, + mesh_scatter, +) +from torch.distributed._tensor.device_mesh import ( + _mesh_resources, + DeviceMesh, + init_device_mesh, +) +from torch.distributed._tensor.placement_types import Shard + +from torch.distributed.distributed_c10d import ( + get_global_rank, + get_world_size, + init_process_group, + is_initialized, + ProcessGroup, +) +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase +from torch.testing._internal.distributed.fake_pg import FakeStore + +import torch_npu +from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests + + +def _get_device_type(world_size): + if ( + torch.npu.is_available() + and torch.npu.device_count() >= world_size + ): + device_type = "npu" + else: + device_type = "cpu" + return device_type + + +def _set_env_var(addr="localhost", port="29500", world_size=1, rank=0): + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = port + os.environ["WORLD_SIZE"] = f"{world_size}" + os.environ["RANK"] = f"{rank}" + + +class DeviceMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @skipIfUnsupportMultiNPU(8) + def test_init_process_group(self): + device_type = _get_device_type(self.world_size) + mesh_tensor = torch.arange(4).reshape(2, 2) + self.assertTrue(not is_initialized()) + _set_env_var(world_size=self.world_size, rank=self.rank) + DeviceMesh(device_type, mesh_tensor) + self.assertTrue(is_initialized()) + self.destroy_pg() + + @skipIfUnsupportMultiNPU(8) + def test_fake_pg_device_mesh(self): + fake_store = FakeStore() + init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size) + torch.npu.set_device(0) + device_type = "npu" if torch.npu.is_available() else "cpu" + mesh = DeviceMesh(device_type, torch.arange(self.world_size)) + + local_tensor = torch.randn(2, 8) + global_tensor = funcol.all_gather_tensor( + local_tensor, gather_dim=0, group=(mesh, 0) + ) + self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) + + +class DeviceMeshTestDim(DTensorTestBase): + @property + def world_size(self): + return 4 + + @skipIfUnsupportMultiNPU(4) + @with_comms + def test_device_mesh_2d(self): + mesh_tensor = torch.arange(4).reshape(2, 2) + # construct a npu device mesh + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < 2) + dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + current_rank_expected_group_ranks = ( + dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] + ) + self.assertEqual(global_ranks, current_rank_expected_group_ranks) + + @skipIfUnsupportMultiNPU(4) + @with_comms + def test_lazy_init_device_mesh(self): + mesh = DeviceMesh(self.device_type, [1], _init_process_groups=False) + + with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"): + mesh.get_dim_groups() + + @skipIfUnsupportMultiNPU(4) + @with_comms + def test_validate_device_mesh(self): + mesh = torch.arange(self.world_size).reshape(2, -1) + mesh_subpg_1 = mesh[0] + mesh_subpg_2 = mesh[1] + with self.assertRaisesRegex(RuntimeError, "different mesh"): + if self.rank in mesh_subpg_1: + mesh = DeviceMesh(self.device_type, mesh_subpg_1) + else: + mesh = DeviceMesh(self.device_type, mesh_subpg_2) + + +class DeviceMeshTestNDim(DTensorTestBase): + @property + def world_size(self): + return 8 + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_device_mesh_nd(self): + # construct a cuda device mesh + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < mesh_tensor.ndim) + dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + for ranks in dim_ranks: + if self.rank in ranks: + self.assertEqual(global_ranks, ranks.tolist()) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_device_mesh_hash(self): + mesh_tensor_2d = torch.arange(8).reshape(4, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor_2d) + mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d) + self.assertNotEqual(hash(mesh), hash(mesh2)) + mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2) + mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d) + self.assertNotEqual(hash(mesh), hash(mesh3)) + self.assertNotEqual(hash(mesh2), hash(mesh3)) + + +class InitDeviceMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_init_device_mesh(self): + mesh_shape = (2, 4) + ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape)) + + # test init_device_mesh with mesh_dim_names + mesh_dim_names = ("DP", "TP") + two_d_mesh = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + self.assertEqual(two_d_mesh, ref_mesh) + self.assertEqual(two_d_mesh.mesh_dim_names, mesh_dim_names) + + # test init_device_mesh without mesh_dim_names + two_d_mesh = init_device_mesh(self.device_type, mesh_shape) + self.assertEqual(two_d_mesh, ref_mesh) + + +class TestDeviceMeshGetItem(DTensorTestBase): + @property + def world_size(self): + return 8 + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_raises_mesh_dim_less_than_2(self): + with self.assertRaisesRegex(RuntimeError, "Cannot slice a DeviceMesh"): + mesh = init_device_mesh(self.device_type, (8,)) + child_mesh = mesh["DP"] + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_raises_no_mesh_dim_found(self): + with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): + mesh = init_device_mesh(self.device_type, (2, 4)) + child_mesh = mesh["DP"] + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_raises_invalid_mesh_dim_name(self): + child_mesh_dim_name = "PP" + with self.assertRaisesRegex( + KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist." + ): + mesh_dim_names = ("DP", "TP") + mesh = init_device_mesh( + self.device_type, (2, 4), mesh_dim_names=mesh_dim_names + ) + child_mesh = mesh[child_mesh_dim_name] + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_get_item(self): + mesh_shape = (2, 4) + mesh_dim_names = ("DP", "TP") + two_d_mesh = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + pg_ranks_by_dim_name = {} + for mesh_dim_name in mesh_dim_names: + mesh_dim = mesh_dim_names.index(mesh_dim_name) + pg_ranks_by_dim_name[mesh_dim_name] = two_d_mesh.mesh.swapdims( + -1, mesh_dim + ).reshape(-1, two_d_mesh.mesh.size(mesh_dim)) + + tp_mesh = two_d_mesh["TP"] + tp_group_idx = self.rank // 4 + self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name.get("TP")[tp_group_idx]) + + dp_mesh = two_d_mesh["DP"] + dp_group_idx = self.rank % 4 + self.assertEqual( + two_d_mesh["DP"].mesh, pg_ranks_by_dim_name.get("DP")[dp_group_idx] + ) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_get_parent_mesh(self): + mesh_shape = (2, 4) + mesh_dim_names = ("DP", "TP") + two_d_mesh = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + self.assertEqual(_mesh_resources.get_parent_mesh(two_d_mesh["DP"]), two_d_mesh) + self.assertEqual(_mesh_resources.get_parent_mesh(two_d_mesh["TP"]), two_d_mesh) + + +class DeviceMeshCollectiveTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_broadcast_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + mesh_broadcast(local_tensor, mesh, mesh_dim=0) + self.assertEqual(local_tensor, torch.zeros(3, 3)) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_scatter_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + scatter_tensor_shape = [3, 3, 3] + for scatter_dim, _ in enumerate(scatter_tensor_shape): + shard_placement = Shard(scatter_dim) + scatter_tensor_shape[scatter_dim] *= self.world_size + # make the random seed same across rank + torch.manual_seed(0) + global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type) + splitted_list, _ = shard_placement._split_tensor( + global_tensor, mesh.size(), with_padding=True, contiguous=True + ) + recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0) + self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.randn( + device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + + tensor_to_scatter = tensor_to_split.clone() + tensor_splitted_list = list( + torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) + ) + for _ in range(self.world_size - len(tensor_splitted_list)): + tensor_splitted_list.append(torch.tensor([], device=self.device_type)) + + padded_tensor_list, pad_sizes = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) + mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0) + + if pad_sizes[my_rank] != 0: + scattered_tensor = shard_placement._unpad_tensor( + scattered_tensor, pad_sizes[my_rank] + ) + + if scattered_tensor.numel() == 0: + # We need to check numel() instead of size if a tensor is ([]) after unpadding, + # since the size could be ([0, 8]) after unpadding. + self.assertEqual( + scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() + ) + else: + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_all_gather_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_padded_list, pad_sizes = shard_placement._split_tensor( + tensor_to_split, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + local_tensor = tensor_padded_list[my_rank] + big_tensor = funcol.all_gather_tensor( + local_tensor, gather_dim=shard_dim, group=(device_mesh, 0) + ) + big_tensor_chunks = list( + torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim) + ) + unpadded_list = [ + shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i]) + if pad_sizes[i] > 0 + else big_tensor_chunks[i] + for i, big_tensor in enumerate(big_tensor_chunks) + ] + all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim) + + self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) + self.assertEqual(all_gathered_tensor, tensor_to_split) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_reduce_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = ( + torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + * self.rank + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_to_scatter = tensor_to_split.clone() + + tensor_splitted_list = list( + torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) + ) + for _ in range(self.world_size - len(tensor_splitted_list)): + tensor_splitted_list.append(torch.tensor([], device=self.device_type)) + + padded_tensor_list, pad_sizes = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim) + + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + + scattered_tensor = funcol.reduce_scatter_tensor( + tensor_to_reduce, + reduceOp="sum", + scatter_dim=shard_dim, + group=(device_mesh, 0), + ) + + # unpad scattered_tensor + if pad_sizes[my_rank] > 0: + scattered_tensor = shard_placement._unpad_tensor( + scattered_tensor, pad_sizes[my_rank] + ) + + if scattered_tensor.numel() == 0: + # We need to check numel() instead of size if a tensor is ([]) after unpadding, + # since the size could be ([0, 8]) after unpadding. + self.assertEqual( + scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() + ) + else: + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual( + scattered_tensor, + torch.ones_like(tensor_splitted_list[my_rank]) * res_num, + ) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_broadcast_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + cloned_local_tensor = local_tensor.clone() + mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim) + res_num = global_ranks[0] + self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_scatter_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + scattered_tensors = [ + torch.ones(3, 3, device=self.device_type) * global_rank + for global_rank in global_ranks + ] + received_tensor = torch.empty_like( + scattered_tensors[mesh.get_coordinate()[dim]] + ) + mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim) + self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_all_to_all_1d(self): + # transpose on a 2D tensor distributed over N nodes: + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + tensor_shape = [3, 3] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (rank + self.rank * self.world_size) + for rank in range(self.world_size) + ] + expected_tensor_list = [ + torch.ones(tensor_shape, device=self.device_type) + * (self.rank + rank * self.world_size) # i.e. transpose + for rank in range(self.world_size) + ] + for scatter_dim in range(len(tensor_shape)): + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) + + self.assertEqual(output_tensor, expected_tensor) + + @skipIfUnsupportMultiNPU(8) + @with_comms + def test_all_to_all_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + tensor_shape = [3, 3, 3] + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + my_coordinate = mesh.get_coordinate()[dim] + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (i + self.rank * dim_group_size) + for i in range(dim_group_size) + ] + expected_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (my_coordinate + global_rank * dim_group_size) # i.e. transpose + for global_rank in global_ranks + ] + for scatter_dim in range(len(tensor_shape)): + # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim) + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh_all_to_all( + output_tensor_list, input_tensor_list, mesh, mesh_dim=dim + ) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) + self.assertEqual(output_tensor, expected_tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 22d9be7948..48c2494eb4 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -8,11 +8,11 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class DummyMLP(torch.nn.Module): @@ -33,8 +33,8 @@ class DummyMLP(torch.nn.Module): self.net2.bias.fill_(1.2) -@skipIfUnsupportMultiNPU(4) class DTensorTest(DTensorTestBase): + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_constructor(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -75,6 +75,7 @@ class DTensorTest(DTensorTestBase): stride=local_tensor.stride(), ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_meta_dtensor(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -98,6 +99,7 @@ class DTensorTest(DTensorTestBase): value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5) self.assertEqual(meta_dtensor.to_local(), value_tensor) + @skipIfUnsupportMultiNPU(4) @with_comms def test_modules_w_meta_dtensor(self): model = DummyMLP("meta") @@ -129,6 +131,7 @@ class DTensorTest(DTensorTestBase): inp = torch.randn(20, 5, device=self.device_type) self.assertEqual(model_tp(inp), model_regular_tp(inp)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_stride(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -155,6 +158,7 @@ class DTensorTest(DTensorTestBase): global_stride = (8 * self.world_size, 1, 32 * self.world_size) self.assertEqual(dist_tensor.stride(), global_stride) + @skipIfUnsupportMultiNPU(4) @with_comms def test_from_local(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -192,6 +196,7 @@ class DTensorTest(DTensorTestBase): expected_grad = torch.ones(3, 3) * 9 self.assertEqual(local_tensor_with_grad.grad, expected_grad) + @skipIfUnsupportMultiNPU(4) @with_comms def test_to_local(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -243,6 +248,7 @@ class DTensorTest(DTensorTestBase): except RuntimeError: self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_new_empty_strided(self): device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -266,6 +272,7 @@ class DTensorTest(DTensorTestBase): local_tensor.grad, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_async_output(self): # Tests that if the output of some dtensor operations isn't used in any compute, @@ -305,6 +312,7 @@ class DTensorTest(DTensorTestBase): self.assertEqual(type(out_data), torch.Tensor) self.assertEqual(out_data, ref) + @skipIfUnsupportMultiNPU(4) @with_comms def test_from_local_then_to_local(self): # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works @@ -338,6 +346,7 @@ class DTensorTest(DTensorTestBase): expected_grad = torch.ones(3, 3) * 6 self.assertEqual(local_tensor_with_grad.grad, expected_grad) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_spec_read_only_after_set(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -350,6 +359,7 @@ class DTensorTest(DTensorTestBase): self.assertTrue(sharded_tensor.placements is not shard_spec) self.assertNotEqual(sharded_tensor.placements, shard_spec) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_spec_hash(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -370,6 +380,7 @@ class DTensorTest(DTensorTestBase): ) self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_properties(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -378,6 +389,7 @@ class DTensorTest(DTensorTestBase): sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) self.assertEqual(sharded_tensor.device.type, self.device_type) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_save_load(self): import io @@ -393,7 +405,6 @@ class DTensorTest(DTensorTestBase): self.assertEqual(sharded_tensor, reloaded_st) -@skipIfUnsupportMultiNPU(4) class DTensorMeshTest(DTensorTestBase): @property def world_size(self): @@ -405,6 +416,7 @@ class DTensorMeshTest(DTensorTestBase): else: self.assertEqual(tensor, exp_out_of_mesh) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_device_mesh_device_conversion(self): # construct a cuda device mesh @@ -418,6 +430,7 @@ class DTensorMeshTest(DTensorTestBase): self.assertEqual(dist_tensor.device.type, self.device_type) self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_api_device_mesh_context_manager(self): with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: @@ -457,6 +470,7 @@ class DTensorMeshTest(DTensorTestBase): sharded_after_2d = distribute_tensor(global_tensor, placements=shard_spec) self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3])) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_2d_mesh(self): mesh_tensor = torch.arange(self.world_size).reshape(2, 4) @@ -480,6 +494,7 @@ class DTensorMeshTest(DTensorTestBase): dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec) self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) + @skipIfUnsupportMultiNPU(4) @with_comms def test_device_mesh_nd(self): # construct a cuda device mesh @@ -501,6 +516,7 @@ class DTensorMeshTest(DTensorTestBase): self.assertEqual(dist_tensor.device.type, self.device_type) self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_spec_local_shard_offset(self): device_mesh = DeviceMesh( @@ -538,6 +554,7 @@ class DTensorMeshTest(DTensorTestBase): ) self.assertEqual(expected_shard_offsets, offset) + @skipIfUnsupportMultiNPU(4) @with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) @@ -565,6 +582,7 @@ class DTensorMeshTest(DTensorTestBase): dtensor.to_local(), ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_default_value_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) @@ -603,6 +621,7 @@ class DTensorMeshTest(DTensorTestBase): [dt.to_local() for dt in dtensor_list], ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_redistribute_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) @@ -620,7 +639,6 @@ class DTensorMeshTest(DTensorTestBase): ) -@skipIfUnsupportMultiNPU(4) class TestDTensorPlacementTypes(DTensorTestBase): @property def world_size(self): @@ -635,6 +653,7 @@ class TestDTensorPlacementTypes(DTensorTestBase): else: return tensor + @skipIfUnsupportMultiNPU(4) @with_comms def test_split_tensor(self) -> None: mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index dfcc009299..0f4c587f2c 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -7,9 +7,6 @@ import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module -from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule @@ -18,6 +15,7 @@ from torch.testing._internal.distributed.fake_pg import FakeStore import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class SimpleModel(nn.Module): @@ -30,12 +28,12 @@ class SimpleModel(nn.Module): return self.mlp_1(self.mlp_0(input_x)) -@skipIfUnsupportMultiNPU(4) class TestDTensorCompile(DTensorTestBase): @property def world_size(self) -> int: return 2 + @skipIfUnsupportMultiNPU(4) @with_comms def test_fakify_dtensor(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -51,6 +49,7 @@ class TestDTensorCompile(DTensorTestBase): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dynamo_dtensor(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -66,6 +65,7 @@ class TestDTensorCompile(DTensorTestBase): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dynamo_dtensor_from_local(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -81,6 +81,7 @@ class TestDTensorCompile(DTensorTestBase): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dynamo_dtensor_from_local_redistribute(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -98,12 +99,12 @@ class TestDTensorCompile(DTensorTestBase): self.assertEqual(res, ref) -@skipIfUnsupportMultiNPU(4) class TestDTensorCompileE2E(DTensorTestBase): @property def world_size(self): return 4 + @skipIfUnsupportMultiNPU(4) @with_comms def test_tp_compile_fullgraph(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -122,7 +123,6 @@ class TestDTensorCompileE2E(DTensorTestBase): data_parallel_size = 2 model = SimpleModel(self.device_type) model_copy = copy.deepcopy(model) - enable_2d_with_fsdp() # 2-D mesh is [dp, tp] twod_mesh = DeviceMesh( diff --git a/test/distributed/_tensor/test_dtensor_custom_ops.py b/test/distributed/_tensor/test_dtensor_custom_ops.py index cc8147b1ad..55a398c4dc 100644 --- a/test/distributed/_tensor/test_dtensor_custom_ops.py +++ b/test/distributed/_tensor/test_dtensor_custom_ops.py @@ -5,14 +5,13 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class TestDTensorCustomOps(DTensorTestBase): @property def world_size(self): @@ -20,6 +19,7 @@ class TestDTensorCustomOps(DTensorTestBase): # at least with 2d mesh return 4 + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_npu_bmmV2(self): npu_input1 = torch.randn(4, 12, 8).npu() @@ -47,6 +47,7 @@ class TestDTensorCustomOps(DTensorTestBase): ) self.assertEqual(dist_res.to_local(), local_res) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_fast_gelu(self): npu_input = torch.randn(4, 3).npu() @@ -60,6 +61,7 @@ class TestDTensorCustomOps(DTensorTestBase): self.assertEqual(npu_input.shape, dist_res.to_local().shape) self.assertEqual(local_res, dist_res.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_npu_fast_gelu(self): npu_input = torch.randn(4, 3).npu() @@ -71,6 +73,7 @@ class TestDTensorCustomOps(DTensorTestBase): dist_res = torch_npu.npu_fast_gelu(dist_tensor).redistribute(device_mesh, [Replicate()]) self.assertEqual(local_res, dist_res.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_npu_dtype_cast(self): npu_input = torch.randn((2, 3), dtype=torch.float32).npu() @@ -92,6 +95,7 @@ class TestDTensorCustomOps(DTensorTestBase): self.assertEqual(dist_res.to_local().dtype, dst_dtype) self.assertEqual(dist_res.to_local(), local_result) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dtensor_npu_transpose(self): npu_input = torch.randn(5, 3, 6, 4).npu() diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 85005bf69f..6163cbc925 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -3,17 +3,13 @@ import sys import torch from torch.distributed._tensor import distribute_tensor, DTensor, DeviceMesh from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -if TEST_WITH_DEV_DBG_ASAN: - raise RuntimeError("Skip dev-asan as torch + multiprocessing spawn have known issues") - -@skipIfUnsupportMultiNPU(4) class TestEmbeddingOp(DTensorTestBase): def _run_embedding_op_test( self, @@ -24,7 +20,7 @@ class TestEmbeddingOp(DTensorTestBase): **kwargs, ): # Use same seed. - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() torch.manual_seed(0) local_embedding = torch.nn.Embedding( num_embeddings, @@ -41,7 +37,14 @@ class TestEmbeddingOp(DTensorTestBase): # Shard the parameter of local embedding and set it to sharded embedding. sharded_embedding.weight = torch.nn.Parameter( - distribute_tensor(local_embedding.weight, device_mesh, [Shard(shard_dim)]) + local_embedding.weight.clone().detach() + ) + parallelize_module( + module=sharded_embedding, + device_mesh=device_mesh, + parallelize_plan=ColwiseParallel(output_layouts=Replicate()) + if shard_dim == 1 + else RowwiseParallel(), ) # Run sharded computation @@ -52,12 +55,7 @@ class TestEmbeddingOp(DTensorTestBase): target = torch.empty( *inp.size(), embedding_dim, dtype=torch.float, device=self.device_type ).random_(0, 1) - placements = [Replicate()] - replicate_inp = DTensor.from_local(inp, device_mesh, placements) - sharded_output = sharded_embedding(replicate_inp) - output = sharded_output.redistribute( - sharded_output.device_mesh, [Replicate()] - ).to_local() + output = sharded_embedding(inp) # Run local computation local_output = local_embedding(inp) @@ -79,7 +77,7 @@ class TestEmbeddingOp(DTensorTestBase): attn_dup_loss.backward() gradient = sharded_embedding.weight.grad.redistribute( - sharded_output.device_mesh, [Replicate()] + device_mesh, [Replicate()] ).to_local() local_grad = local_embedding.weight.grad @@ -94,17 +92,13 @@ class TestEmbeddingOp(DTensorTestBase): **kwargs, ) sharded_output = torch.nn.functional.embedding( - replicate_inp, + DTensor.from_local(inp, device_mesh, [Replicate()]), sharded_embedding.weight, **kwargs, ) - self.assertEqual( - local_output, - sharded_output.redistribute( - sharded_output.device_mesh, [Replicate()] - ).to_local(), - ) + self.assertEqual(local_output, sharded_output.full_tensor()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_sharded_embedding_colwise_errors(self): with self.assertRaisesRegex( @@ -115,6 +109,7 @@ class TestEmbeddingOp(DTensorTestBase): 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0 ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_sharded_embedding_rowwise(self): with self.assertRaisesRegex( diff --git a/test/distributed/_tensor/test_init.py b/test/distributed/_tensor/test_init.py index f1993fb5ee..43b0b6bcd0 100644 --- a/test/distributed/_tensor/test_init.py +++ b/test/distributed/_tensor/test_init.py @@ -1,13 +1,12 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class DTensorInitOpsTest(DTensorTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -22,13 +21,13 @@ class DTensorInitOpsTest(DTensorTestBase): dtensor = init_op(dtensor, *args, **kwargs) self.assertEqual(local_tensor_clone, dtensor.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_init_ops(self): # NOTE: random init tests are moved to test_random_ops.py self._run_init_op(torch.nn.init.constant_, 2.4) -@skipIfUnsupportMultiNPU(4) class DTensorConstructorTest(DTensorTestBase): @property def world_size(self): @@ -82,6 +81,7 @@ class DTensorConstructorTest(DTensorTestBase): exp_tensor = init_op(tensor_size, *args, **kwargs) eq_op(exp_tensor, dist_tensor.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_ones(self): self._run_init_op( @@ -91,6 +91,7 @@ class DTensorConstructorTest(DTensorTestBase): requires_grad=True, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_empty(self): self._run_init_op( @@ -102,6 +103,7 @@ class DTensorConstructorTest(DTensorTestBase): requires_grad=True, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_full(self): self._run_init_op( @@ -112,6 +114,7 @@ class DTensorConstructorTest(DTensorTestBase): requires_grad=True, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zeros(self): self._run_init_op( @@ -121,6 +124,7 @@ class DTensorConstructorTest(DTensorTestBase): requires_grad=True, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zeros_full_mesh(self): # construct a npu device 1d mesh @@ -186,6 +190,7 @@ class DTensorConstructorTest(DTensorTestBase): elif self.rank == 3: self.assertEqual(local_tensor, torch.zeros([15, 1])) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zeros_submesh(self): # default world_size is 4 diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index 56cebbacb9..a5f1b83c26 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -4,15 +4,15 @@ import torch from torch.distributed._tensor import distribute_tensor, DeviceMesh from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class DistMathOpsTest(DTensorTestBase): + @skipIfUnsupportMultiNPU(4) @with_comms def test_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index 976c6b0dda..87412f4a0d 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -10,11 +10,11 @@ from torch.distributed._tensor.placement_types import ( Replicate, Shard, ) -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class DistMatrixOpsTest(DTensorTestBase): diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py index 5110727c5a..8bc88b8a54 100644 --- a/test/distributed/_tensor/test_pointwise_ops.py +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -14,11 +14,11 @@ from torch.distributed._tensor.placement_types import ( Shard, ) from torch.distributed.distributed_c10d import ReduceOp -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests def no_op(): @@ -60,16 +60,12 @@ def deepcopy_convert_from_dtensor(val: Any) -> Any: def f(x): if isinstance(x, DTensor): - return x.redistribute( - device_mesh=x.device_mesh, - placements=[Replicate()] * x.device_mesh.ndim, - ).to_local() + return x.full_tensor() return x return pytree.tree_map(f, [val])[0] -@skipIfUnsupportMultiNPU(4) class DistElementwiseOpsTest(DTensorTestBase): def _compare_pairwise_ops( self, @@ -134,6 +130,7 @@ class DistElementwiseOpsTest(DTensorTestBase): kwargs=kwargs, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_partial_add(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -142,6 +139,7 @@ class DistElementwiseOpsTest(DTensorTestBase): d_3 = d_1 + d_2 self.assertEqual(d_3._spec.placements[0].is_partial(), True) + @skipIfUnsupportMultiNPU(4) @with_comms def test_activations(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -182,6 +180,7 @@ class DistElementwiseOpsTest(DTensorTestBase): op=torch.sigmoid, ) + @skipIfUnsupportMultiNPU(4) @skip("testing RNG based ops is broken") @with_comms def test_dropout(self): @@ -209,6 +208,7 @@ class DistElementwiseOpsTest(DTensorTestBase): training=True, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_dropout_errors(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -220,6 +220,7 @@ class DistElementwiseOpsTest(DTensorTestBase): op=torch.nn.functional.dropout, ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_mul_out(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index 3539dbaad6..506ce502a5 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -13,11 +13,11 @@ from torch.distributed._tensor.random import is_rng_supported_mesh, manual_seed from torch.distributed.distributed_c10d import broadcast_object_list -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class DistTensorRandomInitTest(DTensorTestBase): diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index 99b5f4c4dd..4ae48715b2 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -3,15 +3,16 @@ import itertools import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests + -@skipIfUnsupportMultiNPU(4) class RedistributeTest(DTensorTestBase): + @skipIfUnsupportMultiNPU(4) @with_comms def test_shard_to_replicate_forward_backward(self): # 1) test shard -> replicate forward @@ -47,6 +48,7 @@ class RedistributeTest(DTensorTestBase): grad_input.to_local(), torch.ones(dtensor.to_local().size()) ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_replicate_to_replicate_forward_backward(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -66,6 +68,7 @@ class RedistributeTest(DTensorTestBase): self.assertEqual(grad_input.placements, replica_spec) self.assertEqual(grad_input.to_local(), torch.ones(12, 3)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_replicate_to_shard_forward_backward(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -105,6 +108,7 @@ class RedistributeTest(DTensorTestBase): self.assertEqual(grad_input.placements, replica_spec) self.assertEqual(grad_input.to_local(), torch.ones(input_size).npu()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_partial_to_replicate_forward_backward(self): # Although we don't allow user to reshard to produce a partial @@ -130,6 +134,7 @@ class RedistributeTest(DTensorTestBase): if device_mesh.get_rank() == 0: self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_replicate_to_partial(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -173,6 +178,7 @@ class RedistributeTest(DTensorTestBase): else: self.assertEqual(replica_tensor.to_local(), torch.zeros_like(local_tensor)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_partial_to_shard(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index 38a6fad182..93d311bee3 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -1,7 +1,6 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorConverter, DTensorTestBase @@ -9,10 +8,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class DistTensorOpsTest(DTensorTestBase): + @skipIfUnsupportMultiNPU(4) @with_comms def test_aten_contiguous(self): # this op not covered by dtensor_ops @@ -23,6 +23,7 @@ class DistTensorOpsTest(DTensorTestBase): torch.randn(16, 32), ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_detach(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -33,6 +34,7 @@ class DistTensorOpsTest(DTensorTestBase): detached_mat = mat.detach() self.assertFalse(detached_mat is mat) + @skipIfUnsupportMultiNPU(4) @with_comms def test_clone(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -44,6 +46,7 @@ class DistTensorOpsTest(DTensorTestBase): self.assertFalse(cloned_mat is mat) self.assertEqual(cloned_mat.to_local(), mat.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_inplace_op(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -71,6 +74,7 @@ class DistTensorOpsTest(DTensorTestBase): self.assertTrue(res is dt_to_inplace_add) self.assertTrue(res.placements == tuple(shard_spec)) + @skipIfUnsupportMultiNPU(4) @with_comms def test_op_out_variant(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -92,6 +96,7 @@ class DistTensorOpsTest(DTensorTestBase): self.assertTrue(res.placements == tuple(replica_spec)) self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_empty_like(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -103,6 +108,7 @@ class DistTensorOpsTest(DTensorTestBase): # empty is not deterministic, so we only check that the shard propagation worked self.assertEqual((4, 8), empty_like_dt.to_local().shape) + @skipIfUnsupportMultiNPU(4) @with_comms def test_fill_inplace(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -115,6 +121,7 @@ class DistTensorOpsTest(DTensorTestBase): self.assertEqual(full_expected, full_like_dt.to_local()) self.assertEqual(full_expected, dist_tensor.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_full_like(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -126,6 +133,7 @@ class DistTensorOpsTest(DTensorTestBase): full_expected = torch.full((4, 8), 42.0) self.assertEqual(full_expected, full_like_dt.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_ones_like(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -137,6 +145,7 @@ class DistTensorOpsTest(DTensorTestBase): ones_expected = torch.ones(4, 8) self.assertEqual(ones_expected, ones_like_dt.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_ones_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -150,9 +159,10 @@ class DistTensorOpsTest(DTensorTestBase): ones_expected = torch.ones(dist_tensor.shape) self.assertEqual( ones_expected, - ones_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + ones_like_dt.full_tensor() ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_fill_inplace_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -166,9 +176,10 @@ class DistTensorOpsTest(DTensorTestBase): fill_expected = torch.full(dist_tensor.shape, 42, dtype=input_tensor.dtype) self.assertEqual( fill_expected, - dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(), + dist_tensor.full_tensor() ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zeros_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -182,9 +193,10 @@ class DistTensorOpsTest(DTensorTestBase): zeros_expected = torch.zeros(dist_tensor.shape) self.assertEqual( zeros_expected, - zeros_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + zeros_like_dt.full_tensor() ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zero_inplace(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -197,6 +209,7 @@ class DistTensorOpsTest(DTensorTestBase): self.assertEqual(zeros_expected, zeros_like_dt.to_local()) self.assertEqual(zeros_expected, dist_tensor.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_zeros_like(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -208,6 +221,7 @@ class DistTensorOpsTest(DTensorTestBase): zeros_expected = torch.zeros(4, 8) self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + @skipIfUnsupportMultiNPU(4) @with_comms def test_equal(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -241,8 +255,8 @@ class DistTensorOpsTest(DTensorTestBase): self.assertTrue(dtc.successful()) d_out = op_call(*d_args, **d_kwargs) self.assertEqual( - d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(), - out, + d_out.full_tensor(), + out ) diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index d278898991..179a28735c 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -1,26 +1,26 @@ import itertools + import torch from torch.distributed._tensor import distribute_tensor +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed._tensor._utils import ( compute_local_shape, compute_local_shape_and_global_offset, ) -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import Replicate, Shard - -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests -@skipIfUnsupportMultiNPU(4) class UtilTest(DTensorTestBase): @property def world_size(self): return 8 + @skipIfUnsupportMultiNPU(4) @with_comms def test_compute_local_shape_2d_uneven(self): # mesh: 4 * 2 @@ -51,6 +51,7 @@ class UtilTest(DTensorTestBase): else: self.assertEqual(local_size3[1], 3) + @skipIfUnsupportMultiNPU(4) @with_comms def test_compute_local_shape_and_global_offset_1D(self): one_d_placements = [[Shard(0)], [Replicate()]] @@ -76,6 +77,7 @@ class UtilTest(DTensorTestBase): global_tensor[dim0_start:dim0_end], ) + @skipIfUnsupportMultiNPU(4) @with_comms def test_compute_local_shape_and_global_offset_2D(self): two_d_placements_options = [Shard(0), Shard(1), Replicate()] diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index da566e6f55..5621843b3c 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -16,7 +16,6 @@ from torch.distributed._tensor.ops.view_ops import ( view_groups, ) from torch.distributed._tensor.placement_types import Placement -from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, redistribute_profiler @@ -25,6 +24,7 @@ from torch.utils._pytree import tree_flatten import torch_npu from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU +from torch_npu.testing.testcase import run_tests class TestViewOps(DTensorTestBase): -- Gitee