xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/qembeddingbag.h>
7 #include <c10/util/irange.h>
8 #include <torch/library.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_per_channel_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/from_blob.h>
17 #include <ATen/ops/resize_native.h>
18 #endif
19 
20 int register_embedding_params();
21 
unpack()22 at::Tensor PackedEmbeddingBagWeight::unpack() {
23   auto packed_weight = packed_w;
24   at::Tensor weight_origin;
25 
26   if (bit_rate_ == 8 || bit_rate_ == 4) {
27     const auto input_rows = packed_weight.size(0);
28     const auto input_columns = packed_weight.size(1);
29     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
30     int scale_bias_bytes;
31     const auto num_elem_per_byte = 8 / bit_rate_;
32     if (bit_rate_ == 8) {
33       // The last 2 values are used to store the FP32 scale and zero_point
34       // values per row.
35       scale_bias_bytes = 8;
36     } else {
37       scale_bias_bytes = 4;
38     }
39 
40     const auto* input = packed_weight.const_data_ptr<uint8_t>();
41     // Calculate the output shape, accounting for the last n bytes to be used
42     // for scale/bias rest of the entries are packed depending on the bit_width.
43     std::vector<int64_t> output_shape = {
44         input_rows,
45         static_cast<std::int64_t>(input_columns - scale_bias_bytes) *
46             num_elem_per_byte};
47 
48     auto scales = at::from_blob(
49         w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
50     auto zero_points = at::from_blob(
51         w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat));
52 
53     auto output_columns = output_shape[1];
54     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
55     uint8_t* output_data;
56 
57     // Allocate output weight tensor based on the bit_width
58     if (bit_rate_ == 8) {
59       weight_origin = at::_empty_per_channel_affine_quantized(
60           output_shape,
61           scales.toType(c10::kFloat),
62           zero_points.toType(c10::kFloat),
63           0, // The output channel axis is 0
64           device(c10::kCPU).dtype(c10::kQUInt8));
65       output_data = static_cast<uint8_t*>(weight_origin.data_ptr());
66     } else {
67       // We create empty qtensor with the full output shape, and dtype set to
68       // quint4x2 This will internally allocate appropriate storage bytes to
69       // account for the packed nature of this dtype.
70       weight_origin = at::_empty_per_channel_affine_quantized(
71           output_shape,
72           scales.toType(c10::kFloat),
73           zero_points.toType(c10::kFloat),
74           0, // The output channel axis is 0
75           device(c10::kCPU).dtype(c10::kQUInt4x2));
76       output_data = static_cast<uint8_t*>(weight_origin.data_ptr());
77     }
78 
79     // Copy over the data from the packed weight to the output.
80     // For sub-byte tensors this will copy the packed bytes over since the
81     // sub_byte qtensors are expected to store data in packed format.
82     at::parallel_for(0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
83       for (const auto row : c10::irange(start_idx, end_idx)) {
84         const std::uint8_t* input_row = input + row * input_columns;
85         uint8_t* output_row =
86             output_data + row * output_columns / num_elem_per_byte;
87 
88         // output_columns
89         for (const auto col : c10::irange(output_columns / num_elem_per_byte)) {
90           output_row[col] = input_row[col];
91         }
92       }
93     });
94 
95     return weight_origin;
96   }
97   TORCH_INTERNAL_ASSERT(
98       false,
99       "We currently only support 8-bit and 4-bit quantization of embedding_bag.");
100   return weight_origin;
101 }
102 
103 namespace at {
104 namespace native {
105 
qembeddingbag_byte_unpack_out(Tensor & output,const Tensor & packed_weight)106 Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight) {
107   // The "last" dimension of an N-Dimensioned batch of embedding bags is
108   // quantization channel. E.g. for a 2D embedding bag, this has
109   // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
110   // [ batch, row, col ].
111   //
112   // Python Batched Embedding Example:
113   // weights = torch.from_numpy((np.random.random_sample((
114   //          2, 10, 3)).squeeze() + 1).astype(np.float32))
115   // assert(weights.size() == torch.Size([2, 10, 3]))
116   // # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
117   // packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
118   // assert(packed_weights.size() == torch.Size([2, 10, 11]))
119   // unpacked_weights = torch.ops.quantized.embedding_bag_byte_unpack(packed_weights)
120   // assert(unpacked_weights.size() == torch.Size([2, 10, 3]))
121   const auto packed_weight_sizes = packed_weight.sizes();
122   const auto col_dim = packed_weight_sizes.size() - 1;
123   const int64_t input_rows = c10::size_to_dim_(col_dim, packed_weight_sizes);
124   const int32_t input_columns = packed_weight_sizes[col_dim];
125   // The last 2 values are used to store the FP32 scale and zero_point values
126   // per row.
127   const int32_t output_columns = input_columns - 2 * sizeof(float);
128   const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
129 
130   std::vector<int64_t> output_shape = packed_weight_sizes.vec();
131   output_shape[col_dim] = output_columns;
132   at::native::resize_(output, output_shape);
133   auto output_contig = output.expect_contiguous();
134   float* output_data = output_contig->data_ptr<float>();
135 
136 #ifdef USE_FBGEMM
137   at::parallel_for(0, input_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
138     fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float>(
139         input_data + start_idx * input_columns,
140         end_idx - start_idx,
141         input_columns,
142         output_data + start_idx * output_columns);
143   });
144 #else
145   for (auto row : c10::irange(input_rows)) {
146     const std::uint8_t* input_row = input_data + row * input_columns;
147     const float* input_row_scale_zp =
148         reinterpret_cast<const float*>(input_row + output_columns);
149     float* output_row = output_data + row * output_columns;
150 
151     for (auto col : c10::irange(output_columns)) {
152       output_row[col] =
153           input_row[col] * input_row_scale_zp[0] + input_row_scale_zp[1];
154     } // output_columns
155   } // input_rows
156 #endif // USE_FBGEMM
157   return output;
158 }
159 
160 namespace {
qembeddingbag_byte_unpack(const Tensor & packed_weight)161 Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
162   at::Tensor output = at::empty(
163       {},
164       packed_weight.options().dtype(kFloat),
165       packed_weight.suggest_memory_format());
166   qembeddingbag_byte_unpack_out(output, packed_weight);
167   return output;
168 }
169 
qembeddingbag_byte_unpack_meta(const Tensor & packed_weight)170 Tensor qembeddingbag_byte_unpack_meta(const Tensor& packed_weight) {
171   const auto packed_weight_sizes = packed_weight.sym_sizes();
172   const auto col_dim = packed_weight_sizes.size() - 1;
173   const auto input_columns = packed_weight_sizes[col_dim];
174   // The last 2 values are used to store the FP32 scale and zero_point values
175   // per row.
176   const auto output_columns = input_columns - 2 * sizeof(float);
177 
178   auto output_shape = packed_weight_sizes.vec();
179   output_shape[col_dim] = output_columns;
180 
181   at::SymDimVector output_shape_vec(output_shape);
182   return at::empty_symint(output_shape_vec, packed_weight.options().dtype(kFloat), packed_weight.suggest_memory_format());
183 }
184 
_qembeddingbag_nbit_unpack_helper(const Tensor & packed_weight,int BIT_RATE)185 Tensor _qembeddingbag_nbit_unpack_helper(
186     const Tensor& packed_weight,
187     int BIT_RATE) {
188   const auto input_rows = packed_weight.size(0);
189   const auto input_columns = packed_weight.size(1);
190   const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
191   int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
192 
193   // The last 4 bytes per row are two fp16 scale and zero_point.
194   // The rest of input_columns is the number of values in the original row.
195   std::vector<int64_t> output_dimensions = {
196       input_rows,
197       static_cast<std::int64_t>(input_columns - 2 * sizeof(at::Half)) *
198           NUM_ELEM_PER_BYTE};
199 
200   auto output = at::empty(
201       output_dimensions,
202       packed_weight.options().dtype(kFloat),
203       packed_weight.suggest_memory_format());
204   float* output_data = output.data_ptr<float>();
205 #ifdef USE_FBGEMM
206   at::parallel_for(0, input_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
207     fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<float>(
208         BIT_RATE,
209         input_data + start_idx * input_columns,
210         end_idx - start_idx,
211         input_columns,
212         output_data + start_idx * output_dimensions[1]);
213   });
214 #else
215   auto output_columns = output_dimensions[1];
216   for (auto row : c10::irange(input_rows)) {
217     float* output_row = output_data + row * output_columns;
218     const std::uint8_t* input_row = input_data + row * input_columns;
219     const at::Half* input_row_scale_zp = reinterpret_cast<const at::Half*>(
220         input_row +
221         (output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
222     float scale = input_row_scale_zp[0];
223     float zero_point = input_row_scale_zp[1];
224 
225     for (const auto col : c10::irange(output_columns)) {
226       std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE];
227       quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE;
228       quantized &= (1 << BIT_RATE) - 1;
229       output_row[col] = scale * quantized + zero_point;
230     } // output_columns
231   } // input_rows
232 #endif // USE_FBGEMM
233 
234   return output;
235 }
236 
237 // De-quantizes the result of the qembeddingbag_4bit_prepack operator.
238 // The input is expected to first have quantized values,
239 // then 2-byte fp16 scale and 2-byte zero_offset.
240 // The output is a matrix containing only the values, but de-quantized.
241 // De-quantization is performed by multiplying each value by its
242 // row's scale and zero_point parameters. The de-quantized values
243 // will thus not be exactly equal to the original, un-quantized
244 // floating point values.
qembeddingbag_4bit_unpack(const Tensor & packed_weight)245 Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
246   return _qembeddingbag_nbit_unpack_helper(packed_weight, 4 /*BIT_RATE*/);
247 }
248 
249 // De-quantizes the result of the qembeddingbag_2bit_prepack operator.
250 // The input is expected to first have quantized values,
251 // then 2-byte fp16 scale and 2-byte zero_offset.
252 // The output is a matrix containing only the values, but de-quantized.
253 // De-quantization is performed by multiplying each value by its
254 // row's scale and zero_point parameters. The de-quantized values
255 // will thus not be exactly equal to the original, un-quantized
256 // floating point values.
qembeddingbag_2bit_unpack(const Tensor & packed_weight)257 Tensor qembeddingbag_2bit_unpack(const Tensor& packed_weight) {
258   return _qembeddingbag_nbit_unpack_helper(packed_weight, 2 /*BIT_RATE*/);
259 }
260 
261 class QEmbeddingUnpackWeights final {
262  public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight)263   static at::Tensor run(
264       const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight) {
265     return packed_weight->unpack();
266   }
267 };
268 
TORCH_LIBRARY_IMPL(quantized,CPU,m)269 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
270   m.impl(
271       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"),
272       qembeddingbag_byte_unpack);
273   m.impl(
274       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"),
275       qembeddingbag_4bit_unpack);
276   m.impl(
277       TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"),
278       qembeddingbag_2bit_unpack);
279 }
280 
TORCH_LIBRARY_IMPL(quantized,CatchAll,m)281 TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
282   // Unpack the packed embedding_bag weights using TorchBind custom class.
283   // TODO extend to support 4-bit qtensor.
284   m.impl(
285       TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"),
286       TORCH_FN(QEmbeddingUnpackWeights::run));
287 }
288 
TORCH_LIBRARY_IMPL(quantized,Meta,m)289 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
290   m.impl(
291       "quantized::embedding_bag_byte_unpack",
292       qembeddingbag_byte_unpack_meta);
293 }
294 
295 } // namespace
296 } // namespace native
297 } // namespace at
298