xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/single_threaded_executor_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <functional>
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/executor.h"
24 #include "tensorflow/core/common_runtime/executor_factory.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
27 #include "tensorflow/core/common_runtime/process_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/rendezvous.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/graph/testlib.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/lib/random/simple_philox.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/test.h"
42 #include "tensorflow/core/platform/test_benchmark.h"
43 #include "tensorflow/core/platform/tracing.h"
44 #include "tensorflow/core/public/session_options.h"
45 
46 namespace tensorflow {
47 namespace data {
48 namespace {
49 
50 class MockOp : public OpKernel {
51  public:
52   using OpKernel::OpKernel;
53 
SetCompute(std::function<void (OpKernelContext *)> compute)54   void SetCompute(std::function<void(OpKernelContext*)> compute) {
55     compute_ = std::move(compute);
56   }
57 
Compute(OpKernelContext * ctx)58   void Compute(OpKernelContext* ctx) override {
59     OP_REQUIRES(ctx, compute_ != nullptr,
60                 errors::FailedPrecondition("Compute() is not set"));
61     compute_(ctx);
62   }
63 
64  private:
65   std::function<void(OpKernelContext* ctx)> compute_;
66 };
67 REGISTER_OP("Mock")
68     .Input("x: float")
69     .Output("y: float")
70     .Output("empty_output: string")
71     .SetIsStateful();
72 REGISTER_KERNEL_BUILDER(Name("Mock").Device(DEVICE_CPU), MockOp);
73 
74 class ExecutorTest : public ::testing::Test {
75  protected:
ExecutorTest()76   ExecutorTest()
77       : device_(DeviceFactory::NewDevice("CPU", {},
78                                          "/job:localhost/replica:0/task:0")) {}
79 
~ExecutorTest()80   ~ExecutorTest() override {
81     // There should always be exactly one Ref left on the Rendezvous
82     // when the test completes.
83     CHECK(rendez_->Unref());
84   }
85 
86   // Resets executor_ with a new executor based on a graph 'gdef'.
Create(std::unique_ptr<const Graph> graph,std::function<void (OpKernelContext *)> mock_fn=nullptr)87   void Create(std::unique_ptr<const Graph> graph,
88               std::function<void(OpKernelContext*)> mock_fn = nullptr) {
89     const int version = graph->versions().producer();
90     LocalExecutorParams params;
91     params.device = device_.get();
92     params.create_kernel =
93         [this, mock_fn = std::move(mock_fn), version](
94             const std::shared_ptr<const NodeProperties>& props,
95             OpKernel** kernel) {
96           TF_RETURN_IF_ERROR(CreateNonCachedKernel(device_.get(), nullptr,
97                                                    props, version, kernel));
98           if ((*kernel)->type_string_view() == "Mock") {
99             down_cast<MockOp*>(*kernel)->SetCompute(mock_fn);
100           }
101           return OkStatus();
102         };
103     params.delete_kernel = [](OpKernel* kernel) {
104       DeleteNonCachedKernel(kernel);
105     };
106     TF_CHECK_OK(
107         NewExecutor("SINGLE_THREADED_EXECUTOR", params, *graph, &exec_));
108     runner_ = [](const std::function<void()>& fn) { fn(); };
109     rendez_ = NewLocalRendezvous();
110   }
111 
Run(Rendezvous * rendez)112   Status Run(Rendezvous* rendez) {
113     Executor::Args args;
114     args.rendezvous = rendez;
115     args.runner = runner_;
116     return exec_->Run(args);
117   }
118 
Run(CallFrameInterface * call_frame)119   Status Run(CallFrameInterface* call_frame) {
120     Executor::Args args;
121     args.call_frame = call_frame;
122     args.runner = runner_;
123     return exec_->Run(args);
124   }
125 
TestContext(Executor::Args args,std::function<void (OpKernelContext *)> test_fn)126   void TestContext(Executor::Args args,
127                    std::function<void(OpKernelContext*)> test_fn) {
128     auto g = std::make_unique<Graph>(OpRegistry::Global());
129     Node* arg = test::graph::Arg(g.get(), 0, DT_FLOAT);
130     Node* tmp;
131     TF_ASSERT_OK(NodeBuilder(g->NewName("n"), "Mock")
132                      .Input(arg)
133                      .Finalize(g.get(), &tmp));
134     auto ret = test::graph::Retval(g.get(), 0, tmp);
135     g->AddControlEdge(arg, ret);
136     FixupSourceAndSinkEdges(g.get());
137 
138     bool mock_called = false;
139     Create(std::move(g), [&](OpKernelContext* ctx) {
140       mock_called = true;
141       ctx->set_output(0, ctx->input(0));
142       test_fn(ctx);
143     });
144 
145     FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
146     TF_ASSERT_OK(call_frame.SetArgs({Tensor(DT_FLOAT, {0})}));
147     args.call_frame = &call_frame;
148     args.runner = runner_;
149     TF_ASSERT_OK(exec_->Run(args));
150     EXPECT_TRUE(mock_called);
151   }
152 
153   std::unique_ptr<Device> device_;
154   std::unique_ptr<Executor> exec_ = nullptr;
155   Executor::Args::Runner runner_;
156   Rendezvous* rendez_ = nullptr;
157 };
158 
159 // A float val -> Tensor<float>
V(const float val)160 Tensor V(const float val) {
161   Tensor tensor(DT_FLOAT, TensorShape({}));
162   tensor.scalar<float>()() = val;
163   return tensor;
164 }
165 
166 // Tensor<float> -> a float val.
V(const Tensor & tensor)167 float V(const Tensor& tensor) {
168   CHECK_EQ(tensor.dtype(), DT_FLOAT);
169   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
170   return tensor.scalar<float>()();
171 }
172 
Key(const string & sender,const uint64 incarnation,const string & receiver,const string & name)173 Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
174                           const string& receiver, const string& name) {
175   Rendezvous::ParsedKey result;
176   TF_CHECK_OK(
177       Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
178                                                  name, FrameAndIter(0, 0)),
179                            &result));
180   return result;
181 }
182 
TEST_F(ExecutorTest,UserIntraOpThreadPool)183 TEST_F(ExecutorTest, UserIntraOpThreadPool) {
184   class DummyThreadPool : public thread::ThreadPoolInterface {
185    public:
186     void Schedule(std::function<void()> fn) override { fn(); }
187     int NumThreads() const override { return 1; }
188     int CurrentThreadId() const override { return -1; }
189   };
190   DummyThreadPool dummy_thread_pool;
191 
192   Executor::Args args;
193   args.user_intra_op_threadpool = &dummy_thread_pool;
194 
195   TestContext(args, [&](OpKernelContext* ctx) {
196     EXPECT_EQ(ctx->device()
197                   ->tensorflow_cpu_worker_threads()
198                   ->workers->AsEigenThreadPool(),
199               &dummy_thread_pool);
200   });
201 }
202 
TEST_F(ExecutorTest,SimpleAdd)203 TEST_F(ExecutorTest, SimpleAdd) {
204   // c = a + b
205   auto g = std::make_unique<Graph>(OpRegistry::Global());
206   auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
207   auto in1 = test::graph::Arg(g.get(), 1, DT_FLOAT);
208   auto tmp = test::graph::Add(g.get(), in0, in1);
209   auto ret = test::graph::Retval(g.get(), 0, tmp);
210   g->AddControlEdge(in1, ret);
211   FixupSourceAndSinkEdges(g.get());
212   Create(std::move(g));
213   FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
214   TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(2.0)}));
215   TF_ASSERT_OK(Run(&call_frame));
216   std::vector<Tensor> retvals;
217   TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
218   EXPECT_EQ(3.0, V(retvals[0]));  // out = 1.0 + 2.0 = 3.0
219 
220   // Verify that the argument values are unchanged.
221   const Tensor* arg_0;
222   TF_ASSERT_OK(call_frame.GetArg(0, &arg_0));
223   EXPECT_EQ(1.0, V(*arg_0));
224   const Tensor* arg_1;
225   TF_ASSERT_OK(call_frame.GetArg(1, &arg_1));
226   EXPECT_EQ(2.0, V(*arg_1));
227 }
228 
TEST_F(ExecutorTest,EmptyOutput)229 TEST_F(ExecutorTest, EmptyOutput) {
230   // in, _ = MockOp(in)
231   auto g = std::make_unique<Graph>(OpRegistry::Global());
232   Node* in = test::graph::Arg(g.get(), 0, DT_FLOAT);
233   Node* mock;
234   TF_ASSERT_OK(
235       NodeBuilder(g->NewName("n"), "Mock").Input(in).Finalize(g.get(), &mock));
236   test::graph::Retval(g.get(), 0, mock, 0);
237   test::graph::Retval(g.get(), 1, mock, 1);
238   FixupSourceAndSinkEdges(g.get());
239   Create(std::move(g),
240          [&](OpKernelContext* ctx) { ctx->set_output(0, ctx->input(0)); });
241   FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT, DT_STRING});
242   TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
243   TF_ASSERT_OK(Run(&call_frame));
244   std::vector<Tensor> retvals;
245   TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
246   EXPECT_EQ(1.0, V(retvals[0]));
247   EXPECT_EQ(DT_STRING, retvals[1].dtype());
248   EXPECT_EQ(0, retvals[1].tensor_data().size());
249 }
250 
TEST_F(ExecutorTest,SelfAdd)251 TEST_F(ExecutorTest, SelfAdd) {
252   // v0 <- a
253   // v1 = v0 + v0
254   // v2 = v1 + v1
255   // ... ...
256   // v10 = v9 + v9
257   //
258   // b <- v10
259   // All nodes are executed by one thread.
260   auto g = std::make_unique<Graph>(OpRegistry::Global());
261   auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
262   const int N = 10;
263   for (int i = 1; i <= N; ++i) {
264     v = test::graph::Add(g.get(), v, v);
265   }
266   // out <- v10
267   test::graph::Retval(g.get(), 0, v);
268   FixupSourceAndSinkEdges(g.get());
269   Create(std::move(g));
270   FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
271   // a = 1.0
272   TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
273   TF_ASSERT_OK(Run(&call_frame));
274   std::vector<Tensor> retvals;
275   TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
276   EXPECT_EQ(1024.0, V(retvals[0]));  // b=v10=2*v9=4*v8=...=1024*a=1024.0
277 }
278 
279 // Builds a graph which adds N copies of one variable "in". I.e.,
280 //     a + a + a + ... + a
281 // The returned graph is parenthesized ramdonly. I.e.,
282 //     a + ((a + a) + a)
283 //     (a + a) + (a + a)
284 //     ((a + a) + a) + a
285 // are all possibly generated.
BuildTree(int N,Graph * g)286 void BuildTree(int N, Graph* g) {
287   CHECK_GT(N, 1);
288   // A single input node "in".
289   auto in = test::graph::Arg(g, 0, DT_FLOAT);
290   std::vector<Node*> nodes;
291   int i = 0;
292   // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
293   for (; i < N; ++i) {
294     nodes.push_back(test::graph::Identity(g, in, 0));
295   }
296   random::PhiloxRandom philox(0, 17);
297   random::SimplePhilox rnd(&philox);
298   while (nodes.size() > 1) {
299     // Randomly pick two from nodes and add them. The resulting node
300     // is named lik n10, n11, .... and is put back into "nodes".
301     int x = rnd.Uniform(nodes.size());
302     auto in0 = nodes[x];
303     nodes[x] = nodes.back();
304     nodes.resize(nodes.size() - 1);
305     x = rnd.Uniform(nodes.size());
306     auto in1 = nodes[x];
307     // node = in0 + in1.
308     nodes[x] = test::graph::Add(g, in0, in1);
309   }
310   // The final output node "out".
311   test::graph::Retval(g, 0, nodes.back());
312   FixupSourceAndSinkEdges(g);
313 }
314 
TEST_F(ExecutorTest,RandomTree)315 TEST_F(ExecutorTest, RandomTree) {
316   auto g = std::make_unique<Graph>(OpRegistry::Global());
317   BuildTree(4096, g.get());
318   Create(std::move(g));
319   FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
320   TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
321   TF_ASSERT_OK(Run(&call_frame));
322   std::vector<Tensor> retvals;
323   TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
324   EXPECT_EQ(4096.0, V(retvals[0]));
325 }
326 
TEST_F(ExecutorTest,OpError)327 TEST_F(ExecutorTest, OpError) {
328   auto g = std::make_unique<Graph>(OpRegistry::Global());
329   auto zero = test::graph::Constant(g.get(), V(0.0));
330   auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
331   auto check = test::graph::CheckNumerics(g.get(), inf, "message");
332   auto two = test::graph::Constant(g.get(), V(2.0));
333   test::graph::Binary(g.get(), "Mul", check, two);
334   FixupSourceAndSinkEdges(g.get());
335   Create(std::move(g));
336   FunctionCallFrame call_frame({}, {});
337   // Fails due to invalid dtype.
338   EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
339 }
340 
TEST_F(ExecutorTest,ControlDependenciesFromSpecialNodes)341 TEST_F(ExecutorTest, ControlDependenciesFromSpecialNodes) {
342   auto g = std::make_unique<Graph>(OpRegistry::Global());
343   auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
344   auto one = test::graph::Constant(g.get(), V(2.0));
345   auto add = test::graph::Add(g.get(), in0, one);
346   auto ret = test::graph::Retval(g.get(), 0, add);
347   g->AddControlEdge(in0, add);
348   g->AddControlEdge(one, ret);
349   FixupSourceAndSinkEdges(g.get());
350   Create(std::move(g));
351   FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
352   TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
353   TF_ASSERT_OK(Run(&call_frame));
354   std::vector<Tensor> retvals;
355   TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
356   EXPECT_EQ(3.0, V(retvals[0]));  // out = 1.0 + 2.0 = 3.0
357 }
358 
BM_executor(::testing::benchmark::State & state)359 void BM_executor(::testing::benchmark::State& state) {
360   const int width = state.range(0);
361   const int depth = state.range(1);
362 
363   Graph* g = new Graph(OpRegistry::Global());
364   random::PhiloxRandom philox(1729, 17);
365   random::SimplePhilox rand(&philox);
366   uint64 cur = 0;
367   uint32 r = 1 + rand.Rand32() % width;
368   std::vector<Node*> ready_nodes;
369   for (int i = 0; i < r; ++i) {
370     ready_nodes.push_back(test::graph::NoOp(g, {}));
371     ++cur;
372   }
373   std::random_device random_device;
374   std::mt19937 rng(random_device());
375   for (int i = 0; i < depth; ++i) {
376     std::shuffle(ready_nodes.begin(), ready_nodes.end(), rng);
377     r = 1 + rand.Rand32() % (ready_nodes.size());
378     std::vector<Node*> control_inputs;
379     for (int j = 0; j < r; ++j) {
380       control_inputs.push_back(ready_nodes.back());
381       ready_nodes.pop_back();
382     }
383     Node* n = test::graph::NoOp(g, control_inputs);
384     ++cur;
385     r = 1 + rand.Rand32() % width;
386     for (int j = 0; j < r; ++j) {
387       ready_nodes.push_back(test::graph::NoOp(g, {n}));
388       ++cur;
389     }
390   }
391   FixupSourceAndSinkEdges(g);
392   test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
393                   "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api=*/false)
394       .Run(state);
395   state.SetLabel(strings::StrCat("Nodes = ", cur));
396   state.SetItemsProcessed(cur * static_cast<int64_t>(state.iterations()));
397 }
398 
399 // Tall skinny graphs
400 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
401 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
402 
403 // Short fat graphs
404 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
405 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
406 
407 // Tall fat graph
408 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
409 
BM_const_identity(::testing::benchmark::State & state)410 void BM_const_identity(::testing::benchmark::State& state) {
411   const int width = state.range(0);
412   const int outputs_per_const = state.range(1);
413 
414   Graph* g = new Graph(OpRegistry::Global());
415   for (int i = 0; i < width; ++i) {
416     Tensor i_t(i);
417     Node* const_node = test::graph::Constant(g, i_t);
418     for (int j = 0; j < outputs_per_const; ++j) {
419       test::graph::Identity(g, const_node);
420     }
421   }
422   FixupSourceAndSinkEdges(g);
423   test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
424                   "SINGLE_THREADED_EXECUTOR",
425                   /*old_benchmark_api=*/false)
426       .Run(state);
427   state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
428   state.SetItemsProcessed((1 + outputs_per_const) * width *
429                           static_cast<int64_t>(state.iterations()));
430 }
431 
432 // Graph with actual op execution.
433 BENCHMARK(BM_const_identity)->UseRealTime()->ArgPair(1, 1);
434 BENCHMARK(BM_const_identity)->UseRealTime()->ArgPair(1, 100);
435 BENCHMARK(BM_const_identity)->UseRealTime()->ArgPair(100, 1);
436 BENCHMARK(BM_const_identity)->UseRealTime()->ArgPair(100, 100);
437 
438 // TODO(mrry): This benchmark currently crashes with a use-after free, because
439 // test::Benchmark::RunWithArgs() assumes that the executor will take ownership
440 // of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
441 // duration of the benchmark. Since the single threaded executor does not retain
442 // a copy of the graph, this fails.
443 //
444 // TODO(mrry): Add support for Arg/Retval "function call convention" in
445 // `test::Benchmark::RunWithArgs()`.
446 #if 0
447 #define ALICE "/job:j/replica:0/task:0/cpu:0"
448 #define BOB "/job:j/replica:0/task:0/gpu:0"
449 
450 static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
451   Graph* g = new Graph(OpRegistry::Global());
452   // z = x + y: x and y are provided as benchmark inputs.  z is the
453   // output of the benchmark.  Conceptually, the caller is ALICE, the
454   // benchmark is BOB.
455   Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
456   Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
457   Node* sum = test::graph::Add(g, x, y);
458   Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
459   FixupSourceAndSinkEdges(g);
460   Tensor val(DT_FLOAT, TensorShape({}));
461   val.scalar<float>()() = 3.14;
462   test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
463                   "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api=*/false)
464       .RunWithArgs({{x, val}, {y, val}}, {z}, state);
465   state.SetItemsProcessed(state.iterations());
466 }
467 BENCHMARK(BM_FeedInputFetchOutput);
468 #endif
469 
470 }  // namespace
471 }  // namespace data
472 }  // namespace tensorflow
473