1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <fcntl.h>
18
19 #include <string>
20
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "fcp/tensorflow/serve_slices_registry.h"
26 #include "google/protobuf/io/zero_copy_stream_impl.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/platform/status_matchers.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 #include "tensorflow/core/public/session.h"
33
34 namespace fcp {
35 namespace {
36
37 using ::testing::_;
38 using ::testing::HasSubstr;
39 using ::testing::Return;
40
41 using MockServeSlicesCallback = ::testing::MockFunction<std::string(
42 RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
43 std::vector<std::string>, absl::string_view, absl::string_view,
44 absl::string_view)>;
45
46 // Constants related to the GraphDef we test with
47 // See make_serve_slices_test_graph.py
48
49 char const* const kExampleGraphPath =
50 "fcp/tensorflow/serve_slices_test.pbtxt";
51 char const* const kCallbackTokenPlaceholderName = "callback_token";
52 char const* const kServedAtTensorName = "served_at_id:0";
53
54 // Loads the example graph created by `make_serve_slices_test_graph.py`.
LoadExampleGraph()55 tensorflow::GraphDef LoadExampleGraph() {
56 int fd = open(kExampleGraphPath, O_RDONLY);
57 CHECK(fd != -1) << "Failed to open example graph at path "
58 << kExampleGraphPath;
59
60 google::protobuf::io::FileInputStream fs(fd);
61 fs.SetCloseOnDelete(true);
62
63 tensorflow::GraphDef graph;
64 bool parsed = google::protobuf::TextFormat::Parse(&fs, &graph);
65 CHECK(parsed) << "Invalid text-format GraphDef";
66
67 return graph;
68 }
69
70 // Loads the example graph created by `make_serve_slices_test_graph.py` into a
71 // `tensorflow::Session`.
PrepareExampleGraphSession()72 std::unique_ptr<tensorflow::Session> PrepareExampleGraphSession() {
73 tensorflow::GraphDef graph = LoadExampleGraph();
74
75 std::unique_ptr<tensorflow::Session> session;
76 {
77 tensorflow::SessionOptions options;
78 tensorflow::Session* raw_session = nullptr;
79 tensorflow::Status session_new_status =
80 tensorflow::NewSession(options, &raw_session);
81 TF_CHECK_OK(session_new_status);
82 session = std::unique_ptr<tensorflow::Session>(raw_session);
83 }
84
85 tensorflow::Status graph_build_status = session->Create(graph);
86 TF_CHECK_OK(graph_build_status);
87 return session;
88 }
89
90 class ServeSlicesOpTest : public ::testing::Test {
91 protected:
SetUp()92 void SetUp() override { session_ = PrepareExampleGraphSession(); }
93
94 // Runs a `ServeSlices` session and returns the result.
95 //
96 // Inputs:
97 // callback_token: A `tensorflow::Tensor` to use as the `callback_token`
98 // argument to `ServeSlices`. For successful calls, this must be a
99 // `RandomToken` corresponding to the `HostObjectRegistration returned by
100 // `register_serve_slices_callback`.
101 // served_at_id_out: An output parameter into which the `served_at_id`
102 // returned from `ServeSlices` is stored.
103 //
104 // Outputs:
105 // The status of the `session.Run` invocation.
RunSession(tensorflow::Tensor callback_token,tensorflow::Tensor & served_at_id_out)106 tensorflow::Status RunSession(tensorflow::Tensor callback_token,
107 tensorflow::Tensor& served_at_id_out) {
108 std::vector<tensorflow::Tensor> outputs;
109 tensorflow::Status run_status =
110 session_->Run({{kCallbackTokenPlaceholderName, callback_token}},
111 {kServedAtTensorName}, {}, &outputs);
112
113 if (run_status.ok()) {
114 CHECK(outputs.size() == 1)
115 << "Expected one output, found " << outputs.size();
116 served_at_id_out = outputs[0];
117 }
118
119 return run_status;
120 }
121
122 // Runs a `ServeSlices` session and returns the result.
123 //
124 // This method is similar to `RunSession`, but it expects that the run is
125 // successful and enforces that the inputs and outputs are correctly-typed.
126 //
127 // Inputs:
128 // callback_token: The `CallbackToken` of the callback to invoke from
129 // `ServeSlices`.
130 //
131 // Outputs:
132 // The `served_at_id` returned from `ServeSlices`.
RunSessionExpectingSuccess(RandomToken callback_token)133 std::string RunSessionExpectingSuccess(RandomToken callback_token) {
134 tensorflow::Tensor served_at_id_out;
135 TF_CHECK_OK(RunSession(tensorflow::Tensor(callback_token.ToString()),
136 served_at_id_out));
137 return served_at_id_out.scalar<tensorflow::tstring>()();
138 }
139
140 private:
141 std::unique_ptr<tensorflow::Session> session_;
142 };
143
TEST_F(ServeSlicesOpTest,SessionRunCallsBackIntoCPP)144 TEST_F(ServeSlicesOpTest, SessionRunCallsBackIntoCPP) {
145 std::string mock_served_at_id = "mock_served_at_id";
146 MockServeSlicesCallback mock_callback;
147 HostObjectRegistration callback_registration =
148 register_serve_slices_callback(mock_callback.AsStdFunction());
149 RandomToken callback_token = callback_registration.token();
150 EXPECT_CALL(mock_callback, Call(callback_token, _, _, _, _, _, _, _))
151 .WillOnce(Return(mock_served_at_id));
152 std::string served_at_id = RunSessionExpectingSuccess(callback_token);
153 EXPECT_EQ(served_at_id, "mock_served_at_id");
154 }
155
TEST_F(ServeSlicesOpTest,SessionRunFailsOnMissingCallback)156 TEST_F(ServeSlicesOpTest, SessionRunFailsOnMissingCallback) {
157 std::optional<RandomToken> callback_token;
158 {
159 MockServeSlicesCallback mock_callback;
160 HostObjectRegistration callback_registration =
161 register_serve_slices_callback(mock_callback.AsStdFunction());
162 callback_token = callback_registration.token();
163 // The registration gets destructed here.
164 }
165 tensorflow::Tensor callback_token_tensor(callback_token->ToString());
166 tensorflow::Tensor served_at_id_out;
167 tensorflow::Status status =
168 RunSession(callback_token_tensor, served_at_id_out);
169 // Remove the cast after TF 2.12 is released and used in FCP.
170 EXPECT_THAT(
171 status,
172 tensorflow::testing::StatusIs(
173 static_cast<tsl::errors::Code>(absl::StatusCode::kInvalidArgument),
174 HasSubstr("No `ServeSlices` callback found")));
175 }
176
177 } // namespace
178 } // namespace fcp
179