xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_backend_lib.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend.h>
2 #include <torch/csrc/jit/backends/backend_debug_handler.h>
3 #include <torch/csrc/jit/backends/backend_preprocess.h>
4 
5 namespace torch {
6 namespace jit {
7 // This test JIT backend is intended to do the minimal amount of work
8 // necessary to test that the JIT backend registration endpoints and
9 // code generation are working correctly. It is not intended to
10 // produce numerically correct results.
11 template <bool isAvailable>
12 class TestBackend : public PyTorchBackendInterface {
13  public:
14   // Constructor.
15   // NOLINTNEXTLINE(modernize-use-equals-default)
TestBackend()16   explicit TestBackend() {}
17   virtual ~TestBackend() override = default;
18 
is_available()19   bool is_available() override {
20     return isAvailable;
21   }
22 
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)23   c10::impl::GenericDict compile(
24       c10::IValue processed,
25       c10::impl::GenericDict method_compile_spec) override {
26     auto spec =
27         c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
28 
29     // Return the same string as a value for every key in method_compile_spec.
30     auto handles = c10::Dict<std::string, std::string>();
31     for (const auto& it : spec) {
32       handles.insert(it.key(), it.key());
33     }
34     return c10::impl::toGenericDict(handles);
35   }
execute(c10::IValue handle,c10::impl::GenericList inputs)36   c10::impl::GenericList execute(
37       c10::IValue handle,
38       c10::impl::GenericList inputs) override {
39     TORCH_INTERNAL_ASSERT(handle.isString());
40     TORCH_INTERNAL_ASSERT(inputs.size() > 0);
41 
42     c10::List<at::Tensor> output_list;
43 
44     // Implement simple accumulator and negative accumulator (?) ops. Return one
45     // or both of them depending on the handle to make sure multiple outputs are
46     // handled.
47     c10::IValue value = inputs[0];
48     at::Tensor accum = value.toTensor();
49     accum = accum.clone();
50     at::Tensor sub_accum = value.toTensor();
51     sub_accum = sub_accum.clone();
52 
53     for (size_t i = 1, e = inputs.size(); i < e; ++i) {
54       value = inputs[i];
55       accum.add_(value.toTensor(), 1.0);
56       sub_accum.sub_(value.toTensor(), 1.0);
57     }
58 
59     if (handle.toStringRef() == "accum") {
60       output_list.emplace_back(accum);
61     } else if (handle.toStringRef() == "sub_accum") {
62       output_list.emplace_back(sub_accum);
63     } else if (handle.toStringRef() == "forward") {
64       output_list.emplace_back(accum);
65       output_list.emplace_back(sub_accum);
66     }
67 
68     return c10::impl::toList(output_list);
69   }
70 };
71 
72 namespace {
preprocess(const Module & mod,const c10::Dict<IValue,IValue> & method_compile_spec,const BackendDebugHandleGenerator & generate_debug_handles)73 c10::IValue preprocess(
74     const Module& mod,
75     const c10::Dict<IValue, IValue>& method_compile_spec,
76     const BackendDebugHandleGenerator& generate_debug_handles) {
77   return mod._ivalue();
78 }
79 
80 constexpr auto backend_name = "test_backend";
81 static auto cls_available =
82     torch::jit::backend<TestBackend<true>>(backend_name);
83 static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
84 
85 constexpr auto backend_unavailable_name = "test_backend_unavailable";
86 static auto cls_unavailable =
87     torch::jit::backend<TestBackend<false>>(backend_unavailable_name);
88 static auto pre_reg_unavailable =
89     backend_preprocess_register(backend_unavailable_name, preprocess);
90 
91 } // namespace
92 } // namespace jit
93 } // namespace torch
94