xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_mixed_linear.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/vec_ops.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
16 using Tensor = exec_aten::Tensor;
17 
check_quantized_mixed_linear_args(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)18 bool check_quantized_mixed_linear_args(
19     const Tensor& in,
20     const Tensor& weight,
21     const Tensor& weight_scales,
22     const exec_aten::optional<Tensor>& opt_weight_zero_points,
23     const exec_aten::optional<ScalarType> dtype,
24     Tensor& out) {
25   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
26   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2));
27   ET_LOG_AND_RETURN_IF_FALSE(
28       tensor_is_rank(weight_scales, 1) || tensor_is_rank(weight_scales, 2));
29   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
30 
31   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
32   ET_LOG_AND_RETURN_IF_FALSE(
33       tensors_have_same_size_at_dims(weight_scales, 0, weight, 0));
34   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
35 
36   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales));
37   if (dtype.has_value()) {
38     ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
39     ET_LOG_MSG_AND_RETURN_IF_FALSE(
40         dtype.value() == ScalarType::Float || dtype.value() == ScalarType::Half,
41         "dtype must be Float or Half");
42   }
43   ET_LOG_MSG_AND_RETURN_IF_FALSE(
44       weight.scalar_type() == ScalarType::Char, "weight dtype must be int8");
45   ET_LOG_MSG_AND_RETURN_IF_FALSE(
46       in.scalar_type() == ScalarType::Float ||
47           in.scalar_type() == ScalarType::Half,
48       "input dtype must be Float or Half");
49 
50   if (opt_weight_zero_points.has_value()) {
51     ET_LOG_AND_RETURN_IF_FALSE(
52         tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales));
53     ET_LOG_AND_RETURN_IF_FALSE(
54         tensors_have_same_dtype(opt_weight_zero_points.value(), in));
55   }
56 
57   // Support for non-null zero points is not implemented yet.
58   ET_LOG_MSG_AND_RETURN_IF_FALSE(
59       !opt_weight_zero_points.has_value(), "zero points not supported yet.");
60   return true;
61 }
62 
quantized_mixed_linear_out(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)63 Tensor& quantized_mixed_linear_out(
64     const Tensor& in,
65     const Tensor& weight,
66     const Tensor& weight_scales,
67     const exec_aten::optional<Tensor>& opt_weight_zero_points,
68     const exec_aten::optional<ScalarType> dtype,
69     Tensor& out) {
70   // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available.
71   ET_CHECK(check_quantized_mixed_linear_args(
72       in, weight, weight_scales, opt_weight_zero_points, dtype, out));
73 
74   ScalarType out_dtype = dtype.has_value() ? dtype.value() : out.scalar_type();
75 
76   size_t output_ndim = 2;
77   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
78   output_sizes[0] = in.size(0);
79   output_sizes[1] = weight.size(0);
80 
81   // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available.
82   ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok);
83 
84   constexpr auto name = "quantized_decomposed::mixed_linear.out";
85 
86   ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
87     ET_SWITCH_FLOAT_TYPES_AND(Half, out_dtype, ctx, name, CTYPE_OUT, [&]() {
88       size_t m = in.size(0);
89       size_t n = in.size(1);
90       size_t p = weight.size(0);
91       size_t g = n;
92 
93       if (weight_scales.dim() == 2) {
94         g = (n + weight_scales.size(1) - 1) / weight_scales.size(1);
95       };
96 
97       // FIXME: this currently ignores dtype
98       vec_quantized_matmul_transb_int8<
99           CTYPE_OUT, // T *z
100           CTYPE>( // U *x, U *s
101           out.mutable_data_ptr<CTYPE_OUT>(),
102           in.const_data_ptr<CTYPE>(),
103           weight.const_data_ptr<int8_t>(),
104           weight_scales.const_data_ptr<CTYPE>(),
105           m,
106           n,
107           p,
108           g);
109     });
110   });
111 
112   return out;
113 }
114 
quantized_mixed_linear_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)115 Tensor& quantized_mixed_linear_out(
116     KernelRuntimeContext& ctx,
117     const Tensor& in,
118     const Tensor& weight,
119     const Tensor& weight_scales,
120     const exec_aten::optional<Tensor>& opt_weight_zero_points,
121     const exec_aten::optional<ScalarType> dtype,
122     Tensor& out) {
123   // TODO(mcandales): Remove the need for this wrapper
124   // TODO(mkg): add support for dtype
125   (void)ctx;
126   return quantized_mixed_linear_out(
127       in, weight, weight_scales, opt_weight_zero_points, dtype, out);
128 }
129 
130 } // namespace native
131 } // namespace executor
132 } // namespace torch
133