1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4
5 #include <c10/util/irange.h>
6
7 namespace torch::jit {
8
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12
DeduplicateInitializers(std::shared_ptr<Graph> & g,ValueToParamPairMap & valsToParamsMap,bool (* comp)(at::Tensor &,at::Tensor &))13 void DeduplicateInitializers(
14 std::shared_ptr<Graph>& g,
15 ValueToParamPairMap& valsToParamsMap,
16 bool (*comp)(at::Tensor&, at::Tensor&)) {
17 auto is_same_tensor_as = [&valsToParamsMap, comp](Value* v1) {
18 return [&valsToParamsMap, v1, comp](Value* v2) {
19 if ((valsToParamsMap.find(v1) == valsToParamsMap.end()) ||
20 (valsToParamsMap.find(v2) == valsToParamsMap.end())) {
21 return false;
22 }
23 auto iv1 = valsToParamsMap.find(v1)->second.second;
24 auto iv2 = valsToParamsMap.find(v2)->second.second;
25 if (!iv1.isTensor() || !iv2.isTensor()) {
26 return false;
27 }
28 auto t1 = iv1.toTensor();
29 auto t2 = iv2.toTensor();
30 return comp(t1, t2);
31 };
32 };
33 std::vector<Value*> uniqueVals;
34 std::vector<size_t> inputsIndicesToRemove;
35 auto b = g->block();
36
37 for (auto i : c10::irange(b->inputs().size())) {
38 auto v = g->inputs().at(i);
39 if (valsToParamsMap.find(v) == valsToParamsMap.end()) {
40 // Skip model inputs
41 continue;
42 }
43 auto it = std::find_if(
44 uniqueVals.begin(), uniqueVals.end(), is_same_tensor_as(v));
45 if (it == uniqueVals.end()) {
46 uniqueVals.emplace_back(v);
47 } else {
48 inputsIndicesToRemove.emplace_back(i);
49 auto id_node = g->create(onnx::Identity);
50 id_node->insertAfter(g->block()->param_node());
51 id_node->addInput(*it);
52 id_node->output()->copyMetadata(v);
53 id_node->copyMetadata(g->block()->param_node());
54 v->replaceAllUsesWith(id_node->output());
55 }
56 }
57 for (auto it = inputsIndicesToRemove.rbegin();
58 it != inputsIndicesToRemove.rend();
59 ++it) {
60 valsToParamsMap.erase(g->inputs().at(*it));
61 g->eraseInput(*it);
62 }
63 }
64
DeduplicateInitializersByDataPtr(at::Tensor & t1,at::Tensor & t2)65 bool DeduplicateInitializersByDataPtr(at::Tensor& t1, at::Tensor& t2) {
66 return t1.sizes().equals(t2.sizes()) && t1.strides().equals(t2.strides()) &&
67 (t1.has_storage() && t2.has_storage() && t1.data_ptr() == t2.data_ptr());
68 }
69
DeduplicateInitializersByValue(at::Tensor & t1,at::Tensor & t2)70 bool DeduplicateInitializersByValue(at::Tensor& t1, at::Tensor& t2) {
71 if (t1.dtype() != t2.dtype() || !t1.sizes().equals(t2.sizes()) ||
72 !t1.strides().equals(t2.strides())) {
73 return false;
74 }
75
76 if (t1.device() != t2.device()) {
77 return t1.to("cpu").equal(t2.to("cpu"));
78 }
79
80 return t1.equal(t2);
81 }
82
DeduplicateInitializers(std::shared_ptr<Graph> & g,std::map<std::string,IValue> & paramsDict,bool is_train)83 void DeduplicateInitializers(
84 std::shared_ptr<Graph>& g,
85 std::map<std::string, IValue>& paramsDict,
86 bool is_train) {
87 auto valsToParamsMap = buildValueToParamsMap(g->block(), paramsDict);
88 // ONNX spec does not support parameters with shared memory.
89 // This pass de-duplicate those parameters. Training is not affected.
90 DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByDataPtr);
91 if (!is_train) {
92 // More aggressive parameters de-duplication based on tensor values.
93 // Producing more compact model for inference.
94 // For training, this pass is disabled,
95 // because parameters may be updated differently.
96 DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByValue);
97 }
98 buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
99 }
100
101 } // namespace torch::jit
102