1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/codegen.h> 4 #include <torch/csrc/jit/tensorexpr/ir_printer.h> 5 6 namespace torch::jit::tensorexpr { 7 8 class CppVarNameRewriter; 9 10 // Generates C++ code from the IR. 11 // 12 // Vector operations are unrolled. 13 // For example: 14 // C[Ramp(0, 1, 3)] = A[Ramp(0, 2, 3)] + B[Ramp(0, 3, 3)]; 15 // is unrolled into: 16 // C[0] = A[0] + B[0]; 17 // C[1] = A[2] + B[3]; 18 // C[2] = A[4] + B[6]; 19 class TORCH_API CppPrinter : public IRPrinter { 20 public: 21 explicit CppPrinter(std::ostream* os); 22 ~CppPrinter() override; 23 24 void printPrologue(); 25 26 using IRPrinter::visit; 27 28 // Binary expressions. 29 void visit(const ModPtr&) override; 30 void visit(const MaxPtr&) override; 31 void visit(const MinPtr&) override; 32 33 // Conditional expressions. 34 void visit(const CompareSelectPtr&) override; 35 void visit(const IfThenElsePtr&) override; 36 37 // Tensor operations. 38 void visit(const AllocatePtr&) override; 39 void visit(const FreePtr&) override; 40 void visit(const LoadPtr&) override; 41 void visit(const StorePtr&) override; 42 43 // Casts. 44 void visit(const CastPtr&) override; 45 void visit(const BitCastPtr&) override; 46 47 // Calls. 48 void visit(const IntrinsicsPtr&) override; 49 void visit(const ExternalCallPtr&) override; 50 51 // Vars. 52 void visit(const LetPtr&) override; 53 void visit(const VarPtr&) override; 54 55 // Vector data types. 56 void visit(const RampPtr&) override; 57 void visit(const BroadcastPtr&) override; 58 59 private: 60 int lane_; 61 std::unordered_map<VarPtr, ExprPtr> vector_vars_; 62 }; 63 64 class TORCH_API CppCodeGen : public CodeGen { 65 public: 66 CppCodeGen( 67 StmtPtr stmt, 68 const std::vector<BufferArg>& buffer_args, 69 at::Device device = at::kCPU, 70 const std::string& kernel_func_name = "func"); 71 72 ~CppCodeGen() override; 73 74 void call(const std::vector<CallArg>& args) override; 75 void call_raw(const std::vector<void*>& args) override; 76 77 template <typename... Ts> operator()78 void operator()(const Ts&... ts) { 79 call(std::vector<CallArg>({CallArg(ts)...})); 80 } 81 82 std::string getCodeText(const std::string& attr = "") override { 83 return oss_.str(); 84 } 85 86 private: 87 void init(); 88 os()89 std::ostream& os() { 90 return printer_->os(); 91 } 92 93 std::ostringstream oss_; 94 std::unique_ptr<CppPrinter> printer_; 95 std::unique_ptr<CppVarNameRewriter> var_name_rewriter_; 96 }; 97 98 } // namespace torch::jit::tensorexpr 99