1 #pragma once
2
3 #include <ATen/core/builtin_function.h>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/class_type.h>
7 #include <ATen/core/op_registration/infer_schema.h>
8 #include <ATen/core/stack.h>
9 #include <c10/util/C++17.h>
10 #include <c10/util/Metaprogramming.h>
11 #include <c10/util/TypeList.h>
12 #include <c10/util/TypeTraits.h>
13 #include <torch/custom_class_detail.h>
14 #include <torch/library.h>
15 #include <sstream>
16
17 namespace torch {
18
19 /// This function is used in conjunction with `class_::def()` to register
20 /// a constructor for a given C++ class type. For example,
21 /// `torch::init<int, std::string>()` would register a two-argument constructor
22 /// taking an `int` and a `std::string` as argument.
23 template <class... Types>
init()24 detail::types<void, Types...> init() {
25 return detail::types<void, Types...>{};
26 }
27
28 template <typename Func, typename... ParameterTypeList>
29 struct InitLambda {
30 Func f;
31 };
32
33 template <typename Func>
decltype(auto)34 decltype(auto) init(Func&& f) {
35 using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
36 using ParameterTypeList = typename InitTraits::parameter_types;
37
38 InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
39 return init;
40 }
41
42 /// Entry point for custom C++ class registration. To register a C++ class
43 /// in PyTorch, instantiate `torch::class_` with the desired class as the
44 /// template parameter. Typically, this instantiation should be done in
45 /// the initialization of a global variable, so that the class will be
46 /// made available on dynamic library loading without any additional API
47 /// calls needed. For example, to register a class named Foo, you might
48 /// create a global variable like so:
49 ///
50 /// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
51 /// .def("myMethod", &Foo::myMethod)
52 /// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
53 /// // Do something with `self`
54 /// });
55 ///
56 /// In addition to registering the class, this registration also chains
57 /// `def()` calls to register methods. `myMethod()` is registered with
58 /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
59 /// is registered with a C++ lambda expression.
60 template <class CurClass>
61 class class_ : public ::torch::detail::class_base {
62 static_assert(
63 std::is_base_of_v<CustomClassHolder, CurClass>,
64 "torch::class_<T> requires T to inherit from CustomClassHolder");
65
66 public:
67 /// This constructor actually registers the class type.
68 /// String argument `namespaceName` is an identifier for the
69 /// namespace you would like this class to appear in.
70 /// String argument `className` is the name you would like to
71 /// see this class exposed as in Python and TorchScript. For example, if
72 /// you pass `foo` as the namespace name and `Bar` as the className, the
73 /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
74 explicit class_(
75 const std::string& namespaceName,
76 const std::string& className,
77 std::string doc_string = "")
class_base(namespaceName,className,std::move (doc_string),typeid (c10::intrusive_ptr<CurClass>),typeid (c10::tagged_capsule<CurClass>))78 : class_base(
79 namespaceName,
80 className,
81 std::move(doc_string),
82 typeid(c10::intrusive_ptr<CurClass>),
83 typeid(c10::tagged_capsule<CurClass>)) {}
84
85 /// def() can be used in conjunction with `torch::init()` to register
86 /// a constructor for a given C++ class type. For example, passing
87 /// `torch::init<int, std::string>()` would register a two-argument
88 /// constructor taking an `int` and a `std::string` as argument.
89 template <typename... Types>
90 class_& def(
91 torch::detail::types<void, Types...>,
92 std::string doc_string = "",
93 std::initializer_list<arg> default_args =
94 {}) { // Used in combination with
95 // torch::init<...>()
96 auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
97 auto classObj = c10::make_intrusive<CurClass>(args...);
98 auto object = self.ivalue.toObject();
99 object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
100 };
101
102 defineMethod(
103 "__init__",
104 std::move(func),
105 std::move(doc_string),
106 default_args);
107 return *this;
108 }
109
110 // Used in combination with torch::init([]lambda(){......})
111 template <typename Func, typename... ParameterTypes>
112 class_& def(
113 InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
114 std::string doc_string = "",
115 std::initializer_list<arg> default_args = {}) {
116 auto init_lambda_wrapper = [func = std::move(init.f)](
117 c10::tagged_capsule<CurClass> self,
118 ParameterTypes... arg) {
119 c10::intrusive_ptr<CurClass> classObj =
120 at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
121 auto object = self.ivalue.toObject();
122 object->setSlot(0, c10::IValue::make_capsule(classObj));
123 };
124
125 defineMethod(
126 "__init__",
127 std::move(init_lambda_wrapper),
128 std::move(doc_string),
129 default_args);
130
131 return *this;
132 }
133
134 /// This is the normal method registration API. `name` is the name that
135 /// the method will be made accessible by in Python and TorchScript.
136 /// `f` is a callable object that defines the method. Typically `f`
137 /// will either be a pointer to a method on `CurClass`, or a lambda
138 /// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
139 /// argument (emulating a `this` argument in a C++ method.)
140 ///
141 /// Examples:
142 ///
143 /// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
144 /// // Python and TorchScript
145 /// .def("call_foo", &Foo::foo)
146 ///
147 /// // Exposes the given lambda expression as method `call_lambda()`
148 /// // in Python and TorchScript.
149 /// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
150 /// // do something
151 /// })
152 template <typename Func>
153 class_& def(
154 std::string name,
155 Func f,
156 std::string doc_string = "",
157 std::initializer_list<arg> default_args = {}) {
158 auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
159 defineMethod(
160 std::move(name),
161 std::move(wrapped_f),
162 std::move(doc_string),
163 default_args);
164 return *this;
165 }
166
167 /// Method registration API for static methods.
168 template <typename Func>
169 class_& def_static(std::string name, Func func, std::string doc_string = "") {
170 auto qualMethodName = qualClassName + "." + name;
171 auto schema =
172 c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
173
174 auto wrapped_func =
175 [func = std::move(func)](jit::Stack& stack) mutable -> void {
176 using RetType =
177 typename c10::guts::infer_function_traits_t<Func>::return_type;
178 detail::BoxedProxy<RetType, Func>()(stack, func);
179 };
180 auto method = std::make_unique<jit::BuiltinOpFunction>(
181 std::move(qualMethodName),
182 std::move(schema),
183 std::move(wrapped_func),
184 std::move(doc_string));
185
186 classTypePtr->addStaticMethod(method.get());
187 registerCustomClassMethod(std::move(method));
188 return *this;
189 }
190
191 /// Property registration API for properties with both getter and setter
192 /// functions.
193 template <typename GetterFunc, typename SetterFunc>
194 class_& def_property(
195 const std::string& name,
196 GetterFunc getter_func,
197 SetterFunc setter_func,
198 std::string doc_string = "") {
199 torch::jit::Function* getter{};
200 torch::jit::Function* setter{};
201
202 auto wrapped_getter =
203 detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
204 getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
205
206 auto wrapped_setter =
207 detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
208 setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
209
210 classTypePtr->addProperty(name, getter, setter);
211 return *this;
212 }
213
214 /// Property registration API for properties with only getter function.
215 template <typename GetterFunc>
216 class_& def_property(
217 const std::string& name,
218 GetterFunc getter_func,
219 std::string doc_string = "") {
220 torch::jit::Function* getter{};
221
222 auto wrapped_getter =
223 detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
224 getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
225
226 classTypePtr->addProperty(name, getter, nullptr);
227 return *this;
228 }
229
230 /// Property registration API for properties with read-write access.
231 template <typename T>
def_readwrite(const std::string & name,T CurClass::* field)232 class_& def_readwrite(const std::string& name, T CurClass::*field) {
233 auto getter_func = [field =
234 field](const c10::intrusive_ptr<CurClass>& self) {
235 return self.get()->*field;
236 };
237
238 auto setter_func = [field = field](
239 const c10::intrusive_ptr<CurClass>& self, T value) {
240 self.get()->*field = value;
241 };
242
243 return def_property(name, getter_func, setter_func);
244 }
245
246 /// Property registration API for properties with read-only access.
247 template <typename T>
def_readonly(const std::string & name,T CurClass::* field)248 class_& def_readonly(const std::string& name, T CurClass::*field) {
249 auto getter_func =
250 [field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
251 return self.get()->*field;
252 };
253
254 return def_property(name, getter_func);
255 }
256
257 /// This is an unsafe method registration API added for adding custom JIT
258 /// backend support via custom C++ classes. It is not for general purpose use.
259 class_& _def_unboxed(
260 const std::string& name,
261 std::function<void(jit::Stack&)> func,
262 c10::FunctionSchema schema,
263 std::string doc_string = "") {
264 auto method = std::make_unique<jit::BuiltinOpFunction>(
265 qualClassName + "." + name,
266 std::move(schema),
267 std::move(func),
268 std::move(doc_string));
269 classTypePtr->addMethod(method.get());
270 registerCustomClassMethod(std::move(method));
271 return *this;
272 }
273
274 /// def_pickle() is used to define exactly what state gets serialized
275 /// or deserialized for a given instance of a custom C++ class in
276 /// Python or TorchScript. This protocol is equivalent to the Pickle
277 /// concept of `__getstate__` and `__setstate__` from Python
278 /// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
279 ///
280 /// Currently, both the `get_state` and `set_state` callables must be
281 /// C++ lambda expressions. They should have the following signatures,
282 /// where `CurClass` is the class you're registering and `T1` is some object
283 /// that encapsulates the state of the object.
284 ///
285 /// __getstate__(intrusive_ptr<CurClass>) -> T1
286 /// __setstate__(T2) -> intrusive_ptr<CurClass>
287 ///
288 /// `T1` must be an object that is convertable to IValue by the same rules
289 /// for custom op/method registration.
290 ///
291 /// For the common case, T1 == T2. T1 can also be a subtype of T2. An
292 /// example where it makes sense for T1 and T2 to differ is if __setstate__
293 /// handles legacy formats in a backwards compatible way.
294 ///
295 /// Example:
296 ///
297 /// .def_pickle(
298 /// // __getstate__
299 /// [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
300 /// return self->stack_;
301 /// },
302 /// [](std::vector<std::string> state) { // __setstate__
303 /// return c10::make_intrusive<MyStackClass<std::string>>(
304 /// std::vector<std::string>{"i", "was", "deserialized"});
305 /// })
306 template <typename GetStateFn, typename SetStateFn>
307 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
def_pickle(GetStateFn && get_state,SetStateFn && set_state)308 class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
309 static_assert(
310 c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
311 c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
312 "def_pickle() currently only supports lambdas as "
313 "__getstate__ and __setstate__ arguments.");
314 def("__getstate__", std::forward<GetStateFn>(get_state));
315
316 // __setstate__ needs to be registered with some custom handling:
317 // We need to wrap the invocation of the user-provided function
318 // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
319 // and assign it to the `capsule` attribute.
320 using SetStateTraits =
321 c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
322 using SetStateArg = typename c10::guts::typelist::head_t<
323 typename SetStateTraits::parameter_types>;
324 auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
325 c10::tagged_capsule<CurClass> self,
326 SetStateArg arg) {
327 c10::intrusive_ptr<CurClass> classObj =
328 at::guts::invoke(set_state, std::move(arg));
329 auto object = self.ivalue.toObject();
330 object->setSlot(0, c10::IValue::make_capsule(classObj));
331 };
332 defineMethod(
333 "__setstate__",
334 detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
335 std::move(setstate_wrapper)));
336
337 // type validation
338 auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
339 #ifndef STRIP_ERROR_MESSAGES
340 auto format_getstate_schema = [&getstate_schema]() {
341 std::stringstream ss;
342 ss << getstate_schema;
343 return ss.str();
344 };
345 #endif
346 TORCH_CHECK(
347 getstate_schema.arguments().size() == 1,
348 "__getstate__ should take exactly one argument: self. Got: ",
349 format_getstate_schema());
350 auto first_arg_type = getstate_schema.arguments().at(0).type();
351 TORCH_CHECK(
352 *first_arg_type == *classTypePtr,
353 "self argument of __getstate__ must be the custom class type. Got ",
354 first_arg_type->repr_str());
355 TORCH_CHECK(
356 getstate_schema.returns().size() == 1,
357 "__getstate__ should return exactly one value for serialization. Got: ",
358 format_getstate_schema());
359
360 auto ser_type = getstate_schema.returns().at(0).type();
361 auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
362 auto arg_type = setstate_schema.arguments().at(1).type();
363 TORCH_CHECK(
364 ser_type->isSubtypeOf(*arg_type),
365 "__getstate__'s return type should be a subtype of "
366 "input argument of __setstate__. Got ",
367 ser_type->repr_str(),
368 " but expected ",
369 arg_type->repr_str());
370
371 return *this;
372 }
373
374 private:
375 template <typename Func>
376 torch::jit::Function* defineMethod(
377 std::string name,
378 Func func,
379 std::string doc_string = "",
380 std::initializer_list<arg> default_args = {}) {
381 auto qualMethodName = qualClassName + "." + name;
382 auto schema =
383 c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
384
385 // If default values are provided for function arguments, there must be
386 // none (no default values) or default values for all function
387 // arguments, except for self. This is because argument names are not
388 // extracted by inferFunctionSchemaSingleReturn, and so there must be a
389 // torch::arg instance in default_args even for arguments that do not
390 // have an actual default value provided.
391 TORCH_CHECK(
392 default_args.size() == 0 ||
393 default_args.size() == schema.arguments().size() - 1,
394 "Default values must be specified for none or all arguments");
395
396 // If there are default args, copy the argument names and default values to
397 // the function schema.
398 if (default_args.size() > 0) {
399 schema = withNewArguments(schema, default_args);
400 }
401
402 auto wrapped_func =
403 [func = std::move(func)](jit::Stack& stack) mutable -> void {
404 // TODO: we need to figure out how to profile calls to custom functions
405 // like this! Currently can't do it because the profiler stuff is in
406 // libtorch and not ATen
407 using RetType =
408 typename c10::guts::infer_function_traits_t<Func>::return_type;
409 detail::BoxedProxy<RetType, Func>()(stack, func);
410 };
411 auto method = std::make_unique<jit::BuiltinOpFunction>(
412 qualMethodName,
413 std::move(schema),
414 std::move(wrapped_func),
415 std::move(doc_string));
416
417 // Register the method here to keep the Method alive.
418 // ClassTypes do not hold ownership of their methods (normally it
419 // those are held by the CompilationUnit), so we need a proxy for
420 // that behavior here.
421 auto method_val = method.get();
422 classTypePtr->addMethod(method_val);
423 registerCustomClassMethod(std::move(method));
424 return method_val;
425 }
426 };
427
428 /// make_custom_class() is a convenient way to create an instance of a
429 /// registered custom class and wrap it in an IValue, for example when you want
430 /// to pass the object to TorchScript. Its syntax is equivalent to APIs like
431 /// `std::make_shared<>` or `c10::make_intrusive<>`.
432 ///
433 /// For example, if you have a custom C++ class that can be constructed from an
434 /// `int` and `std::string`, you might use this API like so:
435 ///
436 /// IValue custom_class_iv = torch::make_custom_class<MyClass>(3,
437 /// "foobarbaz");
438 template <typename CurClass, typename... CtorArgs>
make_custom_class(CtorArgs &&...args)439 c10::IValue make_custom_class(CtorArgs&&... args) {
440 auto userClassInstance =
441 c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
442 return c10::IValue(std::move(userClassInstance));
443 }
444
445 // Alternative api for creating a torchbind class over torch::class_ this api is
446 // preffered to prevent size regressions on Edge usecases. Must be used in
447 // conjunction with TORCH_SELECTIVE_CLASS macro aka
448 // selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))
449 template <class CurClass>
selective_class_(const std::string & namespace_name,detail::SelectiveStr<true> className)450 inline class_<CurClass> selective_class_(
451 const std::string& namespace_name,
452 detail::SelectiveStr<true> className) {
453 auto class_name = std::string(className.operator const char*());
454 return torch::class_<CurClass>(namespace_name, class_name);
455 }
456
457 template <class CurClass>
selective_class_(const std::string &,detail::SelectiveStr<false>)458 inline detail::ClassNotSelected selective_class_(
459 const std::string&,
460 detail::SelectiveStr<false>) {
461 return detail::ClassNotSelected();
462 }
463
464 // jit namespace for backward-compatibility
465 // We previously defined everything in torch::jit but moved it out to
466 // better reflect that these features are not limited only to TorchScript
467 namespace jit {
468
469 using ::torch::class_;
470 using ::torch::getCustomClass;
471 using ::torch::init;
472 using ::torch::isCustomClass;
473
474 } // namespace jit
475
476 template <class CurClass>
class_(const std::string & className)477 inline class_<CurClass> Library::class_(const std::string& className) {
478 TORCH_CHECK(
479 kind_ == DEF || kind_ == FRAGMENT,
480 "class_(\"",
481 className,
482 "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
483 "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
484 "(Error occurred at ",
485 file_,
486 ":",
487 line_,
488 ")");
489 TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
490 return torch::class_<CurClass>(*ns_, className);
491 }
492
493 const std::unordered_set<std::string> getAllCustomClassesNames();
494
495 template <class CurClass>
class_(detail::SelectiveStr<true> className)496 inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
497 auto class_name = std::string(className.operator const char*());
498 TORCH_CHECK(
499 kind_ == DEF || kind_ == FRAGMENT,
500 "class_(\"",
501 class_name,
502 "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
503 "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
504 "(Error occurred at ",
505 file_,
506 ":",
507 line_,
508 ")");
509 TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
510 return torch::class_<CurClass>(*ns_, class_name);
511 }
512
513 template <class CurClass>
class_(detail::SelectiveStr<false>)514 inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
515 return detail::ClassNotSelected();
516 }
517
518 } // namespace torch
519