1 #pragma once 2 3 #include <ATen/WrapDimUtils.h> 4 5 namespace at::namedinference { 6 7 // TensorName and TensorNames are wrappers around Dimname and DimnameList 8 // that contain helper functions to make writing name inference rules easier. 9 // 10 // A TensorName represents a Dimname associated with some DimnameList (from a 11 // Tensor). This encapsulates all the information that is needed to check if 12 // names *match* and to *unify* names. 13 // 14 // Definition: Two names in two tensors *match* if they are equal, or if at 15 // least one of them is a wildcard that can be *refined* to the other name. 16 // 17 // Definition: unify(name, other) fails if the names do not match. Otherwise, 18 // it returns the most refined of name and other. 19 // 20 // Here is an example of checking if two names match. 21 // tensor: Tensor[A, None] 22 // other: Tensor[A] 23 // 24 // Let's say we wish to check if tensor.names[-1] matches other.names[-1]. 25 // None (in tensor) cannot match A (in other) because if the None were refined 26 // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check 27 // tensor.names [A, None] for the existence of A. 28 struct TORCH_API TensorName { TensorNameTensorName29 explicit TensorName(ArrayRef<Dimname> origin, int origin_idx) 30 : origin_(origin), 31 name_(origin[maybe_wrap_dim( 32 origin_idx, 33 static_cast<int64_t>(origin.size()))]), 34 origin_idx_(origin_idx) {} 35 36 // op_name is only used for error reporting. 37 const TensorName& unify(const TensorName& other, const char* op_name) const; 38 Dimname toDimname() const; 39 40 private: 41 ArrayRef<Dimname> origin_; 42 Dimname name_; 43 int origin_idx_; // A named tensor can have at most 64 dims. 44 45 TORCH_API friend std::ostream& operator<<( 46 std::ostream& out, 47 const TensorName& tensorname); 48 }; 49 50 using TensorNameVec = SmallVector<TensorName, 10>; 51 52 struct TORCH_API TensorNames { 53 explicit TensorNames(ArrayRef<Dimname> names); 54 55 // Create TensorNames from names[start:end]. Each individual TensorName stores 56 // `names`, NOT names[start:end], because the original tensor's names are 57 // `names`. 58 explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end); 59 60 // op_name is only used for error reporting. 61 TensorNames& unifyFromRightInplace( 62 const TensorNames& other, 63 const char* op_name = "unify"); 64 void checkUnique(const char* op_name) const; 65 66 void append(TensorName name); 67 std::vector<Dimname> toDimnameVec() const; 68 69 private: TensorNamesTensorNames70 explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){}; 71 72 TensorNameVec names_; 73 }; 74 75 } // namespace at::namedinference 76