1 #pragma once 2 3 #include <memory> 4 #include <vector> 5 6 #include <test/cpp/tensorexpr/test_base.h> 7 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 8 #include <torch/csrc/jit/testing/file_check.h> 9 10 namespace torch { 11 namespace jit { 12 using namespace torch::jit::tensorexpr; 13 14 #define IS_NODE(T, node) \ 15 { \ 16 auto node_ = to<T>(node); \ 17 ASSERT_NE(nullptr, node_); \ 18 } 19 20 #define IS_NODE_WITH_NAME(T, node, name) \ 21 auto name = to<T>(node); \ 22 ASSERT_NE(nullptr, name); 23 24 #define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ 25 NodePtr<T> name = nullptr; \ 26 { \ 27 auto node_ = to<Cast>(node); \ 28 ASSERT_NE(nullptr, node_); \ 29 ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ 30 name = to<T>(node_->src_value()); \ 31 } \ 32 ASSERT_NE(nullptr, name); 33 34 #define IS_IMM_WITH_VAL(T, node, val) \ 35 { \ 36 auto node_ = to<T##Imm>(node); \ 37 ASSERT_NE(nullptr, node_); \ 38 ASSERT_EQ(node_->value(), val); \ 39 } 40 41 #define IS_VAR_WITH_NAME(node, name) \ 42 { \ 43 auto node_ = to<Var>(node); \ 44 ASSERT_NE(nullptr, node_); \ 45 ASSERT_EQ(node_->name_hint(), name); \ 46 } 47 48 #define IS_BINOP_W_VARS(T, node, name, v1, v2) \ 49 NodePtr<T> name = nullptr; \ 50 { \ 51 name = to<T>(node); \ 52 ASSERT_NE(nullptr, name); \ 53 IS_VAR_WITH_NAME(name->lhs(), v1); \ 54 IS_VAR_WITH_NAME(name->rhs(), v2); \ 55 } 56 57 #define IS_BINOP_W_CONST(T, node, name, v, c) \ 58 NodePtr<T> name = nullptr; \ 59 { \ 60 name = to<T>(node); \ 61 ASSERT_NE(nullptr, name); \ 62 IS_VAR_WITH_NAME(name->lhs(), v); \ 63 IS_IMM_WITH_VAL(Int, name->rhs(), c); \ 64 } 65 66 #define IS_RAND(node) \ 67 { \ 68 auto node_ = to<Intrinsics>(node); \ 69 ASSERT_NE(nullptr, node_); \ 70 ASSERT_EQ(node_->op_type(), kRand); \ 71 } 72 73 void checkIR(StmtPtr s, const std::string& pattern); 74 void checkExprIR(ExprPtr e, const std::string& pattern); 75 void checkExprIR(const ExprHandle& e, const std::string& pattern); 76 77 } // namespace jit 78 } // namespace torch 79