Home
last modified time | relevance | path

Searched refs:split_dim_shape (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
H A Dsplit_op.cc39 const TensorShape split_dim_shape = ctx->InputShape("split_dim"); in Compile() local
43 ctx, TensorShapeUtils::IsScalar(split_dim_shape), in Compile()
45 split_dim_shape.dims())); in Compile()
/aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/xnnpack/
H A Dsplit_tester.cc170 std::array<int32_t, 0> split_dim_shape = {}; in CreateTfLiteModel() local
180 builder.CreateVector<int32_t>(split_dim_shape.data(), in CreateTfLiteModel()
181 split_dim_shape.size()), in CreateTfLiteModel()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/
H A Dspmd_partitioner.cc4386 std::vector<int64_t> split_dim_shape; in AllGatherShardsInternal() local
4387 split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank()); in AllGatherShardsInternal()
4389 split_dim_shape.push_back(sharding.tile_assignment().dim(i)); in AllGatherShardsInternal()
4392 split_dim_shape.push_back(dim); in AllGatherShardsInternal()
4395 ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape), in AllGatherShardsInternal()