xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/standalone/privateuse1_observer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/profiler/api.h>
3 
4 namespace torch::profiler::impl {
5 
6 using CallBackFnPtr = void (*)(
7     const ProfilerConfig& config,
8     const std::unordered_set<at::RecordScope>& scopes);
9 
10 struct PushPRIVATEUSE1CallbacksStub {
11   PushPRIVATEUSE1CallbacksStub() = default;
12   PushPRIVATEUSE1CallbacksStub(const PushPRIVATEUSE1CallbacksStub&) = delete;
13   PushPRIVATEUSE1CallbacksStub& operator=(const PushPRIVATEUSE1CallbacksStub&) =
14       delete;
15 
16   template <typename... ArgTypes>
operatorPushPRIVATEUSE1CallbacksStub17   void operator()(ArgTypes&&... args) {
18     return (*push_privateuse1_callbacks_fn)(std::forward<ArgTypes>(args)...);
19   }
20 
set_privateuse1_dispatch_ptrPushPRIVATEUSE1CallbacksStub21   void set_privateuse1_dispatch_ptr(CallBackFnPtr fn_ptr) {
22     push_privateuse1_callbacks_fn = fn_ptr;
23   }
24 
25  private:
26   CallBackFnPtr push_privateuse1_callbacks_fn = nullptr;
27 };
28 
29 extern TORCH_API struct PushPRIVATEUSE1CallbacksStub
30     pushPRIVATEUSE1CallbacksStub;
31 
32 struct RegisterPRIVATEUSE1Observer {
RegisterPRIVATEUSE1ObserverRegisterPRIVATEUSE1Observer33   RegisterPRIVATEUSE1Observer(
34       PushPRIVATEUSE1CallbacksStub& stub,
35       CallBackFnPtr value) {
36     stub.set_privateuse1_dispatch_ptr(value);
37   }
38 };
39 
40 #define REGISTER_PRIVATEUSE1_OBSERVER(name, fn) \
41   static RegisterPRIVATEUSE1Observer name##__register(name, fn);
42 } // namespace torch::profiler::impl
43