xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_exception.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * We have a python unit test for exceptions in test/jit/test_exception.py .
3  * Add a CPP version here to verify that excepted exception types thrown from
4  * C++. This is hard to test in python code since C++ exceptions will be
5  * translated to python exceptions.
6  */
7 #include <gtest/gtest.h>
8 #include <pybind11/embed.h>
9 #include <torch/csrc/jit/frontend/parser.h>
10 #include <torch/csrc/jit/frontend/resolver.h>
11 #include <torch/csrc/jit/runtime/jit_exception.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/jit.h>
14 #include <iostream>
15 #include <stdexcept>
16 
17 namespace torch {
18 namespace jit {
19 
20 namespace py = pybind11;
21 
TEST(TestException,TestAssertion)22 TEST(TestException, TestAssertion) {
23   std::string pythonCode = R"PY(
24   def foo():
25     raise AssertionError("An assertion failed")
26   )PY";
27   auto cu_ptr = torch::jit::compile(pythonCode);
28   torch::jit::GraphFunction* gf =
29       (torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
30   std::cerr << "Graph is\n" << *gf->graph() << std::endl;
31 
32   bool is_jit_exception = false;
33   std::string message;
34   std::optional<std::string> exception_class;
35   try {
36     cu_ptr->run_method("foo");
37   } catch (JITException& e) {
38     is_jit_exception = true;
39     message = e.what();
40     exception_class = e.getPythonClassName();
41   }
42   EXPECT_TRUE(is_jit_exception);
43   EXPECT_FALSE(exception_class);
44   EXPECT_TRUE(
45       message.find("RuntimeError: AssertionError: An assertion failed") !=
46       std::string::npos);
47 }
48 
49 struct MyPythonExceptionValue : public torch::jit::SugaredValue {
MyPythonExceptionValuetorch::jit::MyPythonExceptionValue50   explicit MyPythonExceptionValue(const py::object& exception_class) {
51     qualified_name_ =
52         (py::str(py::getattr(exception_class, "__module__", py::str(""))) +
53          py::str(".") +
54          py::str(py::getattr(exception_class, "__name__", py::str(""))))
55             .cast<std::string>();
56   }
57 
kindtorch::jit::MyPythonExceptionValue58   std::string kind() const override {
59     return "My Python exception";
60   }
61 
62   // Simplified from PythonExceptionValue::call
calltorch::jit::MyPythonExceptionValue63   std::shared_ptr<torch::jit::SugaredValue> call(
64       const torch::jit::SourceRange& loc,
65       torch::jit::GraphFunction& caller,
66       at::ArrayRef<torch::jit::NamedValue> args,
67       at::ArrayRef<torch::jit::NamedValue> kwargs,
68       size_t n_binders) override {
69     TORCH_CHECK(args.size() == 1);
70     Value* error_message = args.at(0).value(*caller.graph());
71     Value* qualified_class_name =
72         insertConstant(*caller.graph(), qualified_name_, loc);
73     return std::make_shared<ExceptionMessageValue>(
74         error_message, qualified_class_name);
75   }
76 
77  private:
78   std::string qualified_name_;
79 };
80 
81 class SimpleResolver : public torch::jit::Resolver {
82  public:
SimpleResolver()83   explicit SimpleResolver() {}
84 
resolveValue(const std::string & name,torch::jit::GraphFunction & m,const torch::jit::SourceRange & loc)85   std::shared_ptr<torch::jit::SugaredValue> resolveValue(
86       const std::string& name,
87       torch::jit::GraphFunction& m,
88       const torch::jit::SourceRange& loc) override {
89     // follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
90     // a python extension. We can not add that as a cpp_binary's dep)
91     if (name == "SimpleValueError") {
92       py::object obj = py::globals()["SimpleValueError"];
93       return std::make_shared<MyPythonExceptionValue>(obj);
94     }
95     TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
96   }
97 
resolveType(const std::string & name,const torch::jit::SourceRange & loc)98   torch::jit::TypePtr resolveType(
99       const std::string& name,
100       const torch::jit::SourceRange& loc) override {
101     return nullptr;
102   }
103 };
104 
105 /*
106  * - The python source code parsing for TorchScript here is learned from
107  * torch::jit::compile.
108  * - The code only parses one Def. If there are multiple in the code, those
109  * except the first one are skipped.
110  */
TEST(TestException,TestCustomException)111 TEST(TestException, TestCustomException) {
112   py::scoped_interpreter guard{};
113   py::exec(R"PY(
114   class SimpleValueError(ValueError):
115     def __init__(self, message):
116       super().__init__(message)
117   )PY");
118 
119   std::string pythonCode = R"PY(
120   def foo():
121     raise SimpleValueError("An assertion failed")
122   )PY";
123 
124   torch::jit::Parser p(
125       std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
126   auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
127   std::cerr << "Def is:\n" << def << std::endl;
128   auto cu = std::make_shared<torch::jit::CompilationUnit>();
129   (void)cu->define(
130       std::nullopt,
131       {},
132       {},
133       {def},
134       // class PythonResolver is defined in
135       // torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
136       // can not use it. Create a SimpleResolver instead
137       {std::make_shared<SimpleResolver>()},
138       nullptr);
139   torch::jit::GraphFunction* gf =
140       (torch::jit::GraphFunction*)&cu->get_function("foo");
141   std::cerr << "Graph is\n" << *gf->graph() << std::endl;
142   bool is_jit_exception = false;
143   std::optional<std::string> exception_class;
144   std::string message;
145   try {
146     cu->run_method("foo");
147   } catch (JITException& e) {
148     is_jit_exception = true;
149     exception_class = e.getPythonClassName();
150     message = e.what();
151   }
152   EXPECT_TRUE(is_jit_exception);
153   EXPECT_EQ("__main__.SimpleValueError", *exception_class);
154   EXPECT_TRUE(
155       message.find("__main__.SimpleValueError: An assertion failed") !=
156       std::string::npos);
157 }
158 
159 } // namespace jit
160 } // namespace torch
161