1 #include <torch/csrc/jit/frontend/sugared_value.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/frontend/tree_views.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/passes/constant_propagation.h>
8
9 namespace torch::jit {
10
11 struct NoneValue : SugaredValue {
12 NoneValue() = default;
kindtorch::jit::NoneValue13 std::string kind() const override {
14 return "None";
15 }
16 };
17
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)18 std::shared_ptr<SugaredValue> PrintValue::call(
19 const SourceRange& loc,
20 GraphFunction& m,
21 at::ArrayRef<NamedValue> args,
22 at::ArrayRef<NamedValue> kwargs,
23 size_t n_binders) {
24 auto& g = *m.graph();
25 if (!kwargs.empty())
26 throw(ErrorReport(loc) << "print doesn't accept any keyword arguments");
27
28 std::vector<Value*> lowered_inputs = toValues(*m.graph(), args);
29 g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc));
30 return std::make_shared<NoneValue>();
31 }
32
33 static const std::unordered_map<std::string, at::ScalarType>&
builtin_cast_method_to_scalar_type()34 builtin_cast_method_to_scalar_type() {
35 static std::unordered_map<std::string, at::ScalarType> mapping = {
36 {"byte", at::kByte},
37 {"char", at::kChar},
38 {"double", at::kDouble},
39 {"float", at::kFloat},
40 {"cfloat", at::kComplexFloat},
41 {"cdouble", at::kComplexDouble},
42 {"int", at::kInt},
43 {"long", at::kLong},
44 {"short", at::kShort},
45 {"half", at::kHalf}};
46 return mapping;
47 }
48
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)49 std::shared_ptr<SugaredValue> BuiltinFunction::call(
50 const SourceRange& loc,
51 GraphFunction& m,
52 at::ArrayRef<NamedValue> args,
53 at::ArrayRef<NamedValue> kwargs,
54 size_t n_binders) {
55 return std::make_shared<SimpleValue>(
56 emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self));
57 }
58
59 // older versions of gcc/clang have a bug where enums can't be used as keys
60 // in a map by default
61 // https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
62 struct EnumClassHash {
63 template <typename T>
operator ()torch::jit::EnumClassHash64 std::size_t operator()(T t) const {
65 return static_cast<std::size_t>(t);
66 }
67 };
68
hasAttr(const SourceRange & loc,GraphFunction & m,const std::string & field)69 bool SimpleValue::hasAttr(
70 const SourceRange& loc,
71 GraphFunction& m,
72 const std::string& field) {
73 if (auto class_type = value_->type()->cast<ClassType>()) {
74 return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
75 class_type->hasConstant(field);
76 } else if (auto tuple_type = value_->type()->cast<TupleType>()) {
77 if (tuple_type->schema()) {
78 for (const auto& arg : tuple_type->schema()->arguments()) {
79 if (arg.name() == field) {
80 return true;
81 }
82 }
83 return false;
84 } else {
85 throw(
86 ErrorReport(loc) << "hasattr's first argument must be a object "
87 << "or NamedTuple, but got a normal Tuple "
88 << value_->type()->repr_str() << " instead");
89 }
90 }
91 throw(
92 ErrorReport(loc) << "hasattr's first argument must be an object or "
93 << "NamedTuple, got " << value_->type()->repr_str()
94 << " instead");
95 }
96
97 // support syntax sugar for x.foo(y, z) by allowing x.foo to return a
98 // callable value that will resolve to foo(x, y, z) when called.
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)99 std::shared_ptr<SugaredValue> SimpleValue::attr(
100 const SourceRange& loc,
101 GraphFunction& m,
102 const std::string& field) {
103 // Allow method-style casts on Tensor types. e.g. x.int()
104 if (value_->type()->isSubtypeOf(*TensorType::get())) {
105 if (builtin_cast_method_to_scalar_type().count(field)) {
106 return std::make_shared<TensorCastValue>(
107 builtin_cast_method_to_scalar_type().at(field),
108 NamedValue(loc, "self", value_));
109 }
110 }
111 // accessing properties of Tensor and Device that are implemented as
112 // prim:: or aten:: operators
113 using PropertiesLookup = std::unordered_map<
114 TypeKind,
115 std::unordered_map<std::string, std::string>,
116 EnumClassHash>;
117 static const PropertiesLookup builtin_properties = {
118 {TypeKind::OptionalType,
119 {
120 {"unchecked_unwrap_optional", "prim"},
121 }},
122 {TypeKind::TensorType,
123 {
124 {"dtype", "prim"},
125 {"device", "prim"},
126 {"grad", "prim"},
127 {"data", "prim"},
128 {"shape", "prim"},
129 {"is_cuda", "prim"},
130 {"is_cpu", "prim"},
131 {"is_xla", "prim"},
132 {"is_xpu", "prim"},
133 {"is_sparse", "prim"},
134 {"is_sparse_csr", "prim"},
135 {"is_mkldnn", "prim"},
136 {"is_mps", "prim"},
137 {"is_mtia", "prim"},
138 {"is_quantized", "prim"},
139 {"is_vulkan", "prim"},
140 {"is_ipu", "prim"},
141 {"is_meta", "prim"},
142 {"is_leaf", "aten"},
143 {"is_nested", "prim"},
144 {"requires_grad", "prim"},
145 {"layout", "prim"},
146 {"T", "prim"},
147 {"H", "prim"},
148 {"mT", "aten"},
149 {"mH", "aten"},
150 {"is_maia", "prim"},
151 {"itemsize", "prim"},
152 {"nbytes", "prim"},
153 {"ndim", "prim"},
154 {"name", "prim"},
155 {"real", "aten"},
156 {"imag", "aten"},
157 {"retains_grad", "aten"},
158 }},
159 {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
160 auto kind = value_->type()->kind();
161 auto types_for_builtin = builtin_properties.find(kind);
162 if (types_for_builtin != builtin_properties.end()) {
163 auto builtin_entry = types_for_builtin->second.find(field);
164 if (builtin_entry != types_for_builtin->second.end()) {
165 // A builtin was found, add it to the graph
166 auto the_namespace = builtin_entry->second;
167 auto r = m.graph()->insert(
168 Symbol::fromQualString(the_namespace + "::" + field), {value_});
169 return std::make_shared<SimpleValue>(r);
170 }
171 }
172
173 // accessing fields of named tuples
174 if (auto tuple_type = value_->type()->cast<TupleType>()) {
175 if (tuple_type->schema()) {
176 auto attrs = tuple_type->schema()->arguments();
177 for (const auto i : c10::irange(attrs.size())) {
178 if (attrs[i].name() == field) {
179 auto idx = m.graph()->insertConstant(IValue(static_cast<int64_t>(i)));
180 auto out_type = tuple_type->elements().at(i);
181 auto r = m.graph()
182 ->insertNode(
183 m.graph()->createTupleIndex(value_, idx, out_type))
184 ->output();
185 return std::make_shared<SimpleValue>(r);
186 }
187 }
188 }
189 } else if (auto awaitType = value_->type()->cast<AwaitType>()) {
190 auto elType = awaitType->getElementType();
191 auto& g = *m.graph();
192 auto v = g.insert(prim::awaitable_wait, {value_}, {}, loc);
193 auto sv = std::make_shared<SimpleValue>(v);
194 return sv->attr(loc, m, field);
195 } else if (auto classType = value_->type()->cast<ClassType>()) {
196 // This is a class, emit the proper attribute lookup
197 if (classType->findMethod(field)) {
198 return std::make_shared<MethodValue>(getValue(), field);
199 }
200 if (classType->hasAttribute(field)) {
201 auto& g = *m.graph();
202 auto n = g.insertNode(g.createGetAttr(value_, field));
203 return std::make_shared<SimpleValue>(n->output());
204 }
205 // Check and see if it's a getter attribute.
206 auto prop = classType->getProperty(field);
207 if (prop) {
208 return MethodValue(value_, prop->getter->name())
209 .call(loc, m, {}, {}, /*n_binders=*/1);
210 }
211 } else if (auto iface = value_->type()->cast<InterfaceType>()) {
212 // accessing methods of interfaces
213 if (iface->getMethod(field)) {
214 return std::make_shared<MethodValue>(getValue(), field);
215 }
216 } else if (auto enum_type = value_->type()->cast<EnumType>()) {
217 // Handle access to Enum's `name` and `value` attribute.
218 auto& g = *m.graph();
219
220 if (field == "name") {
221 auto n = g.insertNode(g.createEnumName(value_));
222 return std::make_shared<SimpleValue>(n->output());
223 }
224
225 if (field == "value") {
226 auto n = g.insertNode(g.createEnumValue(value_));
227 return std::make_shared<SimpleValue>(n->output());
228 }
229 }
230
231 // none of the more-specific cases worked, so see if this is a builtin method
232 // If field is a type, then call the aten::to op
233 if (field == "type") {
234 if (auto builtin = BuiltinFunction::tryCreate(
235 Symbol::aten("to"), NamedValue(loc, "self", value_))) {
236 return builtin;
237 }
238 }
239
240 if (auto builtin = BuiltinFunction::tryCreate(
241 Symbol::aten(field), NamedValue(loc, "self", value_))) {
242 return builtin;
243 }
244
245 // Handle calling tolist() on a Tensor.
246 if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
247 return SpecialFormValue::create(prim::tolist);
248 }
249
250 // Handle calling __getitem__() directly on a Tensor, it needs special
251 // handling because desired method name (`__getitem__`) doesn't match `aten`
252 // operator name of `aten::index`.
253 if (value_->type()->isSubtypeOf(*TensorType::get()) &&
254 field == "__getitem__") {
255 return SpecialFormValue::create(aten::index);
256 }
257
258 if (auto generator_type = value_->type()->cast<GeneratorType>()) {
259 // Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
260 // attributes.
261 if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
262 if (auto builtin = BuiltinFunction::tryCreate(
263 Symbol::aten(field), NamedValue(loc, "self", value_))) {
264 return builtin;
265 }
266 }
267 }
268
269 ErrorReport report(loc);
270 report << "'" << value_->type()->repr_str()
271 << "' object has no attribute or method '" << field << "'.";
272 if (auto classType = value_->type()->cast<ClassType>()) {
273 if (classType->isUnresolvedClassAttribute(field)) {
274 report
275 << " '" << field
276 << "' is defined as a class attribute which currently is not"
277 " supported. Consider converting this to an instance attribute.";
278 } else {
279 report << " Did you forget to initialize an attribute in __init__()?";
280 }
281 }
282 throw ErrorReport(report);
283 }
284
asTuple(const SourceRange & loc,GraphFunction & m,const std::optional<size_t> & size_hint)285 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
286 const SourceRange& loc,
287 GraphFunction& m,
288 const std::optional<size_t>& size_hint) {
289 static const auto make_simple_value =
290 [](Value* v) -> std::shared_ptr<SugaredValue> {
291 return std::make_shared<SimpleValue>(v);
292 };
293 if (value_->type()->kind() == TypeKind::TupleType) {
294 auto outputs = createTupleUnpack(value_);
295 return fmap(outputs, make_simple_value);
296 } else if (value_->type()->kind() == TypeKind::ListType) {
297 if (!size_hint) {
298 throw(
299 ErrorReport(loc) << "cannot statically infer the expected size of a "
300 << "list in this context");
301 }
302 auto graph = value_->owningGraph();
303 Node* unpack =
304 graph->insertNode(graph->createListUnpack(value_, *size_hint));
305 return fmap(unpack->outputs(), make_simple_value);
306 } else if (value_->type()->kind() == TypeKind::AnyTupleType) {
307 throw(
308 ErrorReport(loc)
309 << "Provided tuple is not fully defined/refined including its element types, please provide a value of type like Tuple[int, int]");
310 }
311 throw(
312 ErrorReport(loc) << value_->type()->repr_str()
313 << " cannot be used as a tuple");
314 }
315
isRecursive(const TypePtr & classType,const TypePtr & attrType)316 static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
317 if (attrType->isSubtypeOf(*classType)) {
318 return true;
319 }
320
321 // Recursively check contained types. We need to do this because a user may do
322 // A -> B -> A.
323 for (const auto& type : attrType->containedTypes()) {
324 if (isRecursive(classType, type)) {
325 return true;
326 }
327 }
328 return false;
329 }
330
setAttr(const SourceRange & loc,GraphFunction & m,const std::string & field,Value * newValue)331 void SimpleValue::setAttr(
332 const SourceRange& loc,
333 GraphFunction& m,
334 const std::string& field,
335 Value* newValue) {
336 const auto classType = value_->type()->cast<ClassType>();
337 if (!classType) {
338 throw(
339 ErrorReport(loc) << "Tried to set an attribute: " << field
340 << " on a non-class: " << value_->type()->repr_str());
341 }
342 auto expectedType = classType->findAttribute(field);
343 if (!expectedType) {
344 // If we are still compiling the __init__ method for this class, then
345 // setting an unknown attribute adds it to the class's definition.
346
347 // We are initializing if:
348 const auto isInitializing =
349 // 1. The method we're currently inserting into is an init method
350 // TODO this can be a qualified name check
351 m.name() == "__init__" &&
352 // 2. The `self` arg matches this value's type (i.e. we are in the init
353 // method for this class, not some other class)
354 !m.graph()->inputs().empty() &&
355 m.graph()->inputs().at(0)->type() == classType;
356
357 if (isInitializing) {
358 if (isRecursive(classType, newValue->type())) {
359 throw(
360 ErrorReport(loc)
361 << "Assignment to attribute '" << field
362 << "' cannot be of a type that contains class "
363 << "'" << classType->repr_str() << "'.\n"
364 << "Classes that recursively contain instances of themselves"
365 << " are not yet supported");
366 }
367
368 classType->addAttribute(field, newValue->type());
369 expectedType = newValue->type();
370
371 const auto insertPoint = m.graph()->insertPoint();
372 const auto topLevelBlock = m.graph()->block();
373 if (insertPoint->owningBlock() != topLevelBlock) {
374 throw(
375 ErrorReport(loc)
376 << "First assignment cannot be in a control-flow block. "
377 << "Initialize the field at the top level first");
378 }
379 } else {
380 // Check and see if it's a setter attribute.
381 auto prop = classType->getProperty(field);
382 if (prop && prop->setter) {
383 MethodValue(value_, prop->setter->name())
384 .call(loc, m, {newValue}, {}, /*n_binders=*/1);
385 return;
386 }
387
388 if (prop && !prop->setter) {
389 throw(
390 ErrorReport(loc) << "Tried to set read-only attribute: " << field);
391 }
392
393 throw(
394 ErrorReport(loc)
395 << "Tried to set nonexistent attribute: " << field
396 << ". Did you forget to initialize it in __init__()?");
397 }
398 }
399
400 AT_ASSERT(expectedType);
401
402 // Check type correctness
403 const auto newType = newValue->type();
404 if (!newType->isSubtypeOf(*expectedType)) {
405 throw(
406 ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
407 << expectedType->repr_str() << " but got "
408 << newType->repr_str());
409 }
410
411 auto& g = *m.graph();
412 g.insertNode(g.createSetAttr(value_, field, newValue));
413 }
414
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)415 std::shared_ptr<SugaredValue> SimpleValue::call(
416 const SourceRange& loc,
417 GraphFunction& m,
418 at::ArrayRef<NamedValue> args,
419 at::ArrayRef<NamedValue> kwargs,
420 size_t n_binders) {
421 // allow our 'fake' closures to be called, used for fork serialization
422 // at the moment, but can be expanded later
423 Node* self = getValue()->node();
424 if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
425 self->inputs().at(0)->node()->kind() == prim::Closure) {
426 std::shared_ptr<Graph> graph =
427 self->inputs().at(0)->node()->g(attr::Subgraph);
428 Value* context = self->inputs().at(1);
429 AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
430
431 // fork nodes are emitted in their own block but we do not simplify
432 // tuple construction across blocks. To ensure we clean up the tuple
433 // construct create another copy of the tuple construct in the fork block
434 Value* close_context =
435 m.graph()
436 ->insertNode(m.graph()->createTuple(context->node()->inputs()))
437 ->output();
438 // TODO this needs to go in `m`s compilation unit
439 auto cu = std::make_shared<CompilationUnit>();
440 auto fn = cu->create_function(QualifiedName("anon"), graph);
441 auto ret = StrongFunctionPtr(std::move(cu), fn);
442
443 std::vector<NamedValue> ctx_inputs = {close_context};
444 ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end());
445 return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders);
446 }
447
448 if (auto class_type = getValue()->type()->cast<ClassType>()) {
449 return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders);
450 }
451
452 return SugaredValue::call(loc, m, args, kwargs, n_binders);
453 }
454
len(const SourceRange & loc,GraphFunction & m)455 Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
456 // List, Tuple, Tensor, fill in missing information desugaring
457 Value* val = getValue();
458 TypePtr val_type = val->type();
459 Graph& g = *m.graph();
460 if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
461 val_type->isSubtypeOf(*TensorType::get())) {
462 return g.insert(aten::len, {val}, {}, loc);
463 } else {
464 throw(
465 ErrorReport(loc) << "'" << val_type->repr_str() << "'"
466 << " object is not iterable");
467 }
468 }
469
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)470 SugaredValuePtr SimpleValue::getitem(
471 const SourceRange& loc,
472 GraphFunction& m,
473 Value* idx,
474 TypePtr type_hint) {
475 Value* val = getValue();
476 TypePtr val_type = val->type();
477 Graph& g = *m.graph();
478
479 // if it's a List/String/Dict, emit a regular __getitem__ op
480 // NOLINTNEXTLINE(bugprone-branch-clone)
481 if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
482 return std::make_shared<SimpleValue>(
483 g.insert(aten::__getitem__, {val, idx}, {}, loc));
484 } else if (auto dict_type = val_type->cast<DictType>()) {
485 return std::make_shared<SimpleValue>(
486 g.insert(aten::__getitem__, {val, idx}, {}, loc));
487 } else if (val_type->isSubtypeOf(*TensorType::get())) {
488 return std::make_shared<SimpleValue>(
489 g.insert(aten::select, {val, 0, idx}, {}, loc));
490 } else if (auto class_type = val_type->cast<ClassType>()) {
491 // Check if this is an indexing operation enabled by a type hint.
492 // The ModuleDict has already been checked during IR generation to make
493 // sure its contents implement the module interface referred to by
494 // type_hint.
495 if (class_type->is_module() && type_hint) {
496 auto res = g.insert(prim::ModuleContainerIndex, {val, idx}, {}, loc);
497 res->setType(type_hint);
498 return std::make_shared<SimpleValue>(res);
499 }
500
501 // Defer to the __getitem__ attr on the class.
502 return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
503 } else {
504 throw(
505 ErrorReport(loc) << "'" << val_type->repr_str() << "'"
506 << " object is not subscriptable");
507 }
508 }
509
iter(const SourceRange & loc,GraphFunction & m)510 SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
511 auto value = getValue();
512 auto type = value->type();
513 // built-in iterable types
514 if (type->cast<ListType>() || type->cast<StringType>() ||
515 type->cast<TensorType>()) {
516 return std::make_shared<SimpleValue>(value);
517 }
518 // dicts iterate over keys
519 if (type->cast<DictType>()) {
520 return std::make_shared<SimpleValue>(
521 m.graph()->insert(aten::keys, {value}, {}, loc));
522 }
523 if (auto tup = type->cast<TupleType>()) {
524 auto tup_values = createTupleUnpack(value);
525 std::vector<SugaredValuePtr> tup_sugared;
526 for (Value* v : tup_values) {
527 tup_sugared.push_back(std::make_shared<SimpleValue>(v));
528 }
529 return std::make_shared<SugaredTupleValue>(tup_sugared);
530 } else {
531 throw(
532 ErrorReport(loc) << "'" << type->repr_str() << "'"
533 << " object is not iterable");
534 }
535 }
536
RangeValue(const SourceRange & loc,GraphFunction & m,std::vector<Value * > inputs,std::optional<int64_t> static_len)537 RangeValue::RangeValue(
538 const SourceRange& loc,
539 GraphFunction& m,
540 std::vector<Value*> inputs,
541 std::optional<int64_t> static_len) {
542 for (const auto i : c10::irange(inputs.size())) {
543 auto typ = inputs[i]->type();
544 if (!typ->cast<IntType>()) {
545 throw(
546 ErrorReport(loc) << "all inputs of range must be ints, found "
547 << typ->repr_str() << " in argument "
548 << std::to_string(i));
549 }
550 }
551
552 Graph& g = *m.graph();
553 if (inputs.empty()) {
554 throw(ErrorReport(loc) << "range expected at least 1 arguments, got 0");
555 } else if (inputs.size() == 1) {
556 end_ = inputs[0];
557 start_ = g.insertConstant(0, loc);
558 step_ = g.insertConstant(1, loc);
559 // range() call only contains end, easier to calculate len() and getitem()
560 has_only_end_ = true;
561 } else if (inputs.size() <= 3) {
562 start_ = inputs[0];
563 end_ = inputs[1];
564 if (inputs.size() == 3) {
565 step_ = inputs[2];
566 } else {
567 step_ = g.insertConstant(1, loc);
568 }
569 has_only_end_ = false;
570 } else {
571 throw(
572 ErrorReport(loc) << "range expected at most 3 arguments, got "
573 << inputs.size());
574 }
575
576 static_len_ = static_len;
577 }
578
iter(const SourceRange & loc,GraphFunction & m)579 SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
580 return shared_from_this();
581 };
582
len(const SourceRange & loc,GraphFunction & m)583 Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
584 if (static_len_) {
585 return insertConstant(*m.graph(), *static_len_, loc);
586 }
587 if (has_only_end_) {
588 return end_;
589 } else {
590 Graph& g = *m.graph();
591 return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc);
592 }
593 }
594
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)595 SugaredValuePtr RangeValue::getitem(
596 const SourceRange& loc,
597 GraphFunction& m,
598 Value* idx,
599 TypePtr type_hint) {
600 if (has_only_end_) {
601 return std::make_shared<SimpleValue>(idx);
602 } else {
603 auto& g = *m.graph();
604 return std::make_shared<SimpleValue>(
605 g.insert(aten::__derive_index, {idx, start_, step_}, {}, loc));
606 }
607 }
608
get_base_iterables()609 std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
610 std::vector<SugaredValuePtr> base_iters{};
611
612 for (SugaredValuePtr& sv : children_) {
613 if (auto iv = std::dynamic_pointer_cast<IterableTree>(sv)) {
614 std::vector<SugaredValuePtr> child_iters = iv->get_base_iterables();
615 // merge child iters with the base_iters
616 base_iters.insert(
617 base_iters.end(),
618 std::make_move_iterator(child_iters.begin()),
619 std::make_move_iterator(child_iters.end()));
620
621 } else {
622 // IterableTree leaves, either SimpleValue or RangeValue
623 base_iters.emplace_back(sv);
624 }
625 }
626 return base_iters;
627 }
628
len(const SourceRange & loc,GraphFunction & m)629 Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
630 // if it's a iterable tree, we get the base iterables that consists of
631 // SimpleValue or RangeValue, and then calculate the minimum length of all the
632 // base iterables to be max_trip_count_val
633 TORCH_INTERNAL_ASSERT(!unroll_length_);
634 Graph& g = *m.graph();
635 std::vector<SugaredValuePtr> base_iters = get_base_iterables();
636 std::vector<Value*> lengths;
637 lengths.reserve(base_iters.size());
638
639 for (const SugaredValuePtr& base_iter : base_iters) {
640 lengths.emplace_back(base_iter->len(loc, m));
641 }
642 Node* list_node = g.insertNode(g.createList(IntType::get(), lengths));
643 return g.insert(prim::min, {list_node->output()}, {}, loc);
644 }
645
getitem(const SourceRange & loc,GraphFunction & m,Value * idx,TypePtr type_hint)646 SugaredValuePtr IterableTree::getitem(
647 const SourceRange& loc,
648 GraphFunction& m,
649 Value* idx,
650 TypePtr type_hint) {
651 std::vector<SugaredValuePtr> child_items;
652 child_items.reserve(children_.size());
653 for (const SugaredValuePtr& child : children_) {
654 child_items.emplace_back(child->getitem(loc, m, idx));
655 }
656 return std::make_shared<SugaredTupleValue>(child_items);
657 }
658
addChild(const SourceRange & range,GraphFunction & m,const SugaredValuePtr & iter_value)659 void IterableTree::addChild(
660 const SourceRange& range,
661 GraphFunction& m,
662 const SugaredValuePtr& iter_value) {
663 std::optional<int64_t> child_len = iter_value->staticLen();
664 if (children_.empty()) {
665 unroll_length_ = child_len;
666 } else {
667 if ((unroll_length_ && !child_len) || (child_len && !unroll_length_)) {
668 throw(
669 ErrorReport(range)
670 << "Can not iterate over a module list or tuple with a value "
671 "that does not have a statically determinable length\n");
672 }
673 if (unroll_length_ && child_len) {
674 // iterables run for the minimum length of all its leaves
675 unroll_length_ = std::min(*child_len, *unroll_length_);
676 } else {
677 unroll_length_ = std::nullopt;
678 }
679 }
680 children_.push_back(iter_value);
681 }
682
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)683 std::shared_ptr<SugaredValue> MagicMethod::call(
684 const SourceRange& loc,
685 GraphFunction& m,
686 at::ArrayRef<NamedValue> args,
687 at::ArrayRef<NamedValue> kwargs,
688 size_t n_binders) {
689 if (!args.empty()) {
690 Value* self = args[0].value(*m.graph());
691 if (auto class_ptr = self->type()->cast<ClassType>()) {
692 return SimpleValue(self)
693 .attr(loc, m, desugared_name_)
694 ->call(loc, m, args.slice(1), kwargs, n_binders);
695 }
696 }
697 TORCH_INTERNAL_ASSERT(base_value_);
698 return base_value_->call(loc, m, args, kwargs, n_binders);
699 }
700
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)701 std::shared_ptr<SugaredValue> ClassValue::call(
702 const SourceRange& loc,
703 GraphFunction& m,
704 // note: names for args will be 'argument 0', 'argument 1', etc..
705 at::ArrayRef<NamedValue> args,
706 at::ArrayRef<NamedValue> kwargs,
707 size_t n_binders) {
708 AT_ASSERT(n_binders <= 1);
709
710 // Generate a new object of the right type, then call `__init__` on it
711 auto& g = *m.graph();
712 auto self = g.insertNode(g.createObject(type_))->output();
713 self->node()->setSourceRange(loc);
714 if (!type_->findMethod("__init__")) {
715 throw(
716 ErrorReport(loc) << "Class " << type_->name()->name()
717 << " does not have an __init__ function defined");
718 }
719
720 // Call the init function
721 MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders);
722
723 return std::make_shared<SimpleValue>(self);
724 }
725
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)726 std::shared_ptr<SugaredValue> ClassValue::attr(
727 const SourceRange& loc,
728 GraphFunction& m,
729 const std::string& field) {
730 // Allow import_source.cpp to resolve calls to a submodule's
731 // hooks. Edge case because normally you wouldn't allow a module to
732 // call functions of a submodule
733 if (Function* hook = type_->findHook(field)) {
734 return std::make_shared<FunctionValue>(hook);
735 }
736
737 if (field != "__new__") {
738 throw(
739 ErrorReport(loc) << "Tried to lookup unknown attribute on class "
740 << type_->annotation_str());
741 }
742 return SpecialFormValue::create(prim::CreateObject);
743 }
744
call(const SourceRange & loc,GraphFunction & m,at::ArrayRef<NamedValue> args,at::ArrayRef<NamedValue> kwargs,size_t n_binders)745 std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
746 const SourceRange& loc,
747 GraphFunction& m,
748 at::ArrayRef<NamedValue> args,
749 at::ArrayRef<NamedValue> kwargs,
750 size_t n_binders) {
751 auto& g = *m.graph();
752
753 auto schema = type_->schema();
754 TORCH_INTERNAL_ASSERT(schema);
755 auto qualname = type_->name();
756 auto matched_schema = matchSchema(*schema, loc, g, args, kwargs);
757
758 auto self =
759 g.insertNode(
760 g.createTuple(matched_schema.inputs, type_)->setSourceRange(loc))
761 ->output();
762 self->setType(type_);
763
764 return std::make_shared<SimpleValue>(self);
765 }
766
tryCreate(Symbol symbol,std::optional<NamedValue> self)767 std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
768 Symbol symbol,
769 std::optional<NamedValue> self) {
770 for (const std::shared_ptr<Operator>& op : getAllOperatorsFor(symbol)) {
771 if (!self) {
772 return std::make_shared<BuiltinFunction>(symbol, nullptr);
773 }
774 if (auto index = op->schema().argumentIndexWithName("self")) {
775 std::unordered_map<std::string, TypePtr> type_env;
776 TypePtr formal_type = op->schema().arguments().at(*index).type();
777 const MatchTypeReturn matched =
778 matchTypeVariables(formal_type, self->type(), type_env);
779 if (!matched.success()) {
780 continue;
781 }
782 const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
783 if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
784 continue;
785 }
786 return std::make_shared<BuiltinFunction>(symbol, self);
787 }
788 }
789 return nullptr;
790 }
791
attr(const SourceRange & loc,GraphFunction & m,const std::string & field)792 std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
793 const SourceRange& loc,
794 GraphFunction& m,
795 const std::string& field) {
796 const auto& names_values = enum_type_->enumNamesValues();
797 auto it = std::find_if(
798 names_values.begin(),
799 names_values.end(),
800 [&field](const at::EnumNameValue& nv) { return nv.first == field; });
801 if (it == names_values.end()) {
802 throw(
803 ErrorReport(loc) << enum_type_->repr_str() << "'"
804 << " has no attribute '" << field << "'");
805 }
806 auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
807 enum_type_, it->first, it->second);
808 return std::make_shared<SimpleValue>(
809 m.graph()->insertConstant(IValue(enum_holder), loc));
810 }
811
iter(const SourceRange & loc,GraphFunction & m)812 SugaredValuePtr SugaredEnumClass::iter(
813 const SourceRange& loc,
814 GraphFunction& m) {
815 const auto& names_values = enum_type_->enumNamesValues();
816 auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
817 enum_value_ivalues.reserve(names_values.size());
818 for (const auto& name_value : names_values) {
819 auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
820 enum_type_, name_value.first, name_value.second);
821 enum_value_ivalues.emplace_back(enum_holder);
822 }
823
824 auto enum_values_list_constant = std::make_shared<SimpleValue>(
825 m.graph()->insertConstant(enum_value_ivalues, loc));
826 return enum_values_list_constant;
827 }
828
829 } // namespace torch::jit
830