1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/runtime/argument_spec.h>
3
4 #include <iostream>
5
6 namespace torch::jit {
7
scan(const TypePtr & typ,size_t depth,const WrittenSlots & written_slots)8 void ArgumentSpecCreator::scan(
9 const TypePtr& typ,
10 size_t depth,
11 const WrittenSlots& written_slots) {
12 auto finishAggregate = [&](size_t pos) {
13 // it is possible after all the work we did to scan this aggregate,
14 // we found no tensors or optionals to specialize. In this case, just
15 // generate a skip for the whole aggregate.
16 bool any_spec = std::any_of(
17 instructions_.begin() + pos, instructions_.end(), [](Inst i) {
18 return i == SPECIALIZE_TENSOR || i == SPECIALIZE_OPTIONAL ||
19 i == SPECIALIZE_OPTIONAL_TENSOR;
20 });
21 if (!any_spec) {
22 instructions_[pos] = SKIP;
23 instructions_.resize(pos + 1);
24 } else {
25 instructions_.emplace_back(LEAVE);
26 }
27 };
28 // the simple vm that scans instructions_ has a limited stack depth,
29 // this prevents going deeper than that.
30 if (depth >= ARG_SPEC_DEPTH_LIMIT) {
31 instructions_.emplace_back(SKIP);
32 }
33 if (typ->isSubtypeOf(*TensorType::get())) {
34 num_tensors_++;
35 instructions_.emplace_back(SPECIALIZE_TENSOR);
36 } else if (typ->isSubtypeOf(*OptionalType::ofTensor())) {
37 num_tensors_++;
38 num_optionals_++;
39 instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR);
40 } else if (typ->kind() == TypeKind::OptionalType) {
41 // note that Optional[Tuple] or Optional[Class] will just register
42 // as optional (previously they didn't at all, so it's not a regression).
43 num_optionals_++;
44 instructions_.emplace_back(SPECIALIZE_OPTIONAL);
45 } else if (auto tup = typ->cast<TupleType>()) {
46 size_t pos = instructions_.size();
47 instructions_.emplace_back(ENTER_TUPLE);
48 for (const auto& elem : tup->containedTypes()) {
49 scan(elem, depth + 1, written_slots);
50 }
51 finishAggregate(pos);
52 } else if (auto cls = typ->cast<ClassType>()) {
53 size_t pos = instructions_.size();
54 instructions_.emplace_back(ENTER_OBJECT);
55 for (size_t i = 0; i < cls->numAttributes(); ++i) {
56 auto key =
57 cls->name()->qualifiedName() + cls->getAttributes().at(i).getName();
58 // it is only safe to specialize because someone might have written to it
59 if (!written_slots.count(key)) {
60 scan(cls->containedTypes().at(i), depth + 1, written_slots);
61 } else {
62 instructions_.emplace_back(SKIP);
63 }
64 }
65 finishAggregate(pos);
66 } else {
67 instructions_.emplace_back(SKIP);
68 }
69 };
70
71 // this is a coarse-grained guarantee that the slots of a class will not be
72 // modified by the function. It works fine for things that used be read-only
73 // modules, but will be overly conservative when some classes are written to.
74 // Doing alias analysis and looking for writes to the class would be more
75 // accurate.
scanWrittenSlots(Block * block,ArgumentSpecCreator::WrittenSlots & written_slots)76 static void scanWrittenSlots(
77 Block* block,
78 ArgumentSpecCreator::WrittenSlots& written_slots) {
79 for (Node* n : block->nodes()) {
80 if (n->kind() == prim::SetAttr) {
81 if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
82 written_slots.insert(cls->name()->qualifiedName() + n->s(attr::name));
83 }
84 }
85 for (Block* subblock : n->blocks()) {
86 scanWrittenSlots(subblock, written_slots);
87 }
88 if (n->hasAttribute(attr::Subgraph)) {
89 scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
90 }
91 }
92 }
93
ArgumentSpecCreator(Graph & graph)94 ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
95 : num_inputs_(graph.inputs().size()) {
96 WrittenSlots written_slots;
97 scanWrittenSlots(graph.block(), written_slots);
98 for (Value* input : graph.inputs()) {
99 scan(input->type(), 0, written_slots);
100 }
101 }
102
dump() const103 void ArgumentSpecCreator::dump() const {
104 for (Inst inst : instructions_) {
105 switch (inst) {
106 case LEAVE:
107 std::cout << "] ";
108 break;
109 case ENTER_TUPLE:
110 std::cout << "Tuple[";
111 break;
112 case ENTER_OBJECT:
113 std::cout << "Object[";
114 break;
115 case SKIP:
116 std::cout << "Skip ";
117 break;
118 case SPECIALIZE_TENSOR:
119 std::cout << "SpecializeTensor ";
120 break;
121 case SPECIALIZE_OPTIONAL_TENSOR:
122 std::cout << "SpecializeOptionalTensor ";
123 break;
124 case SPECIALIZE_OPTIONAL:
125 std::cout << "SpecializeOptional ";
126 break;
127 }
128 }
129 std::cout << "\n";
130 }
131
create(bool with_grad,const Stack & input) const132 ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
133 const {
134 ArgumentSpec spec(num_tensors_, num_optionals_);
135 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
136 const IValue* stack[ARG_SPEC_DEPTH_LIMIT]; // The stack of IValue lists
137 // The stack gets initialized with the input list
138 stack[0] = last(input, num_inputs_).begin();
139 size_t stack_top = 0; // offset to the top of the stack
140 for (Inst inst : instructions_) {
141 switch (inst) {
142 case SPECIALIZE_OPTIONAL_TENSOR: {
143 // consume a tensor optional and add to the argspec
144 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
145 auto& arg = *stack[stack_top]++;
146 spec.addOptional(arg);
147 if (!arg.isNone()) {
148 spec.addTensor(arg, with_grad);
149 }
150 } break;
151 case SPECIALIZE_TENSOR:
152 // consume a tensor and add to the argspec
153 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
154 spec.addTensor(*stack[stack_top]++, with_grad);
155 break;
156 case SPECIALIZE_OPTIONAL:
157 // consume a non-tensor optional and add to the argspec
158 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
159 spec.addOptional(*stack[stack_top]++);
160 break;
161 case ENTER_TUPLE: {
162 // consume tuple
163 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
164 const IValue* iv = stack[stack_top]++;
165 AT_ASSERT(iv->isTuple(), "Expected Tuple but got ", iv->tagKind());
166 auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
167 auto tup_ptr = &p->elements()[0];
168 // push list of tuple elements to the stack
169 stack[++stack_top] = tup_ptr;
170 } break;
171 case ENTER_OBJECT: {
172 // consume object
173 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
174 const IValue* iv = stack[stack_top]++;
175 AT_ASSERT(iv->isObject(), "Expected Object but got ", iv->tagKind());
176 auto obj_ptr = &iv->toObjectRef().slots()[0];
177 // push list of object elements to the stack
178 stack[++stack_top] = obj_ptr;
179 } break;
180 case SKIP:
181 // consume and skip an element
182 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
183 stack[stack_top]++;
184 break;
185 case LEAVE:
186 --stack_top;
187 break;
188 }
189 }
190 return spec;
191 }
192
193 // For every input of a given graph, returns a most detailed type that can be
194 // inferred for it based on this ArgumentSpec.
specializeTypes(Graph & graph,const ArgumentSpec & spec) const195 void ArgumentSpecCreator::specializeTypes(
196 Graph& graph,
197 const ArgumentSpec& spec) const {
198 auto input_types =
199 fmap(graph.inputs(), [](Value* input) { return input->type(); });
200 std::vector<std::vector<TypePtr>> result_stack;
201 result_stack.emplace_back();
202 std::vector<const TypePtr*> input_stack = {input_types.data()};
203 std::vector<std::function<TypePtr()>> aggregate_creators;
204
205 size_t tensor_arg_spec_offset =
206 0; // number of specialized tensors seen so far
207 size_t optional_arg_spec_offset =
208 0; // number of specialized optionals seen so far
209
210 for (Inst inst : instructions_) {
211 switch (inst) {
212 case SPECIALIZE_OPTIONAL_TENSOR: {
213 auto& input_type = *input_stack.back()++;
214 auto is_present = spec.isPresent(optional_arg_spec_offset++);
215 if (!is_present) {
216 result_stack.back().emplace_back(input_type);
217 break;
218 }
219 auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
220 AT_ASSERT(arg.defined());
221 result_stack.back().emplace_back(arg.toType());
222 } break;
223 case SPECIALIZE_TENSOR: {
224 input_stack.back()++;
225 auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
226 if (!arg.defined()) {
227 result_stack.back().emplace_back(TensorType::get()->withUndefined());
228 } else {
229 result_stack.back().emplace_back(arg.toType());
230 }
231 } break;
232 case SPECIALIZE_OPTIONAL: {
233 auto is_present = spec.isPresent(optional_arg_spec_offset++);
234 auto ot = (*input_stack.back()++)->expect<OptionalType>();
235 if (!is_present) {
236 result_stack.back().emplace_back(ot);
237 } else {
238 result_stack.back().emplace_back(ot->getElementType());
239 }
240 } break;
241 case ENTER_TUPLE: {
242 auto tup = (*input_stack.back()++)->expect<TupleType>();
243 input_stack.emplace_back(tup->elements().data());
244 result_stack.emplace_back();
245 aggregate_creators.emplace_back(
246 [&] { return TupleType::create(result_stack.back()); });
247 } break;
248 case ENTER_OBJECT: {
249 auto cls = (*input_stack.back()++)->expect<ClassType>();
250 input_stack.emplace_back(cls->containedTypes().data());
251 result_stack.emplace_back();
252 aggregate_creators.emplace_back(
253 [&result_stack, cls] { return cls->refine(result_stack.back()); });
254 } break;
255 case SKIP:
256 result_stack.back().emplace_back(*input_stack.back()++);
257 break;
258 case LEAVE:
259 TypePtr result = aggregate_creators.back()();
260 result_stack.pop_back();
261 aggregate_creators.pop_back();
262 input_stack.pop_back();
263 result_stack.back().emplace_back(std::move(result));
264 break;
265 }
266 }
267 AT_ASSERT(result_stack.size() == 1);
268 // FIXME: by doing this only on the inputs, we only capture graph inputs and
269 // not
270 // optionals in tuples or objects. For that to work, we would have
271 // to investigate the uses of the inputs in detail to change the
272 // accesses/ unwrapping
273 auto inputs = graph.inputs();
274 for (const auto i : c10::irange(inputs.size())) {
275 auto t = result_stack.back()[i];
276 if (auto ot = t->cast<OptionalType>()) {
277 // if an optional input hasn't been specialized above, it is None
278 // so we disconnect the input here and replace its uses with
279 // a constant
280 WithInsertPoint guard(*graph.nodes().begin());
281 auto c = graph.insertConstant({});
282 inputs[i]->replaceAllUsesWith(c);
283 } else {
284 inputs[i]->setType(t);
285 }
286 }
287 }
288
289 } // namespace torch::jit
290