Home
last modified time | relevance | path

Searched defs:gather_dim (Results 1 – 5 of 5) sorted by relevance

/aosp_15_r20/external/pytorch/torch/distributed/tensor/
H A D_collective_utils.py32 def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): argument
50 def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): argument
/aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/
H A Dtest_micro_pipeline_tp.py200 def test_fuse_all_gather_matmul(self, A_dims, gather_dim): argument
238 def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim): argument
/aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/
H A DFunctional.cpp605 int64_t gather_dim, in shard_dim_alltoall()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dindexed_array_analysis.cc213 for (int64_t gather_dim : source->output_dims()) { in FoldGatherOfGather() local
/aosp_15_r20/external/pytorch/torch/_inductor/
H A Dlowering.py6444 def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): argument