xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_linear_transpose.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/ir.h>
2 #include <torch/csrc/jit/ir/ir_views.h>
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/frozen_linear_transpose.h>
5 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
6 #include <torch/csrc/jit/runtime/graph_executor.h>
7 #include <torch/csrc/jit/runtime/graph_iterator.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/transpose.h>
13 #endif
14 
15 #include <iostream>
16 #include <utility>
17 
18 namespace torch::jit {
19 namespace {
20 
21 using Tensor = at::Tensor;
22 
23 class TransposeFrozenLinear {
24  public:
TransposeFrozenLinear(std::shared_ptr<Graph> graph)25   TransposeFrozenLinear(std::shared_ptr<Graph> graph)
26       : graph_(std::move(graph)) {}
27 
run()28   bool run() {
29     // Can't delete nodes while also iterating over it
30     DepthFirstGraphNodeIterator graph_it(graph_);
31 
32     for (auto next_node = graph_it.next(); next_node != nullptr;) {
33       Node* node = next_node;
34       next_node = graph_it.next();
35 
36       if (is_constant_linear_op(node)) {
37         replace_linear_with_matmul(node);
38       }
39     }
40     return graph_modified_;
41   }
42 
is_constant_linear_op(Node * node)43   bool is_constant_linear_op(Node* node) {
44     if (node->kind() != aten::linear) {
45       return false;
46     }
47 
48     // This also filters out out-variants of the linear op.
49     return !nonConstantParameters(node);
50   }
51 
replace_linear_with_matmul(Node * node)52   void replace_linear_with_matmul(Node* node) {
53     graph_modified_ = true;
54     Node* matmul = nullptr;
55 
56     {
57       WithInsertPoint insert_guard(node);
58       auto weight = node->namedInput("weight");
59 
60       Tensor weight_tensor = constant_as<Tensor>(weight).value();
61       Tensor weight_t_tensor = at::transpose(weight_tensor, 1, 0)
62                                    .clone(at::MemoryFormat::Contiguous);
63       Value* weight_t = graph_->insertConstant(std::move(weight_t_tensor));
64       matmul = graph_->create(aten::matmul, {node->inputs()[0], weight_t});
65       matmul->insertAfter(node);
66     }
67 
68     // Handle a bias if there is any
69     WithInsertPoint insert_guard(matmul);
70     auto bias = node->namedInput("bias");
71     if (bias->type() == NoneType::get()) {
72       node->replaceAllUsesWith(matmul);
73     } else {
74       Value* bias_scale = graph_->insertConstant(1);
75       Node* bias_result =
76           graph_->create(aten::add, {matmul->output(), bias, bias_scale});
77       bias_result->insertAfter(matmul);
78       node->replaceAllUsesWith(bias_result);
79     }
80     node->destroy();
81   };
82 
handleBlockAndSubblocks(Block * block)83   void handleBlockAndSubblocks(Block* block) {}
84 
85  private:
86   std::shared_ptr<Graph> graph_;
87   bool graph_modified_ = false;
88 };
89 } // namespace
90 
FrozenLinearTranspose(std::shared_ptr<Graph> & graph)91 TORCH_API bool FrozenLinearTranspose(std::shared_ptr<Graph>& graph) {
92   TransposeFrozenLinear transposeWeight(graph);
93   GRAPH_DUMP("Before FrozenLinearTranspose", graph);
94   bool changed = transposeWeight.run();
95   if (changed) {
96     GRAPH_DUMP("After FrozenLinearTranspose", graph);
97   }
98   return changed;
99 }
100 
101 } // namespace torch::jit
102