Searched defs:gather_dim (Results 1 – 5 of 5) sorted by relevance
/aosp_15_r20/external/pytorch/torch/distributed/tensor/ |
H A D | _collective_utils.py | 32 def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): argument 50 def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): argument
|
/aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/ |
H A D | test_micro_pipeline_tp.py | 200 def test_fuse_all_gather_matmul(self, A_dims, gather_dim): argument 238 def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim): argument
|
/aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ |
H A D | Functional.cpp | 605 int64_t gather_dim, in shard_dim_alltoall()
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/ |
H A D | indexed_array_analysis.cc | 213 for (int64_t gather_dim : source->output_dims()) { in FoldGatherOfGather() local
|
/aosp_15_r20/external/pytorch/torch/_inductor/ |
H A D | lowering.py | 6444 def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): argument
|