1 #include <torch/csrc/jit/python/pybind_utils.h>
2 #include <torch/csrc/jit/python/python_custom_class.h>
3
4 #include <torch/csrc/jit/frontend/sugared_value.h>
5
6 #include <fmt/format.h>
7
8 namespace torch::jit {
9
10 struct CustomMethodProxy;
11 struct CustomObjectProxy;
12
__call__(const py::args & args,const py::kwargs & kwargs)13 py::object ScriptClass::__call__(
14 const py::args& args,
15 const py::kwargs& kwargs) {
16 auto instance =
17 Object(at::ivalue::Object::create(class_type_, /*numSlots=*/1));
18 Function* init_fn = instance.type()->findMethod("__init__");
19 TORCH_CHECK(
20 init_fn,
21 fmt::format(
22 "Custom C++ class: '{}' does not have an '__init__' method bound. "
23 "Did you forget to add '.def(torch::init<...>)' to its registration?",
24 instance.type()->repr_str()));
25 Method init_method(instance._ivalue(), init_fn);
26 invokeScriptMethodFromPython(init_method, args, kwargs);
27 return py::cast(instance);
28 }
29
30 /// Variant of StrongFunctionPtr, but for static methods of custom classes.
31 /// They do not belong to compilation units (the custom class method registry
32 /// serves that purpose in this case), so StrongFunctionPtr cannot be used here.
33 /// While it is usually unsafe to carry a raw pointer like this, the custom
34 /// class method registry that owns the pointer is never destroyed.
35 struct ScriptClassFunctionPtr {
ScriptClassFunctionPtrtorch::jit::ScriptClassFunctionPtr36 ScriptClassFunctionPtr(Function* function) : function_(function) {
37 TORCH_INTERNAL_ASSERT(function_);
38 }
39 Function* function_;
40 };
41
initPythonCustomClassBindings(PyObject * module)42 void initPythonCustomClassBindings(PyObject* module) {
43 auto m = py::handle(module).cast<py::module>();
44
45 py::class_<ScriptClassFunctionPtr>(
46 m, "ScriptClassFunction", py::dynamic_attr())
47 .def("__call__", [](py::args args, const py::kwargs& kwargs) {
48 auto strongPtr = py::cast<ScriptClassFunctionPtr>(args[0]);
49 Function& callee = *strongPtr.function_;
50 py::object result = invokeScriptFunctionFromPython(
51 callee, tuple_slice(std::move(args), 1), kwargs);
52 return result;
53 });
54
55 py::class_<ScriptClass>(m, "ScriptClass")
56 .def("__call__", &ScriptClass::__call__)
57 .def(
58 "__getattr__",
59 [](ScriptClass& self, const std::string& name) {
60 // Define __getattr__ so that static functions of custom classes can
61 // be used in regular Python.
62 auto type = self.class_type_.type_->castRaw<ClassType>();
63 TORCH_INTERNAL_ASSERT(type);
64 auto* fn = type->findStaticMethod(name);
65 if (fn) {
66 return ScriptClassFunctionPtr(fn);
67 }
68
69 throw AttributeError("%s does not exist", name.c_str());
70 })
71 .def_property_readonly("__doc__", [](const ScriptClass& self) {
72 return self.class_type_.type_->expectRef<ClassType>().doc_string();
73 });
74
75 // This function returns a ScriptClass that wraps the constructor
76 // of the given class, specified by the qualified name passed in.
77 //
78 // This is to emulate the behavior in python where instantiation
79 // of a class is a call to a code object for the class, where that
80 // code object in turn calls __init__. Rather than calling __init__
81 // directly, we need a wrapper that at least returns the instance
82 // rather than the None return value from __init__
83 m.def(
84 "_get_custom_class_python_wrapper",
85 [](const std::string& ns, const std::string& qualname) {
86 std::string full_qualname =
87 "__torch__.torch.classes." + ns + "." + qualname;
88 auto named_type = getCustomClass(full_qualname);
89 TORCH_CHECK(
90 named_type,
91 fmt::format(
92 "Tried to instantiate class '{}.{}', but it does not exist! "
93 "Ensure that it is registered via torch::class_",
94 ns,
95 qualname));
96 c10::ClassTypePtr class_type = named_type->cast<ClassType>();
97 return ScriptClass(c10::StrongTypePtr(
98 std::shared_ptr<CompilationUnit>(), std::move(class_type)));
99 });
100 }
101
102 } // namespace torch::jit
103