1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> 5 6 namespace torch::jit { 7 8 struct TORCH_API CanonicalizedSymbolicShape { 9 // TODO: Consider in the future if it is reasonable to 10 // merge code with SymbolicShape or VaryingShape while keeping 11 // the two not implicitly convertable (and cause bugs). CanonicalizedSymbolicShapeCanonicalizedSymbolicShape12 CanonicalizedSymbolicShape( 13 const c10::SymbolicShape& orig_shape, 14 std::unordered_map<int64_t, int64_t>& ss_map) { 15 init(orig_shape, ss_map); 16 } 17 CanonicalizedSymbolicShapeCanonicalizedSymbolicShape18 CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) { 19 std::unordered_map<int64_t, int64_t> new_ssmap; 20 init(orig_shape, new_ssmap); 21 } 22 23 size_t hash() const; 24 25 c10::SymbolicShape toSymbolicShape( 26 std::unordered_map<int64_t, int64_t>& inverse_ss_map) const; 27 28 TORCH_API friend bool operator==( 29 const CanonicalizedSymbolicShape& a, 30 const CanonicalizedSymbolicShape& b); 31 32 private: 33 std::optional<std::vector<int64_t>> values_; 34 35 void init( 36 const c10::SymbolicShape& orig_shape, 37 std::unordered_map<int64_t, int64_t>& ss_map); 38 }; 39 40 // SHAPE CACHE API 41 TORCH_API std::optional<std::vector<at::SymbolicShape>> 42 get_cached_shape_function( 43 const FunctionSchema* schema, 44 const std::vector<SSAInput>& arg_vec); 45 46 TORCH_API void cache_shape_function( 47 const FunctionSchema* schema, 48 const std::vector<SSAInput>& arg_vec, 49 const std::vector<at::SymbolicShape>& ret_vec); 50 51 // For use in test code 52 TORCH_API void clear_shape_cache(); 53 TORCH_API size_t get_shape_cache_size(); 54 55 } // namespace torch::jit 56