xref: /aosp_15_r20/external/pytorch/c10/core/ScalarTypeToTypeMeta.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/Optional.h>
5 #include <c10/util/typeid.h>
6 
7 // these just expose TypeMeta/ScalarType bridge functions in c10
8 // TODO move to typeid.h (or codemod away) when TypeMeta et al
9 // are moved from caffe2 to c10 (see note at top of typeid.h)
10 
11 namespace c10 {
12 
13 /**
14  * convert ScalarType enum values to TypeMeta handles
15  */
scalarTypeToTypeMeta(ScalarType scalar_type)16 inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
17   return caffe2::TypeMeta::fromScalarType(scalar_type);
18 }
19 
20 /**
21  * convert TypeMeta handles to ScalarType enum values
22  */
typeMetaToScalarType(caffe2::TypeMeta dtype)23 inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
24   return dtype.toScalarType();
25 }
26 
27 /**
28  * typeMetaToScalarType(), lifted to optional
29  */
optTypeMetaToScalarType(std::optional<caffe2::TypeMeta> type_meta)30 inline std::optional<at::ScalarType> optTypeMetaToScalarType(
31     std::optional<caffe2::TypeMeta> type_meta) {
32   if (!type_meta.has_value()) {
33     return std::nullopt;
34   }
35   return type_meta->toScalarType();
36 }
37 
38 /**
39  * convenience: equality across TypeMeta/ScalarType conversion
40  */
41 inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
42   return m.isScalarType(t);
43 }
44 
45 inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
46   return t == m;
47 }
48 
49 inline bool operator!=(ScalarType t, caffe2::TypeMeta m) {
50   return !(t == m);
51 }
52 
53 inline bool operator!=(caffe2::TypeMeta m, ScalarType t) {
54   return !(t == m);
55 }
56 
57 } // namespace c10
58