xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorNames.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/WrapDimUtils.h>
4 
5 namespace at::namedinference {
6 
7 // TensorName and TensorNames are wrappers around Dimname and DimnameList
8 // that contain helper functions to make writing name inference rules easier.
9 //
10 // A TensorName represents a Dimname associated with some DimnameList (from a
11 // Tensor). This encapsulates all the information that is needed to check if
12 // names *match* and to *unify* names.
13 //
14 // Definition: Two names in two tensors *match* if they are equal, or if at
15 // least one of them is a wildcard that can be *refined* to the other name.
16 //
17 // Definition: unify(name, other) fails if the names do not match. Otherwise,
18 // it returns the most refined of name and other.
19 //
20 // Here is an example of checking if two names match.
21 // tensor: Tensor[A, None]
22 // other: Tensor[A]
23 //
24 // Let's say we wish to check if tensor.names[-1] matches other.names[-1].
25 // None (in tensor) cannot match A (in other) because if the None were refined
26 // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
27 // tensor.names [A, None] for the existence of A.
28 struct TORCH_API TensorName {
TensorNameTensorName29   explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
30       : origin_(origin),
31         name_(origin[maybe_wrap_dim(
32             origin_idx,
33             static_cast<int64_t>(origin.size()))]),
34         origin_idx_(origin_idx) {}
35 
36   // op_name is only used for error reporting.
37   const TensorName& unify(const TensorName& other, const char* op_name) const;
38   Dimname toDimname() const;
39 
40  private:
41   ArrayRef<Dimname> origin_;
42   Dimname name_;
43   int origin_idx_; // A named tensor can have at most 64 dims.
44 
45   TORCH_API friend std::ostream& operator<<(
46       std::ostream& out,
47       const TensorName& tensorname);
48 };
49 
50 using TensorNameVec = SmallVector<TensorName, 10>;
51 
52 struct TORCH_API TensorNames {
53   explicit TensorNames(ArrayRef<Dimname> names);
54 
55   // Create TensorNames from names[start:end]. Each individual TensorName stores
56   // `names`, NOT names[start:end], because the original tensor's names are
57   // `names`.
58   explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
59 
60   // op_name is only used for error reporting.
61   TensorNames& unifyFromRightInplace(
62       const TensorNames& other,
63       const char* op_name = "unify");
64   void checkUnique(const char* op_name) const;
65 
66   void append(TensorName name);
67   std::vector<Dimname> toDimnameVec() const;
68 
69  private:
TensorNamesTensorNames70   explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){};
71 
72   TensorNameVec names_;
73 };
74 
75 } // namespace at::namedinference
76