1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_detail.h>
3 #include <torch/csrc/jit/api/module.h>
4
5 namespace torch {
6 namespace custom_backend {
7 // This custom JIT backend is intended to do the minimal amount of work
8 // necessary to test that the JIT backend registration endpoints and
9 // code generation are working correctly. It is not intended to
10 // produce numerically correct results.
11 class CustomBackend : public torch::jit::PyTorchBackendInterface {
12 public:
13 // Constructor.
CustomBackend()14 explicit CustomBackend() {}
15 virtual ~CustomBackend() = default;
16
is_available()17 bool is_available() override {
18 return true;
19 }
20
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)21 c10::impl::GenericDict compile(
22 c10::IValue processed,
23 c10::impl::GenericDict method_compile_spec) override {
24 auto spec =
25 c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
26
27 // Return the same string as a value for every key in method_compile_spec.
28 auto handles = c10::Dict<std::string, std::string>();
29 for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
30 handles.insert(it->key(), it->key());
31 }
32 return c10::impl::toGenericDict(handles);
33 }
execute(c10::IValue handle,c10::impl::GenericList inputs)34 c10::impl::GenericList execute(
35 c10::IValue handle,
36 c10::impl::GenericList inputs) override {
37 TORCH_INTERNAL_ASSERT(handle.isString());
38 TORCH_INTERNAL_ASSERT(inputs.size() > 0);
39
40 c10::List<at::Tensor> output_list;
41
42 // Implement simple accumulator and negative accumulator (?) ops. Return one
43 // or both of them depending on the handle to make sure multiple outputs are
44 // handled.
45 c10::IValue value = inputs[0];
46 at::Tensor accum = value.toTensor();
47 accum = accum.clone();
48 at::Tensor sub_accum = value.toTensor();
49 sub_accum = sub_accum.clone();
50
51 for (size_t i = 1, e = inputs.size(); i < e; ++i) {
52 value = inputs[i];
53 accum.add_(value.toTensor(), 1.0);
54 sub_accum.sub_(value.toTensor(), 1.0);
55 }
56
57 if (handle.toStringRef() == "accum") {
58 output_list.emplace_back(accum);
59 } else if (handle.toStringRef() == "sub_accum") {
60 output_list.emplace_back(sub_accum);
61 } else if (handle.toStringRef() == "forward") {
62 output_list.emplace_back(accum);
63 output_list.emplace_back(sub_accum);
64 }
65
66 return c10::impl::toList(output_list);
67 }
68 };
69
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)70 c10::IValue preprocess(
71 const torch::jit::Module& mod,
72 const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
73 const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
74 return mod._ivalue();
75 }
76
77 // clang-format off
78 # if defined(_WIN32)
79 # if defined(custom_ops_EXPORTS)
80 # define CUSTOM_BACKEND_API __declspec(dllexport)
81 # else
82 # define CUSTOM_BACKEND_API __declspec(dllimport)
83 # endif
84 # else
85 # define CUSTOM_BACKEND_API
86 # endif
87 // clang-format on
88
89 CUSTOM_BACKEND_API std::string getBackendName();
90 } // namespace custom_backend
91 } // namespace torch
92