xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/shape.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ostream>
4 #include <vector>
5 
6 #include <c10/core/Scalar.h>
7 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
8 #include <torch/csrc/lazy/core/hash.h>
9 
10 C10_DECLARE_bool(ltc_enable_symbolic_shapes);
11 
12 namespace torch {
13 namespace lazy {
14 
15 class TORCH_API Shape {
16  public:
17   Shape() = default;
18 
19   Shape(
20       at::ScalarType scalar_type,
21       c10::ArrayRef<int64_t> sizes,
22       std::optional<std::vector<bool>> is_symbolic = std::nullopt);
23 
24   std::string to_string() const;
25 
scalar_type()26   c10::ScalarType scalar_type() const {
27     return scalar_type_;
28   }
set_scalar_type(at::ScalarType value)29   void set_scalar_type(at::ScalarType value) {
30     scalar_type_ = value;
31   }
32 
dim()33   int64_t dim() const {
34     return sizes_.size();
35   }
sizes()36   c10::ArrayRef<int64_t> sizes() const {
37     return sizes_;
38   }
size(int64_t dim)39   int64_t size(int64_t dim) const {
40     return sizes_.at(dim);
41   }
set_size(int64_t dim,int64_t size)42   void set_size(int64_t dim, int64_t size) {
43     sizes_.at(dim) = size;
44   }
45 
is_symbolic()46   const std::optional<std::vector<bool>>& is_symbolic() const {
47     return is_symbolic_;
48   }
49 
50   // Makes a copy with symbolic dims applied
51   Shape with_symbolic_dims(
52       std::optional<std::vector<bool>> symbolic_dims) const;
53 
54   size_t numel() const;
55   hash_t hash(bool bakeInSizes) const;
56 
57   bool operator==(const Shape& other) const;
58 
59  private:
60   c10::ScalarType scalar_type_{c10::ScalarType::Undefined};
61 
62   // Sizes are the upper bound sizes for a tensor, used by XLA.
63   std::vector<int64_t> sizes_;
64   // Stores which dimmensions are symbolic
65   // If nullopt, either it hasn't been initialized or the symbolic
66   // dimmensions are not calculatable
67   std::optional<std::vector<bool>> is_symbolic_ = std::nullopt;
68 };
69 
70 TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape);
71 
72 TORCH_API bool symbolicShapeEnabled();
73 // Calculate and applies symbolic shapes onto the
74 // Shape objects passed to result_shapes
75 TORCH_API void applySymbolicShapesOnLT(
76     const char* schema_str,
77     std::vector<c10::IValue> args,
78     std::vector<Shape>& result_shapes);
79 } // namespace lazy
80 } // namespace torch
81