1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file should not have any dependencies apart from the standard library,
17 // as it will be used in OSS outside of this repository.
18
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdlib>
22 #include <cstring>
23 #include <fstream>
24 #include <iostream>
25 #include <iterator>
26 #include <map>
27 #include <random>
28 #include <regex> // NOLINT
29 #include <sstream>
30 #include <string>
31 #include <type_traits>
32 #include <vector>
33
34 static constexpr int kSeed = 42;
35 static constexpr int kUpperBound = 100;
36 static constexpr int kLowerBound = -100;
37 static constexpr double kLowerBoundFP = -0.1;
38 static constexpr double kUpperBoundFP = 0.1;
39 static const char* const kUsageString = R"(
40 Driver for executing an HLO reproducer in object form in order to let OSS
41 users reproduce the miscompiles.
42
43 Expected workflow:
44
45 1) In the .hlo file, rename the root computation to `EntryModule`.
46 2) Run the .hlo file with XLA_FLAGS=--xla_dump_to set, to obtain the .ll file.
47 3) Compile and link this file with the object file from step (2).
48 4) Run the resulting file with the buffer assignment table as an argument,
49 taken from step 2. The driver will print the output to stderr.
50 5) Compare the output with optimized and non-optimized .ll file from step (2).
51 If the outputs differ, there is a miscompile.
52
53 Run with an environment variable VERBOSE set to see logging.
54 )";
55
56 // Must be kept ABI-compatible with the real definition of XlaCustomCallStatus.
57 struct XlaCustomCallStatus {
58 // If 'failed' is true then 'message' is present; otherwise it is absent.
59 // (The 'bool' followed by 'std::string' is ABI-compatible with
60 // 'std::optional<std::string>').
61 bool failed;
62 std::string message;
63 // To account for extra struct padding at the end.
64 std::string padding;
65 };
66
67 extern "C" {
68 // Function to be linked with.
69 extern void EntryModule(char* result_buffer, char* run_opts, char** params,
70 char** buffer_table, void* status,
71 int64_t* prof_counters);
72
73 // Must be kept in sync with the real definition of this runtime function.
__xla_cpu_runtime_StatusIsSuccess(const XlaCustomCallStatus * status)74 bool __xla_cpu_runtime_StatusIsSuccess( // NOLINT: This doesn't need a
75 // prototype.
76 const XlaCustomCallStatus* status) {
77 return !(status->failed);
78 }
79 }
80
81 namespace {
82
ExitWithMsg(const std::string & msg)83 [[noreturn]] void ExitWithMsg(const std::string& msg) {
84 std::cerr << msg << std::endl;
85 exit(1);
86 }
87
Check(bool cond,const std::string & msg="Precondition failed")88 void Check(bool cond, const std::string& msg = "Precondition failed") {
89 if (!cond) {
90 ExitWithMsg(msg);
91 }
92 }
93
IsVerbose()94 bool IsVerbose() { return getenv("VERBOSE") != nullptr; }
95
Log(const std::string & msg)96 void Log(const std::string& msg) {
97 if (IsVerbose()) {
98 std::cerr << msg << std::endl;
99 }
100 }
101
102 // Needs to be kept in sync with PrimitiveType in xla_data.proto.
103 enum PrimitiveType {
104 S16 = 0,
105 S32,
106 S64,
107 U8,
108 U16,
109 U32,
110 U64,
111 F16,
112 BF16,
113 F32,
114 F64,
115 C64,
116 C128
117 };
118
primitive_strings()119 const std::vector<std::string>& primitive_strings() {
120 static auto vec = new std::vector<std::string>(
121 {"s16", "s32", "s64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32",
122 "f64", "c64", "c128"});
123 return *vec;
124 }
125
ToString(PrimitiveType type)126 std::string ToString(PrimitiveType type) { return primitive_strings()[type]; }
127
PrimitiveTypeFromString(const std::string & s)128 PrimitiveType PrimitiveTypeFromString(const std::string& s) {
129 const auto& vec = primitive_strings();
130 return static_cast<PrimitiveType>(
131 std::distance(vec.begin(), std::find(vec.begin(), vec.end(), s)));
132 }
133
ByteSize(PrimitiveType type)134 int ByteSize(PrimitiveType type) {
135 std::string s = ToString(type);
136 s = s.substr(1, s.size());
137 return std::stoi(s) / 8;
138 }
139
140 struct ArrayShape {
141 PrimitiveType type;
142 std::vector<int> dimensions;
143 };
144
145 // We support tuples only for output, and we do not support nested tuples.
146 struct TupleShape {
147 std::vector<ArrayShape> elements;
148 };
149
ArrayShapeToString(ArrayShape shape)150 std::string ArrayShapeToString(ArrayShape shape) {
151 std::ostringstream out;
152 out << ToString(shape.type) << "[";
153 for (int i = 0; i < shape.dimensions.size(); i++) {
154 out << std::to_string(shape.dimensions[i]);
155 if (i != shape.dimensions.size() - 1) {
156 out << ",";
157 }
158 }
159 out << "]";
160 return out.str();
161 }
162
163 // Input: TYPE[D1,D2,...DN]
ArrayShapeFromString(const std::string & s)164 ArrayShape ArrayShapeFromString(const std::string& s) {
165 Log("Array shape from string: " + s);
166 Check(s.find('(') == std::string::npos, "Tuple shape is not supported");
167 std::regex shape_r("([^\\[]+)\\[(.*)\\]");
168 std::smatch match;
169 Check(std::regex_match(s, match, shape_r), "Shape not found");
170 std::string type = match[1];
171 std::string dims = match[2];
172 PrimitiveType ptype = PrimitiveTypeFromString(type);
173 std::istringstream dims_stream(dims);
174 std::string dim;
175 std::vector<int> dimensions;
176 while (std::getline(dims_stream, dim, ',')) {
177 dimensions.push_back(std::stoi(dim));
178 }
179 return {ptype, dimensions};
180 }
181
182 // E.g. (f32[10,20], u32[])
TupleShapeFromString(std::string s)183 TupleShape TupleShapeFromString(std::string s) {
184 Log("Tuple shape from string: " + s);
185 if (s[0] != '(') {
186 return {{ArrayShapeFromString(s)}};
187 }
188 s = s.substr(1, s.size() - 2);
189 std::istringstream sstream(s);
190 std::string subshape;
191 std::vector<ArrayShape> out;
192 while (std::getline(sstream, subshape, ' ')) {
193 if (subshape[subshape.size() - 1] == ',') {
194 subshape = subshape.substr(0, subshape.size() - 1);
195 }
196 out.push_back(ArrayShapeFromString(subshape));
197 }
198 return {out};
199 }
200
TupleShapeToString(TupleShape shape)201 std::string TupleShapeToString(TupleShape shape) {
202 std::ostringstream out;
203 if (shape.elements.size() == 1) {
204 return ArrayShapeToString(shape.elements[0]);
205 }
206 out << "(";
207 for (int idx = 0; idx < shape.elements.size(); idx++) {
208 out << ArrayShapeToString(shape.elements[idx]);
209 if (idx != shape.elements.size() - 1) {
210 out << ", ";
211 }
212 }
213 out << ")";
214 return out.str();
215 }
216
217 // Information about the buffer assignment.
218 struct BufferAssignment {
219 // Mapping from allocation index to buffer size (in bytes).
220 std::vector<int> buffers_size;
221
222 // Mapping from allocation index to its shape.
223 std::map<int, TupleShape> buffers_shape;
224
225 // Mapping from param index to allocation index.
226 std::map<int, int> param_to_alloc_idx;
227
228 // Index of the output parameter.
229 int output_idx = -1;
230 };
231
BufferAssignmentToString(const BufferAssignment & assignment)232 std::string BufferAssignmentToString(const BufferAssignment& assignment) {
233 std::ostringstream out;
234 for (const auto& p : assignment.param_to_alloc_idx) {
235 int param_idx = p.first;
236 int allocation_idx = p.second;
237 out << "Param: " << param_idx << " (allocation " << allocation_idx << "): ";
238 auto p2 = assignment.buffers_shape.find(allocation_idx);
239 Check(p2 != assignment.buffers_shape.end(),
240 "Shape not found for parameter: " + std::to_string(param_idx));
241 out << TupleShapeToString(p2->second)
242 << ", size = " << assignment.buffers_size[allocation_idx] << "\n";
243 }
244 return out.str();
245 }
246
247 // RAII table for the given assignment: mapping from a allocation idx to the
248 // actual allocation.
249 class BufferTable {
250 public:
BufferTable(BufferAssignment assignment)251 explicit BufferTable(BufferAssignment assignment) : assignment_(assignment) {
252 int num_buffers = assignment.buffers_size.size();
253 ptr_ = new char*[num_buffers];
254 for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
255 // Call malloc to ensure alignment up to std::max_align_t.
256 ptr_[buffer_idx] =
257 static_cast<char*>(malloc(assignment.buffers_size[buffer_idx]));
258 }
259 }
260
AsPtr()261 char** AsPtr() { return ptr_; }
262
~BufferTable()263 ~BufferTable() {
264 int num_buffers = assignment_.buffers_size.size();
265 for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
266 free(ptr_[buffer_idx]);
267 }
268 delete[] ptr_;
269 }
270
271 private:
272 BufferAssignment assignment_;
273 char** ptr_;
274 };
275
276 // Parse and populate the buffer table;
277 //
278 // Example of input:
279 //
280 // BufferAssignment:
281 // allocation 0: 0x27017c46b600, size 32768, parameter 0, shape f32[256,32] at
282 // ShapeIndex {}:
283 // value: <3 parameter @0> (size=32768,offset=0): f32[256,32]{1,0}
284 // allocation 1: 0x27017c46b6b0, size 128, output shape is f32[32],
285 // maybe-live-out:
286 // value: <5 reduce @0> (size=128,offset=0): f32[32]{0}
287 // allocation 2: 0x27017c46b760, size 4, constant:
288 // value: <4 init_value @0> (size=4,offset=0): f32[]
289 // allocation 3: 0x27017c46b810, size 4, thread-local:
290 // value: <0 x.1 @0> (size=4,offset=0): f32[]
291 // allocation 4: 0x27017c46b8c0, size 4, thread-local:
292 // value: <1 y.1 @0> (size=4,offset=0): f32[]
293 // allocation 5: 0x27017c46b970, size 4, output shape is f32[], thread-local:
294 // value: <2 add.1 @0> (size=4,offset=0): f32[]
ParseBufferAssignment(const std::string & fname)295 BufferAssignment ParseBufferAssignment(const std::string& fname) {
296 BufferAssignment assignment;
297 std::ifstream infile(fname);
298 std::string line;
299 while (std::getline(infile, line)) {
300 std::regex allocation_line_r(
301 "allocation ([0-9]+): .+, size ([0-9]+), (.+)");
302 std::smatch match;
303 if (std::regex_search(line, match, allocation_line_r)) {
304 Log("Matched allocation description: " + line);
305 int allocation_idx = std::stoi(match[1]);
306 int size = std::stoi(match[2]);
307 Log("Allocation size = " + std::to_string(size));
308 const std::string& postfix = match[3];
309 Check(allocation_idx == assignment.buffers_size.size(),
310 "Unordered allocations in input");
311 assignment.buffers_size.push_back(size);
312
313 std::regex output_r("output shape is \\|([^\\|]+)\\|,");
314 std::smatch output_match;
315 if (std::regex_search(postfix, output_match, output_r)) {
316 Log("Matched out parameter: " + postfix);
317 Check(assignment.output_idx == -1, "Multiple out-parameters");
318 assignment.output_idx = allocation_idx;
319 std::string output_shape = output_match[1];
320 Log("output shape = " + output_shape);
321 TupleShape shape = TupleShapeFromString(output_shape);
322 assignment.buffers_shape[allocation_idx] = shape;
323 Log("parsed output shape = " + TupleShapeToString(shape));
324 }
325
326 std::regex parameter_r("parameter ([0-9]+), shape \\|([^\\|]+)\\|");
327 std::smatch param_match;
328 if (std::regex_search(postfix, param_match, parameter_r)) {
329 Log("Matched parameter description: " + postfix);
330 int param_idx = std::stoi(param_match[1]);
331 assignment.param_to_alloc_idx[param_idx] = allocation_idx;
332 std::string param_shape = param_match[2];
333 TupleShape shape = TupleShapeFromString(param_shape);
334 assignment.buffers_shape[allocation_idx] = shape;
335 Log("parsed parameter shape for param " + std::to_string(param_idx) +
336 " = " + TupleShapeToString(shape));
337 }
338 }
339 }
340 Check(assignment.output_idx != -1, "Output not set");
341 return assignment;
342 }
343
GetNumElements(const ArrayShape & shape)344 int GetNumElements(const ArrayShape& shape) {
345 int num_elements = 1;
346 for (int dim : shape.dimensions) {
347 num_elements *= dim;
348 }
349 return num_elements;
350 }
351
352 template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
FillIntT(void * buffer,int num_elements)353 void FillIntT(void* buffer, int num_elements) {
354 std::mt19937 generator(kSeed);
355 T* casted = static_cast<T*>(buffer);
356 std::uniform_int_distribution<> distr(kLowerBound, kUpperBound);
357 for (int i = 0; i < num_elements; i++) {
358 casted[i] = static_cast<T>(distr(generator));
359 }
360 }
361
362 template <typename T,
363 typename = std::enable_if_t<std::is_floating_point<T>::value>>
FillFloatT(void * buffer,int num_elements)364 void FillFloatT(void* buffer, int num_elements) {
365 std::mt19937 generator(kSeed);
366 T* casted = static_cast<T*>(buffer);
367 std::uniform_real_distribution<T> distr(kLowerBoundFP, kUpperBoundFP);
368 for (int i = 0; i < num_elements; i++) {
369 casted[i] = distr(generator);
370 }
371 }
372
Fill(void * buffer,const ArrayShape & shape)373 void Fill(void* buffer, const ArrayShape& shape) {
374 int num_elements = GetNumElements(shape);
375 Log("Number of elements = " + std::to_string(num_elements));
376 Log("Shape type = " + ToString(shape.type) +
377 ", shape = " + ArrayShapeToString(shape));
378 switch (shape.type) {
379 case S16:
380 return FillIntT<short>(buffer, num_elements); // NOLINT
381 case S32:
382 return FillIntT<int>(buffer, num_elements);
383 case S64:
384 return FillIntT<long long>(buffer, num_elements); // NOLINT
385 case U8:
386 return FillIntT<unsigned char>(buffer, num_elements);
387 case U16:
388 return FillIntT<unsigned short>(buffer, num_elements); // NOLINT
389 case U32:
390 return FillIntT<unsigned int>(buffer, num_elements);
391 case U64:
392 return FillIntT<unsigned long long>(buffer, num_elements); // NOLINT
393 case F32:
394 return FillFloatT<float>(buffer, num_elements);
395 case F64:
396 return FillFloatT<double>(buffer, num_elements);
397
398 case F16:
399 case BF16:
400 case C64:
401 case C128:
402 ExitWithMsg("Unsupported type: " + ToString(shape.type));
403 }
404 }
405
406 template <typename T>
407 #if defined(MEMORY_SANITIZER)
408 __attribute__((no_sanitize_memory))
409 #endif
DisplayT(const void * buffer,int num_elements)410 void DisplayT(const void* buffer, int num_elements) {
411 const T* casted = static_cast<const T*>(buffer);
412 for (int i = 0; i < num_elements; i++) {
413 std::cout << casted[i];
414 if (i != num_elements - 1) {
415 std::cout << ", ";
416 }
417 }
418 std::cout << std::endl;
419 }
420
Display(const void * buffer,const ArrayShape & shape)421 void Display(const void* buffer, const ArrayShape& shape) {
422 int num_elements = GetNumElements(shape);
423 switch (shape.type) {
424 case S16:
425 return DisplayT<short>(buffer, num_elements); // NOLINT
426 case S32:
427 return DisplayT<int>(buffer, num_elements);
428 case S64:
429 return DisplayT<long long>(buffer, num_elements); // NOLINT
430 case U8:
431 return DisplayT<unsigned char>(buffer, num_elements);
432 case U16:
433 return DisplayT<unsigned short>(buffer, num_elements); // NOLINT
434 case U32:
435 return DisplayT<unsigned int>(buffer, num_elements);
436 case U64:
437 return DisplayT<unsigned long long>(buffer, num_elements); // NOLINT
438 case F32:
439 return DisplayT<float>(buffer, num_elements);
440 case F64:
441 return DisplayT<double>(buffer, num_elements);
442
443 case F16:
444 case BF16:
445 case C64:
446 case C128:
447 ExitWithMsg("Unsupported type: " + ToString(shape.type));
448 }
449 }
450
Display(const void * buffer,const TupleShape & shape)451 void Display(const void* buffer, const TupleShape& shape) {
452 if (shape.elements.size() == 1) {
453 return Display(buffer, shape.elements[0]);
454 }
455 std::cout << "(" << std::endl;
456 auto casted = static_cast<const void* const*>(buffer);
457 for (int tuple_idx = 0; tuple_idx < shape.elements.size(); tuple_idx++) {
458 ArrayShape array_shape = shape.elements[tuple_idx];
459 Display(casted[tuple_idx], array_shape);
460 if (tuple_idx != shape.elements.size() - 1) {
461 std::cout << ", " << std::endl;
462 }
463 }
464 std::cout << ")" << std::endl;
465 }
466
467 } // end namespace
468
main(int argc,char ** argv)469 int main(int argc, char** argv) {
470 if (argc < 2) {
471 ExitWithMsg(
472 "Please provide buffer table filename as an argument, "
473 "or invoke with --help for usage instructions.");
474 }
475 std::string arg = argv[1];
476 if (arg == "--help") {
477 std::cout << kUsageString << std::endl;
478 return 0;
479 }
480
481 BufferAssignment assignment = ParseBufferAssignment(arg);
482 Log("Buffer assignment: \n" + BufferAssignmentToString(assignment));
483 BufferTable table(assignment);
484
485 // Fill out input parameters.
486 for (const auto& p : assignment.param_to_alloc_idx) {
487 int param_idx = p.first;
488 int allocation_idx = p.second;
489 TupleShape tuple_shape = assignment.buffers_shape[allocation_idx];
490 Check(tuple_shape.elements.size() == 1,
491 "Parameters can not be tuples, got shape: " +
492 TupleShapeToString(tuple_shape));
493 ArrayShape shape = tuple_shape.elements[0];
494 Check(GetNumElements(shape) ==
495 assignment.buffers_size[allocation_idx] / ByteSize(shape.type),
496 "Unexpected number of elements");
497 Fill(table.AsPtr()[allocation_idx], shape);
498
499 if (IsVerbose()) {
500 std::cout << "Filled parameter buffer for param " << param_idx << ": "
501 << std::endl;
502 Display(table.AsPtr()[allocation_idx], shape);
503 }
504 }
505
506 XlaCustomCallStatus status;
507
508 Log("Launching module");
509 EntryModule(/*result_buffer=*/nullptr,
510 /*run_opts=*/nullptr,
511 /*params=*/nullptr, table.AsPtr(),
512 /*status=*/&status,
513 /*prof_counters=*/nullptr);
514
515 std::cout << "Output:" << std::endl;
516 Log("Output shape: " +
517 TupleShapeToString(assignment.buffers_shape[assignment.output_idx]));
518 Display(table.AsPtr()[assignment.output_idx],
519 assignment.buffers_shape[assignment.output_idx]);
520 }
521