xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/EmptyTensor.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Formatting.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/ATenGeneral.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Generator.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/core/StorageImpl.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/core/UndefinedTensorImpl.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/accumulate.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
18*da0073e9SAndroid Build Coastguard Worker   TypeName(const TypeName&) = delete;         \
19*da0073e9SAndroid Build Coastguard Worker   void operator=(const TypeName&) = delete
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker namespace at {
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker TORCH_API int _crash_if_asan(int);
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
26*da0073e9SAndroid Build Coastguard Worker // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
27*da0073e9SAndroid Build Coastguard Worker // Once cat is ported entirely to ATen this can be deleted!
checked_dense_tensor_list_unwrap(ArrayRef<Tensor> tensors,const char * name,int pos,c10::DeviceType device_type,ScalarType scalar_type)28*da0073e9SAndroid Build Coastguard Worker inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
29*da0073e9SAndroid Build Coastguard Worker     ArrayRef<Tensor> tensors,
30*da0073e9SAndroid Build Coastguard Worker     const char* name,
31*da0073e9SAndroid Build Coastguard Worker     int pos,
32*da0073e9SAndroid Build Coastguard Worker     c10::DeviceType device_type,
33*da0073e9SAndroid Build Coastguard Worker     ScalarType scalar_type) {
34*da0073e9SAndroid Build Coastguard Worker   std::vector<TensorImpl*> unwrapped;
35*da0073e9SAndroid Build Coastguard Worker   unwrapped.reserve(tensors.size());
36*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(tensors.size())) {
37*da0073e9SAndroid Build Coastguard Worker     const auto& expr = tensors[i];
38*da0073e9SAndroid Build Coastguard Worker     if (expr.layout() != Layout::Strided) {
39*da0073e9SAndroid Build Coastguard Worker       AT_ERROR(
40*da0073e9SAndroid Build Coastguard Worker           "Expected dense tensor but got ",
41*da0073e9SAndroid Build Coastguard Worker           expr.layout(),
42*da0073e9SAndroid Build Coastguard Worker           " for sequence element ",
43*da0073e9SAndroid Build Coastguard Worker           i,
44*da0073e9SAndroid Build Coastguard Worker           " in sequence argument at position #",
45*da0073e9SAndroid Build Coastguard Worker           pos,
46*da0073e9SAndroid Build Coastguard Worker           " '",
47*da0073e9SAndroid Build Coastguard Worker           name,
48*da0073e9SAndroid Build Coastguard Worker           "'");
49*da0073e9SAndroid Build Coastguard Worker     }
50*da0073e9SAndroid Build Coastguard Worker     if (expr.device().type() != device_type) {
51*da0073e9SAndroid Build Coastguard Worker       AT_ERROR(
52*da0073e9SAndroid Build Coastguard Worker           "Expected object of device type ",
53*da0073e9SAndroid Build Coastguard Worker           device_type,
54*da0073e9SAndroid Build Coastguard Worker           " but got device type ",
55*da0073e9SAndroid Build Coastguard Worker           expr.device().type(),
56*da0073e9SAndroid Build Coastguard Worker           " for sequence element ",
57*da0073e9SAndroid Build Coastguard Worker           i,
58*da0073e9SAndroid Build Coastguard Worker           " in sequence argument at position #",
59*da0073e9SAndroid Build Coastguard Worker           pos,
60*da0073e9SAndroid Build Coastguard Worker           " '",
61*da0073e9SAndroid Build Coastguard Worker           name,
62*da0073e9SAndroid Build Coastguard Worker           "'");
63*da0073e9SAndroid Build Coastguard Worker     }
64*da0073e9SAndroid Build Coastguard Worker     if (expr.scalar_type() != scalar_type) {
65*da0073e9SAndroid Build Coastguard Worker       AT_ERROR(
66*da0073e9SAndroid Build Coastguard Worker           "Expected object of scalar type ",
67*da0073e9SAndroid Build Coastguard Worker           scalar_type,
68*da0073e9SAndroid Build Coastguard Worker           " but got scalar type ",
69*da0073e9SAndroid Build Coastguard Worker           expr.scalar_type(),
70*da0073e9SAndroid Build Coastguard Worker           " for sequence element ",
71*da0073e9SAndroid Build Coastguard Worker           i,
72*da0073e9SAndroid Build Coastguard Worker           " in sequence argument at position #",
73*da0073e9SAndroid Build Coastguard Worker           pos,
74*da0073e9SAndroid Build Coastguard Worker           " '",
75*da0073e9SAndroid Build Coastguard Worker           name,
76*da0073e9SAndroid Build Coastguard Worker           "'");
77*da0073e9SAndroid Build Coastguard Worker     }
78*da0073e9SAndroid Build Coastguard Worker     unwrapped.emplace_back(expr.unsafeGetTensorImpl());
79*da0073e9SAndroid Build Coastguard Worker   }
80*da0073e9SAndroid Build Coastguard Worker   return unwrapped;
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker template <size_t N>
check_intlist(ArrayRef<int64_t> list,const char * name,int pos)84*da0073e9SAndroid Build Coastguard Worker std::array<int64_t, N> check_intlist(
85*da0073e9SAndroid Build Coastguard Worker     ArrayRef<int64_t> list,
86*da0073e9SAndroid Build Coastguard Worker     const char* name,
87*da0073e9SAndroid Build Coastguard Worker     int pos) {
88*da0073e9SAndroid Build Coastguard Worker   if (list.empty()) {
89*da0073e9SAndroid Build Coastguard Worker     // TODO: is this necessary?  We used to treat nullptr-vs-not in IntList
90*da0073e9SAndroid Build Coastguard Worker     // differently with strides as a way of faking optional.
91*da0073e9SAndroid Build Coastguard Worker     list = {};
92*da0073e9SAndroid Build Coastguard Worker   }
93*da0073e9SAndroid Build Coastguard Worker   auto res = std::array<int64_t, N>();
94*da0073e9SAndroid Build Coastguard Worker   if (list.size() == 1 && N > 1) {
95*da0073e9SAndroid Build Coastguard Worker     res.fill(list[0]);
96*da0073e9SAndroid Build Coastguard Worker     return res;
97*da0073e9SAndroid Build Coastguard Worker   }
98*da0073e9SAndroid Build Coastguard Worker   if (list.size() != N) {
99*da0073e9SAndroid Build Coastguard Worker     AT_ERROR(
100*da0073e9SAndroid Build Coastguard Worker         "Expected a list of ",
101*da0073e9SAndroid Build Coastguard Worker         N,
102*da0073e9SAndroid Build Coastguard Worker         " ints but got ",
103*da0073e9SAndroid Build Coastguard Worker         list.size(),
104*da0073e9SAndroid Build Coastguard Worker         " for argument #",
105*da0073e9SAndroid Build Coastguard Worker         pos,
106*da0073e9SAndroid Build Coastguard Worker         " '",
107*da0073e9SAndroid Build Coastguard Worker         name,
108*da0073e9SAndroid Build Coastguard Worker         "'");
109*da0073e9SAndroid Build Coastguard Worker   }
110*da0073e9SAndroid Build Coastguard Worker   std::copy_n(list.begin(), N, res.begin());
111*da0073e9SAndroid Build Coastguard Worker   return res;
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker 
114*da0073e9SAndroid Build Coastguard Worker using at::detail::check_size_nonnegative;
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker namespace detail {
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker template <typename T>
119*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker template <typename T>
122*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
123*da0073e9SAndroid Build Coastguard Worker tensor_backend(ArrayRef<T> values, const TensorOptions& options);
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker template <typename T>
126*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
127*da0073e9SAndroid Build Coastguard Worker tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
128*da0073e9SAndroid Build Coastguard Worker 
129*da0073e9SAndroid Build Coastguard Worker template <typename T>
130*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
131*da0073e9SAndroid Build Coastguard Worker tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
132*da0073e9SAndroid Build Coastguard Worker } // namespace detail
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker } // namespace at
135