Home
last modified time | relevance | path

Searched refs:GetPartitionGroupsForReplication (Results 1 – 2 of 2) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/
H A Dspmd_partitioner.cc203 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication( in GetPartitionGroupsForReplication() function
4352 GetPartitionGroupsForReplication(sharding, {*it}); in AllGatherShardsInternal()
4371 GetPartitionGroupsForReplication(sharding, selected_dims); in AllGatherShardsInternal()
4444 GetPartitionGroupsForReplication(sharding, selected_dims); in AllReduceAlongShardingDimsInternal()
4454 GetPartitionGroupsForReplication(sharding, {*it}); in AllReduceAlongShardingDimsInternal()
H A Ddot_handler.cc386 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication( in GetPartitionGroupsForReplication() function
3165 GetPartitionGroupsForReplication(other_sharding, ag_replication_dims); in PrioritizeContractingDimensionsPartitioning()
3166 auto reduce_scatter_subgroups = GetPartitionGroupsForReplication( in PrioritizeContractingDimensionsPartitioning()