xref: /aosp_15_r20/external/pytorch/c10/core/SymBool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/SymBool.h>
2 #include <c10/core/SymNodeImpl.h>
3 
4 namespace c10 {
5 
toSymNodeImpl() const6 SymNode SymBool::toSymNodeImpl() const {
7   TORCH_CHECK(is_heap_allocated());
8   return SymNode::reclaim_copy(toSymNodeImplUnowned());
9 }
10 
wrap_node(const SymNode & base) const11 SymNode SymBool::wrap_node(const SymNode& base) const {
12   if (auto ma = maybe_as_bool()) {
13     return base->wrap_bool(*ma);
14   } else {
15     return toSymNodeImpl();
16   }
17 }
18 
19 #define DEFINE_BINARY(API, OP, METHOD, RET)                              \
20   RET SymBool::API(const SymBool& sci) const {                           \
21     if (auto ma = maybe_as_bool()) {                                     \
22       if (auto mb = sci.maybe_as_bool()) {                               \
23         return RET(OP(*ma, *mb));                                        \
24       } else {                                                           \
25         auto b = sci.toSymNodeImpl();                                    \
26         return RET(b->wrap_bool(*ma)->METHOD(b));                        \
27       }                                                                  \
28     } else {                                                             \
29       if (auto mb = sci.maybe_as_bool()) {                               \
30         auto a = toSymNodeImplUnowned();                                 \
31         return RET(a->METHOD(a->wrap_bool(*mb)));                        \
32       } else {                                                           \
33         return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
34       }                                                                  \
35     }                                                                    \
36   }
37 
38 // clang-format off
DEFINE_BINARY(sym_and,std::logical_and<> (),sym_and,SymBool)39 DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
40 DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
41 // clang-format on
42 
43 SymBool SymBool::sym_not() const {
44   if (auto ma = maybe_as_bool()) {
45     return SymBool(!*ma);
46   }
47   return SymBool(toSymNodeImpl()->sym_not());
48 }
49 
operator <<(std::ostream & os,const SymBool & s)50 std::ostream& operator<<(std::ostream& os, const SymBool& s) {
51   if (auto ma = s.maybe_as_bool()) {
52     os << *ma;
53   } else {
54     os << s.toSymNodeImpl()->str();
55   }
56   return os;
57 }
58 
guard_bool(const char * file,int64_t line) const59 bool SymBool::guard_bool(const char* file, int64_t line) const {
60   if (auto ma = maybe_as_bool()) {
61     return *ma;
62   }
63   SymNode a = toSymNodeImpl();
64   return a->guard_bool(file, line);
65 }
66 
guard_size_oblivious(const char * file,int64_t line) const67 bool SymBool::guard_size_oblivious(const char* file, int64_t line) const {
68   if (auto ma = maybe_as_bool()) {
69     return *ma;
70   }
71   SymNode a = toSymNodeImpl();
72   return a->guard_size_oblivious(file, line);
73 }
74 
expect_true(const char * file,int64_t line) const75 bool SymBool::expect_true(const char* file, int64_t line) const {
76   if (auto ma = maybe_as_bool()) {
77     return *ma;
78   }
79   SymNode a = toSymNodeImpl();
80   return a->expect_true(file, line);
81 }
82 
has_hint() const83 bool SymBool::has_hint() const {
84   if (maybe_as_bool()) {
85     return true;
86   }
87   return toSymNodeImpl()->has_hint();
88 }
89 
90 } // namespace c10
91