1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/qembeddingbag.h>
7 #include <c10/util/irange.h>
8 #include <torch/library.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_per_channel_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/from_blob.h>
17 #include <ATen/ops/resize_native.h>
18 #endif
19
20 int register_embedding_params();
21
unpack()22 at::Tensor PackedEmbeddingBagWeight::unpack() {
23 auto packed_weight = packed_w;
24 at::Tensor weight_origin;
25
26 if (bit_rate_ == 8 || bit_rate_ == 4) {
27 const auto input_rows = packed_weight.size(0);
28 const auto input_columns = packed_weight.size(1);
29 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
30 int scale_bias_bytes;
31 const auto num_elem_per_byte = 8 / bit_rate_;
32 if (bit_rate_ == 8) {
33 // The last 2 values are used to store the FP32 scale and zero_point
34 // values per row.
35 scale_bias_bytes = 8;
36 } else {
37 scale_bias_bytes = 4;
38 }
39
40 const auto* input = packed_weight.const_data_ptr<uint8_t>();
41 // Calculate the output shape, accounting for the last n bytes to be used
42 // for scale/bias rest of the entries are packed depending on the bit_width.
43 std::vector<int64_t> output_shape = {
44 input_rows,
45 static_cast<std::int64_t>(input_columns - scale_bias_bytes) *
46 num_elem_per_byte};
47
48 auto scales = at::from_blob(
49 w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
50 auto zero_points = at::from_blob(
51 w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat));
52
53 auto output_columns = output_shape[1];
54 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
55 uint8_t* output_data;
56
57 // Allocate output weight tensor based on the bit_width
58 if (bit_rate_ == 8) {
59 weight_origin = at::_empty_per_channel_affine_quantized(
60 output_shape,
61 scales.toType(c10::kFloat),
62 zero_points.toType(c10::kFloat),
63 0, // The output channel axis is 0
64 device(c10::kCPU).dtype(c10::kQUInt8));
65 output_data = static_cast<uint8_t*>(weight_origin.data_ptr());
66 } else {
67 // We create empty qtensor with the full output shape, and dtype set to
68 // quint4x2 This will internally allocate appropriate storage bytes to
69 // account for the packed nature of this dtype.
70 weight_origin = at::_empty_per_channel_affine_quantized(
71 output_shape,
72 scales.toType(c10::kFloat),
73 zero_points.toType(c10::kFloat),
74 0, // The output channel axis is 0
75 device(c10::kCPU).dtype(c10::kQUInt4x2));
76 output_data = static_cast<uint8_t*>(weight_origin.data_ptr());
77 }
78
79 // Copy over the data from the packed weight to the output.
80 // For sub-byte tensors this will copy the packed bytes over since the
81 // sub_byte qtensors are expected to store data in packed format.
82 at::parallel_for(0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
83 for (const auto row : c10::irange(start_idx, end_idx)) {
84 const std::uint8_t* input_row = input + row * input_columns;
85 uint8_t* output_row =
86 output_data + row * output_columns / num_elem_per_byte;
87
88 // output_columns
89 for (const auto col : c10::irange(output_columns / num_elem_per_byte)) {
90 output_row[col] = input_row[col];
91 }
92 }
93 });
94
95 return weight_origin;
96 }
97 TORCH_INTERNAL_ASSERT(
98 false,
99 "We currently only support 8-bit and 4-bit quantization of embedding_bag.");
100 return weight_origin;
101 }
102
103 namespace at {
104 namespace native {
105
qembeddingbag_byte_unpack_out(Tensor & output,const Tensor & packed_weight)106 Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight) {
107 // The "last" dimension of an N-Dimensioned batch of embedding bags is
108 // quantization channel. E.g. for a 2D embedding bag, this has
109 // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
110 // [ batch, row, col ].
111 //
112 // Python Batched Embedding Example:
113 // weights = torch.from_numpy((np.random.random_sample((
114 // 2, 10, 3)).squeeze() + 1).astype(np.float32))
115 // assert(weights.size() == torch.Size([2, 10, 3]))
116 // # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
117 // packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
118 // assert(packed_weights.size() == torch.Size([2, 10, 11]))
119 // unpacked_weights = torch.ops.quantized.embedding_bag_byte_unpack(packed_weights)
120 // assert(unpacked_weights.size() == torch.Size([2, 10, 3]))
121 const auto packed_weight_sizes = packed_weight.sizes();
122 const auto col_dim = packed_weight_sizes.size() - 1;
123 const int64_t input_rows = c10::size_to_dim_(col_dim, packed_weight_sizes);
124 const int32_t input_columns = packed_weight_sizes[col_dim];
125 // The last 2 values are used to store the FP32 scale and zero_point values
126 // per row.
127 const int32_t output_columns = input_columns - 2 * sizeof(float);
128 const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
129
130 std::vector<int64_t> output_shape = packed_weight_sizes.vec();
131 output_shape[col_dim] = output_columns;
132 at::native::resize_(output, output_shape);
133 auto output_contig = output.expect_contiguous();
134 float* output_data = output_contig->data_ptr<float>();
135
136 #ifdef USE_FBGEMM
137 at::parallel_for(0, input_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
138 fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float>(
139 input_data + start_idx * input_columns,
140 end_idx - start_idx,
141 input_columns,
142 output_data + start_idx * output_columns);
143 });
144 #else
145 for (auto row : c10::irange(input_rows)) {
146 const std::uint8_t* input_row = input_data + row * input_columns;
147 const float* input_row_scale_zp =
148 reinterpret_cast<const float*>(input_row + output_columns);
149 float* output_row = output_data + row * output_columns;
150
151 for (auto col : c10::irange(output_columns)) {
152 output_row[col] =
153 input_row[col] * input_row_scale_zp[0] + input_row_scale_zp[1];
154 } // output_columns
155 } // input_rows
156 #endif // USE_FBGEMM
157 return output;
158 }
159
160 namespace {
qembeddingbag_byte_unpack(const Tensor & packed_weight)161 Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
162 at::Tensor output = at::empty(
163 {},
164 packed_weight.options().dtype(kFloat),
165 packed_weight.suggest_memory_format());
166 qembeddingbag_byte_unpack_out(output, packed_weight);
167 return output;
168 }
169
qembeddingbag_byte_unpack_meta(const Tensor & packed_weight)170 Tensor qembeddingbag_byte_unpack_meta(const Tensor& packed_weight) {
171 const auto packed_weight_sizes = packed_weight.sym_sizes();
172 const auto col_dim = packed_weight_sizes.size() - 1;
173 const auto input_columns = packed_weight_sizes[col_dim];
174 // The last 2 values are used to store the FP32 scale and zero_point values
175 // per row.
176 const auto output_columns = input_columns - 2 * sizeof(float);
177
178 auto output_shape = packed_weight_sizes.vec();
179 output_shape[col_dim] = output_columns;
180
181 at::SymDimVector output_shape_vec(output_shape);
182 return at::empty_symint(output_shape_vec, packed_weight.options().dtype(kFloat), packed_weight.suggest_memory_format());
183 }
184
_qembeddingbag_nbit_unpack_helper(const Tensor & packed_weight,int BIT_RATE)185 Tensor _qembeddingbag_nbit_unpack_helper(
186 const Tensor& packed_weight,
187 int BIT_RATE) {
188 const auto input_rows = packed_weight.size(0);
189 const auto input_columns = packed_weight.size(1);
190 const auto* input_data = packed_weight.const_data_ptr<uint8_t>();
191 int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
192
193 // The last 4 bytes per row are two fp16 scale and zero_point.
194 // The rest of input_columns is the number of values in the original row.
195 std::vector<int64_t> output_dimensions = {
196 input_rows,
197 static_cast<std::int64_t>(input_columns - 2 * sizeof(at::Half)) *
198 NUM_ELEM_PER_BYTE};
199
200 auto output = at::empty(
201 output_dimensions,
202 packed_weight.options().dtype(kFloat),
203 packed_weight.suggest_memory_format());
204 float* output_data = output.data_ptr<float>();
205 #ifdef USE_FBGEMM
206 at::parallel_for(0, input_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
207 fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<float>(
208 BIT_RATE,
209 input_data + start_idx * input_columns,
210 end_idx - start_idx,
211 input_columns,
212 output_data + start_idx * output_dimensions[1]);
213 });
214 #else
215 auto output_columns = output_dimensions[1];
216 for (auto row : c10::irange(input_rows)) {
217 float* output_row = output_data + row * output_columns;
218 const std::uint8_t* input_row = input_data + row * input_columns;
219 const at::Half* input_row_scale_zp = reinterpret_cast<const at::Half*>(
220 input_row +
221 (output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
222 float scale = input_row_scale_zp[0];
223 float zero_point = input_row_scale_zp[1];
224
225 for (const auto col : c10::irange(output_columns)) {
226 std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE];
227 quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE;
228 quantized &= (1 << BIT_RATE) - 1;
229 output_row[col] = scale * quantized + zero_point;
230 } // output_columns
231 } // input_rows
232 #endif // USE_FBGEMM
233
234 return output;
235 }
236
237 // De-quantizes the result of the qembeddingbag_4bit_prepack operator.
238 // The input is expected to first have quantized values,
239 // then 2-byte fp16 scale and 2-byte zero_offset.
240 // The output is a matrix containing only the values, but de-quantized.
241 // De-quantization is performed by multiplying each value by its
242 // row's scale and zero_point parameters. The de-quantized values
243 // will thus not be exactly equal to the original, un-quantized
244 // floating point values.
qembeddingbag_4bit_unpack(const Tensor & packed_weight)245 Tensor qembeddingbag_4bit_unpack(const Tensor& packed_weight) {
246 return _qembeddingbag_nbit_unpack_helper(packed_weight, 4 /*BIT_RATE*/);
247 }
248
249 // De-quantizes the result of the qembeddingbag_2bit_prepack operator.
250 // The input is expected to first have quantized values,
251 // then 2-byte fp16 scale and 2-byte zero_offset.
252 // The output is a matrix containing only the values, but de-quantized.
253 // De-quantization is performed by multiplying each value by its
254 // row's scale and zero_point parameters. The de-quantized values
255 // will thus not be exactly equal to the original, un-quantized
256 // floating point values.
qembeddingbag_2bit_unpack(const Tensor & packed_weight)257 Tensor qembeddingbag_2bit_unpack(const Tensor& packed_weight) {
258 return _qembeddingbag_nbit_unpack_helper(packed_weight, 2 /*BIT_RATE*/);
259 }
260
261 class QEmbeddingUnpackWeights final {
262 public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight)263 static at::Tensor run(
264 const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight) {
265 return packed_weight->unpack();
266 }
267 };
268
TORCH_LIBRARY_IMPL(quantized,CPU,m)269 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
270 m.impl(
271 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"),
272 qembeddingbag_byte_unpack);
273 m.impl(
274 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"),
275 qembeddingbag_4bit_unpack);
276 m.impl(
277 TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"),
278 qembeddingbag_2bit_unpack);
279 }
280
TORCH_LIBRARY_IMPL(quantized,CatchAll,m)281 TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
282 // Unpack the packed embedding_bag weights using TorchBind custom class.
283 // TODO extend to support 4-bit qtensor.
284 m.impl(
285 TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"),
286 TORCH_FN(QEmbeddingUnpackWeights::run));
287 }
288
TORCH_LIBRARY_IMPL(quantized,Meta,m)289 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
290 m.impl(
291 "quantized::embedding_bag_byte_unpack",
292 qembeddingbag_byte_unpack_meta);
293 }
294
295 } // namespace
296 } // namespace native
297 } // namespace at
298