1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <string> 18 #include <utility> 19 20 #include "google/protobuf/any.pb.h" 21 #include "absl/strings/str_format.h" 22 #include "fcp/client/federated_select.h" 23 #include "fcp/protos/plan.pb.h" 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/op_requires.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/platform/errors.h" 30 #include "tensorflow/core/platform/status.h" 31 #include "tensorflow/core/platform/stringpiece.h" 32 33 namespace fcp { 34 35 namespace { 36 37 REGISTER_OP("MakeSlicesSelectorExampleSelector") 38 .Input("served_at_id: string") 39 .Input("keys: int32") 40 .Output("serialized_proto: string") 41 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 42 43 class MakeSlicesSelectorExampleSelectorOp : public tensorflow::OpKernel { 44 public: MakeSlicesSelectorExampleSelectorOp(tensorflow::OpKernelConstruction * context)45 explicit MakeSlicesSelectorExampleSelectorOp( 46 tensorflow::OpKernelConstruction* context) 47 : OpKernel(context) {} Compute(tensorflow::OpKernelContext * context)48 void Compute(tensorflow::OpKernelContext* context) override { 49 const tensorflow::Tensor* served_at_id_tensor; 50 OP_REQUIRES_OK(context, 51 context->input("served_at_id", &served_at_id_tensor)); 52 std::string served_at_id = 53 served_at_id_tensor->scalar<tensorflow::tstring>()(); 54 55 const tensorflow::Tensor* keys_tensor; 56 OP_REQUIRES_OK(context, context->input("keys", &keys_tensor)); 57 tensorflow::TTypes<int32_t>::ConstFlat keys = keys_tensor->flat<int32_t>(); 58 59 google::internal::federated::plan::SlicesSelector slices_selector; 60 slices_selector.set_served_at_id(std::move(served_at_id)); 61 slices_selector.mutable_keys()->Reserve(keys.size()); 62 for (size_t i = 0; i < keys.size(); i++) { 63 slices_selector.add_keys(keys(i)); 64 } 65 66 google::internal::federated::plan::ExampleSelector example_selector; 67 example_selector.mutable_criteria()->PackFrom(slices_selector); 68 example_selector.set_collection_uri( 69 fcp::client::kFederatedSelectCollectionUri); 70 // `resumption_token` not set. 71 72 tensorflow::Tensor* output_tensor; 73 OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor)); 74 output_tensor->scalar<tensorflow::tstring>()() = 75 example_selector.SerializeAsString(); 76 } 77 }; 78 79 REGISTER_KERNEL_BUILDER( 80 Name("MakeSlicesSelectorExampleSelector").Device(tensorflow::DEVICE_CPU), 81 MakeSlicesSelectorExampleSelectorOp); 82 83 } // namespace 84 85 } // namespace fcp 86