xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/argument_spec.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/runtime/argument_spec.h>
3 
4 #include <iostream>
5 
6 namespace torch::jit {
7 
scan(const TypePtr & typ,size_t depth,const WrittenSlots & written_slots)8 void ArgumentSpecCreator::scan(
9     const TypePtr& typ,
10     size_t depth,
11     const WrittenSlots& written_slots) {
12   auto finishAggregate = [&](size_t pos) {
13     // it is possible after all the work we did to scan this aggregate,
14     // we found no tensors or optionals to specialize. In this case, just
15     // generate a skip for the whole aggregate.
16     bool any_spec = std::any_of(
17         instructions_.begin() + pos, instructions_.end(), [](Inst i) {
18           return i == SPECIALIZE_TENSOR || i == SPECIALIZE_OPTIONAL ||
19               i == SPECIALIZE_OPTIONAL_TENSOR;
20         });
21     if (!any_spec) {
22       instructions_[pos] = SKIP;
23       instructions_.resize(pos + 1);
24     } else {
25       instructions_.emplace_back(LEAVE);
26     }
27   };
28   // the simple vm that scans instructions_ has a limited stack depth,
29   // this prevents going deeper than that.
30   if (depth >= ARG_SPEC_DEPTH_LIMIT) {
31     instructions_.emplace_back(SKIP);
32   }
33   if (typ->isSubtypeOf(*TensorType::get())) {
34     num_tensors_++;
35     instructions_.emplace_back(SPECIALIZE_TENSOR);
36   } else if (typ->isSubtypeOf(*OptionalType::ofTensor())) {
37     num_tensors_++;
38     num_optionals_++;
39     instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR);
40   } else if (typ->kind() == TypeKind::OptionalType) {
41     // note that Optional[Tuple] or Optional[Class] will just register
42     // as optional (previously they didn't at all, so it's not a regression).
43     num_optionals_++;
44     instructions_.emplace_back(SPECIALIZE_OPTIONAL);
45   } else if (auto tup = typ->cast<TupleType>()) {
46     size_t pos = instructions_.size();
47     instructions_.emplace_back(ENTER_TUPLE);
48     for (const auto& elem : tup->containedTypes()) {
49       scan(elem, depth + 1, written_slots);
50     }
51     finishAggregate(pos);
52   } else if (auto cls = typ->cast<ClassType>()) {
53     size_t pos = instructions_.size();
54     instructions_.emplace_back(ENTER_OBJECT);
55     for (size_t i = 0; i < cls->numAttributes(); ++i) {
56       auto key =
57           cls->name()->qualifiedName() + cls->getAttributes().at(i).getName();
58       // it is only safe to specialize because someone might have written to it
59       if (!written_slots.count(key)) {
60         scan(cls->containedTypes().at(i), depth + 1, written_slots);
61       } else {
62         instructions_.emplace_back(SKIP);
63       }
64     }
65     finishAggregate(pos);
66   } else {
67     instructions_.emplace_back(SKIP);
68   }
69 };
70 
71 // this is a coarse-grained guarantee that the slots of a class will not be
72 // modified by the function. It works fine for things that used be read-only
73 // modules, but will be overly conservative when some classes are written to.
74 // Doing alias analysis and looking for writes to the class would be more
75 // accurate.
scanWrittenSlots(Block * block,ArgumentSpecCreator::WrittenSlots & written_slots)76 static void scanWrittenSlots(
77     Block* block,
78     ArgumentSpecCreator::WrittenSlots& written_slots) {
79   for (Node* n : block->nodes()) {
80     if (n->kind() == prim::SetAttr) {
81       if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
82         written_slots.insert(cls->name()->qualifiedName() + n->s(attr::name));
83       }
84     }
85     for (Block* subblock : n->blocks()) {
86       scanWrittenSlots(subblock, written_slots);
87     }
88     if (n->hasAttribute(attr::Subgraph)) {
89       scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
90     }
91   }
92 }
93 
ArgumentSpecCreator(Graph & graph)94 ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
95     : num_inputs_(graph.inputs().size()) {
96   WrittenSlots written_slots;
97   scanWrittenSlots(graph.block(), written_slots);
98   for (Value* input : graph.inputs()) {
99     scan(input->type(), 0, written_slots);
100   }
101 }
102 
dump() const103 void ArgumentSpecCreator::dump() const {
104   for (Inst inst : instructions_) {
105     switch (inst) {
106       case LEAVE:
107         std::cout << "] ";
108         break;
109       case ENTER_TUPLE:
110         std::cout << "Tuple[";
111         break;
112       case ENTER_OBJECT:
113         std::cout << "Object[";
114         break;
115       case SKIP:
116         std::cout << "Skip ";
117         break;
118       case SPECIALIZE_TENSOR:
119         std::cout << "SpecializeTensor ";
120         break;
121       case SPECIALIZE_OPTIONAL_TENSOR:
122         std::cout << "SpecializeOptionalTensor ";
123         break;
124       case SPECIALIZE_OPTIONAL:
125         std::cout << "SpecializeOptional ";
126         break;
127     }
128   }
129   std::cout << "\n";
130 }
131 
create(bool with_grad,const Stack & input) const132 ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
133     const {
134   ArgumentSpec spec(num_tensors_, num_optionals_);
135   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
136   const IValue* stack[ARG_SPEC_DEPTH_LIMIT]; // The stack of IValue lists
137   // The stack gets initialized with the input list
138   stack[0] = last(input, num_inputs_).begin();
139   size_t stack_top = 0; // offset to the top of the stack
140   for (Inst inst : instructions_) {
141     switch (inst) {
142       case SPECIALIZE_OPTIONAL_TENSOR: {
143         // consume a tensor optional and add to the argspec
144         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
145         auto& arg = *stack[stack_top]++;
146         spec.addOptional(arg);
147         if (!arg.isNone()) {
148           spec.addTensor(arg, with_grad);
149         }
150       } break;
151       case SPECIALIZE_TENSOR:
152         // consume a tensor and add to the argspec
153         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
154         spec.addTensor(*stack[stack_top]++, with_grad);
155         break;
156       case SPECIALIZE_OPTIONAL:
157         // consume a non-tensor optional and add to the argspec
158         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
159         spec.addOptional(*stack[stack_top]++);
160         break;
161       case ENTER_TUPLE: {
162         // consume tuple
163         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
164         const IValue* iv = stack[stack_top]++;
165         AT_ASSERT(iv->isTuple(), "Expected Tuple but got ", iv->tagKind());
166         auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
167         auto tup_ptr = &p->elements()[0];
168         // push list of tuple elements to the stack
169         stack[++stack_top] = tup_ptr;
170       } break;
171       case ENTER_OBJECT: {
172         // consume object
173         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
174         const IValue* iv = stack[stack_top]++;
175         AT_ASSERT(iv->isObject(), "Expected Object but got ", iv->tagKind());
176         auto obj_ptr = &iv->toObjectRef().slots()[0];
177         // push list of object elements to the stack
178         stack[++stack_top] = obj_ptr;
179       } break;
180       case SKIP:
181         // consume and skip an element
182         // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
183         stack[stack_top]++;
184         break;
185       case LEAVE:
186         --stack_top;
187         break;
188     }
189   }
190   return spec;
191 }
192 
193 // For every input of a given graph, returns a most detailed type that can be
194 // inferred for it based on this ArgumentSpec.
specializeTypes(Graph & graph,const ArgumentSpec & spec) const195 void ArgumentSpecCreator::specializeTypes(
196     Graph& graph,
197     const ArgumentSpec& spec) const {
198   auto input_types =
199       fmap(graph.inputs(), [](Value* input) { return input->type(); });
200   std::vector<std::vector<TypePtr>> result_stack;
201   result_stack.emplace_back();
202   std::vector<const TypePtr*> input_stack = {input_types.data()};
203   std::vector<std::function<TypePtr()>> aggregate_creators;
204 
205   size_t tensor_arg_spec_offset =
206       0; // number of specialized tensors seen so far
207   size_t optional_arg_spec_offset =
208       0; // number of specialized optionals seen so far
209 
210   for (Inst inst : instructions_) {
211     switch (inst) {
212       case SPECIALIZE_OPTIONAL_TENSOR: {
213         auto& input_type = *input_stack.back()++;
214         auto is_present = spec.isPresent(optional_arg_spec_offset++);
215         if (!is_present) {
216           result_stack.back().emplace_back(input_type);
217           break;
218         }
219         auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
220         AT_ASSERT(arg.defined());
221         result_stack.back().emplace_back(arg.toType());
222       } break;
223       case SPECIALIZE_TENSOR: {
224         input_stack.back()++;
225         auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
226         if (!arg.defined()) {
227           result_stack.back().emplace_back(TensorType::get()->withUndefined());
228         } else {
229           result_stack.back().emplace_back(arg.toType());
230         }
231       } break;
232       case SPECIALIZE_OPTIONAL: {
233         auto is_present = spec.isPresent(optional_arg_spec_offset++);
234         auto ot = (*input_stack.back()++)->expect<OptionalType>();
235         if (!is_present) {
236           result_stack.back().emplace_back(ot);
237         } else {
238           result_stack.back().emplace_back(ot->getElementType());
239         }
240       } break;
241       case ENTER_TUPLE: {
242         auto tup = (*input_stack.back()++)->expect<TupleType>();
243         input_stack.emplace_back(tup->elements().data());
244         result_stack.emplace_back();
245         aggregate_creators.emplace_back(
246             [&] { return TupleType::create(result_stack.back()); });
247       } break;
248       case ENTER_OBJECT: {
249         auto cls = (*input_stack.back()++)->expect<ClassType>();
250         input_stack.emplace_back(cls->containedTypes().data());
251         result_stack.emplace_back();
252         aggregate_creators.emplace_back(
253             [&result_stack, cls] { return cls->refine(result_stack.back()); });
254       } break;
255       case SKIP:
256         result_stack.back().emplace_back(*input_stack.back()++);
257         break;
258       case LEAVE:
259         TypePtr result = aggregate_creators.back()();
260         result_stack.pop_back();
261         aggregate_creators.pop_back();
262         input_stack.pop_back();
263         result_stack.back().emplace_back(std::move(result));
264         break;
265     }
266   }
267   AT_ASSERT(result_stack.size() == 1);
268   // FIXME: by doing this only on the inputs, we only capture graph inputs and
269   // not
270   //        optionals in tuples or objects. For that to work, we would have
271   //        to investigate the uses of the inputs in detail to change the
272   //        accesses/ unwrapping
273   auto inputs = graph.inputs();
274   for (const auto i : c10::irange(inputs.size())) {
275     auto t = result_stack.back()[i];
276     if (auto ot = t->cast<OptionalType>()) {
277       // if an optional input hasn't been specialized above, it is None
278       // so we disconnect the input here and replace its uses with
279       // a constant
280       WithInsertPoint guard(*graph.nodes().begin());
281       auto c = graph.insertConstant({});
282       inputs[i]->replaceAllUsesWith(c);
283     } else {
284       inputs[i]->setType(t);
285     }
286   }
287 }
288 
289 } // namespace torch::jit
290