xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
5 
6 #include <stdexcept>
7 
8 // Forward declarations of types
9 
10 namespace torch::jit::tensorexpr {
11 class Expr;
12 class Stmt;
13 } // namespace torch::jit::tensorexpr
14 
15 // Forward declarations of functions
16 namespace std {
17 TORCH_API std::string to_string(const torch::jit::tensorexpr::ExprPtr&);
18 TORCH_API std::string to_string(const torch::jit::tensorexpr::StmtPtr&);
19 } // namespace std
20 
21 namespace torch::jit::tensorexpr {
22 
23 class unsupported_dtype : public std::runtime_error {
24  public:
unsupported_dtype()25   explicit unsupported_dtype() : std::runtime_error("UNSUPPORTED DTYPE") {}
unsupported_dtype(const std::string & err)26   explicit unsupported_dtype(const std::string& err)
27       : std::runtime_error("UNSUPPORTED DTYPE: " + err) {}
28 };
29 
30 class out_of_range_index : public std::runtime_error {
31  public:
out_of_range_index()32   explicit out_of_range_index() : std::runtime_error("OUT OF RANGE INDEX") {}
out_of_range_index(const std::string & err)33   explicit out_of_range_index(const std::string& err)
34       : std::runtime_error("OUT OF RANGE INDEX: " + err) {}
35 };
36 
37 class unimplemented_lowering : public std::runtime_error {
38  public:
unimplemented_lowering()39   explicit unimplemented_lowering()
40       : std::runtime_error("UNIMPLEMENTED LOWERING") {}
unimplemented_lowering(const ExprPtr & expr)41   explicit unimplemented_lowering(const ExprPtr& expr)
42       : std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
unimplemented_lowering(const StmtPtr & stmt)43   explicit unimplemented_lowering(const StmtPtr& stmt)
44       : std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
45 };
46 
47 class malformed_input : public std::runtime_error {
48  public:
malformed_input()49   explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {}
malformed_input(const std::string & err)50   explicit malformed_input(const std::string& err)
51       : std::runtime_error("MALFORMED INPUT: " + err) {}
malformed_input(const ExprPtr & expr)52   explicit malformed_input(const ExprPtr& expr)
53       : std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
malformed_input(const std::string & err,const ExprPtr & expr)54   explicit malformed_input(const std::string& err, const ExprPtr& expr)
55       : std::runtime_error(
56             "MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
malformed_input(const StmtPtr & stmt)57   explicit malformed_input(const StmtPtr& stmt)
58       : std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
malformed_input(const std::string & err,const StmtPtr & stmt)59   explicit malformed_input(const std::string& err, const StmtPtr& stmt)
60       : std::runtime_error(
61             "MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
62 };
63 
64 class malformed_ir : public std::runtime_error {
65  public:
malformed_ir()66   explicit malformed_ir() : std::runtime_error("MALFORMED IR") {}
malformed_ir(const std::string & err)67   explicit malformed_ir(const std::string& err)
68       : std::runtime_error("MALFORMED IR: " + err) {}
malformed_ir(const ExprPtr & expr)69   explicit malformed_ir(const ExprPtr& expr)
70       : std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {}
malformed_ir(const std::string & err,const ExprPtr & expr)71   explicit malformed_ir(const std::string& err, const ExprPtr& expr)
72       : std::runtime_error(
73             "MALFORMED IR: " + err + " - " + std::to_string(expr)) {}
malformed_ir(const StmtPtr & stmt)74   explicit malformed_ir(const StmtPtr& stmt)
75       : std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {}
malformed_ir(const std::string & err,const StmtPtr & stmt)76   explicit malformed_ir(const std::string& err, const StmtPtr& stmt)
77       : std::runtime_error(
78             "MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
79 };
80 
81 TORCH_API std::string buildErrorMessage(const std::string& s = "");
82 
83 } // namespace torch::jit::tensorexpr
84