xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/padded_buffer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "test/cpp/tensorexpr/padded_buffer.h"
2 
3 #include <c10/util/Logging.h>
4 #include <c10/util/irange.h>
5 #include <sstream>
6 
7 namespace torch {
8 namespace jit {
9 namespace tensorexpr {
10 
Index(const std::vector<int> & indices) const11 int PaddedBufferBase::Index(const std::vector<int>& indices) const {
12   TORCH_DCHECK_EQ(dims_.size(), indices.size());
13   int total_index = 0;
14   for (const auto i : c10::irange(dims_.size())) {
15     total_index += indices[i] * strides_[i];
16   }
17   return total_index;
18 }
19 
PaddedBufferBase(const std::vector<int> & dims,const std::string & name)20 PaddedBufferBase::PaddedBufferBase(
21     const std::vector<int>& dims,
22     // NOLINTNEXTLINE(modernize-pass-by-value)
23     const std::string& name)
24     : dims_(dims), name_(name), strides_(dims.size()) {
25   for (int i = (int)dims.size() - 1; i >= 0; --i) {
26     if (i == (int)dims.size() - 1) {
27       strides_[i] = 1;
28     } else {
29       strides_[i] = strides_[i + 1] * dims[i + 1];
30     }
31   }
32   total_size_ = strides_[0] * dims[0];
33 }
34 
35 } // namespace tensorexpr
36 } // namespace jit
37 } // namespace torch
38