xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Lerp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <torch/library.h>
3 
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8 namespace {
9 
10 using namespace api::utils;
11 
check_inputs_elementwise_op(const Tensor & input1,const Tensor & input2)12 void check_inputs_elementwise_op(const Tensor& input1, const Tensor& input2) {
13   TORCH_CHECK(
14       get_dim<Dim4D::Channel>(input1) == get_dim<Dim4D::Channel>(input2),
15       "Vulkan elementwise ops require channel dimension to be equal!");
16   if (get_dim<Dim4D::Batch>(input1) != get_dim<Dim4D::Batch>(input2)) {
17     TORCH_CHECK(
18         get_dim<Dim4D::Channel>(input1) % 4 == 0,
19         "Vulkan elementwise ops require channel to be a multiple of 4 to broadcast along batch dimension!")
20   }
21 
22   const uint32_t input1_h = get_dim<Dim4D::Height>(input1);
23   const uint32_t input1_w = get_dim<Dim4D::Width>(input1);
24   const uint32_t input2_h = get_dim<Dim4D::Height>(input2);
25   const uint32_t input2_w = get_dim<Dim4D::Width>(input2);
26 
27   const std::string broadcast_error_msg =
28       "Incompatible input dimensions for broadcasting for Vulkan elementwise op!";
29   if (input1_h != input2_h) {
30     if (input1_h > input2_h) {
31       TORCH_CHECK(input2_h == 1, broadcast_error_msg);
32       TORCH_CHECK(input2_w == input1_w || input2_w == 1, broadcast_error_msg);
33     } else if (input2_h > input1_h) {
34       TORCH_CHECK(input1_h == 1, broadcast_error_msg);
35       TORCH_CHECK(input1_w == input2_w || input1_w == 1, broadcast_error_msg);
36     }
37   } else if (input1_w != input2_w) {
38     if (input1_w > input2_w) {
39       TORCH_CHECK(input2_w == 1, broadcast_error_msg);
40     } else if (input2_w > input1_w) {
41       TORCH_CHECK(input1_h == 1, broadcast_error_msg);
42     }
43   }
44 }
45 
_lerp_scalar(const Tensor & start_arg,const Tensor & end_arg,const Scalar & weight_arg)46 Tensor _lerp_scalar(
47     const Tensor& start_arg,
48     const Tensor& end_arg,
49     const Scalar& weight_arg) {
50   check_inputs_elementwise_op(start_arg, end_arg);
51   api::Context* const context = api::context();
52 
53   const Tensor start = start_arg.is_vulkan() ? start_arg : start_arg.vulkan();
54   const vTensor& v_start = convert(start);
55 
56   const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
57   const vTensor& v_end = convert(end);
58 
59   vTensor v_output{
60       context,
61       v_start.sizes(),
62       v_start.dtype(),
63   };
64 
65   const float weight = weight_arg.to<float>();
66   const struct Block final {
67     uvec3 extents;
68     uint32_t fill_0;
69     uvec3 input1_extents;
70     uint32_t fill_1;
71     uvec3 input2_extents;
72     float weight;
73   } block{
74       v_output.extents(),
75       0u,
76       v_start.extents(),
77       0u,
78       v_end.extents(),
79       weight,
80   };
81 
82   api::UniformParamsBuffer params(context, block);
83   api::PipelineBarrier pipeline_barrier{};
84 
85   context->submit_compute_job(
86       // shader descriptor
87       VK_KERNEL(lerp_scalar),
88       // pipeline barrier
89       pipeline_barrier,
90       // global work group size
91       v_output.extents(),
92       // local work group size
93       adaptive_work_group_size(v_output.extents()),
94       // fence handle
95       VK_NULL_HANDLE,
96       // shader arguments
97       v_output.image(
98           pipeline_barrier,
99           api::PipelineStage::COMPUTE,
100           api::MemoryAccessType::WRITE),
101       v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE),
102       v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
103       // params buffer
104       params.buffer());
105 
106   return convert(v_output);
107 }
108 
_lerp_scalar_(Tensor & self_arg,const Tensor & end_arg,const Scalar & weight_arg)109 Tensor& _lerp_scalar_(
110     Tensor& self_arg,
111     const Tensor& end_arg,
112     const Scalar& weight_arg) {
113   check_inputs_elementwise_op(self_arg, end_arg);
114 
115   TORCH_CHECK(
116       self_arg.is_vulkan(),
117       "Vulkan: In-place operator is only supported on Vulkan tensors.");
118 
119   api::Context* const context = api::context();
120 
121   vTensor& v_self = convert(self_arg);
122 
123   const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
124   const vTensor& v_end = convert(end);
125 
126   const float weight = weight_arg.to<float>();
127   const struct Block final {
128     uvec3 extents;
129     uint32_t fill_0;
130     uvec3 input_extents;
131     float alpha;
132   } block{
133       v_self.extents(),
134       0u,
135       v_end.extents(),
136       weight,
137   };
138 
139   api::UniformParamsBuffer params(context, block);
140   api::PipelineBarrier pipeline_barrier{};
141 
142   context->submit_compute_job(
143       // shader descriptor
144       VK_KERNEL(lerp_scalar_),
145       // pipeline barrier
146       pipeline_barrier,
147       // global work group size
148       v_self.extents(),
149       // local work group size
150       adaptive_work_group_size(v_self.extents()),
151       // fence handle
152       VK_NULL_HANDLE,
153       // shader arguments
154       v_self.image(
155           pipeline_barrier,
156           api::PipelineStage::COMPUTE,
157           api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
158       v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
159       // params buffer
160       params.buffer());
161 
162   return self_arg;
163 }
164 
_lerp_tensor(const Tensor & start_arg,const Tensor & end_arg,const Tensor & weight_arg)165 Tensor _lerp_tensor(
166     const Tensor& start_arg,
167     const Tensor& end_arg,
168     const Tensor& weight_arg) {
169   check_inputs_elementwise_op(start_arg, end_arg);
170   check_inputs_elementwise_op(start_arg, weight_arg);
171 
172   api::Context* const context = api::context();
173 
174   const Tensor start = start_arg.is_vulkan() ? start_arg : start_arg.vulkan();
175   const vTensor& v_start = convert(start);
176 
177   const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
178   const vTensor& v_end = convert(end);
179 
180   const Tensor weight =
181       weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
182   const vTensor& v_weight = convert(weight_arg);
183 
184   vTensor v_output{
185       context,
186       v_start.sizes(),
187       v_start.dtype(),
188   };
189 
190   const struct Block final {
191     uvec3 extents;
192     uint32_t fill_0;
193     uvec3 input1_extents;
194     uint32_t fill_1;
195     uvec3 input2_extents;
196     uint32_t fill_2;
197     uvec3 input3_extents;
198     uint32_t fill_3;
199   } block{
200       v_output.extents(),
201       0u,
202       v_start.extents(),
203       0u,
204       v_end.extents(),
205       0u,
206       v_weight.extents(),
207       0u,
208   };
209 
210   api::UniformParamsBuffer params(context, block);
211   api::PipelineBarrier pipeline_barrier{};
212 
213   context->submit_compute_job(
214       // shader descriptor
215       VK_KERNEL(lerp),
216       // pipeline barrier
217       pipeline_barrier,
218       // global work group size
219       v_output.extents(),
220       // local work group size
221       adaptive_work_group_size(v_output.extents()),
222       // fence handle
223       VK_NULL_HANDLE,
224       // shader arguments
225       v_output.image(
226           pipeline_barrier,
227           api::PipelineStage::COMPUTE,
228           api::MemoryAccessType::WRITE),
229       v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE),
230       v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
231       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
232       // params buffer
233       params.buffer());
234 
235   return convert(v_output);
236 }
237 
_lerp_tensor_(Tensor & self_arg,const Tensor & end_arg,const Tensor & weight_arg)238 Tensor& _lerp_tensor_(
239     Tensor& self_arg,
240     const Tensor& end_arg,
241     const Tensor& weight_arg) {
242   check_inputs_elementwise_op(self_arg, end_arg);
243   check_inputs_elementwise_op(self_arg, weight_arg);
244 
245   TORCH_CHECK(
246       self_arg.is_vulkan(),
247       "Vulkan: In-place operator is only supported on Vulkan tensors.");
248 
249   api::Context* const context = api::context();
250 
251   vTensor& v_self = convert(self_arg);
252 
253   const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
254   const vTensor& v_end = convert(end_arg);
255 
256   const Tensor weight =
257       weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
258   const vTensor& v_weight = convert(weight_arg);
259 
260   const struct Block final {
261     uvec3 extents;
262     uint32_t fill_0;
263     uvec3 input1_extents;
264     uint32_t fill_1;
265     uvec3 input2_extents;
266     uint32_t fill_2;
267   } block{
268       v_self.extents(),
269       0u,
270       v_end.extents(),
271       0u,
272       v_weight.extents(),
273       0u,
274   };
275 
276   api::UniformParamsBuffer params(context, block);
277   api::PipelineBarrier pipeline_barrier{};
278 
279   context->submit_compute_job(
280       // shader descriptor
281       VK_KERNEL(lerp_),
282       // pipeline barrier
283       pipeline_barrier,
284       // global work group size
285       v_self.extents(),
286       // local work group size
287       adaptive_work_group_size(v_self.extents()),
288       // fence handle
289       VK_NULL_HANDLE,
290       // shader arguments
291       v_self.image(
292           pipeline_barrier,
293           api::PipelineStage::COMPUTE,
294           api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
295       v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
296       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
297       // params buffer
298       params.buffer());
299 
300   return self_arg;
301 }
302 
lerp_scalar(const Tensor & start,const Tensor & end,const Scalar & weight)303 Tensor lerp_scalar(
304     const Tensor& start,
305     const Tensor& end,
306     const Scalar& weight) {
307   return _lerp_scalar(start, end, weight);
308 }
309 
lerp_scalar_(Tensor & self,const Tensor & end,const Scalar & weight)310 Tensor& lerp_scalar_(Tensor& self, const Tensor& end, const Scalar& weight) {
311   return _lerp_scalar_(self, end, weight);
312 }
313 
lerp_tensor(const Tensor & start,const Tensor & end,const Tensor & weight)314 Tensor lerp_tensor(
315     const Tensor& start,
316     const Tensor& end,
317     const Tensor& weight) {
318   if (weight.sizes().size() == 0) {
319     return _lerp_scalar(start, end, weight.item<float>());
320   }
321   return _lerp_tensor(start, end, weight);
322 }
323 
lerp_tensor_(Tensor & self,const Tensor & end,const Tensor & weight)324 Tensor& lerp_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
325   if (weight.sizes().size() == 0) {
326     return _lerp_scalar_(self, end, weight.item<float>());
327   }
328   return _lerp_tensor_(self, end, weight);
329 }
330 
331 #ifdef USE_VULKAN_API
332 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)333 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
334   m.impl(TORCH_SELECTIVE_NAME("aten::lerp.Scalar"), TORCH_FN(lerp_scalar));
335   m.impl(TORCH_SELECTIVE_NAME("aten::lerp_.Scalar"), TORCH_FN(lerp_scalar_));
336   m.impl(TORCH_SELECTIVE_NAME("aten::lerp.Tensor"), TORCH_FN(lerp_tensor));
337   m.impl(TORCH_SELECTIVE_NAME("aten::lerp_.Tensor"), TORCH_FN(lerp_tensor_));
338 }
339 
340 #endif /* USE_VULKAN_API */
341 
342 } // namespace
343 } // namespace ops
344 } // namespace vulkan
345 } // namespace native
346 } // namespace at
347