1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <utility>
19 
20 #include "google/protobuf/any.pb.h"
21 #include "absl/strings/str_format.h"
22 #include "fcp/client/federated_select.h"
23 #include "fcp/protos/plan.pb.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/platform/stringpiece.h"
32 
33 namespace fcp {
34 
35 namespace {
36 
37 REGISTER_OP("MakeSlicesSelectorExampleSelector")
38     .Input("served_at_id: string")
39     .Input("keys: int32")
40     .Output("serialized_proto: string")
41     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
42 
43 class MakeSlicesSelectorExampleSelectorOp : public tensorflow::OpKernel {
44  public:
MakeSlicesSelectorExampleSelectorOp(tensorflow::OpKernelConstruction * context)45   explicit MakeSlicesSelectorExampleSelectorOp(
46       tensorflow::OpKernelConstruction* context)
47       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)48   void Compute(tensorflow::OpKernelContext* context) override {
49     const tensorflow::Tensor* served_at_id_tensor;
50     OP_REQUIRES_OK(context,
51                    context->input("served_at_id", &served_at_id_tensor));
52     std::string served_at_id =
53         served_at_id_tensor->scalar<tensorflow::tstring>()();
54 
55     const tensorflow::Tensor* keys_tensor;
56     OP_REQUIRES_OK(context, context->input("keys", &keys_tensor));
57     tensorflow::TTypes<int32_t>::ConstFlat keys = keys_tensor->flat<int32_t>();
58 
59     google::internal::federated::plan::SlicesSelector slices_selector;
60     slices_selector.set_served_at_id(std::move(served_at_id));
61     slices_selector.mutable_keys()->Reserve(keys.size());
62     for (size_t i = 0; i < keys.size(); i++) {
63       slices_selector.add_keys(keys(i));
64     }
65 
66     google::internal::federated::plan::ExampleSelector example_selector;
67     example_selector.mutable_criteria()->PackFrom(slices_selector);
68     example_selector.set_collection_uri(
69         fcp::client::kFederatedSelectCollectionUri);
70     // `resumption_token` not set.
71 
72     tensorflow::Tensor* output_tensor;
73     OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
74     output_tensor->scalar<tensorflow::tstring>()() =
75         example_selector.SerializeAsString();
76   }
77 };
78 
79 REGISTER_KERNEL_BUILDER(
80     Name("MakeSlicesSelectorExampleSelector").Device(tensorflow::DEVICE_CPU),
81     MakeSlicesSelectorExampleSelectorOp);
82 
83 }  // namespace
84 
85 }  // namespace fcp
86