1 /*
2 * We have a python unit test for exceptions in test/jit/test_exception.py .
3 * Add a CPP version here to verify that excepted exception types thrown from
4 * C++. This is hard to test in python code since C++ exceptions will be
5 * translated to python exceptions.
6 */
7 #include <gtest/gtest.h>
8 #include <pybind11/embed.h>
9 #include <torch/csrc/jit/frontend/parser.h>
10 #include <torch/csrc/jit/frontend/resolver.h>
11 #include <torch/csrc/jit/runtime/jit_exception.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/jit.h>
14 #include <iostream>
15 #include <stdexcept>
16
17 namespace torch {
18 namespace jit {
19
20 namespace py = pybind11;
21
TEST(TestException,TestAssertion)22 TEST(TestException, TestAssertion) {
23 std::string pythonCode = R"PY(
24 def foo():
25 raise AssertionError("An assertion failed")
26 )PY";
27 auto cu_ptr = torch::jit::compile(pythonCode);
28 torch::jit::GraphFunction* gf =
29 (torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
30 std::cerr << "Graph is\n" << *gf->graph() << std::endl;
31
32 bool is_jit_exception = false;
33 std::string message;
34 std::optional<std::string> exception_class;
35 try {
36 cu_ptr->run_method("foo");
37 } catch (JITException& e) {
38 is_jit_exception = true;
39 message = e.what();
40 exception_class = e.getPythonClassName();
41 }
42 EXPECT_TRUE(is_jit_exception);
43 EXPECT_FALSE(exception_class);
44 EXPECT_TRUE(
45 message.find("RuntimeError: AssertionError: An assertion failed") !=
46 std::string::npos);
47 }
48
49 struct MyPythonExceptionValue : public torch::jit::SugaredValue {
MyPythonExceptionValuetorch::jit::MyPythonExceptionValue50 explicit MyPythonExceptionValue(const py::object& exception_class) {
51 qualified_name_ =
52 (py::str(py::getattr(exception_class, "__module__", py::str(""))) +
53 py::str(".") +
54 py::str(py::getattr(exception_class, "__name__", py::str(""))))
55 .cast<std::string>();
56 }
57
kindtorch::jit::MyPythonExceptionValue58 std::string kind() const override {
59 return "My Python exception";
60 }
61
62 // Simplified from PythonExceptionValue::call
calltorch::jit::MyPythonExceptionValue63 std::shared_ptr<torch::jit::SugaredValue> call(
64 const torch::jit::SourceRange& loc,
65 torch::jit::GraphFunction& caller,
66 at::ArrayRef<torch::jit::NamedValue> args,
67 at::ArrayRef<torch::jit::NamedValue> kwargs,
68 size_t n_binders) override {
69 TORCH_CHECK(args.size() == 1);
70 Value* error_message = args.at(0).value(*caller.graph());
71 Value* qualified_class_name =
72 insertConstant(*caller.graph(), qualified_name_, loc);
73 return std::make_shared<ExceptionMessageValue>(
74 error_message, qualified_class_name);
75 }
76
77 private:
78 std::string qualified_name_;
79 };
80
81 class SimpleResolver : public torch::jit::Resolver {
82 public:
SimpleResolver()83 explicit SimpleResolver() {}
84
resolveValue(const std::string & name,torch::jit::GraphFunction & m,const torch::jit::SourceRange & loc)85 std::shared_ptr<torch::jit::SugaredValue> resolveValue(
86 const std::string& name,
87 torch::jit::GraphFunction& m,
88 const torch::jit::SourceRange& loc) override {
89 // follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
90 // a python extension. We can not add that as a cpp_binary's dep)
91 if (name == "SimpleValueError") {
92 py::object obj = py::globals()["SimpleValueError"];
93 return std::make_shared<MyPythonExceptionValue>(obj);
94 }
95 TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
96 }
97
resolveType(const std::string & name,const torch::jit::SourceRange & loc)98 torch::jit::TypePtr resolveType(
99 const std::string& name,
100 const torch::jit::SourceRange& loc) override {
101 return nullptr;
102 }
103 };
104
105 /*
106 * - The python source code parsing for TorchScript here is learned from
107 * torch::jit::compile.
108 * - The code only parses one Def. If there are multiple in the code, those
109 * except the first one are skipped.
110 */
TEST(TestException,TestCustomException)111 TEST(TestException, TestCustomException) {
112 py::scoped_interpreter guard{};
113 py::exec(R"PY(
114 class SimpleValueError(ValueError):
115 def __init__(self, message):
116 super().__init__(message)
117 )PY");
118
119 std::string pythonCode = R"PY(
120 def foo():
121 raise SimpleValueError("An assertion failed")
122 )PY";
123
124 torch::jit::Parser p(
125 std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
126 auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
127 std::cerr << "Def is:\n" << def << std::endl;
128 auto cu = std::make_shared<torch::jit::CompilationUnit>();
129 (void)cu->define(
130 std::nullopt,
131 {},
132 {},
133 {def},
134 // class PythonResolver is defined in
135 // torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
136 // can not use it. Create a SimpleResolver instead
137 {std::make_shared<SimpleResolver>()},
138 nullptr);
139 torch::jit::GraphFunction* gf =
140 (torch::jit::GraphFunction*)&cu->get_function("foo");
141 std::cerr << "Graph is\n" << *gf->graph() << std::endl;
142 bool is_jit_exception = false;
143 std::optional<std::string> exception_class;
144 std::string message;
145 try {
146 cu->run_method("foo");
147 } catch (JITException& e) {
148 is_jit_exception = true;
149 exception_class = e.getPythonClassName();
150 message = e.what();
151 }
152 EXPECT_TRUE(is_jit_exception);
153 EXPECT_EQ("__main__.SimpleValueError", *exception_class);
154 EXPECT_TRUE(
155 message.find("__main__.SimpleValueError: An assertion failed") !=
156 std::string::npos);
157 }
158
159 } // namespace jit
160 } // namespace torch
161