xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/builtin_function.h>
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/backends/backend_interface.h>
6 #include <torch/custom_class.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace {
11 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getIsAvailableSchema()12 inline c10::FunctionSchema getIsAvailableSchema() {
13   c10::Argument self("self", c10::AnyType::get());
14   c10::Argument available("available", c10::BoolType::get());
15   c10::FunctionSchema preprocessor_schema(
16       "is_available",
17       /*overload_name=*/"",
18       /*arguments=*/{self},
19       /*returns=*/{available});
20   return preprocessor_schema;
21 }
22 
23 constexpr static auto kBackendsNamespace = "__backends__";
24 
25 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getCompileSchema()26 inline c10::FunctionSchema getCompileSchema() {
27   c10::Argument self("self", c10::AnyType::get());
28   c10::Argument mod("processed", c10::AnyType::get());
29   auto any_dict_ty =
30       c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
31   c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
32   c10::Argument handles("handles", any_dict_ty);
33 
34   c10::FunctionSchema compile_schema(
35       "compile",
36       /*overload_name=*/"",
37       /*arguments=*/{self, mod, method_compile_spec},
38       /*returns=*/{handles});
39   return compile_schema;
40 }
41 
42 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getExecuteSchema()43 inline c10::FunctionSchema getExecuteSchema() {
44   auto any_list_ty = c10::ListType::create(c10::AnyType::get());
45   c10::Argument self("self", c10::AnyType::get());
46   c10::Argument handle("handle", c10::AnyType::get());
47   c10::Argument input("input", any_list_ty);
48   c10::Argument output("output", any_list_ty);
49   return c10::FunctionSchema(
50       "execute",
51       /*overload_name=*/"",
52       /*arguments=*/{self, handle, input},
53       /*returns=*/{output});
54 }
55 
56 template <typename TBackendInterface>
getIsAvailableFunc()57 std::function<void(Stack&)> getIsAvailableFunc() {
58   return [](Stack& stack) {
59     auto self = pop(stack).toCustomClass<TBackendInterface>();
60     auto ret = self->is_available();
61     push(stack, ret);
62   };
63 }
64 
65 template <typename TBackendInterface>
getCompileFunc()66 std::function<void(Stack&)> getCompileFunc() {
67   return [](Stack& stack) {
68     auto method_compile_spec = pop(stack).toGenericDict();
69     auto processed = pop(stack);
70     auto self = pop(stack).toCustomClass<TBackendInterface>();
71     auto ret = self->compile(processed, method_compile_spec);
72     push(stack, ret);
73   };
74 }
75 
76 template <typename TBackendInterface>
getExecuteFunc()77 std::function<void(Stack&)> getExecuteFunc() {
78   return [](Stack& stack) {
79     auto args = pop(stack);
80     auto handle = pop(stack);
81     auto self = pop(stack);
82     auto backend = self.toCustomClass<TBackendInterface>();
83     auto res = backend->execute(handle, args.toList());
84     push(stack, res);
85   };
86 }
87 } // namespace
88 
89 // Static registration API for backends.
90 template <class TBackendInterface>
91 class backend {
92   static_assert(
93       std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value,
94       "torch::jit::backend<T> requires T to inherit from PyTorchBackendInterface");
95   std::string backend_name_;
96 
97  public:
98   // Registers a new backend with /p name, and the given /p preprocess
99   // function.
backend(const std::string & name)100   backend(const std::string& name) : backend_name_(name) {
101     static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name)
102                           .def(torch::init<>())
103                           ._def_unboxed(
104                               "is_available",
105                               getIsAvailableFunc<TBackendInterface>(),
106                               getIsAvailableSchema())
107                           ._def_unboxed(
108                               "compile",
109                               getCompileFunc<TBackendInterface>(),
110                               getCompileSchema())
111                           ._def_unboxed(
112                               "execute",
113                               getExecuteFunc<TBackendInterface>(),
114                               getExecuteSchema());
115   }
116 };
117 
118 } // namespace jit
119 } // namespace torch
120