1 #include <torch/csrc/jit/ir/ir.h> 2 #include <torch/csrc/jit/ir/ir_views.h> 3 #include <torch/csrc/jit/jit_log.h> 4 #include <torch/csrc/jit/passes/frozen_linear_transpose.h> 5 #include <torch/csrc/jit/passes/utils/optimization_utils.h> 6 #include <torch/csrc/jit/runtime/graph_executor.h> 7 #include <torch/csrc/jit/runtime/graph_iterator.h> 8 9 #ifndef AT_PER_OPERATOR_HEADERS 10 #include <ATen/Functions.h> 11 #else 12 #include <ATen/ops/transpose.h> 13 #endif 14 15 #include <iostream> 16 #include <utility> 17 18 namespace torch::jit { 19 namespace { 20 21 using Tensor = at::Tensor; 22 23 class TransposeFrozenLinear { 24 public: TransposeFrozenLinear(std::shared_ptr<Graph> graph)25 TransposeFrozenLinear(std::shared_ptr<Graph> graph) 26 : graph_(std::move(graph)) {} 27 run()28 bool run() { 29 // Can't delete nodes while also iterating over it 30 DepthFirstGraphNodeIterator graph_it(graph_); 31 32 for (auto next_node = graph_it.next(); next_node != nullptr;) { 33 Node* node = next_node; 34 next_node = graph_it.next(); 35 36 if (is_constant_linear_op(node)) { 37 replace_linear_with_matmul(node); 38 } 39 } 40 return graph_modified_; 41 } 42 is_constant_linear_op(Node * node)43 bool is_constant_linear_op(Node* node) { 44 if (node->kind() != aten::linear) { 45 return false; 46 } 47 48 // This also filters out out-variants of the linear op. 49 return !nonConstantParameters(node); 50 } 51 replace_linear_with_matmul(Node * node)52 void replace_linear_with_matmul(Node* node) { 53 graph_modified_ = true; 54 Node* matmul = nullptr; 55 56 { 57 WithInsertPoint insert_guard(node); 58 auto weight = node->namedInput("weight"); 59 60 Tensor weight_tensor = constant_as<Tensor>(weight).value(); 61 Tensor weight_t_tensor = at::transpose(weight_tensor, 1, 0) 62 .clone(at::MemoryFormat::Contiguous); 63 Value* weight_t = graph_->insertConstant(std::move(weight_t_tensor)); 64 matmul = graph_->create(aten::matmul, {node->inputs()[0], weight_t}); 65 matmul->insertAfter(node); 66 } 67 68 // Handle a bias if there is any 69 WithInsertPoint insert_guard(matmul); 70 auto bias = node->namedInput("bias"); 71 if (bias->type() == NoneType::get()) { 72 node->replaceAllUsesWith(matmul); 73 } else { 74 Value* bias_scale = graph_->insertConstant(1); 75 Node* bias_result = 76 graph_->create(aten::add, {matmul->output(), bias, bias_scale}); 77 bias_result->insertAfter(matmul); 78 node->replaceAllUsesWith(bias_result); 79 } 80 node->destroy(); 81 }; 82 handleBlockAndSubblocks(Block * block)83 void handleBlockAndSubblocks(Block* block) {} 84 85 private: 86 std::shared_ptr<Graph> graph_; 87 bool graph_modified_ = false; 88 }; 89 } // namespace 90 FrozenLinearTranspose(std::shared_ptr<Graph> & graph)91TORCH_API bool FrozenLinearTranspose(std::shared_ptr<Graph>& graph) { 92 TransposeFrozenLinear transposeWeight(graph); 93 GRAPH_DUMP("Before FrozenLinearTranspose", graph); 94 bool changed = transposeWeight.run(); 95 if (changed) { 96 GRAPH_DUMP("After FrozenLinearTranspose", graph); 97 } 98 return changed; 99 } 100 101 } // namespace torch::jit 102