xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tools/driver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file should not have any dependencies apart from the standard library,
17 // as it will be used in OSS outside of this repository.
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdlib>
22 #include <cstring>
23 #include <fstream>
24 #include <iostream>
25 #include <iterator>
26 #include <map>
27 #include <random>
28 #include <regex>  // NOLINT
29 #include <sstream>
30 #include <string>
31 #include <type_traits>
32 #include <vector>
33 
34 static constexpr int kSeed = 42;
35 static constexpr int kUpperBound = 100;
36 static constexpr int kLowerBound = -100;
37 static constexpr double kLowerBoundFP = -0.1;
38 static constexpr double kUpperBoundFP = 0.1;
39 static const char* const kUsageString = R"(
40 Driver for executing an HLO reproducer in object form in order to let OSS
41 users reproduce the miscompiles.
42 
43 Expected workflow:
44 
45 1) In the .hlo file, rename the root computation to `EntryModule`.
46 2) Run the .hlo file with XLA_FLAGS=--xla_dump_to set, to obtain the .ll file.
47 3) Compile and link this file with the object file from step (2).
48 4) Run the resulting file with the buffer assignment table as an argument,
49 taken from step 2. The driver will print the output to stderr.
50 5) Compare the output with optimized and non-optimized .ll file from step (2).
51 If the outputs differ, there is a miscompile.
52 
53 Run with an environment variable VERBOSE set to see logging.
54 )";
55 
56 // Must be kept ABI-compatible with the real definition of XlaCustomCallStatus.
57 struct XlaCustomCallStatus {
58   // If 'failed' is true then 'message' is present; otherwise it is absent.
59   // (The 'bool' followed by 'std::string' is ABI-compatible with
60   // 'std::optional<std::string>').
61   bool failed;
62   std::string message;
63   // To account for extra struct padding at the end.
64   std::string padding;
65 };
66 
67 extern "C" {
68 // Function to be linked with.
69 extern void EntryModule(char* result_buffer, char* run_opts, char** params,
70                         char** buffer_table, void* status,
71                         int64_t* prof_counters);
72 
73 // Must be kept in sync with the real definition of this runtime function.
__xla_cpu_runtime_StatusIsSuccess(const XlaCustomCallStatus * status)74 bool __xla_cpu_runtime_StatusIsSuccess(  // NOLINT: This doesn't need a
75                                          // prototype.
76     const XlaCustomCallStatus* status) {
77   return !(status->failed);
78 }
79 }
80 
81 namespace {
82 
ExitWithMsg(const std::string & msg)83 [[noreturn]] void ExitWithMsg(const std::string& msg) {
84   std::cerr << msg << std::endl;
85   exit(1);
86 }
87 
Check(bool cond,const std::string & msg="Precondition failed")88 void Check(bool cond, const std::string& msg = "Precondition failed") {
89   if (!cond) {
90     ExitWithMsg(msg);
91   }
92 }
93 
IsVerbose()94 bool IsVerbose() { return getenv("VERBOSE") != nullptr; }
95 
Log(const std::string & msg)96 void Log(const std::string& msg) {
97   if (IsVerbose()) {
98     std::cerr << msg << std::endl;
99   }
100 }
101 
102 // Needs to be kept in sync with PrimitiveType in xla_data.proto.
103 enum PrimitiveType {
104   S16 = 0,
105   S32,
106   S64,
107   U8,
108   U16,
109   U32,
110   U64,
111   F16,
112   BF16,
113   F32,
114   F64,
115   C64,
116   C128
117 };
118 
primitive_strings()119 const std::vector<std::string>& primitive_strings() {
120   static auto vec = new std::vector<std::string>(
121       {"s16", "s32", "s64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32",
122        "f64", "c64", "c128"});
123   return *vec;
124 }
125 
ToString(PrimitiveType type)126 std::string ToString(PrimitiveType type) { return primitive_strings()[type]; }
127 
PrimitiveTypeFromString(const std::string & s)128 PrimitiveType PrimitiveTypeFromString(const std::string& s) {
129   const auto& vec = primitive_strings();
130   return static_cast<PrimitiveType>(
131       std::distance(vec.begin(), std::find(vec.begin(), vec.end(), s)));
132 }
133 
ByteSize(PrimitiveType type)134 int ByteSize(PrimitiveType type) {
135   std::string s = ToString(type);
136   s = s.substr(1, s.size());
137   return std::stoi(s) / 8;
138 }
139 
140 struct ArrayShape {
141   PrimitiveType type;
142   std::vector<int> dimensions;
143 };
144 
145 // We support tuples only for output, and we do not support nested tuples.
146 struct TupleShape {
147   std::vector<ArrayShape> elements;
148 };
149 
ArrayShapeToString(ArrayShape shape)150 std::string ArrayShapeToString(ArrayShape shape) {
151   std::ostringstream out;
152   out << ToString(shape.type) << "[";
153   for (int i = 0; i < shape.dimensions.size(); i++) {
154     out << std::to_string(shape.dimensions[i]);
155     if (i != shape.dimensions.size() - 1) {
156       out << ",";
157     }
158   }
159   out << "]";
160   return out.str();
161 }
162 
163 // Input: TYPE[D1,D2,...DN]
ArrayShapeFromString(const std::string & s)164 ArrayShape ArrayShapeFromString(const std::string& s) {
165   Log("Array shape from string: " + s);
166   Check(s.find('(') == std::string::npos, "Tuple shape is not supported");
167   std::regex shape_r("([^\\[]+)\\[(.*)\\]");
168   std::smatch match;
169   Check(std::regex_match(s, match, shape_r), "Shape not found");
170   std::string type = match[1];
171   std::string dims = match[2];
172   PrimitiveType ptype = PrimitiveTypeFromString(type);
173   std::istringstream dims_stream(dims);
174   std::string dim;
175   std::vector<int> dimensions;
176   while (std::getline(dims_stream, dim, ',')) {
177     dimensions.push_back(std::stoi(dim));
178   }
179   return {ptype, dimensions};
180 }
181 
182 // E.g. (f32[10,20], u32[])
TupleShapeFromString(std::string s)183 TupleShape TupleShapeFromString(std::string s) {
184   Log("Tuple shape from string: " + s);
185   if (s[0] != '(') {
186     return {{ArrayShapeFromString(s)}};
187   }
188   s = s.substr(1, s.size() - 2);
189   std::istringstream sstream(s);
190   std::string subshape;
191   std::vector<ArrayShape> out;
192   while (std::getline(sstream, subshape, ' ')) {
193     if (subshape[subshape.size() - 1] == ',') {
194       subshape = subshape.substr(0, subshape.size() - 1);
195     }
196     out.push_back(ArrayShapeFromString(subshape));
197   }
198   return {out};
199 }
200 
TupleShapeToString(TupleShape shape)201 std::string TupleShapeToString(TupleShape shape) {
202   std::ostringstream out;
203   if (shape.elements.size() == 1) {
204     return ArrayShapeToString(shape.elements[0]);
205   }
206   out << "(";
207   for (int idx = 0; idx < shape.elements.size(); idx++) {
208     out << ArrayShapeToString(shape.elements[idx]);
209     if (idx != shape.elements.size() - 1) {
210       out << ", ";
211     }
212   }
213   out << ")";
214   return out.str();
215 }
216 
217 // Information about the buffer assignment.
218 struct BufferAssignment {
219   // Mapping from allocation index to buffer size (in bytes).
220   std::vector<int> buffers_size;
221 
222   // Mapping from allocation index to its shape.
223   std::map<int, TupleShape> buffers_shape;
224 
225   // Mapping from param index to allocation index.
226   std::map<int, int> param_to_alloc_idx;
227 
228   // Index of the output parameter.
229   int output_idx = -1;
230 };
231 
BufferAssignmentToString(const BufferAssignment & assignment)232 std::string BufferAssignmentToString(const BufferAssignment& assignment) {
233   std::ostringstream out;
234   for (const auto& p : assignment.param_to_alloc_idx) {
235     int param_idx = p.first;
236     int allocation_idx = p.second;
237     out << "Param: " << param_idx << " (allocation " << allocation_idx << "): ";
238     auto p2 = assignment.buffers_shape.find(allocation_idx);
239     Check(p2 != assignment.buffers_shape.end(),
240           "Shape not found for parameter: " + std::to_string(param_idx));
241     out << TupleShapeToString(p2->second)
242         << ", size = " << assignment.buffers_size[allocation_idx] << "\n";
243   }
244   return out.str();
245 }
246 
247 // RAII table for the given assignment: mapping from a allocation idx to the
248 // actual allocation.
249 class BufferTable {
250  public:
BufferTable(BufferAssignment assignment)251   explicit BufferTable(BufferAssignment assignment) : assignment_(assignment) {
252     int num_buffers = assignment.buffers_size.size();
253     ptr_ = new char*[num_buffers];
254     for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
255       // Call malloc to ensure alignment up to std::max_align_t.
256       ptr_[buffer_idx] =
257           static_cast<char*>(malloc(assignment.buffers_size[buffer_idx]));
258     }
259   }
260 
AsPtr()261   char** AsPtr() { return ptr_; }
262 
~BufferTable()263   ~BufferTable() {
264     int num_buffers = assignment_.buffers_size.size();
265     for (int buffer_idx = 0; buffer_idx < num_buffers; buffer_idx++) {
266       free(ptr_[buffer_idx]);
267     }
268     delete[] ptr_;
269   }
270 
271  private:
272   BufferAssignment assignment_;
273   char** ptr_;
274 };
275 
276 // Parse and populate the buffer table;
277 //
278 // Example of input:
279 //
280 // BufferAssignment:
281 // allocation 0: 0x27017c46b600, size 32768, parameter 0, shape f32[256,32] at
282 // ShapeIndex {}:
283 //  value: <3 parameter @0> (size=32768,offset=0): f32[256,32]{1,0}
284 // allocation 1: 0x27017c46b6b0, size 128, output shape is f32[32],
285 // maybe-live-out:
286 //  value: <5 reduce @0> (size=128,offset=0): f32[32]{0}
287 // allocation 2: 0x27017c46b760, size 4, constant:
288 //  value: <4 init_value @0> (size=4,offset=0): f32[]
289 // allocation 3: 0x27017c46b810, size 4, thread-local:
290 //  value: <0 x.1 @0> (size=4,offset=0): f32[]
291 // allocation 4: 0x27017c46b8c0, size 4, thread-local:
292 //  value: <1 y.1 @0> (size=4,offset=0): f32[]
293 // allocation 5: 0x27017c46b970, size 4, output shape is f32[], thread-local:
294 //  value: <2 add.1 @0> (size=4,offset=0): f32[]
ParseBufferAssignment(const std::string & fname)295 BufferAssignment ParseBufferAssignment(const std::string& fname) {
296   BufferAssignment assignment;
297   std::ifstream infile(fname);
298   std::string line;
299   while (std::getline(infile, line)) {
300     std::regex allocation_line_r(
301         "allocation ([0-9]+): .+, size ([0-9]+), (.+)");
302     std::smatch match;
303     if (std::regex_search(line, match, allocation_line_r)) {
304       Log("Matched allocation description: " + line);
305       int allocation_idx = std::stoi(match[1]);
306       int size = std::stoi(match[2]);
307       Log("Allocation size = " + std::to_string(size));
308       const std::string& postfix = match[3];
309       Check(allocation_idx == assignment.buffers_size.size(),
310             "Unordered allocations in input");
311       assignment.buffers_size.push_back(size);
312 
313       std::regex output_r("output shape is \\|([^\\|]+)\\|,");
314       std::smatch output_match;
315       if (std::regex_search(postfix, output_match, output_r)) {
316         Log("Matched out parameter: " + postfix);
317         Check(assignment.output_idx == -1, "Multiple out-parameters");
318         assignment.output_idx = allocation_idx;
319         std::string output_shape = output_match[1];
320         Log("output shape = " + output_shape);
321         TupleShape shape = TupleShapeFromString(output_shape);
322         assignment.buffers_shape[allocation_idx] = shape;
323         Log("parsed output shape = " + TupleShapeToString(shape));
324       }
325 
326       std::regex parameter_r("parameter ([0-9]+), shape \\|([^\\|]+)\\|");
327       std::smatch param_match;
328       if (std::regex_search(postfix, param_match, parameter_r)) {
329         Log("Matched parameter description: " + postfix);
330         int param_idx = std::stoi(param_match[1]);
331         assignment.param_to_alloc_idx[param_idx] = allocation_idx;
332         std::string param_shape = param_match[2];
333         TupleShape shape = TupleShapeFromString(param_shape);
334         assignment.buffers_shape[allocation_idx] = shape;
335         Log("parsed parameter shape for param " + std::to_string(param_idx) +
336             " = " + TupleShapeToString(shape));
337       }
338     }
339   }
340   Check(assignment.output_idx != -1, "Output not set");
341   return assignment;
342 }
343 
GetNumElements(const ArrayShape & shape)344 int GetNumElements(const ArrayShape& shape) {
345   int num_elements = 1;
346   for (int dim : shape.dimensions) {
347     num_elements *= dim;
348   }
349   return num_elements;
350 }
351 
352 template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
FillIntT(void * buffer,int num_elements)353 void FillIntT(void* buffer, int num_elements) {
354   std::mt19937 generator(kSeed);
355   T* casted = static_cast<T*>(buffer);
356   std::uniform_int_distribution<> distr(kLowerBound, kUpperBound);
357   for (int i = 0; i < num_elements; i++) {
358     casted[i] = static_cast<T>(distr(generator));
359   }
360 }
361 
362 template <typename T,
363           typename = std::enable_if_t<std::is_floating_point<T>::value>>
FillFloatT(void * buffer,int num_elements)364 void FillFloatT(void* buffer, int num_elements) {
365   std::mt19937 generator(kSeed);
366   T* casted = static_cast<T*>(buffer);
367   std::uniform_real_distribution<T> distr(kLowerBoundFP, kUpperBoundFP);
368   for (int i = 0; i < num_elements; i++) {
369     casted[i] = distr(generator);
370   }
371 }
372 
Fill(void * buffer,const ArrayShape & shape)373 void Fill(void* buffer, const ArrayShape& shape) {
374   int num_elements = GetNumElements(shape);
375   Log("Number of elements = " + std::to_string(num_elements));
376   Log("Shape type = " + ToString(shape.type) +
377       ", shape = " + ArrayShapeToString(shape));
378   switch (shape.type) {
379     case S16:
380       return FillIntT<short>(buffer, num_elements);  // NOLINT
381     case S32:
382       return FillIntT<int>(buffer, num_elements);
383     case S64:
384       return FillIntT<long long>(buffer, num_elements);  // NOLINT
385     case U8:
386       return FillIntT<unsigned char>(buffer, num_elements);
387     case U16:
388       return FillIntT<unsigned short>(buffer, num_elements);  // NOLINT
389     case U32:
390       return FillIntT<unsigned int>(buffer, num_elements);
391     case U64:
392       return FillIntT<unsigned long long>(buffer, num_elements);  // NOLINT
393     case F32:
394       return FillFloatT<float>(buffer, num_elements);
395     case F64:
396       return FillFloatT<double>(buffer, num_elements);
397 
398     case F16:
399     case BF16:
400     case C64:
401     case C128:
402       ExitWithMsg("Unsupported type: " + ToString(shape.type));
403   }
404 }
405 
406 template <typename T>
407 #if defined(MEMORY_SANITIZER)
408 __attribute__((no_sanitize_memory))
409 #endif
DisplayT(const void * buffer,int num_elements)410 void DisplayT(const void* buffer, int num_elements) {
411   const T* casted = static_cast<const T*>(buffer);
412   for (int i = 0; i < num_elements; i++) {
413     std::cout << casted[i];
414     if (i != num_elements - 1) {
415       std::cout << ", ";
416     }
417   }
418   std::cout << std::endl;
419 }
420 
Display(const void * buffer,const ArrayShape & shape)421 void Display(const void* buffer, const ArrayShape& shape) {
422   int num_elements = GetNumElements(shape);
423   switch (shape.type) {
424     case S16:
425       return DisplayT<short>(buffer, num_elements);  // NOLINT
426     case S32:
427       return DisplayT<int>(buffer, num_elements);
428     case S64:
429       return DisplayT<long long>(buffer, num_elements);  // NOLINT
430     case U8:
431       return DisplayT<unsigned char>(buffer, num_elements);
432     case U16:
433       return DisplayT<unsigned short>(buffer, num_elements);  // NOLINT
434     case U32:
435       return DisplayT<unsigned int>(buffer, num_elements);
436     case U64:
437       return DisplayT<unsigned long long>(buffer, num_elements);  // NOLINT
438     case F32:
439       return DisplayT<float>(buffer, num_elements);
440     case F64:
441       return DisplayT<double>(buffer, num_elements);
442 
443     case F16:
444     case BF16:
445     case C64:
446     case C128:
447       ExitWithMsg("Unsupported type: " + ToString(shape.type));
448   }
449 }
450 
Display(const void * buffer,const TupleShape & shape)451 void Display(const void* buffer, const TupleShape& shape) {
452   if (shape.elements.size() == 1) {
453     return Display(buffer, shape.elements[0]);
454   }
455   std::cout << "(" << std::endl;
456   auto casted = static_cast<const void* const*>(buffer);
457   for (int tuple_idx = 0; tuple_idx < shape.elements.size(); tuple_idx++) {
458     ArrayShape array_shape = shape.elements[tuple_idx];
459     Display(casted[tuple_idx], array_shape);
460     if (tuple_idx != shape.elements.size() - 1) {
461       std::cout << ", " << std::endl;
462     }
463   }
464   std::cout << ")" << std::endl;
465 }
466 
467 }  // end namespace
468 
main(int argc,char ** argv)469 int main(int argc, char** argv) {
470   if (argc < 2) {
471     ExitWithMsg(
472         "Please provide buffer table filename as an argument, "
473         "or invoke with --help for usage instructions.");
474   }
475   std::string arg = argv[1];
476   if (arg == "--help") {
477     std::cout << kUsageString << std::endl;
478     return 0;
479   }
480 
481   BufferAssignment assignment = ParseBufferAssignment(arg);
482   Log("Buffer assignment: \n" + BufferAssignmentToString(assignment));
483   BufferTable table(assignment);
484 
485   // Fill out input parameters.
486   for (const auto& p : assignment.param_to_alloc_idx) {
487     int param_idx = p.first;
488     int allocation_idx = p.second;
489     TupleShape tuple_shape = assignment.buffers_shape[allocation_idx];
490     Check(tuple_shape.elements.size() == 1,
491           "Parameters can not be tuples, got shape: " +
492               TupleShapeToString(tuple_shape));
493     ArrayShape shape = tuple_shape.elements[0];
494     Check(GetNumElements(shape) ==
495               assignment.buffers_size[allocation_idx] / ByteSize(shape.type),
496           "Unexpected number of elements");
497     Fill(table.AsPtr()[allocation_idx], shape);
498 
499     if (IsVerbose()) {
500       std::cout << "Filled parameter buffer for param " << param_idx << ": "
501                 << std::endl;
502       Display(table.AsPtr()[allocation_idx], shape);
503     }
504   }
505 
506   XlaCustomCallStatus status;
507 
508   Log("Launching module");
509   EntryModule(/*result_buffer=*/nullptr,
510               /*run_opts=*/nullptr,
511               /*params=*/nullptr, table.AsPtr(),
512               /*status=*/&status,
513               /*prof_counters=*/nullptr);
514 
515   std::cout << "Output:" << std::endl;
516   Log("Output shape: " +
517       TupleShapeToString(assignment.buffers_shape[assignment.output_idx]));
518   Display(table.AsPtr()[assignment.output_idx],
519           assignment.buffers_shape[assignment.output_idx]);
520 }
521