1 #pragma once 2 3 #include <unordered_map> 4 #include <utility> 5 #include <vector> 6 7 #include <torch/csrc/jit/tensorexpr/analysis.h> 8 #include <torch/csrc/jit/tensorexpr/ir.h> 9 #include <torch/csrc/jit/tensorexpr/ir_mutator.h> 10 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 11 #include <torch/csrc/jit/tensorexpr/reduction.h> 12 13 namespace torch::jit::tensorexpr { 14 15 using VarMapping = std::vector<std::pair<VarPtr, ExprPtr>>; 16 17 class VarSubMutator : public IRMutator { 18 public: VarSubMutator(const VarMapping & var_mapping)19 VarSubMutator(const VarMapping& var_mapping) { 20 for (auto& entry : var_mapping) { 21 VarPtr key_var = entry.first; 22 ExprPtr value = entry.second; 23 if (!key_var) { 24 throw malformed_input("missing key in VarSubMutator"); 25 } 26 var_mapping_[std::move(key_var)] = std::move(value); 27 } 28 } 29 mutate(const VarPtr & var)30 ExprPtr mutate(const VarPtr& var) override { 31 auto iter = var_mapping_.find(var); 32 if (iter == var_mapping_.end()) { 33 return var; 34 } 35 return iter->second; 36 } 37 mutate(const ReduceOpPtr & var)38 ExprPtr mutate(const ReduceOpPtr& var) override { 39 auto body = var->body()->accept_mutator(this); 40 std::vector<VarPtr> new_inner; 41 42 for (const auto& v : var->reduce_args()) { 43 ExprPtr e = v->accept_mutator(this); 44 if (VarPtr new_var = to<Var>(e)) { 45 new_inner.push_back(std::move(new_var)); 46 } else { 47 VarFinder varFinder; 48 e->accept(&varFinder); 49 auto varlist = varFinder.vars(); 50 new_inner.insert(new_inner.end(), varlist.begin(), varlist.end()); 51 } 52 } 53 54 return alloc<ReduceOp>(body, new_inner, var->reducer()); 55 } 56 57 private: 58 std::unordered_map<VarPtr, ExprPtr> var_mapping_; 59 }; 60 61 } // namespace torch::jit::tensorexpr 62