Searched refs:layout_specs (Results 1 – 6 of 6) sorted by relevance
38 std::vector<ShardingSpec> layout_specs(2); in GetSuggestedPredictionsLayout() local39 layout_specs[0].set_sharding_spec(layout.sharding_spec(0)); in GetSuggestedPredictionsLayout()40 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); in GetSuggestedPredictionsLayout()42 return Layout::GetLayout(layout_specs, layout.mesh()); in GetSuggestedPredictionsLayout()49 std::vector<ShardingSpec> layout_specs(layout.rank()); in MatchBatchDim() local50 layout_specs[0].set_sharding_spec(other_layout.sharding_spec(0)); in MatchBatchDim()52 layout_specs[i].set_sharding_spec(layout.sharding_spec(i)); in MatchBatchDim()55 return Layout::GetLayout(layout_specs, layout.mesh()); in MatchBatchDim()145 std::vector<std::string> layout_specs(1); in ComputeLayoutForward() local146 layout_specs[0] = Layout::kUnshardedDim; in ComputeLayoutForward()[all …]
58 std::vector<ShardingSpec> layout_specs; in ComputeLayoutForward() local59 layout_specs.reserve(input_layout.rank()); in ComputeLayoutForward()63 layout_specs.push_back(input_layout.dim(i)); in ComputeLayoutForward()67 layout_specs.push_back(input_layout.dim(i)); in ComputeLayoutForward()73 Layout::GetLayout(layout_specs, input_layout.mesh())); in ComputeLayoutForward()91 std::vector<ShardingSpec> layout_specs; in ComputeLayoutBackward() local92 layout_specs.reserve(output_layout.rank()); in ComputeLayoutBackward()97 layout_specs.push_back(output_layout.dim(j++)); in ComputeLayoutBackward()99 layout_specs.push_back(unsharded_spec); in ComputeLayoutBackward()103 layout_specs.push_back(output_layout.dim(j++)); in ComputeLayoutBackward()[all …]
641 std::vector<ShardingSpec> layout_specs(2); in ComputeLayoutForward() local642 layout_specs[0].set_sharding_spec(Layout::kUnshardedDim); in ComputeLayoutForward()643 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim); in ComputeLayoutForward()648 layout_specs[0] = features_layout->dim(0); in ComputeLayoutForward()652 Layout::IsUnshardedSpec(layout_specs[0])) in ComputeLayoutForward()653 layout_specs[0] = labels_layout->dim(0); in ComputeLayoutForward()658 (layout_specs[0].sharding_spec() != in ComputeLayoutForward()660 layout_specs[1] = features_layout->dim(features_layout->rank() - 1); in ComputeLayoutForward()662 Layout::IsUnshardedSpec(layout_specs[1]) && in ComputeLayoutForward()663 (layout_specs[0].sharding_spec() != in ComputeLayoutForward()[all …]
32 std::vector<ShardingSpec> layout_specs(input_layout.rank()); in GetSuggestedLayout() local35 layout_specs[i].set_sharding_spec(input_layout.sharding_spec(i)); in GetSuggestedLayout()37 layout_specs[input_layout.rank() - 1].set_sharding_spec( in GetSuggestedLayout()40 return Layout::GetLayout(layout_specs, input_layout.mesh()); in GetSuggestedLayout()
639 std::vector<std::string> layout_specs; in MakeLayoutForReshape() local640 layout_specs.reserve(output_shape.size()); in MakeLayoutForReshape()643 layout_specs.push_back(Layout::kUnshardedDim); in MakeLayoutForReshape()651 layout_specs[output_segment_start[i]] = in MakeLayoutForReshape()654 return Layout::GetLayout(layout_specs, input_layout.mesh()); in MakeLayoutForReshape()
269 layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)312 alloc_size = layout_specs.get(spec_key, (None, value.size))[1]341 if spec_key in layout_specs and layout_specs[spec_key][0] is not None:342 fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])