1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/layer_norm.h>
3 #include <ATen/native/quantized/cpu/QuantizedOps.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/accumulate.h>
6 #include <torch/library.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/_empty_affine_quantized.h>
12 #endif
13
14 #include <algorithm>
15 #include <vector>
16
17 namespace at {
18 namespace native {
19
20 DEFINE_DISPATCH(quantized_normalize_stub);
21 DEFINE_DISPATCH(quantized_groupnorm_nhwc_stub);
22
quantized_layer_norm_impl(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)23 static Tensor quantized_layer_norm_impl(
24 const Tensor& input,
25 IntArrayRef normalized_shape,
26 const Tensor& weight /* optional */,
27 const Tensor& bias /* optional */,
28 double eps,
29 double output_scale,
30 int64_t output_zero_point) {
31
32 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
33 auto M = M_N.first;
34 auto N = M_N.second;
35 auto X = input.expect_contiguous();
36 auto gamma = weight.expect_contiguous();
37 auto beta = bias.expect_contiguous();
38
39 Tensor Y = at::_empty_affine_quantized(
40 X->sizes(),
41 X->scalar_type(),
42 output_scale,
43 output_zero_point,
44 X->suggest_memory_format());
45
46 if (M > 0) {
47 bool affine_per_channel = false;
48 int num_channels = 1; // not relevant for LayerNorm
49 int num_groups = 1; // not relevant for LayerNorm
50 quantized_normalize_stub(kCPU, *X, *gamma, *beta, affine_per_channel,
51 num_channels, num_groups, M, N, eps, &Y);
52 }
53 return Y;
54 }
55
quantized_group_norm_impl(const Tensor & qx,int64_t num_groups,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)56 static Tensor quantized_group_norm_impl(
57 const Tensor& qx,
58 int64_t num_groups,
59 const Tensor& weight, // optional
60 const Tensor& bias, // optional
61 double eps,
62 double output_scale,
63 int64_t output_zero_point) {
64 bool is_channels_last = qx.is_contiguous(c10::MemoryFormat::ChannelsLast);
65 auto mem_layout = is_channels_last ? c10::MemoryFormat::ChannelsLast :
66 c10::MemoryFormat::Contiguous;
67
68 const auto& qx_contig = qx.contiguous(mem_layout);
69 const auto& weight_contig = weight.contiguous();
70 const auto& bias_contig = bias.contiguous();
71
72 const auto input_ndim = qx_contig.dim();
73 TORCH_CHECK(
74 input_ndim >= 3,
75 "Expected normalized_shape to be at least 3-dimensional");
76 TORCH_CHECK(num_groups > 0, "Expected num_groups to be positive");
77
78 const auto input_shape = qx_contig.sizes();
79 TORCH_CHECK(input_shape[1] % num_groups == 0,
80 "Expected channels to be divisible by groups");
81
82 const int64_t batches = input_shape[0];
83 const int64_t num_channels = input_shape[1];
84 const int64_t elements_per_batch =
85 c10::multiply_integers(input_shape.cbegin() + 1, input_shape.cend());
86
87 const int64_t M = batches * num_groups;
88 const int64_t N = elements_per_batch / num_groups;
89
90 Tensor Y = at::_empty_affine_quantized(
91 qx_contig.sizes(),
92 qx_contig.scalar_type(),
93 output_scale,
94 output_zero_point,
95 qx_contig.suggest_memory_format());
96
97 if (M > 0) {
98 bool affine_per_channel = true;
99 if (is_channels_last) {
100 quantized_groupnorm_nhwc_stub(kCPU, qx_contig, weight_contig, bias_contig,
101 affine_per_channel, num_channels, num_groups, M, N, eps, &Y);
102 } else {
103 quantized_normalize_stub(kCPU, qx_contig, weight_contig, bias_contig,
104 affine_per_channel, num_channels, num_groups, M, N, eps, &Y);
105 }
106 }
107 return Y;
108 }
109
quantized_instance_norm_impl(const Tensor & qx,const Tensor & weight,const Tensor & bias,double eps,double output_scale,int64_t output_zero_point)110 static Tensor quantized_instance_norm_impl(
111 const Tensor& qx,
112 const Tensor& weight, // optional
113 const Tensor& bias, // optional
114 double eps,
115 double output_scale,
116 int64_t output_zero_point) {
117
118 const auto input_ndim = qx.dim();
119 TORCH_CHECK(
120 input_ndim >= 3,
121 "Expected normalized_shape to be at least 3-dimensional");
122 const auto input_shape = qx.sizes();
123
124 // IN is GN with num_groups == num_channels
125 const auto num_channels = input_shape[1];
126 TORCH_CHECK(num_channels > 0, "Expected 2nd dimension to be positive");
127
128 return quantized_group_norm_impl(
129 qx, num_channels, weight, bias, eps, output_scale, output_zero_point);
130 }
131
132
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)133 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
134 // TODO: this is kind of... blegh
135 m.impl(TORCH_SELECTIVE_NAME("quantized::layer_norm"), [](
136 Tensor input,
137 std::vector<int64_t> normalized_shape, // because IntArrayRef doesn't work
138 std::optional<Tensor> weight,
139 std::optional<Tensor> bias,
140 double eps,
141 double output_scale,
142 int64_t output_zero_point) {
143 return quantized_layer_norm_impl(
144 input, normalized_shape,
145 weight.has_value() ? *weight : Tensor(),
146 bias.has_value() ? *bias : Tensor(),
147 eps, output_scale, output_zero_point);
148 });
149 m.impl(TORCH_SELECTIVE_NAME("quantized::group_norm"), [](
150 Tensor qx,
151 int64_t num_groups,
152 std::optional<Tensor> weight,
153 std::optional<Tensor> bias,
154 double eps,
155 double output_scale,
156 int64_t output_zero_point) {
157 return quantized_group_norm_impl(
158 qx, num_groups,
159 weight.has_value() ? *weight : Tensor(),
160 bias.has_value() ? *bias : Tensor(),
161 eps, output_scale, output_zero_point);
162 });
163 m.impl(TORCH_SELECTIVE_NAME("quantized::instance_norm"), [](
164 Tensor qx,
165 std::optional<Tensor> weight,
166 std::optional<Tensor> bias,
167 double eps,
168 double output_scale,
169 int64_t output_zero_point) {
170 return quantized_instance_norm_impl(
171 qx,
172 weight.has_value() ? *weight : Tensor(),
173 bias.has_value() ? *bias : Tensor(),
174 eps, output_scale, output_zero_point);
175 });
176 }
177
178 } // namespace native
179 } // namespace at
180