xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/operator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <oneapi/dnnl/dnnl_graph.hpp>
4 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 
7 namespace torch {
8 namespace jit {
9 namespace fuser {
10 namespace onednn {
11 
12 class Operator {
13  public:
Operator(const Node * node,dnnl::graph::op::kind kind)14   Operator(const Node* node, dnnl::graph::op::kind kind)
15       : n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {}
16 
17   // Returns output index if the Value is a graph output.
18   // Otherwise returns -1
graphOutputIdx(Value * v)19   int32_t graphOutputIdx(Value* v) {
20     int32_t i = 0;
21     for (const Value* output : v->owningGraph()->outputs()) {
22       if (v == output) {
23         return i;
24       }
25       i++;
26     }
27     return -1;
28   }
29 
setInputValue(Value * v)30   Operator& setInputValue(Value* v) {
31     if (v->mustNotBeNone()) {
32       if (v->type()->kind() == c10::TensorType::Kind) {
33         o.add_input(createLogicalTensor(v));
34       }
35     }
36     return *this;
37   }
38 
setInput(size_t offset)39   Operator& setInput(size_t offset) {
40     return setInputValue(n->input(offset));
41   }
42 
43   template <typename... Ts>
setInput(size_t offset,Ts...other)44   Operator& setInput(size_t offset, Ts... other) {
45     setInput(offset);
46     return setInput(other...);
47   }
48 
setOutputValue(Value * v)49   Operator& setOutputValue(Value* v) {
50     if (v->mustNotBeNone()) {
51       o.add_output(createLogicalTensor(v));
52     }
53     return *this;
54   }
55 
56   // setOutputValue & setOutput require a pointer to the LLGA graph, as output
57   // logical tensors that are graph outputs should be connected to an End LLGA
58   // op. A value of NULL can be provided for the graph pointer in order to
59   // maintain the legacy functionality of this function.
setOutputValue(Value * v,std::unique_ptr<dnnl::graph::graph> & g)60   Operator& setOutputValue(Value* v, std::unique_ptr<dnnl::graph::graph>& g) {
61     if (v->mustNotBeNone()) {
62       auto output_tensor = createLogicalTensor(v);
63       o.add_output(output_tensor);
64       if (g) {
65         int32_t outputIndex = graphOutputIdx(v);
66         if (outputIndex != -1) {
67           dnnl::graph::op newEndNode(
68               LONG_MAX - outputIndex,
69               dnnl::graph::op::kind::End,
70               "EndNodeForGraphOutput");
71           newEndNode.add_input(output_tensor);
72           g->add_op(newEndNode);
73         }
74       }
75     }
76     return *this;
77   }
78 
setOutput(std::unique_ptr<dnnl::graph::graph> & g,size_t offset)79   Operator& setOutput(std::unique_ptr<dnnl::graph::graph>& g, size_t offset) {
80     return setOutputValue(n->output(offset), g);
81   }
82 
setOutput(size_t offset)83   Operator& setOutput(size_t offset) {
84     return setOutputValue(n->output(offset));
85   }
86 
87   template <typename... Ts>
setOutput(std::unique_ptr<dnnl::graph::graph> & g,size_t offset,Ts...other)88   Operator& setOutput(
89       std::unique_ptr<dnnl::graph::graph>& g,
90       size_t offset,
91       Ts... other) {
92     setOutput(g, offset);
93     return setOutput(g, other...);
94   }
95 
96   template <typename Attr>
setAttr(dnnl::graph::op::attr name,Attr && attr)97   Operator& setAttr(dnnl::graph::op::attr name, Attr&& attr) {
98     o.set_attr(name, std::forward<Attr>(attr));
99     return *this;
100   }
101 
102   template <typename F>
setAttr(dnnl::graph::op::attr name,const F & fn,size_t offset)103   Operator& setAttr(dnnl::graph::op::attr name, const F& fn, size_t offset) {
104     return setAttr(name, fn(n, offset));
105   }
106 
ScalarToFloat(const Node * node,size_t offset)107   static float ScalarToFloat(const Node* node, size_t offset) {
108     return toIValue(node->input(offset))->toScalar().to<float>();
109   }
110 
Ints(const Node * node,size_t offset)111   static std::vector<int64_t> Ints(const Node* node, size_t offset) {
112     return toIValue(node->input(offset))->toIntVector();
113   }
114 
Int(const Node * node,size_t offset)115   static int64_t Int(const Node* node, size_t offset) {
116     return toIValue(node->input(offset))->toInt();
117   }
118 
Float(const Node * node,size_t offset)119   static float Float(const Node* node, size_t offset) {
120     return static_cast<float>(toIValue(node->input(offset))->toDouble());
121   }
122 
Bool(const Node * node,size_t offset)123   static bool Bool(const Node* node, size_t offset) {
124     return toIValue(node->input(offset))->toBool();
125   }
126 
getId(const Node * node)127   static uint64_t getId(const Node* node) {
128     return reinterpret_cast<uint64_t>(node); // cast node address as op id
129   }
130 
kind()131   dnnl::graph::op::kind kind() const {
132     return k;
133   }
134 
llgaOp()135   dnnl::graph::op llgaOp() const {
136     return o;
137   }
138 
139  private:
createLogicalTensor(Value * value)140   dnnl::graph::logical_tensor createLogicalTensor(Value* value) const {
141     return LlgaTensorDesc(value).logical_tensor();
142   }
143 
144   const Node* n;
145   dnnl::graph::op o;
146   dnnl::graph::op::kind k;
147 };
148 
149 } // namespace onednn
150 } // namespace fuser
151 } // namespace jit
152 } // namespace torch
153