Home
last modified time | relevance | path

Searched refs:branch_computations (Results 1 – 22 of 22) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/
H A Ddynamic_shaped_ops.cc168 absl::Span<const XlaComputation* const> branch_computations, in DynamicConditional() argument
172 root_shapes.reserve(branch_computations.size()); in DynamicConditional()
173 for (int64_t i = 0; i < branch_computations.size(); ++i) { in DynamicConditional()
175 branch_computations[i]->GetProgramShape()); in DynamicConditional()
185 return xla::Conditional(branch_index, branch_computations, in DynamicConditional()
211 rewritten_computations.reserve(branch_computations.size()); in DynamicConditional()
213 for (int64_t i = 0; i < branch_computations.size(); ++i) { in DynamicConditional()
219 max_shape, *branch_computations[i])); in DynamicConditional()
223 rewritten_computation_ptrs.reserve(branch_computations.size()); in DynamicConditional()
224 for (int64_t i = 0; i < branch_computations.size(); ++i) { in DynamicConditional()
H A Ddynamic_shaped_ops.h40 absl::Span<const XlaComputation* const> branch_computations,
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dar_crs_combiner.cc252 for (int64_t i = 0; i < cond_instr->branch_computations().size(); ++i) { in GetAllTuples()
290 const auto& branch_computations = instruction->branch_computations(); in GetAllTuples() local
291 result_tuples.reserve(branch_computations.size()); in GetAllTuples()
292 for (HloComputation* body : branch_computations) { in GetAllTuples()
H A Dshape_inference.cc2902 absl::Span<const ProgramShape> branch_computations, in InferConditionalShape() argument
2910 TF_RET_CHECK(2 == branch_computations.size()); in InferConditionalShape()
2912 TF_RET_CHECK(!branch_computations.empty()); in InferConditionalShape()
2914 TF_RET_CHECK(branch_computations.size() == branch_operands.size()); in InferConditionalShape()
2915 Shape result = branch_computations[0].result(); in InferConditionalShape()
2916 for (int j = 0; j < branch_computations.size(); ++j) { in InferConditionalShape()
2917 if (branch_computations[j].parameters_size() != 1) { in InferConditionalShape()
2920 branch_computations[j].parameters_size()); in InferConditionalShape()
2922 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0), in InferConditionalShape()
2927 ShapeUtil::HumanString(branch_computations[j])); in InferConditionalShape()
[all …]
H A Dconditional_simplifier.cc257 for (HloComputation* branch : conditional_op->branch_computations()) { in RemoveUnusedTupleElements()
376 for (const HloComputation* branch : conditional->branch_computations()) { in MergeDuplicateTupleElements()
401 absl::c_transform(conditional->branch_computations(), in MergeDuplicateTupleElements()
H A Dbfloat16_propagation.cc389 absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) { in DetermineInstructionPrecision()
556 for (auto* branch : hlo->branch_computations()) { in AdjustCalledComputationRoot()
678 for (auto* branch : hlo->branch_computations()) { in ResolveInconsistencyOfAliasingBuffersHelper()
H A Dhlo_instruction.cc1491 absl::Span<HloComputation* const> branch_computations, in CreateConditional() argument
1496 CHECK_EQ(branch_computations.size(), branch_computation_args.size()); in CreateConditional()
1497 for (int i = 0; i < branch_computations.size(); ++i) { in CreateConditional()
1498 instruction->called_computations_.push_back(branch_computations[i]); in CreateConditional()
2135 absl::MakeSpan(branch_computations()), in CloneWithNewOperands()
2872 const std::vector<HloComputation*>& HloInstruction::branch_computations() in branch_computations() function in xla::HloInstruction
3224 StrJoin(branch_computations(), ", ", in ExtraAttributesToString()
3295 StrJoin(branch_computations(), ",\n", in ExtraAttributesToString()
H A Dlayout_assignment.cc1210 const auto& branch_computations = instruction->branch_computations(); in CheckLayouts() local
1211 branch_computation_layouts.reserve(branch_computations.size()); in CheckLayouts()
1212 for (const auto branch_computation : branch_computations) { in CheckLayouts()
H A Dshape_inference.h264 absl::Span<const ProgramShape> branch_computations,
H A Dhlo_parser.cc2525 optional<std::vector<HloComputation*>> branch_computations; in CreateInstruction() local
2547 &branch_computations}; in CreateInstruction()
2553 branch_computations.emplace({*true_computation, *false_computation}); in CreateInstruction()
2555 if (branch_computations->empty() || in CreateInstruction()
2556 operands.size() != branch_computations->size() + 1) { in CreateInstruction()
2561 branch_computation_shapes.reserve(branch_computations->size()); in CreateInstruction()
2562 for (auto* computation : *branch_computations) { in CreateInstruction()
2579 absl::MakeSpan(*branch_computations), in CreateInstruction()
H A Dhlo_instruction.h1055 absl::Span<HloComputation* const> branch_computations,
1480 const std::vector<HloComputation*>& branch_computations() const;
H A Dhlo_parser_test.cc1229 …onstant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %… in CreateTestCases()
1573 …= f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Id… in CreateTestCases()
H A Dcopy_insertion.cc1796 for (HloComputation* computation : conditional->branch_computations()) { in AddCopiesForConditional()
H A Dconditional_code_motion.cc578 absl::MakeSpan(conditional->branch_computations()), in ConvertSpecialMove()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/tests/translate/
H A Dcase.mlir43 …f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_…
90 …f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_…
150 …f32[] %[[OPERAND_1]], (f32[], f32[]) %[[TUPLE1]], () %[[TUPLE2]]), branch_computations={%[[NEGATE_…
H A Dcase_conditional.hlotxt25 …onstant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %…
90 …(f32[]) %tuple.1, ((f32[]), (f32[],f32[])) %tuple.3, () %tuple.4), branch_computations={%Negate1, …
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/
H A Dno_opt_ops.hlotxt39 …onstant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %…
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/
H A Dxla_builder.h863 absl::Span<const XlaComputation* const> branch_computations,
1462 absl::Span<const XlaComputation* const> branch_computations,
1466 absl::Span<const XlaComputation* const> branch_computations,
1523 absl::Span<const XlaComputation* const> branch_computations,
2586 absl::Span<const XlaComputation* const> branch_computations,
H A Dxla_builder.cc2518 absl::Span<const XlaComputation* const> branch_computations, in Conditional() argument
2528 return ConditionalImpl(branch_index, branch_computations, branch_operands); in Conditional()
2534 absl::Span<const XlaComputation* const> branch_computations, in ConditionalImpl() argument
2543 branch_computations.size()); in ConditionalImpl()
2548 branch_computations[j]->GetProgramShape()); in ConditionalImpl()
2556 for (const XlaComputation* branch_computation : branch_computations) { in ConditionalImpl()
4887 absl::Span<const XlaComputation* const> branch_computations, in Conditional() argument
4889 return branch_index.builder()->Conditional(branch_index, branch_computations, in Conditional()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/xla_extension/
H A Dops.pyi141 branch_computations: Sequence[XlaComputation],
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/g3doc/
H A Doperation_semantics.md718 <b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
725 | `branch_computations` | sequence of N | XlaComputations of type \\( |
733 Executes `branch_computations[branch_index]`, and returns the result. If
734 `branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
737 Each `branch_computations[b]` must take in a single argument of type `T_b` and
739 type of the returned value of each `branch_computations[b]` must be the same.
741 Note that only one of the `branch_computations` will be executed depending on
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/
H A Dhlo_function_importer.cc948 llvm::enumerate(instruction->branch_computations())) { in ImportInstructionImpl()