xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/prepare_binary.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <aten/src/ATen/core/jit_type.h>
2 #include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/shape_analysis.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace fuser {
9 namespace onednn {
10 
compareConstValue(Value * v,double d)11 static bool compareConstValue(Value* v, double d) {
12   auto ival = toIValue(v);
13   return ival.has_value() &&
14       ((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
15        (ival->isDouble() && ival->toDouble() == d));
16 }
17 
handleBinaryOpInputs(Node * node)18 static void handleBinaryOpInputs(Node* node) {
19   // We do not handle binary ops with two scalar inputs,
20   // and we assume scalar is always at the second place.
21   if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
22     auto dtypeOfFirstInput =
23         node->input(0)->type()->cast<TensorType>()->scalarType().value();
24     if (node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
25         node->input(1)->type()->isSubtypeOf(IntType::get())) {
26       // If a scalar is added to be a tensor, we would assume that the
27       // scalar is of the same dtype as the tensor, as oneDNN graph
28       // currently requires inputs of binary ops to have the same dtype.
29       // We create a 1D tensor from the scalar input & "promote" its
30       // dtype to that of the first input. Doing so helps us satisfy PyTorch's
31       // type promotion rules.
32       // Although we convert the scalar to a tensor, we still need to promote
33       // types, as if the second input were still a scalar.
34       // The following sample code-snippet illustrates that converting a scalar
35       // input to a 1-D tensor may result in a different output dtype than would
36       // otherwise have been the case.
37       // clang-format off
38       //   >>> (1. + torch.rand([2]).half()).dtype
39       //       torch.float16
40       //   >>> (torch.tensor(1.).unsqueeze(0) + (torch.rand([2]).half())).dtype
41       //       torch.float32
42       // clang-format on
43       auto promotedDtype = dtypeOfFirstInput;
44       auto scalar = node->input(1);
45       WithInsertPoint guard(node);
46       auto g = node->owningGraph();
47       // 42 : Scalar  -->  tensor(42.0) : Float([])
48       auto t = g->insert(aten::as_tensor, {scalar}, {{"dtype", promotedDtype}});
49       // add dim & stride info to IR
50       std::optional<size_t> t_dim = 1;
51       auto target_type = TensorTypePtr(
52           TensorType::create(promotedDtype, at::kCPU, t_dim, false));
53       target_type = target_type->withSizes({1});
54       t->setType(target_type);
55 
56       // tensor(42.0) : Float([])  -->  tensor([42.0]) : Float([1])
57       auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
58       unsqueezed->setType(target_type);
59       node->replaceInput(1, unsqueezed);
60 
61       // dtype might have changed, so needs to be updated in IR as well
62       node->output()->setType(
63           node->output()->type()->expect<TensorType>()->withScalarType(
64               promotedDtype));
65     } else if (node->input(1)->type()->isSubtypeOf(TensorType::get())) {
66       // Here, both inputs are tensors, and we just wanna make sure that they
67       // are the same dtype, as oneDNN Graph requires both inputs to have the
68       // same dtype. We'll follow PyTorch's type-promotion rules here.
69       auto second_input_typeptr = node->input(1)->type()->expect<TensorType>();
70       std::optional<at::ScalarType> second_input_type =
71           second_input_typeptr->scalarType();
72       if (second_input_type != std::nullopt) {
73         // dtype of the second tensor might not be available in the IR
74         auto dtypeOfSecondInput = second_input_type.value();
75         if (dtypeOfFirstInput != dtypeOfSecondInput) {
76           // Type promotion is required
77           auto promotedDtype =
78               c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
79           WithInsertPoint guard(node);
80           auto g = node->owningGraph();
81           if (promotedDtype == dtypeOfFirstInput) {
82             auto to_node_output = g->insert(
83                 aten::to, {node->input(1)}, {{"dtype", promotedDtype}});
84             to_node_output->setType(
85                 node->input(1)->type()->expect<TensorType>()->withScalarType(
86                     promotedDtype));
87             node->replaceInput(1, to_node_output);
88           } else {
89             auto to_node_output = g->insert(
90                 aten::to, {node->input(0)}, {{"dtype", promotedDtype}});
91             to_node_output->setType(
92                 node->input(0)->type()->expect<TensorType>()->withScalarType(
93                     promotedDtype));
94             node->replaceInput(0, to_node_output);
95           }
96           // dtype might have changed, so needs to be updated in IR as well
97           node->output()->setType(
98               node->output()->type()->expect<TensorType>()->withScalarType(
99                   promotedDtype));
100         } else {
101           // both dtypes are same
102           // IR info of dtypes is missing sometimes in JIT IR,
103           // and we shouldn't treat those tensors as FP32 tensors by default.
104           node->output()->setType(
105               node->output()->type()->expect<TensorType>()->withScalarType(
106                   dtypeOfFirstInput));
107         }
108       } // end inner if block
109     } // end outer if block
110   }
111 }
112 
ConvertScalarToTensor(Block * block)113 static void ConvertScalarToTensor(Block* block) {
114   for (auto node : block->nodes()) {
115     for (auto sub : node->blocks()) {
116       ConvertScalarToTensor(sub);
117     }
118 
119     if (node->kind() == aten::add || node->kind() == aten::mul ||
120         node->kind() == aten::div) {
121       handleBinaryOpInputs(node);
122     }
123   }
124 }
125 
mayDecomposeAdd(Node * node)126 static void mayDecomposeAdd(Node* node) {
127   if (node->inputs().size() < 3) {
128     return; // corner-case in BERT-mrpc that's not in line with
129             // native_functions.yaml
130   }
131   if (toIValue(node->namedInput("alpha")).has_value()) {
132     auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
133     if (!alphaEqualsOne) {
134       WithInsertPoint guard(node);
135       auto g = node->owningGraph();
136       auto mul = g->insert(
137           aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
138       if (node->namedInput("other")->type()->isSubtypeOf(TensorType::get())) {
139         auto mulTensorTypePtr = node->namedInput("other")->type();
140         mul->setType(mulTensorTypePtr);
141       }
142       node->replaceInput(1, mul);
143       auto one = g->insertConstant(1.0);
144       node->replaceInput(2, one);
145     }
146   }
147 }
148 
DecomposeFusedAdd(Block * block)149 static void DecomposeFusedAdd(Block* block) {
150   for (auto node : block->nodes()) {
151     for (auto sub : node->blocks()) {
152       DecomposeFusedAdd(sub);
153     }
154 
155     if (node->kind() == aten::add) {
156       mayDecomposeAdd(node);
157     }
158   }
159 }
160 
EliminateIdentityMulAdd(Block * block)161 static void EliminateIdentityMulAdd(Block* block) {
162   for (auto node : block->nodes()) {
163     for (auto sub : node->blocks()) {
164       EliminateIdentityMulAdd(sub);
165     }
166 
167     if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
168         (node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
169       node->output()->replaceAllUsesWith(node->namedInput("self"));
170     }
171   }
172 }
173 
PrepareBinaryForLLGA(const std::shared_ptr<Graph> & graph)174 void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
175   DecomposeFusedAdd(graph->block());
176   EliminateIdentityMulAdd(graph->block());
177   EliminateDeadCode(graph);
178   // ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
179   ConvertScalarToTensor(graph->block());
180 }
181 
182 } // namespace onednn
183 } // namespace fuser
184 } // namespace jit
185 } // namespace torch
186