xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <string>
9 #include <vector>
10 
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/runtime/static/impl.h>
13 
14 namespace c10 {
15 struct IValue;
16 }
17 
18 namespace torch {
19 namespace jit {
20 
21 struct Node;
22 class StaticModule;
23 
24 namespace test {
25 
26 // Given a model/function in jit or IR script, run the model/function
27 // with the jit interpreter and static runtime, and compare the results
28 void testStaticRuntime(
29     const std::string& source,
30     const std::vector<c10::IValue>& args,
31     const std::vector<c10::IValue>& args2 = {},
32     const bool use_allclose = false,
33     const bool use_equalnan = false,
34     const bool check_resize = true);
35 
36 std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script);
37 
38 std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
39 
40 bool hasProcessedNodeWithName(
41     torch::jit::StaticModule& smodule,
42     const char* name);
43 
44 at::Tensor getTensor(const at::IValue& ival);
45 
46 Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
47 Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
48 
49 bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
50 bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
51 
52 void compareResultsWithJIT(
53     StaticRuntime& runtime,
54     const std::shared_ptr<Graph>& graph,
55     const std::vector<c10::IValue>& args,
56     const bool use_allclose = false,
57     const bool use_equalnan = false);
58 
59 void compareResults(
60     const IValue& expect,
61     const IValue& actual,
62     const bool use_allclose = false,
63     const bool use_equalnan = false);
64 
65 } // namespace test
66 } // namespace jit
67 } // namespace torch
68