1 #pragma once 2 3 #include <ATen/core/symbol.h> 4 #include <c10/util/ArrayRef.h> 5 #include <optional> 6 #include <ostream> 7 8 namespace at { 9 10 enum class NameType: uint8_t { BASIC, WILDCARD }; 11 12 struct TORCH_API Dimname { 13 static Dimname fromSymbol(Symbol name); 14 static Dimname wildcard(); 15 static bool isValidName(const std::string& name); 16 typeDimname17 NameType type() const { return type_; } symbolDimname18 Symbol symbol() const { return name_; } 19 isBasicDimname20 bool isBasic() const { return type_ == NameType::BASIC; } isWildcardDimname21 bool isWildcard() const { return type_ == NameType::WILDCARD; } 22 23 bool matches(Dimname other) const; 24 std::optional<Dimname> unify(Dimname other) const; 25 26 private: DimnameDimname27 Dimname(Symbol name) 28 : name_(name), type_(NameType::BASIC) {} DimnameDimname29 Dimname(Symbol name, NameType type) 30 : name_(name), type_(type) {} 31 32 Symbol name_; 33 NameType type_; 34 }; 35 36 using DimnameList = c10::ArrayRef<Dimname>; 37 38 TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); 39 40 inline bool operator==(const Dimname& lhs, const Dimname& rhs) { 41 return lhs.symbol() == rhs.symbol(); 42 } 43 44 inline bool operator!=(const Dimname& lhs, const Dimname& rhs) { 45 return !(lhs == rhs); 46 } 47 48 } // namespace at 49