1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17
18 #include <memory>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
22 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
23 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
28
29 namespace xla {
30 namespace cpu {
31
32 class SimpleCostModel : public ParallelCostModel {
33 public:
SimpleCostModel(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size)34 SimpleCostModel(const int64_t max_parallelism,
35 const HloCostAnalysis::ShapeSizeFunction& shape_size)
36 : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel()37 ~SimpleCostModel() override {}
38
GetParallelTaskCount(HloInstruction * instruction)39 int64_t GetParallelTaskCount(HloInstruction* instruction) override {
40 // Simple cost model based on hlo size and typical L2 cache size.
41 const int64_t instruction_cost = shape_size_(instruction->shape());
42 const int64_t min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
43 // Return target parallel task count in [1, max_parallelism_].
44 return std::min(
45 max_parallelism_,
46 std::max(int64_t{1}, instruction_cost / min_cost_per_thread));
47 }
48
49 private:
50 const int64_t max_parallelism_;
51 const HloCostAnalysis::ShapeSizeFunction shape_size_;
52 };
53
54 class DefaultCostModel : public ParallelCostModel {
55 public:
DefaultCostModel(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,std::unique_ptr<HloCostAnalysis> cost_analysis)56 DefaultCostModel(const int64_t max_parallelism,
57 const HloCostAnalysis::ShapeSizeFunction& shape_size,
58 std::unique_ptr<HloCostAnalysis> cost_analysis)
59 : max_parallelism_(max_parallelism),
60 shape_size_(shape_size),
61 cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel()62 ~DefaultCostModel() override {}
63
GetParallelTaskCount(HloInstruction * instruction)64 int64_t GetParallelTaskCount(HloInstruction* instruction) override {
65 // Parameters for parallel task count computation.
66 int64_t instruction_cost;
67 int64_t min_cost_per_thread;
68 int64_t max_parallelism;
69 // Calculate flops-to-bytes-ratio for 'instruction'.
70 const int64_t bytes_accessed =
71 std::max(int64_t{1}, cost_analysis_->bytes_accessed(*instruction));
72 const float flops_to_bytes_ratio =
73 cost_analysis_->flop_count(*instruction) /
74 static_cast<float>(bytes_accessed);
75 // Check for I/O bound instructions.
76 if (flops_to_bytes_ratio <= 1.0) {
77 // Limit max parallelism for I/O bound instructions by assuming a
78 // sub-linear scaling function (fit based on empirical benchmark results).
79 // TODO(b/29630486) Develop system bandwidth model.
80 max_parallelism = std::min<int64_t>(
81 max_parallelism_,
82 std::ceil(std::sqrt(tensorflow::port::MaxParallelism())));
83 // Use shape size instruction cost and L2 cache size min per-thread cost.
84 instruction_cost = shape_size_(instruction->shape());
85 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
86 } else {
87 // Use max parallelism for compute bound instructions.
88 max_parallelism = max_parallelism_;
89 // Calculate the instruction cost in cycles.
90 // TODO(b/29630486) Improve on this linear cost model.
91 // Consider making 'min_cost_per_thread' be a function of the target
92 // bandwidth limit for instructions with low arithmetic complexity.
93 instruction_cost =
94 1 * cost_analysis_->flop_count(*instruction) +
95 2 * cost_analysis_->transcendental_count(*instruction) +
96 10 * cost_analysis_->bytes_accessed(*instruction);
97 // Minimum per-thread cost is 100us of work on a 2GHz core.
98 min_cost_per_thread = 100000;
99 }
100 // Return target parallel task count in [1, max_parallelism_].
101 return std::min(
102 max_parallelism,
103 std::max(int64_t{1}, instruction_cost / min_cost_per_thread));
104 }
105
106 private:
107 const int64_t max_parallelism_;
108 const HloCostAnalysis::ShapeSizeFunction shape_size_;
109 const std::unique_ptr<HloCostAnalysis> cost_analysis_;
110 };
111
ParallelTaskAssignment(const int64_t max_parallelism,const HloCostAnalysis::ShapeSizeFunction & shape_size,HloModule * module,const TargetMachineFeatures * target_machine_features)112 ParallelTaskAssignment::ParallelTaskAssignment(
113 const int64_t max_parallelism,
114 const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
115 const TargetMachineFeatures* target_machine_features)
116 : target_machine_features_(*target_machine_features) {
117 VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
118 // Run cost analysis on 'module'.
119 auto cost_analysis = std::make_unique<HloCostAnalysis>(shape_size);
120 HloComputation* computation = module->entry_computation();
121 Status status = computation->root_instruction()->Accept(cost_analysis.get());
122 if (status.ok()) {
123 // Set default cost model based on 'cost_analysis'.
124 cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
125 std::move(cost_analysis)));
126 } else {
127 // Fall back to a simple cost model based on hlo size and L2 cache size.
128 // Note that HloCostAnalysis can returns an error status (likely because
129 // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
130 cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
131 }
132 }
133
GetTargetParallelTaskCount(HloInstruction * instruction)134 int64_t ParallelTaskAssignment::GetTargetParallelTaskCount(
135 HloInstruction* instruction) {
136 // Currently, we do not assign parallel tasks to instructions with at least
137 // one of the following properties:
138 // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
139 // *) Emit custom loops (kSelectAndScatter).
140 // *) Operations that are not thread safe (like infeed and rng).
141 // *) Tuple-shaped.
142 // *) Operations that might be implemented as an in-place
143 // dynamic-update-slice, because we can't know how many output elements
144 // they will write (out-of-place will touch the whole output buffer, while
145 // in-place will only touch the updated elements).
146 // TODO(b/27458679) Parallelize instructions which are skipped here.
147 auto opcode = instruction->opcode();
148 if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
149 instruction->shape().IsTuple() || opcode == HloOpcode::kRng ||
150 opcode == HloOpcode::kConstant) {
151 return 1;
152 }
153
154 // Only allow instructions that can be trivially parallelized (where all
155 // outputs can be computed independently of each other).
156 if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
157 opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
158 opcode == HloOpcode::kDynamicSlice ||
159 opcode == HloOpcode::kDynamicUpdateSlice ||
160 opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
161 opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
162 opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
163 opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
164 opcode == HloOpcode::kTranspose ||
165 (opcode == HloOpcode::kConvolution &&
166 !PotentiallyImplementedAsEigenConvolution(*instruction,
167 target_machine_features_))) {
168 // Consult 'cost_model_' to compute target parallel task count.
169 return cost_model_->GetParallelTaskCount(instruction);
170 }
171
172 return 1;
173 }
174
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)175 StatusOr<bool> ParallelTaskAssigner::Run(
176 HloModule* module,
177 const absl::flat_hash_set<absl::string_view>& execution_threads) {
178 XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
179 XLA_VLOG_LINES(3, module->ToString());
180 // Compute target parallel task counts for all instructions in 'module'.
181 HloToParallelTasks hlo_to_parallel_tasks;
182 ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
183
184 // Assign parallel tasks to target specific instructions in 'module'.
185 // TODO(b/27458679) Support inter-op parallelism.
186 bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
187
188 XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
189 XLA_VLOG_LINES(3, module->ToString());
190 return changed;
191 }
192
AssignParallelTasks(HloModule * module,const HloToParallelTasks & hlo_to_parallel_tasks)193 bool ParallelTaskAssigner::AssignParallelTasks(
194 HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
195 return AssignParallelTasksHelper(module, module->entry_computation(),
196 hlo_to_parallel_tasks);
197 }
198
AssignParallelTasksHelper(HloModule * module,HloComputation * computation,const HloToParallelTasks & hlo_to_parallel_tasks)199 bool ParallelTaskAssigner::AssignParallelTasksHelper(
200 HloModule* module, HloComputation* computation,
201 const HloToParallelTasks& hlo_to_parallel_tasks) {
202 bool changed = false;
203 // Snapshot set of instructions because outlining modifies the set below.
204 std::vector<HloInstruction*> instructions(computation->instructions().begin(),
205 computation->instructions().end());
206 for (auto* instruction : instructions) {
207 // Assign parallel tasks to sub-computations for While and Call HLOs.
208 // TODO(b/27458679) Evaluate alternative intra-op parallelism placement,
209 // and support other callable computations like reduce.
210 if (instruction->opcode() == HloOpcode::kWhile) {
211 changed |= AssignParallelTasksHelper(module, instruction->while_body(),
212 hlo_to_parallel_tasks);
213 continue;
214 } else if (instruction->opcode() == HloOpcode::kCall) {
215 changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
216 hlo_to_parallel_tasks);
217 continue;
218 }
219 // Skip if no parallel tasks were computed in first pass.
220 auto it = hlo_to_parallel_tasks.find(instruction);
221 if (it == hlo_to_parallel_tasks.end()) {
222 continue;
223 }
224 // Get target parallel task count computed for 'instruction'.
225 const int64_t target_parallel_task_count = (*it).second;
226 // Assign feasible dimension partitions (based on actual dimension sizes).
227 auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
228 .Run(target_parallel_task_count);
229 const int64_t total_partition_count =
230 ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
231 if (total_partition_count <= 1) {
232 // Feasible partition calculation resulting in no partitioning, so skip.
233 continue;
234 }
235
236 // Outline 'instruction' in 'computation' for parallel task assignment.
237 auto* call = module->OutlineExpressionFromComputation(
238 {instruction}, absl::StrCat("parallel_", instruction->name()),
239 computation);
240
241 // Set assigned dimension partitioning to 'instruction'.
242 auto* new_root = call->to_apply()->root_instruction();
243 new_root->set_outer_dimension_partitions(dim_partition_counts);
244
245 VLOG(2) << "Assigned parallel task count: " << total_partition_count
246 << " to instruction: " << new_root->name()
247 << " parent: " << new_root->parent()->name();
248 changed = true;
249 }
250 return changed;
251 }
252
ComputeTargetParallelTasks(HloModule * module,HloToParallelTasks * hlo_to_parallel_tasks)253 void ParallelTaskAssigner::ComputeTargetParallelTasks(
254 HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
255 ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
256 shape_size_function_, module,
257 &target_machine_features_);
258
259 // Compute parallel task counts for all instructions in 'module'.
260 for (auto* computation : module->MakeNonfusionComputations()) {
261 for (auto* instruction : computation->instructions()) {
262 // Query ParallelTaskAssignment for target parallel task count.
263 const int64_t target_parallel_task_count =
264 parallel_task_assignment.GetTargetParallelTaskCount(instruction);
265 if (target_parallel_task_count > 1) {
266 hlo_to_parallel_tasks->insert(
267 {instruction, target_parallel_task_count});
268 }
269 }
270 }
271 }
272
273 } // namespace cpu
274 } // namespace xla
275