xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Exception.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
3 #ifdef USE_VULKAN_API
4 
5 #include <exception>
6 #include <ostream>
7 #include <string>
8 #include <vector>
9 
10 #include <ATen/native/vulkan/api/StringUtil.h>
11 #include <ATen/native/vulkan/api/vk_api.h>
12 
13 #define VK_CHECK(function)                                       \
14   do {                                                           \
15     const VkResult result = (function);                          \
16     if (VK_SUCCESS != result) {                                  \
17       throw ::at::native::vulkan::api::Error(                    \
18           {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
19           ::at::native::vulkan::api::concat_str(                 \
20               #function, " returned ", result));                 \
21     }                                                            \
22   } while (false)
23 
24 #define VK_CHECK_COND(cond, ...)                                 \
25   do {                                                           \
26     if (!(cond)) {                                               \
27       throw ::at::native::vulkan::api::Error(                    \
28           {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
29           #cond,                                                 \
30           ::at::native::vulkan::api::concat_str(__VA_ARGS__));   \
31     }                                                            \
32   } while (false)
33 
34 #define VK_THROW(...)                                          \
35   do {                                                         \
36     throw ::at::native::vulkan::api::Error(                    \
37         {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
38         ::at::native::vulkan::api::concat_str(__VA_ARGS__));   \
39   } while (false)
40 
41 namespace at {
42 namespace native {
43 namespace vulkan {
44 namespace api {
45 
46 std::ostream& operator<<(std::ostream& out, const VkResult loc);
47 
48 struct SourceLocation {
49   const char* function;
50   const char* file;
51   uint32_t line;
52 };
53 
54 std::ostream& operator<<(std::ostream& out, const SourceLocation& loc);
55 
56 class Error : public std::exception {
57  public:
58   Error(SourceLocation source_location, std::string msg);
59   Error(SourceLocation source_location, const char* cond, std::string msg);
60 
61  private:
62   std::string msg_;
63   SourceLocation source_location_;
64   std::string what_;
65 
66  public:
msg()67   const std::string& msg() const {
68     return msg_;
69   }
70 
what()71   const char* what() const noexcept override {
72     return what_.c_str();
73   }
74 };
75 
76 } // namespace api
77 } // namespace vulkan
78 } // namespace native
79 } // namespace at
80 
81 #endif /* USE_VULKAN_API */
82