xref: /aosp_15_r20/external/pytorch/test/custom_backend/custom_backend.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_detail.h>
3 #include <torch/csrc/jit/api/module.h>
4 
5 namespace torch {
6 namespace custom_backend {
7 // This custom JIT backend is intended to do the minimal amount of work
8 // necessary to test that the JIT backend registration endpoints and
9 // code generation are working correctly. It is not intended to
10 // produce numerically correct results.
11 class CustomBackend : public torch::jit::PyTorchBackendInterface {
12  public:
13   // Constructor.
CustomBackend()14   explicit CustomBackend() {}
15   virtual ~CustomBackend() = default;
16 
is_available()17   bool is_available() override {
18     return true;
19   }
20 
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)21   c10::impl::GenericDict compile(
22       c10::IValue processed,
23       c10::impl::GenericDict method_compile_spec) override {
24     auto spec =
25         c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
26 
27     // Return the same string as a value for every key in method_compile_spec.
28     auto handles = c10::Dict<std::string, std::string>();
29     for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
30       handles.insert(it->key(), it->key());
31     }
32     return c10::impl::toGenericDict(handles);
33   }
execute(c10::IValue handle,c10::impl::GenericList inputs)34   c10::impl::GenericList execute(
35       c10::IValue handle,
36       c10::impl::GenericList inputs) override {
37     TORCH_INTERNAL_ASSERT(handle.isString());
38     TORCH_INTERNAL_ASSERT(inputs.size() > 0);
39 
40     c10::List<at::Tensor> output_list;
41 
42     // Implement simple accumulator and negative accumulator (?) ops. Return one
43     // or both of them depending on the handle to make sure multiple outputs are
44     // handled.
45     c10::IValue value = inputs[0];
46     at::Tensor accum = value.toTensor();
47     accum = accum.clone();
48     at::Tensor sub_accum = value.toTensor();
49     sub_accum = sub_accum.clone();
50 
51     for (size_t i = 1, e = inputs.size(); i < e; ++i) {
52       value = inputs[i];
53       accum.add_(value.toTensor(), 1.0);
54       sub_accum.sub_(value.toTensor(), 1.0);
55     }
56 
57     if (handle.toStringRef() == "accum") {
58       output_list.emplace_back(accum);
59     } else if (handle.toStringRef() == "sub_accum") {
60       output_list.emplace_back(sub_accum);
61     } else if (handle.toStringRef() == "forward") {
62       output_list.emplace_back(accum);
63       output_list.emplace_back(sub_accum);
64     }
65 
66     return c10::impl::toList(output_list);
67   }
68 };
69 
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)70 c10::IValue preprocess(
71     const torch::jit::Module& mod,
72     const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
73     const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
74   return mod._ivalue();
75 }
76 
77 // clang-format off
78 #  if defined(_WIN32)
79 #    if defined(custom_ops_EXPORTS)
80 #      define CUSTOM_BACKEND_API __declspec(dllexport)
81 #    else
82 #      define CUSTOM_BACKEND_API __declspec(dllimport)
83 #    endif
84 #  else
85 #    define CUSTOM_BACKEND_API
86 #  endif
87 // clang-format on
88 
89 CUSTOM_BACKEND_API std::string getBackendName();
90 } // namespace custom_backend
91 } // namespace torch
92