xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_addmm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
12 #include <executorch/kernels/portable/cpu/vec_ops.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21 
addmm_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & out)22 Tensor& addmm_out(
23     KernelRuntimeContext& ctx,
24     const Tensor& in,
25     const Tensor& mat1,
26     const Tensor& mat2,
27     const Scalar& beta,
28     const Scalar& alpha,
29     Tensor& out) {
30   ET_KERNEL_CHECK(
31       ctx,
32       check_addmm_args(in, mat1, mat2, beta, alpha, out),
33       InvalidArgument,
34       out);
35 
36   size_t output_ndim = 0;
37   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
38   get_mm_out_target_size(mat1, mat2, output_sizes, &output_ndim);
39   ET_KERNEL_CHECK(
40       ctx,
41       resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
42       InvalidArgument,
43       out);
44 
45   ET_KERNEL_CHECK(
46       ctx, tensor_is_broadcastable_to(in, out), InvalidArgument, out);
47 
48   ET_KERNEL_CHECK(
49       ctx,
50       tensors_have_same_dim_order(in, mat1, mat2, out),
51       InvalidArgument,
52       out);
53 
54   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
55 
56   // @lint-ignore CLANGTIDY facebook-hte-CArray
57   static constexpr const char op_name[] = "addmm.out";
58 
59   ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
60     CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
61     CTYPE beta_val = utils::scalar_to<CTYPE>(beta);
62     size_t m = mat1.size(0);
63     size_t n = mat1.size(1);
64     size_t p = mat2.size(1);
65 
66     if (out.sizes() == in.sizes()) {
67       // vec_addmm assumes that no broadcasting is required.
68       vec_addmm<CTYPE, CTYPE>(
69           out.mutable_data_ptr<CTYPE>(),
70           in.const_data_ptr<CTYPE>(),
71           mat1.const_data_ptr<CTYPE>(),
72           mat2.const_data_ptr<CTYPE>(),
73           m,
74           n,
75           p,
76           beta_val,
77           alpha_val);
78     } else {
79       // If broadcasting is required, them compute the matmul
80       // and addition separately, using
81       // apply_binary_elementwise_fn to perform the addition
82       // while applying broadcasting
83       vec_matmul<CTYPE, CTYPE>(
84           out.mutable_data_ptr<CTYPE>(),
85           mat1.const_data_ptr<CTYPE>(),
86           mat2.const_data_ptr<CTYPE>(),
87           m,
88           n,
89           p);
90 
91       utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
92           [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
93             return val_a * alpha_val + val_b * beta_val;
94           },
95           ctx,
96           out,
97           utils::SupportedTensorDtypes::REALHBF16,
98           in,
99           utils::SupportedTensorDtypes::REALHBF16,
100           out,
101           utils::SupportedTensorDtypes::REALHBF16);
102     }
103   });
104 
105   return out;
106 }
107 
108 } // namespace native
109 } // namespace executor
110 } // namespace torch
111