xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_cache.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
5 
6 namespace torch::jit {
7 
8 struct TORCH_API CanonicalizedSymbolicShape {
9   // TODO: Consider in the future if it is reasonable to
10   // merge code with SymbolicShape or VaryingShape while keeping
11   // the two not implicitly convertable (and cause bugs).
CanonicalizedSymbolicShapeCanonicalizedSymbolicShape12   CanonicalizedSymbolicShape(
13       const c10::SymbolicShape& orig_shape,
14       std::unordered_map<int64_t, int64_t>& ss_map) {
15     init(orig_shape, ss_map);
16   }
17 
CanonicalizedSymbolicShapeCanonicalizedSymbolicShape18   CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
19     std::unordered_map<int64_t, int64_t> new_ssmap;
20     init(orig_shape, new_ssmap);
21   }
22 
23   size_t hash() const;
24 
25   c10::SymbolicShape toSymbolicShape(
26       std::unordered_map<int64_t, int64_t>& inverse_ss_map) const;
27 
28   TORCH_API friend bool operator==(
29       const CanonicalizedSymbolicShape& a,
30       const CanonicalizedSymbolicShape& b);
31 
32  private:
33   std::optional<std::vector<int64_t>> values_;
34 
35   void init(
36       const c10::SymbolicShape& orig_shape,
37       std::unordered_map<int64_t, int64_t>& ss_map);
38 };
39 
40 // SHAPE CACHE API
41 TORCH_API std::optional<std::vector<at::SymbolicShape>>
42 get_cached_shape_function(
43     const FunctionSchema* schema,
44     const std::vector<SSAInput>& arg_vec);
45 
46 TORCH_API void cache_shape_function(
47     const FunctionSchema* schema,
48     const std::vector<SSAInput>& arg_vec,
49     const std::vector<at::SymbolicShape>& ret_vec);
50 
51 // For use in test code
52 TORCH_API void clear_shape_cache();
53 TORCH_API size_t get_shape_cache_size();
54 
55 } // namespace torch::jit
56