xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/c_api_unified_experimental_graph.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/abstract_context.h"
22 #include "tensorflow/c/eager/c_api_internal.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental.h"
24 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25 #include "tensorflow/c/eager/graph_function.h"
26 #include "tensorflow/c/tf_datatype.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/c/tf_status_helper.h"
29 #include "tensorflow/core/framework/shape_inference.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/strcat.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 using tensorflow::dyn_cast;
38 using tensorflow::string;
39 using tensorflow::gtl::ArraySlice;
40 
41 namespace tensorflow {
42 namespace tracing {
43 namespace graph {
44 
45 class GraphContext;
46 class GraphOperation;
47 class GraphTensor;
48 
49 auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
50 auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
51 
52 // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
53 // into the list of outputs for the operation.
54 class GraphTensor : public TracingTensorHandle {
55  public:
GraphTensor(TF_Output output,TF_Graph * graph)56   explicit GraphTensor(TF_Output output, TF_Graph* graph)
57       : TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
58 
DataType() const59   tensorflow::DataType DataType() const override {
60     return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
61   }
62 
Shape(tensorflow::PartialTensorShape * shape) const63   tensorflow::Status Shape(
64       tensorflow::PartialTensorShape* shape) const override {
65     DCHECK(shape != nullptr);
66     TF_Status status;
67     int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
68     DCHECK_GE(num_dims, -1);
69     TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
70     if (num_dims == kUnknownRank) {
71       return OkStatus();
72     }
73 
74     std::vector<int64_t> dims(num_dims, kUnknownDim);
75     TF_GraphGetTensorShape(graph_, output_,
76                            reinterpret_cast<int64_t*>(dims.data()), num_dims,
77                            &status);
78     TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
79     TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
80 
81     return OkStatus();
82   }
83 
84   TF_Output output_;
85 
86   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)87   static bool classof(const AbstractTensorHandle* ptr) {
88     return ptr->getKind() == kGraph;
89   }
90 
91  private:
92   TF_Graph* graph_;  // For shape inference.
93 };
94 
95 // GraphOperation wraps and populates a TF_OperationDescription.
96 class GraphOperation : public TracingOperation {
97  public:
GraphOperation(TF_Graph * g)98   explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
Release()99   void Release() override { delete this; }
Reset(const char * op,const char * raw_device_name)100   Status Reset(const char* op, const char* raw_device_name) override {
101     if (op_) {
102       return errors::FailedPrecondition("Reset called on already built op.");
103     }
104     if (raw_device_name) {
105       device_name_ = raw_device_name;
106     }
107     op_type_ = op;
108     return OkStatus();
109   }
SetOpName(const char * const op_name)110   Status SetOpName(const char* const op_name) override {
111     if (op_) {
112       return errors::FailedPrecondition(
113           "SetOpName called on already built op.");
114     }
115     if (op_type_.empty()) {
116       return errors::FailedPrecondition(
117           "GraphOperation::Reset must be called before calling SetOpName.");
118     }
119     // TODO(b/145674566): We use Graph::NewName to get a unique name here but
120     // this may not be consistent with python's naming policy.
121     mutex_lock l(g_->mu);
122     op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
123                                           g_->graph.NewName(op_name).c_str()));
124     return OkStatus();
125   }
Name() const126   const string& Name() const override { return op_type_; }
DeviceName() const127   const string& DeviceName() const override { return device_name_; }
128 
SetDeviceName(const char * name)129   Status SetDeviceName(const char* name) override {
130     // TODO(srbs): Implement this.
131     device_name_ = name;
132     return OkStatus();
133   }
134 
AddInput(AbstractTensorHandle * input)135   Status AddInput(AbstractTensorHandle* input) override {
136     GraphTensor* t = dyn_cast<GraphTensor>(input);
137     if (!t) {
138       return tensorflow::errors::InvalidArgument(
139           "Unable to cast input to GraphTensor");
140     }
141     TF_AddInput(op_.get(), t->output_);
142     return OkStatus();
143   }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)144   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
145     std::vector<TF_Output> tf_outputs(inputs.size());
146     for (int i = 0; i < inputs.size(); i++) {
147       GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
148       if (!t) {
149         return tensorflow::errors::InvalidArgument(
150             "Unable to cast input to GraphTensor");
151       }
152       tf_outputs[i] = t->output_;
153     }
154     TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
155     return OkStatus();
156   }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)157   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
158                  int* num_retvals) override {
159     auto* tf_opdesc = op_.release();
160     if (tf_opdesc == nullptr) {
161       return errors::InvalidArgument("AbstractOp is incomplete.");
162     }
163     TF_Status* s = TF_NewStatus();
164     auto* operation = TF_FinishOperation(tf_opdesc, s);
165     TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
166     TF_DeleteStatus(s);
167     *num_retvals = TF_OperationNumOutputs(operation);
168     for (int i = 0; i < *num_retvals; ++i) {
169       retvals[i] = new GraphTensor({operation, i}, g_);
170     }
171     return OkStatus();
172   }
173 
SetAttrString(const char * attr_name,const char * data,size_t length)174   Status SetAttrString(const char* attr_name, const char* data,
175                        size_t length) override {
176     tensorflow::StringPiece s(data, length);
177     op_->node_builder.Attr(attr_name, s);
178     return OkStatus();
179   }
SetAttrInt(const char * attr_name,int64_t value)180   Status SetAttrInt(const char* attr_name, int64_t value) override {
181     op_->node_builder.Attr(attr_name, static_cast<int64_t>(value));
182     return OkStatus();
183   }
SetAttrFloat(const char * attr_name,float value)184   Status SetAttrFloat(const char* attr_name, float value) override {
185     op_->node_builder.Attr(attr_name, value);
186     return OkStatus();
187   }
SetAttrBool(const char * attr_name,bool value)188   Status SetAttrBool(const char* attr_name, bool value) override {
189     op_->node_builder.Attr(attr_name, value);
190     return OkStatus();
191   }
SetAttrType(const char * const attr_name,DataType value)192   Status SetAttrType(const char* const attr_name, DataType value) override {
193     if (!op_) {
194       return Status(
195           error::Code::FAILED_PRECONDITION,
196           "op_type and op_name must be specified before specifying attrs.");
197     }
198     op_->node_builder.Attr(attr_name, value);
199     return OkStatus();
200   }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)201   Status SetAttrShape(const char* attr_name, const int64_t* dims,
202                       const int num_dims) override {
203     PartialTensorShape shape;
204     if (num_dims >= 0) {
205       shape = PartialTensorShape(ArraySlice<int64_t>(
206           reinterpret_cast<const int64_t*>(dims), num_dims));
207     }
208     op_->node_builder.Attr(attr_name, shape);
209     return OkStatus();
210   }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)211   Status SetAttrFunction(const char* attr_name,
212                          const AbstractOperation* value) override {
213     return tensorflow::errors::Unimplemented(
214         "SetAttrFunction has not been implemented yet.");
215   }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)216   Status SetAttrFunctionName(const char* attr_name, const char* value,
217                              size_t length) override {
218     tensorflow::NameAttrList func_name;
219     func_name.set_name(string(value, value + length));
220     op_->node_builder.Attr(attr_name, func_name);
221     return OkStatus();
222   }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)223   Status SetAttrTensor(const char* attr_name,
224                        AbstractTensorInterface* tensor) override {
225     return tensorflow::errors::Unimplemented(
226         "SetAttrTensor has not been implemented yet.");
227   }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)228   Status SetAttrStringList(const char* attr_name, const void* const* values,
229                            const size_t* lengths, int num_values) override {
230     if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
231       op_->colocation_constraints.clear();
232       for (int i = 0; i < num_values; ++i) {
233         op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
234                                             lengths[i]);
235       }
236     } else {
237       std::vector<tensorflow::StringPiece> v;
238       v.reserve(num_values);
239       for (int i = 0; i < num_values; ++i) {
240         v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
241       }
242       op_->node_builder.Attr(attr_name, v);
243     }
244     return OkStatus();
245   }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)246   Status SetAttrFloatList(const char* attr_name, const float* values,
247                           int num_values) override {
248     op_->node_builder.Attr(attr_name,
249                            ArraySlice<const float>(values, num_values));
250     return OkStatus();
251   }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)252   Status SetAttrIntList(const char* attr_name, const int64_t* values,
253                         int num_values) override {
254     op_->node_builder.Attr(
255         attr_name, ArraySlice<const int64_t>(
256                        reinterpret_cast<const int64_t*>(values), num_values));
257     return OkStatus();
258   }
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)259   Status SetAttrTypeList(const char* attr_name, const DataType* values,
260                          int num_values) override {
261     op_->node_builder.Attr(attr_name,
262                            ArraySlice<const DataType>(values, num_values));
263     return OkStatus();
264   }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)265   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
266                          int num_values) override {
267     std::unique_ptr<bool[]> b(new bool[num_values]);
268     for (int i = 0; i < num_values; ++i) {
269       b[i] = values[i];
270     }
271     op_->node_builder.Attr(attr_name,
272                            ArraySlice<const bool>(b.get(), num_values));
273 
274     return OkStatus();
275   }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)276   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
277                           const int* num_dims, int num_values) override {
278     std::vector<PartialTensorShape> shapes;
279     shapes.reserve(num_values);
280     for (int i = 0; i < num_values; ++i) {
281       if (num_dims[i] < 0) {
282         shapes.emplace_back();
283       } else {
284         shapes.emplace_back(ArraySlice<int64_t>(
285             reinterpret_cast<const int64_t*>(dims[i]), num_dims[i]));
286       }
287     }
288     op_->node_builder.Attr(attr_name, shapes);
289     return OkStatus();
290   }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)291   Status SetAttrFunctionList(
292       const char* attr_name,
293       absl::Span<const AbstractOperation*> values) override {
294     return tensorflow::errors::Unimplemented(
295         "SetAttrFunctionList has not been implemented yet.");
296   }
297   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)298   static bool classof(const AbstractOperation* ptr) {
299     return ptr->getKind() == kGraph;
300   }
~GraphOperation()301   ~GraphOperation() override {}
302 
303  private:
304   friend class GraphContext;  // For access to op_.
305   TF_Graph* g_;
306   std::unique_ptr<TF_OperationDescription> op_;
307   // Hold `op_type` and `op_name` till both are available since we need both
308   // to build a graph operation.
309   string op_type_;
310   const char* op_name_ = nullptr;
311   // TODO(srbs): Use this.
312   string device_name_;
313 };
314 
315 // GraphContext wraps a TF_Graph modeling a single function and manages the
316 // "execution" of operation, i.e. adding them to the function.
317 class GraphContext : public TracingContext {
318  public:
GraphContext(const char * name)319   explicit GraphContext(const char* name)
320       : TracingContext(kGraph),
321         graph_(new TF_Graph(), TF_DeleteGraph),
322         name_(name) {}
323 
Release()324   void Release() override { delete this; }
325 
CreateOperation()326   TracingOperation* CreateOperation() override {
327     return new GraphOperation(graph_.get());
328   }
329 
AddParameter(DataType dtype,const PartialTensorShape & shape,TracingTensorHandle ** output)330   Status AddParameter(DataType dtype, const PartialTensorShape& shape,
331                       TracingTensorHandle** output) override {
332     TracingOperationPtr operation(CreateOperation());
333     TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
334     TF_RETURN_IF_ERROR(
335         operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
336     TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
337     if (!shape.unknown_rank()) {
338       TF_RETURN_IF_ERROR(operation->SetAttrShape(
339           "shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
340           shape.dims()));
341     }
342     int num_outputs = 1;
343     std::vector<AbstractTensorHandle*> outputs(num_outputs);
344     TF_RETURN_IF_ERROR(operation->Execute(
345         absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
346 
347     if (num_outputs != 1) {
348       return errors::Internal("Expected 1 output but found ", num_outputs);
349     }
350     auto* t = dyn_cast<GraphTensor>(outputs[0]);
351     if (!t) {
352       return tensorflow::errors::InvalidArgument(
353           "Unable to cast input to GraphTensor");
354     }
355     inputs_.push_back(t->output_);
356     *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
357     return OkStatus();
358   }
359 
Finalize(OutputList * outputs,AbstractFunction ** f)360   Status Finalize(OutputList* outputs, AbstractFunction** f) override {
361     std::vector<TF_Output> graph_outputs;
362     graph_outputs.reserve(outputs->outputs.size());
363     for (auto* abstract_output : outputs->outputs) {
364       GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
365       if (!output) {
366         return errors::Unimplemented(
367             "Returning a non-graph tensor from a function has not "
368             "been implemented yet.");
369       }
370       graph_outputs.push_back(output->output_);
371     }
372 
373     auto s = TF_NewStatus();
374     auto func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
375                                    inputs_.size(), inputs_.data(),
376                                    graph_outputs.size(), graph_outputs.data(),
377                                    nullptr, nullptr, name_.data(), s);
378     *f = new GraphFunction(std::move(func->fdef));
379     TF_DeleteFunction(func);
380     TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
381     TF_DeleteStatus(s);
382     return OkStatus();
383   }
384 
RegisterFunction(AbstractFunction * func)385   Status RegisterFunction(AbstractFunction* func) override {
386     return errors::Unimplemented(
387         "Registering graph functions has not been implemented yet.");
388   }
389 
RemoveFunction(const string & func)390   Status RemoveFunction(const string& func) override {
391     return errors::Unimplemented(
392         "GraphContext::RemoveFunction has not been implemented yet.");
393   }
394   // For LLVM style RTTI.
classof(const AbstractContext * ptr)395   static bool classof(const AbstractContext* ptr) {
396     return ptr->getKind() == kGraph;
397   }
398 
399  private:
400   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
401   std::vector<TF_Output> inputs_;
402   string name_;
403 };
404 
GraphTracingFactory(const char * name,TF_Status * s)405 static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
406   return new GraphContext(name);
407 }
408 
409 // Register the tracing implemented in this file as the default tracing engine.
__anonf59f71810102null410 static bool register_tracing = [] {
411   RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
412   SetDefaultTracingEngine("graphdef").IgnoreError();
413   return true;
414 }();
415 
416 }  // namespace graph
417 }  // namespace tracing
418 }  // namespace tensorflow
419