xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/half_support.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/codegen.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 namespace torch::jit::tensorexpr {
9 
10 // Walk the Statement looking for Half size loads/stores.
11 class HalfChecker : public IRVisitor {
12  public:
HalfChecker(const std::vector<CodeGen::BufferArg> & args)13   HalfChecker(const std::vector<CodeGen::BufferArg>& args) {
14     for (const auto& BA : args) {
15       hasHalf_ |= BA.dtype().scalar_type() == ScalarType::Half;
16     }
17   }
18 
hasHalf()19   bool hasHalf() const {
20     return hasHalf_;
21   }
22 
hasBFloat16()23   bool hasBFloat16() const {
24     return hasBFloat16_;
25   }
26 
visit(const LoadPtr & v)27   void visit(const LoadPtr& v) override {
28     hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
29     hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
30     IRVisitor::visit(v);
31   }
32 
visit(const StorePtr & v)33   void visit(const StorePtr& v) override {
34     hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
35     hasBFloat16_ |= v->buf()->dtype().scalar_type() == ScalarType::BFloat16;
36     IRVisitor::visit(v);
37   }
38 
visit(const HalfImmPtr & v)39   void visit(const HalfImmPtr& v) override {
40     hasHalf_ = true;
41   }
42 
visit(const BFloat16ImmPtr & v)43   void visit(const BFloat16ImmPtr& v) override {
44     hasBFloat16_ = true;
45   }
46 
visit(const CastPtr & v)47   void visit(const CastPtr& v) override {
48     hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
49     hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
50     IRVisitor::visit(v);
51   }
52 
53  private:
54   bool hasHalf_{false};
55   bool hasBFloat16_{false};
56 };
57 
58 class HalfRewriter : public IRMutator {
mutate(const LoadPtr & v)59   ExprPtr mutate(const LoadPtr& v) override {
60     ExprPtr child = IRMutator::mutate(v);
61     if (!isHalf(child)) {
62       return child;
63     }
64 
65     ExprPtr ret = alloc<Cast>(
66         child->dtype().cloneWithScalarType(ScalarType::Float), child);
67 
68     inserted_half_casts_.insert(ret);
69     return ret;
70   }
71 
mutate(const StorePtr & v)72   StmtPtr mutate(const StorePtr& v) override {
73     // Since mutation changes the `value()` expression in-place, we need to
74     // get the dtype of the `value()` before that is mutated.
75     auto newType = v->value()->dtype();
76     ExprPtr new_val = v->value()->accept_mutator(this);
77     auto bufType = v->buf()->dtype();
78 
79     if (isHalf(newType.scalar_type())) {
80       new_val = alloc<Cast>(newType, new_val);
81       inserted_half_casts_.insert(new_val);
82     }
83 
84     // The scalar_type of value is not Half while the buf is Half
85     if (!isHalf(newType.scalar_type()) && isHalf(bufType.scalar_type())) {
86       new_val = alloc<Cast>(
87           newType.cloneWithScalarType(bufType.scalar_type()), new_val);
88       inserted_half_casts_.insert(new_val);
89     }
90 
91     v->set_value(new_val);
92     return v;
93   }
94 
mutate(const HalfImmPtr & v)95   ExprPtr mutate(const HalfImmPtr& v) override {
96     return alloc<Cast>(kFloat, v);
97   }
98 
mutate(const BFloat16ImmPtr & v)99   ExprPtr mutate(const BFloat16ImmPtr& v) override {
100     return alloc<Cast>(kFloat, v);
101   }
102 
mutate(const CastPtr & v)103   ExprPtr mutate(const CastPtr& v) override {
104     ExprPtr child = v->src_value()->accept_mutator(this);
105 
106     // just don't allow half casts we didn't insert.
107     if (isHalf(v)) {
108       if (inserted_half_casts_.count(v) < 1) {
109         v->set_src_value(child);
110         v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat));
111         return v;
112       }
113     }
114 
115     // Remove Half(Float()) and friends.
116     CastPtr cast_child = to<Cast>(child);
117     if (cast_child) {
118       auto cast_to_double = v->dtype().scalar_type() == ScalarType::Double;
119       auto from_half = isHalf(cast_child->src_value());
120       // Cannot simplify the double(float(half)) to double(half) as NNC does
121       // not support cast BF16 to double directly.
122       auto not_cast_half_to_doulbe = !(cast_to_double && from_half);
123       if (v->dtype().is_floating_point() &&
124           cast_child->dtype().is_floating_point() && not_cast_half_to_doulbe) {
125         return alloc<Cast>(v->dtype(), cast_child->src_value());
126       }
127     }
128 
129     if (child == v->src_value()) {
130       return v;
131     }
132 
133     return alloc<Cast>(v->dtype(), child);
134   }
135 
mutate(const LetPtr & v)136   StmtPtr mutate(const LetPtr& v) override {
137     if (isHalf(v->var()->dtype().scalar_type())) {
138       VarPtr load_new_var = alloc<Var>(v->var()->name_hint(), kFloat);
139       ExprPtr new_value = alloc<Cast>(
140           v->var()->dtype().cloneWithScalarType(ScalarType::Float),
141           v->value()->accept_mutator(this));
142       var_map[v->var()] = load_new_var;
143 
144       return alloc<Let>(load_new_var, new_value);
145     }
146 
147     return IRMutator::mutate(v);
148   }
149 
mutate(const VarPtr & v)150   ExprPtr mutate(const VarPtr& v) override {
151     auto it = var_map.find(v);
152     if (it != var_map.end()) {
153       return it->second;
154     }
155 
156     return v;
157   }
158 
159   template <typename T>
mutateArithmetic(T v)160   ExprPtr mutateArithmetic(T v) {
161     IRMutator::mutate(v);
162     if (isHalf(v)) {
163       v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat));
164     }
165     return v;
166   }
167 
mutate(const AddPtr & v)168   ExprPtr mutate(const AddPtr& v) override {
169     return mutateArithmetic(v);
170   }
mutate(const SubPtr & v)171   ExprPtr mutate(const SubPtr& v) override {
172     return mutateArithmetic(v);
173   }
mutate(const MulPtr & v)174   ExprPtr mutate(const MulPtr& v) override {
175     return mutateArithmetic(v);
176   }
mutate(const DivPtr & v)177   ExprPtr mutate(const DivPtr& v) override {
178     return mutateArithmetic(v);
179   }
mutate(const MaxPtr & v)180   ExprPtr mutate(const MaxPtr& v) override {
181     return mutateArithmetic(v);
182   }
mutate(const MinPtr & v)183   ExprPtr mutate(const MinPtr& v) override {
184     return mutateArithmetic(v);
185   }
mutate(const CompareSelectPtr & v)186   ExprPtr mutate(const CompareSelectPtr& v) override {
187     return mutateArithmetic(v);
188   }
mutate(const BroadcastPtr & v)189   ExprPtr mutate(const BroadcastPtr& v) override {
190     return mutateArithmetic(v);
191   }
mutate(const IfThenElsePtr & v)192   ExprPtr mutate(const IfThenElsePtr& v) override {
193     return mutateArithmetic(v);
194   }
mutate(const IntrinsicsPtr & v)195   ExprPtr mutate(const IntrinsicsPtr& v) override {
196     return mutateArithmetic(v);
197   }
198 
199  private:
isHalf(ScalarType st)200   static bool isHalf(ScalarType st) {
201     return st == ScalarType::Half || st == ScalarType::BFloat16;
202   }
203 
isHalf(const ExprPtr & v)204   static bool isHalf(const ExprPtr& v) {
205     return isHalf(v->dtype().scalar_type());
206   }
207 
208   std::unordered_set<ExprPtr> inserted_half_casts_;
209   std::unordered_map<VarPtr, VarPtr> var_map;
210 };
211 
212 } // namespace torch::jit::tensorexpr
213