xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_custom_class.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/python/pybind_utils.h>
2 #include <torch/csrc/jit/python/python_custom_class.h>
3 
4 #include <torch/csrc/jit/frontend/sugared_value.h>
5 
6 #include <fmt/format.h>
7 
8 namespace torch::jit {
9 
10 struct CustomMethodProxy;
11 struct CustomObjectProxy;
12 
__call__(const py::args & args,const py::kwargs & kwargs)13 py::object ScriptClass::__call__(
14     const py::args& args,
15     const py::kwargs& kwargs) {
16   auto instance =
17       Object(at::ivalue::Object::create(class_type_, /*numSlots=*/1));
18   Function* init_fn = instance.type()->findMethod("__init__");
19   TORCH_CHECK(
20       init_fn,
21       fmt::format(
22           "Custom C++ class: '{}' does not have an '__init__' method bound. "
23           "Did you forget to add '.def(torch::init<...>)' to its registration?",
24           instance.type()->repr_str()));
25   Method init_method(instance._ivalue(), init_fn);
26   invokeScriptMethodFromPython(init_method, args, kwargs);
27   return py::cast(instance);
28 }
29 
30 /// Variant of StrongFunctionPtr, but for static methods of custom classes.
31 /// They do not belong to compilation units (the custom class method registry
32 /// serves that purpose in this case), so StrongFunctionPtr cannot be used here.
33 /// While it is usually unsafe to carry a raw pointer like this, the custom
34 /// class method registry that owns the pointer is never destroyed.
35 struct ScriptClassFunctionPtr {
ScriptClassFunctionPtrtorch::jit::ScriptClassFunctionPtr36   ScriptClassFunctionPtr(Function* function) : function_(function) {
37     TORCH_INTERNAL_ASSERT(function_);
38   }
39   Function* function_;
40 };
41 
initPythonCustomClassBindings(PyObject * module)42 void initPythonCustomClassBindings(PyObject* module) {
43   auto m = py::handle(module).cast<py::module>();
44 
45   py::class_<ScriptClassFunctionPtr>(
46       m, "ScriptClassFunction", py::dynamic_attr())
47       .def("__call__", [](py::args args, const py::kwargs& kwargs) {
48         auto strongPtr = py::cast<ScriptClassFunctionPtr>(args[0]);
49         Function& callee = *strongPtr.function_;
50         py::object result = invokeScriptFunctionFromPython(
51             callee, tuple_slice(std::move(args), 1), kwargs);
52         return result;
53       });
54 
55   py::class_<ScriptClass>(m, "ScriptClass")
56       .def("__call__", &ScriptClass::__call__)
57       .def(
58           "__getattr__",
59           [](ScriptClass& self, const std::string& name) {
60             // Define __getattr__ so that static functions of custom classes can
61             // be used in regular Python.
62             auto type = self.class_type_.type_->castRaw<ClassType>();
63             TORCH_INTERNAL_ASSERT(type);
64             auto* fn = type->findStaticMethod(name);
65             if (fn) {
66               return ScriptClassFunctionPtr(fn);
67             }
68 
69             throw AttributeError("%s does not exist", name.c_str());
70           })
71       .def_property_readonly("__doc__", [](const ScriptClass& self) {
72         return self.class_type_.type_->expectRef<ClassType>().doc_string();
73       });
74 
75   // This function returns a ScriptClass that wraps the constructor
76   // of the given class, specified by the qualified name passed in.
77   //
78   // This is to emulate the behavior in python where instantiation
79   // of a class is a call to a code object for the class, where that
80   // code object in turn calls __init__. Rather than calling __init__
81   // directly, we need a wrapper that at least returns the instance
82   // rather than the None return value from __init__
83   m.def(
84       "_get_custom_class_python_wrapper",
85       [](const std::string& ns, const std::string& qualname) {
86         std::string full_qualname =
87             "__torch__.torch.classes." + ns + "." + qualname;
88         auto named_type = getCustomClass(full_qualname);
89         TORCH_CHECK(
90             named_type,
91             fmt::format(
92                 "Tried to instantiate class '{}.{}', but it does not exist! "
93                 "Ensure that it is registered via torch::class_",
94                 ns,
95                 qualname));
96         c10::ClassTypePtr class_type = named_type->cast<ClassType>();
97         return ScriptClass(c10::StrongTypePtr(
98             std::shared_ptr<CompilationUnit>(), std::move(class_type)));
99       });
100 }
101 
102 } // namespace torch::jit
103