1 #pragma once
2
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/frontend/concrete_module_type.h>
5 #include <torch/csrc/jit/frontend/sugared_value.h>
6 #include <torch/csrc/jit/python/pybind_utils.h>
7 #include <memory>
8 #include <sstream>
9 #include <string>
10 #include <utility>
11 #include <vector>
12
13 namespace torch::jit {
14
15 std::string typeString(py::handle h);
16
toSimple(Value * v)17 inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
18 return std::make_shared<SimpleValue>(v);
19 }
20
21 // NB: This should be the single entry-point for instantiating a SugaredValue
22 // from a Python object. If you are adding support for converting a new Python
23 // type, *add it in this function's implementation*.
24 std::shared_ptr<SugaredValue> toSugaredValue(
25 py::object obj,
26 GraphFunction& m,
27 const SourceRange& loc,
28 bool is_constant = false);
29
30 std::optional<StrongFunctionPtr> as_function(const py::object& obj);
31
32 struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
33 PythonValue(
34 py::object the_self,
35 std::optional<py::object> rcb = std::nullopt,
36 Value* module_self = nullptr)
selfPythonValue37 : self(std::move(the_self)),
38 rcb(std::move(rcb)),
39 moduleSelf_(module_self) {}
40
41 FunctionSchema getSchema(
42 const size_t n_args,
43 const size_t n_binders,
44 const SourceRange& loc);
45
46 // call it like a function, e.g. `outputs = this(inputs)`
47 std::shared_ptr<SugaredValue> call(
48 const SourceRange& loc,
49 GraphFunction& m,
50 at::ArrayRef<NamedValue> args,
51 at::ArrayRef<NamedValue> kwargs,
52 size_t n_binders) override;
53
54 std::string kind() const override;
55
56 std::vector<std::shared_ptr<SugaredValue>> asTuple(
57 const SourceRange& loc,
58 GraphFunction& m,
59 const std::optional<size_t>& size_hint = {}) override;
60
61 std::shared_ptr<SugaredValue> attr(
62 const SourceRange& loc,
63 GraphFunction& m,
64 const std::string& field) override;
65
asValuePythonValue66 Value* asValue(const SourceRange& loc, GraphFunction& m) override {
67 throw(
68 ErrorReport(loc)
69 << kind() << " cannot be used as a value. "
70 << "Perhaps it is a closed over global variable? If so, please "
71 << "consider passing it in as an argument or use a local varible "
72 << "instead.");
73 }
74
75 protected:
76 py::object getattr(const SourceRange& loc, const std::string& name);
77
78 void checkForAddToConstantsError(std::stringstream& ss);
79
80 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
81 py::object self;
82 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
83 std::optional<py::object> rcb;
84 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
85 Value* moduleSelf_ = nullptr;
86 };
87
88 struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
PythonModuleValuePythonModuleValue89 explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
90
91 std::shared_ptr<SugaredValue> attr(
92 const SourceRange& loc,
93 GraphFunction& m,
94 const std::string& field) override;
95 };
96
97 // Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
98 // torch.cuda.* are resolved using CUDAPythonModuleValue.
99 struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
CUDAPythonModuleValueCUDAPythonModuleValue100 explicit CUDAPythonModuleValue(py::object mod)
101 : PythonValue(std::move(mod)) {}
102
103 std::shared_ptr<SugaredValue> attr(
104 const SourceRange& loc,
105 GraphFunction& m,
106 const std::string& field) override;
107 };
108
109 // Represents all the parameters of a module as a List[Tensor]
110 struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
ConstantParameterListConstantParameterList111 ConstantParameterList(Value* the_list) : the_list_(the_list) {}
kindConstantParameterList112 std::string kind() const override {
113 return "constant parameter list";
114 }
callConstantParameterList115 std::shared_ptr<SugaredValue> call(
116 const SourceRange& loc,
117 GraphFunction& caller,
118 at::ArrayRef<NamedValue> args,
119 at::ArrayRef<NamedValue> kwargs,
120 size_t n_binders) override {
121 return toSimple(the_list_);
122 }
123
124 private:
125 Value* the_list_;
126 };
127
128 struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
ModuleDictMethodModuleDictMethod129 explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name)
130 : iterable_(std::move(iterable)), name_(std::move(name)){};
131
kindModuleDictMethod132 std::string kind() const override {
133 return name_;
134 }
135
callModuleDictMethod136 std::shared_ptr<SugaredValue> call(
137 const SourceRange& loc,
138 GraphFunction& f,
139 at::ArrayRef<NamedValue> args,
140 at::ArrayRef<NamedValue> kwargs,
141 size_t n_binders) override {
142 if (!args.empty() || !kwargs.empty()) {
143 throw(
144 ErrorReport(loc) << name_ << " method does not accept any arguments");
145 }
146 return iterable_;
147 }
148
149 SugaredValuePtr iterable_;
150 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
151 const std::string name_;
152 };
153
154 struct SugaredDict;
155
156 // defines how modules/methods behave inside the script subset.
157 // for now this does not have any interaction with python.
158 // in the future, we will add the ability to resolve `self.foo` to python
159 // {functions, modules, constants} so this SugaredValue is defined here
160 // anticipating we will eventually need to replace Module with a py::object
161 // holding the actual nn.Module class.
162
163 struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
ModuleValueModuleValue164 ModuleValue(Value* self, std::shared_ptr<ConcreteModuleType> concreteType)
165 : self_(self), concreteType_(std::move(concreteType)) {}
166
kindModuleValue167 std::string kind() const override {
168 return "module";
169 }
170
171 Value* asValue(const SourceRange& loc, GraphFunction& m) override;
172
173 SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
174 override;
175
176 // select an attribute on it, e.g. `this.field`
177 std::shared_ptr<SugaredValue> tryGetAttr(
178 const SourceRange& loc,
179 GraphFunction& m,
180 const std::string& field);
181
182 // select an attribute on it, e.g. `this.field`
183 std::shared_ptr<SugaredValue> attr(
184 const SourceRange& loc,
185 GraphFunction& m,
186 const std::string& field) override;
187
188 // select an attribute on it, e.g. `this.field`
189 bool hasAttr(
190 const SourceRange& loc,
191 GraphFunction& m,
192 const std::string& field) override;
193
194 // call module.forward with pre_hooks and hooks
195 std::shared_ptr<SugaredValue> call(
196 const SourceRange& loc,
197 GraphFunction& caller,
198 at::ArrayRef<NamedValue> args,
199 at::ArrayRef<NamedValue> kwargs,
200 size_t n_binders) override;
201
202 std::shared_ptr<SugaredDict> getSugaredDict(
203 const SourceRange& loc,
204 GraphFunction& m);
205
206 std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
207 const SourceRange& loc,
208 GraphFunction& m);
209
210 std::shared_ptr<SugaredDict> getSugaredNamedParameterList(
211 const SourceRange& loc,
212 GraphFunction& m);
213
214 std::shared_ptr<SugaredDict> getSugaredNamedParameterDict(
215 const SourceRange& loc,
216 GraphFunction& m);
217
218 void setAttr(
219 const SourceRange& loc,
220 GraphFunction& m,
221 const std::string& field,
222 Value* newValue) override;
223
224 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
225
226 std::shared_ptr<SugaredValue> getitem(
227 const SourceRange& loc,
228 GraphFunction& m,
229 Value* idx,
230 TypePtr type_hint) override;
231
232 private:
233 // Check that the type of all submodules is a subtype of ty. If the function
234 // returns false, more information about why it returns false (e.g. which
235 // submodule's type is not a subtype of ty) is printed it why_not if it is not
236 // null.
237 bool areAllSubmodulesSubtypeOf(
238 const TypePtr& ty,
239 std::ostream* why_not = nullptr) const;
240
241 Value* self_;
242 std::shared_ptr<ConcreteModuleType> concreteType_;
243 };
244
245 bool isNamedTupleClass(const py::object& obj);
246 TypePtr registerNamedTuple(
247 const py::object& obj,
248 const SourceRange& loc,
249 const ResolutionCallback& rcb);
250
251 void recurseThroughNestedModules(
252 const SourceRange& loc,
253 GraphFunction& m,
254 std::vector<SugaredValuePtr>& keys,
255 std::vector<SugaredValuePtr>& values,
256 std::shared_ptr<ModuleValue>& self,
257 const std::string& prefix,
258 const std::string& field);
259
260 // Used to support named_modules()
261 struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
SugaredDictSugaredDict262 explicit SugaredDict(
263 std::shared_ptr<ModuleValue> self,
264 std::shared_ptr<SugaredTupleValue> keys,
265 std::shared_ptr<SugaredTupleValue> modules)
266 : self_(std::move(self)),
267 keys_(std::move(keys)),
268 modules_(std::move(modules)) {}
269
kindSugaredDict270 std::string kind() const override {
271 return "ModuleDict";
272 }
273
getKeysSugaredDict274 std::shared_ptr<SugaredTupleValue> getKeys() {
275 return keys_;
276 }
277
getModulesSugaredDict278 std::shared_ptr<SugaredTupleValue> getModules() {
279 return modules_;
280 }
281
282 std::shared_ptr<SugaredValue> attr(
283 const SourceRange& loc,
284 GraphFunction& m,
285 const std::string& field) override;
286
iterSugaredDict287 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
288 return keys_;
289 };
290
291 std::shared_ptr<ModuleValue> self_;
292 std::shared_ptr<SugaredTupleValue> keys_;
293 std::shared_ptr<SugaredTupleValue> modules_;
294 };
295
296 struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValueBooleanDispatchValue297 BooleanDispatchValue(py::dict dispatched_fn)
298 : dispatched_fn_(std::move(dispatched_fn)) {}
299
kindBooleanDispatchValue300 std::string kind() const override {
301 return "boolean dispatch";
302 }
303
304 std::shared_ptr<SugaredValue> call(
305 const SourceRange& loc,
306 GraphFunction& caller,
307 at::ArrayRef<NamedValue> args,
308 at::ArrayRef<NamedValue> kwargs,
309 size_t n_binders) override;
310
311 private:
312 py::dict dispatched_fn_;
313 };
314
315 struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
PythonClassValuePythonClassValue316 PythonClassValue(ClassTypePtr type, py::object py_type)
317 : ClassValue(std::move(type)), py_type_(std::move(py_type)) {}
318
kindPythonClassValue319 std::string kind() const override {
320 return "Python type";
321 }
322
323 std::shared_ptr<SugaredValue> attr(
324 const SourceRange& loc,
325 GraphFunction& m,
326 const std::string& field) override;
327
328 bool hasAttr(
329 const SourceRange& loc,
330 GraphFunction& m,
331 const std::string& field) override;
332
333 private:
334 py::object py_type_;
335 };
336
337 struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
PythonExceptionValuePythonExceptionValue338 explicit PythonExceptionValue(const py::object& exception_class)
339 : ExceptionValue(
340 py::str(py::getattr(exception_class, "__name__", py::str("")))),
341 exception_class_qualified_name_(
342 py::str(py::module::import("torch._jit_internal")
343 .attr("_qualified_name")(
344 exception_class,
345 /*mangle_name=*/false))) {}
346
kindPythonExceptionValue347 std::string kind() const override {
348 return "Python exception";
349 }
350
351 std::shared_ptr<SugaredValue> call(
352 const SourceRange& loc,
353 GraphFunction& caller,
354 at::ArrayRef<NamedValue> args,
355 at::ArrayRef<NamedValue> kwargs,
356 size_t n_binders) override;
357
358 private:
359 std::string exception_class_qualified_name_;
360 };
361
362 // Python Slice class.
363 struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
364 explicit PythonSliceClass() = default;
365
kindPythonSliceClass366 std::string kind() const override {
367 return "Python slice class";
368 }
369
370 std::shared_ptr<SugaredValue> call(
371 const SourceRange& loc,
372 GraphFunction& caller,
373 at::ArrayRef<NamedValue> args,
374 at::ArrayRef<NamedValue> kwargs,
375 size_t n_binders) override;
376 };
377
378 } // namespace torch::jit
379