1 //===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains some templates that are useful if you are working with the 10 // STL at all. 11 // 12 // No library is required when using these functions. 13 // 14 //===----------------------------------------------------------------------===// 15 16 // c10: modified from llvm::function_ref 17 // c10: added more SFINAE to enable use in overloaded functions 18 19 #pragma once 20 21 #include <cstdint> 22 #include <type_traits> 23 #include <utility> 24 25 namespace c10 { 26 27 /// An efficient, type-erasing, non-owning reference to a callable. This is 28 /// intended for use as the type of a function parameter that is not used 29 /// after the function in question returns. 30 /// 31 /// This class does not own the callable, so it is not in general safe to store 32 /// a function_ref. 33 template <typename Fn> 34 class function_ref; 35 36 template <typename Ret, typename... Params> 37 class function_ref<Ret(Params...)> { 38 Ret (*callback)(intptr_t callable, Params... params) = nullptr; 39 intptr_t callable{}; 40 41 template <typename Callable> callback_fn(intptr_t callable,Params...params)42 static Ret callback_fn(intptr_t callable, Params... params) { 43 return (*reinterpret_cast<Callable*>(callable))( 44 std::forward<Params>(params)...); 45 } 46 47 public: 48 function_ref() = default; function_ref(std::nullptr_t)49 function_ref(std::nullptr_t) {} 50 51 template <typename Callable> 52 function_ref( 53 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) 54 Callable&& callable, 55 std::enable_if_t< 56 !std::is_same_v<std::remove_reference_t<Callable>, function_ref>>* = 57 nullptr, 58 std::enable_if_t<std::is_convertible_v< 59 typename std::invoke_result_t<Callable, Params...>, 60 Ret>>* = nullptr) callback(callback_fn<std::remove_reference_t<Callable>>)61 : callback(callback_fn<std::remove_reference_t<Callable>>), 62 callable(reinterpret_cast<intptr_t>(&callable)) {} 63 operator()64 Ret operator()(Params... params) const { 65 return callback(callable, std::forward<Params>(params)...); 66 } 67 68 operator bool() const { 69 return callback; 70 } 71 }; 72 73 } // namespace c10 74