xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/runtime/autodiff.h>
5 #include <torch/csrc/jit/runtime/interpreter.h>
6 #include <torch/csrc/jit/testing/file_check.h>
7 
8 namespace {
trim(std::string & s)9 static inline void trim(std::string& s) {
10   s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
11             return !std::isspace(ch);
12           }));
13   s.erase(
14       std::find_if(
15           s.rbegin(),
16           s.rend(),
17           [](unsigned char ch) { return !std::isspace(ch); })
18           .base(),
19       s.end());
20   for (size_t i = 0; i < s.size(); ++i) {
21     while (i < s.size() && s[i] == '\n') {
22       s.erase(i, 1);
23     }
24   }
25   for (size_t i = 0; i < s.size(); ++i) {
26     if (s[i] == ' ') {
27       while (i + 1 < s.size() && s[i + 1] == ' ') {
28         s.erase(i + 1, 1);
29       }
30     }
31   }
32 }
33 } // namespace
34 
35 #define ASSERT_THROWS_WITH_MESSAGE(statement, substring)             \
36   try {                                                              \
37     (void)statement;                                                 \
38     FAIL();                                                          \
39   } catch (const std::exception& e) {                                \
40     std::string substring_s(substring);                              \
41     trim(substring_s);                                               \
42     auto exception_string = std::string(e.what());                   \
43     trim(exception_string);                                          \
44     ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
45         << " Error was: \n"                                          \
46         << exception_string;                                         \
47   }
48 
49 namespace torch {
50 namespace jit {
51 
52 using tensor_list = std::vector<at::Tensor>;
53 using namespace torch::autograd;
54 
55 // work around the fact that variable_tensor_list doesn't duplicate all
56 // of std::vector's constructors.
57 // most constructors are never used in the implementation, just in our tests.
58 Stack createStack(std::vector<at::Tensor>&& list);
59 
60 void assertAllClose(const tensor_list& a, const tensor_list& b);
61 
62 std::vector<at::Tensor> run(
63     InterpreterState& interp,
64     const std::vector<at::Tensor>& inputs);
65 
66 std::pair<tensor_list, tensor_list> runGradient(
67     Gradient& grad_spec,
68     tensor_list& tensors_in,
69     tensor_list& tensor_grads_in);
70 
71 std::shared_ptr<Graph> build_lstm();
72 std::shared_ptr<Graph> build_mobile_export_analysis_graph();
73 std::shared_ptr<Graph> build_mobile_export_with_out();
74 std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
75 std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
76 std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();
77 
78 at::Tensor t_use(at::Tensor x);
79 at::Tensor t_def(at::Tensor x);
80 
81 // given the difference of output vs expected tensor, check whether the
82 // difference is within a relative tolerance range. This is a standard way of
83 // matching tensor values up to certain precision
84 bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
85 bool almostEqual(const at::Tensor& a, const at::Tensor& b);
86 
87 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
88 bool exactlyEqual(
89     const std::vector<at::Tensor>& a,
90     const std::vector<at::Tensor>& b);
91 
92 std::vector<at::Tensor> runGraph(
93     std::shared_ptr<Graph> graph,
94     const std::vector<at::Tensor>& inputs);
95 
96 std::pair<at::Tensor, at::Tensor> lstm(
97     at::Tensor input,
98     at::Tensor hx,
99     at::Tensor cx,
100     at::Tensor w_ih,
101     at::Tensor w_hh);
102 
103 } // namespace jit
104 } // namespace torch
105