xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tensor_name_op.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <utility>
18 
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/stringpiece.h"
27 
28 namespace fcp {
29 
30 namespace {
31 
32 REGISTER_OP("TensorName")
33     .Attr("InputType: type")
34     .Input("input_tensor: InputType")
35     .Output("tensor_name: string")
36     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
37 
38 class TensorNameOp : public tensorflow::OpKernel {
39  public:
TensorNameOp(tensorflow::OpKernelConstruction * context)40   explicit TensorNameOp(tensorflow::OpKernelConstruction* context)
41       : OpKernel(context) {
42     const tensorflow::NodeDef& def = context->def();
43     // Note: more than one input is allowed since the "true" input node may be
44     // followed by any number of control inputs.
45     OP_REQUIRES(
46         context, def.input_size() >= 1,
47         tensorflow::errors::InvalidArgument("Expected an input, found none."));
48     input_name_ = def.input(0);
49   }
50 
Compute(tensorflow::OpKernelContext * context)51   void Compute(tensorflow::OpKernelContext* context) override {
52     tensorflow::Tensor* output_tensor;
53     OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
54     output_tensor->scalar<tensorflow::tstring>()() = input_name_;
55   }
56 
57  private:
58   tensorflow::tstring input_name_;
59 };
60 
61 REGISTER_KERNEL_BUILDER(Name("TensorName").Device(tensorflow::DEVICE_CPU),
62                         TensorNameOp);
63 
64 }  // namespace
65 
66 }  // namespace fcp
67