1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
12 #include <executorch/kernels/portable/cpu/vec_ops.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using Scalar = exec_aten::Scalar;
21
addmm_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & out)22 Tensor& addmm_out(
23 KernelRuntimeContext& ctx,
24 const Tensor& in,
25 const Tensor& mat1,
26 const Tensor& mat2,
27 const Scalar& beta,
28 const Scalar& alpha,
29 Tensor& out) {
30 ET_KERNEL_CHECK(
31 ctx,
32 check_addmm_args(in, mat1, mat2, beta, alpha, out),
33 InvalidArgument,
34 out);
35
36 size_t output_ndim = 0;
37 exec_aten::SizesType output_sizes[kTensorDimensionLimit];
38 get_mm_out_target_size(mat1, mat2, output_sizes, &output_ndim);
39 ET_KERNEL_CHECK(
40 ctx,
41 resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
42 InvalidArgument,
43 out);
44
45 ET_KERNEL_CHECK(
46 ctx, tensor_is_broadcastable_to(in, out), InvalidArgument, out);
47
48 ET_KERNEL_CHECK(
49 ctx,
50 tensors_have_same_dim_order(in, mat1, mat2, out),
51 InvalidArgument,
52 out);
53
54 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
55
56 // @lint-ignore CLANGTIDY facebook-hte-CArray
57 static constexpr const char op_name[] = "addmm.out";
58
59 ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
60 CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
61 CTYPE beta_val = utils::scalar_to<CTYPE>(beta);
62 size_t m = mat1.size(0);
63 size_t n = mat1.size(1);
64 size_t p = mat2.size(1);
65
66 if (out.sizes() == in.sizes()) {
67 // vec_addmm assumes that no broadcasting is required.
68 vec_addmm<CTYPE, CTYPE>(
69 out.mutable_data_ptr<CTYPE>(),
70 in.const_data_ptr<CTYPE>(),
71 mat1.const_data_ptr<CTYPE>(),
72 mat2.const_data_ptr<CTYPE>(),
73 m,
74 n,
75 p,
76 beta_val,
77 alpha_val);
78 } else {
79 // If broadcasting is required, them compute the matmul
80 // and addition separately, using
81 // apply_binary_elementwise_fn to perform the addition
82 // while applying broadcasting
83 vec_matmul<CTYPE, CTYPE>(
84 out.mutable_data_ptr<CTYPE>(),
85 mat1.const_data_ptr<CTYPE>(),
86 mat2.const_data_ptr<CTYPE>(),
87 m,
88 n,
89 p);
90
91 utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
92 [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
93 return val_a * alpha_val + val_b * beta_val;
94 },
95 ctx,
96 out,
97 utils::SupportedTensorDtypes::REALHBF16,
98 in,
99 utils::SupportedTensorDtypes::REALHBF16,
100 out,
101 utils::SupportedTensorDtypes::REALHBF16);
102 }
103 });
104
105 return out;
106 }
107
108 } // namespace native
109 } // namespace executor
110 } // namespace torch
111