1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <CpuExecutor.h>
10*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/OperandTypes.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <vector>
15*89c4ff92SAndroid Build Coastguard Worker #include <string>
16*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
17*89c4ff92SAndroid Build Coastguard Worker #include <iomanip>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker extern const armnn::PermutationVector g_DontPermute;
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker template <typename OperandType>
27*89c4ff92SAndroid Build Coastguard Worker class UnsupportedOperand: public std::runtime_error
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker public:
UnsupportedOperand(const OperandType type)30*89c4ff92SAndroid Build Coastguard Worker UnsupportedOperand(const OperandType type)
31*89c4ff92SAndroid Build Coastguard Worker : std::runtime_error("Operand type is unsupported")
32*89c4ff92SAndroid Build Coastguard Worker , m_type(type)
33*89c4ff92SAndroid Build Coastguard Worker {}
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker OperandType m_type;
36*89c4ff92SAndroid Build Coastguard Worker };
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker /// Swizzles tensor data in @a input according to the dimension mappings.
39*89c4ff92SAndroid Build Coastguard Worker void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensor,
40*89c4ff92SAndroid Build Coastguard Worker const void* input,
41*89c4ff92SAndroid Build Coastguard Worker void* output,
42*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker /// Returns a pointer to a specific location in a pool`
45*89c4ff92SAndroid Build Coastguard Worker void* GetMemoryFromPool(DataLocation location,
46*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::RunTimePoolInfo>& memPools);
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker void* GetMemoryFromPointer(const Request::Argument& requestArg);
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo GetTensorInfoForOperand(const Operand& operand);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker std::string GetOperandSummary(const Operand& operand);
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker bool isQuantizedOperand(const OperandType& operandType);
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker std::string GetModelSummary(const Model& model);
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker template <typename TensorType>
59*89c4ff92SAndroid Build Coastguard Worker void DumpTensor(const std::string& dumpDir,
60*89c4ff92SAndroid Build Coastguard Worker const std::string& requestName,
61*89c4ff92SAndroid Build Coastguard Worker const std::string& tensorName,
62*89c4ff92SAndroid Build Coastguard Worker const TensorType& tensor);
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
65*89c4ff92SAndroid Build Coastguard Worker const std::string& dumpDir,
66*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId networkId,
67*89c4ff92SAndroid Build Coastguard Worker const armnn::IProfiler* profiler);
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker std::string ExportNetworkGraphToDotFile(const armnn::IOptimizedNetwork& optimizedNetwork,
70*89c4ff92SAndroid Build Coastguard Worker const std::string& dumpDir);
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker std::string SerializeNetwork(const armnn::INetwork& network,
73*89c4ff92SAndroid Build Coastguard Worker const std::string& dumpDir,
74*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t>& dataCacheData,
75*89c4ff92SAndroid Build Coastguard Worker bool dataCachingActive = true);
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker void RenameExportedFiles(const std::string& existingSerializedFileName,
78*89c4ff92SAndroid Build Coastguard Worker const std::string& existingDotFileName,
79*89c4ff92SAndroid Build Coastguard Worker const std::string& dumpDir,
80*89c4ff92SAndroid Build Coastguard Worker const armnn::NetworkId networkId);
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker void RenameFile(const std::string& existingName,
83*89c4ff92SAndroid Build Coastguard Worker const std::string& extension,
84*89c4ff92SAndroid Build Coastguard Worker const std::string& dumpDir,
85*89c4ff92SAndroid Build Coastguard Worker const armnn::NetworkId networkId);
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker /// Checks if a tensor info represents a dynamic tensor
88*89c4ff92SAndroid Build Coastguard Worker bool IsDynamicTensor(const armnn::TensorInfo& outputInfo);
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker /// Checks for ArmNN support of dynamic tensors.
91*89c4ff92SAndroid Build Coastguard Worker bool AreDynamicTensorsSupported(void);
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker std::string GetFileTimestamp();
94*89c4ff92SAndroid Build Coastguard Worker
ComputeShape(const armnn::TensorInfo & info)95*89c4ff92SAndroid Build Coastguard Worker inline OutputShape ComputeShape(const armnn::TensorInfo& info)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker OutputShape shape;
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape tensorShape = info.GetShape();
100*89c4ff92SAndroid Build Coastguard Worker // Android will expect scalars as a zero dimensional tensor
101*89c4ff92SAndroid Build Coastguard Worker if(tensorShape.GetDimensionality() == armnn::Dimensionality::Scalar)
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker shape.dimensions = std::vector<uint32_t>{};
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker else
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> dimensions;
108*89c4ff92SAndroid Build Coastguard Worker const unsigned int numDims = tensorShape.GetNumDimensions();
109*89c4ff92SAndroid Build Coastguard Worker dimensions.resize(numDims);
110*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker dimensions[outputIdx] = tensorShape[outputIdx];
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker shape.dimensions = dimensions;
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker shape.isSufficient = true;
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker return shape;
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools);
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
125