1 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
2
3 #include <torch/csrc/jit/passes/constant_propagation.h>
4 #include <torch/csrc/jit/passes/normalize_ops.h>
5 #include <torch/csrc/jit/runtime/operator.h>
6
7 #include <c10/util/irange.h>
8
9 namespace torch {
10 namespace jit {
11 namespace {
12
deepCopy(const IValue & self)13 IValue deepCopy(const IValue& self) {
14 // primitive types can be copied directly
15 if (!self.isPtrType()) {
16 return self;
17 }
18
19 // Tensors need special handling, since copy assignment creates an alias
20 if (self.isTensor()) {
21 return IValue(self.toTensor().clone(at::MemoryFormat::Preserve));
22 }
23
24 // Lists of ivalues should recursively deep copy their contents
25 if (self.isList()) {
26 auto source = self.toList();
27 auto newList = c10::impl::GenericList(source.elementType());
28 newList.reserve(source.size());
29 for (const IValue& value : source) {
30 newList.push_back(deepCopy(value));
31 }
32 return newList;
33 }
34
35 // Regular lists can copy assign
36 if (self.isIntList()) {
37 return IValue(self.toIntList().copy());
38 } else if (self.isDoubleList()) {
39 return IValue(self.toDoubleList().copy());
40 } else if (self.isComplexDoubleList()) {
41 return IValue(self.toComplexDoubleList().copy());
42 } else if (self.isBoolList()) {
43 return IValue(self.toBoolList().copy());
44 } else if (self.isString()) {
45 return IValue(self.toStringRef());
46 }
47
48 // If in the future we add more reference types that are used in aten ops,
49 // we'll have to add them as cases here.
50 AT_ASSERT(false);
51 }
52
deepCopy(const Stack & stack)53 Stack deepCopy(const Stack& stack) {
54 Stack ret;
55 ret.reserve(stack.size());
56 for (const auto& v : stack) {
57 ret.push_back(deepCopy(v));
58 }
59 return ret;
60 }
61
deepEquals(const IValue & lhs,const IValue & rhs)62 bool deepEquals(const IValue& lhs, const IValue& rhs) {
63 if (lhs.isTensor() && rhs.isTensor()) {
64 return lhs.toTensor().equal(rhs.toTensor());
65 }
66
67 if (lhs.isTensorList() && rhs.isTensorList()) {
68 const auto a = lhs.toTensorList();
69 const auto b = rhs.toTensorList();
70 if (a.size() != b.size()) {
71 return false;
72 }
73 for (auto i = decltype(a.size()){0}; i < a.size(); ++i) {
74 if (!a[i].equal(b[i])) {
75 return false;
76 }
77 }
78 return true;
79 }
80
81 return lhs == rhs;
82 }
83
84 struct AliasAndIValue {
AliasAndIValuetorch::jit::__anon47b13e6e0111::AliasAndIValue85 AliasAndIValue(const at::AliasInfo* aliasInfo, IValue iValue)
86 : aliasInfo(aliasInfo), iValue(std::move(iValue)) {}
87
88 const at::AliasInfo* aliasInfo;
89 const IValue iValue;
90 };
91
92 // No inputs should alias each other
checkInputPreconditions(const Stack & inputs)93 void checkInputPreconditions(const Stack& inputs) {
94 for (const auto i : c10::irange(inputs.size())) {
95 for (const auto j : c10::irange(inputs.size())) {
96 if (i == j) {
97 continue;
98 }
99 const auto& lhs = inputs.at(i);
100 const auto& rhs = inputs.at(j);
101 AT_ASSERT(!lhs.isAliasOf(rhs));
102 }
103 }
104 }
105
106 // If two ivalues alias, they must share an alias set
checkAliases(const std::vector<AliasAndIValue> & inputs,const std::vector<AliasAndIValue> & outputs)107 void checkAliases(
108 const std::vector<AliasAndIValue>& inputs,
109 const std::vector<AliasAndIValue>& outputs) {
110 for (const auto& output : outputs) {
111 // if this output aliases any input, make sure that they share an alias set
112 for (const auto& input : inputs) {
113 if (output.iValue.isAliasOf(input.iValue)) {
114 const auto* inputSet = input.aliasInfo;
115 const auto* outputSet = output.aliasInfo;
116 AT_ASSERT(inputSet && outputSet);
117 bool found = false;
118 for (const auto& set : inputSet->beforeSets()) {
119 if (outputSet->beforeSets().count(set)) {
120 found = true;
121 break;
122 }
123 }
124 AT_ASSERT(found);
125 }
126 }
127 }
128 }
129
130 // If we didn't specify that we write to an input value, it must have not
131 // changed
checkWrites(const std::vector<AliasAndIValue> & inputs,const std::vector<IValue> & deepCopiedInputs)132 void checkWrites(
133 const std::vector<AliasAndIValue>& inputs,
134 const std::vector<IValue>& deepCopiedInputs) {
135 AT_ASSERT(inputs.size() == deepCopiedInputs.size());
136 for (const auto i : c10::irange(inputs.size())) {
137 const auto& input = inputs[i];
138 const auto& deepCopiedInput = deepCopiedInputs[i];
139 if (!input.aliasInfo || !input.aliasInfo->isWrite()) {
140 AT_ASSERT(deepEquals(input.iValue, deepCopiedInput));
141 }
142 }
143 }
144
findNodeForOp(const Graph & g,const std::string & unqualifiedOpName)145 const Node* findNodeForOp(
146 const Graph& g,
147 const std::string& unqualifiedOpName) {
148 const auto opName = Symbol::fromQualString("aten::" + unqualifiedOpName);
149 for (const auto* node : g.nodes()) {
150 if (node->kind() == opName) {
151 return node;
152 }
153 }
154
155 // Check for alias-ed operator names
156 const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName);
157 if (aliasOp != torch::jit::getOperatorAliasMap().end()) {
158 for (const auto* node : g.nodes()) {
159 if (node->kind() == aliasOp->second) {
160 return node;
161 }
162 }
163 }
164
165 // Ideally, there will be only one ATen operator that has tensor outputs in
166 // the graph. Let's use that as the last resolve to make checkAliasAnnotation
167 // more robust.
168 for (const auto* node : g.nodes()) {
169 if (!node->maybeOperator()) {
170 continue;
171 }
172 if (!node->getOperator().isC10Op()) {
173 continue;
174 }
175
176 for (const auto* output : node->outputs()) {
177 if (output->type()->kind() == TypeKind::TensorType) {
178 return node;
179 }
180 }
181 }
182
183 AT_ASSERT(false);
184 }
185
186 // Handle a few special cases where we need to propagate constants
187 // manually
188 // TODO(suo): we should be able to move this stuff to constant prop
toIValueProp(const Value * v)189 std::optional<IValue> toIValueProp(const Value* v) {
190 if (v->node()->kind() == prim::ListConstruct) {
191 std::vector<IValue> genericList;
192 for (auto input : v->node()->inputs()) {
193 if (auto elem = toIValue(input)) {
194 genericList.push_back(*elem);
195 } else {
196 // One of the list elements isn't constant.
197 return std::nullopt;
198 }
199 }
200
201 // Specialize the list based on ListConstruct's return type
202 auto listType = v->node()->output()->type();
203 auto containedType = listType->containedTypes().at(0);
204 if (containedType == IntType::get()) {
205 return IValue(
206 fmap(genericList, [](const IValue& v) { return v.toInt(); }));
207 } else if (containedType == FloatType::get()) {
208 return IValue(
209 fmap(genericList, [](const IValue& v) { return v.toDouble(); }));
210 } else if (containedType->isSubtypeOf(*TensorType::get())) {
211 return IValue(
212 fmap(genericList, [](const IValue& v) { return v.toTensor(); }));
213 } else {
214 return std::nullopt;
215 }
216 }
217
218 if (v->node()->kind() == aten::Float) {
219 if (auto maybe_stack = runNodeIfInputsAreConstant(v->node())) {
220 return maybe_stack->at(0);
221 }
222 }
223 return std::nullopt;
224 }
225
226 // batch_norm and instance_norm have incorrect annotations, because
227 // (a!)? annotations aren't supported, so these checks would fail.
228 // Their behavior also varies depending on the `training` and
229 // `use_input_stats` arguments.
230 // There are custom implementations in alias_analysis.cpp for these ops.
shouldIgnoreNode(const Node * n)231 bool shouldIgnoreNode(const Node* n) {
232 switch (n->kind()) {
233 case aten::batch_norm:
234 case aten::instance_norm:
235 return true;
236 default:
237 return false;
238 }
239 }
240 } // namespace
241
checkAliasAnnotation(const std::shared_ptr<Graph> & graph,std::vector<IValue> pythonInputs,const std::string & unqualifiedOpName)242 void checkAliasAnnotation(
243 const std::shared_ptr<Graph>& graph,
244 std::vector<IValue> pythonInputs,
245 const std::string& unqualifiedOpName) {
246 // Find the node that corresponds to our op name
247 const auto node = findNodeForOp(*graph, unqualifiedOpName);
248 if (shouldIgnoreNode(node)) {
249 return;
250 }
251
252 // Build the stack to use as input to the op
253 Stack stack;
254 for (const auto input : node->inputs()) {
255 if (input->node() == graph->param_node()) {
256 // This value was passed as an input in python
257 push(stack, pythonInputs.at(input->offset()));
258 } else {
259 // This a generated constant, which we need to evaluate
260 auto inputValue = toIValue(input);
261 if (!inputValue) {
262 inputValue = toIValueProp(input);
263 }
264
265 if (inputValue) {
266 push(stack, *inputValue);
267 } else {
268 AT_ASSERT(input->type()->kind() == TypeKind::OptionalType);
269 push(stack, IValue());
270 }
271 }
272 }
273
274 // Precondition: no inputs should alias each other. So if we find an alias,
275 // it was created by the op.
276 checkInputPreconditions(stack);
277
278 const auto& schema = node->schema();
279
280 std::vector<AliasAndIValue> inputsToCheck;
281 for (const auto i : c10::irange(schema.arguments().size())) {
282 inputsToCheck.emplace_back(
283 schema.arguments().at(i).alias_info(), stack.at(i));
284 }
285
286 // Save a copy of the inputs so we can check whether the original inputs were
287 // written to.
288 const auto inputsDeepCopy = deepCopy(stack);
289
290 // Run the op
291 node->getOperation()(stack);
292
293 const auto outputs = std::move(stack);
294
295 std::vector<AliasAndIValue> outputsToCheck;
296 for (const auto i : c10::irange(schema.returns().size())) {
297 outputsToCheck.emplace_back(
298 schema.returns().at(i).alias_info(), outputs.at(i));
299 }
300
301 // Check that if any alias was created, we annotated it properly.
302 checkAliases(inputsToCheck, outputsToCheck);
303
304 // Check that if nothing was accidentally written to.
305 checkWrites(inputsToCheck, inputsDeepCopy);
306 }
307
308 } // namespace jit
309 } // namespace torch
310