1 #pragma once 2 3 #include <c10/core/SymInt.h> 4 #include <c10/macros/Export.h> 5 #include <c10/macros/Macros.h> 6 #include <cstdint> 7 #include <utility> 8 9 namespace c10 { 10 11 namespace detail { 12 // This template can only be specialized at int64_t and c10::SymInt; 13 // you'll get linker errors otherwise 14 template <typename T> 15 C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); 16 } // namespace detail 17 18 template <typename T> 19 T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { 20 // Inline the fast paths 21 if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { 22 // For SymInts, we want an explicit control flow to trigger a guard, so we 23 // may as well branch too. 24 if (dim < 0) { 25 return dim + dim_post_expr; 26 } 27 return dim; 28 } 29 // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) 30 return c10::detail::maybe_wrap_dim_slow<T>( 31 std::move(dim), std::move(dim_post_expr), wrap_scalar); 32 } 33 34 inline int64_t maybe_wrap_dim( 35 int64_t dim, 36 int64_t dim_post_expr, 37 bool wrap_scalar = true) { 38 return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); 39 } 40 41 inline c10::SymInt maybe_wrap_dim( 42 c10::SymInt dim, 43 c10::SymInt dim_post_expr, 44 bool wrap_scalar = true) { 45 return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); 46 } 47 48 } // namespace c10 49