1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_BATCH_SEQ_BUILDER_H_ 16 #define TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_BATCH_SEQ_BUILDER_H_ 17 18 #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" 19 20 namespace tflite { 21 namespace delegates { 22 namespace hexagon { 23 24 class BatchSeqBuilder : public OpBuilder { 25 public: BatchSeqBuilder(GraphBuilder * graph_builder,int op_type)26 explicit BatchSeqBuilder(GraphBuilder* graph_builder, int op_type) 27 : OpBuilder(graph_builder, op_type) {} 28 29 TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs, 30 const TfLiteIntArray* outputs, 31 TfLiteContext* context) override; 32 RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)33 TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, 34 TfLiteContext* context) override { 35 // BatchSeqConfig doesn't have any outputs. 36 return kTfLiteOk; 37 } 38 SetMaxSizeForBatch(int max_size_for_batch)39 void SetMaxSizeForBatch(int max_size_for_batch) { 40 max_size_for_batch_ = max_size_for_batch; 41 } 42 SetInputBatchDimensions(TfLiteIntArray * input_batch_dimensions)43 void SetInputBatchDimensions(TfLiteIntArray* input_batch_dimensions) { 44 input_batch_dims_ = input_batch_dimensions; 45 } 46 SetOutputBatchDimensions(TfLiteIntArray * output_batch_dimensions)47 void SetOutputBatchDimensions(TfLiteIntArray* output_batch_dimensions) { 48 output_batch_dims_ = output_batch_dimensions; 49 } 50 51 private: 52 // Maximum size for the batch dimension in a single run. 53 // The graph can have input with larger batch, internally 54 // multiple runs will happen each won't have more than 'max_size_for_batch_' 55 // in batch dimension. 56 int max_size_for_batch_ = 1; 57 // Input dimension for each input in the graph. 58 // Input with fixed batch should have -1. 59 TfLiteIntArray* input_batch_dims_; 60 // Output dimension for each output in the graph. 61 // Output with fixed batch should have -1. 62 TfLiteIntArray* output_batch_dims_; 63 }; 64 65 } // namespace hexagon 66 } // namespace delegates 67 } // namespace tflite 68 69 #endif // TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_BATCH_SEQ_BUILDER_H_ 70