1 #pragma once 2 #include <torch/csrc/profiler/api.h> 3 4 namespace torch::profiler::impl { 5 6 using CallBackFnPtr = void (*)( 7 const ProfilerConfig& config, 8 const std::unordered_set<at::RecordScope>& scopes); 9 10 struct PushPRIVATEUSE1CallbacksStub { 11 PushPRIVATEUSE1CallbacksStub() = default; 12 PushPRIVATEUSE1CallbacksStub(const PushPRIVATEUSE1CallbacksStub&) = delete; 13 PushPRIVATEUSE1CallbacksStub& operator=(const PushPRIVATEUSE1CallbacksStub&) = 14 delete; 15 16 template <typename... ArgTypes> operatorPushPRIVATEUSE1CallbacksStub17 void operator()(ArgTypes&&... args) { 18 return (*push_privateuse1_callbacks_fn)(std::forward<ArgTypes>(args)...); 19 } 20 set_privateuse1_dispatch_ptrPushPRIVATEUSE1CallbacksStub21 void set_privateuse1_dispatch_ptr(CallBackFnPtr fn_ptr) { 22 push_privateuse1_callbacks_fn = fn_ptr; 23 } 24 25 private: 26 CallBackFnPtr push_privateuse1_callbacks_fn = nullptr; 27 }; 28 29 extern TORCH_API struct PushPRIVATEUSE1CallbacksStub 30 pushPRIVATEUSE1CallbacksStub; 31 32 struct RegisterPRIVATEUSE1Observer { RegisterPRIVATEUSE1ObserverRegisterPRIVATEUSE1Observer33 RegisterPRIVATEUSE1Observer( 34 PushPRIVATEUSE1CallbacksStub& stub, 35 CallBackFnPtr value) { 36 stub.set_privateuse1_dispatch_ptr(value); 37 } 38 }; 39 40 #define REGISTER_PRIVATEUSE1_OBSERVER(name, fn) \ 41 static RegisterPRIVATEUSE1Observer name##__register(name, fn); 42 } // namespace torch::profiler::impl 43