xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/sugared_value.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/sugared_value.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/frontend/tree_views.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/passes/constant_propagation.h>
8 
9 namespace torch::jit {
10 
11 struct NoneValue : SugaredValue {
12   NoneValue() = default;
kindtorch::jit::NoneValue13   std::string kind() const override {
14     return "None";
15   }
16 };
17 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)18 std::shared_ptr<SugaredValue> PrintValue::call(
19     const SourceRange& loc,
20     GraphFunction& m,
21     at::ArrayRef<NamedValue> args,
22     at::ArrayRef<NamedValue> kwargs,
23     size_t n_binders) {
24   auto& g = *m.graph();
25   if (!kwargs.empty())
26     throw(ErrorReport(loc) << "print doesn't accept any keyword arguments");
27 
28   std::vector<Value*> lowered_inputs = toValues(*m.graph(), args);
29   g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc));
30   return std::make_shared<NoneValue>();
31 }
32 
33 static const std::unordered_map<std::string, at::ScalarType>&
builtin_cast_method_to_scalar_type()34 builtin_cast_method_to_scalar_type() {
35   static std::unordered_map<std::string, at::ScalarType> mapping = {
36       {"byte", at::kByte},
37       {"char", at::kChar},
38       {"double", at::kDouble},
39       {"float", at::kFloat},
40       {"cfloat", at::kComplexFloat},
41       {"cdouble", at::kComplexDouble},
42       {"int", at::kInt},
43       {"long", at::kLong},
44       {"short", at::kShort},
45       {"half", at::kHalf}};
46   return mapping;
47 }
48 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)49 std::shared_ptr<SugaredValue> BuiltinFunction::call(
50     const SourceRange& loc,
51     GraphFunction& m,
52     at::ArrayRef<NamedValue> args,
53     at::ArrayRef<NamedValue> kwargs,
54     size_t n_binders) {
55   return std::make_shared<SimpleValue>(
56       emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self));
57 }
58 
59 // older versions of gcc/clang have a bug where enums can't be used as keys
60 // in a map by default
61 // https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
62 struct EnumClassHash {
63   template <typename T>
operator ()torch::jit::EnumClassHash64   std::size_t operator()(T t) const {
65     return static_cast<std::size_t>(t);
66   }
67 };
68 
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)69 bool SimpleValue::hasAttr(
70     const SourceRange& loc,
71     GraphFunction& m,
72     const std::string& field) {
73   if (auto class_type = value_->type()->cast<ClassType>()) {
74     return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
75         class_type->hasConstant(field);
76   } else if (auto tuple_type = value_->type()->cast<TupleType>()) {
77     if (tuple_type->schema()) {
78       for (const auto& arg : tuple_type->schema()->arguments()) {
79         if (arg.name() == field) {
80           return true;
81         }
82       }
83       return false;
84     } else {
85       throw(
86           ErrorReport(loc) << "hasattr's first argument must be a object "
87                            << "or NamedTuple, but got a normal Tuple "
88                            << value_->type()->repr_str() << " instead");
89     }
90   }
91   throw(
92       ErrorReport(loc) << "hasattr's first argument must be an object or "
93                        << "NamedTuple, got " << value_->type()->repr_str()
94                        << " instead");
95 }
96 
97 // support syntax sugar for x.foo(y, z) by allowing x.foo to return a
98 // callable value that will resolve to foo(x, y, z) when called.
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)99 std::shared_ptr<SugaredValue> SimpleValue::attr(
100     const SourceRange& loc,
101     GraphFunction& m,
102     const std::string& field) {
103   // Allow method-style casts on Tensor types. e.g. x.int()
104   if (value_->type()->isSubtypeOf(*TensorType::get())) {
105     if (builtin_cast_method_to_scalar_type().count(field)) {
106       return std::make_shared<TensorCastValue>(
107           builtin_cast_method_to_scalar_type().at(field),
108           NamedValue(loc, "self", value_));
109     }
110   }
111   // accessing properties of Tensor and Device that are implemented as
112   // prim:: or aten:: operators
113   using PropertiesLookup = std::unordered_map<
114       TypeKind,
115       std::unordered_map<std::string, std::string>,
116       EnumClassHash>;
117   static const PropertiesLookup builtin_properties = {
118       {TypeKind::OptionalType,
119        {
120            {"unchecked_unwrap_optional", "prim"},
121        }},
122       {TypeKind::TensorType,
123        {
124            {"dtype", "prim"},
125            {"device", "prim"},
126            {"grad", "prim"},
127            {"data", "prim"},
128            {"shape", "prim"},
129            {"is_cuda", "prim"},
130            {"is_cpu", "prim"},
131            {"is_xla", "prim"},
132            {"is_xpu", "prim"},
133            {"is_sparse", "prim"},
134            {"is_sparse_csr", "prim"},
135            {"is_mkldnn", "prim"},
136            {"is_mps", "prim"},
137            {"is_mtia", "prim"},
138            {"is_quantized", "prim"},
139            {"is_vulkan", "prim"},
140            {"is_ipu", "prim"},
141            {"is_meta", "prim"},
142            {"is_leaf", "aten"},
143            {"is_nested", "prim"},
144            {"requires_grad", "prim"},
145            {"layout", "prim"},
146            {"T", "prim"},
147            {"H", "prim"},
148            {"mT", "aten"},
149            {"mH", "aten"},
150            {"is_maia", "prim"},
151            {"itemsize", "prim"},
152            {"nbytes", "prim"},
153            {"ndim", "prim"},
154            {"name", "prim"},
155            {"real", "aten"},
156            {"imag", "aten"},
157            {"retains_grad", "aten"},
158        }},
159       {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
160   auto kind = value_->type()->kind();
161   auto types_for_builtin = builtin_properties.find(kind);
162   if (types_for_builtin != builtin_properties.end()) {
163     auto builtin_entry = types_for_builtin->second.find(field);
164     if (builtin_entry != types_for_builtin->second.end()) {
165       // A builtin was found, add it to the graph
166       auto the_namespace = builtin_entry->second;
167       auto r = m.graph()->insert(
168           Symbol::fromQualString(the_namespace + "::" + field), {value_});
169       return std::make_shared<SimpleValue>(r);
170     }
171   }
172 
173   // accessing fields of named tuples
174   if (auto tuple_type = value_->type()->cast<TupleType>()) {
175     if (tuple_type->schema()) {
176       auto attrs = tuple_type->schema()->arguments();
177       for (const auto i : c10::irange(attrs.size())) {
178         if (attrs[i].name() == field) {
179           auto idx = m.graph()->insertConstant(IValue(static_cast<int64_t>(i)));
180           auto out_type = tuple_type->elements().at(i);
181           auto r = m.graph()
182                        ->insertNode(
183                            m.graph()->createTupleIndex(value_, idx, out_type))
184                        ->output();
185           return std::make_shared<SimpleValue>(r);
186         }
187       }
188     }
189   } else if (auto awaitType = value_->type()->cast<AwaitType>()) {
190     auto elType = awaitType->getElementType();
191     auto& g = *m.graph();
192     auto v = g.insert(prim::awaitable_wait, {value_}, {}, loc);
193     auto sv = std::make_shared<SimpleValue>(v);
194     return sv->attr(loc, m, field);
195   } else if (auto classType = value_->type()->cast<ClassType>()) {
196     // This is a class, emit the proper attribute lookup
197     if (classType->findMethod(field)) {
198       return std::make_shared<MethodValue>(getValue(), field);
199     }
200     if (classType->hasAttribute(field)) {
201       auto& g = *m.graph();
202       auto n = g.insertNode(g.createGetAttr(value_, field));
203       return std::make_shared<SimpleValue>(n->output());
204     }
205     // Check and see if it's a getter attribute.
206     auto prop = classType->getProperty(field);
207     if (prop) {
208       return MethodValue(value_, prop->getter->name())
209           .call(loc, m, {}, {}, /*n_binders=*/1);
210     }
211   } else if (auto iface = value_->type()->cast<InterfaceType>()) {
212     // accessing methods of interfaces
213     if (iface->getMethod(field)) {
214       return std::make_shared<MethodValue>(getValue(), field);
215     }
216   } else if (auto enum_type = value_->type()->cast<EnumType>()) {
217     // Handle access to Enum's `name` and `value` attribute.
218     auto& g = *m.graph();
219 
220     if (field == "name") {
221       auto n = g.insertNode(g.createEnumName(value_));
222       return std::make_shared<SimpleValue>(n->output());
223     }
224 
225     if (field == "value") {
226       auto n = g.insertNode(g.createEnumValue(value_));
227       return std::make_shared<SimpleValue>(n->output());
228     }
229   }
230 
231   // none of the more-specific cases worked, so see if this is a builtin method
232   // If field is a type, then call the aten::to op
233   if (field == "type") {
234     if (auto builtin = BuiltinFunction::tryCreate(
235             Symbol::aten("to"), NamedValue(loc, "self", value_))) {
236       return builtin;
237     }
238   }
239 
240   if (auto builtin = BuiltinFunction::tryCreate(
241           Symbol::aten(field), NamedValue(loc, "self", value_))) {
242     return builtin;
243   }
244 
245   // Handle calling tolist() on a Tensor.
246   if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
247     return SpecialFormValue::create(prim::tolist);
248   }
249 
250   // Handle calling __getitem__() directly on a Tensor, it needs special
251   // handling because desired method name (`__getitem__`) doesn't match `aten`
252   // operator name of `aten::index`.
253   if (value_->type()->isSubtypeOf(*TensorType::get()) &&
254       field == "__getitem__") {
255     return SpecialFormValue::create(aten::index);
256   }
257 
258   if (auto generator_type = value_->type()->cast<GeneratorType>()) {
259     // Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
260     // attributes.
261     if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
262       if (auto builtin = BuiltinFunction::tryCreate(
263               Symbol::aten(field), NamedValue(loc, "self", value_))) {
264         return builtin;
265       }
266     }
267   }
268 
269   ErrorReport report(loc);
270   report << "'" << value_->type()->repr_str()
271          << "' object has no attribute or method '" << field << "'.";
272   if (auto classType = value_->type()->cast<ClassType>()) {
273     if (classType->isUnresolvedClassAttribute(field)) {
274       report
275           << " '" << field
276           << "' is defined as a class attribute which currently is not"
277              " supported. Consider converting this to an instance attribute.";
278     } else {
279       report << " Did you forget to initialize an attribute in __init__()?";
280     }
281   }
282   throw ErrorReport(report);
283 }
284 
asTuple(const SourceRange & loc,GraphFunction & m,const std::optional<size_t> & size_hint)285 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
286     const SourceRange& loc,
287     GraphFunction& m,
288     const std::optional<size_t>& size_hint) {
289   static const auto make_simple_value =
290       [](Value* v) -> std::shared_ptr<SugaredValue> {
291     return std::make_shared<SimpleValue>(v);
292   };
293   if (value_->type()->kind() == TypeKind::TupleType) {
294     auto outputs = createTupleUnpack(value_);
295     return fmap(outputs, make_simple_value);
296   } else if (value_->type()->kind() == TypeKind::ListType) {
297     if (!size_hint) {
298       throw(
299           ErrorReport(loc) << "cannot statically infer the expected size of a "
300                            << "list in this context");
301     }
302     auto graph = value_->owningGraph();
303     Node* unpack =
304         graph->insertNode(graph->createListUnpack(value_, *size_hint));
305     return fmap(unpack->outputs(), make_simple_value);
306   } else if (value_->type()->kind() == TypeKind::AnyTupleType) {
307     throw(
308         ErrorReport(loc)
309         << "Provided tuple is not fully defined/refined including its element types, please provide a value of type like Tuple[int, int]");
310   }
311   throw(
312       ErrorReport(loc) << value_->type()->repr_str()
313                        << " cannot be used as a tuple");
314 }
315 
isRecursive(const TypePtr & classType,const TypePtr & attrType)316 static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
317   if (attrType->isSubtypeOf(*classType)) {
318     return true;
319   }
320 
321   // Recursively check contained types. We need to do this because a user may do
322   // A -> B -> A.
323   for (const auto& type : attrType->containedTypes()) {
324     if (isRecursive(classType, type)) {
325       return true;
326     }
327   }
328   return false;
329 }
330 
setAttr(const SourceRange & loc,GraphFunction & m,const std::string & field,Value * newValue)331 void SimpleValue::setAttr(
332     const SourceRange& loc,
333     GraphFunction& m,
334     const std::string& field,
335     Value* newValue) {
336   const auto classType = value_->type()->cast<ClassType>();
337   if (!classType) {
338     throw(
339         ErrorReport(loc) << "Tried to set an attribute: " << field
340                          << " on a non-class: " << value_->type()->repr_str());
341   }
342   auto expectedType = classType->findAttribute(field);
343   if (!expectedType) {
344     // If we are still compiling the __init__ method for this class, then
345     // setting an unknown attribute adds it to the class's definition.
346 
347     // We are initializing if:
348     const auto isInitializing =
349         // 1. The method we're currently inserting into is an init method
350         // TODO this can be a qualified name check
351         m.name() == "__init__" &&
352         // 2. The `self` arg matches this value's type (i.e. we are in the init
353         // method for this class, not some other class)
354         !m.graph()->inputs().empty() &&
355         m.graph()->inputs().at(0)->type() == classType;
356 
357     if (isInitializing) {
358       if (isRecursive(classType, newValue->type())) {
359         throw(
360             ErrorReport(loc)
361             << "Assignment to attribute '" << field
362             << "' cannot be of a type that contains class "
363             << "'" << classType->repr_str() << "'.\n"
364             << "Classes that recursively contain instances of themselves"
365             << " are not yet supported");
366       }
367 
368       classType->addAttribute(field, newValue->type());
369       expectedType = newValue->type();
370 
371       const auto insertPoint = m.graph()->insertPoint();
372       const auto topLevelBlock = m.graph()->block();
373       if (insertPoint->owningBlock() != topLevelBlock) {
374         throw(
375             ErrorReport(loc)
376             << "First assignment cannot be in a control-flow block. "
377             << "Initialize the field at the top level first");
378       }
379     } else {
380       // Check and see if it's a setter attribute.
381       auto prop = classType->getProperty(field);
382       if (prop && prop->setter) {
383         MethodValue(value_, prop->setter->name())
384             .call(loc, m, {newValue}, {}, /*n_binders=*/1);
385         return;
386       }
387 
388       if (prop && !prop->setter) {
389         throw(
390             ErrorReport(loc) << "Tried to set read-only attribute: " << field);
391       }
392 
393       throw(
394           ErrorReport(loc)
395           << "Tried to set nonexistent attribute: " << field
396           << ". Did you forget to initialize it in __init__()?");
397     }
398   }
399 
400   AT_ASSERT(expectedType);
401 
402   // Check type correctness
403   const auto newType = newValue->type();
404   if (!newType->isSubtypeOf(*expectedType)) {
405     throw(
406         ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
407                          << expectedType->repr_str() << " but got "
408                          << newType->repr_str());
409   }
410 
411   auto& g = *m.graph();
412   g.insertNode(g.createSetAttr(value_, field, newValue));
413 }
414 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)415 std::shared_ptr<SugaredValue> SimpleValue::call(
416     const SourceRange& loc,
417     GraphFunction& m,
418     at::ArrayRef<NamedValue> args,
419     at::ArrayRef<NamedValue> kwargs,
420     size_t n_binders) {
421   // allow our 'fake' closures to be called, used for fork serialization
422   // at the moment, but can be expanded later
423   Node* self = getValue()->node();
424   if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
425       self->inputs().at(0)->node()->kind() == prim::Closure) {
426     std::shared_ptr<Graph> graph =
427         self->inputs().at(0)->node()->g(attr::Subgraph);
428     Value* context = self->inputs().at(1);
429     AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
430 
431     // fork nodes are emitted in their own block but we do not simplify
432     // tuple construction across blocks. To ensure we clean up the tuple
433     // construct create another copy of the tuple construct in the fork block
434     Value* close_context =
435         m.graph()
436             ->insertNode(m.graph()->createTuple(context->node()->inputs()))
437             ->output();
438     // TODO this needs to go in `m`s compilation unit
439     auto cu = std::make_shared<CompilationUnit>();
440     auto fn = cu->create_function(QualifiedName("anon"), graph);
441     auto ret = StrongFunctionPtr(std::move(cu), fn);
442 
443     std::vector<NamedValue> ctx_inputs = {close_context};
444     ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end());
445     return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders);
446   }
447 
448   if (auto class_type = getValue()->type()->cast<ClassType>()) {
449     return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders);
450   }
451 
452   return SugaredValue::call(loc, m, args, kwargs, n_binders);
453 }
454 
len(const SourceRange & loc,GraphFunction & m)455 Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
456   // List, Tuple, Tensor, fill in missing information desugaring
457   Value* val = getValue();
458   TypePtr val_type = val->type();
459   Graph& g = *m.graph();
460   if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
461       val_type->isSubtypeOf(*TensorType::get())) {
462     return g.insert(aten::len, {val}, {}, loc);
463   } else {
464     throw(
465         ErrorReport(loc) << "'" << val_type->repr_str() << "'"
466                          << " object is not iterable");
467   }
468 }
469 
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)470 SugaredValuePtr SimpleValue::getitem(
471     const SourceRange& loc,
472     GraphFunction& m,
473     Value* idx,
474     TypePtr type_hint) {
475   Value* val = getValue();
476   TypePtr val_type = val->type();
477   Graph& g = *m.graph();
478 
479   // if it's a List/String/Dict, emit a regular __getitem__ op
480   // NOLINTNEXTLINE(bugprone-branch-clone)
481   if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
482     return std::make_shared<SimpleValue>(
483         g.insert(aten::__getitem__, {val, idx}, {}, loc));
484   } else if (auto dict_type = val_type->cast<DictType>()) {
485     return std::make_shared<SimpleValue>(
486         g.insert(aten::__getitem__, {val, idx}, {}, loc));
487   } else if (val_type->isSubtypeOf(*TensorType::get())) {
488     return std::make_shared<SimpleValue>(
489         g.insert(aten::select, {val, 0, idx}, {}, loc));
490   } else if (auto class_type = val_type->cast<ClassType>()) {
491     // Check if this is an indexing operation enabled by a type hint.
492     // The ModuleDict has already been checked during IR generation to make
493     // sure its contents implement the module interface referred to by
494     // type_hint.
495     if (class_type->is_module() && type_hint) {
496       auto res = g.insert(prim::ModuleContainerIndex, {val, idx}, {}, loc);
497       res->setType(type_hint);
498       return std::make_shared<SimpleValue>(res);
499     }
500 
501     // Defer to the __getitem__ attr on the class.
502     return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
503   } else {
504     throw(
505         ErrorReport(loc) << "'" << val_type->repr_str() << "'"
506                          << " object is not subscriptable");
507   }
508 }
509 
iter(const SourceRange & loc,GraphFunction & m)510 SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
511   auto value = getValue();
512   auto type = value->type();
513   // built-in iterable types
514   if (type->cast<ListType>() || type->cast<StringType>() ||
515       type->cast<TensorType>()) {
516     return std::make_shared<SimpleValue>(value);
517   }
518   // dicts iterate over keys
519   if (type->cast<DictType>()) {
520     return std::make_shared<SimpleValue>(
521         m.graph()->insert(aten::keys, {value}, {}, loc));
522   }
523   if (auto tup = type->cast<TupleType>()) {
524     auto tup_values = createTupleUnpack(value);
525     std::vector<SugaredValuePtr> tup_sugared;
526     for (Value* v : tup_values) {
527       tup_sugared.push_back(std::make_shared<SimpleValue>(v));
528     }
529     return std::make_shared<SugaredTupleValue>(tup_sugared);
530   } else {
531     throw(
532         ErrorReport(loc) << "'" << type->repr_str() << "'"
533                          << " object is not iterable");
534   }
535 }
536 
RangeValue(const SourceRange & loc,GraphFunction & m,std::vector<Value * > inputs,std::optional<int64_t> static_len)537 RangeValue::RangeValue(
538     const SourceRange& loc,
539     GraphFunction& m,
540     std::vector<Value*> inputs,
541     std::optional<int64_t> static_len) {
542   for (const auto i : c10::irange(inputs.size())) {
543     auto typ = inputs[i]->type();
544     if (!typ->cast<IntType>()) {
545       throw(
546           ErrorReport(loc) << "all inputs of range must be ints, found "
547                            << typ->repr_str() << " in argument "
548                            << std::to_string(i));
549     }
550   }
551 
552   Graph& g = *m.graph();
553   if (inputs.empty()) {
554     throw(ErrorReport(loc) << "range expected at least 1 arguments, got 0");
555   } else if (inputs.size() == 1) {
556     end_ = inputs[0];
557     start_ = g.insertConstant(0, loc);
558     step_ = g.insertConstant(1, loc);
559     // range() call only contains end, easier to calculate len() and getitem()
560     has_only_end_ = true;
561   } else if (inputs.size() <= 3) {
562     start_ = inputs[0];
563     end_ = inputs[1];
564     if (inputs.size() == 3) {
565       step_ = inputs[2];
566     } else {
567       step_ = g.insertConstant(1, loc);
568     }
569     has_only_end_ = false;
570   } else {
571     throw(
572         ErrorReport(loc) << "range expected at most 3 arguments, got "
573                          << inputs.size());
574   }
575 
576   static_len_ = static_len;
577 }
578 
iter(const SourceRange & loc,GraphFunction & m)579 SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
580   return shared_from_this();
581 };
582 
len(const SourceRange & loc,GraphFunction & m)583 Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
584   if (static_len_) {
585     return insertConstant(*m.graph(), *static_len_, loc);
586   }
587   if (has_only_end_) {
588     return end_;
589   } else {
590     Graph& g = *m.graph();
591     return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc);
592   }
593 }
594 
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)595 SugaredValuePtr RangeValue::getitem(
596     const SourceRange& loc,
597     GraphFunction& m,
598     Value* idx,
599     TypePtr type_hint) {
600   if (has_only_end_) {
601     return std::make_shared<SimpleValue>(idx);
602   } else {
603     auto& g = *m.graph();
604     return std::make_shared<SimpleValue>(
605         g.insert(aten::__derive_index, {idx, start_, step_}, {}, loc));
606   }
607 }
608 
get_base_iterables()609 std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
610   std::vector<SugaredValuePtr> base_iters{};
611 
612   for (SugaredValuePtr& sv : children_) {
613     if (auto iv = std::dynamic_pointer_cast<IterableTree>(sv)) {
614       std::vector<SugaredValuePtr> child_iters = iv->get_base_iterables();
615       // merge child iters with the base_iters
616       base_iters.insert(
617           base_iters.end(),
618           std::make_move_iterator(child_iters.begin()),
619           std::make_move_iterator(child_iters.end()));
620 
621     } else {
622       // IterableTree leaves, either SimpleValue or RangeValue
623       base_iters.emplace_back(sv);
624     }
625   }
626   return base_iters;
627 }
628 
len(const SourceRange & loc,GraphFunction & m)629 Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
630   // if it's a iterable tree, we get the base iterables that consists of
631   // SimpleValue or RangeValue, and then calculate the minimum length of all the
632   // base iterables to be max_trip_count_val
633   TORCH_INTERNAL_ASSERT(!unroll_length_);
634   Graph& g = *m.graph();
635   std::vector<SugaredValuePtr> base_iters = get_base_iterables();
636   std::vector<Value*> lengths;
637   lengths.reserve(base_iters.size());
638 
639   for (const SugaredValuePtr& base_iter : base_iters) {
640     lengths.emplace_back(base_iter->len(loc, m));
641   }
642   Node* list_node = g.insertNode(g.createList(IntType::get(), lengths));
643   return g.insert(prim::min, {list_node->output()}, {}, loc);
644 }
645 
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)646 SugaredValuePtr IterableTree::getitem(
647     const SourceRange& loc,
648     GraphFunction& m,
649     Value* idx,
650     TypePtr type_hint) {
651   std::vector<SugaredValuePtr> child_items;
652   child_items.reserve(children_.size());
653   for (const SugaredValuePtr& child : children_) {
654     child_items.emplace_back(child->getitem(loc, m, idx));
655   }
656   return std::make_shared<SugaredTupleValue>(child_items);
657 }
658 
addChild(const SourceRange & range,GraphFunction & m,const SugaredValuePtr & iter_value)659 void IterableTree::addChild(
660     const SourceRange& range,
661     GraphFunction& m,
662     const SugaredValuePtr& iter_value) {
663   std::optional<int64_t> child_len = iter_value->staticLen();
664   if (children_.empty()) {
665     unroll_length_ = child_len;
666   } else {
667     if ((unroll_length_ && !child_len) || (child_len && !unroll_length_)) {
668       throw(
669           ErrorReport(range)
670           << "Can not iterate over a module list or tuple with a value "
671              "that does not have a statically determinable length\n");
672     }
673     if (unroll_length_ && child_len) {
674       // iterables run for the minimum length of all its leaves
675       unroll_length_ = std::min(*child_len, *unroll_length_);
676     } else {
677       unroll_length_ = std::nullopt;
678     }
679   }
680   children_.push_back(iter_value);
681 }
682 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)683 std::shared_ptr<SugaredValue> MagicMethod::call(
684     const SourceRange& loc,
685     GraphFunction& m,
686     at::ArrayRef<NamedValue> args,
687     at::ArrayRef<NamedValue> kwargs,
688     size_t n_binders) {
689   if (!args.empty()) {
690     Value* self = args[0].value(*m.graph());
691     if (auto class_ptr = self->type()->cast<ClassType>()) {
692       return SimpleValue(self)
693           .attr(loc, m, desugared_name_)
694           ->call(loc, m, args.slice(1), kwargs, n_binders);
695     }
696   }
697   TORCH_INTERNAL_ASSERT(base_value_);
698   return base_value_->call(loc, m, args, kwargs, n_binders);
699 }
700 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)701 std::shared_ptr<SugaredValue> ClassValue::call(
702     const SourceRange& loc,
703     GraphFunction& m,
704     // note: names for args will be 'argument 0', 'argument 1', etc..
705     at::ArrayRef<NamedValue> args,
706     at::ArrayRef<NamedValue> kwargs,
707     size_t n_binders) {
708   AT_ASSERT(n_binders <= 1);
709 
710   // Generate a new object of the right type, then call `__init__` on it
711   auto& g = *m.graph();
712   auto self = g.insertNode(g.createObject(type_))->output();
713   self->node()->setSourceRange(loc);
714   if (!type_->findMethod("__init__")) {
715     throw(
716         ErrorReport(loc) << "Class " << type_->name()->name()
717                          << " does not have an __init__ function defined");
718   }
719 
720   // Call the init function
721   MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders);
722 
723   return std::make_shared<SimpleValue>(self);
724 }
725 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)726 std::shared_ptr<SugaredValue> ClassValue::attr(
727     const SourceRange& loc,
728     GraphFunction& m,
729     const std::string& field) {
730   // Allow import_source.cpp to resolve calls to a submodule's
731   // hooks. Edge case because normally you wouldn't allow a module to
732   // call functions of a submodule
733   if (Function* hook = type_->findHook(field)) {
734     return std::make_shared<FunctionValue>(hook);
735   }
736 
737   if (field != "__new__") {
738     throw(
739         ErrorReport(loc) << "Tried to lookup unknown attribute on class "
740                          << type_->annotation_str());
741   }
742   return SpecialFormValue::create(prim::CreateObject);
743 }
744 
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)745 std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
746     const SourceRange& loc,
747     GraphFunction& m,
748     at::ArrayRef<NamedValue> args,
749     at::ArrayRef<NamedValue> kwargs,
750     size_t n_binders) {
751   auto& g = *m.graph();
752 
753   auto schema = type_->schema();
754   TORCH_INTERNAL_ASSERT(schema);
755   auto qualname = type_->name();
756   auto matched_schema = matchSchema(*schema, loc, g, args, kwargs);
757 
758   auto self =
759       g.insertNode(
760            g.createTuple(matched_schema.inputs, type_)->setSourceRange(loc))
761           ->output();
762   self->setType(type_);
763 
764   return std::make_shared<SimpleValue>(self);
765 }
766 
tryCreate(Symbol symbol,std::optional<NamedValue> self)767 std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
768     Symbol symbol,
769     std::optional<NamedValue> self) {
770   for (const std::shared_ptr<Operator>& op : getAllOperatorsFor(symbol)) {
771     if (!self) {
772       return std::make_shared<BuiltinFunction>(symbol, nullptr);
773     }
774     if (auto index = op->schema().argumentIndexWithName("self")) {
775       std::unordered_map<std::string, TypePtr> type_env;
776       TypePtr formal_type = op->schema().arguments().at(*index).type();
777       const MatchTypeReturn matched =
778           matchTypeVariables(formal_type, self->type(), type_env);
779       if (!matched.success()) {
780         continue;
781       }
782       const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
783       if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
784         continue;
785       }
786       return std::make_shared<BuiltinFunction>(symbol, self);
787     }
788   }
789   return nullptr;
790 }
791 
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)792 std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
793     const SourceRange& loc,
794     GraphFunction& m,
795     const std::string& field) {
796   const auto& names_values = enum_type_->enumNamesValues();
797   auto it = std::find_if(
798       names_values.begin(),
799       names_values.end(),
800       [&field](const at::EnumNameValue& nv) { return nv.first == field; });
801   if (it == names_values.end()) {
802     throw(
803         ErrorReport(loc) << enum_type_->repr_str() << "'"
804                          << " has no attribute '" << field << "'");
805   }
806   auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
807       enum_type_, it->first, it->second);
808   return std::make_shared<SimpleValue>(
809       m.graph()->insertConstant(IValue(enum_holder), loc));
810 }
811 
iter(const SourceRange & loc,GraphFunction & m)812 SugaredValuePtr SugaredEnumClass::iter(
813     const SourceRange& loc,
814     GraphFunction& m) {
815   const auto& names_values = enum_type_->enumNamesValues();
816   auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
817   enum_value_ivalues.reserve(names_values.size());
818   for (const auto& name_value : names_values) {
819     auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
820         enum_type_, name_value.first, name_value.second);
821     enum_value_ivalues.emplace_back(enum_holder);
822   }
823 
824   auto enum_values_list_constant = std::make_shared<SimpleValue>(
825       m.graph()->insertConstant(enum_value_ivalues, loc));
826   return enum_values_list_constant;
827 }
828 
829 } // namespace torch::jit
830