xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/types.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 #include <optional>
6 
7 #include <torch/csrc/autograd/generated/variable_factories.h>
8 #include <torch/csrc/autograd/variable.h>
9 
10 // TODO: These don't really belong here but torchvision builds in CI need them
11 // Remove once the torchvision version being compiled in CI is updated
12 #include <ATen/core/dispatch/Dispatcher.h>
13 #include <torch/library.h>
14 
15 namespace torch {
16 
17 // NOTE [ Exposing declarations in `at::` to `torch::` ]
18 //
19 // The following line `using namespace at;` is responsible for exposing all
20 // declarations in `at::` namespace to `torch::` namespace.
21 //
22 // According to the rules laid out in
23 // https://en.cppreference.com/w/cpp/language/qualified_lookup, section
24 // "Namespace members":
25 // ```
26 // Qualified lookup within the scope of a namespace N first considers all
27 // declarations that are located in N and all declarations that are located in
28 // the inline namespace members of N (and, transitively, in their inline
29 // namespace members). If there are no declarations in that set then it
30 // considers declarations in all namespaces named by using-directives found in N
31 // and in all transitive inline namespace members of N.
32 // ```
33 //
34 // This means that if both `at::` and `torch::` namespaces have a function with
35 // the same signature (e.g. both `at::func()` and `torch::func()` exist), after
36 // `namespace torch { using namespace at; }`, when we call `torch::func()`, the
37 // `func()` function defined in `torch::` namespace will always be called, and
38 // the `func()` function defined in `at::` namespace is always hidden.
39 using namespace at; // NOLINT
40 
41 using std::nullopt;
42 using std::optional;
43 
44 using Dtype = at::ScalarType;
45 
46 /// Fixed width dtypes.
47 constexpr auto kUInt8 = at::kByte;
48 constexpr auto kInt8 = at::kChar;
49 constexpr auto kInt16 = at::kShort;
50 constexpr auto kInt32 = at::kInt;
51 constexpr auto kInt64 = at::kLong;
52 constexpr auto kUInt16 = at::kUInt16;
53 constexpr auto kUInt32 = at::kUInt32;
54 constexpr auto kUInt64 = at::kUInt64;
55 constexpr auto kFloat16 = at::kHalf;
56 constexpr auto kFloat32 = at::kFloat;
57 constexpr auto kFloat64 = at::kDouble;
58 
59 /// Rust-style short dtypes.
60 constexpr auto kU8 = kUInt8;
61 constexpr auto kU16 = kUInt16;
62 constexpr auto kU32 = kUInt32;
63 constexpr auto kU64 = kUInt64;
64 constexpr auto kI8 = kInt8;
65 constexpr auto kI16 = kInt16;
66 constexpr auto kI32 = kInt32;
67 constexpr auto kI64 = kInt64;
68 constexpr auto kF16 = kFloat16;
69 constexpr auto kF32 = kFloat32;
70 constexpr auto kF64 = kFloat64;
71 } // namespace torch
72