1 #include <c10/core/SymBool.h>
2 #include <c10/core/SymNodeImpl.h>
3
4 namespace c10 {
5
toSymNodeImpl() const6 SymNode SymBool::toSymNodeImpl() const {
7 TORCH_CHECK(is_heap_allocated());
8 return SymNode::reclaim_copy(toSymNodeImplUnowned());
9 }
10
wrap_node(const SymNode & base) const11 SymNode SymBool::wrap_node(const SymNode& base) const {
12 if (auto ma = maybe_as_bool()) {
13 return base->wrap_bool(*ma);
14 } else {
15 return toSymNodeImpl();
16 }
17 }
18
19 #define DEFINE_BINARY(API, OP, METHOD, RET) \
20 RET SymBool::API(const SymBool& sci) const { \
21 if (auto ma = maybe_as_bool()) { \
22 if (auto mb = sci.maybe_as_bool()) { \
23 return RET(OP(*ma, *mb)); \
24 } else { \
25 auto b = sci.toSymNodeImpl(); \
26 return RET(b->wrap_bool(*ma)->METHOD(b)); \
27 } \
28 } else { \
29 if (auto mb = sci.maybe_as_bool()) { \
30 auto a = toSymNodeImplUnowned(); \
31 return RET(a->METHOD(a->wrap_bool(*mb))); \
32 } else { \
33 return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
34 } \
35 } \
36 }
37
38 // clang-format off
DEFINE_BINARY(sym_and,std::logical_and<> (),sym_and,SymBool)39 DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
40 DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
41 // clang-format on
42
43 SymBool SymBool::sym_not() const {
44 if (auto ma = maybe_as_bool()) {
45 return SymBool(!*ma);
46 }
47 return SymBool(toSymNodeImpl()->sym_not());
48 }
49
operator <<(std::ostream & os,const SymBool & s)50 std::ostream& operator<<(std::ostream& os, const SymBool& s) {
51 if (auto ma = s.maybe_as_bool()) {
52 os << *ma;
53 } else {
54 os << s.toSymNodeImpl()->str();
55 }
56 return os;
57 }
58
guard_bool(const char * file,int64_t line) const59 bool SymBool::guard_bool(const char* file, int64_t line) const {
60 if (auto ma = maybe_as_bool()) {
61 return *ma;
62 }
63 SymNode a = toSymNodeImpl();
64 return a->guard_bool(file, line);
65 }
66
guard_size_oblivious(const char * file,int64_t line) const67 bool SymBool::guard_size_oblivious(const char* file, int64_t line) const {
68 if (auto ma = maybe_as_bool()) {
69 return *ma;
70 }
71 SymNode a = toSymNodeImpl();
72 return a->guard_size_oblivious(file, line);
73 }
74
expect_true(const char * file,int64_t line) const75 bool SymBool::expect_true(const char* file, int64_t line) const {
76 if (auto ma = maybe_as_bool()) {
77 return *ma;
78 }
79 SymNode a = toSymNodeImpl();
80 return a->expect_true(file, line);
81 }
82
has_hint() const83 bool SymBool::has_hint() const {
84 if (maybe_as_bool()) {
85 return true;
86 }
87 return toSymNodeImpl()->has_hint();
88 }
89
90 } // namespace c10
91