xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4 
5 #include <c10/util/irange.h>
6 
7 namespace torch::jit {
8 
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12 
DeduplicateInitializers(std::shared_ptr<Graph> & g,ValueToParamPairMap & valsToParamsMap,bool (* comp)(at::Tensor &,at::Tensor &))13 void DeduplicateInitializers(
14     std::shared_ptr<Graph>& g,
15     ValueToParamPairMap& valsToParamsMap,
16     bool (*comp)(at::Tensor&, at::Tensor&)) {
17   auto is_same_tensor_as = [&valsToParamsMap, comp](Value* v1) {
18     return [&valsToParamsMap, v1, comp](Value* v2) {
19       if ((valsToParamsMap.find(v1) == valsToParamsMap.end()) ||
20           (valsToParamsMap.find(v2) == valsToParamsMap.end())) {
21         return false;
22       }
23       auto iv1 = valsToParamsMap.find(v1)->second.second;
24       auto iv2 = valsToParamsMap.find(v2)->second.second;
25       if (!iv1.isTensor() || !iv2.isTensor()) {
26         return false;
27       }
28       auto t1 = iv1.toTensor();
29       auto t2 = iv2.toTensor();
30       return comp(t1, t2);
31     };
32   };
33   std::vector<Value*> uniqueVals;
34   std::vector<size_t> inputsIndicesToRemove;
35   auto b = g->block();
36 
37   for (auto i : c10::irange(b->inputs().size())) {
38     auto v = g->inputs().at(i);
39     if (valsToParamsMap.find(v) == valsToParamsMap.end()) {
40       // Skip model inputs
41       continue;
42     }
43     auto it = std::find_if(
44         uniqueVals.begin(), uniqueVals.end(), is_same_tensor_as(v));
45     if (it == uniqueVals.end()) {
46       uniqueVals.emplace_back(v);
47     } else {
48       inputsIndicesToRemove.emplace_back(i);
49       auto id_node = g->create(onnx::Identity);
50       id_node->insertAfter(g->block()->param_node());
51       id_node->addInput(*it);
52       id_node->output()->copyMetadata(v);
53       id_node->copyMetadata(g->block()->param_node());
54       v->replaceAllUsesWith(id_node->output());
55     }
56   }
57   for (auto it = inputsIndicesToRemove.rbegin();
58        it != inputsIndicesToRemove.rend();
59        ++it) {
60     valsToParamsMap.erase(g->inputs().at(*it));
61     g->eraseInput(*it);
62   }
63 }
64 
DeduplicateInitializersByDataPtr(at::Tensor & t1,at::Tensor & t2)65 bool DeduplicateInitializersByDataPtr(at::Tensor& t1, at::Tensor& t2) {
66   return t1.sizes().equals(t2.sizes()) && t1.strides().equals(t2.strides()) &&
67       (t1.has_storage() && t2.has_storage() && t1.data_ptr() == t2.data_ptr());
68 }
69 
DeduplicateInitializersByValue(at::Tensor & t1,at::Tensor & t2)70 bool DeduplicateInitializersByValue(at::Tensor& t1, at::Tensor& t2) {
71   if (t1.dtype() != t2.dtype() || !t1.sizes().equals(t2.sizes()) ||
72       !t1.strides().equals(t2.strides())) {
73     return false;
74   }
75 
76   if (t1.device() != t2.device()) {
77     return t1.to("cpu").equal(t2.to("cpu"));
78   }
79 
80   return t1.equal(t2);
81 }
82 
DeduplicateInitializers(std::shared_ptr<Graph> & g,std::map<std::string,IValue> & paramsDict,bool is_train)83 void DeduplicateInitializers(
84     std::shared_ptr<Graph>& g,
85     std::map<std::string, IValue>& paramsDict,
86     bool is_train) {
87   auto valsToParamsMap = buildValueToParamsMap(g->block(), paramsDict);
88   // ONNX spec does not support parameters with shared memory.
89   // This pass de-duplicate those parameters. Training is not affected.
90   DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByDataPtr);
91   if (!is_train) {
92     // More aggressive parameters de-duplication based on tensor values.
93     // Producing more compact model for inference.
94     // For training, this pass is disabled,
95     // because parameters may be updated differently.
96     DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByValue);
97   }
98   buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
99 }
100 
101 } // namespace torch::jit
102