1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/codegen.h> 4 #include <torch/csrc/jit/tensorexpr/ir.h> 5 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 6 #include <torch/csrc/jit/tensorexpr/tensor.h> 7 8 namespace torch::jit::tensorexpr { 9 10 // Walk the Statement looking for Half size loads/stores. 11 class HalfChecker : public IRVisitor { 12 public: HalfChecker(const std::vector<CodeGen::BufferArg> & args)13 HalfChecker(const std::vector<CodeGen::BufferArg>& args) { 14 for (const auto& BA : args) { 15 hasHalf_ |= BA.dtype().scalar_type() == ScalarType::Half; 16 } 17 } 18 hasHalf()19 bool hasHalf() const { 20 return hasHalf_; 21 } 22 hasBFloat16()23 bool hasBFloat16() const { 24 return hasBFloat16_; 25 } 26 visit(const LoadPtr & v)27 void visit(const LoadPtr& v) override { 28 hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; 29 hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16; 30 IRVisitor::visit(v); 31 } 32 visit(const StorePtr & v)33 void visit(const StorePtr& v) override { 34 hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half; 35 hasBFloat16_ |= v->buf()->dtype().scalar_type() == ScalarType::BFloat16; 36 IRVisitor::visit(v); 37 } 38 visit(const HalfImmPtr & v)39 void visit(const HalfImmPtr& v) override { 40 hasHalf_ = true; 41 } 42 visit(const BFloat16ImmPtr & v)43 void visit(const BFloat16ImmPtr& v) override { 44 hasBFloat16_ = true; 45 } 46 visit(const CastPtr & v)47 void visit(const CastPtr& v) override { 48 hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; 49 hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16; 50 IRVisitor::visit(v); 51 } 52 53 private: 54 bool hasHalf_{false}; 55 bool hasBFloat16_{false}; 56 }; 57 58 class HalfRewriter : public IRMutator { mutate(const LoadPtr & v)59 ExprPtr mutate(const LoadPtr& v) override { 60 ExprPtr child = IRMutator::mutate(v); 61 if (!isHalf(child)) { 62 return child; 63 } 64 65 ExprPtr ret = alloc<Cast>( 66 child->dtype().cloneWithScalarType(ScalarType::Float), child); 67 68 inserted_half_casts_.insert(ret); 69 return ret; 70 } 71 mutate(const StorePtr & v)72 StmtPtr mutate(const StorePtr& v) override { 73 // Since mutation changes the `value()` expression in-place, we need to 74 // get the dtype of the `value()` before that is mutated. 75 auto newType = v->value()->dtype(); 76 ExprPtr new_val = v->value()->accept_mutator(this); 77 auto bufType = v->buf()->dtype(); 78 79 if (isHalf(newType.scalar_type())) { 80 new_val = alloc<Cast>(newType, new_val); 81 inserted_half_casts_.insert(new_val); 82 } 83 84 // The scalar_type of value is not Half while the buf is Half 85 if (!isHalf(newType.scalar_type()) && isHalf(bufType.scalar_type())) { 86 new_val = alloc<Cast>( 87 newType.cloneWithScalarType(bufType.scalar_type()), new_val); 88 inserted_half_casts_.insert(new_val); 89 } 90 91 v->set_value(new_val); 92 return v; 93 } 94 mutate(const HalfImmPtr & v)95 ExprPtr mutate(const HalfImmPtr& v) override { 96 return alloc<Cast>(kFloat, v); 97 } 98 mutate(const BFloat16ImmPtr & v)99 ExprPtr mutate(const BFloat16ImmPtr& v) override { 100 return alloc<Cast>(kFloat, v); 101 } 102 mutate(const CastPtr & v)103 ExprPtr mutate(const CastPtr& v) override { 104 ExprPtr child = v->src_value()->accept_mutator(this); 105 106 // just don't allow half casts we didn't insert. 107 if (isHalf(v)) { 108 if (inserted_half_casts_.count(v) < 1) { 109 v->set_src_value(child); 110 v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat)); 111 return v; 112 } 113 } 114 115 // Remove Half(Float()) and friends. 116 CastPtr cast_child = to<Cast>(child); 117 if (cast_child) { 118 auto cast_to_double = v->dtype().scalar_type() == ScalarType::Double; 119 auto from_half = isHalf(cast_child->src_value()); 120 // Cannot simplify the double(float(half)) to double(half) as NNC does 121 // not support cast BF16 to double directly. 122 auto not_cast_half_to_doulbe = !(cast_to_double && from_half); 123 if (v->dtype().is_floating_point() && 124 cast_child->dtype().is_floating_point() && not_cast_half_to_doulbe) { 125 return alloc<Cast>(v->dtype(), cast_child->src_value()); 126 } 127 } 128 129 if (child == v->src_value()) { 130 return v; 131 } 132 133 return alloc<Cast>(v->dtype(), child); 134 } 135 mutate(const LetPtr & v)136 StmtPtr mutate(const LetPtr& v) override { 137 if (isHalf(v->var()->dtype().scalar_type())) { 138 VarPtr load_new_var = alloc<Var>(v->var()->name_hint(), kFloat); 139 ExprPtr new_value = alloc<Cast>( 140 v->var()->dtype().cloneWithScalarType(ScalarType::Float), 141 v->value()->accept_mutator(this)); 142 var_map[v->var()] = load_new_var; 143 144 return alloc<Let>(load_new_var, new_value); 145 } 146 147 return IRMutator::mutate(v); 148 } 149 mutate(const VarPtr & v)150 ExprPtr mutate(const VarPtr& v) override { 151 auto it = var_map.find(v); 152 if (it != var_map.end()) { 153 return it->second; 154 } 155 156 return v; 157 } 158 159 template <typename T> mutateArithmetic(T v)160 ExprPtr mutateArithmetic(T v) { 161 IRMutator::mutate(v); 162 if (isHalf(v)) { 163 v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat)); 164 } 165 return v; 166 } 167 mutate(const AddPtr & v)168 ExprPtr mutate(const AddPtr& v) override { 169 return mutateArithmetic(v); 170 } mutate(const SubPtr & v)171 ExprPtr mutate(const SubPtr& v) override { 172 return mutateArithmetic(v); 173 } mutate(const MulPtr & v)174 ExprPtr mutate(const MulPtr& v) override { 175 return mutateArithmetic(v); 176 } mutate(const DivPtr & v)177 ExprPtr mutate(const DivPtr& v) override { 178 return mutateArithmetic(v); 179 } mutate(const MaxPtr & v)180 ExprPtr mutate(const MaxPtr& v) override { 181 return mutateArithmetic(v); 182 } mutate(const MinPtr & v)183 ExprPtr mutate(const MinPtr& v) override { 184 return mutateArithmetic(v); 185 } mutate(const CompareSelectPtr & v)186 ExprPtr mutate(const CompareSelectPtr& v) override { 187 return mutateArithmetic(v); 188 } mutate(const BroadcastPtr & v)189 ExprPtr mutate(const BroadcastPtr& v) override { 190 return mutateArithmetic(v); 191 } mutate(const IfThenElsePtr & v)192 ExprPtr mutate(const IfThenElsePtr& v) override { 193 return mutateArithmetic(v); 194 } mutate(const IntrinsicsPtr & v)195 ExprPtr mutate(const IntrinsicsPtr& v) override { 196 return mutateArithmetic(v); 197 } 198 199 private: isHalf(ScalarType st)200 static bool isHalf(ScalarType st) { 201 return st == ScalarType::Half || st == ScalarType::BFloat16; 202 } 203 isHalf(const ExprPtr & v)204 static bool isHalf(const ExprPtr& v) { 205 return isHalf(v->dtype().scalar_type()); 206 } 207 208 std::unordered_set<ExprPtr> inserted_half_casts_; 209 std::unordered_map<VarPtr, VarPtr> var_map; 210 }; 211 212 } // namespace torch::jit::tensorexpr 213