xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_ACL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_ACL_H_
18 
19 #include <iostream>
20 
21 #include "tensorflow/core/platform/types.h"
22 
23 #ifdef XLA_CPU_USE_ACL
24 #include "arm_compute/runtime/NEON/NEFunctions.h"
25 #include "arm_compute/runtime/NEON/NEScheduler.h"
26 #include "utils/Utils.h"
27 
28 extern "C" {
29 struct acl_matmul_obj_t {
30   arm_compute::NEGEMM gemm;
31   arm_compute::NETranspose trans_lhs;
32   arm_compute::NETranspose trans_rhs;
33   arm_compute::Tensor rhs_tensor;
34   arm_compute::Tensor rhs_acc_tensor;
35   arm_compute::Tensor lhs_tensor;
36   arm_compute::Tensor lhs_acc_tensor;
37   arm_compute::Tensor out_tensor;
38 };
39 
40 struct acl_matmul_conf_t {
41   bool with_bias;
42   bool is_trans_lhs;
43   bool is_trans_rhs;
44   arm_compute::TensorInfo lhs_info;
45   arm_compute::TensorInfo lhs_acc_info;
46   arm_compute::TensorInfo rhs_info;
47   arm_compute::TensorInfo rhs_acc_info;
48   arm_compute::TensorInfo out_info;
49   arm_compute::GEMMInfo gemm_info;
50   float alpha;
51 };
52 
53 extern void __xla_cpu_runtime_ACLMatMulF32(
54     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
55     float* lhs, float* rhs, int64_t m, int64_t n, int64_t k,
56     int32_t transpose_lhs, int32_t transpose_rhs);
57 
58 extern void __xla_cpu_runtime_ACLBatchMatMulF32(
59     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
60     float* lhs, float* rhs, int64_t m, int64_t n, int64_t k, int64_t batch_size,
61     int32_t transpose_lhs, int32_t transpose_rhs);
62 
63 }  // extern "C"
64 #else
65 extern "C" {
__xla_cpu_runtime_ACLMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)66 extern void __xla_cpu_runtime_ACLMatMulF32(
67     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
68     float* lhs, float* rhs, int64_t m, int64_t n, int64_t k,
69     int32_t transpose_lhs, int32_t transpose_rhs) {
70   std::cerr
71       << "Attempt to call ACL MatMul runtime library without defining "
72          "XLA_CPU_USE_ACL. Add --define=build_with_acl=true to build with ACL.";
73   exit(1);
74 }
75 
__xla_cpu_runtime_ACLBatchMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)76 extern void __xla_cpu_runtime_ACLBatchMatMulF32(
77     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
78     float* lhs, float* rhs, int64_t m, int64_t n, int64_t k, int64_t batch_size,
79     int32_t transpose_lhs, int32_t transpose_rhs) {
80   std::cerr
81       << "Attempt to call ACL MatMul runtime library without defining "
82          "XLA_CPU_USE_ACL. Add --define=build_with_acl=true to build with ACL.";
83   exit(1);
84 }
85 }  // extern "C"
86 #endif  // XLA_CPU_USE_ACL
87 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_ACL_H_
88