1 #pragma once
2
3 #include <ATen/core/builtin_function.h>
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/backends/backend_interface.h>
6 #include <torch/custom_class.h>
7
8 namespace torch {
9 namespace jit {
10 namespace {
11 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getIsAvailableSchema()12 inline c10::FunctionSchema getIsAvailableSchema() {
13 c10::Argument self("self", c10::AnyType::get());
14 c10::Argument available("available", c10::BoolType::get());
15 c10::FunctionSchema preprocessor_schema(
16 "is_available",
17 /*overload_name=*/"",
18 /*arguments=*/{self},
19 /*returns=*/{available});
20 return preprocessor_schema;
21 }
22
23 constexpr static auto kBackendsNamespace = "__backends__";
24
25 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getCompileSchema()26 inline c10::FunctionSchema getCompileSchema() {
27 c10::Argument self("self", c10::AnyType::get());
28 c10::Argument mod("processed", c10::AnyType::get());
29 auto any_dict_ty =
30 c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
31 c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
32 c10::Argument handles("handles", any_dict_ty);
33
34 c10::FunctionSchema compile_schema(
35 "compile",
36 /*overload_name=*/"",
37 /*arguments=*/{self, mod, method_compile_spec},
38 /*returns=*/{handles});
39 return compile_schema;
40 }
41
42 // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
getExecuteSchema()43 inline c10::FunctionSchema getExecuteSchema() {
44 auto any_list_ty = c10::ListType::create(c10::AnyType::get());
45 c10::Argument self("self", c10::AnyType::get());
46 c10::Argument handle("handle", c10::AnyType::get());
47 c10::Argument input("input", any_list_ty);
48 c10::Argument output("output", any_list_ty);
49 return c10::FunctionSchema(
50 "execute",
51 /*overload_name=*/"",
52 /*arguments=*/{self, handle, input},
53 /*returns=*/{output});
54 }
55
56 template <typename TBackendInterface>
getIsAvailableFunc()57 std::function<void(Stack&)> getIsAvailableFunc() {
58 return [](Stack& stack) {
59 auto self = pop(stack).toCustomClass<TBackendInterface>();
60 auto ret = self->is_available();
61 push(stack, ret);
62 };
63 }
64
65 template <typename TBackendInterface>
getCompileFunc()66 std::function<void(Stack&)> getCompileFunc() {
67 return [](Stack& stack) {
68 auto method_compile_spec = pop(stack).toGenericDict();
69 auto processed = pop(stack);
70 auto self = pop(stack).toCustomClass<TBackendInterface>();
71 auto ret = self->compile(processed, method_compile_spec);
72 push(stack, ret);
73 };
74 }
75
76 template <typename TBackendInterface>
getExecuteFunc()77 std::function<void(Stack&)> getExecuteFunc() {
78 return [](Stack& stack) {
79 auto args = pop(stack);
80 auto handle = pop(stack);
81 auto self = pop(stack);
82 auto backend = self.toCustomClass<TBackendInterface>();
83 auto res = backend->execute(handle, args.toList());
84 push(stack, res);
85 };
86 }
87 } // namespace
88
89 // Static registration API for backends.
90 template <class TBackendInterface>
91 class backend {
92 static_assert(
93 std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value,
94 "torch::jit::backend<T> requires T to inherit from PyTorchBackendInterface");
95 std::string backend_name_;
96
97 public:
98 // Registers a new backend with /p name, and the given /p preprocess
99 // function.
backend(const std::string & name)100 backend(const std::string& name) : backend_name_(name) {
101 static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name)
102 .def(torch::init<>())
103 ._def_unboxed(
104 "is_available",
105 getIsAvailableFunc<TBackendInterface>(),
106 getIsAvailableSchema())
107 ._def_unboxed(
108 "compile",
109 getCompileFunc<TBackendInterface>(),
110 getCompileSchema())
111 ._def_unboxed(
112 "execute",
113 getExecuteFunc<TBackendInterface>(),
114 getExecuteSchema());
115 }
116 };
117
118 } // namespace jit
119 } // namespace torch
120