xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/delegate.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_
16 #define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/delegates/flex/delegate_data.h"
20 #include "tensorflow/lite/delegates/utils/simple_delegate.h"
21 
22 namespace tflite {
23 
24 namespace flex {
25 namespace testing {
26 class KernelTest;
27 }  // namespace testing
28 }  // namespace flex
29 
30 // WARNING: This is an experimental interface that is subject to change.
31 // Delegate that can be used to extract parts of a graph that are designed to be
32 // executed by TensorFlow's runtime via Eager.
33 //
34 // The interpreter must be constructed after the FlexDelegate and destructed
35 // before the FlexDelegate. This delegate may be used with multiple
36 // interpreters, but it is *not* thread-safe.
37 //
38 // Usage:
39 //   auto delegate = FlexDelegate::Create();
40 //   ... build interpreter ...
41 //
42 //   if (delegate) {
43 //     interpreter->ModifyGraphWithDelegate(delegate.get());
44 //   }
45 //
46 //   void* delegate_data = delegate->data_;
47 //   interpreter->SetCancellationFunction(
48 //     delegate_data,
49 //     FlexDelegate::HasCancelled);
50 //
51 //   ... run inference ...
52 //
53 //    static_cast<FlexDelegate*>(delegate_data)->Cancel();
54 //
55 //   ... destroy interpreter ...
56 //   ... destroy delegate ...
57 class FlexDelegate : public SimpleDelegateInterface {
58  public:
59   friend class flex::testing::KernelTest;
60 
61   // Creates a delegate that supports TF ops.
Create()62   static TfLiteDelegateUniquePtr Create() {
63     return Create(/*base_delegate*/ nullptr);
64   }
65 
~FlexDelegate()66   ~FlexDelegate() override {}
67 
mutable_data()68   flex::DelegateData* mutable_data() { return &delegate_data_; }
69 
70   // This method is thread safe. It does two things:
71   //   1. Calls the CancellationManager of the TF eager runtime to support
72   //      intra-op cancellation in TF.
73   //   2. Uses the CancellationManager to signal TFLite interpreter for inter-op
74   //      cancellation.
75   // Training is non-recoverable after calling this API.
76   void Cancel();
77 
78   // The param `data` must be a pointer to a FlexDelegate instance.
79   static bool HasCancelled(void* data);
80 
81  protected:
82   // We sometimes have to create certain stub data to test FlexDelegate. To
83   // achieve this, we will make a testing flex delegate class that inherits from
84   // FlexDelegate to override certain things for stub data creation. Therefore,
85   // this function accepts a FlexDelegate instance to initiliaze it properly for
86   // create a testing flex delegate in some cases, and it is only used in
87   // testing.
88   static TfLiteDelegateUniquePtr Create(
89       std::unique_ptr<FlexDelegate> base_delegate);
90 
FlexDelegate()91   FlexDelegate() {}
92 
93   const char* Name() const override;
94 
95   bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
96                                  const TfLiteNode* node,
97                                  TfLiteContext* context) const override;
98 
99   TfLiteStatus Initialize(TfLiteContext* context) override;
100 
DelegateOptions()101   SimpleDelegateInterface::Options DelegateOptions() const override {
102     // Use default options.
103     return SimpleDelegateInterface::Options();
104   }
105 
106   std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
107       override;
108 
109   TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
110                                     TfLiteBufferHandle buffer_handle,
111                                     TfLiteTensor* output);
112 
113   flex::DelegateData delegate_data_;
114 
115   // Pointer to the base TfLiteDelegate which is created from the Create call.
116   TfLiteDelegate* base_delegate_ = nullptr;
117 
118  private:
119   // A cancellation manager.
120   std::unique_ptr<tensorflow::CancellationManager> cancellation_manager_;
121 };
122 
123 }  // namespace tflite
124 
125 #endif  // TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_
126