1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/EmptyTensor.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Formatting.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/ATenGeneral.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Generator.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/core/StorageImpl.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/core/UndefinedTensorImpl.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/accumulate.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
18*da0073e9SAndroid Build Coastguard Worker TypeName(const TypeName&) = delete; \
19*da0073e9SAndroid Build Coastguard Worker void operator=(const TypeName&) = delete
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker namespace at {
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker TORCH_API int _crash_if_asan(int);
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
26*da0073e9SAndroid Build Coastguard Worker // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
27*da0073e9SAndroid Build Coastguard Worker // Once cat is ported entirely to ATen this can be deleted!
checked_dense_tensor_list_unwrap(ArrayRef<Tensor> tensors,const char * name,int pos,c10::DeviceType device_type,ScalarType scalar_type)28*da0073e9SAndroid Build Coastguard Worker inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
29*da0073e9SAndroid Build Coastguard Worker ArrayRef<Tensor> tensors,
30*da0073e9SAndroid Build Coastguard Worker const char* name,
31*da0073e9SAndroid Build Coastguard Worker int pos,
32*da0073e9SAndroid Build Coastguard Worker c10::DeviceType device_type,
33*da0073e9SAndroid Build Coastguard Worker ScalarType scalar_type) {
34*da0073e9SAndroid Build Coastguard Worker std::vector<TensorImpl*> unwrapped;
35*da0073e9SAndroid Build Coastguard Worker unwrapped.reserve(tensors.size());
36*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(tensors.size())) {
37*da0073e9SAndroid Build Coastguard Worker const auto& expr = tensors[i];
38*da0073e9SAndroid Build Coastguard Worker if (expr.layout() != Layout::Strided) {
39*da0073e9SAndroid Build Coastguard Worker AT_ERROR(
40*da0073e9SAndroid Build Coastguard Worker "Expected dense tensor but got ",
41*da0073e9SAndroid Build Coastguard Worker expr.layout(),
42*da0073e9SAndroid Build Coastguard Worker " for sequence element ",
43*da0073e9SAndroid Build Coastguard Worker i,
44*da0073e9SAndroid Build Coastguard Worker " in sequence argument at position #",
45*da0073e9SAndroid Build Coastguard Worker pos,
46*da0073e9SAndroid Build Coastguard Worker " '",
47*da0073e9SAndroid Build Coastguard Worker name,
48*da0073e9SAndroid Build Coastguard Worker "'");
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker if (expr.device().type() != device_type) {
51*da0073e9SAndroid Build Coastguard Worker AT_ERROR(
52*da0073e9SAndroid Build Coastguard Worker "Expected object of device type ",
53*da0073e9SAndroid Build Coastguard Worker device_type,
54*da0073e9SAndroid Build Coastguard Worker " but got device type ",
55*da0073e9SAndroid Build Coastguard Worker expr.device().type(),
56*da0073e9SAndroid Build Coastguard Worker " for sequence element ",
57*da0073e9SAndroid Build Coastguard Worker i,
58*da0073e9SAndroid Build Coastguard Worker " in sequence argument at position #",
59*da0073e9SAndroid Build Coastguard Worker pos,
60*da0073e9SAndroid Build Coastguard Worker " '",
61*da0073e9SAndroid Build Coastguard Worker name,
62*da0073e9SAndroid Build Coastguard Worker "'");
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker if (expr.scalar_type() != scalar_type) {
65*da0073e9SAndroid Build Coastguard Worker AT_ERROR(
66*da0073e9SAndroid Build Coastguard Worker "Expected object of scalar type ",
67*da0073e9SAndroid Build Coastguard Worker scalar_type,
68*da0073e9SAndroid Build Coastguard Worker " but got scalar type ",
69*da0073e9SAndroid Build Coastguard Worker expr.scalar_type(),
70*da0073e9SAndroid Build Coastguard Worker " for sequence element ",
71*da0073e9SAndroid Build Coastguard Worker i,
72*da0073e9SAndroid Build Coastguard Worker " in sequence argument at position #",
73*da0073e9SAndroid Build Coastguard Worker pos,
74*da0073e9SAndroid Build Coastguard Worker " '",
75*da0073e9SAndroid Build Coastguard Worker name,
76*da0073e9SAndroid Build Coastguard Worker "'");
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker unwrapped.emplace_back(expr.unsafeGetTensorImpl());
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker return unwrapped;
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker template <size_t N>
check_intlist(ArrayRef<int64_t> list,const char * name,int pos)84*da0073e9SAndroid Build Coastguard Worker std::array<int64_t, N> check_intlist(
85*da0073e9SAndroid Build Coastguard Worker ArrayRef<int64_t> list,
86*da0073e9SAndroid Build Coastguard Worker const char* name,
87*da0073e9SAndroid Build Coastguard Worker int pos) {
88*da0073e9SAndroid Build Coastguard Worker if (list.empty()) {
89*da0073e9SAndroid Build Coastguard Worker // TODO: is this necessary? We used to treat nullptr-vs-not in IntList
90*da0073e9SAndroid Build Coastguard Worker // differently with strides as a way of faking optional.
91*da0073e9SAndroid Build Coastguard Worker list = {};
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker auto res = std::array<int64_t, N>();
94*da0073e9SAndroid Build Coastguard Worker if (list.size() == 1 && N > 1) {
95*da0073e9SAndroid Build Coastguard Worker res.fill(list[0]);
96*da0073e9SAndroid Build Coastguard Worker return res;
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker if (list.size() != N) {
99*da0073e9SAndroid Build Coastguard Worker AT_ERROR(
100*da0073e9SAndroid Build Coastguard Worker "Expected a list of ",
101*da0073e9SAndroid Build Coastguard Worker N,
102*da0073e9SAndroid Build Coastguard Worker " ints but got ",
103*da0073e9SAndroid Build Coastguard Worker list.size(),
104*da0073e9SAndroid Build Coastguard Worker " for argument #",
105*da0073e9SAndroid Build Coastguard Worker pos,
106*da0073e9SAndroid Build Coastguard Worker " '",
107*da0073e9SAndroid Build Coastguard Worker name,
108*da0073e9SAndroid Build Coastguard Worker "'");
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker std::copy_n(list.begin(), N, res.begin());
111*da0073e9SAndroid Build Coastguard Worker return res;
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker using at::detail::check_size_nonnegative;
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker namespace detail {
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker template <typename T>
119*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker template <typename T>
122*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
123*da0073e9SAndroid Build Coastguard Worker tensor_backend(ArrayRef<T> values, const TensorOptions& options);
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker template <typename T>
126*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
127*da0073e9SAndroid Build Coastguard Worker tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker template <typename T>
130*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor
131*da0073e9SAndroid Build Coastguard Worker tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
132*da0073e9SAndroid Build Coastguard Worker } // namespace detail
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker } // namespace at
135