xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/serve_slices_op_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <fcntl.h>
18 
19 #include <string>
20 
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "fcp/tensorflow/serve_slices_registry.h"
26 #include "google/protobuf/io/zero_copy_stream_impl.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/platform/status_matchers.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 #include "tensorflow/core/public/session.h"
33 
34 namespace fcp {
35 namespace {
36 
37 using ::testing::_;
38 using ::testing::HasSubstr;
39 using ::testing::Return;
40 
41 using MockServeSlicesCallback = ::testing::MockFunction<std::string(
42     RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
43     std::vector<std::string>, absl::string_view, absl::string_view,
44     absl::string_view)>;
45 
46 // Constants related to the GraphDef we test with
47 // See make_serve_slices_test_graph.py
48 
49 char const* const kExampleGraphPath =
50     "fcp/tensorflow/serve_slices_test.pbtxt";
51 char const* const kCallbackTokenPlaceholderName = "callback_token";
52 char const* const kServedAtTensorName = "served_at_id:0";
53 
54 // Loads the example graph created by `make_serve_slices_test_graph.py`.
LoadExampleGraph()55 tensorflow::GraphDef LoadExampleGraph() {
56   int fd = open(kExampleGraphPath, O_RDONLY);
57   CHECK(fd != -1) << "Failed to open example graph at path "
58                   << kExampleGraphPath;
59 
60   google::protobuf::io::FileInputStream fs(fd);
61   fs.SetCloseOnDelete(true);
62 
63   tensorflow::GraphDef graph;
64   bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
65   CHECK(parsed) << "Invalid text-format GraphDef";
66 
67   return graph;
68 }
69 
70 // Loads the example graph created by `make_serve_slices_test_graph.py` into a
71 // `tensorflow::Session`.
PrepareExampleGraphSession()72 std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
73   tensorflow::GraphDef graph = LoadExampleGraph();
74 
75   std::unique_ptr<tensorflow::Session> session;
76   {
77     tensorflow::SessionOptions options;
78     tensorflow::Session* raw_session = nullptr;
79     tensorflow::Status session_new_status =
80         tensorflow::NewSession(options, &raw_session);
81     TF_CHECK_OK(session_new_status);
82     session = std::unique_ptr<tensorflow::Session>(raw_session);
83   }
84 
85   tensorflow::Status graph_build_status = session->Create(graph);
86   TF_CHECK_OK(graph_build_status);
87   return session;
88 }
89 
90 class ServeSlicesOpTest : public ::testing::Test {
91  protected:
SetUp()92   void SetUp() override { session_ = PrepareExampleGraphSession(); }
93 
94   // Runs a `ServeSlices` session and returns the result.
95   //
96   // Inputs:
97   //   callback_token: A `tensorflow::Tensor` to use as the `callback_token`
98   //   argument to `ServeSlices`. For successful calls, this must be a
99   //   `RandomToken` corresponding to the `HostObjectRegistration returned by
100   //    `register_serve_slices_callback`.
101   //   served_at_id_out: An output parameter into which the `served_at_id`
102   //     returned from `ServeSlices` is stored.
103   //
104   // Outputs:
105   //   The status of the `session.Run` invocation.
RunSession(tensorflow::Tensor callback_token,tensorflow::Tensor & served_at_id_out)106   tensorflow::Status RunSession(tensorflow::Tensor callback_token,
107                                 tensorflow::Tensor& served_at_id_out) {
108     std::vector<tensorflow::Tensor> outputs;
109     tensorflow::Status run_status =
110         session_->Run({{kCallbackTokenPlaceholderName, callback_token}},
111                       {kServedAtTensorName}, {}, &outputs);
112 
113     if (run_status.ok()) {
114       CHECK(outputs.size() == 1)
115           << "Expected one output, found " << outputs.size();
116       served_at_id_out = outputs[0];
117     }
118 
119     return run_status;
120   }
121 
122   // Runs a `ServeSlices` session and returns the result.
123   //
124   // This method is similar to `RunSession`, but it expects that the run is
125   // successful and enforces that the inputs and outputs are correctly-typed.
126   //
127   // Inputs:
128   //   callback_token: The `CallbackToken` of the callback to invoke from
129   //     `ServeSlices`.
130   //
131   // Outputs:
132   //   The `served_at_id` returned from `ServeSlices`.
RunSessionExpectingSuccess(RandomToken callback_token)133   std::string RunSessionExpectingSuccess(RandomToken callback_token) {
134     tensorflow::Tensor served_at_id_out;
135     TF_CHECK_OK(RunSession(tensorflow::Tensor(callback_token.ToString()),
136                            served_at_id_out));
137     return served_at_id_out.scalar<tensorflow::tstring>()();
138   }
139 
140  private:
141   std::unique_ptr<tensorflow::Session> session_;
142 };
143 
TEST_F(ServeSlicesOpTest,SessionRunCallsBackIntoCPP)144 TEST_F(ServeSlicesOpTest, SessionRunCallsBackIntoCPP) {
145   std::string mock_served_at_id = "mock_served_at_id";
146   MockServeSlicesCallback mock_callback;
147   HostObjectRegistration callback_registration =
148       register_serve_slices_callback(mock_callback.AsStdFunction());
149   RandomToken callback_token = callback_registration.token();
150   EXPECT_CALL(mock_callback, Call(callback_token, _, _, _, _, _, _, _))
151       .WillOnce(Return(mock_served_at_id));
152   std::string served_at_id = RunSessionExpectingSuccess(callback_token);
153   EXPECT_EQ(served_at_id, "mock_served_at_id");
154 }
155 
TEST_F(ServeSlicesOpTest,SessionRunFailsOnMissingCallback)156 TEST_F(ServeSlicesOpTest, SessionRunFailsOnMissingCallback) {
157   std::optional<RandomToken> callback_token;
158   {
159     MockServeSlicesCallback mock_callback;
160     HostObjectRegistration callback_registration =
161         register_serve_slices_callback(mock_callback.AsStdFunction());
162     callback_token = callback_registration.token();
163     // The registration gets destructed here.
164   }
165   tensorflow::Tensor callback_token_tensor(callback_token->ToString());
166   tensorflow::Tensor served_at_id_out;
167   tensorflow::Status status =
168       RunSession(callback_token_tensor, served_at_id_out);
169   // Remove the cast after TF 2.12 is released and used in FCP.
170   EXPECT_THAT(
171       status,
172       tensorflow::testing::StatusIs(
173           static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument),
174           HasSubstr("No `ServeSlices` callback found")));
175 }
176 
177 }  // namespace
178 }  // namespace fcp
179