xref: /aosp_15_r20/external/pytorch/test/inductor/extension_backends/cpp/extension_device.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/alloc_cpu.h>
2 #include <c10/core/Allocator.h>
3 
4 #include <torch/csrc/Device.h>
5 #include <c10/core/impl/DeviceGuardImplInterface.h>
6 #include <c10/macros/Macros.h>
7 #include <torch/extension.h>
8 
9 #include <ATen/native/cpu/Loops.h>
10 #include <ATen/native/DispatchStub.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/EmptyTensor.h>
13 #include <ATen/core/GeneratorForPrivateuseone.h>
14 
15 static uint64_t op_counter = 0;
16 static uint64_t last_saved_value = 0;
17 
18 // register guard
19 namespace at {
20 namespace detail {
21 
22 C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
23 
24 }} // namespace at::detail
25 
26 // basic dummy add function
custom_add_Tensor(const at::Tensor & self,const at::Tensor & other,const at::Scalar & alpha)27 at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
28   op_counter += 1;
29   // Since this custom device is just for testing, not bothering to implement kernels.
30   return at::empty(self.sizes(), self.options());
31 }
32 
33 // basic dummy mul function
custom_mul_Tensor(const at::Tensor & self,const at::Tensor & other)34 at::Tensor custom_mul_Tensor(const at::Tensor & self, const at::Tensor & other) {
35   op_counter += 1;
36   // Since this custom device is just for testing, not bothering to implement kernels.
37   return at::empty(self.sizes(), self.options());
38 }
39 
40 // basic dummy eq function: Only support CPU
custom_to_device(const at::Tensor & self,at::Device device,at::ScalarType dtype,bool non_blocking,bool copy,std::optional<at::MemoryFormat> memory_format)41 at::Tensor custom_to_device(
42     const at::Tensor & self,
43     at::Device device,
44     at::ScalarType dtype,
45     bool non_blocking,
46     bool copy,
47     std::optional<at::MemoryFormat> memory_format) {
48   TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
49   TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
50   // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
51   TORCH_CHECK(self.scalar_type() == dtype);
52   TORCH_CHECK(self.is_contiguous());
53 
54   op_counter += 1;
55   if (device != at::DeviceType::CPU) {
56     return at::empty(self.sizes(), self.options());
57   }
58 
59   auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format);
60   memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes());
61   // Since this custom device is just for testing, not bothering to implement kernels.
62   return out;
63 }
64 
65 
66 // A dummy allocator for our custom device, that secretly uses the CPU
67 struct DummyCustomAllocator final : at::Allocator {
68   DummyCustomAllocator() = default;
allocateDummyCustomAllocator69   at::DataPtr allocate(size_t nbytes) override {
70     void* data = c10::alloc_cpu(nbytes);
71     return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
72   }
73 
ReportAndDeleteDummyCustomAllocator74   static void ReportAndDelete(void* ptr) {
75     if (!ptr) {
76       return;
77     }
78     c10::free_cpu(ptr);
79   }
80 
raw_deleterDummyCustomAllocator81   at::DeleterFnPtr raw_deleter() const override {
82     return &ReportAndDelete;
83   }
84 
copy_dataDummyCustomAllocator85   void copy_data(void* dest, const void* src, std::size_t count) const final {
86     default_copy_data(dest, src, count);
87   }
88 };
89 
90 // Register our dummy allocator
91 static DummyCustomAllocator global_custom_alloc;
92 REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
93 
custom_fill__scalar(at::Tensor & self,const at::Scalar & value)94 at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
95   TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows dummy device.");
96   TORCH_CHECK(self.is_contiguous());
97   TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float);
98 
99   op_counter += 1;
100   auto _data = static_cast<float*>(self.mutable_data_ptr());
101   for (size_t idx = 0; idx < self.numel(); idx++) {
102     _data[idx] = value.toFloat();
103   }
104 
105   return self;
106 }
107 
108 // basic dummy copy_() function, so we can copy from the custom device to/from CPU
custom__copy_from(const at::Tensor & self,const at::Tensor & dst,bool non_blocking)109 at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
110   TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
111   TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
112 
113   // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
114   TORCH_CHECK(self.sizes() == dst.sizes());
115   TORCH_CHECK(self.scalar_type() == dst.scalar_type());
116   TORCH_CHECK(self.is_contiguous() && dst.is_contiguous());
117 
118   op_counter += 1;
119   std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes());
120   return dst;
121 }
122 
custom_empty_memory_format(at::IntArrayRef size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)123 at::Tensor custom_empty_memory_format(at::IntArrayRef size,
124                                       std::optional<at::ScalarType> dtype,
125                                       std::optional<at::Layout> layout,
126                                       std::optional<at::Device> device,
127                                       std::optional<bool> pin_memory,
128                                       std::optional<at::MemoryFormat> memory_format) {
129   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
130   return at::detail::empty_generic(size,
131                                    &global_custom_alloc,
132                                    private_use_ks,
133                                    c10::dtype_or_default(dtype),
134                                    memory_format);
135 }
136 
custom_empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<at::ScalarType> dtype_opt,std::optional<at::Layout> layout_opt,std::optional<at::Device> device_opt,std::optional<bool> pin_memory_opt)137 at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, std::optional<at::ScalarType> dtype_opt, std::optional<at::Layout> layout_opt, std::optional<at::Device> device_opt, std::optional<bool> pin_memory_opt) {
138   op_counter += 1;
139 
140   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
141   auto dtype = c10::dtype_or_default(dtype_opt);
142   return  at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
143 }
144 
145 // This macro does the heavy lifting.
146 // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
147 // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
148 // Later in this file, we map a custom device to the PrivateUse1 device type,
149 // which allows user code that puts a tensor on your custom_device to eventually get plumbed
150 // into the kernels registered here.
151 //
152 // This macro registers your kernels to the PyTorch Dispatcher.
153 // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m)154 TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
155   m.impl("add.Tensor", &custom_add_Tensor);
156   m.impl("mul.Tensor", &custom_mul_Tensor);
157   m.impl("to.Device", &custom_to_device);
158   m.impl("fill_.Scalar", &custom_fill__scalar);
159   m.impl("_copy_from", &custom__copy_from);
160   m.impl("empty.memory_format", &custom_empty_memory_format);
161   m.impl("empty_strided", &custom_empty_strided);
162 }
163 
164 // This basic implementation doesn't bother dealing with different device indices
165 // (e.g. custom_device:0 vs. custom_device:1).
166 // We could do that by letting the user pass in a device index in our exposed device function.
167 // Note that if you do that, you'll also need to register a device guard to core.
168 // See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
get_custom_device()169 c10::Device get_custom_device() {
170   return c10::Device(c10::DeviceType::PrivateUse1, 0);
171 }
172 
custom_op_called()173 bool custom_op_called() {
174   bool called = false;
175   if (op_counter > last_saved_value) {
176     called = true;
177     last_saved_value = op_counter;
178   }
179   return called;
180 }
181 
182 class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
183 public:
184   // Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index)185   PrivateGeneratorImpl(c10::DeviceIndex device_index) {
186     device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
187     key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
188   }
189   ~PrivateGeneratorImpl() override = default;
190 };
191 
192 // this is used to register generator
make_generator_privateuse1(c10::DeviceIndex device_index)193 at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
194   return at::make_generator<PrivateGeneratorImpl>(device_index);
195 }
196 
register_generator()197 void register_generator() {
198   REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
199 }
200 
201 // Here, we're exposing a custom device object that corresponds to our custom backend.
202 // We do this using pybind: exposing an "extension_name.custom_device()" function in python,
203 // that's implemented in C++.
204 // The implementation in this file maps directly to the `PrivateUse1` device type.
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)205 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
206     m.def("custom_device", &get_custom_device, "get custom device object");
207     m.def("custom_op_called", &custom_op_called, "check if our custom function was called");
208     m.def("register_generator", &register_generator, "register generator for custom device");
209 }
210