xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/vulkan_quantized_api_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_VULKAN_API
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/dispatch/Dispatcher.h>
5 #include <ATen/native/quantized/PackedParams.h>
6 #include <ATen/native/quantized/cpu/QuantUtils.h>
7 #include <ATen/native/vulkan/api/Utils.h>
8 #include <ATen/native/vulkan/api/api.h>
9 #include <ATen/native/vulkan/impl/Packing.h>
10 #include <ATen/native/vulkan/ops/Common.h>
11 #include <ATen/native/vulkan/ops/Convert.h>
12 #include <ATen/native/vulkan/ops/Copy.h>
13 #include <ATen/native/vulkan/ops/Factory.h>
14 #include <ATen/native/vulkan/ops/Mm.h>
15 #include <ATen/native/vulkan/ops/QuantizedFunctions.h>
16 #include <c10/util/irange.h>
17 #include <gtest/gtest.h>
18 #include <math.h>
19 #include <cstring>
20 #include <iostream>
21 #include <random>
22 
23 #include <cstdio>
24 
25 using namespace at::native::vulkan::api::utils;
26 
27 /*
28  * TODO: rename this file to something like vulkan_experimental_test and move
29  * this under caffe2/fb/vulkan. This file should be used to test experimental
30  * features of the Vulkan backend. vulkan_api_test cannot serve this purpose
31  * because it cannot link against symbols in the ATen/native/vulkan folder.
32  */
33 
34 namespace {
35 
36 using namespace at::native::vulkan;
37 
38 #ifdef USE_VULKAN_FP16_INFERENCE
39 constexpr float kTolerance = 1e-2;
40 #else
41 constexpr float kTolerance = 1e-5;
42 #endif
43 
checkRtol(const at::Tensor & diff,const std::vector<at::Tensor> & inputs,const float tolerated_error=0)44 bool checkRtol(
45     const at::Tensor& diff,
46     const std::vector<at::Tensor>& inputs,
47     const float tolerated_error = 0) {
48   double maxValue = 0.0;
49 
50   for (const auto& tensor : inputs) {
51     maxValue = fmax(tensor.abs().max().item<double>(), maxValue);
52   }
53 
54 #ifdef USE_VULKAN_FP16_INFERENCE
55   constexpr float tolerance = 1e-2;
56 #else
57   constexpr float tolerance = 1e-5;
58 #endif
59 
60   return diff.abs().max().item<double>() <=
61       (tolerance * maxValue + tolerated_error);
62 }
63 
almostEqual(const at::Tensor & a,const at::Tensor & b,const float tolerated_error=0)64 bool almostEqual(
65     const at::Tensor& a,
66     const at::Tensor& b,
67     const float tolerated_error = 0) {
68   return checkRtol(a - b, {a, b}, tolerated_error);
69 }
70 
71 /* Unused function
72 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
73   return (a - b).abs().max().item<float>() == 0.0f;
74 }
75 */
76 
showRtol(const at::Tensor & a,const at::Tensor & b,long * xpos=nullptr,long * ypos=nullptr)77 void showRtol(
78     const at::Tensor& a,
79     const at::Tensor& b,
80     long* xpos = nullptr,
81     long* ypos = nullptr) {
82   const auto diff = (a - b).abs();
83 
84   double maxValue = a.abs().max().item<double>();
85   maxValue = fmax(b.abs().max().item<double>(), maxValue);
86 
87 #ifdef USE_VULKAN_FP16_INFERENCE
88   constexpr float tolerance = 1e-2;
89 #else
90   constexpr float tolerance = 1e-5;
91 #endif
92 
93   const double maxDiff = maxValue * tolerance;
94   std::cout << "Max Diff allowed: " << maxDiff << std::endl;
95   std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
96   if (diff.sizes().size() == 2) {
97     for (const auto y : c10::irange(diff.sizes()[0])) {
98       std::cout << y << ":";
99       for (const auto x : c10::irange(diff.sizes()[1])) {
100         double diff_xy = diff[y][x].item<double>();
101         if (diff_xy > maxDiff) {
102           std::cout << std::setw(5) << x;
103           if (diff.max().item<double>() == diff_xy) {
104             std::cout << " : " << diff_xy;
105             if (xpos && ypos) {
106               *xpos = x;
107               *ypos = y;
108               return;
109             }
110           }
111         } else {
112           std::cout << std::setw(5) << " ";
113         }
114       }
115       std::cout << std::endl;
116     }
117   }
118 }
119 
120 template <class... Inputs>
makeStack(Inputs &&...inputs)121 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
122   return {std::forward<Inputs>(inputs)...};
123 }
124 
125 template <class... Args>
callOpByHandle(const c10::OperatorHandle & op,Args...args)126 inline std::vector<c10::IValue> callOpByHandle(
127     const c10::OperatorHandle& op,
128     Args... args) {
129   auto stack = makeStack(std::forward<Args>(args)...);
130   c10::Dispatcher::singleton().callBoxed(op, &stack);
131   return stack;
132 }
133 
134 template <class... Args>
callOpByName(const char * func_name,const char * overload_name,Args...args)135 inline std::vector<c10::IValue> callOpByName(
136     const char* func_name,
137     const char* overload_name,
138     Args... args) {
139   const std::optional<c10::OperatorHandle> op_handle =
140       c10::Dispatcher::singleton().findSchema({func_name, overload_name});
141   assert(op_handle.has_value());
142   return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
143 }
144 
145 using namespace at::native::vulkan;
146 using at::native::vulkan::api::utils::ivec3;
147 using at::native::vulkan::api::utils::ivec4;
148 using at::native::vulkan::api::utils::vec4;
149 
operator <<(std::ostream & os,const vec4 & v)150 std::ostream& operator<<(std::ostream& os, const vec4& v) {
151   os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
152      << v.data[3u] << ")";
153   return os;
154 }
155 
operator <<(std::ostream & os,const ivec3 & v)156 std::ostream& operator<<(std::ostream& os, const ivec3& v) {
157   os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
158   return os;
159 }
160 
operator <<(std::ostream & os,const ivec4 & v)161 std::ostream& operator<<(std::ostream& os, const ivec4& v) {
162   os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
163      << v.data[3u] << ")";
164   return os;
165 }
166 
167 } // namespace
168 
169 namespace {
170 
rand_double(const double min,const double max)171 double rand_double(const double min, const double max) {
172   std::random_device rd;
173   std::mt19937 gen(rd());
174   if (std::fabs(max - min) < std::numeric_limits<double>::epsilon()) {
175     return min;
176   }
177   return std::uniform_real_distribution<double>(min, max)(gen);
178 }
179 
rand_int(const int min,const int max)180 int rand_int(const int min, const int max) {
181   std::random_device rd;
182   std::mt19937 gen(rd());
183   return std::uniform_int_distribution<int>(min, max)(gen);
184 }
185 
rand_pos_int(const int max_val)186 int rand_pos_int(const int max_val) {
187   TORCH_CHECK(max_val > 0, "max value must be positive");
188   return 1 + rand_int(0, max_val);
189 }
190 
produce_random_tensor(const at::IntArrayRef tensor_shape,const double s_min=1.0,const double s_max=100.0,const double shift=0.45)191 at::Tensor produce_random_tensor(
192     const at::IntArrayRef tensor_shape,
193     const double s_min = 1.0,
194     const double s_max = 100.0,
195     const double shift = 0.45) {
196   // tensor is randomly generated with values in the range
197   // [-shift * s, (1-shift) * s), where s is randomly generated in the range
198   // [s_min, s_max]
199   // with these default values, s is randomly generated in the range [1, 100]
200   // this means that the range of the tensor values could be as narrow as
201   // [-0.45, 0.55) or as wide as [-45.0, 55.0)
202   TORCH_CHECK(s_min > 0, "scalar lower bound must be positive");
203   TORCH_CHECK(s_min <= s_max, "scalar lower bound must be <= upper bound");
204   const auto scalar = rand_double(s_min, s_max);
205   return scalar *
206       (at::rand(tensor_shape, at::device(at::kCPU).dtype(at::kFloat)) - shift);
207 }
208 
produce_random_scale(const double scale_min=0.001,const double scale_max=2.0)209 double produce_random_scale(
210     const double scale_min = 0.001,
211     const double scale_max = 2.0) {
212   TORCH_CHECK(scale_min <= scale_max, "scale min must be <= scale max");
213   // scale is randomly generated in the range [scale_min, scale_max)
214   return rand_double(scale_min, scale_max);
215   ;
216 }
217 
produce_random_zero_point(const c10::ScalarType dtype)218 int produce_random_zero_point(const c10::ScalarType dtype) {
219   int zero_point = 0;
220   switch (dtype) {
221     case c10::ScalarType::QUInt8:
222       zero_point = rand_int(0, 255);
223       break;
224     case c10::ScalarType::QInt8:
225       zero_point = rand_int(-128, 127);
226       break;
227     case c10::ScalarType::QInt32:
228       zero_point = rand_int(-100000, 100000);
229       break;
230     default:
231       TORCH_CHECK(
232           false,
233           "Vulkan quantization currently not supported for dtype ",
234           dtype);
235   }
236   return zero_point;
237 }
238 
compute_quant_params(const at::Tensor & tensor,const c10::ScalarType dtype=c10::ScalarType::QUInt8)239 std::tuple<double, int> compute_quant_params(
240     const at::Tensor& tensor,
241     const c10::ScalarType dtype = c10::ScalarType::QUInt8) {
242   int zero_point_min = 0;
243   int zero_point_max = 255;
244   if (dtype == c10::ScalarType::QUInt8) {
245     zero_point_min = 0;
246     zero_point_max = 255;
247   } else if (dtype == c10::ScalarType::QInt8) {
248     zero_point_min = -128;
249     zero_point_max = 127;
250   } else {
251     TORCH_CHECK(
252         false,
253         "Computation of quant params only available for dtypes",
254         "QUInt8 and QInt8");
255   }
256   const auto tensor_max = tensor.max().item<double>();
257   const auto tensor_min = tensor.min().item<double>();
258   auto q_params = quant_utils::ChooseQuantizationParams(
259       /*min=*/safe_downcast<float>(tensor_min),
260       /*max=*/safe_downcast<float>(tensor_max),
261       /*qmin=*/zero_point_min,
262       /*qmax=*/zero_point_max,
263       /*preserve_sparsity=*/false,
264       /*force_scale_power_of_two=*/false,
265       /*reduce_range=*/false);
266   return std::tuple<double, int>(q_params.scale, q_params.zero_point);
267 }
268 
269 } // namespace
270 
271 namespace {
272 
273 class VulkanAPITest : public ::testing::Test {
274  public:
SetUp()275   void SetUp() override {
276     if (!at::is_vulkan_available()) {
277       GTEST_SKIP() << "Vulkan is not available";
278     }
279 #if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
280     at::native::vulkan::api::context()->reset_querypool();
281 #endif
282   }
283 
TearDown()284   void TearDown() override {
285 #if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
286     try {
287       at::native::vulkan::api::context()->querypool().extract_results();
288       at::native::vulkan::api::context()->querypool().print_results();
289     } catch (const std::exception& e) {
290       std::cout << "Could not get querypool results!"
291                 << " Reason: " << e.what() << std::endl;
292     }
293 #endif
294   }
295 };
296 
cpu_to_vulkan(at::Tensor in_cpu)297 at::Tensor cpu_to_vulkan(at::Tensor in_cpu) {
298   auto options = in_cpu.options();
299   if (options.dtype().toScalarType() == c10::ScalarType::QUInt8 ||
300       options.dtype().toScalarType() == c10::ScalarType::QInt8 ||
301       options.dtype().toScalarType() == c10::ScalarType::QInt32) {
302     auto ret = at::native::vulkan::ops::_empty_affine_quantized(
303         in_cpu.sizes(),
304         options.dtype().toScalarType(),
305         options.layout(),
306         options.device(),
307         options.pinned_memory(),
308         in_cpu.q_scale(),
309         in_cpu.q_zero_point(),
310         c10::MemoryFormat::Contiguous);
311     at::native::vulkan::ops::copy_(ret, in_cpu);
312     return ret;
313   } else {
314     auto ret = at::empty(in_cpu.sizes(), options);
315     at::native::vulkan::ops::copy_(ret, in_cpu);
316     return ret;
317   }
318 }
319 
vulkan_to_cpu(at::Tensor vulkan,at::Tensor in_cpu)320 at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) {
321   auto q_options = in_cpu.options();
322   if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8 ||
323       q_options.dtype().toScalarType() == c10::ScalarType::QInt8 ||
324       q_options.dtype().toScalarType() == c10::ScalarType::QInt32) {
325     auto output = at::native::empty_affine_quantized(
326         in_cpu.sizes(),
327         q_options.dtype().toScalarType(),
328         q_options.layout(),
329         q_options.device(),
330         q_options.pinned_memory(),
331         in_cpu.q_scale(),
332         in_cpu.q_zero_point());
333     at::native::vulkan::ops::copy_(output, vulkan);
334     return output;
335   } else {
336     auto output = at::empty(in_cpu.sizes(), q_options);
337     at::native::vulkan::ops::copy_(output, vulkan);
338     return output;
339   }
340 }
341 
TEST_F(VulkanAPITest,uniform_buffer_copy)342 TEST_F(VulkanAPITest, uniform_buffer_copy) {
343   using namespace at::native::vulkan;
344 
345   struct TestStruct {
346     int a;
347     int b;
348     int c;
349   };
350 
351   TestStruct test_struct{4, 9, 10};
352 
353   api::UniformParamsBuffer params(api::context(), test_struct);
354   api::UniformParamsBuffer params_copy = params;
355 
356   api::MemoryMap copy_mapping(
357       params_copy.buffer(), api::MemoryAccessType::READ);
358 
359   TestStruct* test_copy_p = copy_mapping.template data<TestStruct>();
360 
361   ASSERT_TRUE(test_copy_p->a == test_struct.a);
362   ASSERT_TRUE(test_copy_p->b == test_struct.b);
363   ASSERT_TRUE(test_copy_p->c == test_struct.c);
364 }
365 
TEST_F(VulkanAPITest,copy_to_buffer)366 TEST_F(VulkanAPITest, copy_to_buffer) {
367   using namespace at::native::vulkan;
368 
369   std::array<at::Tensor, 4> test_tensors = {
370       // 4D
371       at::rand(
372           {7, 17, 134, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
373       // 3D
374       at::rand({67, 134, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
375       // 2D
376       at::rand({229, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
377       // 1D
378       at::rand({1902}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
379   };
380 
381   for (auto in_cpu : test_tensors) {
382     vTensor in_vk_copied = ops::to_vulkan(in_cpu, api::StorageType::BUFFER);
383     at::Tensor out_copied = ops::from_vulkan(in_vk_copied);
384 
385     const auto check_copy = almostEqual(out_copied, in_cpu);
386 
387     if (!check_copy) {
388       std::cout << "Copy failed on size " << in_cpu.sizes() << "with dtype"
389                 << in_cpu.dtype() << std::endl;
390     }
391 
392     ASSERT_TRUE(check_copy);
393   }
394 }
395 
TEST_F(VulkanAPITest,copy_to_buffer_channels_last)396 TEST_F(VulkanAPITest, copy_to_buffer_channels_last) {
397   using namespace at::native::vulkan;
398 
399   at::TensorOptions options(at::kCPU);
400   options = options.dtype(at::kFloat);
401 
402   std::array<at::Tensor, 1> test_tensors = {
403       // 4D
404       at::rand({7, 17, 134, 213}, options).to(at::MemoryFormat::ChannelsLast),
405   };
406 
407   for (auto in_cpu : test_tensors) {
408     vTensor in_vk_copied = ops::to_vulkan(in_cpu, api::StorageType::BUFFER);
409     at::Tensor out_copied = ops::from_vulkan(in_vk_copied);
410 
411     const auto check_copy = almostEqual(out_copied, in_cpu);
412 
413     if (!check_copy) {
414       std::cout << "Copy failed on size " << in_cpu.sizes() << "with dtype"
415                 << in_cpu.dtype() << std::endl;
416     }
417 
418     ASSERT_TRUE(check_copy);
419   }
420 }
421 
422 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_support_vulkan)423 TEST_F(VulkanAPITest, DISABLED_support_vulkan) {
424   const double scale = 0.1;
425   const int zero_point = 10;
426 
427   auto in_cpu =
428       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 12 -
429       6;
430   auto in_cpu_quantized = at::quantize_per_tensor(
431       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
432 
433   auto in_vulkan_quantized = cpu_to_vulkan(in_cpu_quantized);
434   at::native::vulkan::api::PipelineBarrier pipeline_barrier{};
435   at::native::vulkan::vTensor& v_self =
436       at::native::vulkan::ops::convert(in_vulkan_quantized);
437   if (in_cpu.dtype() == c10::kQUInt8) {
438     v_self.image(
439         pipeline_barrier,
440         at::native::vulkan::api::PipelineStage::COMPUTE,
441         at::native::vulkan::api::MemoryAccessType::READ);
442     v_self.image(
443         pipeline_barrier,
444         at::native::vulkan::api::PipelineStage::COMPUTE,
445         at::native::vulkan::api::MemoryAccessType::WRITE);
446   }
447   auto output = vulkan_to_cpu(in_vulkan_quantized, in_cpu_quantized);
448   const auto check = almostEqual(
449       at::native::int_repr_quantized_cpu(in_cpu_quantized),
450       at::native::int_repr_quantized_cpu(output));
451 
452   if (!check) {
453     showRtol(
454         at::native::int_repr_quantized_cpu(in_cpu_quantized),
455         at::native::int_repr_quantized_cpu(output));
456   }
457 
458   ASSERT_TRUE(check);
459 }
460 
test_cpu_to_vulkan_and_vulkan_to_cpu(const at::IntArrayRef input_shape,const double scale,const int zero_point,const c10::ScalarType dtype=c10::ScalarType::QUInt8)461 void test_cpu_to_vulkan_and_vulkan_to_cpu(
462     const at::IntArrayRef input_shape,
463     const double scale,
464     const int zero_point,
465     const c10::ScalarType dtype = c10::ScalarType::QUInt8) {
466   // produce random quantized cpu tensor
467   auto in_cpu = produce_random_tensor(input_shape);
468   auto in_q_cpu = at::quantize_per_tensor(in_cpu, scale, zero_point, dtype);
469 
470   // copy quantized cpu tensor to vulkan
471   auto in_q_cpu_vk = cpu_to_vulkan(in_q_cpu);
472 
473   // copy quantized vulkan tensor to cpu
474   auto out_q_cpu = vulkan_to_cpu(in_q_cpu_vk, in_q_cpu);
475 
476   // check that the copy equals the original
477   const auto diff = at::native::int_repr_quantized_cpu(in_q_cpu) -
478       at::native::int_repr_quantized_cpu(out_q_cpu);
479 
480   const int error = diff.abs().max().item<int>();
481 
482   const auto check = (error == 0);
483 
484   if (!check) {
485     std::cout << "Copy to vulkan and back to cpu failed with input shape: "
486               << input_shape << " scale: " << scale
487               << " and zero point: " << zero_point << std::endl;
488     std::cout << "Error: " << error << std::endl;
489   }
490 
491   ASSERT_TRUE(check);
492 }
493 
test_cpu_to_vulkan_and_vulkan_to_cpu_random(const c10::ScalarType dtype)494 void test_cpu_to_vulkan_and_vulkan_to_cpu_random(const c10::ScalarType dtype) {
495   const double scale = produce_random_scale();
496   const int zero_point = produce_random_zero_point(dtype);
497   const at::IntArrayRef tensor_shape = {
498       rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)};
499   test_cpu_to_vulkan_and_vulkan_to_cpu(tensor_shape, scale, zero_point, dtype);
500 }
501 
502 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_quint8)503 TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_quint8) {
504   const c10::ScalarType dtype = c10::ScalarType::QUInt8;
505   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, 21, dtype);
506   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype);
507   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, 120, dtype);
508   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype);
509   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, 10, dtype);
510   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype);
511   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, 15, dtype);
512   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype);
513   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, 10, dtype);
514   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype);
515   test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, 10, dtype);
516   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype);
517   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, 43, dtype);
518   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype);
519   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, 19, dtype);
520   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype);
521   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, 19, dtype);
522   test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype);
523 
524   for (int i = 0; i < 20; i += 1) {
525     test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype);
526   }
527 }
528 
529 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint8)530 TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint8) {
531   const c10::ScalarType dtype = c10::ScalarType::QInt8;
532   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21, dtype);
533   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype);
534   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, -120, dtype);
535   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype);
536   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, -10, dtype);
537   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype);
538   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, -15, dtype);
539   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype);
540   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, -10, dtype);
541   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype);
542   test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, -10, dtype);
543   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype);
544   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43, dtype);
545   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype);
546   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, -19, dtype);
547   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype);
548   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, -19, dtype);
549   test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype);
550 
551   for (int i = 0; i < 20; i += 1) {
552     test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype);
553   }
554 }
555 
556 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint32)557 TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint32) {
558   const c10::ScalarType dtype = c10::ScalarType::QInt32;
559   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21123, dtype);
560   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.339, 8734, dtype);
561   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.228, -12023, dtype);
562   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.338, 8723, dtype);
563   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.193, -1023, dtype);
564   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.0449, 972, dtype);
565   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.073, -15, dtype);
566   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1572, 102, dtype);
567   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.147, -156, dtype);
568   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.129, 10448, dtype);
569   test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.137, -10, dtype);
570   test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype);
571   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43267, dtype);
572   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1243, 19, dtype);
573   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1889, -19784, dtype);
574   test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1345, 196, dtype);
575   test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.129, -19489, dtype);
576   test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype);
577 
578   for (int i = 0; i < 20; i += 1) {
579     test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype);
580   }
581 }
582 
test_cpu_to_vulkan_and_dequantize(const at::IntArrayRef input_shape,const double scale,const int zero_point,const c10::ScalarType dtype=c10::ScalarType::QUInt8)583 void test_cpu_to_vulkan_and_dequantize(
584     const at::IntArrayRef input_shape,
585     const double scale,
586     const int zero_point,
587     const c10::ScalarType dtype = c10::ScalarType::QUInt8) {
588   // produce random quantized cpu tensor
589   auto in_cpu = produce_random_tensor(input_shape);
590   auto in_q_cpu = at::quantize_per_tensor(in_cpu, scale, zero_point, dtype);
591 
592   // copy quantized cpu tensor to vulkan
593   auto in_q_cpu_vk = cpu_to_vulkan(in_q_cpu);
594 
595   // dequantize tensors
596   const auto out_cpu_deq = at::dequantize(in_q_cpu);
597   const auto out_vk_deq = at::dequantize(in_q_cpu_vk);
598   const auto out_vk_deq_cpu = out_vk_deq.cpu();
599 
600   // check dequantized tensors are equal
601   const auto check = almostEqual(out_cpu_deq, out_vk_deq_cpu);
602 
603   if (!check) {
604     const auto error =
605         at::abs(out_vk_deq_cpu - out_cpu_deq).max().item<float>();
606     std::cout << "Copy cpu to vulkan and dequantize failed with input shape: "
607               << input_shape << " scale: " << scale
608               << " and zero point: " << zero_point << std::endl;
609     std::cout << "Error: " << error << std::endl;
610   }
611   ASSERT_TRUE(check);
612 }
613 
test_cpu_to_vulkan_and_dequantize_random(const c10::ScalarType dtype)614 void test_cpu_to_vulkan_and_dequantize_random(const c10::ScalarType dtype) {
615   const double scale = produce_random_scale();
616   const int zero_point = produce_random_zero_point(dtype);
617   const at::IntArrayRef tensor_shape = {
618       rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)};
619   test_cpu_to_vulkan_and_dequantize(tensor_shape, scale, zero_point, dtype);
620 }
621 
TEST_F(VulkanAPITest,cpu_to_vulkan_and_dequantize_quint8)622 TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_quint8) {
623   const c10::ScalarType dtype = c10::ScalarType::QUInt8;
624   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype);
625   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
626   test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype);
627   test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
628   test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype);
629   test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
630   test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.07, 15, dtype);
631   test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
632   test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype);
633   test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
634   test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype);
635   test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype);
636   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, 43, dtype);
637   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
638   test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype);
639   test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
640   test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype);
641   test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
642 
643   for (int i = 0; i < 20; i += 1) {
644     test_cpu_to_vulkan_and_dequantize_random(dtype);
645   }
646 }
647 
TEST_F(VulkanAPITest,cpu_to_vulkan_and_dequantize_qint8)648 TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_qint8) {
649   const c10::ScalarType dtype = c10::ScalarType::QInt8;
650   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype);
651   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
652   test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype);
653   test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
654   test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype);
655   test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
656   test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.07, -15, dtype);
657   test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
658   test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.1, -10, dtype);
659   test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
660   test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.1, -10, dtype);
661   test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype);
662   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, -43, dtype);
663   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
664   test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype);
665   test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
666   test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.1, -19, dtype);
667   test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
668 
669   for (int i = 0; i < 20; i += 1) {
670     test_cpu_to_vulkan_and_dequantize_random(dtype);
671   }
672 }
673 
TEST_F(VulkanAPITest,cpu_to_vulkan_and_dequantize_qint32)674 TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_qint32) {
675   const c10::ScalarType dtype = c10::ScalarType::QInt32;
676   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, -21123, dtype);
677   test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.339, 8734, dtype);
678   test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.228, -12023, dtype);
679   test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.338, 8723, dtype);
680   test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.193, -1023, dtype);
681   test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.0449, 972, dtype);
682   test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.073, -15, dtype);
683   test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1572, 102, dtype);
684   test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.147, -156, dtype);
685   test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.129, 10448, dtype);
686   test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.137, -10, dtype);
687   test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype);
688   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, -43267, dtype);
689   test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1243, 19, dtype);
690   test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1889, -19784, dtype);
691   test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1345, 196, dtype);
692   test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.129, -19489, dtype);
693   test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
694 
695   for (int i = 0; i < 20; i += 1) {
696     test_cpu_to_vulkan_and_dequantize_random(dtype);
697   }
698 }
699 
700 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_quantize_per_tensor)701 TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor) {
702   const auto in_cpu =
703       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
704   const auto in_vulkan = in_cpu.vulkan();
705 
706   const double scale = 0.1;
707   const int zero_point = 10;
708 
709   const auto out_cpu = at::quantize_per_tensor(
710       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
711   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
712       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
713 
714   auto output_for_quantized_vulkan = vulkan_to_cpu(out_vulkan, out_cpu);
715 
716   int rtol = 1;
717   const auto check = at::allclose(
718       at::native::int_repr_quantized_cpu(out_cpu),
719       at::native::int_repr_quantized_cpu(output_for_quantized_vulkan),
720       rtol);
721 
722   if (!check) {
723     std::cout << "Max Diff allowed: " << rtol << std::endl;
724   }
725 
726   ASSERT_TRUE(check);
727 }
728 
test_quantize_per_tensor_and_vulkan_to_cpu(const at::IntArrayRef input_shape,const double input_scale,const int input_zero_point,const c10::ScalarType dtype=c10::ScalarType::QUInt8,const int tolerance=1)729 void test_quantize_per_tensor_and_vulkan_to_cpu(
730     const at::IntArrayRef input_shape,
731     const double input_scale,
732     const int input_zero_point,
733     const c10::ScalarType dtype = c10::ScalarType::QUInt8,
734     const int tolerance = 1) {
735   // tolerance = 1, to allow for precision differences after dividing by random
736   // scale which could result on a difference of 1 unit in the quantized result
737 
738   at::Tensor input = produce_random_tensor(input_shape);
739 
740   // quantize tensor
741   at::Tensor out_q_cpu =
742       at::quantize_per_tensor(input, input_scale, input_zero_point, dtype);
743 
744   at::Tensor out_q_vk = at::quantize_per_tensor(
745       input.vulkan(), input_scale, input_zero_point, dtype);
746 
747   // copy vulkan tensor to cpu
748   at::Tensor out_q_vk_cpu = vulkan_to_cpu(out_q_vk, out_q_cpu);
749 
750   const auto diff = at::native::int_repr_quantized_cpu(out_q_vk_cpu) -
751       at::native::int_repr_quantized_cpu(out_q_cpu);
752 
753   const int error = diff.abs().max().item<int>();
754 
755   const auto check = (error <= tolerance);
756 
757   if (!check) {
758     std::cout << "Quantize and copy to cpu failed with input shape: "
759               << input_shape << " scale: " << input_scale
760               << " and zero point: " << input_zero_point << std::endl;
761     std::cout << "Error: " << error << std::endl;
762   }
763 
764   ASSERT_TRUE(check);
765 }
766 
test_quantize_per_tensor_and_vulkan_to_cpu_random(const c10::ScalarType dtype)767 void test_quantize_per_tensor_and_vulkan_to_cpu_random(
768     const c10::ScalarType dtype) {
769   const double scale = produce_random_scale();
770   const int zero_point = produce_random_zero_point(dtype);
771   const at::IntArrayRef tensor_shape = {
772       rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)};
773   test_quantize_per_tensor_and_vulkan_to_cpu(
774       tensor_shape, scale, zero_point, dtype);
775 }
776 
777 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_quantize_per_tensor_and_vulkan_to_cpu_quint8)778 TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_quint8) {
779   const c10::ScalarType dtype = c10::ScalarType::QUInt8;
780   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, 21, dtype);
781   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype);
782   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, 120, dtype);
783   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype);
784   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, 10, dtype);
785   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype);
786   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, 15, dtype);
787   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype);
788   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, 10, dtype);
789   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype);
790   test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, 10, dtype);
791   test_quantize_per_tensor_and_vulkan_to_cpu(
792       {1, 1, 10, 14}, 0.0001, 101, dtype);
793   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, 43, dtype);
794   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype);
795   test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, 19, dtype);
796   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype);
797   test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, 19, dtype);
798   test_quantize_per_tensor_and_vulkan_to_cpu(
799       {11, 17, 25, 29}, 0.027, 89, dtype);
800   test_quantize_per_tensor_and_vulkan_to_cpu(
801       {3, 16, 77, 54}, 0.204173, 229, dtype);
802 
803   for (int i = 0; i < 20; i += 1) {
804     test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype);
805   }
806 }
807 
808 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint8)809 TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint8) {
810   const c10::ScalarType dtype = c10::ScalarType::QInt8;
811   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21, dtype);
812   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype);
813   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, -120, dtype);
814   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype);
815   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, -10, dtype);
816   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype);
817   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, -15, dtype);
818   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype);
819   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, -10, dtype);
820   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype);
821   test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, -10, dtype);
822   test_quantize_per_tensor_and_vulkan_to_cpu(
823       {1, 1, 10, 14}, 0.0001, 101, dtype);
824   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43, dtype);
825   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype);
826   test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, -19, dtype);
827   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype);
828   test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, -19, dtype);
829   test_quantize_per_tensor_and_vulkan_to_cpu(
830       {11, 17, 25, 29}, 0.027, 89, dtype);
831   test_quantize_per_tensor_and_vulkan_to_cpu(
832       {3, 16, 77, 54}, 0.204173, 229, dtype);
833 
834   for (int i = 0; i < 20; i += 1) {
835     test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype);
836   }
837 }
838 
839 // TODO: Fix vulkan to cpu on Android
TEST_F(VulkanAPITest,DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint32)840 TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint32) {
841   const c10::ScalarType dtype = c10::ScalarType::QInt32;
842   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21123, dtype);
843   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.339, 8734, dtype);
844   test_quantize_per_tensor_and_vulkan_to_cpu(
845       {1, 1, 4, 1}, 0.228, -12023, dtype);
846   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.338, 8723, dtype);
847   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.193, -1023, dtype);
848   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.0449, 972, dtype);
849   test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.073, -15, dtype);
850   test_quantize_per_tensor_and_vulkan_to_cpu(
851       {1, 1, 12, 17}, 0.1572, 102, dtype);
852   test_quantize_per_tensor_and_vulkan_to_cpu(
853       {3, 5, 12, 17}, 0.147, -156, dtype);
854   test_quantize_per_tensor_and_vulkan_to_cpu(
855       {1, 1, 17, 12}, 0.129, 10448, dtype);
856   test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.137, -10, dtype);
857   test_quantize_per_tensor_and_vulkan_to_cpu(
858       {1, 1, 10, 14}, 0.0001, 101, dtype, 1);
859   test_quantize_per_tensor_and_vulkan_to_cpu(
860       {3, 5, 10, 14}, 0.009, -43267, dtype);
861   test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1243, 19, dtype);
862   test_quantize_per_tensor_and_vulkan_to_cpu(
863       {4, 4, 9, 17}, 0.1889, -19784, dtype);
864   test_quantize_per_tensor_and_vulkan_to_cpu(
865       {3, 5, 25, 29}, 0.1345, 196, dtype);
866   test_quantize_per_tensor_and_vulkan_to_cpu(
867       {4, 4, 25, 29}, 0.129, -19489, dtype);
868   test_quantize_per_tensor_and_vulkan_to_cpu(
869       {11, 17, 25, 29}, 0.027, 89, dtype);
870   test_quantize_per_tensor_and_vulkan_to_cpu(
871       {3, 16, 77, 54}, 0.204173, 229, dtype);
872 
873   for (int i = 0; i < 20; i += 1) {
874     test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype);
875   }
876 }
877 
TEST_F(VulkanAPITest,quantize_dequantize)878 TEST_F(VulkanAPITest, quantize_dequantize) {
879   const auto in_cpu =
880       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
881   const auto in_vulkan = in_cpu.vulkan();
882 
883   const double scale = 0.1;
884   const int zero_point = 10;
885   // quantize tensors
886   const auto out_cpu = at::quantize_per_tensor(
887       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
888   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
889       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
890   // dequantize tensors
891   const auto out_cpu_deq = at::dequantize(out_cpu);
892   const auto out_vulkan_deq = at::native::vulkan::ops::dequantize(out_vulkan);
893   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu);
894 
895   float rtol = 1;
896   float atol = 0.5;
897   const auto check =
898       at::allclose(in_cpu, output_for_dequantized_vulkan, rtol, atol);
899 
900   if (!check) {
901     std::cout << "Max Diff allowed: " << rtol << std::endl;
902   }
903 
904   ASSERT_TRUE(check);
905 
906   const auto check_two =
907       at::allclose(out_cpu_deq, output_for_dequantized_vulkan, rtol, atol);
908 
909   if (!check_two) {
910     std::cout << "Max Diff allowed: " << rtol << std::endl;
911   }
912 
913   ASSERT_TRUE(check_two);
914 }
915 
test_quantize_per_tensor_and_dequantize(const at::IntArrayRef input_shape,const double input_scale,const int input_zero_point,const c10::ScalarType dtype=c10::ScalarType::QUInt8,bool use_qparams=false)916 void test_quantize_per_tensor_and_dequantize(
917     const at::IntArrayRef input_shape,
918     const double input_scale,
919     const int input_zero_point,
920     const c10::ScalarType dtype = c10::ScalarType::QUInt8,
921     bool use_qparams = false) {
922   at::Tensor input = produce_random_tensor(input_shape);
923 
924   at::Tensor input_scale_qparam = at::empty({1});
925   input_scale_qparam[0] = input_scale;
926   at::Tensor input_zero_point_qparam = at::empty({1});
927   input_zero_point_qparam[0] = input_zero_point;
928 
929   // quantize tensors
930   at::Tensor out_q_cpu = use_qparams
931       ? at::quantize_per_tensor(
932             input, input_scale_qparam, input_zero_point_qparam, dtype)
933       : at::quantize_per_tensor(input, input_scale, input_zero_point, dtype);
934   at::Tensor out_q_vk = use_qparams
935       ? at::quantize_per_tensor(
936             input.vulkan(), input_scale_qparam, input_zero_point_qparam, dtype)
937       : at::quantize_per_tensor(
938             input.vulkan(), input_scale, input_zero_point, dtype);
939 
940   // dequantize tensors
941   const auto out_cpu_deq = at::dequantize(out_q_cpu);
942   const auto out_vk_deq = at::dequantize(out_q_vk);
943   const auto out_vk_deq_cpu = out_vk_deq.cpu();
944 
945   // check dequantized tensor are equal
946   const float tolerance = safe_downcast<float>(input_scale);
947   // tolerated error = scale, to allow for precision differences after dividing
948   // by random scale, which could result on a difference of 1 unit in the
949   // quantized result.
950   const auto check = almostEqual(out_cpu_deq, out_vk_deq_cpu, tolerance);
951 
952   if (!check) {
953     const auto error =
954         at::abs(out_vk_deq_cpu - out_cpu_deq).max().item<float>();
955     std::cout << "Quantize and Dequantize failed with input shape: "
956               << input_shape << " scale: " << input_scale
957               << " and zero point: " << input_zero_point << std::endl;
958     std::cout << "Error: " << error << std::endl;
959   }
960   ASSERT_TRUE(check);
961 }
962 
test_quantize_per_tensor_and_dequantize_random(const c10::ScalarType dtype,bool use_qparams=false)963 void test_quantize_per_tensor_and_dequantize_random(
964     const c10::ScalarType dtype,
965     bool use_qparams = false) {
966   const double scale = produce_random_scale();
967   const int zero_point = produce_random_zero_point(dtype);
968   const at::IntArrayRef tensor_shape = {
969       rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)};
970   test_quantize_per_tensor_and_dequantize(
971       tensor_shape, scale, zero_point, dtype, use_qparams);
972 }
973 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_quint8)974 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_quint8) {
975   const c10::ScalarType dtype = c10::ScalarType::QUInt8;
976   test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype);
977   test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
978   test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype);
979   test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
980   test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype);
981   test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
982   test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, 15, dtype);
983   test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
984   test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype);
985   test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
986   test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype);
987   test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
988   test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, 43, dtype);
989   test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
990   test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype);
991   test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
992   test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype);
993   test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
994 
995   for (int i = 0; i < 20; i += 1) {
996     test_quantize_per_tensor_and_dequantize_random(dtype);
997   }
998 }
999 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_quint8_qparams)1000 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_quint8_qparams) {
1001   const c10::ScalarType dtype = c10::ScalarType::QUInt8;
1002   test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype, true);
1003   test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype, true);
1004   test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype, true);
1005   test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype, true);
1006   test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype, true);
1007   test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype, true);
1008   test_quantize_per_tensor_and_dequantize(
1009       {1, 1, 11, 17}, 0.07, 15, dtype, true);
1010   test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype, true);
1011   test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype, true);
1012   test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype, true);
1013   test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype, true);
1014   test_quantize_per_tensor_and_dequantize(
1015       {1, 1, 10, 14}, 0.001, 101, dtype, true);
1016   test_quantize_per_tensor_and_dequantize(
1017       {3, 5, 10, 14}, 0.009, 43, dtype, true);
1018   test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype, true);
1019   test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype, true);
1020   test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype, true);
1021   test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype, true);
1022   test_quantize_per_tensor_and_dequantize(
1023       {11, 17, 25, 29}, 0.027, 89, dtype, true);
1024 
1025   for (int i = 0; i < 20; i += 1) {
1026     test_quantize_per_tensor_and_dequantize_random(dtype, true);
1027   }
1028 }
1029 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_qint8)1030 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint8) {
1031   const c10::ScalarType dtype = c10::ScalarType::QInt8;
1032   test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype);
1033   test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
1034   test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype);
1035   test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
1036   test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype);
1037   test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
1038   test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, -15, dtype);
1039   test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
1040   test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, -10, dtype);
1041   test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
1042   test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, -10, dtype);
1043   test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
1044   test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43, dtype);
1045   test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
1046   test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype);
1047   test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
1048   test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, -19, dtype);
1049   test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
1050 
1051   for (int i = 0; i < 20; i += 1) {
1052     test_quantize_per_tensor_and_dequantize_random(dtype);
1053   }
1054 }
1055 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_qint8_qparams)1056 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint8_qparams) {
1057   const c10::ScalarType dtype = c10::ScalarType::QInt8;
1058   test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype, true);
1059   test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype, true);
1060   test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype, true);
1061   test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype, true);
1062   test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype, true);
1063   test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype, true);
1064   test_quantize_per_tensor_and_dequantize(
1065       {1, 1, 11, 17}, 0.07, -15, dtype, true);
1066   test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype, true);
1067   test_quantize_per_tensor_and_dequantize(
1068       {3, 5, 12, 17}, 0.1, -10, dtype, true);
1069   test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype, true);
1070   test_quantize_per_tensor_and_dequantize(
1071       {2, 4, 17, 12}, 0.1, -10, dtype, true);
1072   test_quantize_per_tensor_and_dequantize(
1073       {1, 1, 10, 14}, 0.001, 101, dtype, true);
1074   test_quantize_per_tensor_and_dequantize(
1075       {3, 5, 10, 14}, 0.009, -43, dtype, true);
1076   test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype, true);
1077   test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype, true);
1078   test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype, true);
1079   test_quantize_per_tensor_and_dequantize(
1080       {4, 4, 25, 29}, 0.1, -19, dtype, true);
1081   test_quantize_per_tensor_and_dequantize(
1082       {11, 17, 25, 29}, 0.027, 89, dtype, true);
1083 
1084   for (int i = 0; i < 20; i += 1) {
1085     test_quantize_per_tensor_and_dequantize_random(dtype, true);
1086   }
1087 }
1088 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_qint32)1089 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint32) {
1090   const c10::ScalarType dtype = c10::ScalarType::QInt32;
1091   test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21123, dtype);
1092   test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.339, 8734, dtype);
1093   test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.228, -12023, dtype);
1094   test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.338, 8723, dtype);
1095   test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.193, -1023, dtype);
1096   test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.0449, 972, dtype);
1097   test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.073, -15, dtype);
1098   test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1572, 102, dtype);
1099   test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.147, -156, dtype);
1100   test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.129, 10448, dtype);
1101   test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.137, -10, dtype);
1102   test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
1103   test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43267, dtype);
1104   test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1243, 19, dtype);
1105   test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1889, -19784, dtype);
1106   test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1345, 196, dtype);
1107   test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.129, -19489, dtype);
1108   test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
1109 
1110   for (int i = 0; i < 20; i += 1) {
1111     test_quantize_per_tensor_and_dequantize_random(dtype);
1112   }
1113 }
1114 
TEST_F(VulkanAPITest,quantize_per_tensor_and_dequantize_qint32_qparams)1115 TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint32_qparams) {
1116   const c10::ScalarType dtype = c10::ScalarType::QInt32;
1117   test_quantize_per_tensor_and_dequantize(
1118       {1, 1, 1, 1}, 0.13, -21123, dtype, true);
1119   test_quantize_per_tensor_and_dequantize(
1120       {1, 1, 1, 4}, 0.339, 8734, dtype, true);
1121   test_quantize_per_tensor_and_dequantize(
1122       {1, 1, 4, 1}, 0.228, -12023, dtype, true);
1123   test_quantize_per_tensor_and_dequantize(
1124       {1, 1, 7, 7}, 0.338, 8723, dtype, true);
1125   test_quantize_per_tensor_and_dequantize(
1126       {1, 1, 8, 8}, 0.193, -1023, dtype, true);
1127   test_quantize_per_tensor_and_dequantize(
1128       {3, 5, 8, 8}, 0.0449, 972, dtype, true);
1129   test_quantize_per_tensor_and_dequantize(
1130       {1, 1, 11, 17}, 0.073, -15, dtype, true);
1131   test_quantize_per_tensor_and_dequantize(
1132       {1, 1, 12, 17}, 0.1572, 102, dtype, true);
1133   test_quantize_per_tensor_and_dequantize(
1134       {3, 5, 12, 17}, 0.147, -156, dtype, true);
1135   test_quantize_per_tensor_and_dequantize(
1136       {1, 1, 17, 12}, 0.129, 10448, dtype, true);
1137   test_quantize_per_tensor_and_dequantize(
1138       {2, 4, 17, 12}, 0.137, -10, dtype, true);
1139   test_quantize_per_tensor_and_dequantize(
1140       {1, 1, 10, 14}, 0.001, 101, dtype, true);
1141   test_quantize_per_tensor_and_dequantize(
1142       {3, 5, 10, 14}, 0.009, -43267, dtype, true);
1143   test_quantize_per_tensor_and_dequantize(
1144       {3, 5, 10, 15}, 0.1243, 19, dtype, true);
1145   test_quantize_per_tensor_and_dequantize(
1146       {4, 4, 9, 17}, 0.1889, -19784, dtype, true);
1147   test_quantize_per_tensor_and_dequantize(
1148       {3, 5, 25, 29}, 0.1345, 196, dtype, true);
1149   test_quantize_per_tensor_and_dequantize(
1150       {4, 4, 25, 29}, 0.129, -19489, dtype, true);
1151   test_quantize_per_tensor_and_dequantize(
1152       {11, 17, 25, 29}, 0.027, 89, dtype, true);
1153 
1154   for (int i = 0; i < 20; i += 1) {
1155     test_quantize_per_tensor_and_dequantize_random(dtype, true);
1156   }
1157 }
1158 
TEST_F(VulkanAPITest,quantized_add)1159 TEST_F(VulkanAPITest, quantized_add) {
1160   const auto in_cpu =
1161       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1162   const auto in_vulkan = in_cpu.vulkan();
1163   const auto in_cpu2 =
1164       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1165   const auto in_vulkan2 = in_cpu2.vulkan();
1166 
1167   const double scale = 0.1;
1168   const int zero_point = 10;
1169 
1170   const auto out_cpu = at::quantize_per_tensor(
1171       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1172   const auto out_cpu2 = at::quantize_per_tensor(
1173       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1174   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1175       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1176   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1177       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1178 
1179   const double scale3 = 0.15;
1180   const int zero_point3 = 15;
1181   const auto reg_added_tensors = callOpByName(
1182       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1183   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1184       out_vulkan, out_vulkan2, scale3, zero_point3);
1185 
1186   const auto out_vulkan_deq =
1187       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1188   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
1189 
1190   float rtol = 0;
1191   float atol = 0.5;
1192   const auto check = at::allclose(
1193       at::dequantize(reg_added_tensors[0].toTensor()),
1194       output_for_dequantized_vulkan,
1195       rtol,
1196       atol);
1197 
1198   if (!check) {
1199     std::cout << "Max Diff allowed: " << rtol << std::endl;
1200   }
1201 
1202   ASSERT_TRUE(check);
1203 }
1204 
TEST_F(VulkanAPITest,quantized_add_broadcast)1205 TEST_F(VulkanAPITest, quantized_add_broadcast) {
1206   const auto in_cpu =
1207       at::rand({2, 13, 1, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1208   const auto in_vulkan = in_cpu.vulkan();
1209   const auto in_cpu2 =
1210       at::rand({2, 13, 32, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1211   const auto in_vulkan2 = in_cpu2.vulkan();
1212 
1213   const double scale = 0.1;
1214   const int zero_point = 10;
1215 
1216   const auto out_cpu = at::quantize_per_tensor(
1217       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1218   const auto out_cpu2 = at::quantize_per_tensor(
1219       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1220   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1221       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1222   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1223       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1224 
1225   const double scale3 = 0.15;
1226   const int zero_point3 = 15;
1227   const auto reg_added_tensors = callOpByName(
1228       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1229   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1230       out_vulkan, out_vulkan2, scale3, zero_point3);
1231 
1232   const auto in_cpu3 =
1233       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1234   const auto out_vulkan_deq =
1235       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1236   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3);
1237 
1238   float rtol = 0;
1239   float atol = 0.5;
1240   const auto check = at::allclose(
1241       at::dequantize(reg_added_tensors[0].toTensor()),
1242       output_for_dequantized_vulkan,
1243       rtol,
1244       atol);
1245 
1246   if (!check) {
1247     std::cout << "Max Diff allowed: " << rtol << std::endl;
1248   }
1249 
1250   ASSERT_TRUE(check);
1251 }
1252 
TEST_F(VulkanAPITest,quantized_add_broadcast1)1253 TEST_F(VulkanAPITest, quantized_add_broadcast1) {
1254   if (!at::is_vulkan_available()) {
1255     return;
1256   }
1257 
1258   const auto in_cpu =
1259       at::rand({2, 12, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1260   const auto in_vulkan = in_cpu.vulkan();
1261   const auto in_cpu2 =
1262       at::rand({12, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1263   const auto in_vulkan2 = in_cpu2.vulkan();
1264 
1265   const double scale = 0.1;
1266   const int zero_point = 10;
1267 
1268   const auto out_cpu = at::quantize_per_tensor(
1269       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1270   const auto out_cpu2 = at::quantize_per_tensor(
1271       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1272   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1273       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1274   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1275       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1276 
1277   const double scale3 = 0.15;
1278   const int zero_point3 = 15;
1279   const auto reg_added_tensors = callOpByName(
1280       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1281   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1282       out_vulkan, out_vulkan2, scale3, zero_point3);
1283 
1284   const auto in_cpu3 =
1285       at::rand({2, 12, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1286   const auto out_vulkan_deq =
1287       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1288   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3);
1289 
1290   float rtol = 0;
1291   float atol = 0.5;
1292   const auto check = at::allclose(
1293       at::dequantize(reg_added_tensors[0].toTensor()),
1294       output_for_dequantized_vulkan,
1295       rtol,
1296       atol);
1297 
1298   if (!check) {
1299     std::cout << "Max Diff allowed: " << rtol << std::endl;
1300   }
1301 
1302   ASSERT_TRUE(check);
1303 }
1304 
TEST_F(VulkanAPITest,quantized_add_broadcast2)1305 TEST_F(VulkanAPITest, quantized_add_broadcast2) {
1306   if (!at::is_vulkan_available()) {
1307     return;
1308   }
1309 
1310   const auto in_cpu =
1311       at::rand({32, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1312   const auto in_vulkan = in_cpu.vulkan();
1313   const auto in_cpu2 =
1314       at::rand({1, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1315   const auto in_vulkan2 = in_cpu2.vulkan();
1316 
1317   const double scale = 0.1;
1318   const int zero_point = 10;
1319 
1320   const auto out_cpu = at::quantize_per_tensor(
1321       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1322   const auto out_cpu2 = at::quantize_per_tensor(
1323       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1324   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1325       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1326   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1327       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1328 
1329   const double scale3 = 0.15;
1330   const int zero_point3 = 15;
1331   const auto reg_added_tensors = callOpByName(
1332       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1333   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1334       out_vulkan, out_vulkan2, scale3, zero_point3);
1335 
1336   const auto in_cpu3 =
1337       at::rand({32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1338   const auto out_vulkan_deq =
1339       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1340   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3);
1341 
1342   float rtol = 0;
1343   float atol = 0.5;
1344   const auto check = at::allclose(
1345       at::dequantize(reg_added_tensors[0].toTensor()),
1346       output_for_dequantized_vulkan,
1347       rtol,
1348       atol);
1349 
1350   if (!check) {
1351     std::cout << "Max Diff allowed: " << rtol << std::endl;
1352   }
1353 
1354   ASSERT_TRUE(check);
1355 }
1356 
TEST_F(VulkanAPITest,quantized_add_broadcast3)1357 TEST_F(VulkanAPITest, quantized_add_broadcast3) {
1358   if (!at::is_vulkan_available()) {
1359     return;
1360   }
1361 
1362   const auto in_cpu =
1363       at::rand({32, 24}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1364   const auto in_vulkan = in_cpu.vulkan();
1365   const auto in_cpu2 =
1366       at::rand({1}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1367   const auto in_vulkan2 = in_cpu2.vulkan();
1368 
1369   const double scale = 0.1;
1370   const int zero_point = 10;
1371 
1372   const auto out_cpu = at::quantize_per_tensor(
1373       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1374   const auto out_cpu2 = at::quantize_per_tensor(
1375       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1376   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1377       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1378   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1379       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1380 
1381   const double scale3 = 0.15;
1382   const int zero_point3 = 15;
1383   const auto reg_added_tensors = callOpByName(
1384       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1385   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1386       out_vulkan, out_vulkan2, scale3, zero_point3);
1387 
1388   const auto in_cpu3 =
1389       at::rand({32, 24}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1390   const auto out_vulkan_deq =
1391       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1392   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3);
1393 
1394   float rtol = 0;
1395   float atol = 0.5;
1396   const auto check = at::allclose(
1397       at::dequantize(reg_added_tensors[0].toTensor()),
1398       output_for_dequantized_vulkan,
1399       rtol,
1400       atol);
1401 
1402   if (!check) {
1403     std::cout << "Max Diff allowed: " << rtol << std::endl;
1404   }
1405 
1406   ASSERT_TRUE(check);
1407 }
1408 
TEST_F(VulkanAPITest,quantized_add_dif_params)1409 TEST_F(VulkanAPITest, quantized_add_dif_params) {
1410   const auto in_cpu =
1411       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1412   const auto in_vulkan = in_cpu.vulkan();
1413   const auto in_cpu2 =
1414       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1415   const auto in_vulkan2 = in_cpu2.vulkan();
1416   const double scale = 0.1;
1417   const int zero_point = 10;
1418   const double scale2 = 0.2;
1419   const int zero_point2 = 20;
1420 
1421   const auto out_cpu = at::quantize_per_tensor(
1422       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1423   const auto out_cpu2 = at::quantize_per_tensor(
1424       in_cpu2, scale2, zero_point2, c10::ScalarType::QUInt8);
1425   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1426       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1427   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1428       in_vulkan2, scale2, zero_point2, c10::ScalarType::QUInt8);
1429 
1430   const double scale3 = 0.15;
1431   const int zero_point3 = 15;
1432   const auto reg_added_tensors = callOpByName(
1433       "quantized::add", "", out_cpu, out_cpu2, scale3, zero_point3);
1434   const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add(
1435       out_vulkan, out_vulkan2, scale3, zero_point3);
1436 
1437   const auto out_vulkan_deq =
1438       at::native::vulkan::ops::dequantize(vulk_added_tensors);
1439   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
1440 
1441   float rtol = 0;
1442   float atol = 0.5;
1443   const auto check = at::allclose(
1444       at::dequantize(reg_added_tensors[0].toTensor()),
1445       output_for_dequantized_vulkan,
1446       rtol,
1447       atol);
1448 
1449   if (!check) {
1450     std::cout << "Max Diff allowed: " << rtol << std::endl;
1451   }
1452 
1453   ASSERT_TRUE(check);
1454 }
1455 
test_conv2d(bool bias_quantized)1456 void test_conv2d(bool bias_quantized) {
1457   constexpr int64_t groups = 1;
1458   constexpr std::array<int64_t, 2u> stride{2, 2};
1459   constexpr std::array<int64_t, 2u> padding{1, 1};
1460   // TODO: Support conv2d with dilation != 1
1461   constexpr std::array<int64_t, 2u> dilation{1, 1};
1462 
1463   constexpr struct {
1464     uint32_t batches;
1465     uint32_t channels;
1466     uint32_t width;
1467     uint32_t height;
1468 
1469     std::array<int64_t, 4u> size() const {
1470       return {
1471           batches,
1472           channels,
1473           width,
1474           height,
1475       };
1476     }
1477   } input{1, 3, 8, 8};
1478 
1479   constexpr struct {
1480     uint32_t output_channels;
1481     uint32_t input_channels;
1482     uint32_t width;
1483     uint32_t height;
1484 
1485     std::array<int64_t, 4u> size() const {
1486       return {
1487           output_channels,
1488           input_channels,
1489           width,
1490           height,
1491       };
1492     }
1493   } weights{1, input.channels, 3, 3};
1494 
1495   float r1 = 0.1;
1496   float r2 = 0.7;
1497   const auto input_cpu = (r1 - r2) *
1498           at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1499       r2;
1500   const auto weights_cpu = (r1 - r2) *
1501           at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1502       r2;
1503   const auto bias_cpu = (r1 - r2) *
1504           at::rand({weights.output_channels},
1505                    at::device(at::kCPU).dtype(at::kFloat)) +
1506       r2;
1507 
1508   const double w_scale = 0.1;
1509   const int w_zero_point = 10;
1510 
1511   const double b_scale = 0.1;
1512   const int b_zero_point = 10;
1513 
1514   const auto weight_q = at::quantize_per_tensor(
1515       weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8);
1516   const auto bias_q = at::quantize_per_tensor(
1517       bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8);
1518 
1519   const auto output_cpu = at::conv2d(
1520       input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups);
1521 
1522   const double scale = 0.10;
1523   const int zero_point = 10;
1524   const auto shape_match =
1525       at::rand({1, 1, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1526   const auto in_vulkan = input_cpu.vulkan();
1527   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1528       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1529 
1530   const double scale2 = 0.15;
1531   const int zero_point2 = 15;
1532   const auto output_vulkan = at::native::vulkan::ops::quantized_conv2d(
1533       out_vulkan,
1534       weight_q,
1535       bias_quantized ? bias_q : bias_cpu,
1536       stride,
1537       padding,
1538       dilation,
1539       groups,
1540       scale2,
1541       zero_point2);
1542 
1543   const auto out_vulkan_deq =
1544       at::native::vulkan::ops::dequantize(output_vulkan);
1545   auto output_for_dequantized_vulkan =
1546       vulkan_to_cpu(out_vulkan_deq, shape_match);
1547 
1548   float rtol = 0;
1549   float atol = 1.5;
1550   const auto check =
1551       at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol);
1552 
1553   if (!check) {
1554     std::cout << "Max Diff allowed: " << rtol << std::endl;
1555   }
1556 
1557   ASSERT_TRUE(check);
1558 }
1559 
TEST_F(VulkanAPITest,conv2d)1560 TEST_F(VulkanAPITest, conv2d) {
1561   test_conv2d(false);
1562   test_conv2d(true);
1563 }
1564 
TEST_F(VulkanAPITest,conv2d_pw)1565 TEST_F(VulkanAPITest, conv2d_pw) {
1566   constexpr int64_t groups = 1;
1567   constexpr std::array<int64_t, 2u> stride{1, 1};
1568   constexpr std::array<int64_t, 2u> padding{0, 0};
1569   constexpr std::array<int64_t, 2u> dilation{1, 1};
1570 
1571   constexpr struct {
1572     uint32_t batches;
1573     uint32_t channels;
1574     uint32_t width;
1575     uint32_t height;
1576 
1577     std::array<int64_t, 4u> size() const {
1578       return {
1579           batches,
1580           channels,
1581           width,
1582           height,
1583       };
1584     }
1585   } input{1, 17, 127, 397};
1586 
1587   constexpr struct {
1588     uint32_t output_channels;
1589     uint32_t input_channels;
1590     uint32_t width;
1591     uint32_t height;
1592 
1593     std::array<int64_t, 4u> size() const {
1594       return {
1595           output_channels,
1596           input_channels,
1597           width,
1598           height,
1599       };
1600     }
1601   } weights{29, input.channels, 1, 1};
1602 
1603   float r1 = 0.1;
1604   float r2 = 0.7;
1605   const auto input_cpu = (r1 - r2) *
1606           at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1607       r2;
1608   const auto weights_cpu = (r1 - r2) *
1609           at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1610       r2;
1611   const auto bias_cpu = (r1 - r2) *
1612           at::rand({weights.output_channels},
1613                    at::device(at::kCPU).dtype(at::kFloat)) +
1614       r2;
1615 
1616   const double w_scale = 0.1;
1617   const int w_zero_point = 10;
1618 
1619   const double b_scale = 0.1;
1620   const int b_zero_point = 10;
1621 
1622   const auto weight_q = at::quantize_per_tensor(
1623       weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8);
1624   const auto bias_q = at::quantize_per_tensor(
1625       bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8);
1626 
1627   const auto output_cpu = at::conv2d(
1628       input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups);
1629 
1630   const double scale = 0.10;
1631   const int zero_point = 10;
1632   const auto shape_match =
1633       at::rand({1, 29, 127, 397}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1634   const auto in_vulkan = input_cpu.vulkan();
1635   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1636       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1637 
1638   const double scale2 = 0.15;
1639   const int zero_point2 = 15;
1640   const auto output_vulkan = at::native::vulkan::ops::quantized_conv2d(
1641       out_vulkan,
1642       weight_q,
1643       bias_q,
1644       stride,
1645       padding,
1646       dilation,
1647       groups,
1648       scale2,
1649       zero_point2);
1650 
1651   const auto out_vulkan_deq =
1652       at::native::vulkan::ops::dequantize(output_vulkan);
1653   auto output_for_dequantized_vulkan =
1654       vulkan_to_cpu(out_vulkan_deq, shape_match);
1655 
1656   float rtol = 0;
1657   float atol = 1.5;
1658   const auto check =
1659       at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol);
1660 
1661   if (!check) {
1662     std::cout << "Max Diff allowed: " << rtol << std::endl;
1663   }
1664 
1665   ASSERT_TRUE(check);
1666 }
1667 
TEST_F(VulkanAPITest,conv2d_dw)1668 TEST_F(VulkanAPITest, conv2d_dw) {
1669   constexpr int64_t groups = 7;
1670   constexpr std::array<int64_t, 2u> stride{2, 3};
1671   constexpr std::array<int64_t, 2u> padding{0, 4};
1672   constexpr std::array<int64_t, 2u> dilation{3, 1};
1673 
1674   constexpr struct {
1675     uint32_t batches;
1676     uint32_t channels;
1677     uint32_t width;
1678     uint32_t height;
1679 
1680     std::array<int64_t, 4u> size() const {
1681       return {
1682           batches,
1683           channels,
1684           width,
1685           height,
1686       };
1687     }
1688   } input{1, groups, 137, 199};
1689 
1690   constexpr struct {
1691     uint32_t output_channels;
1692     uint32_t input_channels;
1693     uint32_t width;
1694     uint32_t height;
1695 
1696     std::array<int64_t, 4u> size() const {
1697       return {
1698           output_channels,
1699           input_channels,
1700           width,
1701           height,
1702       };
1703     }
1704   } weights{groups, 1, 17, 7};
1705 
1706   float r1 = 0;
1707   float r2 = 0.2;
1708   const auto input_cpu = (r1 - r2) *
1709           at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1710       r2;
1711   const auto weights_cpu = (r1 - r2) *
1712           at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) +
1713       r2;
1714   const auto bias_cpu = (r1 - r2) *
1715           at::rand({weights.output_channels},
1716                    at::device(at::kCPU).dtype(at::kFloat)) +
1717       r2;
1718 
1719   const double w_scale = 0.1;
1720   const int w_zero_point = 10;
1721 
1722   const double b_scale = 0.1;
1723   const int b_zero_point = 10;
1724 
1725   const auto weight_q = at::quantize_per_tensor(
1726       weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8);
1727   const auto bias_q = at::quantize_per_tensor(
1728       bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8);
1729 
1730   const auto output_cpu = at::conv2d(
1731       input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups);
1732 
1733   const double scale = 0.10;
1734   const int zero_point = 10;
1735   const auto shape_match =
1736       at::rand({1, 7, 45, 67}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1737   const auto in_vulkan = input_cpu.vulkan();
1738   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1739       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1740 
1741   const double scale2 = 0.15;
1742   const int zero_point2 = 15;
1743   const auto output_vulkan = at::native::vulkan::ops::quantized_conv2d(
1744       out_vulkan,
1745       weight_q,
1746       bias_q,
1747       stride,
1748       padding,
1749       dilation,
1750       groups,
1751       scale2,
1752       zero_point2);
1753 
1754   const auto out_vulkan_deq =
1755       at::native::vulkan::ops::dequantize(output_vulkan);
1756   auto output_for_dequantized_vulkan =
1757       vulkan_to_cpu(out_vulkan_deq, shape_match);
1758 
1759   float rtol = 0;
1760   float atol = 1;
1761   const auto check =
1762       at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol);
1763 
1764   if (!check) {
1765     std::cout << "Max Diff allowed: " << rtol << std::endl;
1766   }
1767 
1768   ASSERT_TRUE(check);
1769 }
1770 
test_quantized_conv_transpose2d(const at::IntArrayRef input_shape,const at::IntArrayRef weight_shape,const at::IntArrayRef bias_shape,const c10::ScalarType w_dtype,const c10::ScalarType bias_dtype,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> output_padding,std::vector<int64_t> dilation,int64_t groups)1771 static void test_quantized_conv_transpose2d(
1772     const at::IntArrayRef input_shape,
1773     const at::IntArrayRef weight_shape,
1774     const at::IntArrayRef bias_shape,
1775     const c10::ScalarType w_dtype,
1776     const c10::ScalarType bias_dtype,
1777     std::vector<int64_t> stride,
1778     std::vector<int64_t> padding,
1779     std::vector<int64_t> output_padding,
1780     std::vector<int64_t> dilation,
1781     int64_t groups) {
1782   c10::InferenceMode mode;
1783 
1784   const at::Tensor input =
1785       at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
1786   const at::Tensor weight =
1787       at::rand(weight_shape, at::device(at::kCPU).dtype(at::kFloat));
1788   const at::Tensor bias =
1789       at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat));
1790 
1791   const auto input_quant_params =
1792       compute_quant_params(input, c10::ScalarType::QUInt8);
1793   double input_scale = std::get<0>(input_quant_params);
1794   input_scale = safe_downcast<float>(input_scale);
1795   int32_t input_zero_point = std::get<1>(input_quant_params);
1796   auto input_cpu_q = at::quantize_per_tensor(
1797       input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
1798 
1799   const auto weight_quant_params = compute_quant_params(weight, w_dtype);
1800   double weight_scale = std::get<0>(weight_quant_params);
1801   weight_scale = safe_downcast<float>(weight_scale);
1802   int32_t weight_zero_point = std::get<1>(weight_quant_params);
1803   auto weight_cpu_q =
1804       at::quantize_per_tensor(weight, weight_scale, weight_zero_point, w_dtype);
1805 
1806   double out_scale = produce_random_scale();
1807   out_scale = safe_downcast<float>(out_scale);
1808   int out_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8);
1809 
1810   at::Tensor bias_cpu_q;
1811   // quantize bias
1812   if (bias_dtype != c10::ScalarType::Float) {
1813     const auto bias_quant_params = compute_quant_params(bias, bias_dtype);
1814     double bias_scale = std::get<0>(weight_quant_params);
1815     bias_scale = safe_downcast<float>(bias_scale);
1816     int32_t bias_zero_point = std::get<1>(bias_quant_params);
1817     bias_cpu_q =
1818         at::quantize_per_tensor(bias, bias_scale, bias_zero_point, bias_dtype);
1819   } else {
1820     bias_cpu_q = bias;
1821   }
1822 
1823   auto pack = callOpByName(
1824       "quantized::conv_transpose2d_prepack",
1825       "",
1826       weight_cpu_q,
1827       bias_cpu_q,
1828       stride,
1829       padding,
1830       output_padding,
1831       dilation,
1832       groups);
1833 
1834   auto out_cpu_quant = callOpByName(
1835       "quantized::conv_transpose2d",
1836       "",
1837       input_cpu_q,
1838       pack[0],
1839       out_scale,
1840       out_zero_point);
1841 
1842   const at::Tensor out_cpu = at::dequantize(out_cpu_quant[0].toTensor());
1843 
1844   // vulkan
1845   const auto prepack_vulkan = callOpByName(
1846       "vulkan_prepack::create_qtconv2d_context",
1847       "",
1848       weight_cpu_q,
1849       bias_cpu_q,
1850       stride,
1851       padding,
1852       output_padding,
1853       dilation,
1854       groups,
1855       std::nullopt,
1856       std::nullopt);
1857 
1858   const auto input_vk_q = at::quantize_per_tensor(
1859       input.vulkan(), input_scale, input_zero_point, c10::ScalarType::QUInt8);
1860   auto vulkan_output = callOpByName(
1861       "vulkan_prepack::run_qconv2d_context",
1862       "",
1863       input_vk_q,
1864       out_scale,
1865       out_zero_point,
1866       prepack_vulkan[0]);
1867 
1868   const auto out_vk_dequant = at::dequantize(vulkan_output[0].toTensor());
1869   const auto out_vk_cpu = out_vk_dequant.cpu();
1870 
1871   // check
1872   const auto check = almostEqual(out_cpu, out_vk_cpu, out_scale);
1873   if (!check) {
1874     showRtol(out_cpu, out_vk_cpu);
1875   }
1876 
1877   ASSERT_TRUE(check);
1878 }
1879 
TEST_F(VulkanAPITest,conv_tranpose2d_quantized_int8_float)1880 TEST_F(VulkanAPITest, conv_tranpose2d_quantized_int8_float) {
1881   test_quantized_conv_transpose2d(
1882       {1, 3, 2, 2}, // input_shape
1883       {3, 3, 2, 2}, // weight_shape
1884       {3}, // bias_shape
1885       c10::ScalarType::QInt8, // weight quantization dtype
1886       c10::ScalarType::Float, // bias quantization dtype
1887       {1, 2}, // stride
1888       {1, 0}, // padding
1889       {0, 1}, // output_padding
1890       {1, 1}, // dilation
1891       1); // groups
1892 
1893   test_quantized_conv_transpose2d(
1894       {1, 55, 7, 19}, // input_shape
1895       {55, 47, 2, 3}, // weight_shape
1896       {47}, // bias_shape
1897       c10::ScalarType::QInt8, // weight quantization dtype
1898       c10::ScalarType::Float, // bias quantization dtype
1899       {1, 2}, // stride
1900       {1, 0}, // padding
1901       {0, 1}, // output_padding
1902       {1, 1}, // dilation
1903       1); // groups
1904 }
1905 
TEST_F(VulkanAPITest,quantized_sub)1906 TEST_F(VulkanAPITest, quantized_sub) {
1907   float r1 = 4.0;
1908   float r2 = 7.0;
1909 
1910   float r3 = 2.0;
1911   float r4 = 5.0;
1912   const auto in_cpu = (r1 - r2) *
1913           at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) +
1914       r2;
1915   const auto in_vulkan = in_cpu.vulkan();
1916   const auto in_cpu2 = (r3 - r4) *
1917           at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) +
1918       r4;
1919   const auto in_vulkan2 = in_cpu2.vulkan();
1920 
1921   const double scale = 0.1;
1922   const int zero_point = 10;
1923 
1924   const auto out_cpu = at::quantize_per_tensor(
1925       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1926   const auto out_cpu2 = at::quantize_per_tensor(
1927       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1928   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1929       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1930   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1931       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1932 
1933   const auto reg_subtracted_tensors = at::sub(in_cpu, in_cpu2);
1934 
1935   const double scale3 = 0.15;
1936   const int zero_point3 = 15;
1937   const auto vulk_subtracted_tensors = at::native::vulkan::ops::quantized_sub(
1938       out_vulkan, out_vulkan2, scale3, zero_point3);
1939 
1940   const auto out_vulkan_deq =
1941       at::native::vulkan::ops::dequantize(vulk_subtracted_tensors);
1942   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
1943 
1944   float rtol = 0;
1945   float atol = 0.5;
1946   const auto check = at::allclose(
1947       reg_subtracted_tensors, output_for_dequantized_vulkan, rtol, atol);
1948 
1949   if (!check) {
1950     std::cout << "Max Diff allowed: " << rtol << std::endl;
1951   }
1952 
1953   ASSERT_TRUE(check);
1954 }
1955 
TEST_F(VulkanAPITest,quantized_mul)1956 TEST_F(VulkanAPITest, quantized_mul) {
1957   const auto in_cpu =
1958       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1959   const auto in_vulkan = in_cpu.vulkan();
1960   const auto in_cpu2 =
1961       at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6;
1962   const auto in_vulkan2 = in_cpu2.vulkan();
1963 
1964   const double scale = 0.1;
1965   const int zero_point = 10;
1966 
1967   const auto out_cpu = at::quantize_per_tensor(
1968       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
1969   const auto out_cpu2 = at::quantize_per_tensor(
1970       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
1971   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
1972       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
1973   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
1974       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
1975 
1976   const double scale3 = 0.15;
1977   const int zero_point3 = 15;
1978   const auto reg_mul_tensors = callOpByName(
1979       "quantized::mul", "", out_cpu, out_cpu2, scale3, zero_point3);
1980   const auto vulk_mul_tensors = at::native::vulkan::ops::quantized_mul(
1981       out_vulkan, out_vulkan2, scale3, zero_point3);
1982 
1983   const auto out_vulkan_deq =
1984       at::native::vulkan::ops::dequantize(vulk_mul_tensors);
1985   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
1986 
1987   float rtol = 0;
1988   float atol = 1.5;
1989   const auto check = at::allclose(
1990       at::dequantize(reg_mul_tensors[0].toTensor()),
1991       output_for_dequantized_vulkan,
1992       rtol,
1993       atol);
1994 
1995   if (!check) {
1996     std::cout << "Max Diff allowed: " << rtol << std::endl;
1997   }
1998 
1999   ASSERT_TRUE(check);
2000 }
2001 
TEST_F(VulkanAPITest,quantized_div)2002 TEST_F(VulkanAPITest, quantized_div) {
2003   float r1 = 2.0;
2004   float r2 = 3.5;
2005 
2006   float r3 = 4.0;
2007   float r4 = 5.5;
2008   const auto in_cpu = (r1 - r2) *
2009           at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) +
2010       r2;
2011   const auto in_vulkan = in_cpu.vulkan();
2012   const auto in_cpu2 = (r3 - r4) *
2013           at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) +
2014       r4;
2015   const auto in_vulkan2 = in_cpu2.vulkan();
2016 
2017   const double scale = 0.1;
2018   const int zero_point = 10;
2019 
2020   const auto out_cpu = at::quantize_per_tensor(
2021       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
2022   const auto out_cpu2 = at::quantize_per_tensor(
2023       in_cpu2, scale, zero_point, c10::ScalarType::QUInt8);
2024   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
2025       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
2026   const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
2027       in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
2028 
2029   const auto reg_div_tensors = at::div(in_cpu, in_cpu2);
2030 
2031   const double scale3 = 0.15;
2032   const int zero_point3 = 15;
2033   const auto vulk_div_tensors = at::native::vulkan::ops::quantized_div(
2034       out_vulkan, out_vulkan2, scale3, zero_point3);
2035 
2036   const auto out_vulkan_deq =
2037       at::native::vulkan::ops::dequantize(vulk_div_tensors);
2038   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
2039 
2040   float rtol = 0;
2041   float atol = 1;
2042   const auto check =
2043       at::allclose(reg_div_tensors, output_for_dequantized_vulkan, rtol, atol);
2044 
2045   if (!check) {
2046     std::cout << "Max Diff allowed: " << rtol << std::endl;
2047   }
2048 
2049   ASSERT_TRUE(check);
2050 }
2051 
TEST_F(VulkanAPITest,quantized_upsample_nearest2d)2052 TEST_F(VulkanAPITest, quantized_upsample_nearest2d) {
2053   const auto in_cpu =
2054       at::rand({2, 13, 12, 27}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
2055   const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6}, 1, 1);
2056 
2057   const double scale = 0.1;
2058   const int zero_point = 10;
2059 
2060   const auto in_vulkan = in_cpu.vulkan();
2061   const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor(
2062       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
2063   const auto upsample_vulkan = at::upsample_nearest2d(out_vulkan, {4, 6}, 1, 1);
2064 
2065   const auto in_cpu2 =
2066       at::rand({2, 13, 4, 6}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
2067   const auto out_vulkan_deq =
2068       at::native::vulkan::ops::dequantize(upsample_vulkan);
2069   auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2);
2070 
2071   float rtol = 0;
2072   float atol = 1;
2073   const auto check =
2074       at::allclose(out_cpu, output_for_dequantized_vulkan, rtol, atol);
2075 
2076   if (!check) {
2077     std::cout << "Max Diff allowed: " << rtol << std::endl;
2078   }
2079 
2080   ASSERT_TRUE(check);
2081 }
2082 
produce_inputs_for_binary_op(const bool compute_quantization_params,const bool random_quantization_params,const char * op_name,const at::IntArrayRef input1_shape,const at::IntArrayRef input2_shape,double in1_scale,double in2_scale,int in1_zero_point,int in2_zero_point,at::Tensor & input1_cpu,at::Tensor & input1_cpu_q,at::Tensor & input1_cpu_deq,at::Tensor & input1_vk,at::Tensor & input1_vk_q,at::Tensor & input1_vk_deq,at::Tensor & input1_vk_deq_cpu,at::Tensor & input2_cpu,at::Tensor & input2_cpu_q,at::Tensor & input2_cpu_deq,at::Tensor & input2_vk,at::Tensor & input2_vk_q,at::Tensor & input2_vk_deq,at::Tensor & input2_vk_deq_cpu)2083 std::tuple<double, double, int, int> produce_inputs_for_binary_op(
2084     const bool compute_quantization_params,
2085     const bool random_quantization_params,
2086     const char* op_name,
2087     const at::IntArrayRef input1_shape,
2088     const at::IntArrayRef input2_shape,
2089     double in1_scale,
2090     double in2_scale,
2091     int in1_zero_point,
2092     int in2_zero_point,
2093     at::Tensor& input1_cpu,
2094     at::Tensor& input1_cpu_q,
2095     at::Tensor& input1_cpu_deq,
2096     at::Tensor& input1_vk,
2097     at::Tensor& input1_vk_q,
2098     at::Tensor& input1_vk_deq,
2099     at::Tensor& input1_vk_deq_cpu,
2100     at::Tensor& input2_cpu,
2101     at::Tensor& input2_cpu_q,
2102     at::Tensor& input2_cpu_deq,
2103     at::Tensor& input2_vk,
2104     at::Tensor& input2_vk_q,
2105     at::Tensor& input2_vk_deq,
2106     at::Tensor& input2_vk_deq_cpu) {
2107   int num_attempts = 5;
2108   // in order to make sure we start with input tensors that are numerically
2109   // the same (cpu vs vulkan), we allow multiple attempts when randomly
2110   // generating the inputs. If the cpu quantized tensor and the vk quantized
2111   // tensors are not the same (maybe off by 1 due to differences in rounding
2112   // and precision), we try again.
2113   for (int i = 0; i < num_attempts; i += 1) {
2114     // produce random inputs
2115     input1_cpu = produce_random_tensor(input1_shape);
2116     input2_cpu = produce_random_tensor(input2_shape);
2117 
2118     if (compute_quantization_params) {
2119       // compute appropiate scale and zero point for inputs
2120       const auto in1_quant_params = compute_quant_params(input1_cpu);
2121       in1_scale = std::get<0>(in1_quant_params);
2122       in1_zero_point = std::get<1>(in1_quant_params);
2123 
2124       const auto in2_quant_params = compute_quant_params(input2_cpu);
2125       in2_scale = std::get<0>(in2_quant_params);
2126       in2_zero_point = std::get<1>(in2_quant_params);
2127     } else if (random_quantization_params) {
2128       // produce random scale and zero point for inputs
2129       in1_scale = produce_random_scale();
2130       in1_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8);
2131 
2132       in2_scale = produce_random_scale();
2133       in2_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8);
2134     }
2135 
2136     // we do this, to avoid dividing by zero
2137     if (strcmp(op_name, "quantized::div") == 0) {
2138       // we might end up dividing by 0, if we allow random scale and zero point
2139       // of the divisor.
2140       if (random_quantization_params) {
2141         const auto in2_quant_params = compute_quant_params(input2_cpu);
2142         in2_scale = std::get<0>(in2_quant_params);
2143         in2_zero_point = std::get<1>(in2_quant_params);
2144       }
2145 
2146       const auto non_zero_sign =
2147           input2_cpu.sign() - input2_cpu.sign().abs() + 1;
2148       // non_zero_sign = 1 if the value is non negative, and -1 if it is
2149       // negative
2150       input2_cpu = input2_cpu + in2_scale * non_zero_sign;
2151       // this will force abs(input2_cpu) >= in2_scale, which means that none of
2152       // the quantized values of the second input will be equal to the zero
2153       // point.
2154     }
2155 
2156     // quantize cpu inputs
2157     input1_cpu_q = at::quantize_per_tensor(
2158         input1_cpu, in1_scale, in1_zero_point, c10::ScalarType::QUInt8);
2159     input2_cpu_q = at::quantize_per_tensor(
2160         input2_cpu, in2_scale, in2_zero_point, c10::ScalarType::QUInt8);
2161 
2162     // dequantize quantized cpu inputs
2163     input1_cpu_deq = at::dequantize(input1_cpu_q);
2164     input2_cpu_deq = at::dequantize(input2_cpu_q);
2165 
2166     // vulkan quantized inputs
2167     input1_vk = input1_cpu.vulkan();
2168     input1_vk_q = at::quantize_per_tensor(
2169         input1_vk, in1_scale, in1_zero_point, c10::ScalarType::QUInt8);
2170     input2_vk = input2_cpu.vulkan();
2171     input2_vk_q = at::quantize_per_tensor(
2172         input2_vk, in2_scale, in2_zero_point, c10::ScalarType::QUInt8);
2173 
2174     // dequantize quantized vulkan inputs
2175     input1_vk_deq = at::dequantize(input1_vk_q);
2176     input2_vk_deq = at::dequantize(input2_vk_q);
2177 
2178     input1_vk_deq_cpu = input1_vk_deq.cpu();
2179     input2_vk_deq_cpu = input2_vk_deq.cpu();
2180 
2181     const float input1_dif =
2182         at::abs(input1_cpu_deq - input1_vk_deq_cpu).max().item<float>();
2183     const float input2_dif =
2184         at::abs(input2_cpu_deq - input2_vk_deq_cpu).max().item<float>();
2185     if (input1_dif < 1e-5 && input2_dif < 1e-5 && input1_dif < in1_scale / 2 &&
2186         input2_dif < in2_scale / 2) {
2187       break;
2188     }
2189   }
2190 
2191   return {in1_scale, in2_scale, in1_zero_point, in2_zero_point};
2192 }
2193 
apply_cpu_quantized_binary_op(const char * op_name,at::Tensor input1_cpu_deq,at::Tensor input2_cpu_deq)2194 at::Tensor apply_cpu_quantized_binary_op(
2195     const char* op_name,
2196     at::Tensor input1_cpu_deq,
2197     at::Tensor input2_cpu_deq) {
2198   if (strcmp(op_name, "quantized::add") == 0) {
2199     return at::add(input1_cpu_deq, input2_cpu_deq);
2200   } else if (strcmp(op_name, "quantized::sub") == 0) {
2201     return at::sub(input1_cpu_deq, input2_cpu_deq);
2202   } else if (strcmp(op_name, "quantized::mul") == 0) {
2203     return at::mul(input1_cpu_deq, input2_cpu_deq);
2204   } else if (strcmp(op_name, "quantized::div") == 0) {
2205     return at::div(input1_cpu_deq, input2_cpu_deq);
2206   } else {
2207     TORCH_CHECK(false, "Invalid op");
2208   }
2209 }
2210 
apply_vulkan_quantized_binary_op(const char * op_name,at::Tensor input1_vk_q,at::Tensor input2_vk_q,double out_scale,int out_zero_point)2211 at::Tensor apply_vulkan_quantized_binary_op(
2212     const char* op_name,
2213     at::Tensor input1_vk_q,
2214     at::Tensor input2_vk_q,
2215     double out_scale,
2216     int out_zero_point) {
2217   if (strcmp(op_name, "quantized::add") == 0) {
2218     return at::native::vulkan::ops::quantized_add(
2219         input1_vk_q, input2_vk_q, out_scale, out_zero_point);
2220   } else if (strcmp(op_name, "quantized::sub") == 0) {
2221     return at::native::vulkan::ops::quantized_sub(
2222         input1_vk_q, input2_vk_q, out_scale, out_zero_point);
2223   } else if (strcmp(op_name, "quantized::mul") == 0) {
2224     return at::native::vulkan::ops::quantized_mul(
2225         input1_vk_q, input2_vk_q, out_scale, out_zero_point);
2226   } else if (strcmp(op_name, "quantized::div") == 0) {
2227     return at::native::vulkan::ops::quantized_div(
2228         input1_vk_q, input2_vk_q, out_scale, out_zero_point);
2229   } else {
2230     TORCH_CHECK(false, "Invalid op");
2231   }
2232 }
2233 
test_quantized_binary_op(const bool compute_quantization_params,const bool random_quantization_params,const char * op_name,const at::IntArrayRef input1_shape,const at::IntArrayRef input2_shape,double in1_scale_default=0.103,double in2_scale_default=0.171,double out_scale_default=0.139,int in1_zero_point_default=11,int in2_zero_point_default=9,int out_zero_point_default=17)2234 void test_quantized_binary_op(
2235     const bool compute_quantization_params,
2236     const bool random_quantization_params,
2237     const char* op_name,
2238     const at::IntArrayRef input1_shape,
2239     const at::IntArrayRef input2_shape,
2240     double in1_scale_default = 0.103,
2241     double in2_scale_default = 0.171,
2242     double out_scale_default = 0.139,
2243     int in1_zero_point_default = 11,
2244     int in2_zero_point_default = 9,
2245     int out_zero_point_default = 17) {
2246   // produce inputs
2247   at::Tensor input1_cpu, input1_cpu_q, input1_cpu_deq;
2248   at::Tensor input1_vk, input1_vk_q, input1_vk_deq, input1_vk_deq_cpu;
2249   at::Tensor input2_cpu, input2_cpu_q, input2_cpu_deq;
2250   at::Tensor input2_vk, input2_vk_q, input2_vk_deq, input2_vk_deq_cpu;
2251 
2252   auto input_params = produce_inputs_for_binary_op(
2253       compute_quantization_params,
2254       random_quantization_params,
2255       op_name,
2256       input1_shape,
2257       input2_shape,
2258       in1_scale_default,
2259       in2_scale_default,
2260       in1_zero_point_default,
2261       in2_zero_point_default,
2262       input1_cpu,
2263       input1_cpu_q,
2264       input1_cpu_deq,
2265       input1_vk,
2266       input1_vk_q,
2267       input1_vk_deq,
2268       input1_vk_deq_cpu,
2269       input2_cpu,
2270       input2_cpu_q,
2271       input2_cpu_deq,
2272       input2_vk,
2273       input2_vk_q,
2274       input2_vk_deq,
2275       input2_vk_deq_cpu);
2276 
2277   double in1_scale = std::get<0>(input_params);
2278   double in2_scale = std::get<1>(input_params);
2279   int in1_zero_point = std::get<2>(input_params);
2280   int in2_zero_point = std::get<3>(input_params);
2281 
2282   double out_scale = out_scale_default;
2283   int out_zero_point = out_zero_point_default;
2284 
2285   // apply op on dequantized cpu tensors
2286   at::Tensor output_cpu =
2287       apply_cpu_quantized_binary_op(op_name, input1_cpu_deq, input2_cpu_deq);
2288 
2289   if (compute_quantization_params || random_quantization_params) {
2290     // compute appropiate scale and zero point for output
2291     const auto out_quant_params = compute_quant_params(output_cpu);
2292     out_scale = std::get<0>(out_quant_params);
2293     out_zero_point = std::get<1>(out_quant_params);
2294   }
2295 
2296   // quantize and dequantize cpu output
2297   const auto output_cpu_q = at::quantize_per_tensor(
2298       output_cpu, out_scale, out_zero_point, c10::ScalarType::QUInt8);
2299   const auto output_cpu_deq = at::dequantize(output_cpu_q);
2300 
2301   // vulkan quantized output
2302   at::Tensor output_vk_q = apply_vulkan_quantized_binary_op(
2303       op_name, input1_vk_q, input2_vk_q, out_scale, out_zero_point);
2304 
2305   const auto output_vk_deq = at::dequantize(output_vk_q);
2306   const auto output_vk_deq_cpu = output_vk_deq.cpu();
2307 
2308   // check
2309   const float tolerance =
2310       (compute_quantization_params || random_quantization_params)
2311       ? safe_downcast<float>(out_scale)
2312       : 0;
2313   const auto check = almostEqual(output_cpu_deq, output_vk_deq_cpu, tolerance);
2314 
2315   if (!check) {
2316     const auto vk_q_error =
2317         at::abs(output_vk_deq_cpu - output_cpu_deq).max().item<float>();
2318     std::cout << "Binary op " << op_name
2319               << " failed with inputs: " << std::endl;
2320     std::cout << "input1: shape " << input1_shape << " scale " << in1_scale
2321               << " and zero point " << in1_zero_point << std::endl;
2322     std::cout << "input2: shape " << input2_shape << " scale " << in2_scale
2323               << " and zero point " << in2_zero_point << std::endl;
2324     std::cout << "output scale " << out_scale << " and zero point "
2325               << out_zero_point << std::endl;
2326     std::cout << "error: " << vk_q_error << std::endl;
2327   }
2328   ASSERT_TRUE(check);
2329 }
2330 
quantized_binary_op_test_set(const char * op_name)2331 void quantized_binary_op_test_set(const char* op_name) {
2332   // fixed params
2333   test_quantized_binary_op(false, false, op_name, {1, 1, 1, 1}, {1, 1, 1, 1});
2334   test_quantized_binary_op(false, false, op_name, {1, 1, 8, 8}, {1, 1, 8, 8});
2335   test_quantized_binary_op(
2336       false, false, op_name, {1, 1, 12, 17}, {1, 1, 12, 17});
2337   test_quantized_binary_op(
2338       false, false, op_name, {2, 13, 32, 27}, {2, 13, 32, 27});
2339   test_quantized_binary_op(
2340       false, false, op_name, {7, 15, 6, 1}, {7, 15, 6, 17}); // broadcasting
2341   test_quantized_binary_op(
2342       false, false, op_name, {7, 15, 6, 17}, {7, 15, 6, 1}); // broadcasting
2343   test_quantized_binary_op(
2344       false, false, op_name, {7, 15, 1, 17}, {7, 15, 6, 17}); // broadcasting
2345   test_quantized_binary_op(
2346       false, false, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting
2347   test_quantized_binary_op(
2348       false, false, op_name, {1, 1, 6, 17}, {7, 15, 6, 17}); // broadcasting
2349   test_quantized_binary_op(
2350       false, false, op_name, {7, 15, 6, 17}, {1, 1, 6, 17}); // broadcasting
2351   test_quantized_binary_op(
2352       false, false, op_name, {1, 15, 6, 17}, {7, 15, 6, 17}); // broadcasting
2353   test_quantized_binary_op(
2354       false, false, op_name, {7, 15, 6, 17}, {1, 15, 6, 17}); // broadcasting
2355 
2356   // compute params
2357   test_quantized_binary_op(true, false, op_name, {1, 1, 1, 1}, {1, 1, 1, 1});
2358   test_quantized_binary_op(true, false, op_name, {1, 1, 8, 8}, {1, 1, 8, 8});
2359   test_quantized_binary_op(
2360       true, false, op_name, {1, 1, 12, 17}, {1, 1, 12, 17});
2361   test_quantized_binary_op(
2362       true, false, op_name, {2, 13, 32, 27}, {2, 13, 32, 27});
2363   test_quantized_binary_op(
2364       true, false, op_name, {7, 15, 6, 1}, {7, 15, 6, 17}); // broadcasting
2365   test_quantized_binary_op(
2366       true, false, op_name, {7, 15, 6, 17}, {7, 15, 6, 1}); // broadcasting
2367   test_quantized_binary_op(
2368       true, false, op_name, {7, 15, 1, 17}, {7, 15, 6, 17}); // broadcasting
2369   test_quantized_binary_op(
2370       true, false, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting
2371   test_quantized_binary_op(
2372       true, false, op_name, {1, 1, 6, 17}, {7, 15, 6, 17}); // broadcasting
2373   test_quantized_binary_op(
2374       true, false, op_name, {7, 15, 6, 17}, {1, 1, 6, 17}); // broadcasting
2375   test_quantized_binary_op(
2376       true, false, op_name, {1, 15, 6, 17}, {7, 15, 6, 17}); // broadcasting
2377   test_quantized_binary_op(
2378       true, false, op_name, {7, 15, 6, 17}, {1, 15, 6, 17}); // broadcasting
2379 
2380   // random params
2381   test_quantized_binary_op(false, true, op_name, {1, 1, 1, 1}, {1, 1, 1, 1});
2382   test_quantized_binary_op(false, true, op_name, {1, 1, 8, 8}, {1, 1, 8, 8});
2383   test_quantized_binary_op(
2384       false, true, op_name, {1, 1, 12, 17}, {1, 1, 12, 17});
2385   test_quantized_binary_op(
2386       false, true, op_name, {2, 13, 32, 27}, {2, 13, 32, 27});
2387   test_quantized_binary_op(
2388       false, true, op_name, {7, 15, 6, 1}, {7, 15, 6, 17}); // broadcasting
2389   test_quantized_binary_op(
2390       false, true, op_name, {7, 15, 6, 17}, {7, 15, 6, 1}); // broadcasting
2391   test_quantized_binary_op(
2392       false, true, op_name, {7, 15, 1, 17}, {7, 15, 6, 17}); // broadcasting
2393   test_quantized_binary_op(
2394       false, true, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting
2395   test_quantized_binary_op(
2396       false, true, op_name, {1, 1, 6, 17}, {7, 15, 6, 17}); // broadcasting
2397   test_quantized_binary_op(
2398       false, true, op_name, {7, 15, 6, 17}, {1, 1, 6, 17}); // broadcasting
2399   test_quantized_binary_op(
2400       false, true, op_name, {1, 15, 6, 17}, {7, 15, 6, 17}); // broadcasting
2401   test_quantized_binary_op(
2402       false, true, op_name, {7, 15, 6, 17}, {1, 15, 6, 17}); // broadcasting
2403 
2404   // random shape and params
2405   for (int i = 0; i < 10; i += 1) {
2406     const at::IntArrayRef tensor_shape = {
2407         rand_pos_int(30),
2408         rand_pos_int(30),
2409         rand_pos_int(100),
2410         rand_pos_int(100)};
2411     test_quantized_binary_op(false, true, op_name, tensor_shape, tensor_shape);
2412   }
2413 }
2414 
test_max_pool2d(const at::IntArrayRef input_shape,const c10::ScalarType dtype)2415 void test_max_pool2d(
2416     const at::IntArrayRef input_shape,
2417     const c10::ScalarType dtype) {
2418   const auto in_cpu = produce_random_tensor(input_shape);
2419 
2420   const auto input_quant_params = compute_quant_params(in_cpu, dtype);
2421   double scale = std::get<0>(input_quant_params);
2422   scale = safe_downcast<float>(scale);
2423   int zero_point = std::get<1>(input_quant_params);
2424 
2425   auto in_cpu_quantized =
2426       at::quantize_per_tensor(in_cpu, scale, zero_point, dtype);
2427 
2428   const auto out_cpu_quantized =
2429       at::max_pool2d(in_cpu_quantized, {3, 4}, {2, 1}, {1, 1}, {1, 1}, false);
2430   auto in_vk_quantized =
2431       at::quantize_per_tensor(in_cpu.vulkan(), scale, zero_point, dtype);
2432 
2433   const auto out_vk_quantized =
2434       at::max_pool2d(in_vk_quantized, {3, 4}, {2, 1}, {1, 1}, {1, 1}, false);
2435 
2436   const auto out_cpu_deq = at::dequantize(out_cpu_quantized);
2437   const auto out_vk_deq = at::dequantize(out_vk_quantized);
2438   const auto out_vk_deq_cpu = out_vk_deq.cpu();
2439 
2440   const auto check =
2441       almostEqual(out_vk_deq_cpu, out_cpu_deq, safe_downcast<float>(scale));
2442 
2443   if (!check) {
2444     showRtol(out_cpu_deq, out_vk_deq_cpu);
2445   }
2446   ASSERT_TRUE(check);
2447 }
2448 
TEST_F(VulkanAPITest,max_pool2d_qint8)2449 TEST_F(VulkanAPITest, max_pool2d_qint8) {
2450   c10::InferenceMode mode;
2451   test_max_pool2d({1, 3, 72, 96}, c10::ScalarType::QInt8);
2452   test_max_pool2d({5, 13, 55, 68}, c10::ScalarType::QInt8);
2453 }
2454 
TEST_F(VulkanAPITest,max_pool2d_quint8)2455 TEST_F(VulkanAPITest, max_pool2d_quint8) {
2456   c10::InferenceMode mode;
2457   test_max_pool2d({5, 13, 55, 68}, c10::ScalarType::QUInt8);
2458   test_max_pool2d({5, 13, 55, 19}, c10::ScalarType::QUInt8);
2459 }
2460 
TEST_F(VulkanAPITest,quantized_add_tests)2461 TEST_F(VulkanAPITest, quantized_add_tests) {
2462   quantized_binary_op_test_set("quantized::add");
2463 }
2464 
TEST_F(VulkanAPITest,quantized_sub_tests)2465 TEST_F(VulkanAPITest, quantized_sub_tests) {
2466   quantized_binary_op_test_set("quantized::sub");
2467 }
2468 
TEST_F(VulkanAPITest,quantized_mul_tests)2469 TEST_F(VulkanAPITest, quantized_mul_tests) {
2470   quantized_binary_op_test_set("quantized::mul");
2471 }
2472 
TEST_F(VulkanAPITest,quantized_div_tests)2473 TEST_F(VulkanAPITest, quantized_div_tests) {
2474   quantized_binary_op_test_set("quantized::div");
2475 }
2476 
test_quantized_conv2d(const bool prepacking,const bool compute_quantization_params,const bool random_quantization_params,const at::IntArrayRef input_shape,const at::IntArrayRef weight_shape,const at::IntArrayRef bias_shape,const c10::ScalarType w_dtype,const c10::ScalarType b_dtype,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,double in_scale=0.13,double w_scale=0.29,double b_scale=0.19,double out_scale=0.15,int in_zero_point=11,int w_zero_point=19,int b_zero_point=27,int out_zero_point=10)2477 void test_quantized_conv2d(
2478     const bool prepacking,
2479     const bool compute_quantization_params,
2480     const bool random_quantization_params,
2481     const at::IntArrayRef input_shape,
2482     const at::IntArrayRef weight_shape,
2483     const at::IntArrayRef bias_shape,
2484     const c10::ScalarType w_dtype,
2485     const c10::ScalarType b_dtype,
2486     std::vector<int64_t> stride,
2487     std::vector<int64_t> padding,
2488     std::vector<int64_t> dilation,
2489     int64_t groups,
2490     double in_scale = 0.13,
2491     double w_scale = 0.29,
2492     double b_scale = 0.19,
2493     double out_scale = 0.15,
2494     int in_zero_point = 11,
2495     int w_zero_point = 19,
2496     int b_zero_point = 27,
2497     int out_zero_point = 10) {
2498   c10::InferenceMode mode;
2499 
2500   const c10::ScalarType in_dtype = c10::ScalarType::QUInt8;
2501   const c10::ScalarType out_dtype = c10::ScalarType::QUInt8;
2502 
2503   // input cpu
2504   at::Tensor input_cpu; // input cpu tensor
2505   at::Tensor input_cpu_q; // input cpu tensor -> quantized
2506   at::Tensor input_cpu_deq; // input cpu tensor -> quantized -> dequantized
2507 
2508   // input vulkan
2509   at::Tensor input_vk; // input cpu tensor -> to vulkan
2510   at::Tensor input_vk_q; // input cpu tensor -> to vulkan -> quantized
2511   at::Tensor
2512       input_vk_deq; // input cpu tensor -> to vulkan -> quantized -> dequantized
2513   at::Tensor input_vk_deq_cpu; // input cpu tensor -> to vulkan -> quantized ->
2514                                // dequantized -> to cpu
2515 
2516   // weight cpu
2517   at::Tensor weight_cpu; // weight cpu tensor
2518   at::Tensor weight_cpu_q; // weight cpu tensor -> quantized
2519   at::Tensor weight_cpu_deq; // weight cpu tensor -> quantized -> dequantized
2520 
2521   // bias cpu
2522   at::Tensor bias_cpu; // bias cpu tensor
2523   at::Tensor bias_cpu_q; // bias cpu tensor -> quantized
2524   at::Tensor bias_cpu_deq; // bias cpu tensor -> quantized -> dequantized
2525 
2526   // When we randomly generate the input tensor, we might get unlucky
2527   // and one of the entries might be generated such that when it is divided
2528   // by the scale we get something like 2.50003 for example which could be
2529   // rounded to 2 or 3 depending on the precision and rounding method.
2530   // Because of that possibility, we generate the input and check the
2531   // difference between input_cpu_deq and input_vk_deq_cpu
2532   // If they are different we regenerated them again (up to 3 times)
2533   // The goal is to start with input tensors that remain equal after
2534   // quantization.
2535   int num_attempts = 5;
2536   for (int i = 0; i < num_attempts; i += 1) {
2537     // produce random input, weight and bias
2538     input_cpu = produce_random_tensor(input_shape, 1.26, 5.97, 0.59);
2539     weight_cpu = produce_random_tensor(weight_shape, 1.26, 5.97, 0.59);
2540     bias_cpu = produce_random_tensor(bias_shape, 1.26, 5.97, 0.59);
2541 
2542     if (compute_quantization_params) {
2543       // compute appropiate scale and zero point for input, weight and bias
2544       const auto in_quant_params = compute_quant_params(input_cpu, in_dtype);
2545       in_scale = std::get<0>(in_quant_params);
2546       in_zero_point = std::get<1>(in_quant_params);
2547 
2548       const auto w_quant_params = compute_quant_params(weight_cpu, w_dtype);
2549       w_scale = std::get<0>(w_quant_params);
2550       w_zero_point = std::get<1>(w_quant_params);
2551 
2552       const auto input_max = input_cpu.max().item<float>();
2553       const auto input_min = input_cpu.min().item<float>();
2554       const auto input_range = input_max - input_min;
2555 
2556       bias_cpu = input_range *
2557               at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)) +
2558           input_min;
2559       b_scale = in_scale;
2560       b_zero_point = in_zero_point;
2561       if (b_dtype == c10::ScalarType::QInt32) {
2562         b_scale = in_scale * w_scale;
2563         b_zero_point = 0;
2564       }
2565     } else if (random_quantization_params) {
2566       // produce random scale and zero point for inputs
2567       in_scale = produce_random_scale();
2568       in_zero_point = produce_random_zero_point(in_dtype);
2569 
2570       w_scale = produce_random_scale();
2571       w_zero_point = produce_random_zero_point(w_dtype);
2572 
2573       b_scale = produce_random_scale();
2574       b_zero_point = produce_random_zero_point(b_dtype);
2575     }
2576 
2577     // quantize cpu input, weight and bias
2578     input_cpu_q =
2579         at::quantize_per_tensor(input_cpu, in_scale, in_zero_point, in_dtype);
2580     weight_cpu_q =
2581         at::quantize_per_tensor(weight_cpu, w_scale, w_zero_point, w_dtype);
2582     bias_cpu_q =
2583         at::quantize_per_tensor(bias_cpu, b_scale, b_zero_point, b_dtype);
2584 
2585     // dequantize quantized cpu input, weight and bias
2586     input_cpu_deq = at::dequantize(input_cpu_q);
2587     weight_cpu_deq = at::dequantize(weight_cpu_q);
2588     bias_cpu_deq = at::dequantize(bias_cpu_q);
2589 
2590     // vulkan quantized input
2591     input_vk = input_cpu.vulkan();
2592     input_vk_q =
2593         at::quantize_per_tensor(input_vk, in_scale, in_zero_point, in_dtype);
2594 
2595     // dequantize quantized vulkan input
2596     input_vk_deq = at::dequantize(input_vk_q);
2597     input_vk_deq_cpu = input_vk_deq.cpu();
2598 
2599     const float input_dif =
2600         at::abs(input_cpu_deq - input_vk_deq_cpu).max().item<float>();
2601 
2602     if (input_dif < 1e-5 && input_dif < in_scale / 2) {
2603       break;
2604     } else {
2605       std::cout << "input_dif too big: " << input_dif;
2606       if (i + 1 < num_attempts) {
2607         std::cout << ". generating input again ..." << std::endl;
2608       } else {
2609         std::cout << std::endl;
2610       }
2611     }
2612   }
2613 
2614   // conv2d on dequantized cpu tensors
2615   // Note: we apply the convolution to the dequantized quantized tensors, that
2616   // way we are performing the operations on the same numeric values.
2617   const auto output_cpu = at::conv2d(
2618       input_cpu_deq,
2619       weight_cpu_deq,
2620       bias_cpu_deq,
2621       stride,
2622       padding,
2623       dilation,
2624       groups);
2625 
2626   if (compute_quantization_params || random_quantization_params) {
2627     // compute appropiate scale and zero point for output
2628     const auto out_quant_params = compute_quant_params(output_cpu, out_dtype);
2629     out_scale = std::get<0>(out_quant_params);
2630     out_zero_point = std::get<1>(out_quant_params);
2631   }
2632 
2633   // quantize and dequantize cpu output
2634   at::Tensor output_cpu_q =
2635       at::quantize_per_tensor(output_cpu, out_scale, out_zero_point, out_dtype);
2636   at::Tensor output_cpu_deq = at::dequantize(output_cpu_q);
2637 
2638   // vulkan quantized output
2639   at::Tensor output_vk_q;
2640 
2641   if (!prepacking) {
2642     // vulkan quantized conv2d
2643     output_vk_q = at::native::vulkan::ops::quantized_conv2d(
2644         input_vk_q,
2645         weight_cpu_q,
2646         bias_cpu_q,
2647         stride,
2648         padding,
2649         dilation,
2650         groups,
2651         out_scale,
2652         out_zero_point);
2653   } else {
2654     // vulkan quantized conv2d call by name
2655     const auto prepack_vulkan_call_by_name = callOpByName(
2656         "vulkan_prepack::create_qconv2d_context",
2657         "",
2658         weight_cpu_q,
2659         bias_cpu_q,
2660         stride,
2661         padding,
2662         dilation,
2663         groups,
2664         std::nullopt,
2665         std::nullopt);
2666     const auto vulkan_output = callOpByName(
2667         "vulkan_prepack::run_qconv2d_context",
2668         "",
2669         input_vk_q,
2670         out_scale,
2671         out_zero_point,
2672         prepack_vulkan_call_by_name[0]);
2673     output_vk_q = vulkan_output[0].toTensor();
2674   }
2675 
2676   // dequantize vulkan output
2677   const auto output_vk_deq = at::dequantize(output_vk_q);
2678   const auto output_vk_deq_cpu = output_vk_deq.cpu();
2679 
2680   // check
2681   const float tolerance = safe_downcast<float>(out_scale);
2682   const auto check = almostEqual(output_cpu_deq, output_vk_deq_cpu, tolerance);
2683 
2684   if (!check) {
2685     const auto vk_q_error =
2686         at::abs(output_vk_deq_cpu - output_cpu_deq).max().item<float>();
2687     std::cout << "Quantized Conv2d failed with: " << std::endl;
2688     std::cout << "input: shape " << input_shape << " scale " << in_scale
2689               << " and zero point " << in_zero_point << std::endl;
2690     std::cout << "weight: shape " << weight_shape << " scale " << w_scale
2691               << " and zero point " << w_zero_point << std::endl;
2692     std::cout << "bias: shape " << bias_shape << " scale " << b_scale
2693               << " and zero point " << b_zero_point << std::endl;
2694     std::cout << "output scale " << out_scale << " and zero point "
2695               << out_zero_point << std::endl;
2696     std::cout << "error: " << vk_q_error << std::endl;
2697   }
2698   ASSERT_TRUE(check);
2699 }
2700 
TEST_F(VulkanAPITest,conv2d_quantized_fixed_params_uint8)2701 TEST_F(VulkanAPITest, conv2d_quantized_fixed_params_uint8) {
2702   test_quantized_conv2d(
2703       /* prepacking? */ false,
2704       /* compute params */ false,
2705       /* random params */ false,
2706       /* input_shape */ {1, 3, 8, 8},
2707       /* weight_shape */ {1, 3, 3, 3},
2708       /* bias_shape */ {1},
2709       /* weight_dtype */ c10::ScalarType::QUInt8,
2710       /* bias_dtype */ c10::ScalarType::QUInt8,
2711       /* stride */ {2, 2},
2712       /* padding */ {1, 1},
2713       /* dilation */ {1, 1},
2714       /* groups */ 1);
2715 }
2716 
TEST_F(VulkanAPITest,conv2d_quantized_computed_params_uint8)2717 TEST_F(VulkanAPITest, conv2d_quantized_computed_params_uint8) {
2718   test_quantized_conv2d(
2719       /* prepacking? */ false,
2720       /* compute params */ true,
2721       /* random params */ false,
2722       /* input_shape */ {1, 3, 8, 8},
2723       /* weight_shape */ {1, 3, 3, 3},
2724       /* bias_shape */ {1},
2725       /* weight_dtype */ c10::ScalarType::QUInt8,
2726       /* bias_dtype */ c10::ScalarType::QUInt8,
2727       /* stride */ {2, 2},
2728       /* padding */ {1, 1},
2729       /* dilation */ {1, 1},
2730       /* groups */ 1);
2731 }
2732 
TEST_F(VulkanAPITest,conv2d_quantized_random_params_uint8)2733 TEST_F(VulkanAPITest, conv2d_quantized_random_params_uint8) {
2734   test_quantized_conv2d(
2735       /* prepacking? */ false,
2736       /* compute params */ false,
2737       /* random params */ true,
2738       /* input_shape */ {1, 3, 8, 8},
2739       /* weight_shape */ {1, 3, 3, 3},
2740       /* bias_shape */ {1},
2741       /* weight_dtype */ c10::ScalarType::QUInt8,
2742       /* bias_dtype */ c10::ScalarType::QUInt8,
2743       /* stride */ {2, 2},
2744       /* padding */ {1, 1},
2745       /* dilation */ {1, 1},
2746       /* groups */ 1);
2747 }
2748 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_fixed_params_uint8)2749 TEST_F(VulkanAPITest, conv2d_quantized_prepack_fixed_params_uint8) {
2750   test_quantized_conv2d(
2751       /* prepacking? */ true,
2752       /* compute params */ false,
2753       /* random params */ false,
2754       /* input_shape */ {1, 3, 8, 8},
2755       /* weight_shape */ {1, 3, 3, 3},
2756       /* bias_shape */ {1},
2757       /* weight_dtype */ c10::ScalarType::QUInt8,
2758       /* bias_dtype */ c10::ScalarType::QUInt8,
2759       /* stride */ {2, 2},
2760       /* padding */ {1, 1},
2761       /* dilation */ {1, 1},
2762       /* groups */ 1);
2763 }
2764 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_computed_params_uint8)2765 TEST_F(VulkanAPITest, conv2d_quantized_prepack_computed_params_uint8) {
2766   test_quantized_conv2d(
2767       /* prepacking? */ true,
2768       /* compute params */ true,
2769       /* random params */ false,
2770       /* input_shape */ {1, 3, 8, 8},
2771       /* weight_shape */ {1, 3, 3, 3},
2772       /* bias_shape */ {1},
2773       /* weight_dtype */ c10::ScalarType::QUInt8,
2774       /* bias_dtype */ c10::ScalarType::QUInt8,
2775       /* stride */ {2, 2},
2776       /* padding */ {1, 1},
2777       /* dilation */ {1, 1},
2778       /* groups */ 1);
2779 }
2780 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_random_params_uint8)2781 TEST_F(VulkanAPITest, conv2d_quantized_prepack_random_params_uint8) {
2782   test_quantized_conv2d(
2783       /* prepacking? */ true,
2784       /* compute params */ false,
2785       /* random params */ true,
2786       /* input_shape */ {1, 3, 8, 8},
2787       /* weight_shape */ {1, 3, 3, 3},
2788       /* bias_shape */ {1},
2789       /* weight_dtype */ c10::ScalarType::QUInt8,
2790       /* bias_dtype */ c10::ScalarType::QUInt8,
2791       /* stride */ {2, 2},
2792       /* padding */ {1, 1},
2793       /* dilation */ {1, 1},
2794       /* groups */ 1);
2795 }
2796 
TEST_F(VulkanAPITest,conv2d_dw_quantized_fixed_params_uint8)2797 TEST_F(VulkanAPITest, conv2d_dw_quantized_fixed_params_uint8) {
2798   test_quantized_conv2d(
2799       /* prepacking? */ false,
2800       /* compute params */ false,
2801       /* random params */ false,
2802       /* input_shape */ {1, 7, 137, 199},
2803       /* weight_shape */ {7, 1, 17, 7},
2804       /* bias_shape */ {7},
2805       /* weight_dtype */ c10::ScalarType::QUInt8,
2806       /* bias_dtype */ c10::ScalarType::QUInt8,
2807       /* stride */ {2, 3},
2808       /* padding */ {0, 4},
2809       /* dilation */ {3, 1},
2810       /* groups */ 7);
2811 }
2812 
TEST_F(VulkanAPITest,conv2d_dw_quantized_computed_params_uint8)2813 TEST_F(VulkanAPITest, conv2d_dw_quantized_computed_params_uint8) {
2814   test_quantized_conv2d(
2815       /* prepacking? */ false,
2816       /* compute params */ true,
2817       /* random params */ false,
2818       /* input_shape */ {1, 7, 137, 199},
2819       /* weight_shape */ {7, 1, 17, 7},
2820       /* bias_shape */ {7},
2821       /* weight_dtype */ c10::ScalarType::QUInt8,
2822       /* bias_dtype */ c10::ScalarType::QUInt8,
2823       /* stride */ {2, 3},
2824       /* padding */ {0, 4},
2825       /* dilation */ {3, 1},
2826       /* groups */ 7);
2827 }
2828 
TEST_F(VulkanAPITest,conv2d_dw_quantized_random_params_uint8)2829 TEST_F(VulkanAPITest, conv2d_dw_quantized_random_params_uint8) {
2830   test_quantized_conv2d(
2831       /* prepacking? */ false,
2832       /* compute params */ false,
2833       /* random params */ true,
2834       /* input_shape */ {1, 7, 137, 199},
2835       /* weight_shape */ {7, 1, 17, 7},
2836       /* bias_shape */ {7},
2837       /* weight_dtype */ c10::ScalarType::QUInt8,
2838       /* bias_dtype */ c10::ScalarType::QUInt8,
2839       /* stride */ {2, 3},
2840       /* padding */ {0, 4},
2841       /* dilation */ {3, 1},
2842       /* groups */ 7);
2843 }
2844 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_fixed_params_uint8)2845 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_fixed_params_uint8) {
2846   test_quantized_conv2d(
2847       /* prepacking? */ true,
2848       /* compute params */ false,
2849       /* random params */ false,
2850       /* input_shape */ {1, 7, 137, 199},
2851       /* weight_shape */ {7, 1, 17, 7},
2852       /* bias_shape */ {7},
2853       /* weight_dtype */ c10::ScalarType::QUInt8,
2854       /* bias_dtype */ c10::ScalarType::QUInt8,
2855       /* stride */ {2, 3},
2856       /* padding */ {0, 4},
2857       /* dilation */ {3, 1},
2858       /* groups */ 7);
2859 }
2860 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_computed_params_uint8)2861 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_computed_params_uint8) {
2862   test_quantized_conv2d(
2863       /* prepacking? */ true,
2864       /* compute params */ true,
2865       /* random params */ false,
2866       /* input_shape */ {1, 7, 137, 199},
2867       /* weight_shape */ {7, 1, 17, 7},
2868       /* bias_shape */ {7},
2869       /* weight_dtype */ c10::ScalarType::QUInt8,
2870       /* bias_dtype */ c10::ScalarType::QUInt8,
2871       /* stride */ {2, 3},
2872       /* padding */ {0, 4},
2873       /* dilation */ {3, 1},
2874       /* groups */ 7);
2875 }
2876 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_random_params_uint8)2877 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_random_params_uint8) {
2878   test_quantized_conv2d(
2879       /* prepacking? */ true,
2880       /* compute params */ false,
2881       /* random params */ true,
2882       /* input_shape */ {1, 7, 137, 199},
2883       /* weight_shape */ {7, 1, 17, 7},
2884       /* bias_shape */ {7},
2885       /* weight_dtype */ c10::ScalarType::QUInt8,
2886       /* bias_dtype */ c10::ScalarType::QUInt8,
2887       /* stride */ {2, 3},
2888       /* padding */ {0, 4},
2889       /* dilation */ {3, 1},
2890       /* groups */ 7);
2891 }
2892 
TEST_F(VulkanAPITest,conv2d_pw_quantized_fixed_params_uint8)2893 TEST_F(VulkanAPITest, conv2d_pw_quantized_fixed_params_uint8) {
2894   test_quantized_conv2d(
2895       /* prepacking? */ false,
2896       /* compute params */ false,
2897       /* random params */ false,
2898       /* input_shape */ {1, 17, 127, 397},
2899       /* weight_shape */ {29, 17, 1, 1},
2900       /* bias_shape */ {29},
2901       /* weight_dtype */ c10::ScalarType::QUInt8,
2902       /* bias_dtype */ c10::ScalarType::QUInt8,
2903       /* stride */ {1, 1},
2904       /* padding */ {0, 0},
2905       /* dilation */ {1, 1},
2906       /* groups */ 1);
2907 }
2908 
TEST_F(VulkanAPITest,conv2d_pw_quantized_computed_params_uint8)2909 TEST_F(VulkanAPITest, conv2d_pw_quantized_computed_params_uint8) {
2910   test_quantized_conv2d(
2911       /* prepacking? */ false,
2912       /* compute params */ true,
2913       /* random params */ false,
2914       /* input_shape */ {1, 17, 127, 397},
2915       /* weight_shape */ {29, 17, 1, 1},
2916       /* bias_shape */ {29},
2917       /* weight_dtype */ c10::ScalarType::QUInt8,
2918       /* bias_dtype */ c10::ScalarType::QUInt8,
2919       /* stride */ {1, 1},
2920       /* padding */ {0, 0},
2921       /* dilation */ {1, 1},
2922       /* groups */ 1);
2923 }
2924 
TEST_F(VulkanAPITest,conv2d_pw_quantized_random_params_uint8)2925 TEST_F(VulkanAPITest, conv2d_pw_quantized_random_params_uint8) {
2926   test_quantized_conv2d(
2927       /* prepacking? */ false,
2928       /* compute params */ false,
2929       /* random params */ true,
2930       /* input_shape */ {1, 17, 127, 397},
2931       /* weight_shape */ {29, 17, 1, 1},
2932       /* bias_shape */ {29},
2933       /* weight_dtype */ c10::ScalarType::QUInt8,
2934       /* bias_dtype */ c10::ScalarType::QUInt8,
2935       /* stride */ {1, 1},
2936       /* padding */ {0, 0},
2937       /* dilation */ {1, 1},
2938       /* groups */ 1);
2939 }
2940 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_fixed_params_uint8)2941 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_fixed_params_uint8) {
2942   test_quantized_conv2d(
2943       /* prepacking? */ true,
2944       /* compute params */ false,
2945       /* random params */ false,
2946       /* input_shape */ {1, 17, 127, 397},
2947       /* weight_shape */ {29, 17, 1, 1},
2948       /* bias_shape */ {29},
2949       /* weight_dtype */ c10::ScalarType::QUInt8,
2950       /* bias_dtype */ c10::ScalarType::QUInt8,
2951       /* stride */ {1, 1},
2952       /* padding */ {0, 0},
2953       /* dilation */ {1, 1},
2954       /* groups */ 1);
2955 }
2956 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_computed_params_uint8)2957 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_computed_params_uint8) {
2958   test_quantized_conv2d(
2959       /* prepacking? */ true,
2960       /* compute params */ true,
2961       /* random params */ false,
2962       /* input_shape */ {1, 17, 127, 397},
2963       /* weight_shape */ {29, 17, 1, 1},
2964       /* bias_shape */ {29},
2965       /* weight_dtype */ c10::ScalarType::QUInt8,
2966       /* bias_dtype */ c10::ScalarType::QUInt8,
2967       /* stride */ {1, 1},
2968       /* padding */ {0, 0},
2969       /* dilation */ {1, 1},
2970       /* groups */ 1);
2971 }
2972 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_random_params_uint8)2973 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_random_params_uint8) {
2974   test_quantized_conv2d(
2975       /* prepacking? */ true,
2976       /* compute params */ false,
2977       /* random params */ true,
2978       /* input_shape */ {1, 17, 127, 397},
2979       /* weight_shape */ {29, 17, 1, 1},
2980       /* bias_shape */ {29},
2981       /* weight_dtype */ c10::ScalarType::QUInt8,
2982       /* bias_dtype */ c10::ScalarType::QUInt8,
2983       /* stride */ {1, 1},
2984       /* padding */ {0, 0},
2985       /* dilation */ {1, 1},
2986       /* groups */ 1);
2987 }
2988 
TEST_F(VulkanAPITest,conv2d_quantized_fixed_params_int8_int32)2989 TEST_F(VulkanAPITest, conv2d_quantized_fixed_params_int8_int32) {
2990   test_quantized_conv2d(
2991       /* prepacking? */ false,
2992       /* compute params */ false,
2993       /* random params */ false,
2994       /* input_shape */ {1, 3, 8, 8},
2995       /* weight_shape */ {1, 3, 3, 3},
2996       /* bias_shape */ {1},
2997       /* weight_dtype */ c10::ScalarType::QInt8,
2998       /* bias_dtype */ c10::ScalarType::QInt32,
2999       /* stride */ {2, 2},
3000       /* padding */ {1, 1},
3001       /* dilation */ {1, 1},
3002       /* groups */ 1);
3003 }
3004 
TEST_F(VulkanAPITest,conv2d_quantized_computed_params_int8_int32)3005 TEST_F(VulkanAPITest, conv2d_quantized_computed_params_int8_int32) {
3006   test_quantized_conv2d(
3007       /* prepacking? */ false,
3008       /* compute params */ true,
3009       /* random params */ false,
3010       /* input_shape */ {1, 3, 8, 8},
3011       /* weight_shape */ {1, 3, 3, 3},
3012       /* bias_shape */ {1},
3013       /* weight_dtype */ c10::ScalarType::QInt8,
3014       /* bias_dtype */ c10::ScalarType::QInt32,
3015       /* stride */ {2, 2},
3016       /* padding */ {1, 1},
3017       /* dilation */ {1, 1},
3018       /* groups */ 1);
3019 }
3020 
TEST_F(VulkanAPITest,conv2d_quantized_random_params_int8_int32)3021 TEST_F(VulkanAPITest, conv2d_quantized_random_params_int8_int32) {
3022   test_quantized_conv2d(
3023       /* prepacking? */ false,
3024       /* compute params */ false,
3025       /* random params */ true,
3026       /* input_shape */ {1, 3, 8, 8},
3027       /* weight_shape */ {1, 3, 3, 3},
3028       /* bias_shape */ {1},
3029       /* weight_dtype */ c10::ScalarType::QInt8,
3030       /* bias_dtype */ c10::ScalarType::QInt32,
3031       /* stride */ {2, 2},
3032       /* padding */ {1, 1},
3033       /* dilation */ {1, 1},
3034       /* groups */ 1);
3035 }
3036 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_fixed_params_int8_int32)3037 TEST_F(VulkanAPITest, conv2d_quantized_prepack_fixed_params_int8_int32) {
3038   test_quantized_conv2d(
3039       /* prepacking? */ true,
3040       /* compute params */ false,
3041       /* random params */ false,
3042       /* input_shape */ {1, 3, 8, 8},
3043       /* weight_shape */ {1, 3, 3, 3},
3044       /* bias_shape */ {1},
3045       /* weight_dtype */ c10::ScalarType::QInt8,
3046       /* bias_dtype */ c10::ScalarType::QInt32,
3047       /* stride */ {2, 2},
3048       /* padding */ {1, 1},
3049       /* dilation */ {1, 1},
3050       /* groups */ 1);
3051 }
3052 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_computed_params_int8_int32)3053 TEST_F(VulkanAPITest, conv2d_quantized_prepack_computed_params_int8_int32) {
3054   test_quantized_conv2d(
3055       /* prepacking? */ true,
3056       /* compute params */ true,
3057       /* random params */ false,
3058       /* input_shape */ {1, 3, 8, 8},
3059       /* weight_shape */ {1, 3, 3, 3},
3060       /* bias_shape */ {1},
3061       /* weight_dtype */ c10::ScalarType::QInt8,
3062       /* bias_dtype */ c10::ScalarType::QInt32,
3063       /* stride */ {2, 2},
3064       /* padding */ {1, 1},
3065       /* dilation */ {1, 1},
3066       /* groups */ 1);
3067 }
3068 
TEST_F(VulkanAPITest,conv2d_quantized_prepack_random_params_int8_int32)3069 TEST_F(VulkanAPITest, conv2d_quantized_prepack_random_params_int8_int32) {
3070   test_quantized_conv2d(
3071       /* prepacking? */ true,
3072       /* compute params */ false,
3073       /* random params */ true,
3074       /* input_shape */ {1, 3, 8, 8},
3075       /* weight_shape */ {1, 3, 3, 3},
3076       /* bias_shape */ {1},
3077       /* weight_dtype */ c10::ScalarType::QInt8,
3078       /* bias_dtype */ c10::ScalarType::QInt32,
3079       /* stride */ {2, 2},
3080       /* padding */ {1, 1},
3081       /* dilation */ {1, 1},
3082       /* groups */ 1);
3083 }
3084 
TEST_F(VulkanAPITest,conv2d_dw_quantized_fixed_params_int8_int32)3085 TEST_F(VulkanAPITest, conv2d_dw_quantized_fixed_params_int8_int32) {
3086   test_quantized_conv2d(
3087       /* prepacking? */ false,
3088       /* compute params */ false,
3089       /* random params */ false,
3090       /* input_shape */ {1, 7, 137, 199},
3091       /* weight_shape */ {7, 1, 17, 7},
3092       /* bias_shape */ {7},
3093       /* weight_dtype */ c10::ScalarType::QInt8,
3094       /* bias_dtype */ c10::ScalarType::QInt32,
3095       /* stride */ {2, 3},
3096       /* padding */ {0, 4},
3097       /* dilation */ {3, 1},
3098       /* groups */ 7);
3099 }
3100 
TEST_F(VulkanAPITest,conv2d_dw_quantized_computed_params_int8_int32)3101 TEST_F(VulkanAPITest, conv2d_dw_quantized_computed_params_int8_int32) {
3102   test_quantized_conv2d(
3103       /* prepacking? */ false,
3104       /* compute params */ true,
3105       /* random params */ false,
3106       /* input_shape */ {1, 7, 137, 199},
3107       /* weight_shape */ {7, 1, 17, 7},
3108       /* bias_shape */ {7},
3109       /* weight_dtype */ c10::ScalarType::QInt8,
3110       /* bias_dtype */ c10::ScalarType::QInt32,
3111       /* stride */ {2, 3},
3112       /* padding */ {0, 4},
3113       /* dilation */ {3, 1},
3114       /* groups */ 7);
3115 }
3116 
TEST_F(VulkanAPITest,conv2d_dw_quantized_random_params_int8_int32)3117 TEST_F(VulkanAPITest, conv2d_dw_quantized_random_params_int8_int32) {
3118   test_quantized_conv2d(
3119       /* prepacking? */ false,
3120       /* compute params */ false,
3121       /* random params */ true,
3122       /* input_shape */ {1, 7, 137, 199},
3123       /* weight_shape */ {7, 1, 17, 7},
3124       /* bias_shape */ {7},
3125       /* weight_dtype */ c10::ScalarType::QInt8,
3126       /* bias_dtype */ c10::ScalarType::QInt32,
3127       /* stride */ {2, 3},
3128       /* padding */ {0, 4},
3129       /* dilation */ {3, 1},
3130       /* groups */ 7);
3131 }
3132 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_fixed_params_int8_int32)3133 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_fixed_params_int8_int32) {
3134   test_quantized_conv2d(
3135       /* prepacking? */ true,
3136       /* compute params */ false,
3137       /* random params */ false,
3138       /* input_shape */ {1, 7, 137, 199},
3139       /* weight_shape */ {7, 1, 17, 7},
3140       /* bias_shape */ {7},
3141       /* weight_dtype */ c10::ScalarType::QInt8,
3142       /* bias_dtype */ c10::ScalarType::QInt32,
3143       /* stride */ {2, 3},
3144       /* padding */ {0, 4},
3145       /* dilation */ {3, 1},
3146       /* groups */ 7);
3147 }
3148 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_computed_params_int8_int32)3149 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_computed_params_int8_int32) {
3150   test_quantized_conv2d(
3151       /* prepacking? */ true,
3152       /* compute params */ true,
3153       /* random params */ false,
3154       /* input_shape */ {1, 7, 137, 199},
3155       /* weight_shape */ {7, 1, 17, 7},
3156       /* bias_shape */ {7},
3157       /* weight_dtype */ c10::ScalarType::QInt8,
3158       /* bias_dtype */ c10::ScalarType::QInt32,
3159       /* stride */ {2, 3},
3160       /* padding */ {0, 4},
3161       /* dilation */ {3, 1},
3162       /* groups */ 7);
3163 }
3164 
TEST_F(VulkanAPITest,conv2d_dw_quantized_prepack_random_params_int8_int32)3165 TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_random_params_int8_int32) {
3166   test_quantized_conv2d(
3167       /* prepacking? */ true,
3168       /* compute params */ false,
3169       /* random params */ true,
3170       /* input_shape */ {1, 7, 137, 199},
3171       /* weight_shape */ {7, 1, 17, 7},
3172       /* bias_shape */ {7},
3173       /* weight_dtype */ c10::ScalarType::QInt8,
3174       /* bias_dtype */ c10::ScalarType::QInt32,
3175       /* stride */ {2, 3},
3176       /* padding */ {0, 4},
3177       /* dilation */ {3, 1},
3178       /* groups */ 7);
3179 }
3180 
TEST_F(VulkanAPITest,conv2d_pw_quantized_fixed_params_int8_int32)3181 TEST_F(VulkanAPITest, conv2d_pw_quantized_fixed_params_int8_int32) {
3182   test_quantized_conv2d(
3183       /* prepacking? */ false,
3184       /* compute params */ false,
3185       /* random params */ false,
3186       /* input_shape */ {1, 17, 127, 397},
3187       /* weight_shape */ {29, 17, 1, 1},
3188       /* bias_shape */ {29},
3189       /* weight_dtype */ c10::ScalarType::QInt8,
3190       /* bias_dtype */ c10::ScalarType::QInt32,
3191       /* stride */ {1, 1},
3192       /* padding */ {0, 0},
3193       /* dilation */ {1, 1},
3194       /* groups */ 1);
3195 }
3196 
TEST_F(VulkanAPITest,conv2d_pw_quantized_computed_params_int8_int32)3197 TEST_F(VulkanAPITest, conv2d_pw_quantized_computed_params_int8_int32) {
3198   test_quantized_conv2d(
3199       /* prepacking? */ false,
3200       /* compute params */ true,
3201       /* random params */ false,
3202       /* input_shape */ {1, 17, 127, 397},
3203       /* weight_shape */ {29, 17, 1, 1},
3204       /* bias_shape */ {29},
3205       /* weight_dtype */ c10::ScalarType::QInt8,
3206       /* bias_dtype */ c10::ScalarType::QInt32,
3207       /* stride */ {1, 1},
3208       /* padding */ {0, 0},
3209       /* dilation */ {1, 1},
3210       /* groups */ 1);
3211 }
3212 
TEST_F(VulkanAPITest,conv2d_pw_quantized_random_params_int8_int32)3213 TEST_F(VulkanAPITest, conv2d_pw_quantized_random_params_int8_int32) {
3214   test_quantized_conv2d(
3215       /* prepacking? */ false,
3216       /* compute params */ false,
3217       /* random params */ true,
3218       /* input_shape */ {1, 17, 127, 397},
3219       /* weight_shape */ {29, 17, 1, 1},
3220       /* bias_shape */ {29},
3221       /* weight_dtype */ c10::ScalarType::QInt8,
3222       /* bias_dtype */ c10::ScalarType::QInt32,
3223       /* stride */ {1, 1},
3224       /* padding */ {0, 0},
3225       /* dilation */ {1, 1},
3226       /* groups */ 1);
3227 }
3228 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_fixed_params_int8_int32)3229 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_fixed_params_int8_int32) {
3230   test_quantized_conv2d(
3231       /* prepacking? */ true,
3232       /* compute params */ false,
3233       /* random params */ false,
3234       /* input_shape */ {1, 17, 127, 397},
3235       /* weight_shape */ {29, 17, 1, 1},
3236       /* bias_shape */ {29},
3237       /* weight_dtype */ c10::ScalarType::QInt8,
3238       /* bias_dtype */ c10::ScalarType::QInt32,
3239       /* stride */ {1, 1},
3240       /* padding */ {0, 0},
3241       /* dilation */ {1, 1},
3242       /* groups */ 1);
3243 }
3244 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_computed_params_int8_int32)3245 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_computed_params_int8_int32) {
3246   test_quantized_conv2d(
3247       /* prepacking? */ true,
3248       /* compute params */ true,
3249       /* random params */ false,
3250       /* input_shape */ {1, 17, 127, 397},
3251       /* weight_shape */ {29, 17, 1, 1},
3252       /* bias_shape */ {29},
3253       /* weight_dtype */ c10::ScalarType::QInt8,
3254       /* bias_dtype */ c10::ScalarType::QInt32,
3255       /* stride */ {1, 1},
3256       /* padding */ {0, 0},
3257       /* dilation */ {1, 1},
3258       /* groups */ 1);
3259 }
3260 
TEST_F(VulkanAPITest,conv2d_pw_quantized_prepack_random_params_int8_int32)3261 TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_random_params_int8_int32) {
3262   test_quantized_conv2d(
3263       /* prepacking? */ true,
3264       /* compute params */ false,
3265       /* random params */ true,
3266       /* input_shape */ {1, 17, 127, 397},
3267       /* weight_shape */ {29, 17, 1, 1},
3268       /* bias_shape */ {29},
3269       /* weight_dtype */ c10::ScalarType::QInt8,
3270       /* bias_dtype */ c10::ScalarType::QInt32,
3271       /* stride */ {1, 1},
3272       /* padding */ {0, 0},
3273       /* dilation */ {1, 1},
3274       /* groups */ 1);
3275 }
3276 
TEST_F(VulkanAPITest,quantized_tensor_get_scale_zero_point)3277 TEST_F(VulkanAPITest, quantized_tensor_get_scale_zero_point) {
3278   const auto in_cpu =
3279       at::rand({2, 13, 12, 27}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
3280 
3281   const double scale = 0.1;
3282   const int zero_point = 10;
3283 
3284   const auto cpu_quantized = at::quantize_per_tensor(
3285       in_cpu, scale, zero_point, c10::ScalarType::QUInt8);
3286 
3287   const auto in_vulkan = in_cpu.vulkan();
3288   const auto vulkan_quantized = at::native::vulkan::ops::quantize_per_tensor(
3289       in_vulkan, scale, zero_point, c10::ScalarType::QUInt8);
3290 
3291   double cpu_quantized_scale = cpu_quantized.q_scale();
3292   int64_t cpu_quantized_zero_point = cpu_quantized.q_zero_point();
3293   double vulkan_quantized_scale = vulkan_quantized.q_scale();
3294   int64_t vulkan_quantized_zero_point = vulkan_quantized.q_zero_point();
3295 
3296   ASSERT_TRUE(
3297       cpu_quantized_scale == vulkan_quantized_scale &&
3298       cpu_quantized_zero_point == vulkan_quantized_zero_point);
3299 }
3300 
_test_quantized_linear(const at::Tensor & input_cpu,const at::Tensor & weight,const at::Tensor & bias,double out_scale,int out_zero_point,bool input_quant_dtype_int8,bool weight_quant_dtype_int8)3301 bool _test_quantized_linear(
3302     const at::Tensor& input_cpu,
3303     const at::Tensor& weight,
3304     const at::Tensor& bias,
3305     double out_scale,
3306     int out_zero_point,
3307     bool input_quant_dtype_int8,
3308     bool weight_quant_dtype_int8) {
3309   const auto input_quant_params = compute_quant_params(
3310       input_cpu,
3311       input_quant_dtype_int8 ? c10::ScalarType::QInt8
3312                              : c10::ScalarType::QUInt8);
3313   double scale = std::get<0>(input_quant_params);
3314   scale = safe_downcast<float>(scale);
3315   int zero_point = std::get<1>(input_quant_params);
3316   auto input_cpu_quantized = at::quantize_per_tensor(
3317       input_cpu,
3318       scale,
3319       zero_point,
3320       input_quant_dtype_int8 ? c10::ScalarType::QInt8
3321                              : c10::ScalarType::QUInt8);
3322 
3323   const auto weight_quant_params = compute_quant_params(
3324       weight,
3325       weight_quant_dtype_int8 ? c10::ScalarType::QInt8
3326                               : c10::ScalarType::QUInt8);
3327   double w_scale = std::get<0>(weight_quant_params);
3328   w_scale = safe_downcast<float>(w_scale);
3329   // Weight zero point is expected to always be 0
3330   int w_zero_point = 0;
3331   const auto weight_cpu_quantized = at::quantize_per_tensor(
3332       weight,
3333       w_scale,
3334       w_zero_point,
3335       weight_quant_dtype_int8 ? c10::ScalarType::QInt8
3336                               : c10::ScalarType::QUInt8);
3337 
3338   auto pack =
3339       callOpByName("quantized::linear_prepack", "", weight_cpu_quantized, bias);
3340 
3341   auto out_cpu_quant = callOpByName(
3342       "quantized::linear",
3343       "",
3344       input_cpu_quantized,
3345       pack[0],
3346       out_scale,
3347       out_zero_point);
3348 
3349   at::Tensor out_cpu_dequant = at::dequantize(out_cpu_quant[0].toTensor());
3350 
3351   // Vulkan
3352   auto input_vk_quantized = at::quantize_per_tensor(
3353       input_cpu.vulkan(),
3354       scale,
3355       zero_point,
3356       input_quant_dtype_int8 ? c10::ScalarType::QInt8
3357                              : c10::ScalarType::QUInt8);
3358 
3359   at::Tensor out_vk_quant;
3360 
3361   c10::intrusive_ptr<at::native::vulkan::ops::LinearPackedContext> vk_pack =
3362       at::native::vulkan::ops::create_linear_context(
3363           weight_cpu_quantized.t(), bias);
3364 
3365   out_vk_quant = at::native::vulkan::ops::run_qlinear_context(
3366       input_vk_quantized, out_scale, out_zero_point, vk_pack);
3367 
3368   auto out_vk_dequant = at::dequantize(out_vk_quant);
3369   auto out_vk_to_cpu_dequant = vulkan_to_cpu(out_vk_dequant, out_cpu_dequant);
3370 
3371   const auto check = almostEqual(
3372       out_cpu_dequant, out_vk_to_cpu_dequant, safe_downcast<float>(out_scale));
3373   if (!check) {
3374     long xpos = -1, ypos = -1;
3375     if (input_cpu.sizes().size() == 2) {
3376       // for 2D tensor get the row col that caused failure
3377       showRtol(out_cpu_dequant, out_vk_to_cpu_dequant, &xpos, &ypos);
3378     } else {
3379       showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
3380     }
3381     if (xpos != -1 && ypos != -1) {
3382       std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
3383                 << "\n";
3384       std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
3385                 << "\n";
3386       std::cout << "Input tensor row " << ypos << "\n";
3387       for (int i = 0; i < input_cpu.sizes()[1]; i++) {
3388         std::cout << input_cpu[ypos][i].item<double>() << ", ";
3389       }
3390       std::cout << "\n";
3391 
3392       std::cout << "Weight tensor scale: " << w_scale
3393                 << " zerop: " << w_zero_point << "\n";
3394       std::cout << "Weight tensor col " << xpos << "\n";
3395       for (int i = 0; i < weight.sizes()[1]; i++) {
3396         std::cout << weight[xpos][i].item<double>() << ", ";
3397       }
3398       std::cout << "\n";
3399 
3400       std::cout << "Input tensor quantized row " << ypos << " with dtype "
3401                 << (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
3402       for (int i = 0; i < input_cpu.sizes()[1]; i++) {
3403         std::cout << input_cpu_quantized[ypos][i].item<double>() << ", ";
3404       }
3405       std::cout << "\n";
3406 
3407       std::cout << "Weight tensor quantized col " << xpos << " with dtype "
3408                 << (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
3409       for (int i = 0; i < weight.sizes()[1]; i++) {
3410         std::cout << weight_cpu_quantized[xpos][i].item<double>() << ", ";
3411       }
3412       std::cout << "\n";
3413 
3414       std::cout << "bias tensor\n";
3415       for (int i = 0; i < bias.sizes()[0]; i++) {
3416         std::cout << bias[i].item<double>() << ", ";
3417       }
3418       std::cout << "\n";
3419 
3420       std::cout << "out_scale: " << out_scale
3421                 << " out_zero_point: " << out_zero_point << "\n";
3422 
3423       std::cout << "cpu unmatched output: "
3424                 << out_cpu_dequant[ypos][xpos].item<double>() << "\n";
3425       std::cout << "vk unmatched output: "
3426                 << out_vk_to_cpu_dequant[ypos][xpos].item<double>() << "\n";
3427     }
3428   }
3429   return check;
3430 }
3431 
test_quantized_linear_for_dtypes(const at::Tensor & input_cpu,const at::Tensor & weight,const at::Tensor & bias,bool input_quant_dtype_int8,bool weight_quant_dtype_int8)3432 bool test_quantized_linear_for_dtypes(
3433     const at::Tensor& input_cpu,
3434     const at::Tensor& weight,
3435     const at::Tensor& bias,
3436     bool input_quant_dtype_int8,
3437     bool weight_quant_dtype_int8) {
3438   double out_scale = produce_random_scale();
3439   out_scale = safe_downcast<float>(out_scale);
3440   int out_zero_point = produce_random_zero_point(
3441       input_quant_dtype_int8 ? c10::ScalarType::QInt8
3442                              : c10::ScalarType::QUInt8);
3443   const auto check = _test_quantized_linear(
3444       input_cpu,
3445       weight,
3446       bias,
3447       out_scale,
3448       out_zero_point,
3449       input_quant_dtype_int8,
3450       weight_quant_dtype_int8);
3451   if (!check) {
3452     // on failure we want to print the exact row/col that causes the
3453     // failure in 2D, so we can debug
3454     if (input_cpu.sizes().size() != 2) {
3455       const auto d = c10::multiply_integers(
3456           input_cpu.sizes().cbegin(), input_cpu.sizes().end() - 1);
3457       auto input_cpu_2d = input_cpu.view({d, input_cpu.size(-1)});
3458 
3459       _test_quantized_linear(
3460           input_cpu_2d,
3461           weight,
3462           bias,
3463           out_scale,
3464           out_zero_point,
3465           input_quant_dtype_int8,
3466           weight_quant_dtype_int8);
3467     }
3468   }
3469   return check;
3470 }
3471 
test_quantized_linear(const at::IntArrayRef input_shape,const at::IntArrayRef weight_shape,const at::IntArrayRef bias_shape)3472 void test_quantized_linear(
3473     const at::IntArrayRef input_shape,
3474     const at::IntArrayRef weight_shape,
3475     const at::IntArrayRef bias_shape) {
3476   c10::InferenceMode mode;
3477 
3478   const auto input_cpu = produce_random_tensor(input_shape);
3479 
3480   const auto weight = produce_random_tensor(weight_shape);
3481 
3482   const auto bias = produce_random_tensor(bias_shape);
3483 
3484   bool check =
3485       test_quantized_linear_for_dtypes(input_cpu, weight, bias, false, true);
3486   ASSERT_TRUE(check);
3487   check = test_quantized_linear_for_dtypes(input_cpu, weight, bias, true, true);
3488   ASSERT_TRUE(check);
3489 }
3490 
TEST_F(VulkanAPITest,linear_2d_flat)3491 TEST_F(VulkanAPITest, linear_2d_flat) {
3492   test_quantized_linear({1, 100}, {1, 100}, {1});
3493 }
3494 
TEST_F(VulkanAPITest,linear_2d_small)3495 TEST_F(VulkanAPITest, linear_2d_small) {
3496   test_quantized_linear({2, 3}, {4, 3}, {4});
3497 }
3498 
TEST_F(VulkanAPITest,linear_2d_large)3499 TEST_F(VulkanAPITest, linear_2d_large) {
3500   test_quantized_linear({1287, 17}, {23, 17}, {23});
3501 }
3502 
TEST_F(VulkanAPITest,linear_3d_flat)3503 TEST_F(VulkanAPITest, linear_3d_flat) {
3504   test_quantized_linear({1, 1, 37}, {41, 37}, {41});
3505 }
3506 
TEST_F(VulkanAPITest,linear_3d_small)3507 TEST_F(VulkanAPITest, linear_3d_small) {
3508   test_quantized_linear({2, 3, 4}, {5, 4}, {5});
3509 }
3510 
TEST_F(VulkanAPITest,linear_3d_large)3511 TEST_F(VulkanAPITest, linear_3d_large) {
3512   test_quantized_linear({23, 17, 41}, {15, 41}, {15});
3513 }
3514 
TEST_F(VulkanAPITest,linear_4d_flat)3515 TEST_F(VulkanAPITest, linear_4d_flat) {
3516   test_quantized_linear({1, 1, 1, 37}, {41, 37}, {41});
3517 }
3518 
TEST_F(VulkanAPITest,linear_4d_small)3519 TEST_F(VulkanAPITest, linear_4d_small) {
3520   test_quantized_linear({2, 3, 4, 5}, {6, 5}, {6});
3521 }
3522 
TEST_F(VulkanAPITest,linear_4d_large)3523 TEST_F(VulkanAPITest, linear_4d_large) {
3524   test_quantized_linear({9, 13, 11, 17}, {23, 17}, {23});
3525 }
3526 
3527 // The following code is not directly releated to quantization. We put it here
3528 // since we are not able to run this test on GH's CI: for some unknown reason,
3529 // we are not able to reference symbols in the vulkan directory, hence the build
3530 // on GH fails. Moving the test here so we are still able to run it on
3531 // internally on devserver and laptops.
3532 
texel_almost_equal(int expected,float actual)3533 bool texel_almost_equal(int expected, float actual) {
3534   // -1 is a don't care value.
3535   return (expected == -1) || (fabs(expected - actual) < kTolerance);
3536 }
3537 
texel_almost_equal(const ivec4 & expected,const vec4 & actual)3538 bool texel_almost_equal(const ivec4& expected, const vec4& actual) {
3539   return (
3540       texel_almost_equal(expected.data[0], actual.data[0]) &&
3541       texel_almost_equal(expected.data[1], actual.data[1]) &&
3542       texel_almost_equal(expected.data[2], actual.data[2]) &&
3543       texel_almost_equal(expected.data[3], actual.data[3]));
3544 }
3545 
TEST_F(VulkanAPITest,extract_texel_test)3546 TEST_F(VulkanAPITest, extract_texel_test) {
3547   int n = 3;
3548   int c = 5;
3549   int h = 6;
3550   int w = 7;
3551   int hw = h * w;
3552   int chw = c * h * w;
3553 
3554   // The input tensor is a consecutive range of whole numbers from [0, n * c * h
3555   // * w)
3556   auto cpu =
3557       at::range(0, n * c * h * w - 1, at::device(at::kCPU).dtype(at::kFloat))
3558           .reshape({n, c, h, w});
3559   auto vk = cpu.vulkan();
3560 
3561   // By default, we are using channel-packed 3d tensors.
3562   // The x and y are typical plane.
3563   // The z channel is packed with batch and channel, e.g. every 4 channels are
3564   // packed into one texel. Hence, to access a tensor at batch nn and channel
3565   // cc, we will calculate the z coordinate = nn * ceil(c / 4) + cc / 4, where c
3566   // is the channel count.
3567   // We always start a new batch on a new z. Hence, when c cannot be divided by
3568   // 4, there are some undefined values in the padding area. We use -1 to
3569   // indicate that we are not performing comparsion on those values.
3570   std::tuple<ivec3, ivec4> test_cases[]{
3571       {{0, 0, 0}, {0, hw, 2 * hw, 3 * hw}},
3572       {{1, 0, 0}, {1, hw + 1, 2 * hw + 1, 3 * hw + 1}},
3573       {{0, 0, 1}, {4 * hw, -1, -1, -1}},
3574       {{0, 0, 2}, {chw, chw + hw, chw + 2 * hw, chw + 3 * hw}},
3575       {{0, 1, 2}, {chw + w, chw + hw + w, chw + 2 * hw + w, chw + 3 * hw + w}},
3576       {{0, 0, 3}, {chw + 4 * hw, -1, -1, -1}},
3577       {{0, 1, 3}, {chw + 4 * hw + w, -1, -1, -1}},
3578       {{0, 0, 4}, {2 * chw, 2 * chw + hw, 2 * chw + 2 * hw, 2 * chw + 3 * hw}},
3579       {{0, 1, 4},
3580        {2 * chw + w,
3581         2 * chw + hw + w,
3582         2 * chw + 2 * hw + w,
3583         2 * chw + 3 * hw + w}},
3584   };
3585 
3586   bool has_failure = false;
3587   for (const auto& test_case : test_cases) {
3588     const auto [loc, expected] = test_case;
3589 
3590     vec4 actual = ops::utils::extract_texel(vk, loc);
3591     if (!texel_almost_equal(expected, actual)) {
3592       std::cout << "On loc: " << loc << " expected: " << expected
3593                 << " actual: " << actual << std::endl;
3594       has_failure = true;
3595     }
3596   }
3597   ASSERT_TRUE(!has_failure);
3598 }
3599 
TEST_F(VulkanAPITest,channel_to_height_packing_test)3600 TEST_F(VulkanAPITest, channel_to_height_packing_test) {
3601   int n = 3;
3602   int c = 5;
3603   int h = 6;
3604   int w = 7;
3605   int hw = h * w;
3606   int chw = c * h * w;
3607 
3608   auto data =
3609       at::range(0, n * c * h * w - 1, at::device(at::kCPU).dtype(at::kFloat))
3610           .reshape({n, c, h, w});
3611 
3612   auto v_input = at::native::vulkan::ops::convert(data.vulkan());
3613   auto v_output =
3614       packing::convert_image_channels_packed_to_height_packed(v_input);
3615   ASSERT_EQ(
3616       v_output.gpu_memory_layout(), api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED);
3617 
3618   // This output tensor is on vulkan, since we are interested in evaluating the
3619   // actual layout
3620   at::Tensor output = at::native::vulkan::ops::convert(v_output);
3621 
3622   // This tensor will be height-packed. Meaning that each texel represent
3623   // consecutive elements along the height dimension, element difference within
3624   // a texel is "w".
3625   std::tuple<ivec3, ivec4> test_cases[]{
3626       {{0, 0, 0}, {0, w, 2 * w, 3 * w}},
3627       {{0, 1, 0}, {4 * w, 5 * w, -1, -1}},
3628       {{1, 0, 0}, {0 * w + 1, 1 * w + 1, 2 * w + 1, 3 * w + 1}},
3629       {{1, 1, 0}, {4 * w + 1, 5 * w + 1, -1, -1}},
3630       {{0, 0, 4}, {4 * hw, 4 * hw + w, 4 * hw + 2 * w, 4 * hw + 3 * w}},
3631       {{0, 0, 4 + 2 * c},
3632        {2 * chw + 4 * hw,
3633         2 * chw + 4 * hw + w,
3634         2 * chw + 4 * hw + 2 * w,
3635         2 * chw + 4 * hw + 3 * w}},
3636   };
3637 
3638   bool has_failure = false;
3639   for (const auto& test_case : test_cases) {
3640     const auto [loc, expected] = test_case;
3641 
3642     vec4 actual = ops::utils::extract_texel(output, loc);
3643     if (!texel_almost_equal(expected, actual)) {
3644       std::cout << "On loc: " << loc << " expected: " << expected
3645                 << " actual: " << actual << std::endl;
3646       has_failure = true;
3647     }
3648   }
3649   ASSERT_TRUE(!has_failure);
3650 }
3651 
TEST_F(VulkanAPITest,channel_to_width_packing_test)3652 TEST_F(VulkanAPITest, channel_to_width_packing_test) {
3653   int n = 3;
3654   int c = 5;
3655   int h = 6;
3656   int w = 7;
3657   int hw = h * w;
3658   int chw = c * h * w;
3659 
3660   auto data =
3661       at::range(0, n * c * h * w - 1, at::device(at::kCPU).dtype(at::kFloat))
3662           .reshape({n, c, h, w});
3663 
3664   auto v_input = at::native::vulkan::ops::convert(data.vulkan());
3665   auto v_output =
3666       packing::convert_image_channels_packed_to_width_packed(v_input);
3667   ASSERT_EQ(
3668       v_output.gpu_memory_layout(), api::GPUMemoryLayout::TENSOR_WIDTH_PACKED);
3669 
3670   // This output tensor is on vulkan, since we are interested in evaluating the
3671   // actual layout
3672   at::Tensor output = at::native::vulkan::ops::convert(v_output);
3673 
3674   // This tensor will be width-packed. Meaning that each texel represent
3675   // consecutive elements along the width dimension. The  differece between
3676   // consecutive texels is 1.
3677   std::tuple<ivec3, ivec4> test_cases[]{
3678       {{0, 0, 0}, {0, 1, 2, 3}},
3679       {{1, 0, 0}, {4, 5, 6, -1}},
3680       {{0, 2, 0}, {2 * w + 0, 2 * w + 1, 2 * w + 2, 2 * w + 3}},
3681       {{1, 2, 0}, {2 * w + 4, 2 * w + 5, 2 * w + 6, -1}},
3682       {{0, 0, 4}, {4 * hw + 0, 4 * hw + 1, 4 * hw + 2, 4 * hw + 3}},
3683       {{1, 0, 4}, {4 * hw + 4, 4 * hw + 5, 4 * hw + 6, -1}},
3684       {{0, 0, 4 + 2 * c},
3685        {2 * chw + 4 * hw,
3686         2 * chw + 4 * hw + 1,
3687         2 * chw + 4 * hw + 2,
3688         2 * chw + 4 * hw + 3}},
3689   };
3690 
3691   bool has_failure = false;
3692   for (const auto& test_case : test_cases) {
3693     const auto [loc, expected] = test_case;
3694 
3695     vec4 actual = ops::utils::extract_texel(output, loc);
3696     if (!texel_almost_equal(expected, actual)) {
3697       std::cout << "On loc: " << loc << " expected: " << expected
3698                 << " actual: " << actual << std::endl;
3699       has_failure = true;
3700     }
3701   }
3702   ASSERT_TRUE(!has_failure);
3703 }
3704 
test_gelu(const at::IntArrayRef input_shape,const c10::ScalarType dtype,bool self_test)3705 void test_gelu(
3706     const at::IntArrayRef input_shape,
3707     const c10::ScalarType dtype,
3708     bool self_test) {
3709   const auto& in_cpu = produce_random_tensor(input_shape);
3710 
3711   auto [scale, zero_point] = compute_quant_params(in_cpu, dtype);
3712   scale = safe_downcast<float>(scale);
3713 
3714   auto in_cpu_quantized =
3715       at::quantize_per_tensor(in_cpu, scale, zero_point, dtype);
3716 
3717   auto in_vk_quantized =
3718       at::quantize_per_tensor(in_cpu.vulkan(), scale, zero_point, dtype);
3719 
3720   auto approximate = "tanh";
3721 
3722   const auto& out_cpu_quantized = self_test
3723       ? at::gelu_(in_cpu_quantized, approximate)
3724       : at::gelu(in_cpu_quantized, approximate);
3725 
3726   const auto& out_vk_quantized = self_test
3727       ? at::gelu_(in_vk_quantized, approximate)
3728       : at::gelu(in_vk_quantized, approximate);
3729 
3730   const auto& out_cpu_deq = at::dequantize(out_cpu_quantized);
3731   const auto& out_vk_deq = at::dequantize(out_vk_quantized);
3732   const auto& out_vk_deq_cpu = out_vk_deq.cpu();
3733 
3734   const auto check = almostEqual(out_vk_deq_cpu, out_cpu_deq, scale);
3735 
3736   if (!check) {
3737     showRtol(out_cpu_deq, out_vk_deq_cpu);
3738   }
3739   ASSERT_TRUE(check);
3740 }
3741 
TEST_F(VulkanAPITest,gelu_qint8)3742 TEST_F(VulkanAPITest, gelu_qint8) {
3743   test_gelu({200, 20}, c10::ScalarType::QInt8, false);
3744   test_gelu({200, 20, 10}, c10::ScalarType::QInt8, false);
3745   test_gelu({200, 20, 30, 10}, c10::ScalarType::QInt8, false);
3746 }
3747 
TEST_F(VulkanAPITest,gelu_qint8_self)3748 TEST_F(VulkanAPITest, gelu_qint8_self) {
3749   test_gelu({4, 1, 4}, c10::ScalarType::QInt8, true);
3750   test_gelu({200, 20}, c10::ScalarType::QInt8, true);
3751   test_gelu({200, 20, 10}, c10::ScalarType::QInt8, true);
3752   test_gelu({200, 20, 30, 10}, c10::ScalarType::QInt8, true);
3753 }
3754 
TEST_F(VulkanAPITest,gelu_quint8)3755 TEST_F(VulkanAPITest, gelu_quint8) {
3756   test_gelu({200, 20}, c10::ScalarType::QUInt8, false);
3757   test_gelu({200, 20, 10}, c10::ScalarType::QUInt8, false);
3758   test_gelu({200, 20, 30, 10}, c10::ScalarType::QUInt8, false);
3759 }
3760 
TEST_F(VulkanAPITest,gelu_quint8_self)3761 TEST_F(VulkanAPITest, gelu_quint8_self) {
3762   test_gelu({4, 1, 4}, c10::ScalarType::QUInt8, true);
3763   test_gelu({200, 20}, c10::ScalarType::QUInt8, true);
3764   test_gelu({200, 20, 10}, c10::ScalarType::QUInt8, true);
3765   test_gelu({200, 20, 30, 10}, c10::ScalarType::QUInt8, true);
3766 }
3767 
test_relu(const at::IntArrayRef input_shape,const c10::ScalarType dtype,bool inplace)3768 void test_relu(
3769     const at::IntArrayRef input_shape,
3770     const c10::ScalarType dtype,
3771     bool inplace) {
3772   const auto in_cpu = produce_random_tensor(input_shape);
3773 
3774   const auto input_quant_params = compute_quant_params(in_cpu, dtype);
3775   double scale = std::get<0>(input_quant_params);
3776   scale = safe_downcast<float>(scale);
3777   int zero_point = std::get<1>(input_quant_params);
3778 
3779   auto in_cpu_quantized =
3780       at::quantize_per_tensor(in_cpu, scale, zero_point, dtype);
3781 
3782   auto in_vk_quantized =
3783       at::quantize_per_tensor(in_cpu.vulkan(), scale, zero_point, dtype);
3784 
3785   const auto out_cpu_quantized =
3786       inplace ? at::relu_(in_cpu_quantized) : at::relu(in_cpu_quantized);
3787 
3788   const auto out_vk_quantized =
3789       inplace ? at::relu_(in_vk_quantized) : at::relu(in_vk_quantized);
3790 
3791   const auto out_cpu_deq = at::dequantize(out_cpu_quantized);
3792   const auto out_vk_deq = at::dequantize(out_vk_quantized);
3793   const auto out_vk_deq_cpu = out_vk_deq.cpu();
3794 
3795   const auto check =
3796       almostEqual(out_vk_deq_cpu, out_cpu_deq, safe_downcast<float>(scale));
3797 
3798   if (!check) {
3799     showRtol(out_cpu_deq, out_vk_deq_cpu);
3800   }
3801   ASSERT_TRUE(check);
3802 }
3803 
TEST_F(VulkanAPITest,relu_qint8)3804 TEST_F(VulkanAPITest, relu_qint8) {
3805   test_relu({200, 20}, c10::ScalarType::QInt8, false);
3806   test_relu({200, 20, 10}, c10::ScalarType::QInt8, false);
3807   test_relu({200, 20, 30, 10}, c10::ScalarType::QInt8, false);
3808 }
3809 
TEST_F(VulkanAPITest,relu_qint8_inplace)3810 TEST_F(VulkanAPITest, relu_qint8_inplace) {
3811   test_relu({4, 1, 4}, c10::ScalarType::QInt8, true);
3812   test_relu({200, 20}, c10::ScalarType::QInt8, true);
3813   test_relu({200, 20, 10}, c10::ScalarType::QInt8, true);
3814   test_relu({200, 20, 30, 10}, c10::ScalarType::QInt8, true);
3815 }
3816 
TEST_F(VulkanAPITest,relu_quint8)3817 TEST_F(VulkanAPITest, relu_quint8) {
3818   test_relu({200, 20}, c10::ScalarType::QUInt8, false);
3819   test_relu({200, 20, 10}, c10::ScalarType::QUInt8, false);
3820   test_relu({200, 20, 30, 10}, c10::ScalarType::QUInt8, false);
3821 }
3822 
TEST_F(VulkanAPITest,relu_quint8_inplace)3823 TEST_F(VulkanAPITest, relu_quint8_inplace) {
3824   test_relu({4, 1, 4}, c10::ScalarType::QUInt8, true);
3825   test_relu({200, 20}, c10::ScalarType::QUInt8, true);
3826   test_relu({200, 20, 10}, c10::ScalarType::QUInt8, true);
3827   test_relu({200, 20, 30, 10}, c10::ScalarType::QUInt8, true);
3828 }
3829 
3830 } // namespace
3831 
3832 #endif /* USE_VULKAN_API */
3833