1 #pragma once
2
3 #include <ATen/Context.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/stack.h>
7 #include <torch/csrc/jit/runtime/jit_exception.h>
8 #include <torch/csrc/jit/runtime/vararg_functions.h>
9
10 namespace torch::jit {
11
noop(Stack & n)12 inline void noop(Stack& n) {}
13
14 int64_t normalizeIndex(int64_t idx, int64_t list_size);
15
16 // reference function THPVariable_to in python_variable_methods.cpp
to_dispatch(at::Tensor self,std::optional<at::Device> device,std::optional<at::ScalarType> scalarType,bool non_blocking,bool copy)17 static C10_UNUSED at::Tensor to_dispatch(
18 at::Tensor self,
19 std::optional<at::Device> device,
20 std::optional<at::ScalarType> scalarType,
21 bool non_blocking,
22 bool copy) {
23 if (device && device->is_cuda()) {
24 at::globalContext().lazyInitCUDA();
25 }
26 if (!device && !scalarType && !copy) {
27 return self;
28 } else if (!device) {
29 return self.to(*scalarType, non_blocking, copy);
30 } else if (!scalarType) {
31 return self.to(*device, non_blocking, copy);
32 } else {
33 return self.to(*device, *scalarType, non_blocking, copy);
34 }
35 }
36
37 // Convert the tensor pointed to by \p data to a nested list. \p dim is the
38 // number of dimensions in the tensor and \p cur_dim is the dimension being
39 // processed by the current invocation. \p ty is the expected output IR type of
40 // the operation. \p is the scalar type of \p data. \p sizes and \p strides are
41 // the sizes and strides of the tensor operand and \p element_size is the size
42 // in bytes of one tensor element.
43 IValue tensorToListRecursive(
44 char* data,
45 int64_t cur_dim,
46 int64_t num_tensor_dims,
47 at::TypePtr ty,
48 at::ScalarType scalar_ty,
49 at::IntArrayRef sizes,
50 at::IntArrayRef strides,
51 size_t element_size);
52
53 } // namespace torch::jit
54