1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnnTestUtils/LayerTestResult.hpp> 9 10 #include <ResolveType.hpp> 11 12 #include <armnn/backends/IBackendInternal.hpp> 13 14 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims> 15 LayerTestResult<T, NumDims> BatchMatMulTestImpl( 16 armnn::IWorkloadFactory& workloadFactory, 17 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 18 const armnn::ITensorHandleFactory& tensorHandleFactory, 19 armnn::BatchMatMulDescriptor descriptor, 20 const std::vector<T>& inputX, 21 const std::vector<T>& inputY, 22 const std::vector<T>& outputExpected, 23 const armnn::TensorInfo& inputXInfo, 24 const armnn::TensorInfo& inputYInfo, 25 const armnn::TensorInfo& outputInfo); 26 27 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 28 LayerTestResult<T, 2> BatchMatMul2DSimpleTest( 29 armnn::IWorkloadFactory& workloadFactory, 30 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 31 const armnn::ITensorHandleFactory& tensorHandleFactory); 32 33 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 34 LayerTestResult<T, 3> BatchMatMul3DSimpleTest( 35 armnn::IWorkloadFactory& workloadFactory, 36 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 37 const armnn::ITensorHandleFactory& tensorHandleFactory); 38 39 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 40 LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest( 41 armnn::IWorkloadFactory& workloadFactory, 42 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 43 const armnn::ITensorHandleFactory& tensorHandleFactory); 44 45 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 46 LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest( 47 armnn::IWorkloadFactory& workloadFactory, 48 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 49 const armnn::ITensorHandleFactory& tensorHandleFactory); 50 51 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 52 LayerTestResult<T, 3> BatchMatMul3DBatchTest( 53 armnn::IWorkloadFactory& workloadFactory, 54 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 55 const armnn::ITensorHandleFactory& tensorHandleFactory); 56 57 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 58 LayerTestResult<T, 3> BatchMatMul3DBroadcastTest( 59 armnn::IWorkloadFactory& workloadFactory, 60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 61 const armnn::ITensorHandleFactory& tensorHandleFactory); 62 63 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 64 LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest( 65 armnn::IWorkloadFactory& workloadFactory, 66 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 67 const armnn::ITensorHandleFactory& tensorHandleFactory); 68 69 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 70 LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest( 71 armnn::IWorkloadFactory& workloadFactory, 72 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 73 const armnn::ITensorHandleFactory& tensorHandleFactory); 74 75 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 76 LayerTestResult<T, 2> BatchMatMul2DTinyTest( 77 armnn::IWorkloadFactory& workloadFactory, 78 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 79 const armnn::ITensorHandleFactory& tensorHandleFactory); 80 81 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 82 LayerTestResult<T, 3> BatchMatMul3DNonSquareTest( 83 armnn::IWorkloadFactory& workloadFactory, 84 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 85 const armnn::ITensorHandleFactory& tensorHandleFactory); 86 87 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 88 LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest( 89 armnn::IWorkloadFactory& workloadFactory, 90 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 91 const armnn::ITensorHandleFactory& tensorHandleFactory); 92 93 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 94 LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest( 95 armnn::IWorkloadFactory& workloadFactory, 96 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 97 const armnn::ITensorHandleFactory& tensorHandleFactory); 98 99 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> 100 LayerTestResult<T, 4> BatchMatMulNHWCParamsTest( 101 armnn::IWorkloadFactory& workloadFactory, 102 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, 103 const armnn::ITensorHandleFactory& tensorHandleFactory);