1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <torch/library.h>
3
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8 namespace {
9
10 using namespace api::utils;
11
check_inputs_elementwise_op(const Tensor & input1,const Tensor & input2)12 void check_inputs_elementwise_op(const Tensor& input1, const Tensor& input2) {
13 TORCH_CHECK(
14 get_dim<Dim4D::Channel>(input1) == get_dim<Dim4D::Channel>(input2),
15 "Vulkan elementwise ops require channel dimension to be equal!");
16 if (get_dim<Dim4D::Batch>(input1) != get_dim<Dim4D::Batch>(input2)) {
17 TORCH_CHECK(
18 get_dim<Dim4D::Channel>(input1) % 4 == 0,
19 "Vulkan elementwise ops require channel to be a multiple of 4 to broadcast along batch dimension!")
20 }
21
22 const uint32_t input1_h = get_dim<Dim4D::Height>(input1);
23 const uint32_t input1_w = get_dim<Dim4D::Width>(input1);
24 const uint32_t input2_h = get_dim<Dim4D::Height>(input2);
25 const uint32_t input2_w = get_dim<Dim4D::Width>(input2);
26
27 const std::string broadcast_error_msg =
28 "Incompatible input dimensions for broadcasting for Vulkan elementwise op!";
29 if (input1_h != input2_h) {
30 if (input1_h > input2_h) {
31 TORCH_CHECK(input2_h == 1, broadcast_error_msg);
32 TORCH_CHECK(input2_w == input1_w || input2_w == 1, broadcast_error_msg);
33 } else if (input2_h > input1_h) {
34 TORCH_CHECK(input1_h == 1, broadcast_error_msg);
35 TORCH_CHECK(input1_w == input2_w || input1_w == 1, broadcast_error_msg);
36 }
37 } else if (input1_w != input2_w) {
38 if (input1_w > input2_w) {
39 TORCH_CHECK(input2_w == 1, broadcast_error_msg);
40 } else if (input2_w > input1_w) {
41 TORCH_CHECK(input1_h == 1, broadcast_error_msg);
42 }
43 }
44 }
45
_lerp_scalar(const Tensor & start_arg,const Tensor & end_arg,const Scalar & weight_arg)46 Tensor _lerp_scalar(
47 const Tensor& start_arg,
48 const Tensor& end_arg,
49 const Scalar& weight_arg) {
50 check_inputs_elementwise_op(start_arg, end_arg);
51 api::Context* const context = api::context();
52
53 const Tensor start = start_arg.is_vulkan() ? start_arg : start_arg.vulkan();
54 const vTensor& v_start = convert(start);
55
56 const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
57 const vTensor& v_end = convert(end);
58
59 vTensor v_output{
60 context,
61 v_start.sizes(),
62 v_start.dtype(),
63 };
64
65 const float weight = weight_arg.to<float>();
66 const struct Block final {
67 uvec3 extents;
68 uint32_t fill_0;
69 uvec3 input1_extents;
70 uint32_t fill_1;
71 uvec3 input2_extents;
72 float weight;
73 } block{
74 v_output.extents(),
75 0u,
76 v_start.extents(),
77 0u,
78 v_end.extents(),
79 weight,
80 };
81
82 api::UniformParamsBuffer params(context, block);
83 api::PipelineBarrier pipeline_barrier{};
84
85 context->submit_compute_job(
86 // shader descriptor
87 VK_KERNEL(lerp_scalar),
88 // pipeline barrier
89 pipeline_barrier,
90 // global work group size
91 v_output.extents(),
92 // local work group size
93 adaptive_work_group_size(v_output.extents()),
94 // fence handle
95 VK_NULL_HANDLE,
96 // shader arguments
97 v_output.image(
98 pipeline_barrier,
99 api::PipelineStage::COMPUTE,
100 api::MemoryAccessType::WRITE),
101 v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE),
102 v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
103 // params buffer
104 params.buffer());
105
106 return convert(v_output);
107 }
108
_lerp_scalar_(Tensor & self_arg,const Tensor & end_arg,const Scalar & weight_arg)109 Tensor& _lerp_scalar_(
110 Tensor& self_arg,
111 const Tensor& end_arg,
112 const Scalar& weight_arg) {
113 check_inputs_elementwise_op(self_arg, end_arg);
114
115 TORCH_CHECK(
116 self_arg.is_vulkan(),
117 "Vulkan: In-place operator is only supported on Vulkan tensors.");
118
119 api::Context* const context = api::context();
120
121 vTensor& v_self = convert(self_arg);
122
123 const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
124 const vTensor& v_end = convert(end);
125
126 const float weight = weight_arg.to<float>();
127 const struct Block final {
128 uvec3 extents;
129 uint32_t fill_0;
130 uvec3 input_extents;
131 float alpha;
132 } block{
133 v_self.extents(),
134 0u,
135 v_end.extents(),
136 weight,
137 };
138
139 api::UniformParamsBuffer params(context, block);
140 api::PipelineBarrier pipeline_barrier{};
141
142 context->submit_compute_job(
143 // shader descriptor
144 VK_KERNEL(lerp_scalar_),
145 // pipeline barrier
146 pipeline_barrier,
147 // global work group size
148 v_self.extents(),
149 // local work group size
150 adaptive_work_group_size(v_self.extents()),
151 // fence handle
152 VK_NULL_HANDLE,
153 // shader arguments
154 v_self.image(
155 pipeline_barrier,
156 api::PipelineStage::COMPUTE,
157 api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
158 v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
159 // params buffer
160 params.buffer());
161
162 return self_arg;
163 }
164
_lerp_tensor(const Tensor & start_arg,const Tensor & end_arg,const Tensor & weight_arg)165 Tensor _lerp_tensor(
166 const Tensor& start_arg,
167 const Tensor& end_arg,
168 const Tensor& weight_arg) {
169 check_inputs_elementwise_op(start_arg, end_arg);
170 check_inputs_elementwise_op(start_arg, weight_arg);
171
172 api::Context* const context = api::context();
173
174 const Tensor start = start_arg.is_vulkan() ? start_arg : start_arg.vulkan();
175 const vTensor& v_start = convert(start);
176
177 const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
178 const vTensor& v_end = convert(end);
179
180 const Tensor weight =
181 weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
182 const vTensor& v_weight = convert(weight_arg);
183
184 vTensor v_output{
185 context,
186 v_start.sizes(),
187 v_start.dtype(),
188 };
189
190 const struct Block final {
191 uvec3 extents;
192 uint32_t fill_0;
193 uvec3 input1_extents;
194 uint32_t fill_1;
195 uvec3 input2_extents;
196 uint32_t fill_2;
197 uvec3 input3_extents;
198 uint32_t fill_3;
199 } block{
200 v_output.extents(),
201 0u,
202 v_start.extents(),
203 0u,
204 v_end.extents(),
205 0u,
206 v_weight.extents(),
207 0u,
208 };
209
210 api::UniformParamsBuffer params(context, block);
211 api::PipelineBarrier pipeline_barrier{};
212
213 context->submit_compute_job(
214 // shader descriptor
215 VK_KERNEL(lerp),
216 // pipeline barrier
217 pipeline_barrier,
218 // global work group size
219 v_output.extents(),
220 // local work group size
221 adaptive_work_group_size(v_output.extents()),
222 // fence handle
223 VK_NULL_HANDLE,
224 // shader arguments
225 v_output.image(
226 pipeline_barrier,
227 api::PipelineStage::COMPUTE,
228 api::MemoryAccessType::WRITE),
229 v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE),
230 v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
231 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
232 // params buffer
233 params.buffer());
234
235 return convert(v_output);
236 }
237
_lerp_tensor_(Tensor & self_arg,const Tensor & end_arg,const Tensor & weight_arg)238 Tensor& _lerp_tensor_(
239 Tensor& self_arg,
240 const Tensor& end_arg,
241 const Tensor& weight_arg) {
242 check_inputs_elementwise_op(self_arg, end_arg);
243 check_inputs_elementwise_op(self_arg, weight_arg);
244
245 TORCH_CHECK(
246 self_arg.is_vulkan(),
247 "Vulkan: In-place operator is only supported on Vulkan tensors.");
248
249 api::Context* const context = api::context();
250
251 vTensor& v_self = convert(self_arg);
252
253 const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan();
254 const vTensor& v_end = convert(end_arg);
255
256 const Tensor weight =
257 weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
258 const vTensor& v_weight = convert(weight_arg);
259
260 const struct Block final {
261 uvec3 extents;
262 uint32_t fill_0;
263 uvec3 input1_extents;
264 uint32_t fill_1;
265 uvec3 input2_extents;
266 uint32_t fill_2;
267 } block{
268 v_self.extents(),
269 0u,
270 v_end.extents(),
271 0u,
272 v_weight.extents(),
273 0u,
274 };
275
276 api::UniformParamsBuffer params(context, block);
277 api::PipelineBarrier pipeline_barrier{};
278
279 context->submit_compute_job(
280 // shader descriptor
281 VK_KERNEL(lerp_),
282 // pipeline barrier
283 pipeline_barrier,
284 // global work group size
285 v_self.extents(),
286 // local work group size
287 adaptive_work_group_size(v_self.extents()),
288 // fence handle
289 VK_NULL_HANDLE,
290 // shader arguments
291 v_self.image(
292 pipeline_barrier,
293 api::PipelineStage::COMPUTE,
294 api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
295 v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE),
296 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
297 // params buffer
298 params.buffer());
299
300 return self_arg;
301 }
302
lerp_scalar(const Tensor & start,const Tensor & end,const Scalar & weight)303 Tensor lerp_scalar(
304 const Tensor& start,
305 const Tensor& end,
306 const Scalar& weight) {
307 return _lerp_scalar(start, end, weight);
308 }
309
lerp_scalar_(Tensor & self,const Tensor & end,const Scalar & weight)310 Tensor& lerp_scalar_(Tensor& self, const Tensor& end, const Scalar& weight) {
311 return _lerp_scalar_(self, end, weight);
312 }
313
lerp_tensor(const Tensor & start,const Tensor & end,const Tensor & weight)314 Tensor lerp_tensor(
315 const Tensor& start,
316 const Tensor& end,
317 const Tensor& weight) {
318 if (weight.sizes().size() == 0) {
319 return _lerp_scalar(start, end, weight.item<float>());
320 }
321 return _lerp_tensor(start, end, weight);
322 }
323
lerp_tensor_(Tensor & self,const Tensor & end,const Tensor & weight)324 Tensor& lerp_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
325 if (weight.sizes().size() == 0) {
326 return _lerp_scalar_(self, end, weight.item<float>());
327 }
328 return _lerp_tensor_(self, end, weight);
329 }
330
331 #ifdef USE_VULKAN_API
332
TORCH_LIBRARY_IMPL(aten,Vulkan,m)333 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
334 m.impl(TORCH_SELECTIVE_NAME("aten::lerp.Scalar"), TORCH_FN(lerp_scalar));
335 m.impl(TORCH_SELECTIVE_NAME("aten::lerp_.Scalar"), TORCH_FN(lerp_scalar_));
336 m.impl(TORCH_SELECTIVE_NAME("aten::lerp.Tensor"), TORCH_FN(lerp_tensor));
337 m.impl(TORCH_SELECTIVE_NAME("aten::lerp_.Tensor"), TORCH_FN(lerp_tensor_));
338 }
339
340 #endif /* USE_VULKAN_API */
341
342 } // namespace
343 } // namespace ops
344 } // namespace vulkan
345 } // namespace native
346 } // namespace at
347