1 #include <aten/src/ATen/core/jit_type.h>
2 #include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/shape_analysis.h>
5
6 namespace torch {
7 namespace jit {
8 namespace fuser {
9 namespace onednn {
10
compareConstValue(Value * v,double d)11 static bool compareConstValue(Value* v, double d) {
12 auto ival = toIValue(v);
13 return ival.has_value() &&
14 ((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
15 (ival->isDouble() && ival->toDouble() == d));
16 }
17
handleBinaryOpInputs(Node * node)18 static void handleBinaryOpInputs(Node* node) {
19 // We do not handle binary ops with two scalar inputs,
20 // and we assume scalar is always at the second place.
21 if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
22 auto dtypeOfFirstInput =
23 node->input(0)->type()->cast<TensorType>()->scalarType().value();
24 if (node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
25 node->input(1)->type()->isSubtypeOf(IntType::get())) {
26 // If a scalar is added to be a tensor, we would assume that the
27 // scalar is of the same dtype as the tensor, as oneDNN graph
28 // currently requires inputs of binary ops to have the same dtype.
29 // We create a 1D tensor from the scalar input & "promote" its
30 // dtype to that of the first input. Doing so helps us satisfy PyTorch's
31 // type promotion rules.
32 // Although we convert the scalar to a tensor, we still need to promote
33 // types, as if the second input were still a scalar.
34 // The following sample code-snippet illustrates that converting a scalar
35 // input to a 1-D tensor may result in a different output dtype than would
36 // otherwise have been the case.
37 // clang-format off
38 // >>> (1. + torch.rand([2]).half()).dtype
39 // torch.float16
40 // >>> (torch.tensor(1.).unsqueeze(0) + (torch.rand([2]).half())).dtype
41 // torch.float32
42 // clang-format on
43 auto promotedDtype = dtypeOfFirstInput;
44 auto scalar = node->input(1);
45 WithInsertPoint guard(node);
46 auto g = node->owningGraph();
47 // 42 : Scalar --> tensor(42.0) : Float([])
48 auto t = g->insert(aten::as_tensor, {scalar}, {{"dtype", promotedDtype}});
49 // add dim & stride info to IR
50 std::optional<size_t> t_dim = 1;
51 auto target_type = TensorTypePtr(
52 TensorType::create(promotedDtype, at::kCPU, t_dim, false));
53 target_type = target_type->withSizes({1});
54 t->setType(target_type);
55
56 // tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
57 auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
58 unsqueezed->setType(target_type);
59 node->replaceInput(1, unsqueezed);
60
61 // dtype might have changed, so needs to be updated in IR as well
62 node->output()->setType(
63 node->output()->type()->expect<TensorType>()->withScalarType(
64 promotedDtype));
65 } else if (node->input(1)->type()->isSubtypeOf(TensorType::get())) {
66 // Here, both inputs are tensors, and we just wanna make sure that they
67 // are the same dtype, as oneDNN Graph requires both inputs to have the
68 // same dtype. We'll follow PyTorch's type-promotion rules here.
69 auto second_input_typeptr = node->input(1)->type()->expect<TensorType>();
70 std::optional<at::ScalarType> second_input_type =
71 second_input_typeptr->scalarType();
72 if (second_input_type != std::nullopt) {
73 // dtype of the second tensor might not be available in the IR
74 auto dtypeOfSecondInput = second_input_type.value();
75 if (dtypeOfFirstInput != dtypeOfSecondInput) {
76 // Type promotion is required
77 auto promotedDtype =
78 c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
79 WithInsertPoint guard(node);
80 auto g = node->owningGraph();
81 if (promotedDtype == dtypeOfFirstInput) {
82 auto to_node_output = g->insert(
83 aten::to, {node->input(1)}, {{"dtype", promotedDtype}});
84 to_node_output->setType(
85 node->input(1)->type()->expect<TensorType>()->withScalarType(
86 promotedDtype));
87 node->replaceInput(1, to_node_output);
88 } else {
89 auto to_node_output = g->insert(
90 aten::to, {node->input(0)}, {{"dtype", promotedDtype}});
91 to_node_output->setType(
92 node->input(0)->type()->expect<TensorType>()->withScalarType(
93 promotedDtype));
94 node->replaceInput(0, to_node_output);
95 }
96 // dtype might have changed, so needs to be updated in IR as well
97 node->output()->setType(
98 node->output()->type()->expect<TensorType>()->withScalarType(
99 promotedDtype));
100 } else {
101 // both dtypes are same
102 // IR info of dtypes is missing sometimes in JIT IR,
103 // and we shouldn't treat those tensors as FP32 tensors by default.
104 node->output()->setType(
105 node->output()->type()->expect<TensorType>()->withScalarType(
106 dtypeOfFirstInput));
107 }
108 } // end inner if block
109 } // end outer if block
110 }
111 }
112
ConvertScalarToTensor(Block * block)113 static void ConvertScalarToTensor(Block* block) {
114 for (auto node : block->nodes()) {
115 for (auto sub : node->blocks()) {
116 ConvertScalarToTensor(sub);
117 }
118
119 if (node->kind() == aten::add || node->kind() == aten::mul ||
120 node->kind() == aten::div) {
121 handleBinaryOpInputs(node);
122 }
123 }
124 }
125
mayDecomposeAdd(Node * node)126 static void mayDecomposeAdd(Node* node) {
127 if (node->inputs().size() < 3) {
128 return; // corner-case in BERT-mrpc that's not in line with
129 // native_functions.yaml
130 }
131 if (toIValue(node->namedInput("alpha")).has_value()) {
132 auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
133 if (!alphaEqualsOne) {
134 WithInsertPoint guard(node);
135 auto g = node->owningGraph();
136 auto mul = g->insert(
137 aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
138 if (node->namedInput("other")->type()->isSubtypeOf(TensorType::get())) {
139 auto mulTensorTypePtr = node->namedInput("other")->type();
140 mul->setType(mulTensorTypePtr);
141 }
142 node->replaceInput(1, mul);
143 auto one = g->insertConstant(1.0);
144 node->replaceInput(2, one);
145 }
146 }
147 }
148
DecomposeFusedAdd(Block * block)149 static void DecomposeFusedAdd(Block* block) {
150 for (auto node : block->nodes()) {
151 for (auto sub : node->blocks()) {
152 DecomposeFusedAdd(sub);
153 }
154
155 if (node->kind() == aten::add) {
156 mayDecomposeAdd(node);
157 }
158 }
159 }
160
EliminateIdentityMulAdd(Block * block)161 static void EliminateIdentityMulAdd(Block* block) {
162 for (auto node : block->nodes()) {
163 for (auto sub : node->blocks()) {
164 EliminateIdentityMulAdd(sub);
165 }
166
167 if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
168 (node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
169 node->output()->replaceAllUsesWith(node->namedInput("self"));
170 }
171 }
172 }
173
PrepareBinaryForLLGA(const std::shared_ptr<Graph> & graph)174 void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
175 DecomposeFusedAdd(graph->block());
176 EliminateIdentityMulAdd(graph->block());
177 EliminateDeadCode(graph);
178 // ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
179 ConvertScalarToTensor(graph->block());
180 }
181
182 } // namespace onednn
183 } // namespace fuser
184 } // namespace jit
185 } // namespace torch
186