Home
last modified time | relevance | path

Searched refs:layout_specs (Results 1 – 6 of 6) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/
H A Din_top_k_spmd_expander.cc38 std::vector<ShardingSpec> layout_specs(2); in GetSuggestedPredictionsLayout() local
39 layout_specs[0].set_sharding_spec(layout.sharding_spec(0)); in GetSuggestedPredictionsLayout()
40 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); in GetSuggestedPredictionsLayout()
42 return Layout::GetLayout(layout_specs, layout.mesh()); in GetSuggestedPredictionsLayout()
49 std::vector<ShardingSpec> layout_specs(layout.rank()); in MatchBatchDim() local
50 layout_specs[0].set_sharding_spec(other_layout.sharding_spec(0)); in MatchBatchDim()
52 layout_specs[i].set_sharding_spec(layout.sharding_spec(i)); in MatchBatchDim()
55 return Layout::GetLayout(layout_specs, layout.mesh()); in MatchBatchDim()
145 std::vector<std::string> layout_specs(1); in ComputeLayoutForward() local
146 layout_specs[0] = Layout::kUnshardedDim; in ComputeLayoutForward()
[all …]
H A Dsqueeze_spmd_expander.cc58 std::vector<ShardingSpec> layout_specs; in ComputeLayoutForward() local
59 layout_specs.reserve(input_layout.rank()); in ComputeLayoutForward()
63 layout_specs.push_back(input_layout.dim(i)); in ComputeLayoutForward()
67 layout_specs.push_back(input_layout.dim(i)); in ComputeLayoutForward()
73 Layout::GetLayout(layout_specs, input_layout.mesh())); in ComputeLayoutForward()
91 std::vector<ShardingSpec> layout_specs; in ComputeLayoutBackward() local
92 layout_specs.reserve(output_layout.rank()); in ComputeLayoutBackward()
97 layout_specs.push_back(output_layout.dim(j++)); in ComputeLayoutBackward()
99 layout_specs.push_back(unsharded_spec); in ComputeLayoutBackward()
103 layout_specs.push_back(output_layout.dim(j++)); in ComputeLayoutBackward()
[all …]
H A Dsoftmax_spmd_expander.cc641 std::vector<ShardingSpec> layout_specs(2); in ComputeLayoutForward() local
642 layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); in ComputeLayoutForward()
643 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); in ComputeLayoutForward()
648 layout_specs[0] = features_layout->dim(0); in ComputeLayoutForward()
652 Layout::IsUnshardedSpec(layout_specs[0])) in ComputeLayoutForward()
653 layout_specs[0] = labels_layout->dim(0); in ComputeLayoutForward()
658 (layout_specs[0].sharding_spec() != in ComputeLayoutForward()
660 layout_specs[1] = features_layout->dim(features_layout->rank() - 1); in ComputeLayoutForward()
662 Layout::IsUnshardedSpec(layout_specs[1]) && in ComputeLayoutForward()
663 (layout_specs[0].sharding_spec() != in ComputeLayoutForward()
[all …]
H A Dtop_k_spmd_expander.cc32 std::vector<ShardingSpec> layout_specs(input_layout.rank()); in GetSuggestedLayout() local
35 layout_specs[i].set_sharding_spec(input_layout.sharding_spec(i)); in GetSuggestedLayout()
37 layout_specs[input_layout.rank() - 1].set_sharding_spec( in GetSuggestedLayout()
40 return Layout::GetLayout(layout_specs, input_layout.mesh()); in GetSuggestedLayout()
H A Dmeta_spmd_expander.cc639 std::vector<std::string> layout_specs; in MakeLayoutForReshape() local
640 layout_specs.reserve(output_shape.size()); in MakeLayoutForReshape()
643 layout_specs.push_back(Layout::kUnshardedDim); in MakeLayoutForReshape()
651 layout_specs[output_segment_start[i]] = in MakeLayoutForReshape()
654 return Layout::GetLayout(layout_specs, input_layout.mesh()); in MakeLayoutForReshape()
/aosp_15_r20/external/pytorch/torch/distributed/checkpoint/
H A Doptimizer.py269 layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
312 alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
341 if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
342 fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])