xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_sugared_value.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/frontend/concrete_module_type.h>
5 #include <torch/csrc/jit/frontend/sugared_value.h>
6 #include <torch/csrc/jit/python/pybind_utils.h>
7 #include <memory>
8 #include <sstream>
9 #include <string>
10 #include <utility>
11 #include <vector>
12 
13 namespace torch::jit {
14 
15 std::string typeString(py::handle h);
16 
toSimple(Value * v)17 inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
18   return std::make_shared<SimpleValue>(v);
19 }
20 
21 // NB: This should be the single entry-point for instantiating a SugaredValue
22 // from a Python object. If you are adding support for converting a new Python
23 // type, *add it in this function's implementation*.
24 std::shared_ptr<SugaredValue> toSugaredValue(
25     py::object obj,
26     GraphFunction& m,
27     const SourceRange& loc,
28     bool is_constant = false);
29 
30 std::optional<StrongFunctionPtr> as_function(const py::object& obj);
31 
32 struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
33   PythonValue(
34       py::object the_self,
35       std::optional<py::object> rcb = std::nullopt,
36       Value* module_self = nullptr)
selfPythonValue37       : self(std::move(the_self)),
38         rcb(std::move(rcb)),
39         moduleSelf_(module_self) {}
40 
41   FunctionSchema getSchema(
42       const size_t n_args,
43       const size_t n_binders,
44       const SourceRange& loc);
45 
46   // call it like a function, e.g. `outputs = this(inputs)`
47   std::shared_ptr<SugaredValue> call(
48       const SourceRange& loc,
49       GraphFunction& m,
50       at::ArrayRef<NamedValue> args,
51       at::ArrayRef<NamedValue> kwargs,
52       size_t n_binders) override;
53 
54   std::string kind() const override;
55 
56   std::vector<std::shared_ptr<SugaredValue>> asTuple(
57       const SourceRange& loc,
58       GraphFunction& m,
59       const std::optional<size_t>& size_hint = {}) override;
60 
61   std::shared_ptr<SugaredValue> attr(
62       const SourceRange& loc,
63       GraphFunction& m,
64       const std::string& field) override;
65 
asValuePythonValue66   Value* asValue(const SourceRange& loc, GraphFunction& m) override {
67     throw(
68         ErrorReport(loc)
69         << kind() << " cannot be used as a value. "
70         << "Perhaps it is a closed over global variable? If so, please "
71         << "consider passing it in as an argument or use a local varible "
72         << "instead.");
73   }
74 
75  protected:
76   py::object getattr(const SourceRange& loc, const std::string& name);
77 
78   void checkForAddToConstantsError(std::stringstream& ss);
79 
80   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
81   py::object self;
82   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
83   std::optional<py::object> rcb;
84   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
85   Value* moduleSelf_ = nullptr;
86 };
87 
88 struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
PythonModuleValuePythonModuleValue89   explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
90 
91   std::shared_ptr<SugaredValue> attr(
92       const SourceRange& loc,
93       GraphFunction& m,
94       const std::string& field) override;
95 };
96 
97 // Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
98 // torch.cuda.* are resolved using CUDAPythonModuleValue.
99 struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
CUDAPythonModuleValueCUDAPythonModuleValue100   explicit CUDAPythonModuleValue(py::object mod)
101       : PythonValue(std::move(mod)) {}
102 
103   std::shared_ptr<SugaredValue> attr(
104       const SourceRange& loc,
105       GraphFunction& m,
106       const std::string& field) override;
107 };
108 
109 // Represents all the parameters of a module as a List[Tensor]
110 struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
ConstantParameterListConstantParameterList111   ConstantParameterList(Value* the_list) : the_list_(the_list) {}
kindConstantParameterList112   std::string kind() const override {
113     return "constant parameter list";
114   }
callConstantParameterList115   std::shared_ptr<SugaredValue> call(
116       const SourceRange& loc,
117       GraphFunction& caller,
118       at::ArrayRef<NamedValue> args,
119       at::ArrayRef<NamedValue> kwargs,
120       size_t n_binders) override {
121     return toSimple(the_list_);
122   }
123 
124  private:
125   Value* the_list_;
126 };
127 
128 struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
ModuleDictMethodModuleDictMethod129   explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name)
130       : iterable_(std::move(iterable)), name_(std::move(name)){};
131 
kindModuleDictMethod132   std::string kind() const override {
133     return name_;
134   }
135 
callModuleDictMethod136   std::shared_ptr<SugaredValue> call(
137       const SourceRange& loc,
138       GraphFunction& f,
139       at::ArrayRef<NamedValue> args,
140       at::ArrayRef<NamedValue> kwargs,
141       size_t n_binders) override {
142     if (!args.empty() || !kwargs.empty()) {
143       throw(
144           ErrorReport(loc) << name_ << " method does not accept any arguments");
145     }
146     return iterable_;
147   }
148 
149   SugaredValuePtr iterable_;
150   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
151   const std::string name_;
152 };
153 
154 struct SugaredDict;
155 
156 // defines how modules/methods behave inside the script subset.
157 // for now this does not have any interaction with python.
158 // in the future, we will add the ability to resolve `self.foo` to python
159 // {functions, modules, constants} so this SugaredValue is defined here
160 // anticipating we will eventually need to replace Module with a py::object
161 // holding the actual nn.Module class.
162 
163 struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
ModuleValueModuleValue164   ModuleValue(Value* self, std::shared_ptr<ConcreteModuleType> concreteType)
165       : self_(self), concreteType_(std::move(concreteType)) {}
166 
kindModuleValue167   std::string kind() const override {
168     return "module";
169   }
170 
171   Value* asValue(const SourceRange& loc, GraphFunction& m) override;
172 
173   SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
174       override;
175 
176   // select an attribute on it, e.g. `this.field`
177   std::shared_ptr<SugaredValue> tryGetAttr(
178       const SourceRange& loc,
179       GraphFunction& m,
180       const std::string& field);
181 
182   // select an attribute on it, e.g. `this.field`
183   std::shared_ptr<SugaredValue> attr(
184       const SourceRange& loc,
185       GraphFunction& m,
186       const std::string& field) override;
187 
188   // select an attribute on it, e.g. `this.field`
189   bool hasAttr(
190       const SourceRange& loc,
191       GraphFunction& m,
192       const std::string& field) override;
193 
194   // call module.forward with pre_hooks and hooks
195   std::shared_ptr<SugaredValue> call(
196       const SourceRange& loc,
197       GraphFunction& caller,
198       at::ArrayRef<NamedValue> args,
199       at::ArrayRef<NamedValue> kwargs,
200       size_t n_binders) override;
201 
202   std::shared_ptr<SugaredDict> getSugaredDict(
203       const SourceRange& loc,
204       GraphFunction& m);
205 
206   std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
207       const SourceRange& loc,
208       GraphFunction& m);
209 
210   std::shared_ptr<SugaredDict> getSugaredNamedParameterList(
211       const SourceRange& loc,
212       GraphFunction& m);
213 
214   std::shared_ptr<SugaredDict> getSugaredNamedParameterDict(
215       const SourceRange& loc,
216       GraphFunction& m);
217 
218   void setAttr(
219       const SourceRange& loc,
220       GraphFunction& m,
221       const std::string& field,
222       Value* newValue) override;
223 
224   SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
225 
226   std::shared_ptr<SugaredValue> getitem(
227       const SourceRange& loc,
228       GraphFunction& m,
229       Value* idx,
230       TypePtr type_hint) override;
231 
232  private:
233   // Check that the type of all submodules is a subtype of ty. If the function
234   // returns false, more information about why it returns false (e.g. which
235   // submodule's type is not a subtype of ty) is printed it why_not if it is not
236   // null.
237   bool areAllSubmodulesSubtypeOf(
238       const TypePtr& ty,
239       std::ostream* why_not = nullptr) const;
240 
241   Value* self_;
242   std::shared_ptr<ConcreteModuleType> concreteType_;
243 };
244 
245 bool isNamedTupleClass(const py::object& obj);
246 TypePtr registerNamedTuple(
247     const py::object& obj,
248     const SourceRange& loc,
249     const ResolutionCallback& rcb);
250 
251 void recurseThroughNestedModules(
252     const SourceRange& loc,
253     GraphFunction& m,
254     std::vector<SugaredValuePtr>& keys,
255     std::vector<SugaredValuePtr>& values,
256     std::shared_ptr<ModuleValue>& self,
257     const std::string& prefix,
258     const std::string& field);
259 
260 // Used to support named_modules()
261 struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
SugaredDictSugaredDict262   explicit SugaredDict(
263       std::shared_ptr<ModuleValue> self,
264       std::shared_ptr<SugaredTupleValue> keys,
265       std::shared_ptr<SugaredTupleValue> modules)
266       : self_(std::move(self)),
267         keys_(std::move(keys)),
268         modules_(std::move(modules)) {}
269 
kindSugaredDict270   std::string kind() const override {
271     return "ModuleDict";
272   }
273 
getKeysSugaredDict274   std::shared_ptr<SugaredTupleValue> getKeys() {
275     return keys_;
276   }
277 
getModulesSugaredDict278   std::shared_ptr<SugaredTupleValue> getModules() {
279     return modules_;
280   }
281 
282   std::shared_ptr<SugaredValue> attr(
283       const SourceRange& loc,
284       GraphFunction& m,
285       const std::string& field) override;
286 
iterSugaredDict287   SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
288     return keys_;
289   };
290 
291   std::shared_ptr<ModuleValue> self_;
292   std::shared_ptr<SugaredTupleValue> keys_;
293   std::shared_ptr<SugaredTupleValue> modules_;
294 };
295 
296 struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValueBooleanDispatchValue297   BooleanDispatchValue(py::dict dispatched_fn)
298       : dispatched_fn_(std::move(dispatched_fn)) {}
299 
kindBooleanDispatchValue300   std::string kind() const override {
301     return "boolean dispatch";
302   }
303 
304   std::shared_ptr<SugaredValue> call(
305       const SourceRange& loc,
306       GraphFunction& caller,
307       at::ArrayRef<NamedValue> args,
308       at::ArrayRef<NamedValue> kwargs,
309       size_t n_binders) override;
310 
311  private:
312   py::dict dispatched_fn_;
313 };
314 
315 struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
PythonClassValuePythonClassValue316   PythonClassValue(ClassTypePtr type, py::object py_type)
317       : ClassValue(std::move(type)), py_type_(std::move(py_type)) {}
318 
kindPythonClassValue319   std::string kind() const override {
320     return "Python type";
321   }
322 
323   std::shared_ptr<SugaredValue> attr(
324       const SourceRange& loc,
325       GraphFunction& m,
326       const std::string& field) override;
327 
328   bool hasAttr(
329       const SourceRange& loc,
330       GraphFunction& m,
331       const std::string& field) override;
332 
333  private:
334   py::object py_type_;
335 };
336 
337 struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
PythonExceptionValuePythonExceptionValue338   explicit PythonExceptionValue(const py::object& exception_class)
339       : ExceptionValue(
340             py::str(py::getattr(exception_class, "__name__", py::str("")))),
341         exception_class_qualified_name_(
342             py::str(py::module::import("torch._jit_internal")
343                         .attr("_qualified_name")(
344                             exception_class,
345                             /*mangle_name=*/false))) {}
346 
kindPythonExceptionValue347   std::string kind() const override {
348     return "Python exception";
349   }
350 
351   std::shared_ptr<SugaredValue> call(
352       const SourceRange& loc,
353       GraphFunction& caller,
354       at::ArrayRef<NamedValue> args,
355       at::ArrayRef<NamedValue> kwargs,
356       size_t n_binders) override;
357 
358  private:
359   std::string exception_class_qualified_name_;
360 };
361 
362 // Python Slice class.
363 struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
364   explicit PythonSliceClass() = default;
365 
kindPythonSliceClass366   std::string kind() const override {
367     return "Python slice class";
368   }
369 
370   std::shared_ptr<SugaredValue> call(
371       const SourceRange& loc,
372       GraphFunction& caller,
373       at::ArrayRef<NamedValue> args,
374       at::ArrayRef<NamedValue> kwargs,
375       size_t n_binders) override;
376 };
377 
378 } // namespace torch::jit
379