xref: /aosp_15_r20/external/pytorch/test/cpp/common/support.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 
5 #include <gtest/gtest.h>
6 
7 #include <stdexcept>
8 #include <string>
9 
10 namespace torch {
11 namespace test {
12 #define ASSERT_THROWS_WITH(statement, substring)                        \
13   {                                                                     \
14     std::string assert_throws_with_error_message;                       \
15     try {                                                               \
16       (void)statement;                                                  \
17       FAIL() << "Expected statement `" #statement                       \
18                 "` to throw an exception, but it did not";              \
19     } catch (const c10::Error& e) {                                     \
20       assert_throws_with_error_message = e.what_without_backtrace();    \
21     } catch (const std::exception& e) {                                 \
22       assert_throws_with_error_message = e.what();                      \
23     }                                                                   \
24     if (assert_throws_with_error_message.find(substring) ==             \
25         std::string::npos) {                                            \
26       FAIL() << "Error message \"" << assert_throws_with_error_message  \
27              << "\" did not contain expected substring \"" << substring \
28              << "\"";                                                   \
29     }                                                                   \
30   }
31 
32 } // namespace test
33 } // namespace torch
34