xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/tutorial.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // *** Tensor Expressions ***
2 //
3 // This tutorial covers basics of NNC's tensor expressions, shows basic APIs to
4 // work with them, and outlines how they are used in the overall TorchScript
5 // compilation pipeline. This doc is permanently a "work in progress" since NNC
6 // is under active development and things change fast.
7 //
8 // This Tutorial's code is compiled in the standard pytorch build, and the
9 // executable can be found in `build/bin/tutorial_tensorexpr`.
10 //
11 // *** What is NNC ***
12 //
13 // NNC stands for Neural Net Compiler. It is a component of TorchScript JIT
14 // and it performs on-the-fly code generation for kernels, which are often a
15 // combination of multiple aten (torch) operators.
16 //
17 // When the JIT interpreter executes a torchscript model, it automatically
18 // extracts subgraphs from the torchscript IR graph for which specialized code
19 // can be JIT generated. This usually improves performance as the 'combined'
20 // kernel created from the subgraph could avoid unnecessary memory traffic that
21 // is unavoidable when the subgraph is interpreted as-is, operator by operator.
22 // This optimization is often referred to as 'fusion'. Relatedly, the process of
23 // finding and extracting subgraphs suitable for NNC code generation is done by
24 // a JIT pass called 'fuser'.
25 //
26 // *** What is TE ***
27 //
28 // TE stands for Tensor Expressions. TE is a commonly used approach for
29 // compiling kernels performing tensor (~matrix) computation. The idea behind it
30 // is that operators are represented as a mathematical formula describing what
31 // computation they do (as TEs) and then the TE engine can perform mathematical
32 // simplification and other optimizations using those formulas and eventually
33 // generate executable code that would produce the same results as the original
34 // sequence of operators, but more efficiently.
35 //
36 // NNC's design and implementation of TE was heavily inspired by Halide and TVM
37 // projects.
38 #include <iostream>
39 #include <string>
40 
41 #include <c10/util/irange.h>
42 #include <torch/csrc/jit/ir/ir.h>
43 #include <torch/csrc/jit/ir/irparser.h>
44 #include <torch/csrc/jit/tensorexpr/eval.h>
45 #include <torch/csrc/jit/tensorexpr/expr.h>
46 #include <torch/csrc/jit/tensorexpr/ir.h>
47 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
48 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
49 #include <torch/csrc/jit/tensorexpr/kernel.h>
50 #include <torch/csrc/jit/tensorexpr/loopnest.h>
51 #include <torch/csrc/jit/tensorexpr/stmt.h>
52 #include <torch/csrc/jit/tensorexpr/tensor.h>
53 #include <torch/torch.h>
54 
55 using namespace torch::jit::tensorexpr;
56 
57 #ifdef TORCH_ENABLE_LLVM
58 
59 // Helper function to print a snippet from a big multi-line string
60 static void printLinesToFrom(const std::string& input_str, int from, int to);
61 
62 #endif
63 
main(int argc,char * argv[])64 int main(int argc, char* argv[]) {
65   std::cout << "*** Structure of tensor expressions and statements ***"
66             << std::endl;
67   {
68     // A tensor expression is a tree of expressions. Each expression has a type,
69     // and that type defines what sub-expressions the current expression has.
70     // For instance, an expression of type 'Mul' would have a type 'kMul' and
71     // two subexpressions: LHS and RHS. Each of these two sub-expressions could
72     // also be a 'Mul' or some other expression.
73     //
74     // Let's construct a simple TE:
75     ExprPtr lhs = alloc<IntImm>(5);
76     ExprPtr rhs = alloc<Var>("x", kInt);
77     ExprPtr mul = alloc<Mul>(lhs, rhs);
78     std::cout << "Tensor expression: " << *mul << std::endl;
79     // Prints: Tensor expression: 5 * x
80 
81     // Here we created an expression representing a 5*x computation, where x is
82     // an int variable.
83 
84     // Another, probably a more convenient, way to construct tensor expressions
85     // is to use so called expression handles (as opposed to raw expressions
86     // like we did in the previous example). Expression handles overload common
87     // operations and allow us to express the same semantics in a more natural
88     // way:
89     ExprHandle l = 5;
90     ExprHandle r = Var::make("x", kInt);
91     ExprHandle m = l * r;
92     std::cout << "Tensor expression: " << *m.node() << std::endl;
93     // Prints: Tensor expression: 5 * x
94 
95     // Converting from handles to raw expressions and back is easy:
96     ExprHandle handle = Var::make("x", kInt);
97     ExprPtr raw_expr_from_handle = handle.node();
98     ExprPtr raw_expr = alloc<Var>("x", kInt);
99     ExprHandle handle_from_raw_expr = ExprHandle(raw_expr);
100 
101     // We could construct arbitrarily complex expressions using mathematical
102     // and logical operations, casts between various data types, and a bunch of
103     // intrinsics.
104     ExprHandle a = Var::make("a", kInt);
105     ExprHandle b = Var::make("b", kFloat);
106     ExprHandle c = Var::make("c", kFloat);
107     ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f);
108     std::cout << "Tensor expression: " << *x.node() << std::endl;
109     // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f)
110 
111     // An ultimate purpose of tensor expressions is to optimize tensor
112     // computations, and in order to represent accesses to tensors data, there
113     // is a special kind of expression - a load.
114     // To construct a load we need two pieces: the base and the indices. The
115     // base of a load is a Buf expression, which could be thought of as a
116     // placeholder similar to Var, but with dimensions info.
117     //
118     // Let's construct a simple load:
119     BufHandle A("A", {64, 32}, kInt);
120     VarPtr i_var = alloc<Var>("i", kInt), j_var = alloc<Var>("j", kInt);
121     ExprHandle i(i_var), j(j_var);
122     ExprHandle load = Load::make(A.dtype(), A, {i, j});
123     std::cout << "Tensor expression: " << *load.node() << std::endl;
124     // Prints: Tensor expression: A[i, j]
125 
126     // Tensor Expressions constitute Tensor Statements, which are used to
127     // represent computation of a given operator or a group of operators from a
128     // fusion group.
129     //
130     // There are three main kinds of tensor statements:
131     //  - block
132     //  - store
133     //  - loop
134     //
135     // A Store represents a store to a single element of a tensor (or to a
136     // group of elements if it's a vectorized store). Store statements,
137     // similarly to Load expressions, have a base and indices, but on top of
138     // that they also include a value - an expression representing what needs
139     // to be stored at the given memory location. Let's create a Store stmt:
140     StmtPtr store_a = Store::make(A, {i, j}, i + j);
141     std::cout << "Store statement: " << *store_a << std::endl;
142     // Prints: Store statement: A[i, j] = i + j;
143 
144     // An operator fills the entire tensor, not just a single element, and to
145     // represent this we need to use For stmt: let's wrap our store stmt with
146     // two nested loops to represent that variables i and j need to iterate
147     // over some ranges.
148     ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a);
149     ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a);
150 
151     std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl;
152     // Prints:
153     // Nested for loops:
154     // for (const auto i : c10::irange(64)) {
155     //   for (const auto j : c10::irange(32)) {
156     //     A[i, j] = i + j;
157     //   }
158     // }
159 
160     // A Block statement is used when we need a sequence of other statements.
161     // E.g. if a fusion group contains several operators, we initially define
162     // separate loopnest for each of them and put them all into a common block:
163     BufHandle B("B", {64, 32}, kInt);
164     StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j));
165     ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b);
166     ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b);
167 
168     BlockPtr block = Block::make({loop_i_a, loop_i_b});
169     std::cout << "Compound Block statement: " << std::endl
170               << *block << std::endl;
171     // Prints:
172     // Compound Block statement:
173     // {
174     //   for (const auto i : c10::irange(64)) {
175     //     for (const auto j : c10::irange(32)) {
176     //       A[i, j] = i + j;
177     //     }
178     //   }
179     //   for (const auto i : c10::irange(64)) {
180     //     for (const auto j : c10::irange(32)) {
181     //       B[i, j] = A[i, j];
182     //     }
183     //   }
184     // }
185 
186     // Manually constructing nested loops and blocks to represent a computation
187     // might be laborious, and instead we can use a 'Compute' API. This API
188     // requires us to specify dimensions and a lambda to compute a single
189     // element of the resulting tensor and returns a `Tensor` structure. This
190     // structure is simply a pair of a buffer that was created to represent the
191     // result of the computation (BufPtr) and a statement representing the
192     // computation itself (StmtPtr).
193     Tensor C =
194         Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
195           return i * j;
196         });
197     std::cout << "Stmt produced by 'Compute' API: " << std::endl
198               << *C.stmt() << std::endl;
199     // Prints:
200     // Stmt produced by 'Compute' API:
201     // for (const auto i : c10::irange(64)) {
202     //   for (const auto j : c10::irange(32)) {
203     //     C[i, j] = i * j;
204     //   }
205     // }
206 
207     // To construct statements to represent computations with reductions, we
208     // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple
209     // of extra arguments defining how to perform the reduction. Let's define a
210     // simple 2D sum of C using that:
211     Tensor D = Reduce(
212         "D",
213         {},
214         Sum(),
215         [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
216         {64, 32});
217     std::cout << "Stmt produced by 'Reduce' API: " << std::endl
218               << *D.stmt() << std::endl;
219   }
220 
221   std::cout << "*** Loopnests transformations ***" << std::endl;
222   {
223     // When a statement for the computation is generated, we might want to
224     // apply some optimizations to it. These transformations allow us to end up
225     // with a statement producing the same results, but more efficiently.
226     //
227     // Let's look at a couple of transformations that are used in NNC. We will
228     // begin with constructing a Block statement like we did before.
229 
230     Tensor C =
231         Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
232           return i * (j + 1);
233         });
234     BufHandle c_buf(C.buf());
235     Tensor D =
236         Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
237           return c_buf.load(i, j) - i;
238         });
239     StmtPtr block = Block::make({C.stmt(), D.stmt()});
240     std::cout << "Stmt produced by 'Compute' API: " << std::endl
241               << *block << std::endl;
242     // Prints:
243     // Stmt produced by 'Compute' API:
244     // {
245     //   for (const auto i : c10::irange(64)) {
246     //     for (const auto j : c10::irange(32)) {
247     //       C[i, j] = i * (j + 1);
248     //     }
249     //   }
250     //   for (const auto i_1 : c10::irange(64)) {
251     //     for (const auto j_1 : c10::irange(32)) {
252     //       D[i_1, j_1] = (C[i_1, j_1]) - i_1;
253     //     }
254     //   }
255     // }
256 
257     // One transformation we can apply to this computation is inlining: i.e.
258     // taking the expression that defines values of C and substituting a load
259     // from C with it.
260     // To do that, we first need to create a special object called LoopNest -
261     // all transformations are methods of this class. To create a loopnest we
262     // need to provide a list of output buffers and the root statement:
263     LoopNest nest(block, {D.buf()});
264 
265     // We can always retrieve the Stmt back from LoopNest:
266     std::cout << "LoopNest root stmt: " << std::endl
267               << *nest.root_stmt() << std::endl;
268     // Prints:
269     // LoopNest root stmt:
270     // {
271     //   for (const auto i : c10::irange(64)) {
272     //     for (const auto j : c10::irange(32)) {
273     //       C[i, j] = i * (j + 1);
274     //     }
275     //   }
276     //   for (const auto i_1 : c10::irange(64)) {
277     //     for (const auto j_1 : c10::irange(32)) {
278     //       D[i_1, j_1] = (C[i_1, j_1]) - i_1;
279     //     }
280     //   }
281     // }
282 
283     // Now we can apply the inlining transformation:
284     nest.computeInline(C.buf());
285     std::cout << "Stmt after inlining:" << std::endl
286               << *nest.root_stmt() << std::endl;
287     // Prints:
288     // Stmt after inlining:
289     // {
290     //   for (const auto i : c10::irange(64)) {
291     //     for (const auto j : c10::irange(32)) {
292     //       D[i, j] = i * (j + 1) - i;
293     //     }
294     //   }
295     // }
296 
297     // We can also apply algebraic simplification to a statement:
298     StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt());
299     std::cout << "Stmt after simplification:" << std::endl
300               << *simplified << std::endl;
301     // Prints:
302     // Stmt after simplification:
303     // {
304     //   for (const auto i : c10::irange(64)) {
305     //     for (const auto j : c10::irange(32)) {
306     //       D[i, j] = i * j;
307     //     }
308     //   }
309     // }
310 
311     // Many loopnest transformations are stateless and can be applied without
312     // creating a LoopNest object. In fact, we plan to make all transformations
313     // stateless.
314     // splitWithTail is one such transformation: it splits an iteration space
315     // of a given loop into two with a given factor.
316     ForPtr outer_loop = to<For>(to<Block>(simplified)->stmts().front());
317     LoopNest::splitWithTail(outer_loop, 13);
318     // Call simplifier once more to fold some arithmetic.
319     simplified = IRSimplifier::simplify(simplified);
320     std::cout << "Stmt after splitWithTail:" << std::endl
321               << *simplified << std::endl;
322     // Prints:
323     // Stmt after splitWithTail:
324     // {
325     //   for (const auto i_outer : c10::irange(4)) {
326     //     for (const auto i_inner : c10::irange(13)) {
327     //       for (const auto j : c10::irange(32)) {
328     //         D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j);
329     //       }
330     //     }
331     //   }
332     //   for (const auto i_tail : c10::irange(12)) {
333     //     for (const auto j : c10::irange(32)) {
334     //       D[i_tail + 52, j] = i_tail * j + 52 * j;
335     //     }
336     //   }
337     // }
338 
339     // NNC supports a wide range of loop nest transformations, which we are not
340     // listing here. Please refer to documentation in
341     // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h
342     // for more details.
343   }
344 
345   std::cout << "*** Codegen ***" << std::endl;
346   {
347     // An ultimate goal of tensor expressions is to be provide a mechanism to
348     // execute a given computation in the fastest possible way. So far we've
349     // looked at how we could describe what computation we're interested in, but
350     // we haven't looked at how to actually execute it.
351     //
352     // All we've been dealing with was just symbols with no actual data
353     // associated, in this section we would look at how we can bridge that gap.
354 
355     // Let's start by constructing a simple computation for us to work with:
356     BufHandle A("A", {64, 32}, kInt);
357     BufHandle B("B", {64, 32}, kInt);
358     Tensor X =
359         Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
360           return A.load(i, j) + B.load(i, j);
361         });
362 
363     // And let's lower it to a loop nest, as we did in the previous section. We
364     // can pass Tensor object directly:
365     LoopNest loopnest({X});
366     std::cout << *loopnest.root_stmt() << std::endl;
367     // Prints:
368     // {
369     //   for (const auto i : c10::irange(64)) {
370     //     for (const auto j : c10::irange(32)) {
371     //       X[i, j] = (A[i, j]) + (B[i, j]);
372     //     }
373     //   }
374 
375     // Now imagine that we have two actual tensors 64x32 that we want sum
376     // together, how do we pass those tensors to the computation and how do we
377     // carry it out?
378     //
379     // Codegen object is aimed at providing exactly that functionality. Codegen
380     // is an abstract class and concrete codegens are derived from it.
381     // Currently, we have three codegens:
382     //  1) Simple Evaluator,
383     //  2) LLVM Codegen for CPU,
384     //  3) CUDA Codegen.
385     // In this example we will be using Simple Evaluator, since it's available
386     // everywhere.
387 
388     // To create a codegen, we need to provide the statement - it specifies the
389     // computation we want to perform - and a list of placeholders and tensors
390     // used in the computation. The latter part is crucial since that's the only
391     // way the codegen could use to correlate symbols in the statement to actual
392     // data arrays that we will be passing when we will actually be performing
393     // the computation.
394     //
395     // Let's create a Simple IR Evaluator codegen for our computation:
396     SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X});
397 
398     // We are using the simplest codegen and in it almost no work is done at the
399     // construction step. Real codegens such as CUDA and LLVM perform
400     // compilation during that stage so that when we're about to run the
401     // computation everything is ready.
402 
403     // Let's now create some inputs and run our computation with them:
404     std::vector<int> data_A(64 * 32, 3); // This will be the input A
405     std::vector<int> data_B(64 * 32, 5); // This will be the input B
406     std::vector<int> data_X(64 * 32, 0); // This will be used for the result
407 
408     // Now let's invoke our codegen to perform the computation on our data. We
409     // need to provide as many arguments as how many placeholders and tensors we
410     // passed at the codegen construction time. A position in these lists would
411     // define how real data arrays from the latter call (these arguments are
412     // referred to as 'CallArg's in our codebase) correspond to symbols
413     // (placeholders and tensors) used in the tensor expressions we constructed
414     // (these are referred to as 'BufferArg').
415     // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A
416     // contains data for the placeholder A, data_B - for the placeholder B, and
417     // data_X would be used for contents of tensor X.
418     ir_eval(data_A, data_B, data_X);
419 
420     // Let's print one of the elements from each array to verify that the
421     // computation did happen:
422     std::cout << "A[10] = " << data_A[10] << std::endl
423               << "B[10] = " << data_B[10] << std::endl
424               << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl;
425     // Prints:
426     // A[10] = 3
427     // B[10] = 5
428     // X[10] = A[10] + B[10] = 8
429   }
430 
431   std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl;
432   {
433     // This section requires a LLVM-enabled PyTorch build, so we have to use a
434     // guard:
435 #ifdef TORCH_ENABLE_LLVM
436 
437     // Often we would like to convert a TorchScript IR to TE rather than
438     // construct TE IR from scratch.  NNC provides an API to perform such
439     // lowering: it takes a TorchScript graph and returns an object that can be
440     // used to invoke the generated kernel.
441     // This API is currently used by the TorchScript JIT fuser and can also be
442     // used ahead of time to pre-compile parts of a model.
443     //
444     // To get familiar with this API let's first start with defining a simple
445     // TorchScript graph:
446     const auto graph_string = R"IR(
447         graph(%A : Float(5, 3, strides=[3, 1], device=cpu),
448               %B : Float(5, 3, strides=[3, 1], device=cpu)):
449           %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B)
450           %one : int = prim::Constant[value=1]()
451           %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB)
452           %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one)
453           return (%AAB_plus_B))IR";
454     auto graph = std::make_shared<torch::jit::Graph>();
455     parseIR(graph_string, &*graph);
456 
457     // This graph defines a simple computation of A*A*B + B where A and B are
458     // input 5x3 tensors.
459 
460     // To lower this TorchScript graph to TE, we just need to create a
461     // TensorExprKernel object. In its constructor it constructs the
462     // corresponding TE IR and compiles it for the given backend (in this
463     // example for CPU using LLVM compiler).
464     TensorExprKernel kernel(graph);
465 
466     // We can retrieve the generated TE stmt from the kernel object:
467     StmtPtr kernel_stmt = kernel.getCodeGenStmt();
468     std::cout << "TE Stmt constructed from TorchScript: " << std::endl
469               << *kernel_stmt << std::endl;
470     // Prints:
471     // TE Stmt constructed from TorchScript:
472     // {
473     //   for (const auto v : c10::irange(5)) {
474     //     for (const auto _tail_tail : c10::irange(3)) {
475     //       aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) *
476     //       ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) +
477     //       (tB[_tail_tail + 3 * v]);
478     //     }
479     //   }
480     // }
481 
482     // We can also examine generated LLVM IR and assembly code:
483     std::cout << "Generated LLVM IR: " << std::endl;
484     auto ir_str = kernel.getCodeText("ir");
485     printLinesToFrom(ir_str, 15, 20);
486     // Prints:
487     // Generated LLVM IR:
488     //   %9 = bitcast float* %2 to <8 x float>*
489     //   %10 = load <8 x float>, <8 x float>* %9 ...
490     //   %11 = bitcast float* %5 to <8 x float>*
491     //   %12 = load <8 x float>, <8 x float>* %11 ...
492     //   %13 = fmul <8 x float> %10, %12
493     //   %14 = fmul <8 x float> %10, %13
494 
495     std::cout << "Generated assembly: " << std::endl;
496     auto asm_str = kernel.getCodeText("asm");
497     printLinesToFrom(asm_str, 10, 15);
498     // Prints:
499     // Generated assembly:
500     //         vmulps  %ymm1, %ymm0, %ymm2
501     //         vfmadd213ps     %ymm1, %ymm0, %ymm2
502     //         vmovups %ymm2, (%rax)
503     //         vmovss  32(%rcx), %xmm0
504     //         vmovss  32(%rdx), %xmm1
505     //         vmulss  %xmm1, %xmm0, %xmm2
506 
507     // We can also execute the generated kernel:
508     auto A =
509         at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
510         2.0;
511     auto B =
512         at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
513         3.0;
514     std::vector<at::Tensor> inputs = {A, B};
515     std::vector<torch::IValue> stack = torch::fmap<torch::IValue>(inputs);
516     kernel.run(stack);
517     auto R = stack[0].toTensor();
518 
519     // Let's print one of the elements from the result tensor to verify that the
520     // computation did happen and was correct:
521     std::cout << "R[2][2] = " << R[2][2] << std::endl;
522     // Prints:
523     // R[2][2] = 15
524     // [ CPUFloatType{} ]
525 #endif
526   }
527   return 0;
528 }
529 
printLinesToFrom(const std::string & input_str,int from,int to)530 void printLinesToFrom(const std::string& input_str, int from, int to) {
531   std::istringstream f(input_str);
532   std::string s;
533   int idx = 0;
534   while (getline(f, s)) {
535     if (idx > from) {
536       std::cout << s << "\n";
537     }
538     if (idx++ > to) {
539       break;
540     }
541   }
542 }
543