1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/vec_ops.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15
16 using Tensor = exec_aten::Tensor;
17
check_quantized_mixed_linear_args(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)18 bool check_quantized_mixed_linear_args(
19 const Tensor& in,
20 const Tensor& weight,
21 const Tensor& weight_scales,
22 const exec_aten::optional<Tensor>& opt_weight_zero_points,
23 const exec_aten::optional<ScalarType> dtype,
24 Tensor& out) {
25 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
26 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2));
27 ET_LOG_AND_RETURN_IF_FALSE(
28 tensor_is_rank(weight_scales, 1) || tensor_is_rank(weight_scales, 2));
29 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
30
31 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
32 ET_LOG_AND_RETURN_IF_FALSE(
33 tensors_have_same_size_at_dims(weight_scales, 0, weight, 0));
34 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
35
36 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales));
37 if (dtype.has_value()) {
38 ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
39 ET_LOG_MSG_AND_RETURN_IF_FALSE(
40 dtype.value() == ScalarType::Float || dtype.value() == ScalarType::Half,
41 "dtype must be Float or Half");
42 }
43 ET_LOG_MSG_AND_RETURN_IF_FALSE(
44 weight.scalar_type() == ScalarType::Char, "weight dtype must be int8");
45 ET_LOG_MSG_AND_RETURN_IF_FALSE(
46 in.scalar_type() == ScalarType::Float ||
47 in.scalar_type() == ScalarType::Half,
48 "input dtype must be Float or Half");
49
50 if (opt_weight_zero_points.has_value()) {
51 ET_LOG_AND_RETURN_IF_FALSE(
52 tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales));
53 ET_LOG_AND_RETURN_IF_FALSE(
54 tensors_have_same_dtype(opt_weight_zero_points.value(), in));
55 }
56
57 // Support for non-null zero points is not implemented yet.
58 ET_LOG_MSG_AND_RETURN_IF_FALSE(
59 !opt_weight_zero_points.has_value(), "zero points not supported yet.");
60 return true;
61 }
62
quantized_mixed_linear_out(const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)63 Tensor& quantized_mixed_linear_out(
64 const Tensor& in,
65 const Tensor& weight,
66 const Tensor& weight_scales,
67 const exec_aten::optional<Tensor>& opt_weight_zero_points,
68 const exec_aten::optional<ScalarType> dtype,
69 Tensor& out) {
70 // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available.
71 ET_CHECK(check_quantized_mixed_linear_args(
72 in, weight, weight_scales, opt_weight_zero_points, dtype, out));
73
74 ScalarType out_dtype = dtype.has_value() ? dtype.value() : out.scalar_type();
75
76 size_t output_ndim = 2;
77 exec_aten::SizesType output_sizes[kTensorDimensionLimit];
78 output_sizes[0] = in.size(0);
79 output_sizes[1] = weight.size(0);
80
81 // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available.
82 ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok);
83
84 constexpr auto name = "quantized_decomposed::mixed_linear.out";
85
86 ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
87 ET_SWITCH_FLOAT_TYPES_AND(Half, out_dtype, ctx, name, CTYPE_OUT, [&]() {
88 size_t m = in.size(0);
89 size_t n = in.size(1);
90 size_t p = weight.size(0);
91 size_t g = n;
92
93 if (weight_scales.dim() == 2) {
94 g = (n + weight_scales.size(1) - 1) / weight_scales.size(1);
95 };
96
97 // FIXME: this currently ignores dtype
98 vec_quantized_matmul_transb_int8<
99 CTYPE_OUT, // T *z
100 CTYPE>( // U *x, U *s
101 out.mutable_data_ptr<CTYPE_OUT>(),
102 in.const_data_ptr<CTYPE>(),
103 weight.const_data_ptr<int8_t>(),
104 weight_scales.const_data_ptr<CTYPE>(),
105 m,
106 n,
107 p,
108 g);
109 });
110 });
111
112 return out;
113 }
114
quantized_mixed_linear_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const exec_aten::optional<ScalarType> dtype,Tensor & out)115 Tensor& quantized_mixed_linear_out(
116 KernelRuntimeContext& ctx,
117 const Tensor& in,
118 const Tensor& weight,
119 const Tensor& weight_scales,
120 const exec_aten::optional<Tensor>& opt_weight_zero_points,
121 const exec_aten::optional<ScalarType> dtype,
122 Tensor& out) {
123 // TODO(mcandales): Remove the need for this wrapper
124 // TODO(mkg): add support for dtype
125 (void)ctx;
126 return quantized_mixed_linear_out(
127 in, weight, weight_scales, opt_weight_zero_points, dtype, out);
128 }
129
130 } // namespace native
131 } // namespace executor
132 } // namespace torch
133