1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "TosaRefLayerSupport.hpp"
7
8 #include <tosaCommon/TosaMappings.hpp>
9
10 #include <armnn/Types.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12
13 #include <graph_status.h>
14 #include <model_runner.h>
15
16 #include <vector>
17
18 namespace armnn
19 {
20
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmInputParamsInfo,Optional<std::string &> reasonIfUnsupported) const21 bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
22 const std::vector<TensorInfo>& infos,
23 const BaseDescriptor& descriptor,
24 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
25 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
26 Optional<std::string&> reasonIfUnsupported) const
27 {
28 IgnoreUnused(lstmParamsInfo);
29 IgnoreUnused(quantizedLstmInputParamsInfo);
30 IgnoreUnused(reasonIfUnsupported);
31
32 std::vector<const TensorInfo*> inputInfos;
33 std::vector<const TensorInfo*> outputInfos;
34
35 switch (type)
36 {
37 case LayerType::Input:
38 case LayerType::Output:
39 return true;
40 case LayerType::Addition:
41 case LayerType::Multiplication:
42 case LayerType::Subtraction:
43 // Setup inputs and outputs
44 inputInfos.push_back(&infos[0]);
45 inputInfos.push_back(&infos[1]);
46 outputInfos.push_back(&infos[2]);
47 break;
48 case LayerType::Concat:
49 for (unsigned int i = 0; i < infos.size() - 1; ++i)
50 {
51 inputInfos.push_back(&infos[i]);
52 }
53 outputInfos.push_back(&infos.back());
54 break;
55 case LayerType::Constant:
56 outputInfos.push_back(&infos[0]);
57 break;
58 case LayerType::Convolution2d:
59 {
60 inputInfos.push_back(&infos[0]); // input
61 outputInfos.push_back(&infos[1]); // output
62 inputInfos.push_back(&infos[2]); // weights
63
64 auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
65 if(conv2dDesc->m_BiasEnabled)
66 {
67 inputInfos.push_back(&infos[3]); // bias
68 }
69 break;
70 }
71 case LayerType::ElementwiseUnary:
72 case LayerType::Pooling2d:
73 case LayerType::Reshape:
74 case LayerType::Slice:
75 case LayerType::Transpose:
76 inputInfos.push_back(&infos[0]);
77 outputInfos.push_back(&infos[1]);
78 break;
79 case LayerType::TransposeConvolution2d:
80 {
81 inputInfos.push_back(&infos[0]); // input
82 outputInfos.push_back(&infos[1]); // output
83 inputInfos.push_back(&infos[2]); // weights
84
85 auto conv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
86 if(conv2dDesc->m_BiasEnabled)
87 {
88 inputInfos.push_back(&infos[3]); // bias
89 }
90 break;
91 }
92 default:
93 // Default to false for all unsupported layers.
94 return false;
95 }
96
97 auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor);
98 if (mappings->GetName() == "")
99 {
100 // There currently isn't a TOSA mapping for this layer, as the default was returned.
101 return false;
102 }
103
104 TosaSerializationHandler handler;
105
106 // Add mappings to main block as the TOSA Reference Model requires the graph to be in one block called main.
107 auto* block = new TosaSerializationBasicBlock("main",
108 mappings->GetOperators(),
109 mappings->GetTensors(),
110 mappings->GetInputs(),
111 mappings->GetOutputs());
112 handler.GetBlocks().emplace_back(block);
113
114 GraphStatus status;
115 TosaReference::IModelRunner runner;
116
117 #if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
118 // There currently isn't a way to disable the output from the TOSA Reference Model, but it does have a file pointer
119 // to write debug output to, so set this to /dev/null (if it exists on the system) to hide the output.
120 func_debug_t funcDebug;
121
122 FILE* file = fopen("/dev/null", "w");
123 funcDebug.func_debug_file = (file == nullptr) ? stderr : file;
124
125 runner.setFuncDebug(funcDebug);
126 #endif
127
128 // Initialise the model runner with the TosaSerializationHandler, which runs validation on the mapping.
129 status = runner.initialize(handler);
130
131 #if !defined(TOSA_REFERENCE_MODEL_OUTPUT)
132 // Reset FuncDebug as they can persist across multiple IModelRunner instances.
133 funcDebug.func_debug_file = stderr;
134 runner.setFuncDebug(funcDebug);
135 #endif
136
137 if(status == GraphStatus::TOSA_ERROR || status == GraphStatus::TOSA_UNPREDICTABLE)
138 {
139 return false;
140 }
141 else
142 {
143 return true;
144 }
145 }
146
147 } // namespace armnn
148