1 #pragma once
2
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/runtime/autodiff.h>
5 #include <torch/csrc/jit/runtime/interpreter.h>
6 #include <torch/csrc/jit/testing/file_check.h>
7
8 namespace {
trim(std::string & s)9 static inline void trim(std::string& s) {
10 s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
11 return !std::isspace(ch);
12 }));
13 s.erase(
14 std::find_if(
15 s.rbegin(),
16 s.rend(),
17 [](unsigned char ch) { return !std::isspace(ch); })
18 .base(),
19 s.end());
20 for (size_t i = 0; i < s.size(); ++i) {
21 while (i < s.size() && s[i] == '\n') {
22 s.erase(i, 1);
23 }
24 }
25 for (size_t i = 0; i < s.size(); ++i) {
26 if (s[i] == ' ') {
27 while (i + 1 < s.size() && s[i + 1] == ' ') {
28 s.erase(i + 1, 1);
29 }
30 }
31 }
32 }
33 } // namespace
34
35 #define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
36 try { \
37 (void)statement; \
38 FAIL(); \
39 } catch (const std::exception& e) { \
40 std::string substring_s(substring); \
41 trim(substring_s); \
42 auto exception_string = std::string(e.what()); \
43 trim(exception_string); \
44 ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
45 << " Error was: \n" \
46 << exception_string; \
47 }
48
49 namespace torch {
50 namespace jit {
51
52 using tensor_list = std::vector<at::Tensor>;
53 using namespace torch::autograd;
54
55 // work around the fact that variable_tensor_list doesn't duplicate all
56 // of std::vector's constructors.
57 // most constructors are never used in the implementation, just in our tests.
58 Stack createStack(std::vector<at::Tensor>&& list);
59
60 void assertAllClose(const tensor_list& a, const tensor_list& b);
61
62 std::vector<at::Tensor> run(
63 InterpreterState& interp,
64 const std::vector<at::Tensor>& inputs);
65
66 std::pair<tensor_list, tensor_list> runGradient(
67 Gradient& grad_spec,
68 tensor_list& tensors_in,
69 tensor_list& tensor_grads_in);
70
71 std::shared_ptr<Graph> build_lstm();
72 std::shared_ptr<Graph> build_mobile_export_analysis_graph();
73 std::shared_ptr<Graph> build_mobile_export_with_out();
74 std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
75 std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
76 std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();
77
78 at::Tensor t_use(at::Tensor x);
79 at::Tensor t_def(at::Tensor x);
80
81 // given the difference of output vs expected tensor, check whether the
82 // difference is within a relative tolerance range. This is a standard way of
83 // matching tensor values up to certain precision
84 bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
85 bool almostEqual(const at::Tensor& a, const at::Tensor& b);
86
87 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
88 bool exactlyEqual(
89 const std::vector<at::Tensor>& a,
90 const std::vector<at::Tensor>& b);
91
92 std::vector<at::Tensor> runGraph(
93 std::shared_ptr<Graph> graph,
94 const std::vector<at::Tensor>& inputs);
95
96 std::pair<at::Tensor, at::Tensor> lstm(
97 at::Tensor input,
98 at::Tensor hx,
99 at::Tensor cx,
100 at::Tensor w_ih,
101 at::Tensor w_hh);
102
103 } // namespace jit
104 } // namespace torch
105