xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Mm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/quantized/PackedParams.h>
6 #include <ATen/native/vulkan/ops/Common.h>
7 #include <ATen/native/vulkan/ops/Utils.h>
8 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
9 #include <torch/library.h>
10 
11 namespace at {
12 namespace native {
13 namespace vulkan {
14 namespace ops {
15 
16 template <typename T>
stage_pack_weights(api::Context * const context,vTensor & v_weight,const Tensor & weight,const int64_t src_kb_sz,const int64_t src_kh_sz,const int64_t src_kw_sz,const int64_t dst_kh_sz,const int64_t dst_kw_sz)17 void stage_pack_weights(
18     api::Context* const context,
19     vTensor& v_weight,
20     const Tensor& weight,
21     const int64_t src_kb_sz,
22     const int64_t src_kh_sz,
23     const int64_t src_kw_sz,
24     const int64_t dst_kh_sz,
25     const int64_t dst_kw_sz) {
26   const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
27   const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
28   const int64_t dst_matrix_sz = dst_plane_sz * 4;
29   const T* const src_weight_ptr = weight.const_data_ptr<T>();
30   api::StorageBuffer staging(context, api::kFloat, v_weight.gpu_numel());
31   {
32     api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
33 
34     T* dst_weight_ptr = mapping.template data<T>();
35 
36     memset(dst_weight_ptr, 0, v_weight.nbytes());
37 
38     for (const auto src_b : c10::irange(src_kb_sz)) {
39       for (const auto src_h : c10::irange(src_kh_sz)) {
40         for (const auto src_w : c10::irange(src_kw_sz)) {
41           int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
42           int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
43           memcpy(
44               dst_weight_ptr + src_b * dst_matrix_sz +
45                   dst_plane * dst_plane_sz + dst_index,
46               src_weight_ptr + src_b * src_matrix_sz + src_h * src_kw_sz +
47                   src_w,
48               sizeof(T));
49         }
50       }
51     }
52   }
53   utils::pack_staging_to_vtensor(staging.buffer(), v_weight);
54 }
55 
56 class LinearPackedContext final : virtual public VulkanPackedContext,
57                                   public torch::jit::CustomClassHolder {
58  private:
59   c10::impl::GenericList unpacked_;
60 
61  public:
62   LinearPackedContext(
63       const Tensor& weight,
64       const std::optional<Tensor>& bias,
65       const bool use_batch = false);
66 
67   /*
68    * Assigns a name to each index in the unpacked list.
69    */
70   struct Unpacked final {
71     static constexpr uint32_t Weight = 0u;
72     static constexpr uint32_t Bias = 1u;
73 
74     static constexpr uint32_t NumArgs = 2u;
75   };
76 
77   /*
78    * Assigns a name to each index in the packed list.
79    */
80   struct Packed final {
81     static constexpr uint32_t Weight = 0u;
82     static constexpr uint32_t Bias = 1u;
83     static constexpr uint32_t WeightSizes = 2u;
84     static constexpr uint32_t BiasDefined = 3u;
85 
86     static constexpr uint32_t NumArgs = 4u;
87   };
88 
89   static LinearPackedContext pack(c10::impl::GenericList);
90 
unpack()91   const c10::impl::GenericList unpack() const override {
92     TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
93 
94     return unpacked_;
95   }
96 };
97 
98 c10::intrusive_ptr<LinearPackedContext> create_linear_context(
99     Tensor&& weight,
100     std::optional<Tensor>&& bias);
101 
102 Tensor run_linear_context(
103     const Tensor& input,
104     const c10::intrusive_ptr<LinearPackedContext>& context);
105 
106 Tensor run_qlinear_context(
107     const Tensor& input,
108     double output_scale,
109     int64_t output_zero_point,
110     const c10::intrusive_ptr<LinearPackedContext>& context);
111 
112 } // namespace ops
113 } // namespace vulkan
114 } // namespace native
115 } // namespace at
116 
117 #endif /* USE_VULKAN_API */
118