xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/types.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <iosfwd>
5 
6 #include <c10/core/ScalarType.h>
7 #include <c10/util/Logging.h>
8 #include <torch/csrc/Export.h>
9 
10 #include <torch/csrc/jit/tensorexpr/exceptions.h>
11 
12 namespace torch::jit::tensorexpr {
13 
14 using int32 = std::int32_t;
15 
16 class Dtype;
17 TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype);
18 
19 using ScalarType = c10::ScalarType;
20 
21 enum ElementType {
22   kAllTypes = 0,
23   kIntegralTypes = 1 << 0,
24   kFloatingPointTypes = 1 << 1,
25   kBoolType = 1 << 2,
26   kComplexTypes = 1 << 3,
27   kQintTypes = 1 << 4,
28   kNonComplexOrQintTypes = kIntegralTypes | kBoolType | kFloatingPointTypes,
29 };
30 
31 // Data types for scalar and vector elements.
32 class TORCH_API Dtype {
33  public:
Dtype(int8_t type)34   explicit Dtype(int8_t type)
35       : scalar_type_(static_cast<ScalarType>(type)), lanes_(1) {}
Dtype(ScalarType type)36   explicit Dtype(ScalarType type) : scalar_type_(type), lanes_(1) {}
Dtype(int8_t type,int64_t lanes)37   Dtype(int8_t type, int64_t lanes)
38       : scalar_type_(static_cast<ScalarType>(type)), lanes_(lanes) {}
Dtype(ScalarType type,int64_t lanes)39   Dtype(ScalarType type, int64_t lanes) : scalar_type_(type), lanes_(lanes) {}
Dtype(Dtype type,int64_t lanes)40   Dtype(Dtype type, int64_t lanes)
41       : scalar_type_(type.scalar_type_), lanes_(lanes) {
42     if (type.lanes() != 1) {
43       throw malformed_input("dtype lanes dont match");
44     }
45   }
lanes()46   int64_t lanes() const {
47     return lanes_;
48   }
scalar_type()49   ScalarType scalar_type() const {
50     return scalar_type_;
51   }
52   Dtype scalar_dtype() const;
53   bool operator==(const Dtype& other) const {
54     return scalar_type_ == other.scalar_type_ && lanes_ == other.lanes_;
55   }
56   bool operator!=(const Dtype& other) const {
57     return !(*this == other);
58   }
59   int byte_size() const;
60   std::string ToCppString() const;
61 
is_integral()62   bool is_integral() const {
63     return c10::isIntegralType(scalar_type_, true);
64   }
is_floating_point()65   bool is_floating_point() const {
66     return c10::isFloatingType(scalar_type_);
67   }
is_signed()68   bool is_signed() const {
69     return c10::isSignedType(scalar_type_);
70   }
71 
cloneWithScalarType(ScalarType nt)72   Dtype cloneWithScalarType(ScalarType nt) const {
73     return Dtype(nt, lanes_);
74   }
75 
76  private:
77   friend TORCH_API std::ostream& operator<<(
78       std::ostream& stream,
79       const Dtype& dtype);
80   ScalarType scalar_type_;
81   int64_t lanes_; // the width of the element for a vector time
82 };
83 
84 extern TORCH_API Dtype kHandle;
85 
86 #define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name;
87 
88 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION)
89 NNC_DTYPE_DECLARATION(c10::quint8, QUInt8);
90 NNC_DTYPE_DECLARATION(c10::qint8, QInt8);
91 #undef NNC_DTYPE_DECLARATION
92 
93 template <typename T>
94 TORCH_API Dtype ToDtype();
95 
96 #define NNC_TODTYPE_DECLARATION(ctype, name) \
97   template <>                                \
98   inline Dtype ToDtype<ctype>() {            \
99     return k##name;                          \
100   }
101 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION)
102 NNC_TODTYPE_DECLARATION(c10::quint8, QUInt8);
103 NNC_TODTYPE_DECLARATION(c10::qint8, QInt8);
104 #undef NNC_TODTYPE_DECLARATION
105 
106 TORCH_API Dtype ToDtype(ScalarType type);
107 
promoteTypes(Dtype a,Dtype b)108 inline Dtype promoteTypes(Dtype a, Dtype b) {
109   if (a.lanes() != b.lanes()) {
110     throw malformed_input("promoting types with different lanes");
111   }
112   return Dtype(
113       static_cast<ScalarType>(c10::promoteTypes(
114           static_cast<c10::ScalarType>(a.scalar_type()),
115           static_cast<c10::ScalarType>(b.scalar_type()))),
116       a.lanes());
117 }
118 
119 inline Dtype BinaryOpDtype(
120     Dtype op1_dtype,
121     Dtype op2_dtype,
122     ScalarType ret_type = ScalarType::Undefined) {
123   if (op1_dtype == op2_dtype) {
124     if (ret_type == ScalarType::Undefined) {
125       return op1_dtype;
126     }
127 
128     return ToDtype(ret_type);
129   }
130 
131   if (op1_dtype.lanes() != op2_dtype.lanes()) {
132     throw malformed_input("lanes dont match");
133   }
134   int64_t lanes = op1_dtype.lanes();
135 
136   Dtype resultType = promoteTypes(op1_dtype, op2_dtype);
137   if (resultType.scalar_type() == ScalarType::Undefined) {
138     throw malformed_input("scalar type doesn't match");
139   }
140 
141   if (lanes == 1) {
142     // Use the fixed scalar Dtypes.
143     return ToDtype(resultType.scalar_type());
144   }
145 
146   return resultType;
147 }
148 
149 } // namespace torch::jit::tensorexpr
150 
151 namespace std {
152 
153 using torch::jit::tensorexpr::Dtype;
154 std::string to_string(const Dtype& dtype);
155 using torch::jit::tensorexpr::ScalarType;
156 std::string to_string(const ScalarType& dtype);
157 
158 } // namespace std
159