1 #pragma once 2 3 #include <ostream> 4 #include <vector> 5 6 #include <c10/core/Scalar.h> 7 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> 8 #include <torch/csrc/lazy/core/hash.h> 9 10 C10_DECLARE_bool(ltc_enable_symbolic_shapes); 11 12 namespace torch { 13 namespace lazy { 14 15 class TORCH_API Shape { 16 public: 17 Shape() = default; 18 19 Shape( 20 at::ScalarType scalar_type, 21 c10::ArrayRef<int64_t> sizes, 22 std::optional<std::vector<bool>> is_symbolic = std::nullopt); 23 24 std::string to_string() const; 25 scalar_type()26 c10::ScalarType scalar_type() const { 27 return scalar_type_; 28 } set_scalar_type(at::ScalarType value)29 void set_scalar_type(at::ScalarType value) { 30 scalar_type_ = value; 31 } 32 dim()33 int64_t dim() const { 34 return sizes_.size(); 35 } sizes()36 c10::ArrayRef<int64_t> sizes() const { 37 return sizes_; 38 } size(int64_t dim)39 int64_t size(int64_t dim) const { 40 return sizes_.at(dim); 41 } set_size(int64_t dim,int64_t size)42 void set_size(int64_t dim, int64_t size) { 43 sizes_.at(dim) = size; 44 } 45 is_symbolic()46 const std::optional<std::vector<bool>>& is_symbolic() const { 47 return is_symbolic_; 48 } 49 50 // Makes a copy with symbolic dims applied 51 Shape with_symbolic_dims( 52 std::optional<std::vector<bool>> symbolic_dims) const; 53 54 size_t numel() const; 55 hash_t hash(bool bakeInSizes) const; 56 57 bool operator==(const Shape& other) const; 58 59 private: 60 c10::ScalarType scalar_type_{c10::ScalarType::Undefined}; 61 62 // Sizes are the upper bound sizes for a tensor, used by XLA. 63 std::vector<int64_t> sizes_; 64 // Stores which dimmensions are symbolic 65 // If nullopt, either it hasn't been initialized or the symbolic 66 // dimmensions are not calculatable 67 std::optional<std::vector<bool>> is_symbolic_ = std::nullopt; 68 }; 69 70 TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape); 71 72 TORCH_API bool symbolicShapeEnabled(); 73 // Calculate and applies symbolic shapes onto the 74 // Shape objects passed to result_shapes 75 TORCH_API void applySymbolicShapesOnLT( 76 const char* schema_str, 77 std::vector<c10::IValue> args, 78 std::vector<Shape>& result_shapes); 79 } // namespace lazy 80 } // namespace torch 81