xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17 
18 #include <memory>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
22 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
28 
29 namespace xla {
30 namespace cpu {
31 
32 class SimpleCostModel : public ParallelCostModel {
33  public:
SimpleCostModel(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)34   SimpleCostModel(const int64_t max_parallelism,
35                   const HloCostAnalysis::ShapeSizeFunction& shape_size)
36       : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()37   ~SimpleCostModel() override {}
38 
GetParallelTaskCount(HloInstruction * instruction)39   int64_t GetParallelTaskCount(HloInstruction* instruction) override {
40     // Simple cost model based on hlo size and typical L2 cache size.
41     const int64_t instruction_cost = shape_size_(instruction->shape());
42     const int64_t min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
43     // Return target parallel task count in [1, max_parallelism_].
44     return std::min(
45         max_parallelism_,
46         std::max(int64_t{1}, instruction_cost / min_cost_per_thread));
47   }
48 
49  private:
50   const int64_t max_parallelism_;
51   const HloCostAnalysis::ShapeSizeFunction shape_size_;
52 };
53 
54 class DefaultCostModel : public ParallelCostModel {
55  public:
DefaultCostModel(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)56   DefaultCostModel(const int64_t max_parallelism,
57                    const HloCostAnalysis::ShapeSizeFunction& shape_size,
58                    std::unique_ptr<HloCostAnalysis> cost_analysis)
59       : max_parallelism_(max_parallelism),
60         shape_size_(shape_size),
61         cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()62   ~DefaultCostModel() override {}
63 
GetParallelTaskCount(HloInstruction * instruction)64   int64_t GetParallelTaskCount(HloInstruction* instruction) override {
65     // Parameters for parallel task count computation.
66     int64_t instruction_cost;
67     int64_t min_cost_per_thread;
68     int64_t max_parallelism;
69     // Calculate flops-to-bytes-ratio for 'instruction'.
70     const int64_t bytes_accessed =
71         std::max(int64_t{1}, cost_analysis_->bytes_accessed(*instruction));
72     const float flops_to_bytes_ratio =
73         cost_analysis_->flop_count(*instruction) /
74         static_cast<float>(bytes_accessed);
75     // Check for I/O bound instructions.
76     if (flops_to_bytes_ratio <= 1.0) {
77       // Limit max parallelism for I/O bound instructions by assuming a
78       // sub-linear scaling function (fit based on empirical benchmark results).
79       // TODO(b/29630486) Develop system bandwidth model.
80       max_parallelism = std::min<int64_t>(
81           max_parallelism_,
82           std::ceil(std::sqrt(tensorflow::port::MaxParallelism())));
83       // Use shape size instruction cost and L2 cache size min per-thread cost.
84       instruction_cost = shape_size_(instruction->shape());
85       min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
86     } else {
87       // Use max parallelism for compute bound instructions.
88       max_parallelism = max_parallelism_;
89       // Calculate the instruction cost in cycles.
90       // TODO(b/29630486) Improve on this linear cost model.
91       // Consider making 'min_cost_per_thread' be a function of the target
92       // bandwidth limit for instructions with low arithmetic complexity.
93       instruction_cost =
94           1 * cost_analysis_->flop_count(*instruction) +
95           2 * cost_analysis_->transcendental_count(*instruction) +
96           10 * cost_analysis_->bytes_accessed(*instruction);
97       // Minimum per-thread cost is 100us of work on a 2GHz core.
98       min_cost_per_thread = 100000;
99     }
100     // Return target parallel task count in [1, max_parallelism_].
101     return std::min(
102         max_parallelism,
103         std::max(int64_t{1}, instruction_cost / min_cost_per_thread));
104   }
105 
106  private:
107   const int64_t max_parallelism_;
108   const HloCostAnalysis::ShapeSizeFunction shape_size_;
109   const std::unique_ptr<HloCostAnalysis> cost_analysis_;
110 };
111 
ParallelTaskAssignment(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)112 ParallelTaskAssignment::ParallelTaskAssignment(
113     const int64_t max_parallelism,
114     const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
115     const TargetMachineFeatures* target_machine_features)
116     : target_machine_features_(*target_machine_features) {
117   VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
118   // Run cost analysis on 'module'.
119   auto cost_analysis = std::make_unique<HloCostAnalysis>(shape_size);
120   HloComputation* computation = module->entry_computation();
121   Status status = computation->root_instruction()->Accept(cost_analysis.get());
122   if (status.ok()) {
123     // Set default cost model based on 'cost_analysis'.
124     cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
125                                            std::move(cost_analysis)));
126   } else {
127     // Fall back to a simple cost model based on hlo size and L2 cache size.
128     // Note that HloCostAnalysis can returns an error status (likely because
129     // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
130     cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
131   }
132 }
133 
GetTargetParallelTaskCount(HloInstruction * instruction)134 int64_t ParallelTaskAssignment::GetTargetParallelTaskCount(
135     HloInstruction* instruction) {
136   // Currently, we do not assign parallel tasks to instructions with at least
137   // one of the following properties:
138   // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
139   // *) Emit custom loops (kSelectAndScatter).
140   // *) Operations that are not thread safe (like infeed and rng).
141   // *) Tuple-shaped.
142   // *) Operations that might be implemented as an in-place
143   //    dynamic-update-slice, because we can't know how many output elements
144   //    they will write (out-of-place will touch the whole output buffer, while
145   //    in-place will only touch the updated elements).
146   // TODO(b/27458679) Parallelize instructions which are skipped here.
147   auto opcode = instruction->opcode();
148   if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
149       instruction->shape().IsTuple() || opcode == HloOpcode::kRng ||
150       opcode == HloOpcode::kConstant) {
151     return 1;
152   }
153 
154   // Only allow instructions that can be trivially parallelized (where all
155   // outputs can be computed independently of each other).
156   if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
157       opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
158       opcode == HloOpcode::kDynamicSlice ||
159       opcode == HloOpcode::kDynamicUpdateSlice ||
160       opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
161       opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
162       opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
163       opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
164       opcode == HloOpcode::kTranspose ||
165       (opcode == HloOpcode::kConvolution &&
166        !PotentiallyImplementedAsEigenConvolution(*instruction,
167                                                  target_machine_features_))) {
168     // Consult 'cost_model_' to compute target parallel task count.
169     return cost_model_->GetParallelTaskCount(instruction);
170   }
171 
172   return 1;
173 }
174 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)175 StatusOr<bool> ParallelTaskAssigner::Run(
176     HloModule* module,
177     const absl::flat_hash_set<absl::string_view>& execution_threads) {
178   XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
179   XLA_VLOG_LINES(3, module->ToString());
180   // Compute target parallel task counts for all instructions in 'module'.
181   HloToParallelTasks hlo_to_parallel_tasks;
182   ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
183 
184   // Assign parallel tasks to target specific instructions in 'module'.
185   // TODO(b/27458679) Support inter-op parallelism.
186   bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
187 
188   XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
189   XLA_VLOG_LINES(3, module->ToString());
190   return changed;
191 }
192 
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)193 bool ParallelTaskAssigner::AssignParallelTasks(
194     HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
195   return AssignParallelTasksHelper(module, module->entry_computation(),
196                                    hlo_to_parallel_tasks);
197 }
198 
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)199 bool ParallelTaskAssigner::AssignParallelTasksHelper(
200     HloModule* module, HloComputation* computation,
201     const HloToParallelTasks& hlo_to_parallel_tasks) {
202   bool changed = false;
203   // Snapshot set of instructions because outlining modifies the set below.
204   std::vector<HloInstruction*> instructions(computation->instructions().begin(),
205                                             computation->instructions().end());
206   for (auto* instruction : instructions) {
207     // Assign parallel tasks to sub-computations for While and Call HLOs.
208     // TODO(b/27458679) Evaluate alternative intra-op parallelism placement,
209     // and support other callable computations like reduce.
210     if (instruction->opcode() == HloOpcode::kWhile) {
211       changed |= AssignParallelTasksHelper(module, instruction->while_body(),
212                                            hlo_to_parallel_tasks);
213       continue;
214     } else if (instruction->opcode() == HloOpcode::kCall) {
215       changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
216                                            hlo_to_parallel_tasks);
217       continue;
218     }
219     // Skip if no parallel tasks were computed in first pass.
220     auto it = hlo_to_parallel_tasks.find(instruction);
221     if (it == hlo_to_parallel_tasks.end()) {
222       continue;
223     }
224     // Get target parallel task count computed for 'instruction'.
225     const int64_t target_parallel_task_count = (*it).second;
226     // Assign feasible dimension partitions (based on actual dimension sizes).
227     auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
228                                     .Run(target_parallel_task_count);
229     const int64_t total_partition_count =
230         ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
231     if (total_partition_count <= 1) {
232       // Feasible partition calculation resulting in no partitioning, so skip.
233       continue;
234     }
235 
236     // Outline 'instruction' in 'computation' for parallel task assignment.
237     auto* call = module->OutlineExpressionFromComputation(
238         {instruction}, absl::StrCat("parallel_", instruction->name()),
239         computation);
240 
241     // Set assigned dimension partitioning to 'instruction'.
242     auto* new_root = call->to_apply()->root_instruction();
243     new_root->set_outer_dimension_partitions(dim_partition_counts);
244 
245     VLOG(2) << "Assigned parallel task count: " << total_partition_count
246             << " to instruction: " << new_root->name()
247             << " parent: " << new_root->parent()->name();
248     changed = true;
249   }
250   return changed;
251 }
252 
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)253 void ParallelTaskAssigner::ComputeTargetParallelTasks(
254     HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
255   ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
256                                                   shape_size_function_, module,
257                                                   &target_machine_features_);
258 
259   // Compute parallel task counts for all instructions in 'module'.
260   for (auto* computation : module->MakeNonfusionComputations()) {
261     for (auto* instruction : computation->instructions()) {
262       // Query ParallelTaskAssignment for target parallel task count.
263       const int64_t target_parallel_task_count =
264           parallel_task_assignment.GetTargetParallelTaskCount(instruction);
265       if (target_parallel_task_count > 1) {
266         hlo_to_parallel_tasks->insert(
267             {instruction, target_parallel_task_count});
268       }
269     }
270   }
271 }
272 
273 }  // namespace cpu
274 }  // namespace xla
275