1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
17 
18 #include <memory>
19 #include <string>
20 #include <thread>  // NOLINT(build/c++11)
21 #include <tuple>
22 #include <utility>
23 
24 #include "absl/base/call_once.h"
25 #include "absl/container/fixed_array.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
28 #include "tensorflow/core/lib/core/notification.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/cpu_info.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/status_matchers.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/protobuf/error_codes.pb.h"
38 
39 namespace tensorflow {
40 namespace serving {
41 namespace {
42 
43 class FakeTask : public BatchTask {
44  public:
FakeTask(size_t size)45   explicit FakeTask(size_t size) : size_(size) {}
46 
47   ~FakeTask() override = default;
48 
size() const49   size_t size() const override { return size_; }
50 
51  private:
52   const size_t size_;
53 
54   TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
55 };
56 
57 using Queue = BatchScheduler<FakeTask>;
58 using Scheduler = SharedBatchScheduler<FakeTask>;
59 using QueueOptions = Scheduler::QueueOptions;
60 using SplitFunc =
61     std::function<Status(std::unique_ptr<FakeTask>* input_task,
62                          int first_output_task_size, int input_batch_size_limit,
63                          std::vector<std::unique_ptr<FakeTask>>* output_tasks)>;
64 
65 // Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
66 // that task. Returns the resulting status.
ScheduleTask(size_t task_size,BatchScheduler<FakeTask> * scheduler)67 Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
68   std::unique_ptr<FakeTask> task(new FakeTask(task_size));
69   Status status = scheduler->Schedule(&task);
70   // Schedule() should have consumed 'task' iff it returned Status::OK.
71   CHECK_EQ(status.ok(), task == nullptr);
72   return status;
73 }
74 
75 // Creates a thread that waits on 'start' and then advances the fake clock in
76 // 'env' in a loop until 'stop' is notified. Useful for allowing objects that
77 // use the clock to be destroyed.
CreateFakeClockAdvancerThread(test_util::FakeClockEnv * env,Notification * start,Notification * stop)78 std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
79     test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
80   return std::unique_ptr<Thread>(Env::Default()->StartThread(
81       {}, "FakeClockAdvancerThread", [env, start, stop] {
82         start->WaitForNotification();
83         while (!stop->HasBeenNotified()) {
84           env->AdvanceByMicroseconds(10);
85           Env::Default()->SleepForMicroseconds(10);
86         }
87       }));
88 }
89 
90 // Creates a shared-batch-scheduler.
CreateSharedBatchScheduler(int num_batch_threads,Env * env=Env::Default ())91 std::shared_ptr<Scheduler> CreateSharedBatchScheduler(
92     int num_batch_threads, Env* env = Env::Default()) {
93   Scheduler::Options options;
94   options.num_batch_threads = num_batch_threads;
95   options.env = env;
96 
97   std::shared_ptr<Scheduler> shared_batch_scheduler;
98   TF_CHECK_OK(Scheduler::Create(options, &shared_batch_scheduler));
99 
100   return shared_batch_scheduler;
101 }
102 
103 // Creates a queue with the given `queue_options`.
104 //
105 // Caller takes ownership of returned queue.
CreateQueue(std::shared_ptr<Scheduler> scheduler,Scheduler::QueueOptions queue_options,internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback)106 std::unique_ptr<Queue> CreateQueue(
107     std::shared_ptr<Scheduler> scheduler, Scheduler::QueueOptions queue_options,
108     internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback) {
109   std::unique_ptr<BatchScheduler<FakeTask>> queue;
110   TF_CHECK_OK(
111       scheduler->AddQueue(queue_options, process_batch_callback, &queue));
112   return queue;
113 }
114 
115 // Creates QueueOptions based on input parameters.
CreateQueueOptions(size_t max_execution_batch_size,size_t input_batch_size_limit,size_t batch_timeout_micros,size_t max_enqueued_batches,bool enable_large_batch_splitting,bool enable_lazy_split,SplitFunc split_func)116 QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
117                                 size_t input_batch_size_limit,
118                                 size_t batch_timeout_micros,
119                                 size_t max_enqueued_batches,
120                                 bool enable_large_batch_splitting,
121                                 bool enable_lazy_split, SplitFunc split_func) {
122   QueueOptions queue_options;
123   queue_options.max_enqueued_batches = max_enqueued_batches;
124   queue_options.max_execution_batch_size = max_execution_batch_size;
125   queue_options.input_batch_size_limit = input_batch_size_limit;
126   queue_options.batch_timeout_micros = batch_timeout_micros;
127   queue_options.enable_large_batch_splitting = enable_large_batch_splitting;
128   queue_options.enable_lazy_split = enable_lazy_split;
129   if (enable_large_batch_splitting) {
130     queue_options.split_input_task_func = split_func;
131   }
132   return queue_options;
133 }
134 
135 class SharedBatchSchedulerTest
136     : public ::testing::TestWithParam<std::tuple<bool, bool>> {
137  protected:
CreateQueueOptions(size_t max_execution_batch_size,size_t input_batch_size_limit,size_t batch_timeout_micros,size_t max_enqueued_batches)138   QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
139                                   size_t input_batch_size_limit,
140                                   size_t batch_timeout_micros,
141                                   size_t max_enqueued_batches) {
142     return tensorflow::serving::CreateQueueOptions(
143         max_execution_batch_size, input_batch_size_limit, batch_timeout_micros,
144         max_enqueued_batches, enable_input_batch_split(), enable_lazy_split(),
145         get_split_func());
146   }
enable_input_batch_split() const147   bool enable_input_batch_split() const { return std::get<0>(GetParam()); }
148 
enable_lazy_split() const149   bool enable_lazy_split() const { return std::get<1>(GetParam()); }
150 
get_split_func() const151   SplitFunc get_split_func() const {
152     if (enable_input_batch_split()) {
153       return
154           [](std::unique_ptr<FakeTask>* input_task,
155              int open_batch_remaining_slot, int max_batch_size,
156              std::vector<std::unique_ptr<FakeTask>>* output_tasks) -> Status {
157             std::unique_ptr<FakeTask> owned_input_task = std::move(*input_task);
158             const int input_task_size = owned_input_task->size();
159 
160             const internal::InputSplitMetadata input_split_metadata(
161                 input_task_size, open_batch_remaining_slot, max_batch_size);
162 
163             const absl::FixedArray<int> task_sizes =
164                 input_split_metadata.task_sizes();
165             const int num_batches = task_sizes.size();
166 
167             output_tasks->resize(num_batches);
168             for (int i = 0; i < num_batches; i++) {
169               (*output_tasks)[i] = std::make_unique<FakeTask>(task_sizes[i]);
170             }
171 
172             return OkStatus();
173           };
174     }
175     return nullptr;
176   }
177 };
178 
TEST_P(SharedBatchSchedulerTest,Basic)179 TEST_P(SharedBatchSchedulerTest, Basic) {
180   for (int num_batch_threads : {1, 2, 3}) {
181     for (const bool delete_scheduler_early : {false, true}) {
182       for (const bool delete_queue_1_early : {false, true}) {
183         bool queue_0_callback_called = false;
184         auto queue_0_callback =
185             [&queue_0_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
186               queue_0_callback_called = true;
187               ASSERT_TRUE(batch->IsClosed());
188               ASSERT_EQ(3, batch->num_tasks());
189               EXPECT_EQ(1, batch->task(0).size());
190               EXPECT_EQ(3, batch->task(1).size());
191               EXPECT_EQ(5, batch->task(2).size());
192             };
193         bool queue_1_callback_called = false;
194         auto queue_1_callback =
195             [&queue_1_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
196               queue_1_callback_called = true;
197               ASSERT_TRUE(batch->IsClosed());
198               ASSERT_EQ(2, batch->num_tasks());
199               EXPECT_EQ(2, batch->task(0).size());
200               EXPECT_EQ(4, batch->task(1).size());
201             };
202         {
203           auto scheduler = CreateSharedBatchScheduler(num_batch_threads);
204 
205           // Create two queues.
206 
207           const size_t input_batch_size_limit = 10;
208           const size_t batch_timeout_micros = 1 * 1000 * 1000;  // 1 second
209           const size_t max_enqueued_batches = 2;
210           const auto queue_options =
211               CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
212                                  batch_timeout_micros, max_enqueued_batches);
213           auto queue_0 =
214               CreateQueue(scheduler, queue_options, queue_0_callback);
215 
216           auto queue_1 =
217               CreateQueue(scheduler, queue_options, queue_1_callback);
218 
219           if (delete_scheduler_early) {
220             // Delete our copy of the scheduler. The queues should keep it alive
221             // under the covers.
222             scheduler = nullptr;
223           }
224 
225           // Submit tasks to the two queues, and (optionally) remove the queues.
226           TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
227           TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
228           TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
229           TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
230           if (delete_queue_1_early) {
231             queue_1 = nullptr;
232           }
233           TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
234         }
235         EXPECT_TRUE(queue_0_callback_called);
236         EXPECT_TRUE(queue_1_callback_called);
237       }
238     }
239   }
240 }
241 
TEST_P(SharedBatchSchedulerTest,ObeyBatchSizeConstraint)242 TEST_P(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
243   // Set up a fake clock, which only advances when we explicitly tell it to.
244   test_util::FakeClockEnv env(Env::Default());
245   Notification start_teardown, stop_teardown;
246   std::unique_ptr<Thread> teardown_thread =
247       CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
248   // Set up a callback that captures the batches' task sizes.
249   mutex mu;
250   std::vector<std::vector<size_t>> callback_data;
251   Notification all_batches_processed;
252   auto callback = [&mu, &callback_data, &all_batches_processed](
253                       std::unique_ptr<Batch<FakeTask>> batch) {
254     ASSERT_TRUE(batch->IsClosed());
255     std::vector<size_t> batch_data;
256     batch_data.reserve(batch->num_tasks());
257     for (int i = 0; i < batch->num_tasks(); ++i) {
258       batch_data.push_back(batch->mutable_task(i)->size());
259     }
260     {
261       mutex_lock l(mu);
262       callback_data.push_back(batch_data);
263       if (callback_data.size() == 2) {
264         all_batches_processed.Notify();
265       }
266     }
267   };
268 
269   // Run a batch scheduler and inject some tasks.
270   {
271     auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads=*/2, &env);
272 
273     const size_t input_batch_size_limit = 10;
274     const size_t batch_timeout_micros = 10 * 1000;  // 10 milli-seconds
275     const size_t max_enqueued_batches = 2;
276     auto queue = CreateQueue(
277         scheduler,
278         CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
279                            batch_timeout_micros, max_enqueued_batches),
280         callback);
281 
282     if (enable_input_batch_split()) {
283       // First batch.
284       TF_ASSERT_OK(ScheduleTask(3, queue.get()));
285       TF_ASSERT_OK(ScheduleTask(5, queue.get()));
286 
287       // Second batch
288       // Task spans over first batch and second batch, so contributes two tasks.
289       TF_ASSERT_OK(ScheduleTask(3 /* (3+5) + 3 > 10 */, queue.get()));
290       TF_ASSERT_OK(ScheduleTask(1, queue.get()));
291       TF_ASSERT_OK(ScheduleTask(6, queue.get()));
292       TF_ASSERT_OK(ScheduleTask(1, queue.get()));
293     } else {
294       // First batch.
295       TF_ASSERT_OK(ScheduleTask(3, queue.get()));
296       TF_ASSERT_OK(ScheduleTask(5, queue.get()));
297 
298       // Second batch (due to size overage).
299       TF_ASSERT_OK(ScheduleTask(3 /* (3+5) + 3 > 10 */, queue.get()));
300       TF_ASSERT_OK(ScheduleTask(1, queue.get()));
301       TF_ASSERT_OK(ScheduleTask(6, queue.get()));
302       // (Empty third batch, since the second batch exactly hit the size limit,
303       // which should never get sent to the callback.)
304     }
305 
306     // Advance clock to trigger batch processing.
307     env.AdvanceByMicroseconds(20 * 1000);
308     all_batches_processed.WaitForNotification();
309     // Expect a certain grouping of the tasks into batches.
310     if (enable_input_batch_split()) {
311       EXPECT_THAT(
312           callback_data,
313           ::testing::UnorderedElementsAreArray(std::vector<std::vector<size_t>>{
314               std::vector<size_t>{3, 5, 2}, std::vector<size_t>{1, 1, 6, 1}}));
315     } else {
316       EXPECT_THAT(callback_data,
317                   ::testing::UnorderedElementsAreArray(
318                       std::vector<std::vector<size_t>>{{3, 5}, {3, 1, 6}}));
319     }
320     start_teardown.Notify();
321   }
322   stop_teardown.Notify();
323 }
324 
TEST_P(SharedBatchSchedulerTest,ObeysTimeout)325 TEST_P(SharedBatchSchedulerTest, ObeysTimeout) {
326   // Set up a fake clock, which only advances when we explicitly tell it to.
327   test_util::FakeClockEnv env(Env::Default());
328   Notification start_teardown, stop_teardown;
329   std::unique_ptr<Thread> teardown_thread =
330       CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
331 
332   {
333     Notification first_batch_processed, second_batch_processed,
334         third_batch_processed;
335     bool notify_first_batch = false, notify_second_batch = false,
336          notify_third_batch = false;
337     auto callback = [&](std::unique_ptr<Batch<FakeTask>> batch) {
338       ASSERT_TRUE(batch->IsClosed());
339       if (notify_first_batch && (!first_batch_processed.HasBeenNotified())) {
340         first_batch_processed.Notify();
341         return;
342       }
343       if (notify_second_batch && (!second_batch_processed.HasBeenNotified())) {
344         second_batch_processed.Notify();
345         return;
346       }
347       if (notify_third_batch && (!third_batch_processed.HasBeenNotified())) {
348         third_batch_processed.Notify();
349         return;
350       }
351 
352       EXPECT_TRUE(false) << "Unexpected condition";
353     };
354 
355     auto scheduler = CreateSharedBatchScheduler(1, &env);
356 
357     const size_t input_batch_size_limit = 4;
358     const size_t batch_timeout_micros = 10;
359     const size_t max_enqueued_batches = 2;
360     QueueOptions options =
361         CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
362                            batch_timeout_micros, max_enqueued_batches);
363     auto queue = CreateQueue(scheduler, options, callback);
364 
365     // Create an underfull batch, and ensure that it gets processed when the
366     // clock hits the timeout.
367     TF_ASSERT_OK(ScheduleTask(1, queue.get()));
368     env.AdvanceByMicroseconds(9);
369     Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
370     EXPECT_FALSE(first_batch_processed.HasBeenNotified());
371     notify_first_batch = true;
372     env.AdvanceByMicroseconds(1);
373     first_batch_processed.WaitForNotification();
374 
375     // Start creating a batch, while leaving the clock well below the timeout.
376     // Then submit a new task that overflows into the next batch, causing
377     // the original batch to close.
378     TF_ASSERT_OK(ScheduleTask(2, queue.get()));
379     Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
380     EXPECT_FALSE(second_batch_processed.HasBeenNotified());
381     notify_second_batch = true;
382     TF_ASSERT_OK(ScheduleTask(3, queue.get()));
383     second_batch_processed.WaitForNotification();
384 
385     // Allow the third batch to hit its timeout, and ensure it gets closed at
386     // the right time.
387     env.AdvanceByMicroseconds(9);
388     Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
389     EXPECT_FALSE(third_batch_processed.HasBeenNotified());
390     notify_third_batch = true;
391     env.AdvanceByMicroseconds(1);
392     third_batch_processed.WaitForNotification();
393 
394     start_teardown.Notify();
395   }
396   stop_teardown.Notify();
397 }
398 
TEST_P(SharedBatchSchedulerTest,ObeysTimeoutWithRealClock)399 TEST_P(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) {
400   Notification first_batch_processed, second_batch_processed;
401   auto callback = [&first_batch_processed, &second_batch_processed](
402                       std::unique_ptr<Batch<FakeTask>> batch) {
403     ASSERT_TRUE(batch->IsClosed());
404     if (batch->size() == 1) {
405       first_batch_processed.Notify();
406     } else if (batch->size() == 2) {
407       second_batch_processed.Notify();
408     } else {
409       EXPECT_TRUE(false) << "Unexpected batch size";
410     }
411   };
412 
413   auto scheduler = CreateSharedBatchScheduler(2);
414 
415   const size_t input_batch_size_limit = 10;
416   const size_t batch_timeout_micros = 100 * 1000;  // 100 milliseconds
417   const size_t max_enqueued_batches = 2;
418   auto queue = CreateQueue(
419       scheduler,
420       CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
421                          batch_timeout_micros, max_enqueued_batches),
422       callback);
423 
424   // Submit a single task that doesn't fill up the batch.
425   // Ensure that it gets processed due to the timeout.
426   TF_ASSERT_OK(ScheduleTask(1, queue.get()));
427   first_batch_processed.WaitForNotification();
428 
429   // Do it again.
430   TF_ASSERT_OK(ScheduleTask(2, queue.get()));
431   second_batch_processed.WaitForNotification();
432 }
433 
TEST_P(SharedBatchSchedulerTest,WithZeroTimeoutBatchesScheduledAsSoonAsThreadIsAvailable)434 TEST_P(SharedBatchSchedulerTest,
435        WithZeroTimeoutBatchesScheduledAsSoonAsThreadIsAvailable) {
436   // Set up a fake clock, and never advance the time.
437   test_util::FakeClockEnv env(Env::Default());
438   Notification start_teardown, stop_teardown;
439   std::unique_ptr<Thread> teardown_thread =
440       CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
441 
442   {
443     Notification first_batch_processed, second_batch_processed;
444     auto callback = [&first_batch_processed, &second_batch_processed](
445                         std::unique_ptr<Batch<FakeTask>> batch) {
446       ASSERT_TRUE(batch->IsClosed());
447       if (batch->size() == 1) {
448         first_batch_processed.Notify();
449       } else if (batch->size() == 2) {
450         second_batch_processed.Notify();
451       } else {
452         EXPECT_TRUE(false) << "Unexpected batch size";
453       }
454     };
455 
456     auto scheduler = CreateSharedBatchScheduler(2, &env);
457 
458     // Set a large batch size, so that we don't hit the batch size limit.
459     const size_t batch_size_limit = 100;
460     // Process a batch as soon as a thread is available.
461     const size_t batch_timeout_micros = 0;
462     const size_t max_enqueued_batches = 2;
463     auto queue = CreateQueue(
464         scheduler,
465         CreateQueueOptions(batch_size_limit, batch_size_limit,
466                            batch_timeout_micros, max_enqueued_batches),
467         callback);
468 
469     TF_ASSERT_OK(ScheduleTask(1, queue.get()));
470     first_batch_processed.WaitForNotification();
471     TF_ASSERT_OK(ScheduleTask(2, queue.get()));
472     second_batch_processed.WaitForNotification();
473 
474     // Shut everything down.
475     start_teardown.Notify();
476   }
477   stop_teardown.Notify();
478 }
479 
TEST_P(SharedBatchSchedulerTest,Fairness)480 TEST_P(SharedBatchSchedulerTest, Fairness) {
481   test_util::FakeClockEnv env(Env::Default());
482   Notification start_teardown, stop_teardown;
483   std::unique_ptr<Thread> teardown_thread =
484       CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
485 
486   {
487     Notification queue_0_first_batch_scheduled, queue_0_first_batch_proceed,
488         queue_0_second_batch_scheduled;
489     auto queue_0_callback = [&queue_0_first_batch_scheduled,
490                              &queue_0_first_batch_proceed,
491                              &queue_0_second_batch_scheduled](
492                                 std::unique_ptr<Batch<FakeTask>> batch) {
493       if (!queue_0_first_batch_scheduled.HasBeenNotified()) {
494         queue_0_first_batch_scheduled.Notify();
495         queue_0_first_batch_proceed.WaitForNotification();
496       } else if (!queue_0_second_batch_scheduled.HasBeenNotified()) {
497         queue_0_second_batch_scheduled.Notify();
498       }
499     };
500 
501     Notification queue_1_first_batch_scheduled, queue_1_first_batch_proceed;
502     auto queue_1_callback =
503         [&queue_1_first_batch_scheduled,
504          &queue_1_first_batch_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
505           queue_1_first_batch_scheduled.Notify();
506           queue_1_first_batch_proceed.WaitForNotification();
507         };
508 
509     auto scheduler = CreateSharedBatchScheduler(1, &env);
510     size_t input_batch_size_limit = 10;
511     QueueOptions queue_options = CreateQueueOptions(
512         input_batch_size_limit, input_batch_size_limit,
513         1 /* batch_timeout_micros */, 100 /* give plenty of room */);
514     std::vector<std::unique_ptr<BatchScheduler<FakeTask>>> queues(2);
515     TF_ASSERT_OK(
516         scheduler->AddQueue(queue_options, queue_0_callback, &queues[0]));
517     TF_ASSERT_OK(
518         scheduler->AddQueue(queue_options, queue_1_callback, &queues[1]));
519 
520     // Enqueue a batch-filling task to queue 0, and wait for it to get
521     // scheduled.
522     TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
523     env.AdvanceByMicroseconds(1);
524     queue_0_first_batch_scheduled.WaitForNotification();
525 
526     // Enqueue two more batch-filling tasks to queue 0.
527     TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
528     TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
529 
530     // Enqueue one task to queue 1, and then advance the clock so it becomes
531     // eligible for scheduling due to the timeout. Ensure that the queue 1 batch
532     // gets scheduled before the next queue 0 one.
533     TF_ASSERT_OK(ScheduleTask(1, queues[1].get()));
534     env.AdvanceByMicroseconds(1);
535     queue_0_first_batch_proceed.Notify();
536     queue_1_first_batch_scheduled.WaitForNotification();
537     Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
538     EXPECT_FALSE(queue_0_second_batch_scheduled.HasBeenNotified());
539 
540     // Shut everything down.
541     queue_1_first_batch_proceed.Notify();
542     start_teardown.Notify();
543   }
544   stop_teardown.Notify();
545 }
546 
TEST_P(SharedBatchSchedulerTest,ConstMethods)547 TEST_P(SharedBatchSchedulerTest, ConstMethods) {
548   for (const int max_enqueued_batches : {1, 2, 5}) {
549     Notification processing, proceed;
550     auto callback = [&processing,
551                      &proceed](std::unique_ptr<Batch<FakeTask>> batch) {
552       if (!processing.HasBeenNotified()) {
553         processing.Notify();
554       }
555       proceed.WaitForNotification();
556     };
557 
558     auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads*/ 1);
559 
560     const size_t input_batch_size_limit = 2;
561     const size_t batch_timeout_micros = 0;
562     auto queue = CreateQueue(
563         scheduler,
564         CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
565                            batch_timeout_micros, max_enqueued_batches),
566         callback);
567 
568     EXPECT_EQ(2, queue->max_task_size());
569     EXPECT_EQ(0, queue->NumEnqueuedTasks());
570     EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity());
571 
572     // Get one batch going on the thread, and keep the thread blocked until
573     // we're done testing the maximum queue length.
574     TF_ASSERT_OK(ScheduleTask(2, queue.get()));
575     processing.WaitForNotification();
576     EXPECT_EQ(0, queue->NumEnqueuedTasks());
577 
578     // We should be able to enqueue 'max_enqueued_batches'*2 tasks without
579     // issue.
580     for (int i = 0; i < max_enqueued_batches; ++i) {
581       EXPECT_EQ(i * 2, queue->NumEnqueuedTasks());
582       EXPECT_EQ((max_enqueued_batches - i) * 2, queue->SchedulingCapacity());
583       TF_ASSERT_OK(ScheduleTask(1, queue.get()));
584       EXPECT_EQ((i * 2) + 1, queue->NumEnqueuedTasks());
585       EXPECT_EQ((max_enqueued_batches - i) * 2 - 1,
586                 queue->SchedulingCapacity());
587       TF_ASSERT_OK(ScheduleTask(1, queue.get()));
588     }
589     EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
590     EXPECT_EQ(0, queue->SchedulingCapacity());
591 
592     // Attempting to enqueue one more task should yield an UNAVAILABLE error.
593     EXPECT_THAT(
594         ScheduleTask(1, queue.get()),
595         testing::StatusIs(error::UNAVAILABLE,
596                           "The batch scheduling queue to which this task was "
597                           "submitted is full"));
598 
599     EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
600     EXPECT_EQ(0, queue->SchedulingCapacity());
601 
602     proceed.Notify();
603   }
604 }
605 
TEST_P(SharedBatchSchedulerTest,OneFullQueueDoesntBlockOtherQueues)606 TEST_P(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) {
607   Notification queue_0_processing, queue_0_proceed;
608   auto queue_0_callback = [&queue_0_processing, &queue_0_proceed](
609                               std::unique_ptr<Batch<FakeTask>> batch) {
610     if (!queue_0_processing.HasBeenNotified()) {
611       queue_0_processing.Notify();
612       queue_0_proceed.WaitForNotification();
613     }
614   };
615 
616   Notification queue_1_first_batch_processed, queue_1_second_batch_processed,
617       queue_1_third_batch_processed;
618   auto queue_1_callback =
619       [&queue_1_first_batch_processed, &queue_1_second_batch_processed,
620        &queue_1_third_batch_processed](std::unique_ptr<Batch<FakeTask>> batch) {
621         if (batch->size() == 1) {
622           queue_1_first_batch_processed.Notify();
623         } else if (batch->size() == 2) {
624           queue_1_second_batch_processed.Notify();
625         } else if (batch->size() == 3) {
626           queue_1_third_batch_processed.Notify();
627         } else {
628           EXPECT_TRUE(false) << "Unexpected batch size";
629         }
630       };
631 
632   auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads*/ 2);
633 
634   const size_t input_batch_size_limit = 10;
635   const size_t batch_timeout_micros = 0;
636   const size_t max_enqueued_batches = 2;
637   QueueOptions queue_options =
638       CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
639                          batch_timeout_micros, max_enqueued_batches);
640 
641   std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
642   TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
643   std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
644   TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
645 
646   // Clog up queue 0.
647   TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
648   queue_0_processing.WaitForNotification();
649   Status queue_0_status;
650   do {
651     queue_0_status = ScheduleTask(1, queue_0.get());
652   } while (queue_0_status.ok());
653   EXPECT_EQ(error::UNAVAILABLE, queue_0_status.code());
654 
655   // Ensure that queue 1 still behaves normally, and lets us process tasks.
656   TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
657   queue_1_first_batch_processed.WaitForNotification();
658   TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
659   queue_1_second_batch_processed.WaitForNotification();
660   TF_ASSERT_OK(ScheduleTask(3, queue_1.get()));
661   queue_1_third_batch_processed.WaitForNotification();
662 
663   // Let poor queue 0 drain.
664   queue_0_proceed.Notify();
665 }
666 
TEST_P(SharedBatchSchedulerTest,QueueDestructorBlocksUntilAllTasksProcessed)667 TEST_P(SharedBatchSchedulerTest, QueueDestructorBlocksUntilAllTasksProcessed) {
668   test_util::FakeClockEnv env(Env::Default());
669   Notification start_teardown, stop_teardown;
670   std::unique_ptr<Thread> teardown_thread =
671       CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
672 
673   {
674     int current_batch = 0;
675     Notification first_callback_started;
676     const int kMaxEnqueuedBatches = 3;
677     std::vector<Notification> callback_proceed(kMaxEnqueuedBatches);
678     auto callback =
679         [&current_batch, &first_callback_started,
680          &callback_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
681           if (current_batch == 0) {
682             first_callback_started.Notify();
683           }
684           callback_proceed[current_batch].WaitForNotification();
685           ++current_batch;
686         };
687 
688     auto scheduler = CreateSharedBatchScheduler(1, &env);
689 
690     const size_t batch_size_limit = 10;
691     const size_t batch_timeout_micros = 0;
692     const size_t max_enqueued_batches = 2;
693     QueueOptions queue_options =
694         CreateQueueOptions(batch_size_limit, batch_size_limit,
695                            batch_timeout_micros, max_enqueued_batches);
696     auto queue = CreateQueue(scheduler, queue_options, callback);
697 
698     // Clog up the queue.
699     int num_enqueued_batches = 0;
700     TF_ASSERT_OK(ScheduleTask(10, queue.get()));
701     ++num_enqueued_batches;
702     env.AdvanceByMicroseconds(1);
703     first_callback_started.WaitForNotification();
704     for (int i = 0; i < 2; ++i) {
705       TF_ASSERT_OK(ScheduleTask(10, queue.get()));
706       ++num_enqueued_batches;
707     }
708     EXPECT_EQ(kMaxEnqueuedBatches, num_enqueued_batches);
709     EXPECT_EQ(error::UNAVAILABLE, ScheduleTask(10, queue.get()).code());
710 
711     // Destroy the queue. The destructor should block until all tasks have been
712     // processed.
713     Notification destroy_queue_thread_started, queue_destroyed;
714     std::unique_ptr<Thread> destroy_queue_thread(Env::Default()->StartThread(
715         {}, "DestroyQueueThread",
716         [&queue, &destroy_queue_thread_started, &queue_destroyed] {
717           destroy_queue_thread_started.Notify();
718           queue = nullptr;
719           queue_destroyed.Notify();
720         }));
721     destroy_queue_thread_started.WaitForNotification();
722     for (int i = 0; i < num_enqueued_batches; ++i) {
723       Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
724       EXPECT_FALSE(queue_destroyed.HasBeenNotified());
725       callback_proceed[i].Notify();
726     }
727 
728     start_teardown.Notify();
729   }
730   stop_teardown.Notify();
731 }
732 
733 // Tests that `enable_lazy_split` could be enabled only if
734 // `enable_large_batch_splitting` is enabled.
TEST_P(SharedBatchSchedulerTest,InvalidLazySplitOptions)735 TEST_P(SharedBatchSchedulerTest, InvalidLazySplitOptions) {
736   auto callback = [](std::unique_ptr<Batch<FakeTask>> batch) {
737     // do nothing.
738   };
739 
740   auto scheduler = CreateSharedBatchScheduler(2);
741 
742   const size_t input_batch_size_limit = 10;
743   const size_t batch_timeout_micros = 100 * 1000;  // 100 milliseconds
744   const size_t max_enqueued_batches = 2;
745   std::unique_ptr<Queue> queue;
746   EXPECT_THAT(
747       scheduler->AddQueue(tensorflow::serving::CreateQueueOptions(
748                               input_batch_size_limit, input_batch_size_limit,
749                               batch_timeout_micros, max_enqueued_batches,
750                               false /* enable_large_batch_splitting */,
751                               true /* enable_lazy_split */, get_split_func()),
752                           callback, &queue),
753       testing::StatusIs(error::INVALID_ARGUMENT,
754                         "enable_lazy_split should be enabled only if "
755                         "enable_large_batch_splitting is enabled."));
756 }
757 
758 // Tests that queue configured with zero `max_enqueued_batches` get one queue.
759 // Note, technically an invalid-argument error should be returned.
760 // Since existing models (with very low QPS) rely on the rewrite, retain the
761 // old behavior so such models continue to work.
TEST_P(SharedBatchSchedulerTest,ZeroQueueRewrittenToOneQueue)762 TEST_P(SharedBatchSchedulerTest, ZeroQueueRewrittenToOneQueue) {
763   auto callback = [](std::unique_ptr<Batch<FakeTask>> batch) {
764     // do nothing.
765   };
766 
767   auto scheduler = CreateSharedBatchScheduler(2);
768 
769   const size_t input_batch_size_limit = 10;
770   const size_t batch_timeout_micros = 100 * 1000;  // 100 milliseconds
771   const size_t max_enqueued_batches = 0;
772   std::unique_ptr<Queue> queue;
773   if (enable_input_batch_split()) {
774     EXPECT_THAT(
775         scheduler->AddQueue(tensorflow::serving::CreateQueueOptions(
776                                 input_batch_size_limit, input_batch_size_limit,
777                                 batch_timeout_micros, max_enqueued_batches,
778                                 enable_input_batch_split(), enable_lazy_split(),
779                                 get_split_func()),
780                             callback, &queue),
781         testing::StatusIs(error::INVALID_ARGUMENT,
782                           "max_enqueued_batches must be positive; was 0"));
783   } else {
784     TF_ASSERT_OK(scheduler->AddQueue(
785         tensorflow::serving::CreateQueueOptions(
786             input_batch_size_limit, input_batch_size_limit,
787             batch_timeout_micros, max_enqueued_batches,
788             enable_input_batch_split(), enable_lazy_split(), get_split_func()),
789         callback, &queue));
790     EXPECT_EQ(queue->SchedulingCapacity(), input_batch_size_limit);
791   }
792 }
793 
794 // TODO(b/161857471):
795 // Add test coverage when input-split and no-split returns differently.
796 INSTANTIATE_TEST_SUITE_P(
797     Parameter, SharedBatchSchedulerTest,
798     ::testing::Values(std::make_tuple(/*enable_input_batch_split=*/true,
799                                       /*enable_lazy_split=*/true),
800                       std::make_tuple(/*enable_input_batch_split=*/true,
801                                       /*enable_lazy_split=*/false),
802                       std::make_tuple(/*enable_input_batch_split=*/false,
803                                       /*enable_lazy_split=*/false)));
804 
805 #ifdef PLATFORM_GOOGLE
806 // This benchmark relies on https://github.com/google/benchmark features,
807 // (in particular, `Benchmark::ThreadRange`) not available in open-sourced TF
808 //  codebase.
809 
810 static std::vector<std::unique_ptr<Queue>>* queues =
811     new std::vector<std::unique_ptr<Queue>>();
812 
813 // Store queue labels, which are used to label benchmark results.
814 static std::vector<std::string>* queue_labels = new std::vector<std::string>();
815 
816 // Create queues and add them to `queues` to keep them alive.
817 // Adds labels in `queue_labels`.
CreateQueues()818 void CreateQueues() {
819   // The split function is guaranteed (in the context of test) to process task
820   // of size one, so it adds `input_task` into `output_tasks` directly, and
821   // simulates a computation that takes some cpu cycles and time to complete.
822   auto split_func_for_size_one_task =
823       [](std::unique_ptr<FakeTask>* input_task, int open_batch_remaining_slot,
824          int max_batch_size,
825          std::vector<std::unique_ptr<FakeTask>>* output_tasks) -> Status {
826     output_tasks->push_back(std::move(*input_task));
827 
828     Notification notify;
829     std::thread busy_waiter([&] {
830       while (!notify.HasBeenNotified()) {
831       }
832     });
833 
834     std::thread notifier([&] {
835       Env::Default()->SleepForMicroseconds(1);
836       notify.Notify();
837     });
838     busy_waiter.join();
839     notifier.join();
840     return OkStatus();
841   };
842 
843   internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback =
844       [](std::unique_ptr<Batch<FakeTask>> task) {
845         // process_batch_callback is supposed to take ownership of `task`.
846         // do nothing since `task` will be freed up when the callback returns.
847       };
848   const size_t max_execution_batch_size = 64;
849   const size_t input_batch_size_limit = 128;
850   const size_t batch_timeout_micros = 10;
851   // Each queue has its own shared-batch-scheduler with the same parameter, so
852   // scheduling behavior are approximately the same.
853   queues->push_back(CreateQueue(
854       CreateSharedBatchScheduler(5),
855       CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
856                          batch_timeout_micros, INT_MAX /* unbounded queue */,
857                          true /* enable_large_batch_splitting */,
858                          false /* enable_lazy_split */,
859                          split_func_for_size_one_task),
860       process_batch_callback));
861   queue_labels->push_back(std::string("EagerSplit"));
862 
863   queues->push_back(CreateQueue(
864       CreateSharedBatchScheduler(5),
865       CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
866                          batch_timeout_micros, INT_MAX /* unbounded queue */,
867                          false /* enable_large_batch_splitting */,
868 
869                          false /* enable_lazy_split */, nullptr /* no func */),
870       process_batch_callback));
871   queue_labels->push_back(std::string("NoSplit"));
872 
873   queues->push_back(CreateQueue(
874       CreateSharedBatchScheduler(5),
875       CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
876                          batch_timeout_micros, INT_MAX /* unbounded queue */,
877                          true /* enable_large_batch_splitting */,
878                          true /* enable_lazy_split */,
879                          split_func_for_size_one_task),
880       process_batch_callback));
881   queue_labels->push_back(std::string("LazySplit"));
882 }
883 
BM_QueueSchedule(::testing::benchmark::State & state)884 void BM_QueueSchedule(::testing::benchmark::State& state) {
885   static absl::once_flag once;
886   absl::call_once(once, []() { CreateQueues(); });
887 
888   const int queue_index = state.range(1);
889   Queue* queue = (*queues)[queue_index].get();
890 
891   const string label = strings::StrCat(state.threads(), "-Threads",
892                                        (*queue_labels)[queue_index]);
893   state.SetLabel(label);
894   for (auto s : state) {
895     for (int i = 0; i < state.range(0); i++) {
896       auto batch_task = std::make_unique<FakeTask>(1);
897 
898       auto status = queue->Schedule(&batch_task);
899       tensorflow::testing::DoNotOptimize(status);
900     }
901   }
902 }
903 
__anon3314b7bb1802(benchmark::internal::Benchmark* b) 904 BENCHMARK(BM_QueueSchedule)->Apply([](benchmark::internal::Benchmark* b) {
905   b->ThreadRange(1,
906                  port::NumSchedulableCPUs() * tensorflow::port::CPUIDNumSMT());
907 
908   for (int queue_index : {0, 1, 2}) {
909     b->ArgPair(10000, queue_index);
910   }
911 });
912 
913 #endif  // PLATFORM_GOOGLE
914 
915 }  // namespace
916 }  // namespace serving
917 }  // namespace tensorflow
918