1 #pragma once 2 3 #include <ATen/ATen.h> 4 5 #include <optional> 6 7 #include <torch/csrc/autograd/generated/variable_factories.h> 8 #include <torch/csrc/autograd/variable.h> 9 10 // TODO: These don't really belong here but torchvision builds in CI need them 11 // Remove once the torchvision version being compiled in CI is updated 12 #include <ATen/core/dispatch/Dispatcher.h> 13 #include <torch/library.h> 14 15 namespace torch { 16 17 // NOTE [ Exposing declarations in `at::` to `torch::` ] 18 // 19 // The following line `using namespace at;` is responsible for exposing all 20 // declarations in `at::` namespace to `torch::` namespace. 21 // 22 // According to the rules laid out in 23 // https://en.cppreference.com/w/cpp/language/qualified_lookup, section 24 // "Namespace members": 25 // ``` 26 // Qualified lookup within the scope of a namespace N first considers all 27 // declarations that are located in N and all declarations that are located in 28 // the inline namespace members of N (and, transitively, in their inline 29 // namespace members). If there are no declarations in that set then it 30 // considers declarations in all namespaces named by using-directives found in N 31 // and in all transitive inline namespace members of N. 32 // ``` 33 // 34 // This means that if both `at::` and `torch::` namespaces have a function with 35 // the same signature (e.g. both `at::func()` and `torch::func()` exist), after 36 // `namespace torch { using namespace at; }`, when we call `torch::func()`, the 37 // `func()` function defined in `torch::` namespace will always be called, and 38 // the `func()` function defined in `at::` namespace is always hidden. 39 using namespace at; // NOLINT 40 41 using std::nullopt; 42 using std::optional; 43 44 using Dtype = at::ScalarType; 45 46 /// Fixed width dtypes. 47 constexpr auto kUInt8 = at::kByte; 48 constexpr auto kInt8 = at::kChar; 49 constexpr auto kInt16 = at::kShort; 50 constexpr auto kInt32 = at::kInt; 51 constexpr auto kInt64 = at::kLong; 52 constexpr auto kUInt16 = at::kUInt16; 53 constexpr auto kUInt32 = at::kUInt32; 54 constexpr auto kUInt64 = at::kUInt64; 55 constexpr auto kFloat16 = at::kHalf; 56 constexpr auto kFloat32 = at::kFloat; 57 constexpr auto kFloat64 = at::kDouble; 58 59 /// Rust-style short dtypes. 60 constexpr auto kU8 = kUInt8; 61 constexpr auto kU16 = kUInt16; 62 constexpr auto kU32 = kUInt32; 63 constexpr auto kU64 = kUInt64; 64 constexpr auto kI8 = kInt8; 65 constexpr auto kI16 = kInt16; 66 constexpr auto kI32 = kInt32; 67 constexpr auto kI64 = kInt64; 68 constexpr auto kF16 = kFloat16; 69 constexpr auto kF32 = kFloat32; 70 constexpr auto kF64 = kFloat64; 71 } // namespace torch 72