Searched refs:all_gather_tensor (Results 1 – 15 of 15) sorted by relevance
/aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/ |
H A D | test_micro_pipeline_tp.py | 18 all_gather_tensor, 53 ag = all_gather_tensor(tensor, gather_dim=0, group=group_name) 87 a = all_gather_tensor(inp, gather_dim=0, group=group.group_name) 88 b = all_gather_tensor(inp, gather_dim=1, group=group.group_name) 179 b = all_gather_tensor(inp, gather_dim=0, group=group.group_name) 184 e = all_gather_tensor(d, gather_dim=0, group=group.group_name) 207 A = all_gather_tensor(A_shard, gather_dim=gather_dim, group=group)
|
H A D | test_tp_random_state.py | 87 tensor_gather = funcol.all_gather_tensor( 110 tensor_gather = funcol.all_gather_tensor(
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/ |
H A D | test_random_ops.py | 130 local_tensor = funcol.all_gather_tensor( 156 local_tensor = funcol.all_gather_tensor( 173 local_tensor = funcol.all_gather_tensor( 313 local_tensor = funcol.all_gather_tensor( 335 local_tensor = funcol.all_gather_tensor(
|
/aosp_15_r20/external/pytorch/test/distributed/checkpoint/ |
H A D | test_state_dict_utils.py | 46 expected_gathered_dtensor = funcol.all_gather_tensor( 65 expected_gathered_dtensor = funcol.all_gather_tensor( 101 tensor = funcol.all_gather_tensor(
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/experimental/ |
H A D | test_local_map.py | 33 eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) 39 return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait() 167 X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait()
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/debug/ |
H A D | test_comm_mode.py | 51 x = funcol.all_gather_tensor(x, 0, world_pg) 77 x = funcol.all_gather_tensor(x, 0, world_pg)
|
/aosp_15_r20/external/pytorch/torch/distributed/ |
H A D | _functional_collectives.py | 179 def all_gather_tensor( function 1020 return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) 1115 output = all_gather_tensor(tensor, 0, group, tag)
|
/aosp_15_r20/external/pytorch/test/distributed/ |
H A D | test_c10d_functional_native.py | 13 all_gather_tensor, 211 output = all_gather_tensor( 650 ag0 = funcol.all_gather_tensor(arg, 0, "0")
|
H A D | test_inductor_collectives.py | 302 res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag) 326 res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag) 641 ar = _functional_collectives.all_gather_tensor(inp, 0, "0") 657 ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
|
H A D | test_fake_pg.py | 95 return funcol.all_gather_tensor(tensor, 0, default_pg)
|
H A D | test_device_mesh.py | 186 global_tensor = funcol.all_gather_tensor( 805 big_tensor = funcol.all_gather_tensor(
|
H A D | test_functional_api.py | 299 gathered_tensor = ft_c.all_gather_tensor(
|
/aosp_15_r20/external/pytorch/torch/distributed/tensor/ |
H A D | placement_types.py | 244 result = funcol.all_gather_tensor( 504 result = funcol.all_gather_tensor(
|
H A D | _collective_utils.py | 58 out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
|
/aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/ |
H A D | test_fully_shard_mixed_precision.py | 219 param.grad = funcol.all_gather_tensor(
|