1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
17
18 #include <memory>
19 #include <string>
20 #include <thread> // NOLINT(build/c++11)
21 #include <tuple>
22 #include <utility>
23
24 #include "absl/base/call_once.h"
25 #include "absl/container/fixed_array.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
28 #include "tensorflow/core/lib/core/notification.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/cpu_info.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/status_matchers.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/protobuf/error_codes.pb.h"
38
39 namespace tensorflow {
40 namespace serving {
41 namespace {
42
43 class FakeTask : public BatchTask {
44 public:
FakeTask(size_t size)45 explicit FakeTask(size_t size) : size_(size) {}
46
47 ~FakeTask() override = default;
48
size() const49 size_t size() const override { return size_; }
50
51 private:
52 const size_t size_;
53
54 TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
55 };
56
57 using Queue = BatchScheduler<FakeTask>;
58 using Scheduler = SharedBatchScheduler<FakeTask>;
59 using QueueOptions = Scheduler::QueueOptions;
60 using SplitFunc =
61 std::function<Status(std::unique_ptr<FakeTask>* input_task,
62 int first_output_task_size, int input_batch_size_limit,
63 std::vector<std::unique_ptr<FakeTask>>* output_tasks)>;
64
65 // Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
66 // that task. Returns the resulting status.
ScheduleTask(size_t task_size,BatchScheduler<FakeTask> * scheduler)67 Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
68 std::unique_ptr<FakeTask> task(new FakeTask(task_size));
69 Status status = scheduler->Schedule(&task);
70 // Schedule() should have consumed 'task' iff it returned Status::OK.
71 CHECK_EQ(status.ok(), task == nullptr);
72 return status;
73 }
74
75 // Creates a thread that waits on 'start' and then advances the fake clock in
76 // 'env' in a loop until 'stop' is notified. Useful for allowing objects that
77 // use the clock to be destroyed.
CreateFakeClockAdvancerThread(test_util::FakeClockEnv * env,Notification * start,Notification * stop)78 std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
79 test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
80 return std::unique_ptr<Thread>(Env::Default()->StartThread(
81 {}, "FakeClockAdvancerThread", [env, start, stop] {
82 start->WaitForNotification();
83 while (!stop->HasBeenNotified()) {
84 env->AdvanceByMicroseconds(10);
85 Env::Default()->SleepForMicroseconds(10);
86 }
87 }));
88 }
89
90 // Creates a shared-batch-scheduler.
CreateSharedBatchScheduler(int num_batch_threads,Env * env=Env::Default ())91 std::shared_ptr<Scheduler> CreateSharedBatchScheduler(
92 int num_batch_threads, Env* env = Env::Default()) {
93 Scheduler::Options options;
94 options.num_batch_threads = num_batch_threads;
95 options.env = env;
96
97 std::shared_ptr<Scheduler> shared_batch_scheduler;
98 TF_CHECK_OK(Scheduler::Create(options, &shared_batch_scheduler));
99
100 return shared_batch_scheduler;
101 }
102
103 // Creates a queue with the given `queue_options`.
104 //
105 // Caller takes ownership of returned queue.
CreateQueue(std::shared_ptr<Scheduler> scheduler,Scheduler::QueueOptions queue_options,internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback)106 std::unique_ptr<Queue> CreateQueue(
107 std::shared_ptr<Scheduler> scheduler, Scheduler::QueueOptions queue_options,
108 internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback) {
109 std::unique_ptr<BatchScheduler<FakeTask>> queue;
110 TF_CHECK_OK(
111 scheduler->AddQueue(queue_options, process_batch_callback, &queue));
112 return queue;
113 }
114
115 // Creates QueueOptions based on input parameters.
CreateQueueOptions(size_t max_execution_batch_size,size_t input_batch_size_limit,size_t batch_timeout_micros,size_t max_enqueued_batches,bool enable_large_batch_splitting,bool enable_lazy_split,SplitFunc split_func)116 QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
117 size_t input_batch_size_limit,
118 size_t batch_timeout_micros,
119 size_t max_enqueued_batches,
120 bool enable_large_batch_splitting,
121 bool enable_lazy_split, SplitFunc split_func) {
122 QueueOptions queue_options;
123 queue_options.max_enqueued_batches = max_enqueued_batches;
124 queue_options.max_execution_batch_size = max_execution_batch_size;
125 queue_options.input_batch_size_limit = input_batch_size_limit;
126 queue_options.batch_timeout_micros = batch_timeout_micros;
127 queue_options.enable_large_batch_splitting = enable_large_batch_splitting;
128 queue_options.enable_lazy_split = enable_lazy_split;
129 if (enable_large_batch_splitting) {
130 queue_options.split_input_task_func = split_func;
131 }
132 return queue_options;
133 }
134
135 class SharedBatchSchedulerTest
136 : public ::testing::TestWithParam<std::tuple<bool, bool>> {
137 protected:
CreateQueueOptions(size_t max_execution_batch_size,size_t input_batch_size_limit,size_t batch_timeout_micros,size_t max_enqueued_batches)138 QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
139 size_t input_batch_size_limit,
140 size_t batch_timeout_micros,
141 size_t max_enqueued_batches) {
142 return tensorflow::serving::CreateQueueOptions(
143 max_execution_batch_size, input_batch_size_limit, batch_timeout_micros,
144 max_enqueued_batches, enable_input_batch_split(), enable_lazy_split(),
145 get_split_func());
146 }
enable_input_batch_split() const147 bool enable_input_batch_split() const { return std::get<0>(GetParam()); }
148
enable_lazy_split() const149 bool enable_lazy_split() const { return std::get<1>(GetParam()); }
150
get_split_func() const151 SplitFunc get_split_func() const {
152 if (enable_input_batch_split()) {
153 return
154 [](std::unique_ptr<FakeTask>* input_task,
155 int open_batch_remaining_slot, int max_batch_size,
156 std::vector<std::unique_ptr<FakeTask>>* output_tasks) -> Status {
157 std::unique_ptr<FakeTask> owned_input_task = std::move(*input_task);
158 const int input_task_size = owned_input_task->size();
159
160 const internal::InputSplitMetadata input_split_metadata(
161 input_task_size, open_batch_remaining_slot, max_batch_size);
162
163 const absl::FixedArray<int> task_sizes =
164 input_split_metadata.task_sizes();
165 const int num_batches = task_sizes.size();
166
167 output_tasks->resize(num_batches);
168 for (int i = 0; i < num_batches; i++) {
169 (*output_tasks)[i] = std::make_unique<FakeTask>(task_sizes[i]);
170 }
171
172 return OkStatus();
173 };
174 }
175 return nullptr;
176 }
177 };
178
TEST_P(SharedBatchSchedulerTest,Basic)179 TEST_P(SharedBatchSchedulerTest, Basic) {
180 for (int num_batch_threads : {1, 2, 3}) {
181 for (const bool delete_scheduler_early : {false, true}) {
182 for (const bool delete_queue_1_early : {false, true}) {
183 bool queue_0_callback_called = false;
184 auto queue_0_callback =
185 [&queue_0_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
186 queue_0_callback_called = true;
187 ASSERT_TRUE(batch->IsClosed());
188 ASSERT_EQ(3, batch->num_tasks());
189 EXPECT_EQ(1, batch->task(0).size());
190 EXPECT_EQ(3, batch->task(1).size());
191 EXPECT_EQ(5, batch->task(2).size());
192 };
193 bool queue_1_callback_called = false;
194 auto queue_1_callback =
195 [&queue_1_callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
196 queue_1_callback_called = true;
197 ASSERT_TRUE(batch->IsClosed());
198 ASSERT_EQ(2, batch->num_tasks());
199 EXPECT_EQ(2, batch->task(0).size());
200 EXPECT_EQ(4, batch->task(1).size());
201 };
202 {
203 auto scheduler = CreateSharedBatchScheduler(num_batch_threads);
204
205 // Create two queues.
206
207 const size_t input_batch_size_limit = 10;
208 const size_t batch_timeout_micros = 1 * 1000 * 1000; // 1 second
209 const size_t max_enqueued_batches = 2;
210 const auto queue_options =
211 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
212 batch_timeout_micros, max_enqueued_batches);
213 auto queue_0 =
214 CreateQueue(scheduler, queue_options, queue_0_callback);
215
216 auto queue_1 =
217 CreateQueue(scheduler, queue_options, queue_1_callback);
218
219 if (delete_scheduler_early) {
220 // Delete our copy of the scheduler. The queues should keep it alive
221 // under the covers.
222 scheduler = nullptr;
223 }
224
225 // Submit tasks to the two queues, and (optionally) remove the queues.
226 TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
227 TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
228 TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
229 TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
230 if (delete_queue_1_early) {
231 queue_1 = nullptr;
232 }
233 TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
234 }
235 EXPECT_TRUE(queue_0_callback_called);
236 EXPECT_TRUE(queue_1_callback_called);
237 }
238 }
239 }
240 }
241
TEST_P(SharedBatchSchedulerTest,ObeyBatchSizeConstraint)242 TEST_P(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
243 // Set up a fake clock, which only advances when we explicitly tell it to.
244 test_util::FakeClockEnv env(Env::Default());
245 Notification start_teardown, stop_teardown;
246 std::unique_ptr<Thread> teardown_thread =
247 CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
248 // Set up a callback that captures the batches' task sizes.
249 mutex mu;
250 std::vector<std::vector<size_t>> callback_data;
251 Notification all_batches_processed;
252 auto callback = [&mu, &callback_data, &all_batches_processed](
253 std::unique_ptr<Batch<FakeTask>> batch) {
254 ASSERT_TRUE(batch->IsClosed());
255 std::vector<size_t> batch_data;
256 batch_data.reserve(batch->num_tasks());
257 for (int i = 0; i < batch->num_tasks(); ++i) {
258 batch_data.push_back(batch->mutable_task(i)->size());
259 }
260 {
261 mutex_lock l(mu);
262 callback_data.push_back(batch_data);
263 if (callback_data.size() == 2) {
264 all_batches_processed.Notify();
265 }
266 }
267 };
268
269 // Run a batch scheduler and inject some tasks.
270 {
271 auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads=*/2, &env);
272
273 const size_t input_batch_size_limit = 10;
274 const size_t batch_timeout_micros = 10 * 1000; // 10 milli-seconds
275 const size_t max_enqueued_batches = 2;
276 auto queue = CreateQueue(
277 scheduler,
278 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
279 batch_timeout_micros, max_enqueued_batches),
280 callback);
281
282 if (enable_input_batch_split()) {
283 // First batch.
284 TF_ASSERT_OK(ScheduleTask(3, queue.get()));
285 TF_ASSERT_OK(ScheduleTask(5, queue.get()));
286
287 // Second batch
288 // Task spans over first batch and second batch, so contributes two tasks.
289 TF_ASSERT_OK(ScheduleTask(3 /* (3+5) + 3 > 10 */, queue.get()));
290 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
291 TF_ASSERT_OK(ScheduleTask(6, queue.get()));
292 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
293 } else {
294 // First batch.
295 TF_ASSERT_OK(ScheduleTask(3, queue.get()));
296 TF_ASSERT_OK(ScheduleTask(5, queue.get()));
297
298 // Second batch (due to size overage).
299 TF_ASSERT_OK(ScheduleTask(3 /* (3+5) + 3 > 10 */, queue.get()));
300 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
301 TF_ASSERT_OK(ScheduleTask(6, queue.get()));
302 // (Empty third batch, since the second batch exactly hit the size limit,
303 // which should never get sent to the callback.)
304 }
305
306 // Advance clock to trigger batch processing.
307 env.AdvanceByMicroseconds(20 * 1000);
308 all_batches_processed.WaitForNotification();
309 // Expect a certain grouping of the tasks into batches.
310 if (enable_input_batch_split()) {
311 EXPECT_THAT(
312 callback_data,
313 ::testing::UnorderedElementsAreArray(std::vector<std::vector<size_t>>{
314 std::vector<size_t>{3, 5, 2}, std::vector<size_t>{1, 1, 6, 1}}));
315 } else {
316 EXPECT_THAT(callback_data,
317 ::testing::UnorderedElementsAreArray(
318 std::vector<std::vector<size_t>>{{3, 5}, {3, 1, 6}}));
319 }
320 start_teardown.Notify();
321 }
322 stop_teardown.Notify();
323 }
324
TEST_P(SharedBatchSchedulerTest,ObeysTimeout)325 TEST_P(SharedBatchSchedulerTest, ObeysTimeout) {
326 // Set up a fake clock, which only advances when we explicitly tell it to.
327 test_util::FakeClockEnv env(Env::Default());
328 Notification start_teardown, stop_teardown;
329 std::unique_ptr<Thread> teardown_thread =
330 CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
331
332 {
333 Notification first_batch_processed, second_batch_processed,
334 third_batch_processed;
335 bool notify_first_batch = false, notify_second_batch = false,
336 notify_third_batch = false;
337 auto callback = [&](std::unique_ptr<Batch<FakeTask>> batch) {
338 ASSERT_TRUE(batch->IsClosed());
339 if (notify_first_batch && (!first_batch_processed.HasBeenNotified())) {
340 first_batch_processed.Notify();
341 return;
342 }
343 if (notify_second_batch && (!second_batch_processed.HasBeenNotified())) {
344 second_batch_processed.Notify();
345 return;
346 }
347 if (notify_third_batch && (!third_batch_processed.HasBeenNotified())) {
348 third_batch_processed.Notify();
349 return;
350 }
351
352 EXPECT_TRUE(false) << "Unexpected condition";
353 };
354
355 auto scheduler = CreateSharedBatchScheduler(1, &env);
356
357 const size_t input_batch_size_limit = 4;
358 const size_t batch_timeout_micros = 10;
359 const size_t max_enqueued_batches = 2;
360 QueueOptions options =
361 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
362 batch_timeout_micros, max_enqueued_batches);
363 auto queue = CreateQueue(scheduler, options, callback);
364
365 // Create an underfull batch, and ensure that it gets processed when the
366 // clock hits the timeout.
367 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
368 env.AdvanceByMicroseconds(9);
369 Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
370 EXPECT_FALSE(first_batch_processed.HasBeenNotified());
371 notify_first_batch = true;
372 env.AdvanceByMicroseconds(1);
373 first_batch_processed.WaitForNotification();
374
375 // Start creating a batch, while leaving the clock well below the timeout.
376 // Then submit a new task that overflows into the next batch, causing
377 // the original batch to close.
378 TF_ASSERT_OK(ScheduleTask(2, queue.get()));
379 Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
380 EXPECT_FALSE(second_batch_processed.HasBeenNotified());
381 notify_second_batch = true;
382 TF_ASSERT_OK(ScheduleTask(3, queue.get()));
383 second_batch_processed.WaitForNotification();
384
385 // Allow the third batch to hit its timeout, and ensure it gets closed at
386 // the right time.
387 env.AdvanceByMicroseconds(9);
388 Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
389 EXPECT_FALSE(third_batch_processed.HasBeenNotified());
390 notify_third_batch = true;
391 env.AdvanceByMicroseconds(1);
392 third_batch_processed.WaitForNotification();
393
394 start_teardown.Notify();
395 }
396 stop_teardown.Notify();
397 }
398
TEST_P(SharedBatchSchedulerTest,ObeysTimeoutWithRealClock)399 TEST_P(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) {
400 Notification first_batch_processed, second_batch_processed;
401 auto callback = [&first_batch_processed, &second_batch_processed](
402 std::unique_ptr<Batch<FakeTask>> batch) {
403 ASSERT_TRUE(batch->IsClosed());
404 if (batch->size() == 1) {
405 first_batch_processed.Notify();
406 } else if (batch->size() == 2) {
407 second_batch_processed.Notify();
408 } else {
409 EXPECT_TRUE(false) << "Unexpected batch size";
410 }
411 };
412
413 auto scheduler = CreateSharedBatchScheduler(2);
414
415 const size_t input_batch_size_limit = 10;
416 const size_t batch_timeout_micros = 100 * 1000; // 100 milliseconds
417 const size_t max_enqueued_batches = 2;
418 auto queue = CreateQueue(
419 scheduler,
420 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
421 batch_timeout_micros, max_enqueued_batches),
422 callback);
423
424 // Submit a single task that doesn't fill up the batch.
425 // Ensure that it gets processed due to the timeout.
426 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
427 first_batch_processed.WaitForNotification();
428
429 // Do it again.
430 TF_ASSERT_OK(ScheduleTask(2, queue.get()));
431 second_batch_processed.WaitForNotification();
432 }
433
TEST_P(SharedBatchSchedulerTest,WithZeroTimeoutBatchesScheduledAsSoonAsThreadIsAvailable)434 TEST_P(SharedBatchSchedulerTest,
435 WithZeroTimeoutBatchesScheduledAsSoonAsThreadIsAvailable) {
436 // Set up a fake clock, and never advance the time.
437 test_util::FakeClockEnv env(Env::Default());
438 Notification start_teardown, stop_teardown;
439 std::unique_ptr<Thread> teardown_thread =
440 CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
441
442 {
443 Notification first_batch_processed, second_batch_processed;
444 auto callback = [&first_batch_processed, &second_batch_processed](
445 std::unique_ptr<Batch<FakeTask>> batch) {
446 ASSERT_TRUE(batch->IsClosed());
447 if (batch->size() == 1) {
448 first_batch_processed.Notify();
449 } else if (batch->size() == 2) {
450 second_batch_processed.Notify();
451 } else {
452 EXPECT_TRUE(false) << "Unexpected batch size";
453 }
454 };
455
456 auto scheduler = CreateSharedBatchScheduler(2, &env);
457
458 // Set a large batch size, so that we don't hit the batch size limit.
459 const size_t batch_size_limit = 100;
460 // Process a batch as soon as a thread is available.
461 const size_t batch_timeout_micros = 0;
462 const size_t max_enqueued_batches = 2;
463 auto queue = CreateQueue(
464 scheduler,
465 CreateQueueOptions(batch_size_limit, batch_size_limit,
466 batch_timeout_micros, max_enqueued_batches),
467 callback);
468
469 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
470 first_batch_processed.WaitForNotification();
471 TF_ASSERT_OK(ScheduleTask(2, queue.get()));
472 second_batch_processed.WaitForNotification();
473
474 // Shut everything down.
475 start_teardown.Notify();
476 }
477 stop_teardown.Notify();
478 }
479
TEST_P(SharedBatchSchedulerTest,Fairness)480 TEST_P(SharedBatchSchedulerTest, Fairness) {
481 test_util::FakeClockEnv env(Env::Default());
482 Notification start_teardown, stop_teardown;
483 std::unique_ptr<Thread> teardown_thread =
484 CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
485
486 {
487 Notification queue_0_first_batch_scheduled, queue_0_first_batch_proceed,
488 queue_0_second_batch_scheduled;
489 auto queue_0_callback = [&queue_0_first_batch_scheduled,
490 &queue_0_first_batch_proceed,
491 &queue_0_second_batch_scheduled](
492 std::unique_ptr<Batch<FakeTask>> batch) {
493 if (!queue_0_first_batch_scheduled.HasBeenNotified()) {
494 queue_0_first_batch_scheduled.Notify();
495 queue_0_first_batch_proceed.WaitForNotification();
496 } else if (!queue_0_second_batch_scheduled.HasBeenNotified()) {
497 queue_0_second_batch_scheduled.Notify();
498 }
499 };
500
501 Notification queue_1_first_batch_scheduled, queue_1_first_batch_proceed;
502 auto queue_1_callback =
503 [&queue_1_first_batch_scheduled,
504 &queue_1_first_batch_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
505 queue_1_first_batch_scheduled.Notify();
506 queue_1_first_batch_proceed.WaitForNotification();
507 };
508
509 auto scheduler = CreateSharedBatchScheduler(1, &env);
510 size_t input_batch_size_limit = 10;
511 QueueOptions queue_options = CreateQueueOptions(
512 input_batch_size_limit, input_batch_size_limit,
513 1 /* batch_timeout_micros */, 100 /* give plenty of room */);
514 std::vector<std::unique_ptr<BatchScheduler<FakeTask>>> queues(2);
515 TF_ASSERT_OK(
516 scheduler->AddQueue(queue_options, queue_0_callback, &queues[0]));
517 TF_ASSERT_OK(
518 scheduler->AddQueue(queue_options, queue_1_callback, &queues[1]));
519
520 // Enqueue a batch-filling task to queue 0, and wait for it to get
521 // scheduled.
522 TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
523 env.AdvanceByMicroseconds(1);
524 queue_0_first_batch_scheduled.WaitForNotification();
525
526 // Enqueue two more batch-filling tasks to queue 0.
527 TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
528 TF_ASSERT_OK(ScheduleTask(10, queues[0].get()));
529
530 // Enqueue one task to queue 1, and then advance the clock so it becomes
531 // eligible for scheduling due to the timeout. Ensure that the queue 1 batch
532 // gets scheduled before the next queue 0 one.
533 TF_ASSERT_OK(ScheduleTask(1, queues[1].get()));
534 env.AdvanceByMicroseconds(1);
535 queue_0_first_batch_proceed.Notify();
536 queue_1_first_batch_scheduled.WaitForNotification();
537 Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
538 EXPECT_FALSE(queue_0_second_batch_scheduled.HasBeenNotified());
539
540 // Shut everything down.
541 queue_1_first_batch_proceed.Notify();
542 start_teardown.Notify();
543 }
544 stop_teardown.Notify();
545 }
546
TEST_P(SharedBatchSchedulerTest,ConstMethods)547 TEST_P(SharedBatchSchedulerTest, ConstMethods) {
548 for (const int max_enqueued_batches : {1, 2, 5}) {
549 Notification processing, proceed;
550 auto callback = [&processing,
551 &proceed](std::unique_ptr<Batch<FakeTask>> batch) {
552 if (!processing.HasBeenNotified()) {
553 processing.Notify();
554 }
555 proceed.WaitForNotification();
556 };
557
558 auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads*/ 1);
559
560 const size_t input_batch_size_limit = 2;
561 const size_t batch_timeout_micros = 0;
562 auto queue = CreateQueue(
563 scheduler,
564 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
565 batch_timeout_micros, max_enqueued_batches),
566 callback);
567
568 EXPECT_EQ(2, queue->max_task_size());
569 EXPECT_EQ(0, queue->NumEnqueuedTasks());
570 EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity());
571
572 // Get one batch going on the thread, and keep the thread blocked until
573 // we're done testing the maximum queue length.
574 TF_ASSERT_OK(ScheduleTask(2, queue.get()));
575 processing.WaitForNotification();
576 EXPECT_EQ(0, queue->NumEnqueuedTasks());
577
578 // We should be able to enqueue 'max_enqueued_batches'*2 tasks without
579 // issue.
580 for (int i = 0; i < max_enqueued_batches; ++i) {
581 EXPECT_EQ(i * 2, queue->NumEnqueuedTasks());
582 EXPECT_EQ((max_enqueued_batches - i) * 2, queue->SchedulingCapacity());
583 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
584 EXPECT_EQ((i * 2) + 1, queue->NumEnqueuedTasks());
585 EXPECT_EQ((max_enqueued_batches - i) * 2 - 1,
586 queue->SchedulingCapacity());
587 TF_ASSERT_OK(ScheduleTask(1, queue.get()));
588 }
589 EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
590 EXPECT_EQ(0, queue->SchedulingCapacity());
591
592 // Attempting to enqueue one more task should yield an UNAVAILABLE error.
593 EXPECT_THAT(
594 ScheduleTask(1, queue.get()),
595 testing::StatusIs(error::UNAVAILABLE,
596 "The batch scheduling queue to which this task was "
597 "submitted is full"));
598
599 EXPECT_EQ(max_enqueued_batches * 2, queue->NumEnqueuedTasks());
600 EXPECT_EQ(0, queue->SchedulingCapacity());
601
602 proceed.Notify();
603 }
604 }
605
TEST_P(SharedBatchSchedulerTest,OneFullQueueDoesntBlockOtherQueues)606 TEST_P(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) {
607 Notification queue_0_processing, queue_0_proceed;
608 auto queue_0_callback = [&queue_0_processing, &queue_0_proceed](
609 std::unique_ptr<Batch<FakeTask>> batch) {
610 if (!queue_0_processing.HasBeenNotified()) {
611 queue_0_processing.Notify();
612 queue_0_proceed.WaitForNotification();
613 }
614 };
615
616 Notification queue_1_first_batch_processed, queue_1_second_batch_processed,
617 queue_1_third_batch_processed;
618 auto queue_1_callback =
619 [&queue_1_first_batch_processed, &queue_1_second_batch_processed,
620 &queue_1_third_batch_processed](std::unique_ptr<Batch<FakeTask>> batch) {
621 if (batch->size() == 1) {
622 queue_1_first_batch_processed.Notify();
623 } else if (batch->size() == 2) {
624 queue_1_second_batch_processed.Notify();
625 } else if (batch->size() == 3) {
626 queue_1_third_batch_processed.Notify();
627 } else {
628 EXPECT_TRUE(false) << "Unexpected batch size";
629 }
630 };
631
632 auto scheduler = CreateSharedBatchScheduler(/*num_batch_threads*/ 2);
633
634 const size_t input_batch_size_limit = 10;
635 const size_t batch_timeout_micros = 0;
636 const size_t max_enqueued_batches = 2;
637 QueueOptions queue_options =
638 CreateQueueOptions(input_batch_size_limit, input_batch_size_limit,
639 batch_timeout_micros, max_enqueued_batches);
640
641 std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
642 TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
643 std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
644 TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
645
646 // Clog up queue 0.
647 TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
648 queue_0_processing.WaitForNotification();
649 Status queue_0_status;
650 do {
651 queue_0_status = ScheduleTask(1, queue_0.get());
652 } while (queue_0_status.ok());
653 EXPECT_EQ(error::UNAVAILABLE, queue_0_status.code());
654
655 // Ensure that queue 1 still behaves normally, and lets us process tasks.
656 TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
657 queue_1_first_batch_processed.WaitForNotification();
658 TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
659 queue_1_second_batch_processed.WaitForNotification();
660 TF_ASSERT_OK(ScheduleTask(3, queue_1.get()));
661 queue_1_third_batch_processed.WaitForNotification();
662
663 // Let poor queue 0 drain.
664 queue_0_proceed.Notify();
665 }
666
TEST_P(SharedBatchSchedulerTest,QueueDestructorBlocksUntilAllTasksProcessed)667 TEST_P(SharedBatchSchedulerTest, QueueDestructorBlocksUntilAllTasksProcessed) {
668 test_util::FakeClockEnv env(Env::Default());
669 Notification start_teardown, stop_teardown;
670 std::unique_ptr<Thread> teardown_thread =
671 CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
672
673 {
674 int current_batch = 0;
675 Notification first_callback_started;
676 const int kMaxEnqueuedBatches = 3;
677 std::vector<Notification> callback_proceed(kMaxEnqueuedBatches);
678 auto callback =
679 [¤t_batch, &first_callback_started,
680 &callback_proceed](std::unique_ptr<Batch<FakeTask>> batch) {
681 if (current_batch == 0) {
682 first_callback_started.Notify();
683 }
684 callback_proceed[current_batch].WaitForNotification();
685 ++current_batch;
686 };
687
688 auto scheduler = CreateSharedBatchScheduler(1, &env);
689
690 const size_t batch_size_limit = 10;
691 const size_t batch_timeout_micros = 0;
692 const size_t max_enqueued_batches = 2;
693 QueueOptions queue_options =
694 CreateQueueOptions(batch_size_limit, batch_size_limit,
695 batch_timeout_micros, max_enqueued_batches);
696 auto queue = CreateQueue(scheduler, queue_options, callback);
697
698 // Clog up the queue.
699 int num_enqueued_batches = 0;
700 TF_ASSERT_OK(ScheduleTask(10, queue.get()));
701 ++num_enqueued_batches;
702 env.AdvanceByMicroseconds(1);
703 first_callback_started.WaitForNotification();
704 for (int i = 0; i < 2; ++i) {
705 TF_ASSERT_OK(ScheduleTask(10, queue.get()));
706 ++num_enqueued_batches;
707 }
708 EXPECT_EQ(kMaxEnqueuedBatches, num_enqueued_batches);
709 EXPECT_EQ(error::UNAVAILABLE, ScheduleTask(10, queue.get()).code());
710
711 // Destroy the queue. The destructor should block until all tasks have been
712 // processed.
713 Notification destroy_queue_thread_started, queue_destroyed;
714 std::unique_ptr<Thread> destroy_queue_thread(Env::Default()->StartThread(
715 {}, "DestroyQueueThread",
716 [&queue, &destroy_queue_thread_started, &queue_destroyed] {
717 destroy_queue_thread_started.Notify();
718 queue = nullptr;
719 queue_destroyed.Notify();
720 }));
721 destroy_queue_thread_started.WaitForNotification();
722 for (int i = 0; i < num_enqueued_batches; ++i) {
723 Env::Default()->SleepForMicroseconds(10 * 1000 /* 10 milliseconds */);
724 EXPECT_FALSE(queue_destroyed.HasBeenNotified());
725 callback_proceed[i].Notify();
726 }
727
728 start_teardown.Notify();
729 }
730 stop_teardown.Notify();
731 }
732
733 // Tests that `enable_lazy_split` could be enabled only if
734 // `enable_large_batch_splitting` is enabled.
TEST_P(SharedBatchSchedulerTest,InvalidLazySplitOptions)735 TEST_P(SharedBatchSchedulerTest, InvalidLazySplitOptions) {
736 auto callback = [](std::unique_ptr<Batch<FakeTask>> batch) {
737 // do nothing.
738 };
739
740 auto scheduler = CreateSharedBatchScheduler(2);
741
742 const size_t input_batch_size_limit = 10;
743 const size_t batch_timeout_micros = 100 * 1000; // 100 milliseconds
744 const size_t max_enqueued_batches = 2;
745 std::unique_ptr<Queue> queue;
746 EXPECT_THAT(
747 scheduler->AddQueue(tensorflow::serving::CreateQueueOptions(
748 input_batch_size_limit, input_batch_size_limit,
749 batch_timeout_micros, max_enqueued_batches,
750 false /* enable_large_batch_splitting */,
751 true /* enable_lazy_split */, get_split_func()),
752 callback, &queue),
753 testing::StatusIs(error::INVALID_ARGUMENT,
754 "enable_lazy_split should be enabled only if "
755 "enable_large_batch_splitting is enabled."));
756 }
757
758 // Tests that queue configured with zero `max_enqueued_batches` get one queue.
759 // Note, technically an invalid-argument error should be returned.
760 // Since existing models (with very low QPS) rely on the rewrite, retain the
761 // old behavior so such models continue to work.
TEST_P(SharedBatchSchedulerTest,ZeroQueueRewrittenToOneQueue)762 TEST_P(SharedBatchSchedulerTest, ZeroQueueRewrittenToOneQueue) {
763 auto callback = [](std::unique_ptr<Batch<FakeTask>> batch) {
764 // do nothing.
765 };
766
767 auto scheduler = CreateSharedBatchScheduler(2);
768
769 const size_t input_batch_size_limit = 10;
770 const size_t batch_timeout_micros = 100 * 1000; // 100 milliseconds
771 const size_t max_enqueued_batches = 0;
772 std::unique_ptr<Queue> queue;
773 if (enable_input_batch_split()) {
774 EXPECT_THAT(
775 scheduler->AddQueue(tensorflow::serving::CreateQueueOptions(
776 input_batch_size_limit, input_batch_size_limit,
777 batch_timeout_micros, max_enqueued_batches,
778 enable_input_batch_split(), enable_lazy_split(),
779 get_split_func()),
780 callback, &queue),
781 testing::StatusIs(error::INVALID_ARGUMENT,
782 "max_enqueued_batches must be positive; was 0"));
783 } else {
784 TF_ASSERT_OK(scheduler->AddQueue(
785 tensorflow::serving::CreateQueueOptions(
786 input_batch_size_limit, input_batch_size_limit,
787 batch_timeout_micros, max_enqueued_batches,
788 enable_input_batch_split(), enable_lazy_split(), get_split_func()),
789 callback, &queue));
790 EXPECT_EQ(queue->SchedulingCapacity(), input_batch_size_limit);
791 }
792 }
793
794 // TODO(b/161857471):
795 // Add test coverage when input-split and no-split returns differently.
796 INSTANTIATE_TEST_SUITE_P(
797 Parameter, SharedBatchSchedulerTest,
798 ::testing::Values(std::make_tuple(/*enable_input_batch_split=*/true,
799 /*enable_lazy_split=*/true),
800 std::make_tuple(/*enable_input_batch_split=*/true,
801 /*enable_lazy_split=*/false),
802 std::make_tuple(/*enable_input_batch_split=*/false,
803 /*enable_lazy_split=*/false)));
804
805 #ifdef PLATFORM_GOOGLE
806 // This benchmark relies on https://github.com/google/benchmark features,
807 // (in particular, `Benchmark::ThreadRange`) not available in open-sourced TF
808 // codebase.
809
810 static std::vector<std::unique_ptr<Queue>>* queues =
811 new std::vector<std::unique_ptr<Queue>>();
812
813 // Store queue labels, which are used to label benchmark results.
814 static std::vector<std::string>* queue_labels = new std::vector<std::string>();
815
816 // Create queues and add them to `queues` to keep them alive.
817 // Adds labels in `queue_labels`.
CreateQueues()818 void CreateQueues() {
819 // The split function is guaranteed (in the context of test) to process task
820 // of size one, so it adds `input_task` into `output_tasks` directly, and
821 // simulates a computation that takes some cpu cycles and time to complete.
822 auto split_func_for_size_one_task =
823 [](std::unique_ptr<FakeTask>* input_task, int open_batch_remaining_slot,
824 int max_batch_size,
825 std::vector<std::unique_ptr<FakeTask>>* output_tasks) -> Status {
826 output_tasks->push_back(std::move(*input_task));
827
828 Notification notify;
829 std::thread busy_waiter([&] {
830 while (!notify.HasBeenNotified()) {
831 }
832 });
833
834 std::thread notifier([&] {
835 Env::Default()->SleepForMicroseconds(1);
836 notify.Notify();
837 });
838 busy_waiter.join();
839 notifier.join();
840 return OkStatus();
841 };
842
843 internal::Queue<FakeTask>::ProcessBatchCallback process_batch_callback =
844 [](std::unique_ptr<Batch<FakeTask>> task) {
845 // process_batch_callback is supposed to take ownership of `task`.
846 // do nothing since `task` will be freed up when the callback returns.
847 };
848 const size_t max_execution_batch_size = 64;
849 const size_t input_batch_size_limit = 128;
850 const size_t batch_timeout_micros = 10;
851 // Each queue has its own shared-batch-scheduler with the same parameter, so
852 // scheduling behavior are approximately the same.
853 queues->push_back(CreateQueue(
854 CreateSharedBatchScheduler(5),
855 CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
856 batch_timeout_micros, INT_MAX /* unbounded queue */,
857 true /* enable_large_batch_splitting */,
858 false /* enable_lazy_split */,
859 split_func_for_size_one_task),
860 process_batch_callback));
861 queue_labels->push_back(std::string("EagerSplit"));
862
863 queues->push_back(CreateQueue(
864 CreateSharedBatchScheduler(5),
865 CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
866 batch_timeout_micros, INT_MAX /* unbounded queue */,
867 false /* enable_large_batch_splitting */,
868
869 false /* enable_lazy_split */, nullptr /* no func */),
870 process_batch_callback));
871 queue_labels->push_back(std::string("NoSplit"));
872
873 queues->push_back(CreateQueue(
874 CreateSharedBatchScheduler(5),
875 CreateQueueOptions(max_execution_batch_size, input_batch_size_limit,
876 batch_timeout_micros, INT_MAX /* unbounded queue */,
877 true /* enable_large_batch_splitting */,
878 true /* enable_lazy_split */,
879 split_func_for_size_one_task),
880 process_batch_callback));
881 queue_labels->push_back(std::string("LazySplit"));
882 }
883
BM_QueueSchedule(::testing::benchmark::State & state)884 void BM_QueueSchedule(::testing::benchmark::State& state) {
885 static absl::once_flag once;
886 absl::call_once(once, []() { CreateQueues(); });
887
888 const int queue_index = state.range(1);
889 Queue* queue = (*queues)[queue_index].get();
890
891 const string label = strings::StrCat(state.threads(), "-Threads",
892 (*queue_labels)[queue_index]);
893 state.SetLabel(label);
894 for (auto s : state) {
895 for (int i = 0; i < state.range(0); i++) {
896 auto batch_task = std::make_unique<FakeTask>(1);
897
898 auto status = queue->Schedule(&batch_task);
899 tensorflow::testing::DoNotOptimize(status);
900 }
901 }
902 }
903
__anon3314b7bb1802(benchmark::internal::Benchmark* b) 904 BENCHMARK(BM_QueueSchedule)->Apply([](benchmark::internal::Benchmark* b) {
905 b->ThreadRange(1,
906 port::NumSchedulableCPUs() * tensorflow::port::CPUIDNumSMT());
907
908 for (int queue_index : {0, 1, 2}) {
909 b->ArgPair(10000, queue_index);
910 }
911 });
912
913 #endif // PLATFORM_GOOGLE
914
915 } // namespace
916 } // namespace serving
917 } // namespace tensorflow
918