1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "fcp/tensorflow/serve_slices_registry.h"
16
17 #include <string>
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "fcp/base/random_token.h"
22 #include "fcp/tensorflow/host_object.h"
23 #include "tensorflow/core/framework/tensor.h"
24
25 namespace fcp {
26 namespace {
27
28 using ::testing::_;
29 using ::testing::Return;
30
31 using MockServeSlicesCallback = ::testing::MockFunction<std::string(
32 RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
33 std::vector<std::string>, absl::string_view, absl::string_view,
34 absl::string_view)>;
35
TEST(ServeSlicesRegistryTest,CanRegisterGetAndUnregisterCallback)36 TEST(ServeSlicesRegistryTest, CanRegisterGetAndUnregisterCallback) {
37 MockServeSlicesCallback mock_callback;
38 std::optional<RandomToken> id = std::nullopt;
39 {
40 HostObjectRegistration registration =
41 register_serve_slices_callback(mock_callback.AsStdFunction());
42 id = registration.token();
43 std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
44 get_serve_slices_callback(*id);
45 ASSERT_TRUE(returned_callback.has_value());
46
47 std::string mock_served_at_id = "served_at_id";
48 EXPECT_CALL(mock_callback, Call(*id, _, _, _, _, _, _, _))
49 .WillOnce(Return(mock_served_at_id));
50 EXPECT_EQ(mock_served_at_id,
51 (**returned_callback)(*id, {}, 0, "", {}, "", "", ""));
52 }
53 // Check that it is gone after `registration` has been destroyed.
54 EXPECT_EQ(std::nullopt, get_serve_slices_callback(*id));
55 }
56
TEST(ServeSlicesRegistryTest,CanRegisterMultipleDifferentCallbacks)57 TEST(ServeSlicesRegistryTest, CanRegisterMultipleDifferentCallbacks) {
58 constexpr int8_t num_callbacks = 5;
59 MockServeSlicesCallback mock_callbacks[num_callbacks];
60 std::vector<HostObjectRegistration> callback_tokens;
61 // Register all callbacks.
62 for (int8_t i = 0; i < num_callbacks; i++) {
63 callback_tokens.push_back(
64 register_serve_slices_callback(mock_callbacks[i].AsStdFunction()));
65 }
66 // Get and invoke all callbacks.
67 for (int8_t i = 0; i < num_callbacks; i++) {
68 RandomToken id = callback_tokens[i].token();
69 std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
70 get_serve_slices_callback(id);
71 ASSERT_TRUE(returned_callback.has_value());
72
73 std::string mock_served_at_id = absl::StrCat("served_at_id_", i);
74 EXPECT_CALL(mock_callbacks[i], Call(id, _, _, _, _, _, _, _))
75 .WillOnce(Return(mock_served_at_id));
76 EXPECT_EQ(mock_served_at_id,
77 (**returned_callback)(id, {}, 0, "", {}, "", "", ""));
78 }
79 }
80
81 } // namespace
82 } // namespace fcp
83