xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/utils/aten_types.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7####################
8## ATen C++ Types ##
9####################
10
11AT_INT_ARRAY_REF = "at::IntArrayRef"
12AT_SCALAR = "at::Scalar"
13AT_TENSOR = "at::Tensor"
14AT_TENSOR_LIST = "at::TensorList"
15BOOL = "bool"
16DOUBLE = "double"
17INT = "int64_t"
18OPT_AT_DOUBLE_ARRAY_REF = "::std::optional<at::ArrayRef<double>>"
19OPT_AT_INT_ARRAY_REF = "at::OptionalIntArrayRef"
20OPT_AT_TENSOR = "::std::optional<at::Tensor>"
21OPT_BOOL = "::std::optional<bool>"
22OPT_INT64 = "::std::optional<int64_t>"
23OPT_DEVICE = "::std::optional<at::Device>"
24OPT_LAYOUT = "::std::optional<at::Layout>"
25OPT_MEMORY_FORMAT = "::std::optional<at::MemoryFormat>"
26OPT_SCALAR_TYPE = "::std::optional<at::ScalarType>"
27STRING = "c10::string_view"
28TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
29THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
30TENSOR_VECTOR = "::std::vector<at::Tensor>"
31