Home
last modified time | relevance | path

Searched defs:batch_dimensions (Results 1 – 5 of 5) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/
H A Dgemm_broadcast_folding_rewriter.cc57 const tensorflow::protobuf::RepeatedField<int64_t> &batch_dimensions = in HandleCustomCall() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dtriangular_solve_expander.cc487 std::vector<int64_t> batch_dimensions; in BuildTriangularSolve() local
H A Dalgebraic_simplifier.cc2092 HloInstruction* dot_operand, absl::Span<const int64_t> batch_dimensions, in NormalizeDotOperandToBatchMajorAndContractingMinor()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
H A Dlegalize_hlo.cc1618 DotDimensionsInfo(ShapedType type, ArrayRef<int64_t> batch_dimensions, in DotDimensionsInfo()
1641 const DimensionVector &batch_dimensions() const { return batch_dimensions_; } in batch_dimensions() function in mlir::TF::__anonab43b57b0111::DotDimensionsInfo
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
H A Dlegalize_tf.cc444 auto batch_dimensions = in BatchDot() local
3257 auto batch_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)); in matchAndRewrite() local