xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/helper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 namespace torch::jit {
7 
8 // Utility functions for PyTorch to ONNX conversion.
9 
10 static const int OPSET_VERSION_1 = 1;
11 static const int OPSET_VERSION_9 = 9;
12 static const int OPSET_VERSION_10 = 10;
13 static const int OPSET_VERSION_11 = 11;
14 static const int OPSET_VERSION_12 = 12;
15 static const int OPSET_VERSION_13 = 13;
16 static const int OPSET_VERSION_14 = 14;
17 static const int OPSET_VERSION_15 = 15;
18 static const int OPSET_VERSION_16 = 16;
19 
20 using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
21 
22 using ParamMap = std::map<std::string, IValue>;
23 
24 TORCH_API void buildParamsMapFromValueToParamsMap(
25     const ValueToParamPairMap& valsToParamsMap,
26     ParamMap& paramsDict);
27 TORCH_API ValueToParamPairMap
28 buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
29 TORCH_API void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
30 TORCH_API void eraseUnusedBlockInputs(Block* b);
31 TORCH_API void buildParamsMapFromValueToParamsMap(
32     const ValueToParamPairMap& valsToParamsMap,
33     ParamMap& paramsDict);
34 
35 TORCH_API Node* addNodeToBlock(
36     Block* block,
37     Symbol kind,
38     ArrayRef<Value*> inputs);
39 
40 TORCH_API Value* addInputToBlock(Block* block);
41 
42 TORCH_API std::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
43 
44 // Use int return type as no sable way exists to forward declare protobuf enum
45 TORCH_API int ATenTypeToOnnxType(at::ScalarType at_type);
46 
47 TORCH_API void ONNXLintGraph(const std::shared_ptr<Graph>& graph);
48 
49 Node* createONNXUnsqueeze(
50     Graph* graph,
51     Node* n_to_insert_before,
52     Value* input,
53     int axis,
54     int opset_version);
55 Node* createONNXConstant(
56     Graph* graph,
57     Node* n_to_insert_before,
58     at::Tensor value);
59 
60 bool isValidToTransformToONNXConcatNode(Node* lc_node);
61 
62 Node* transformToONNXConcatNode(
63     Graph* graph,
64     Node* lc_node,
65     bool need_new_input,
66     int opset_version);
67 
68 class ScalarTypeHashFunction {
69  public:
operator()70   size_t operator()(const c10::ScalarType& type) const {
71     return static_cast<size_t>(type);
72   }
73 };
74 
75 } // namespace torch::jit
76