xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/shape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/lazy/core/shape.h>
3 #include <torch/csrc/lazy/core/tensor.h>
4 
5 C10_DEFINE_bool(
6     ltc_enable_symbolic_shapes,
7     false,
8     "Enables calculation of if dims are symbolic");
9 
10 namespace torch {
11 namespace lazy {
12 
Shape(at::ScalarType scalar_type,c10::ArrayRef<int64_t> sizes,std::optional<std::vector<bool>> is_symbolic)13 Shape::Shape(
14     at::ScalarType scalar_type,
15     c10::ArrayRef<int64_t> sizes,
16     std::optional<std::vector<bool>> is_symbolic)
17     : scalar_type_(scalar_type),
18       sizes_(sizes.begin(), sizes.end()),
19       is_symbolic_(std::move(is_symbolic)) {}
20 
to_string() const21 std::string Shape::to_string() const {
22   return c10::str(toString(scalar_type_), "[", c10::Join(",", sizes_), "]");
23 }
24 
operator ==(const Shape & other) const25 bool Shape::operator==(const Shape& other) const {
26   return scalar_type_ == other.scalar_type_ && sizes_ == other.sizes_;
27 }
28 
operator <<(std::ostream & out,const Shape & shape)29 std::ostream& operator<<(std::ostream& out, const Shape& shape) {
30   return out << shape.to_string();
31 }
32 
numel() const33 size_t Shape::numel() const {
34   size_t elts = 1;
35   for (auto size : sizes_) {
36     elts *= size;
37   }
38   return elts;
39 }
40 
hash(bool bakeInSizes) const41 hash_t Shape::hash(bool bakeInSizes) const {
42   if (bakeInSizes) {
43     return HashCombine(
44         Hash(scalar_type_),
45         DataHash(sizes_.data(), sizes_.size() * sizeof(int64_t)));
46   } else {
47     return HashCombine(Hash(scalar_type_), Hash(sizes_.size()));
48   }
49 }
50 
with_symbolic_dims(std::optional<std::vector<bool>> symbolic_dims) const51 Shape Shape::with_symbolic_dims(
52     std::optional<std::vector<bool>> symbolic_dims) const {
53   Shape copy = *this;
54   copy.is_symbolic_ = symbolic_dims;
55   return copy;
56 }
57 
symbolicShapeEnabled()58 bool symbolicShapeEnabled() {
59   static bool enabled = std::getenv("LTC_ENABLE_SYMBOLIC_SHAPES") != nullptr;
60   return enabled || FLAGS_ltc_enable_symbolic_shapes;
61 }
62 
get_symbolic_shape(at::Tensor & tensor)63 static c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) {
64   auto ltc_tensor = TryGetLtcTensor(tensor);
65   if (!ltc_tensor) {
66     // Set Concrete sizes for Concrete tensors
67     return c10::SymbolicShape(tensor.sizes());
68   }
69   const Shape& input_shape = ltc_tensor->GetIrValue()->shape();
70   auto& is_symbolic = input_shape.is_symbolic();
71   if (!is_symbolic.has_value()) {
72     return c10::SymbolicShape();
73   }
74   auto sizes = input_shape.sizes();
75   TORCH_INTERNAL_ASSERT(
76       sizes.size() == is_symbolic->size(),
77       "Dims of two values are not consistent");
78   std::vector<std::optional<int64_t>> symbolic_dims;
79   for (size_t i = 0; i < sizes.size(); i++) {
80     if (is_symbolic->at(i)) {
81       symbolic_dims.emplace_back(std::nullopt);
82     } else {
83       symbolic_dims.emplace_back(sizes.at(i));
84     }
85   }
86   return c10::SymbolicShape(symbolic_dims);
87 }
88 
applySymbolicShapesOnLT(const char * schema_str,std::vector<c10::IValue> args,std::vector<Shape> & result_shapes)89 void applySymbolicShapesOnLT(
90     const char* schema_str,
91     std::vector<c10::IValue> args,
92     std::vector<Shape>& result_shapes) {
93   std::vector<jit::SSAInput> converted_args;
94   // TODO: Determine if there are any unknown values in LazyTensor
95   const c10::FunctionSchema& schema =
96       jit::getOperatorForLiteral(schema_str)->schema();
97 
98   for (auto& arg : args) {
99     // Handle list of tensors
100     if (arg.isTensorList()) {
101       at::List<at::Tensor> tensor_list = arg.toTensorList();
102       for (at::Tensor tensor : tensor_list) {
103         converted_args.emplace_back(get_symbolic_shape(tensor));
104       }
105     } else if (arg.isTensor()) {
106       auto ss = get_symbolic_shape(arg.toTensor());
107       converted_args.emplace_back(ss);
108     } else {
109       // If we need to support symbolic ints, here is the place
110       // to add it.
111       converted_args.emplace_back(arg);
112     }
113   }
114   auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args);
115   if (!res_symbolic) {
116     for (auto& result_shape : result_shapes) {
117       result_shape = result_shape.with_symbolic_dims(std::nullopt);
118     }
119   } else {
120     TORCH_INTERNAL_ASSERT(
121         res_symbolic->size() == result_shapes.size(),
122         "Result shape size is not consistent");
123     for (size_t i = 0; i < res_symbolic->size(); i++) {
124       auto sym_dims = res_symbolic->at(i).symbolicDims();
125       if (sym_dims.has_value()) {
126         result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims);
127       }
128     }
129   }
130 }
131 
132 } // namespace lazy
133 } // namespace torch
134