#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { /// This function is used in conjunction with `class_::def()` to register /// a constructor for a given C++ class type. For example, /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template detail::types init() { return detail::types{}; } template struct InitLambda { Func f; }; template decltype(auto) init(Func&& f) { using InitTraits = c10::guts::infer_function_traits_t>; using ParameterTypeList = typename InitTraits::parameter_types; InitLambda init{std::forward(f)}; return init; } /// Entry point for custom C++ class registration. To register a C++ class /// in PyTorch, instantiate `torch::class_` with the desired class as the /// template parameter. Typically, this instantiation should be done in /// the initialization of a global variable, so that the class will be /// made available on dynamic library loading without any additional API /// calls needed. For example, to register a class named Foo, you might /// create a global variable like so: /// /// static auto register_foo = torch::class_("myclasses", "Foo") /// .def("myMethod", &Foo::myMethod) /// .def("lambdaMethod", [](const c10::intrusive_ptr& self) { /// // Do something with `self` /// }); /// /// In addition to registering the class, this registration also chains /// `def()` calls to register methods. `myMethod()` is registered with /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()` /// is registered with a C++ lambda expression. template class class_ : public ::torch::detail::class_base { static_assert( std::is_base_of_v, "torch::class_ requires T to inherit from CustomClassHolder"); public: /// This constructor actually registers the class type. /// String argument `namespaceName` is an identifier for the /// namespace you would like this class to appear in. /// String argument `className` is the name you would like to /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript explicit class_( const std::string& namespaceName, const std::string& className, std::string doc_string = "") : class_base( namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr), typeid(c10::tagged_capsule)) {} /// def() can be used in conjunction with `torch::init()` to register /// a constructor for a given C++ class type. For example, passing /// `torch::init()` would register a two-argument /// constructor taking an `int` and a `std::string` as argument. template class_& def( torch::detail::types, std::string doc_string = "", std::initializer_list default_args = {}) { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; defineMethod( "__init__", std::move(func), std::move(doc_string), default_args); return *this; } // Used in combination with torch::init([]lambda(){......}) template class_& def( InitLambda> init, std::string doc_string = "", std::initializer_list default_args = {}) { auto init_lambda_wrapper = [func = std::move(init.f)]( c10::tagged_capsule self, ParameterTypes... arg) { c10::intrusive_ptr classObj = at::guts::invoke(func, std::forward(arg)...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; defineMethod( "__init__", std::move(init_lambda_wrapper), std::move(doc_string), default_args); return *this; } /// This is the normal method registration API. `name` is the name that /// the method will be made accessible by in Python and TorchScript. /// `f` is a callable object that defines the method. Typically `f` /// will either be a pointer to a method on `CurClass`, or a lambda /// expression that takes a `c10::intrusive_ptr` as the first /// argument (emulating a `this` argument in a C++ method.) /// /// Examples: /// /// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in /// // Python and TorchScript /// .def("call_foo", &Foo::foo) /// /// // Exposes the given lambda expression as method `call_lambda()` /// // in Python and TorchScript. /// .def("call_lambda", [](const c10::intrusive_ptr& self) { /// // do something /// }) template class_& def( std::string name, Func f, std::string doc_string = "", std::initializer_list default_args = {}) { auto wrapped_f = detail::wrap_func(std::move(f)); defineMethod( std::move(name), std::move(wrapped_f), std::move(doc_string), default_args); return *this; } /// Method registration API for static methods. template class_& def_static(std::string name, Func func, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void { using RetType = typename c10::guts::infer_function_traits_t::return_type; detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( std::move(qualMethodName), std::move(schema), std::move(wrapped_func), std::move(doc_string)); classTypePtr->addStaticMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; } /// Property registration API for properties with both getter and setter /// functions. template class_& def_property( const std::string& name, GetterFunc getter_func, SetterFunc setter_func, std::string doc_string = "") { torch::jit::Function* getter{}; torch::jit::Function* setter{}; auto wrapped_getter = detail::wrap_func(std::move(getter_func)); getter = defineMethod(name + "_getter", wrapped_getter, doc_string); auto wrapped_setter = detail::wrap_func(std::move(setter_func)); setter = defineMethod(name + "_setter", wrapped_setter, doc_string); classTypePtr->addProperty(name, getter, setter); return *this; } /// Property registration API for properties with only getter function. template class_& def_property( const std::string& name, GetterFunc getter_func, std::string doc_string = "") { torch::jit::Function* getter{}; auto wrapped_getter = detail::wrap_func(std::move(getter_func)); getter = defineMethod(name + "_getter", wrapped_getter, doc_string); classTypePtr->addProperty(name, getter, nullptr); return *this; } /// Property registration API for properties with read-write access. template class_& def_readwrite(const std::string& name, T CurClass::*field) { auto getter_func = [field = field](const c10::intrusive_ptr& self) { return self.get()->*field; }; auto setter_func = [field = field]( const c10::intrusive_ptr& self, T value) { self.get()->*field = value; }; return def_property(name, getter_func, setter_func); } /// Property registration API for properties with read-only access. template class_& def_readonly(const std::string& name, T CurClass::*field) { auto getter_func = [field = std::move(field)](const c10::intrusive_ptr& self) { return self.get()->*field; }; return def_property(name, getter_func); } /// This is an unsafe method registration API added for adding custom JIT /// backend support via custom C++ classes. It is not for general purpose use. class_& _def_unboxed( const std::string& name, std::function func, c10::FunctionSchema schema, std::string doc_string = "") { auto method = std::make_unique( qualClassName + "." + name, std::move(schema), std::move(func), std::move(doc_string)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; } /// def_pickle() is used to define exactly what state gets serialized /// or deserialized for a given instance of a custom C++ class in /// Python or TorchScript. This protocol is equivalent to the Pickle /// concept of `__getstate__` and `__setstate__` from Python /// (https://docs.python.org/2/library/pickle.html#object.__getstate__) /// /// Currently, both the `get_state` and `set_state` callables must be /// C++ lambda expressions. They should have the following signatures, /// where `CurClass` is the class you're registering and `T1` is some object /// that encapsulates the state of the object. /// /// __getstate__(intrusive_ptr) -> T1 /// __setstate__(T2) -> intrusive_ptr /// /// `T1` must be an object that is convertable to IValue by the same rules /// for custom op/method registration. /// /// For the common case, T1 == T2. T1 can also be a subtype of T2. An /// example where it makes sense for T1 and T2 to differ is if __setstate__ /// handles legacy formats in a backwards compatible way. /// /// Example: /// /// .def_pickle( /// // __getstate__ /// [](const c10::intrusive_ptr>& self) { /// return self->stack_; /// }, /// [](std::vector state) { // __setstate__ /// return c10::make_intrusive>( /// std::vector{"i", "was", "deserialized"}); /// }) template // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { static_assert( c10::guts::is_stateless_lambda>::value && c10::guts::is_stateless_lambda>::value, "def_pickle() currently only supports lambdas as " "__getstate__ and __setstate__ arguments."); def("__getstate__", std::forward(get_state)); // __setstate__ needs to be registered with some custom handling: // We need to wrap the invocation of the user-provided function // such that we take the return value (i.e. c10::intrusive_ptr) // and assign it to the `capsule` attribute. using SetStateTraits = c10::guts::infer_function_traits_t>; using SetStateArg = typename c10::guts::typelist::head_t< typename SetStateTraits::parameter_types>; auto setstate_wrapper = [set_state = std::forward(set_state)]( c10::tagged_capsule self, SetStateArg arg) { c10::intrusive_ptr classObj = at::guts::invoke(set_state, std::move(arg)); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; defineMethod( "__setstate__", detail::wrap_func( std::move(setstate_wrapper))); // type validation auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema(); #ifndef STRIP_ERROR_MESSAGES auto format_getstate_schema = [&getstate_schema]() { std::stringstream ss; ss << getstate_schema; return ss.str(); }; #endif TORCH_CHECK( getstate_schema.arguments().size() == 1, "__getstate__ should take exactly one argument: self. Got: ", format_getstate_schema()); auto first_arg_type = getstate_schema.arguments().at(0).type(); TORCH_CHECK( *first_arg_type == *classTypePtr, "self argument of __getstate__ must be the custom class type. Got ", first_arg_type->repr_str()); TORCH_CHECK( getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", format_getstate_schema()); auto ser_type = getstate_schema.returns().at(0).type(); auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( ser_type->isSubtypeOf(*arg_type), "__getstate__'s return type should be a subtype of " "input argument of __setstate__. Got ", ser_type->repr_str(), " but expected ", arg_type->repr_str()); return *this; } private: template torch::jit::Function* defineMethod( std::string name, Func func, std::string doc_string = "", std::initializer_list default_args = {}) { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); // If default values are provided for function arguments, there must be // none (no default values) or default values for all function // arguments, except for self. This is because argument names are not // extracted by inferFunctionSchemaSingleReturn, and so there must be a // torch::arg instance in default_args even for arguments that do not // have an actual default value provided. TORCH_CHECK( default_args.size() == 0 || default_args.size() == schema.arguments().size() - 1, "Default values must be specified for none or all arguments"); // If there are default args, copy the argument names and default values to // the function schema. if (default_args.size() > 0) { schema = withNewArguments(schema, default_args); } auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void { // TODO: we need to figure out how to profile calls to custom functions // like this! Currently can't do it because the profiler stuff is in // libtorch and not ATen using RetType = typename c10::guts::infer_function_traits_t::return_type; detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it // those are held by the CompilationUnit), so we need a proxy for // that behavior here. auto method_val = method.get(); classTypePtr->addMethod(method_val); registerCustomClassMethod(std::move(method)); return method_val; } }; /// make_custom_class() is a convenient way to create an instance of a /// registered custom class and wrap it in an IValue, for example when you want /// to pass the object to TorchScript. Its syntax is equivalent to APIs like /// `std::make_shared<>` or `c10::make_intrusive<>`. /// /// For example, if you have a custom C++ class that can be constructed from an /// `int` and `std::string`, you might use this API like so: /// /// IValue custom_class_iv = torch::make_custom_class(3, /// "foobarbaz"); template c10::IValue make_custom_class(CtorArgs&&... args) { auto userClassInstance = c10::make_intrusive(std::forward(args)...); return c10::IValue(std::move(userClassInstance)); } // Alternative api for creating a torchbind class over torch::class_ this api is // preffered to prevent size regressions on Edge usecases. Must be used in // conjunction with TORCH_SELECTIVE_CLASS macro aka // selective_class("foo_namespace", TORCH_SELECTIVE_CLASS("foo")) template inline class_ selective_class_( const std::string& namespace_name, detail::SelectiveStr className) { auto class_name = std::string(className.operator const char*()); return torch::class_(namespace_name, class_name); } template inline detail::ClassNotSelected selective_class_( const std::string&, detail::SelectiveStr) { return detail::ClassNotSelected(); } // jit namespace for backward-compatibility // We previously defined everything in torch::jit but moved it out to // better reflect that these features are not limited only to TorchScript namespace jit { using ::torch::class_; using ::torch::getCustomClass; using ::torch::init; using ::torch::isCustomClass; } // namespace jit template inline class_ Library::class_(const std::string& className) { TORCH_CHECK( kind_ == DEF || kind_ == FRAGMENT, "class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " "(Error occurred at ", file_, ":", line_, ")"); TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); return torch::class_(*ns_, className); } const std::unordered_set getAllCustomClassesNames(); template inline class_ Library::class_(detail::SelectiveStr className) { auto class_name = std::string(className.operator const char*()); TORCH_CHECK( kind_ == DEF || kind_ == FRAGMENT, "class_(\"", class_name, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " "(Error occurred at ", file_, ":", line_, ")"); TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); return torch::class_(*ns_, class_name); } template inline detail::ClassNotSelected Library::class_(detail::SelectiveStr) { return detail::ClassNotSelected(); } } // namespace torch