1 #pragma once 2 #include <ATen/Utils.h> 3 #include <c10/util/ArrayRef.h> 4 5 namespace at { 6 /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that 7 /// we can easily view it as a multidimensional array. 8 /// 9 /// Like ArrayRef, this class does not own the underlying data, it is expected 10 /// to be used in situations where the data resides in some other buffer. 11 /// 12 /// This is intended to be trivially copyable, so it should be passed by 13 /// value. 14 /// 15 /// For now, 2D only (so the copies are actually cheap, without having 16 /// to write a SmallVector class) and contiguous only (so we can 17 /// return non-strided ArrayRef on index). 18 /// 19 /// P.S. dimension 0 indexes rows, dimension 1 indexes columns 20 template <typename T> 21 class MatrixRef { 22 public: 23 typedef size_t size_type; 24 25 private: 26 /// Underlying ArrayRef 27 ArrayRef<T> arr; 28 29 /// Stride of dim 0 (outer dimension) 30 size_type stride0; 31 32 // Stride of dim 1 is assumed to be 1 33 34 public: 35 /// Construct an empty Matrixref. MatrixRef()36 /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} 37 38 /// Construct an MatrixRef from an ArrayRef and outer stride. MatrixRef(ArrayRef<T> arr,size_type stride0)39 /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0) 40 : arr(arr), stride0(stride0) { 41 TORCH_CHECK( 42 arr.size() % stride0 == 0, 43 "MatrixRef: ArrayRef size ", 44 arr.size(), 45 " not divisible by stride ", 46 stride0) 47 } 48 49 /// @} 50 /// @name Simple Operations 51 /// @{ 52 53 /// empty - Check if the matrix is empty. empty()54 bool empty() const { 55 return arr.empty(); 56 } 57 data()58 const T* data() const { 59 return arr.data(); 60 } 61 62 /// size - Get size a dimension size(size_t dim)63 size_t size(size_t dim) const { 64 if (dim == 0) { 65 return arr.size() / stride0; 66 } else if (dim == 1) { 67 return stride0; 68 } else { 69 TORCH_CHECK( 70 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1"); 71 } 72 } 73 numel()74 size_t numel() const { 75 return arr.size(); 76 } 77 78 /// equals - Check for element-wise equality. equals(MatrixRef RHS)79 bool equals(MatrixRef RHS) const { 80 return stride0 == RHS.stride0 && arr.equals(RHS.arr); 81 } 82 83 /// @} 84 /// @name Operator Overloads 85 /// @{ 86 ArrayRef<T> operator[](size_t Index) const { 87 return arr.slice(Index * stride0, stride0); 88 } 89 90 /// Disallow accidental assignment from a temporary. 91 /// 92 /// The declaration here is extra complicated so that "arrayRef = {}" 93 /// continues to select the move assignment operator. 94 template <typename U> 95 std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=( 96 U&& Temporary) = delete; 97 98 /// Disallow accidental assignment from a temporary. 99 /// 100 /// The declaration here is extra complicated so that "arrayRef = {}" 101 /// continues to select the move assignment operator. 102 template <typename U> 103 std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=( 104 std::initializer_list<U>) = delete; 105 }; 106 107 } // end namespace at 108