xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Dimname.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/symbol.h>
4 #include <c10/util/ArrayRef.h>
5 #include <optional>
6 #include <ostream>
7 
8 namespace at {
9 
10 enum class NameType: uint8_t { BASIC, WILDCARD };
11 
12 struct TORCH_API Dimname {
13   static Dimname fromSymbol(Symbol name);
14   static Dimname wildcard();
15   static bool isValidName(const std::string& name);
16 
typeDimname17   NameType type() const { return type_; }
symbolDimname18   Symbol symbol() const { return name_; }
19 
isBasicDimname20   bool isBasic() const { return type_ == NameType::BASIC; }
isWildcardDimname21   bool isWildcard() const { return type_ == NameType::WILDCARD; }
22 
23   bool matches(Dimname other) const;
24   std::optional<Dimname> unify(Dimname other) const;
25 
26  private:
DimnameDimname27   Dimname(Symbol name)
28     : name_(name), type_(NameType::BASIC) {}
DimnameDimname29   Dimname(Symbol name, NameType type)
30     : name_(name), type_(type) {}
31 
32   Symbol name_;
33   NameType type_;
34 };
35 
36 using DimnameList = c10::ArrayRef<Dimname>;
37 
38 TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
39 
40 inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
41   return lhs.symbol() == rhs.symbol();
42 }
43 
44 inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
45   return !(lhs == rhs);
46 }
47 
48 } // namespace at
49