xref: /aosp_15_r20/external/pytorch/aten/src/ATen/MatrixRef.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/Utils.h>
3 #include <c10/util/ArrayRef.h>
4 
5 namespace at {
6 /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
7 /// we can easily view it as a multidimensional array.
8 ///
9 /// Like ArrayRef, this class does not own the underlying data, it is expected
10 /// to be used in situations where the data resides in some other buffer.
11 ///
12 /// This is intended to be trivially copyable, so it should be passed by
13 /// value.
14 ///
15 /// For now, 2D only (so the copies are actually cheap, without having
16 /// to write a SmallVector class) and contiguous only (so we can
17 /// return non-strided ArrayRef on index).
18 ///
19 /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
20 template <typename T>
21 class MatrixRef {
22  public:
23   typedef size_t size_type;
24 
25  private:
26   /// Underlying ArrayRef
27   ArrayRef<T> arr;
28 
29   /// Stride of dim 0 (outer dimension)
30   size_type stride0;
31 
32   // Stride of dim 1 is assumed to be 1
33 
34  public:
35   /// Construct an empty Matrixref.
MatrixRef()36   /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
37 
38   /// Construct an MatrixRef from an ArrayRef and outer stride.
MatrixRef(ArrayRef<T> arr,size_type stride0)39   /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
40       : arr(arr), stride0(stride0) {
41     TORCH_CHECK(
42         arr.size() % stride0 == 0,
43         "MatrixRef: ArrayRef size ",
44         arr.size(),
45         " not divisible by stride ",
46         stride0)
47   }
48 
49   /// @}
50   /// @name Simple Operations
51   /// @{
52 
53   /// empty - Check if the matrix is empty.
empty()54   bool empty() const {
55     return arr.empty();
56   }
57 
data()58   const T* data() const {
59     return arr.data();
60   }
61 
62   /// size - Get size a dimension
size(size_t dim)63   size_t size(size_t dim) const {
64     if (dim == 0) {
65       return arr.size() / stride0;
66     } else if (dim == 1) {
67       return stride0;
68     } else {
69       TORCH_CHECK(
70           0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
71     }
72   }
73 
numel()74   size_t numel() const {
75     return arr.size();
76   }
77 
78   /// equals - Check for element-wise equality.
equals(MatrixRef RHS)79   bool equals(MatrixRef RHS) const {
80     return stride0 == RHS.stride0 && arr.equals(RHS.arr);
81   }
82 
83   /// @}
84   /// @name Operator Overloads
85   /// @{
86   ArrayRef<T> operator[](size_t Index) const {
87     return arr.slice(Index * stride0, stride0);
88   }
89 
90   /// Disallow accidental assignment from a temporary.
91   ///
92   /// The declaration here is extra complicated so that "arrayRef = {}"
93   /// continues to select the move assignment operator.
94   template <typename U>
95   std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
96       U&& Temporary) = delete;
97 
98   /// Disallow accidental assignment from a temporary.
99   ///
100   /// The declaration here is extra complicated so that "arrayRef = {}"
101   /// continues to select the move assignment operator.
102   template <typename U>
103   std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
104       std::initializer_list<U>) = delete;
105 };
106 
107 } // end namespace at
108