xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/serve_slices_registry_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "fcp/tensorflow/serve_slices_registry.h"
16 
17 #include <string>
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "fcp/base/random_token.h"
22 #include "fcp/tensorflow/host_object.h"
23 #include "tensorflow/core/framework/tensor.h"
24 
25 namespace fcp {
26 namespace {
27 
28 using ::testing::_;
29 using ::testing::Return;
30 
31 using MockServeSlicesCallback = ::testing::MockFunction<std::string(
32     RandomToken, std::vector<tensorflow::Tensor>, int32_t, absl::string_view,
33     std::vector<std::string>, absl::string_view, absl::string_view,
34     absl::string_view)>;
35 
TEST(ServeSlicesRegistryTest,CanRegisterGetAndUnregisterCallback)36 TEST(ServeSlicesRegistryTest, CanRegisterGetAndUnregisterCallback) {
37   MockServeSlicesCallback mock_callback;
38   std::optional<RandomToken> id = std::nullopt;
39   {
40     HostObjectRegistration registration =
41         register_serve_slices_callback(mock_callback.AsStdFunction());
42     id = registration.token();
43     std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
44         get_serve_slices_callback(*id);
45     ASSERT_TRUE(returned_callback.has_value());
46 
47     std::string mock_served_at_id = "served_at_id";
48     EXPECT_CALL(mock_callback, Call(*id, _, _, _, _, _, _, _))
49         .WillOnce(Return(mock_served_at_id));
50     EXPECT_EQ(mock_served_at_id,
51               (**returned_callback)(*id, {}, 0, "", {}, "", "", ""));
52   }
53   // Check that it is gone after `registration` has been destroyed.
54   EXPECT_EQ(std::nullopt, get_serve_slices_callback(*id));
55 }
56 
TEST(ServeSlicesRegistryTest,CanRegisterMultipleDifferentCallbacks)57 TEST(ServeSlicesRegistryTest, CanRegisterMultipleDifferentCallbacks) {
58   constexpr int8_t num_callbacks = 5;
59   MockServeSlicesCallback mock_callbacks[num_callbacks];
60   std::vector<HostObjectRegistration> callback_tokens;
61   // Register all callbacks.
62   for (int8_t i = 0; i < num_callbacks; i++) {
63     callback_tokens.push_back(
64         register_serve_slices_callback(mock_callbacks[i].AsStdFunction()));
65   }
66   // Get and invoke all callbacks.
67   for (int8_t i = 0; i < num_callbacks; i++) {
68     RandomToken id = callback_tokens[i].token();
69     std::optional<std::shared_ptr<ServeSlicesCallback>> returned_callback =
70         get_serve_slices_callback(id);
71     ASSERT_TRUE(returned_callback.has_value());
72 
73     std::string mock_served_at_id = absl::StrCat("served_at_id_", i);
74     EXPECT_CALL(mock_callbacks[i], Call(id, _, _, _, _, _, _, _))
75         .WillOnce(Return(mock_served_at_id));
76     EXPECT_EQ(mock_served_at_id,
77               (**returned_callback)(id, {}, 0, "", {}, "", "", ""));
78   }
79 }
80 
81 }  // namespace
82 }  // namespace fcp
83