1 #pragma once 2 3 #include <oneapi/dnnl/dnnl_graph.hpp> 4 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 7 namespace torch { 8 namespace jit { 9 namespace fuser { 10 namespace onednn { 11 12 class Operator { 13 public: Operator(const Node * node,dnnl::graph::op::kind kind)14 Operator(const Node* node, dnnl::graph::op::kind kind) 15 : n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {} 16 17 // Returns output index if the Value is a graph output. 18 // Otherwise returns -1 graphOutputIdx(Value * v)19 int32_t graphOutputIdx(Value* v) { 20 int32_t i = 0; 21 for (const Value* output : v->owningGraph()->outputs()) { 22 if (v == output) { 23 return i; 24 } 25 i++; 26 } 27 return -1; 28 } 29 setInputValue(Value * v)30 Operator& setInputValue(Value* v) { 31 if (v->mustNotBeNone()) { 32 if (v->type()->kind() == c10::TensorType::Kind) { 33 o.add_input(createLogicalTensor(v)); 34 } 35 } 36 return *this; 37 } 38 setInput(size_t offset)39 Operator& setInput(size_t offset) { 40 return setInputValue(n->input(offset)); 41 } 42 43 template <typename... Ts> setInput(size_t offset,Ts...other)44 Operator& setInput(size_t offset, Ts... other) { 45 setInput(offset); 46 return setInput(other...); 47 } 48 setOutputValue(Value * v)49 Operator& setOutputValue(Value* v) { 50 if (v->mustNotBeNone()) { 51 o.add_output(createLogicalTensor(v)); 52 } 53 return *this; 54 } 55 56 // setOutputValue & setOutput require a pointer to the LLGA graph, as output 57 // logical tensors that are graph outputs should be connected to an End LLGA 58 // op. A value of NULL can be provided for the graph pointer in order to 59 // maintain the legacy functionality of this function. setOutputValue(Value * v,std::unique_ptr<dnnl::graph::graph> & g)60 Operator& setOutputValue(Value* v, std::unique_ptr<dnnl::graph::graph>& g) { 61 if (v->mustNotBeNone()) { 62 auto output_tensor = createLogicalTensor(v); 63 o.add_output(output_tensor); 64 if (g) { 65 int32_t outputIndex = graphOutputIdx(v); 66 if (outputIndex != -1) { 67 dnnl::graph::op newEndNode( 68 LONG_MAX - outputIndex, 69 dnnl::graph::op::kind::End, 70 "EndNodeForGraphOutput"); 71 newEndNode.add_input(output_tensor); 72 g->add_op(newEndNode); 73 } 74 } 75 } 76 return *this; 77 } 78 setOutput(std::unique_ptr<dnnl::graph::graph> & g,size_t offset)79 Operator& setOutput(std::unique_ptr<dnnl::graph::graph>& g, size_t offset) { 80 return setOutputValue(n->output(offset), g); 81 } 82 setOutput(size_t offset)83 Operator& setOutput(size_t offset) { 84 return setOutputValue(n->output(offset)); 85 } 86 87 template <typename... Ts> setOutput(std::unique_ptr<dnnl::graph::graph> & g,size_t offset,Ts...other)88 Operator& setOutput( 89 std::unique_ptr<dnnl::graph::graph>& g, 90 size_t offset, 91 Ts... other) { 92 setOutput(g, offset); 93 return setOutput(g, other...); 94 } 95 96 template <typename Attr> setAttr(dnnl::graph::op::attr name,Attr && attr)97 Operator& setAttr(dnnl::graph::op::attr name, Attr&& attr) { 98 o.set_attr(name, std::forward<Attr>(attr)); 99 return *this; 100 } 101 102 template <typename F> setAttr(dnnl::graph::op::attr name,const F & fn,size_t offset)103 Operator& setAttr(dnnl::graph::op::attr name, const F& fn, size_t offset) { 104 return setAttr(name, fn(n, offset)); 105 } 106 ScalarToFloat(const Node * node,size_t offset)107 static float ScalarToFloat(const Node* node, size_t offset) { 108 return toIValue(node->input(offset))->toScalar().to<float>(); 109 } 110 Ints(const Node * node,size_t offset)111 static std::vector<int64_t> Ints(const Node* node, size_t offset) { 112 return toIValue(node->input(offset))->toIntVector(); 113 } 114 Int(const Node * node,size_t offset)115 static int64_t Int(const Node* node, size_t offset) { 116 return toIValue(node->input(offset))->toInt(); 117 } 118 Float(const Node * node,size_t offset)119 static float Float(const Node* node, size_t offset) { 120 return static_cast<float>(toIValue(node->input(offset))->toDouble()); 121 } 122 Bool(const Node * node,size_t offset)123 static bool Bool(const Node* node, size_t offset) { 124 return toIValue(node->input(offset))->toBool(); 125 } 126 getId(const Node * node)127 static uint64_t getId(const Node* node) { 128 return reinterpret_cast<uint64_t>(node); // cast node address as op id 129 } 130 kind()131 dnnl::graph::op::kind kind() const { 132 return k; 133 } 134 llgaOp()135 dnnl::graph::op llgaOp() const { 136 return o; 137 } 138 139 private: createLogicalTensor(Value * value)140 dnnl::graph::logical_tensor createLogicalTensor(Value* value) const { 141 return LlgaTensorDesc(value).logical_tensor(); 142 } 143 144 const Node* n; 145 dnnl::graph::op o; 146 dnnl::graph::op::kind k; 147 }; 148 149 } // namespace onednn 150 } // namespace fuser 151 } // namespace jit 152 } // namespace torch 153