1 #pragma once
2
3 #ifdef USE_VULKAN_API
4
5 #include <ATen/native/quantized/PackedParams.h>
6 #include <ATen/native/vulkan/ops/Common.h>
7 #include <ATen/native/vulkan/ops/Utils.h>
8 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
9 #include <torch/library.h>
10
11 namespace at {
12 namespace native {
13 namespace vulkan {
14 namespace ops {
15
16 template <typename T>
stage_pack_weights(api::Context * const context,vTensor & v_weight,const Tensor & weight,const int64_t src_kb_sz,const int64_t src_kh_sz,const int64_t src_kw_sz,const int64_t dst_kh_sz,const int64_t dst_kw_sz)17 void stage_pack_weights(
18 api::Context* const context,
19 vTensor& v_weight,
20 const Tensor& weight,
21 const int64_t src_kb_sz,
22 const int64_t src_kh_sz,
23 const int64_t src_kw_sz,
24 const int64_t dst_kh_sz,
25 const int64_t dst_kw_sz) {
26 const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
27 const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
28 const int64_t dst_matrix_sz = dst_plane_sz * 4;
29 const T* const src_weight_ptr = weight.const_data_ptr<T>();
30 api::StorageBuffer staging(context, api::kFloat, v_weight.gpu_numel());
31 {
32 api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
33
34 T* dst_weight_ptr = mapping.template data<T>();
35
36 memset(dst_weight_ptr, 0, v_weight.nbytes());
37
38 for (const auto src_b : c10::irange(src_kb_sz)) {
39 for (const auto src_h : c10::irange(src_kh_sz)) {
40 for (const auto src_w : c10::irange(src_kw_sz)) {
41 int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
42 int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
43 memcpy(
44 dst_weight_ptr + src_b * dst_matrix_sz +
45 dst_plane * dst_plane_sz + dst_index,
46 src_weight_ptr + src_b * src_matrix_sz + src_h * src_kw_sz +
47 src_w,
48 sizeof(T));
49 }
50 }
51 }
52 }
53 utils::pack_staging_to_vtensor(staging.buffer(), v_weight);
54 }
55
56 class LinearPackedContext final : virtual public VulkanPackedContext,
57 public torch::jit::CustomClassHolder {
58 private:
59 c10::impl::GenericList unpacked_;
60
61 public:
62 LinearPackedContext(
63 const Tensor& weight,
64 const std::optional<Tensor>& bias,
65 const bool use_batch = false);
66
67 /*
68 * Assigns a name to each index in the unpacked list.
69 */
70 struct Unpacked final {
71 static constexpr uint32_t Weight = 0u;
72 static constexpr uint32_t Bias = 1u;
73
74 static constexpr uint32_t NumArgs = 2u;
75 };
76
77 /*
78 * Assigns a name to each index in the packed list.
79 */
80 struct Packed final {
81 static constexpr uint32_t Weight = 0u;
82 static constexpr uint32_t Bias = 1u;
83 static constexpr uint32_t WeightSizes = 2u;
84 static constexpr uint32_t BiasDefined = 3u;
85
86 static constexpr uint32_t NumArgs = 4u;
87 };
88
89 static LinearPackedContext pack(c10::impl::GenericList);
90
unpack()91 const c10::impl::GenericList unpack() const override {
92 TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
93
94 return unpacked_;
95 }
96 };
97
98 c10::intrusive_ptr<LinearPackedContext> create_linear_context(
99 Tensor&& weight,
100 std::optional<Tensor>&& bias);
101
102 Tensor run_linear_context(
103 const Tensor& input,
104 const c10::intrusive_ptr<LinearPackedContext>& context);
105
106 Tensor run_qlinear_context(
107 const Tensor& input,
108 double output_scale,
109 int64_t output_zero_point,
110 const c10::intrusive_ptr<LinearPackedContext>& context);
111
112 } // namespace ops
113 } // namespace vulkan
114 } // namespace native
115 } // namespace at
116
117 #endif /* USE_VULKAN_API */
118