xref: /aosp_15_r20/external/pytorch/torch/custom_class.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/builtin_function.h>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/class_type.h>
7 #include <ATen/core/op_registration/infer_schema.h>
8 #include <ATen/core/stack.h>
9 #include <c10/util/C++17.h>
10 #include <c10/util/Metaprogramming.h>
11 #include <c10/util/TypeList.h>
12 #include <c10/util/TypeTraits.h>
13 #include <torch/custom_class_detail.h>
14 #include <torch/library.h>
15 #include <sstream>
16 
17 namespace torch {
18 
19 /// This function is used in conjunction with `class_::def()` to register
20 /// a constructor for a given C++ class type. For example,
21 /// `torch::init<int, std::string>()` would register a two-argument constructor
22 /// taking an `int` and a `std::string` as argument.
23 template <class... Types>
init()24 detail::types<void, Types...> init() {
25   return detail::types<void, Types...>{};
26 }
27 
28 template <typename Func, typename... ParameterTypeList>
29 struct InitLambda {
30   Func f;
31 };
32 
33 template <typename Func>
decltype(auto)34 decltype(auto) init(Func&& f) {
35   using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
36   using ParameterTypeList = typename InitTraits::parameter_types;
37 
38   InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
39   return init;
40 }
41 
42 /// Entry point for custom C++ class registration. To register a C++ class
43 /// in PyTorch, instantiate `torch::class_` with the desired class as the
44 /// template parameter. Typically, this instantiation should be done in
45 /// the initialization of a global variable, so that the class will be
46 /// made available on dynamic library loading without any additional API
47 /// calls needed. For example, to register a class named Foo, you might
48 /// create a global variable like so:
49 ///
50 ///     static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
51 ///       .def("myMethod", &Foo::myMethod)
52 ///       .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
53 ///         // Do something with `self`
54 ///       });
55 ///
56 /// In addition to registering the class, this registration also chains
57 /// `def()` calls to register methods. `myMethod()` is registered with
58 /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
59 /// is registered with a C++ lambda expression.
60 template <class CurClass>
61 class class_ : public ::torch::detail::class_base {
62   static_assert(
63       std::is_base_of_v<CustomClassHolder, CurClass>,
64       "torch::class_<T> requires T to inherit from CustomClassHolder");
65 
66  public:
67   /// This constructor actually registers the class type.
68   /// String argument `namespaceName` is an identifier for the
69   /// namespace you would like this class to appear in.
70   /// String argument `className` is the name you would like to
71   /// see this class exposed as in Python and TorchScript. For example, if
72   /// you pass `foo` as the namespace name and `Bar` as the className, the
73   /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
74   explicit class_(
75       const std::string& namespaceName,
76       const std::string& className,
77       std::string doc_string = "")
class_base(namespaceName,className,std::move (doc_string),typeid (c10::intrusive_ptr<CurClass>),typeid (c10::tagged_capsule<CurClass>))78       : class_base(
79             namespaceName,
80             className,
81             std::move(doc_string),
82             typeid(c10::intrusive_ptr<CurClass>),
83             typeid(c10::tagged_capsule<CurClass>)) {}
84 
85   /// def() can be used in conjunction with `torch::init()` to register
86   /// a constructor for a given C++ class type. For example, passing
87   /// `torch::init<int, std::string>()` would register a two-argument
88   /// constructor taking an `int` and a `std::string` as argument.
89   template <typename... Types>
90   class_& def(
91       torch::detail::types<void, Types...>,
92       std::string doc_string = "",
93       std::initializer_list<arg> default_args =
94           {}) { // Used in combination with
95     // torch::init<...>()
96     auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
97       auto classObj = c10::make_intrusive<CurClass>(args...);
98       auto object = self.ivalue.toObject();
99       object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
100     };
101 
102     defineMethod(
103         "__init__",
104         std::move(func),
105         std::move(doc_string),
106         default_args);
107     return *this;
108   }
109 
110   // Used in combination with torch::init([]lambda(){......})
111   template <typename Func, typename... ParameterTypes>
112   class_& def(
113       InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
114       std::string doc_string = "",
115       std::initializer_list<arg> default_args = {}) {
116     auto init_lambda_wrapper = [func = std::move(init.f)](
117                                    c10::tagged_capsule<CurClass> self,
118                                    ParameterTypes... arg) {
119       c10::intrusive_ptr<CurClass> classObj =
120           at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
121       auto object = self.ivalue.toObject();
122       object->setSlot(0, c10::IValue::make_capsule(classObj));
123     };
124 
125     defineMethod(
126         "__init__",
127         std::move(init_lambda_wrapper),
128         std::move(doc_string),
129         default_args);
130 
131     return *this;
132   }
133 
134   /// This is the normal method registration API. `name` is the name that
135   /// the method will be made accessible by in Python and TorchScript.
136   /// `f` is a callable object that defines the method. Typically `f`
137   /// will either be a pointer to a method on `CurClass`, or a lambda
138   /// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
139   /// argument (emulating a `this` argument in a C++ method.)
140   ///
141   /// Examples:
142   ///
143   ///     // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
144   ///     // Python and TorchScript
145   ///     .def("call_foo", &Foo::foo)
146   ///
147   ///     // Exposes the given lambda expression as method `call_lambda()`
148   ///     // in Python and TorchScript.
149   ///     .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
150   ///       // do something
151   ///     })
152   template <typename Func>
153   class_& def(
154       std::string name,
155       Func f,
156       std::string doc_string = "",
157       std::initializer_list<arg> default_args = {}) {
158     auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
159     defineMethod(
160         std::move(name),
161         std::move(wrapped_f),
162         std::move(doc_string),
163         default_args);
164     return *this;
165   }
166 
167   /// Method registration API for static methods.
168   template <typename Func>
169   class_& def_static(std::string name, Func func, std::string doc_string = "") {
170     auto qualMethodName = qualClassName + "." + name;
171     auto schema =
172         c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
173 
174     auto wrapped_func =
175         [func = std::move(func)](jit::Stack& stack) mutable -> void {
176       using RetType =
177           typename c10::guts::infer_function_traits_t<Func>::return_type;
178       detail::BoxedProxy<RetType, Func>()(stack, func);
179     };
180     auto method = std::make_unique<jit::BuiltinOpFunction>(
181         std::move(qualMethodName),
182         std::move(schema),
183         std::move(wrapped_func),
184         std::move(doc_string));
185 
186     classTypePtr->addStaticMethod(method.get());
187     registerCustomClassMethod(std::move(method));
188     return *this;
189   }
190 
191   /// Property registration API for properties with both getter and setter
192   /// functions.
193   template <typename GetterFunc, typename SetterFunc>
194   class_& def_property(
195       const std::string& name,
196       GetterFunc getter_func,
197       SetterFunc setter_func,
198       std::string doc_string = "") {
199     torch::jit::Function* getter{};
200     torch::jit::Function* setter{};
201 
202     auto wrapped_getter =
203         detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
204     getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
205 
206     auto wrapped_setter =
207         detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
208     setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
209 
210     classTypePtr->addProperty(name, getter, setter);
211     return *this;
212   }
213 
214   /// Property registration API for properties with only getter function.
215   template <typename GetterFunc>
216   class_& def_property(
217       const std::string& name,
218       GetterFunc getter_func,
219       std::string doc_string = "") {
220     torch::jit::Function* getter{};
221 
222     auto wrapped_getter =
223         detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
224     getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
225 
226     classTypePtr->addProperty(name, getter, nullptr);
227     return *this;
228   }
229 
230   /// Property registration API for properties with read-write access.
231   template <typename T>
def_readwrite(const std::string & name,T CurClass::* field)232   class_& def_readwrite(const std::string& name, T CurClass::*field) {
233     auto getter_func = [field =
234                             field](const c10::intrusive_ptr<CurClass>& self) {
235       return self.get()->*field;
236     };
237 
238     auto setter_func = [field = field](
239                            const c10::intrusive_ptr<CurClass>& self, T value) {
240       self.get()->*field = value;
241     };
242 
243     return def_property(name, getter_func, setter_func);
244   }
245 
246   /// Property registration API for properties with read-only access.
247   template <typename T>
def_readonly(const std::string & name,T CurClass::* field)248   class_& def_readonly(const std::string& name, T CurClass::*field) {
249     auto getter_func =
250         [field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
251           return self.get()->*field;
252         };
253 
254     return def_property(name, getter_func);
255   }
256 
257   /// This is an unsafe method registration API added for adding custom JIT
258   /// backend support via custom C++ classes. It is not for general purpose use.
259   class_& _def_unboxed(
260       const std::string& name,
261       std::function<void(jit::Stack&)> func,
262       c10::FunctionSchema schema,
263       std::string doc_string = "") {
264     auto method = std::make_unique<jit::BuiltinOpFunction>(
265         qualClassName + "." + name,
266         std::move(schema),
267         std::move(func),
268         std::move(doc_string));
269     classTypePtr->addMethod(method.get());
270     registerCustomClassMethod(std::move(method));
271     return *this;
272   }
273 
274   /// def_pickle() is used to define exactly what state gets serialized
275   /// or deserialized for a given instance of a custom C++ class in
276   /// Python or TorchScript. This protocol is equivalent to the Pickle
277   /// concept of `__getstate__` and `__setstate__` from Python
278   /// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
279   ///
280   /// Currently, both the `get_state` and `set_state` callables must be
281   /// C++ lambda expressions. They should have the following signatures,
282   /// where `CurClass` is the class you're registering and `T1` is some object
283   /// that encapsulates the state of the object.
284   ///
285   ///     __getstate__(intrusive_ptr<CurClass>) -> T1
286   ///     __setstate__(T2) -> intrusive_ptr<CurClass>
287   ///
288   /// `T1` must be an object that is convertable to IValue by the same rules
289   /// for custom op/method registration.
290   ///
291   /// For the common case, T1 == T2. T1 can also be a subtype of T2. An
292   /// example where it makes sense for T1 and T2 to differ is if __setstate__
293   /// handles legacy formats in a backwards compatible way.
294   ///
295   /// Example:
296   ///
297   ///     .def_pickle(
298   ///         // __getstate__
299   ///         [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
300   ///           return self->stack_;
301   ///         },
302   ///         [](std::vector<std::string> state) { // __setstate__
303   ///            return c10::make_intrusive<MyStackClass<std::string>>(
304   ///               std::vector<std::string>{"i", "was", "deserialized"});
305   ///         })
306   template <typename GetStateFn, typename SetStateFn>
307   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
def_pickle(GetStateFn && get_state,SetStateFn && set_state)308   class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
309     static_assert(
310         c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
311             c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
312         "def_pickle() currently only supports lambdas as "
313         "__getstate__ and __setstate__ arguments.");
314     def("__getstate__", std::forward<GetStateFn>(get_state));
315 
316     // __setstate__ needs to be registered with some custom handling:
317     // We need to wrap the invocation of the user-provided function
318     // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
319     // and assign it to the `capsule` attribute.
320     using SetStateTraits =
321         c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
322     using SetStateArg = typename c10::guts::typelist::head_t<
323         typename SetStateTraits::parameter_types>;
324     auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
325                                 c10::tagged_capsule<CurClass> self,
326                                 SetStateArg arg) {
327       c10::intrusive_ptr<CurClass> classObj =
328           at::guts::invoke(set_state, std::move(arg));
329       auto object = self.ivalue.toObject();
330       object->setSlot(0, c10::IValue::make_capsule(classObj));
331     };
332     defineMethod(
333         "__setstate__",
334         detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
335             std::move(setstate_wrapper)));
336 
337     // type validation
338     auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
339 #ifndef STRIP_ERROR_MESSAGES
340     auto format_getstate_schema = [&getstate_schema]() {
341       std::stringstream ss;
342       ss << getstate_schema;
343       return ss.str();
344     };
345 #endif
346     TORCH_CHECK(
347         getstate_schema.arguments().size() == 1,
348         "__getstate__ should take exactly one argument: self. Got: ",
349         format_getstate_schema());
350     auto first_arg_type = getstate_schema.arguments().at(0).type();
351     TORCH_CHECK(
352         *first_arg_type == *classTypePtr,
353         "self argument of __getstate__ must be the custom class type. Got ",
354         first_arg_type->repr_str());
355     TORCH_CHECK(
356         getstate_schema.returns().size() == 1,
357         "__getstate__ should return exactly one value for serialization. Got: ",
358         format_getstate_schema());
359 
360     auto ser_type = getstate_schema.returns().at(0).type();
361     auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
362     auto arg_type = setstate_schema.arguments().at(1).type();
363     TORCH_CHECK(
364         ser_type->isSubtypeOf(*arg_type),
365         "__getstate__'s return type should be a subtype of "
366         "input argument of __setstate__. Got ",
367         ser_type->repr_str(),
368         " but expected ",
369         arg_type->repr_str());
370 
371     return *this;
372   }
373 
374  private:
375   template <typename Func>
376   torch::jit::Function* defineMethod(
377       std::string name,
378       Func func,
379       std::string doc_string = "",
380       std::initializer_list<arg> default_args = {}) {
381     auto qualMethodName = qualClassName + "." + name;
382     auto schema =
383         c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
384 
385     // If default values are provided for function arguments, there must be
386     // none (no default values) or default values for all function
387     // arguments, except for self. This is because argument names are not
388     // extracted by inferFunctionSchemaSingleReturn, and so there must be a
389     // torch::arg instance in default_args even for arguments that do not
390     // have an actual default value provided.
391     TORCH_CHECK(
392         default_args.size() == 0 ||
393             default_args.size() == schema.arguments().size() - 1,
394         "Default values must be specified for none or all arguments");
395 
396     // If there are default args, copy the argument names and default values to
397     // the function schema.
398     if (default_args.size() > 0) {
399       schema = withNewArguments(schema, default_args);
400     }
401 
402     auto wrapped_func =
403         [func = std::move(func)](jit::Stack& stack) mutable -> void {
404       // TODO: we need to figure out how to profile calls to custom functions
405       // like this! Currently can't do it because the profiler stuff is in
406       // libtorch and not ATen
407       using RetType =
408           typename c10::guts::infer_function_traits_t<Func>::return_type;
409       detail::BoxedProxy<RetType, Func>()(stack, func);
410     };
411     auto method = std::make_unique<jit::BuiltinOpFunction>(
412         qualMethodName,
413         std::move(schema),
414         std::move(wrapped_func),
415         std::move(doc_string));
416 
417     // Register the method here to keep the Method alive.
418     // ClassTypes do not hold ownership of their methods (normally it
419     // those are held by the CompilationUnit), so we need a proxy for
420     // that behavior here.
421     auto method_val = method.get();
422     classTypePtr->addMethod(method_val);
423     registerCustomClassMethod(std::move(method));
424     return method_val;
425   }
426 };
427 
428 /// make_custom_class() is a convenient way to create an instance of a
429 /// registered custom class and wrap it in an IValue, for example when you want
430 /// to pass the object to TorchScript. Its syntax is equivalent to APIs like
431 /// `std::make_shared<>` or `c10::make_intrusive<>`.
432 ///
433 /// For example, if you have a custom C++ class that can be constructed from an
434 /// `int` and `std::string`, you might use this API like so:
435 ///
436 ///     IValue custom_class_iv = torch::make_custom_class<MyClass>(3,
437 ///     "foobarbaz");
438 template <typename CurClass, typename... CtorArgs>
make_custom_class(CtorArgs &&...args)439 c10::IValue make_custom_class(CtorArgs&&... args) {
440   auto userClassInstance =
441       c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
442   return c10::IValue(std::move(userClassInstance));
443 }
444 
445 // Alternative api for creating a torchbind class over torch::class_ this api is
446 // preffered to prevent size regressions on Edge usecases. Must be used in
447 // conjunction with TORCH_SELECTIVE_CLASS macro aka
448 // selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))
449 template <class CurClass>
selective_class_(const std::string & namespace_name,detail::SelectiveStr<true> className)450 inline class_<CurClass> selective_class_(
451     const std::string& namespace_name,
452     detail::SelectiveStr<true> className) {
453   auto class_name = std::string(className.operator const char*());
454   return torch::class_<CurClass>(namespace_name, class_name);
455 }
456 
457 template <class CurClass>
selective_class_(const std::string &,detail::SelectiveStr<false>)458 inline detail::ClassNotSelected selective_class_(
459     const std::string&,
460     detail::SelectiveStr<false>) {
461   return detail::ClassNotSelected();
462 }
463 
464 // jit namespace for backward-compatibility
465 // We previously defined everything in torch::jit but moved it out to
466 // better reflect that these features are not limited only to TorchScript
467 namespace jit {
468 
469 using ::torch::class_;
470 using ::torch::getCustomClass;
471 using ::torch::init;
472 using ::torch::isCustomClass;
473 
474 } // namespace jit
475 
476 template <class CurClass>
class_(const std::string & className)477 inline class_<CurClass> Library::class_(const std::string& className) {
478   TORCH_CHECK(
479       kind_ == DEF || kind_ == FRAGMENT,
480       "class_(\"",
481       className,
482       "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block.  "
483       "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  "
484       "(Error occurred at ",
485       file_,
486       ":",
487       line_,
488       ")");
489   TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
490   return torch::class_<CurClass>(*ns_, className);
491 }
492 
493 const std::unordered_set<std::string> getAllCustomClassesNames();
494 
495 template <class CurClass>
class_(detail::SelectiveStr<true> className)496 inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
497   auto class_name = std::string(className.operator const char*());
498   TORCH_CHECK(
499       kind_ == DEF || kind_ == FRAGMENT,
500       "class_(\"",
501       class_name,
502       "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block.  "
503       "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  "
504       "(Error occurred at ",
505       file_,
506       ":",
507       line_,
508       ")");
509   TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
510   return torch::class_<CurClass>(*ns_, class_name);
511 }
512 
513 template <class CurClass>
class_(detail::SelectiveStr<false>)514 inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
515   return detail::ClassNotSelected();
516 }
517 
518 } // namespace torch
519