xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnormalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/layer_norm.h>
3 #include <ATen/native/quantized/cpu/QuantizedOps.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/accumulate.h>
6 #include <torch/library.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/_empty_affine_quantized.h>
12 #endif
13 
14 #include <algorithm>
15 #include <vector>
16 
17 namespace at {
18 namespace native {
19 
20 DEFINE_DISPATCH(quantized_normalize_stub);
21 DEFINE_DISPATCH(quantized_groupnorm_nhwc_stub);
22 
quantized_layer_norm_impl(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)23 static Tensor quantized_layer_norm_impl(
24     const Tensor& input,
25     IntArrayRef normalized_shape,
26     const Tensor& weight /* optional */,
27     const Tensor& bias /* optional */,
28     double eps,
29     double output_scale,
30     int64_t output_zero_point) {
31 
32   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
33   auto M = M_N.first;
34   auto N = M_N.second;
35   auto X = input.expect_contiguous();
36   auto gamma = weight.expect_contiguous();
37   auto beta = bias.expect_contiguous();
38 
39   Tensor Y = at::_empty_affine_quantized(
40     X->sizes(),
41     X->scalar_type(),
42     output_scale,
43     output_zero_point,
44     X->suggest_memory_format());
45 
46   if (M > 0) {
47     bool affine_per_channel = false;
48     int num_channels = 1; // not relevant for LayerNorm
49     int num_groups = 1; // not relevant for LayerNorm
50     quantized_normalize_stub(kCPU, *X, *gamma, *beta, affine_per_channel,
51         num_channels, num_groups, M, N, eps, &Y);
52   }
53   return Y;
54 }
55 
quantized_group_norm_impl(const Tensor & qx,int64_t num_groups,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)56 static Tensor quantized_group_norm_impl(
57     const Tensor& qx,
58     int64_t num_groups,
59     const Tensor& weight, // optional
60     const Tensor& bias, // optional
61     double eps,
62     double output_scale,
63     int64_t output_zero_point) {
64   bool is_channels_last = qx.is_contiguous(c10::MemoryFormat::ChannelsLast);
65   auto mem_layout = is_channels_last ? c10::MemoryFormat::ChannelsLast :
66                                        c10::MemoryFormat::Contiguous;
67 
68   const auto& qx_contig = qx.contiguous(mem_layout);
69   const auto& weight_contig = weight.contiguous();
70   const auto& bias_contig = bias.contiguous();
71 
72   const auto input_ndim = qx_contig.dim();
73   TORCH_CHECK(
74       input_ndim >= 3,
75       "Expected normalized_shape to be at least 3-dimensional");
76   TORCH_CHECK(num_groups > 0, "Expected num_groups to be positive");
77 
78   const auto input_shape = qx_contig.sizes();
79   TORCH_CHECK(input_shape[1] % num_groups == 0,
80       "Expected channels to be divisible by groups");
81 
82   const int64_t batches = input_shape[0];
83   const int64_t num_channels = input_shape[1];
84   const int64_t elements_per_batch =
85       c10::multiply_integers(input_shape.cbegin() + 1, input_shape.cend());
86 
87   const int64_t M = batches * num_groups;
88   const int64_t N = elements_per_batch / num_groups;
89 
90   Tensor Y = at::_empty_affine_quantized(
91     qx_contig.sizes(),
92     qx_contig.scalar_type(),
93     output_scale,
94     output_zero_point,
95     qx_contig.suggest_memory_format());
96 
97   if (M > 0) {
98     bool affine_per_channel = true;
99     if (is_channels_last) {
100       quantized_groupnorm_nhwc_stub(kCPU, qx_contig, weight_contig, bias_contig,
101           affine_per_channel, num_channels, num_groups, M, N, eps, &Y);
102     } else {
103       quantized_normalize_stub(kCPU, qx_contig, weight_contig, bias_contig,
104           affine_per_channel, num_channels, num_groups, M, N, eps, &Y);
105     }
106   }
107   return Y;
108 }
109 
quantized_instance_norm_impl(const Tensor & qx,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)110 static Tensor quantized_instance_norm_impl(
111     const Tensor& qx,
112     const Tensor& weight, // optional
113     const Tensor& bias, // optional
114     double eps,
115     double output_scale,
116     int64_t output_zero_point) {
117 
118   const auto input_ndim = qx.dim();
119   TORCH_CHECK(
120       input_ndim >= 3,
121       "Expected normalized_shape to be at least 3-dimensional");
122   const auto input_shape = qx.sizes();
123 
124   // IN is GN with num_groups == num_channels
125   const auto num_channels = input_shape[1];
126   TORCH_CHECK(num_channels > 0, "Expected 2nd dimension to be positive");
127 
128   return quantized_group_norm_impl(
129       qx, num_channels, weight, bias, eps, output_scale, output_zero_point);
130 }
131 
132 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)133 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
134   // TODO: this is kind of... blegh
135   m.impl(TORCH_SELECTIVE_NAME("quantized::layer_norm"), [](
136     Tensor input,
137     std::vector<int64_t> normalized_shape,  // because IntArrayRef doesn't work
138     std::optional<Tensor> weight,
139     std::optional<Tensor> bias,
140     double eps,
141     double output_scale,
142     int64_t output_zero_point) {
143       return quantized_layer_norm_impl(
144           input, normalized_shape,
145           weight.has_value() ? *weight : Tensor(),
146           bias.has_value() ? *bias : Tensor(),
147           eps, output_scale, output_zero_point);
148   });
149   m.impl(TORCH_SELECTIVE_NAME("quantized::group_norm"), [](
150       Tensor qx,
151       int64_t num_groups,
152       std::optional<Tensor> weight,
153       std::optional<Tensor> bias,
154       double eps,
155       double output_scale,
156       int64_t output_zero_point) {
157     return quantized_group_norm_impl(
158         qx, num_groups,
159         weight.has_value() ? *weight : Tensor(),
160         bias.has_value() ? *bias : Tensor(),
161         eps, output_scale, output_zero_point);
162   });
163   m.impl(TORCH_SELECTIVE_NAME("quantized::instance_norm"), [](
164       Tensor qx,
165       std::optional<Tensor> weight,
166       std::optional<Tensor> bias,
167       double eps,
168       double output_scale,
169       int64_t output_zero_point) {
170     return quantized_instance_norm_impl(
171         qx,
172         weight.has_value() ? *weight : Tensor(),
173         bias.has_value() ? *bias : Tensor(),
174         eps, output_scale, output_zero_point);
175   });
176 }
177 
178 } // namespace native
179 } // namespace at
180