xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/register_ops_common_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Context.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/stack.h>
7 #include <torch/csrc/jit/runtime/jit_exception.h>
8 #include <torch/csrc/jit/runtime/vararg_functions.h>
9 
10 namespace torch::jit {
11 
noop(Stack & n)12 inline void noop(Stack& n) {}
13 
14 int64_t normalizeIndex(int64_t idx, int64_t list_size);
15 
16 // reference function THPVariable_to in python_variable_methods.cpp
to_dispatch(at::Tensor self,std::optional<at::Device> device,std::optional<at::ScalarType> scalarType,bool non_blocking,bool copy)17 static C10_UNUSED at::Tensor to_dispatch(
18     at::Tensor self,
19     std::optional<at::Device> device,
20     std::optional<at::ScalarType> scalarType,
21     bool non_blocking,
22     bool copy) {
23   if (device && device->is_cuda()) {
24     at::globalContext().lazyInitCUDA();
25   }
26   if (!device && !scalarType && !copy) {
27     return self;
28   } else if (!device) {
29     return self.to(*scalarType, non_blocking, copy);
30   } else if (!scalarType) {
31     return self.to(*device, non_blocking, copy);
32   } else {
33     return self.to(*device, *scalarType, non_blocking, copy);
34   }
35 }
36 
37 // Convert the tensor pointed to by \p data to a nested list. \p dim is the
38 // number of dimensions in the tensor and \p cur_dim is the dimension being
39 // processed by the current invocation. \p ty is the expected output IR type of
40 // the operation. \p is the scalar type of \p data. \p sizes and \p strides are
41 // the sizes and strides of the tensor operand and \p element_size is the size
42 // in bytes of one tensor element.
43 IValue tensorToListRecursive(
44     char* data,
45     int64_t cur_dim,
46     int64_t num_tensor_dims,
47     at::TypePtr ty,
48     at::ScalarType scalar_ty,
49     at::IntArrayRef sizes,
50     at::IntArrayRef strides,
51     size_t element_size);
52 
53 } // namespace torch::jit
54