xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/collection.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <memory>
5 #include <mutex>
6 #include <type_traits>
7 #include <utility>
8 #include <variant>
9 
10 #include <ATen/Context.h>
11 #include <c10/core/Device.h>
12 #include <c10/core/TensorImpl.h>
13 #include <c10/macros/Macros.h>
14 #include <c10/util/ApproximateClock.h>
15 #include <c10/util/flat_hash_map.h>
16 #include <c10/util/strong_type.h>
17 #include <torch/csrc/profiler/containers.h>
18 #include <torch/csrc/profiler/data_flow.h>
19 #include <torch/csrc/profiler/events.h>
20 #include <torch/csrc/profiler/kineto_shim.h>
21 #include <torch/csrc/profiler/orchestration/python_tracer.h>
22 #include <torch/csrc/profiler/perf.h>
23 #include <torch/csrc/profiler/stubs/base.h>
24 #include <torch/csrc/profiler/util.h>
25 #include <torch/csrc/utils/python_stub.h>
26 
27 namespace torch::profiler::impl {
28 
29 enum class EventType : uint8_t {
30   TorchOp = 0,
31   Backend,
32   Vulkan,
33   Allocation,
34   OutOfMemory,
35   PyCall,
36   PyCCall,
37   Kineto
38 };
39 
40 // ============================================================================
41 // == Value (Tensor, Scalar) summary ==========================================
42 // ============================================================================
43 struct TORCH_API RawTensorMetadataBase {
44   RawTensorMetadataBase() = default;
45   explicit RawTensorMetadataBase(const at::Tensor& t);
46 
47   StorageImplData data_;
48   c10::ScalarType dtype_{c10::ScalarType::Undefined};
49   c10::Layout layout_{c10::Layout::Strided};
50   uint32_t size_dim_{0};
51 };
52 
53 // Collected during profiling.
54 struct TORCH_API RawTensorMetadata : RawTensorMetadataBase {
55   RawTensorMetadata() = default;
56   RawTensorMetadata(const RawTensorMetadata&) = default;
57   RawTensorMetadata(RawTensorMetadata&&) noexcept = default;
58   RawTensorMetadata& operator=(const RawTensorMetadata&) = default;
59   RawTensorMetadata& operator=(RawTensorMetadata&&) noexcept = default;
60   explicit RawTensorMetadata(const at::Tensor& t);
61 
62   // Wrap `weak_self_` in `std::optional` and split device into components to
63   // keep struct default constructable. (which the std::array initializer needs)
64   std::optional<WeakTensor> weak_self_;
65   c10::DeviceType device_type_{c10::DeviceType::CPU};
66   c10::DeviceIndex device_index_{-1};
67 };
68 
69 // Used during post processing.
70 struct TORCH_API TensorMetadata : public RawTensorMetadataBase {
71   TensorMetadata(
72       const RawTensorMetadata& r,
73       std::vector<int64_t> sizes,
74       std::vector<int64_t> strides);
75 
implTensorMetadata76   TensorImplAddress impl() const {
77     return weak_self_.get();
78   }
79 
80   WeakTensor weak_self_;
81   c10::Device device_;
82   std::vector<int64_t> sizes_;
83   std::vector<int64_t> strides_;
84 
85   // Set during `calculateUniqueTensorIDs`.
86   std::optional<TensorID> id_;
87   std::optional<AllocationID> allocation_id_;
88 };
89 
90 // Used during post processing.
91 struct TORCH_API ProfilerStepInfo {
92   int64_t start_time_ns; // start time of the profiler step
93   int64_t end_time_ns; // end time of the profiler step
94   uint64_t out_idx; // index of the profiler step in the profiler "out" var in
95                     // getRecords
96 
ProfilerStepInfoProfilerStepInfo97   ProfilerStepInfo(int64_t start, int64_t end, uint64_t out_idx)
98       : start_time_ns(start), end_time_ns(end), out_idx(out_idx) {}
99 };
100 
101 using op_input_t = std::variant<
102     TensorMetadata,
103     std::vector<TensorMetadata>,
104     c10::IValue,
105     std::nullopt_t>;
106 
107 // ============================================================================
108 // == ExtraFields =============================================================
109 // ============================================================================
110 template <EventType>
111 struct ExtraFields;
112 
113 struct TorchOpBasicFields {
114   int64_t sequence_number_{0};
115   uint64_t forward_tid_{0};
116   at::RecordScope scope_{};
117   bool is_async_{false};
118   uint64_t record_function_id_{0};
119   int64_t debug_handle_{0};
120   std::string name_;
121 
122   // Set in the exit callback.
123   uint64_t end_tid_{0};
124 };
125 
126 using jit_stack_t = std::vector<std::string>;
127 using jit_modules_t = std::vector<std::string>;
128 using extra_args_t = std::unordered_map<std::string, c10::IValue>;
129 using extra_meta_t = std::unordered_map<std::string, std::string>;
130 using kwinputs_t = std::unordered_map<std::string, c10::IValue>;
131 
132 struct FallbackPair {
133   ProfilerVoidEventStub device_event_start_ = nullptr;
134   ProfilerVoidEventStub device_event_end_ = nullptr;
135 };
136 
137 template <>
138 struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
139   ExtraFields(
140       TorchOpBasicFields&& f,
141       uint64_t correlation_id,
142       c10::time_t end_time_ns,
143       std::vector<op_input_t>&& inputs,
144       std::vector<op_input_t>&& concrete_inputs,
145       jit_stack_t&& jit_stack,
146       jit_modules_t&& jit_modules,
147       extra_args_t&& extra_args,
148       extra_meta_t&& extra_meta,
149       kwinputs_t&& kwinputs,
150       FallbackPair&& device_fallback,
151       bool allow_tf32_cublas,
152       std::unique_ptr<perf_counters_t>&& perf_event_counters)
153       : TorchOpBasicFields(std::move(f)),
154         correlation_id_{correlation_id},
155         end_time_ns_{end_time_ns},
156         inputs_{std::move(inputs)},
157         concrete_inputs_{std::move(concrete_inputs)},
158         jit_stack_{std::move(jit_stack)},
159         jit_modules_{std::move(jit_modules)},
160         extra_args_{std::move(extra_args)},
161         extra_meta_{std::move(extra_meta)},
162         kwinputs_{std::move(kwinputs)},
163         device_fallback_{std::move(device_fallback)},
164         allow_tf32_cublas_{allow_tf32_cublas},
165         perf_event_counters_{std::move(perf_event_counters)} {}
166   uint64_t correlation_id_;
167   c10::time_t end_time_ns_;
168   std::vector<op_input_t> inputs_;
169   std::vector<op_input_t> concrete_inputs_;
170   jit_stack_t jit_stack_;
171   jit_modules_t jit_modules_;
172   extra_args_t extra_args_;
173   extra_meta_t extra_meta_;
174   kwinputs_t kwinputs_;
175   FallbackPair device_fallback_;
176   bool allow_tf32_cublas_;
177   std::unique_ptr<perf_counters_t> perf_event_counters_;
178 };
179 
180 template <>
181 struct ExtraFields<EventType::Backend> {
182   int64_t start_time_us_;
183   int64_t end_time_us_;
184   int64_t debug_handle_;
185   at::RecordScope scope_;
186   std::string name_;
187   std::string backend_;
188   jit_stack_t jit_stack_;
189   jit_modules_t jit_modules_;
190 };
191 
192 template <>
193 struct ExtraFields<EventType::Vulkan> {
194   using raw_event_t = std::pair<c10::approx_time_t, vulkan_id_t>;
195   std::string name_;
196   int64_t duration_ns_{0};
197   // While building the event tree, we want to report a vulkan event's duration
198   // as 0 so that its end time doesn't exceed that of its parent cpu op
199   bool in_tree_building_{false};
200 };
201 
202 struct RawAllocation {
203   c10::approx_time_t start_time_;
204   void* ptr_;
205   int64_t alloc_size_;
206   size_t total_allocated_;
207   size_t total_reserved_;
208   c10::DeviceType device_type_;
209   c10::DeviceIndex device_index_;
210 };
211 
212 // For performance.
213 static_assert(c10::is_pod_v<RawAllocation>, "Non-POD member of RawAllocation.");
214 
215 template <>
216 struct ExtraFields<EventType::Allocation> : RawAllocation {
217   ExtraFields(const RawAllocation& allocation) : RawAllocation(allocation) {}
218 
219   c10::Device device() const {
220     return {device_type_, device_index_};
221   }
222 
223   std::optional<TensorID> id_;
224   std::optional<AllocationID> allocation_id_;
225 };
226 
227 template <>
228 struct ExtraFields<EventType::OutOfMemory> {
229   c10::approx_time_t start_time_;
230   int64_t alloc_size_;
231   size_t total_allocated_;
232   size_t total_reserved_;
233   c10::DeviceType device_type_;
234   c10::DeviceIndex device_index_;
235 };
236 
237 // For performance.
238 static_assert(
239     c10::is_pod_v<ExtraFields<EventType::OutOfMemory>>,
240     "Non-POD member of ExtraFields<EventType::OutOfMemory>.");
241 
242 struct PyFrameState {
243   int line_no_;
244   at::StringView filename_;
245   at::StringView funcname_;
246 };
247 
248 template <typename T, typename Tag>
249 using strong_t = strong::
250     type<T, Tag, strong::regular, strong::convertible_to<T>, strong::hashable>;
251 
252 using PyModuleSelf = strong_t<PyObject*, struct PyModuleSelf_>;
253 using PyModuleCls = strong_t<PyObject*, struct PyModuleCls_>;
254 using PyMethod = strong_t</*PyMethodDef*/ void*, struct PyMethod_>;
255 using PyOptimizerSelf = strong_t<PyObject*, struct PyOptSelf_>;
256 using PyOptimizerCls = strong_t<PyObject*, struct PyOptimizer_>;
257 
258 struct NNModuleInfo {
259   struct ParameterInfo {
260     std::string name_;
261     TensorMetadata metadata_;
262     std::optional<TensorMetadata> grad_metadata_;
263   };
264 
265   PyModuleSelf self_;
266   PyModuleCls cls_;
267   at::StringView cls_name_;
268 
269   std::vector<ParameterInfo> parameters_;
270   // Indicates that `self_` is the kth instance of `cls_` observed.
271   size_t id_{std::numeric_limits<size_t>::max()};
272 };
273 
274 struct OptimizerInfo {
275   struct ParameterInfo {
276     TensorMetadata metadata_;
277     std::optional<TensorMetadata> grad_metadata_;
278     std::vector<std::pair<std::string, TensorMetadata>> state_;
279   };
280 
281   PyOptimizerSelf self_;
282   PyOptimizerCls cls_;
283   at::StringView cls_name_;
284 
285   std::vector<ParameterInfo> parameters_;
286 };
287 
288 struct PyExtraFieldsBase {
289   PyExtraFieldsBase(
290       c10::time_t end_time_ns,
291       size_t python_tid,
292       PyFrameState caller)
293       : end_time_ns_{end_time_ns},
294         python_tid_{python_tid},
295         caller_{std::move(caller)} {}
296 
297   c10::time_t end_time_ns_;
298   size_t python_tid_;
299   PyFrameState caller_;
300 
301   // kth python event observed. (Used by TensorBoard)
302   size_t id_{std::numeric_limits<size_t>::max()};
303 };
304 
305 template <>
306 struct ExtraFields<EventType::PyCall> : public PyExtraFieldsBase {
307   struct args_t {
308     PyFrameState frame_state_;
309     std::optional<NNModuleInfo> module_info_;
310     std::optional<OptimizerInfo> optimizer_info_;
311   };
312 
313   ExtraFields(
314       c10::time_t end_time_ns,
315       size_t python_tid,
316       PyFrameState caller,
317       args_t args)
318       : PyExtraFieldsBase(end_time_ns, python_tid, std::move(caller)),
319         callsite_{std::move(args.frame_state_)},
320         module_{std::move(args.module_info_)},
321         optimizer_{std::move(args.optimizer_info_)} {}
322 
323   PyFrameState callsite_;
324   std::optional<NNModuleInfo> module_;
325   std::optional<OptimizerInfo> optimizer_;
326 };
327 
328 template <>
329 struct ExtraFields<EventType::PyCCall> : public PyExtraFieldsBase {
330   using args_t = at::StringView;
331 
332   ExtraFields(
333       c10::time_t end_time_ns,
334       size_t python_tid,
335       PyFrameState caller,
336       args_t args)
337       : PyExtraFieldsBase(end_time_ns, python_tid, std::move(caller)),
338         function_name_{std::move(args)} {}
339 
340   at::StringView function_name_;
341 };
342 
343 template <>
344 struct ExtraFields<EventType::Kineto> {
345   // Mirrors `libkineto::GenericTraceActivity::Flow`. This information is used
346   // during post processing to properly embed Kineto events into the broader
347   // profiler tree structure. End users are not generally expected to use these
348   // fields directly, but they are available for debugging.
349   struct Flow {
350     uint32_t id{0};
351     uint32_t type{0};
352     uint32_t start{0};
353   };
354 
355   std::string name_;
356   int64_t duration_ns_{0};
357   uint64_t correlation_id_{0};
358   libkineto::ActivityType activity_type_;
359   Flow flow;
360   std::weak_ptr<Result> linked_activity_{};
361 };
362 
363 struct TORCH_API Result : public std::enable_shared_from_this<Result> {
364   template <typename... Args>
365   [[nodiscard]] static std::shared_ptr<Result> create(Args... args) {
366     return std::shared_ptr<Result>(new Result(std::forward<Args>(args)...));
367   }
368 
369   template <typename T>
370   decltype(auto) visit(T&& visitor) {
371     return std::visit(std::forward<T>(visitor), extra_fields_);
372   }
373 
374   template <typename T>
375   decltype(auto) visit(T&& visitor) const {
376     return std::visit(std::forward<T>(visitor), extra_fields_);
377   }
378 
379   template <typename T, typename Fn>
380   void visit_if_base(Fn&& fn) const {
381     visit([&](const auto& extra_fields) {
382       using extra_fields_t = typename std::remove_cv_t<
383           typename std::remove_reference_t<decltype(extra_fields)>>;
384 
385       if constexpr (std::is_base_of_v<T, extra_fields_t>) {
386         fn(extra_fields);
387       }
388     });
389   }
390 
391   EventType tag() const {
392     return visit([](const auto& i) { return deduceTag(i); });
393   }
394 
395   std::string name() const;
396   libkineto::ActivityType kinetoType() const;
397   uint64_t correlationID() const;
398   int64_t endTimeNS() const;
399   uint64_t endTID() const;
400   c10::DeviceType deviceType() const;
401 
402   int64_t start_time_ns_;
403   uint64_t start_tid_;
404   kineto::DeviceAndResource kineto_info_;
405   std::variant<
406       ExtraFields<EventType::TorchOp>,
407       ExtraFields<EventType::Backend>,
408       ExtraFields<EventType::Vulkan>,
409       ExtraFields<EventType::Allocation>,
410       ExtraFields<EventType::OutOfMemory>,
411       ExtraFields<EventType::PyCall>,
412       ExtraFields<EventType::PyCCall>,
413       ExtraFields<EventType::Kineto>>
414       extra_fields_;
415 
416   std::weak_ptr<Result> parent_;
417   std::vector<std::shared_ptr<Result>> children_;
418   bool finished_{false};
419 
420   const torch::profiler::impl::kineto::activity_t* kineto_activity_{nullptr};
421 
422  private:
423   template <EventType E>
424   Result(
425       int64_t start_time_ns,
426       uint64_t start_tid,
427       kineto::DeviceAndResource kineto_info,
428       ExtraFields<E>&& extra_fields)
429       : start_time_ns_{start_time_ns},
430         start_tid_{start_tid},
431         kineto_info_{kineto_info},
432         extra_fields_{std::move(extra_fields)} {}
433 
434   template <EventType E>
435   static EventType deduceTag(const ExtraFields<E>&) {
436     return E;
437   }
438 };
439 
440 struct KinetoObserverContext : public at::ObserverContext {
441   struct Event {
442     TorchOpBasicFields basic_fields_;
443     c10::approx_time_t start_time_;
444 
445     // Set in the exit callback.
446     c10::approx_time_t end_time_{
447         std::numeric_limits<c10::approx_time_t>::min()};
448 
449     bool allow_tf32_cublas_;
450     std::unique_ptr<perf_counters_t> counters_;
451   };
452 
453   explicit KinetoObserverContext(Event* event) : event_{event} {}
454 
455   Event* event_;
456   FallbackPair* fallback_{nullptr};
457 };
458 
459 constexpr int IO_ENCODER_DEFAULT_BLOCK_SIZE = 1024;
460 
461 constexpr int SCALAR_LIST_LENGTH_LIMIT = 30;
462 
463 // InputOutputEncoder
464 // Stores each op_events' shapes and dtypes, and concrete values into a
465 // contiguous AppendOnlyList so that we no longer create vectors for shapes
466 // and dtypes on every op. Those vectors can be created during
467 // post-processing.
468 // It splits the data into two categories: input shapes and concrete inputs.
469 class InputOutputEncoder final {
470  public:
471   void push(c10::ArrayRef<const c10::IValue> values);
472 
473   // Used during post-processing to unpack the encoded data.
474   // Each method returns a "supplier" lambda which takes no arguments;
475   // invoking the lambda once will return a list of args that represent
476   // the inputs for one op.
477   // The data is split into two streams: "input shapes" and "concrete inputs".
478   // Note: "auto" only works because these are only used in collection.cpp,
479   // where they are implemented.
480   auto getInputShapeGenerator();
481   auto getConcreteInputGenerator();
482 
483   bool isSupportedScalarList(const c10::IValue& list_candidate);
484 
485   void clear();
486 
487   enum class Tag {
488     Tensor = 0,
489     UndefinedTensor,
490     TensorListBegin, // TODO: generalize to other lists.
491     ScalarList,
492     Scalar,
493     Other,
494     TERMINATOR
495   };
496 
497   enum class IOType { Shapes, ConcreteInputs, None };
498 
499  private:
500   void push(const at::Tensor& t);
501 
502   // Implementation detail for getInputShapeGenerator and
503   // getConcreteInputGenerator
504   auto getIValueGenerator(const IOType& io_type);
505 
506   AppendOnlyList<Tag, IO_ENCODER_DEFAULT_BLOCK_SIZE> tags_;
507   AppendOnlyList<RawTensorMetadata, IO_ENCODER_DEFAULT_BLOCK_SIZE>
508       tensor_metadata_;
509   AppendOnlyList<int64_t, IO_ENCODER_DEFAULT_BLOCK_SIZE> tensor_sizes_strides_;
510   AppendOnlyList<c10::IValue, IO_ENCODER_DEFAULT_BLOCK_SIZE> ivalues_;
511 };
512 
513 using perf_profiler_t = torch::profiler::impl::linux_perf::PerfProfiler;
514 
515 class TORCH_API ThreadLocalSubqueue {
516  public:
517   ThreadLocalSubqueue(const uint64_t tid, ProfilerConfig config);
518 
519   std::unique_ptr<KinetoObserverContext> begin_op(const at::RecordFunction& fn);
520 
521   template <class... Args>
522   void emplace_backend_event(Args&&... args) {
523     backend_events_.emplace_back(std::forward<Args>(args)...);
524   }
525 
526   template <class... Args>
527   void emplace_vulkan_event(Args&&... args) {
528     vulkan_events_.emplace_back(std::forward<Args>(args)...);
529   }
530 
531   template <class... Args>
532   void emplace_allocation_event(Args&&... args) {
533     allocations_.emplace_back(std::forward<Args>(args)...);
534   }
535 
536   template <class... Args>
537   void emplace_ooms_event(Args&&... args) {
538     ooms_.emplace_back(std::forward<Args>(args)...);
539   }
540 
541   template <class... Args>
542   void emplace_py_call(Args&&... args) {
543     py_calls_.emplace_back(std::forward<Args>(args)...);
544   }
545 
546   uint64_t tid() const {
547     return tid_;
548   }
549 
550   const kineto::DeviceAndResource& kineto_info() const {
551     return kineto_info_;
552   }
553 
554   inline void disable_perf_profiler(perf_counters_t& counters) const {
555     perf_profiler_->Disable(counters);
556   }
557 
558  private:
559   uint64_t tid_;
560   ProfilerConfig config_;
561   kineto::DeviceAndResource kineto_info_;
562   std::unique_ptr<perf_profiler_t> perf_profiler_;
563 
564   friend class RecordQueue;
565   // See `containers.h` for block size benchmarks.
566   static constexpr size_t BlockSize = 512;
567 
568   struct TorchOpStorage {
569     // NB: This is a destructive operation.
570     void materialize(
571         std::vector<std::shared_ptr<Result>>& out,
572         std::vector<ProfilerStepInfo>& step_info,
573         const std::function<c10::time_t(c10::approx_time_t)>& time_converter,
574         const uint64_t tid,
575         const kineto::DeviceAndResource& kineto_info);
576 
577     template <typename T, size_t ChunkSize>
578     class EventBlock : public std::array<T, ChunkSize> {
579      public:
580       EventBlock();
581       uint64_t correlation_id(const T* ptr) const;
582 
583      private:
584       uint64_t id_start_;
585     };
586 
587     using event_t = KinetoObserverContext::Event;
588     class OpList : public AppendOnlyList<event_t, BlockSize, EventBlock> {
589      public:
590       template <class... Args>
591       std::pair<event_t*, uint64_t> emplace_back(Args&&... args);
592       static uint64_t correlationID(const OpList::Iterator& e);
593     } op_events_;
594 
595     // report_input_shapes
596     InputOutputEncoder inputs_outputs_;
597 
598     // with_stack (JIT)
599     AppendOnlyList<jit_stack_t, BlockSize> jit_stack_;
600 
601     // with_modules
602     AppendOnlyList<jit_modules_t, BlockSize> jit_modules_;
603 
604     // with_flops
605     AppendOnlyList<extra_args_t, BlockSize> extra_args_;
606 
607     // report extra metadata, i.e. collective communication meta
608     AppendOnlyList<extra_meta_t, BlockSize> extra_meta_;
609 
610     // report kwinputs
611     AppendOnlyList<kwinputs_t, BlockSize> kwinputs_;
612 
613     // ProfilerState::KINETO_GPU_FALLBACK or
614     // ProfilerState::KINETO_PRIVATEUSE1_FALLBACK
615     AppendOnlyList<FallbackPair, BlockSize> device_fallback_;
616   } torch_ops_;
617 
618   // reportBackendEventToActiveKinetoProfiler
619   AppendOnlyList<ExtraFields<EventType::Backend>, BlockSize> backend_events_;
620 
621   // _reportVulkanEventToProfiler
622   AppendOnlyList<ExtraFields<EventType::Vulkan>::raw_event_t, BlockSize>
623       vulkan_events_;
624 
625   // reportMemoryUsage
626   AppendOnlyList<RawAllocation, BlockSize> allocations_;
627 
628   // reportOOMs
629   AppendOnlyList<ExtraFields<EventType::OutOfMemory>, BlockSize> ooms_;
630 
631   // with_stack (Python)
632   AppendOnlyList<
633       std::pair<python_tracer::TraceKey, c10::approx_time_t>,
634       BlockSize>
635       py_calls_;
636 };
637 
638 class TORCH_API RecordQueue {
639  public:
640   RecordQueue(ProfilerConfig config, std::set<ActivityType> activities);
641 
642   bool tracePython() const;
643   ThreadLocalSubqueue* getSubqueue();
644   void stop();
645   void restart();
646 
647   // NB: This is a destructive operation.
648   std::pair<
649       std::vector<std::shared_ptr<Result>>,
650       std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
651   getRecords(
652       std::function<c10::time_t(c10::approx_time_t)> time_converter,
653       uint64_t start_time_ns,
654       uint64_t end_time_ns);
655 
656  private:
657   uint32_t id_;
658   ProfilerConfig config_;
659   std::set<ActivityType> activities_;
660   ska::flat_hash_map<uint64_t, std::unique_ptr<ThreadLocalSubqueue>>
661       sub_queues_;
662   std::mutex sub_queue_mutex_;
663   std::unique_ptr<python_tracer::PythonTracerBase> python_tracer_;
664 };
665 
666 TORCH_API bool get_record_concrete_inputs_enabled();
667 TORCH_API void set_record_concrete_inputs_enabled_fn(std::function<bool()>);
668 TORCH_API void set_record_concrete_inputs_enabled_val(bool);
669 
670 TORCH_API bool get_fwd_bwd_enabled();
671 TORCH_API void set_fwd_bwd_enabled_fn(std::function<bool()>);
672 TORCH_API void set_fwd_bwd_enabled_val(bool);
673 
674 TORCH_API bool get_cuda_sync_enabled();
675 TORCH_API void set_cuda_sync_enabled_fn(std::function<bool()>);
676 TORCH_API void set_cuda_sync_enabled_val(bool);
677 
678 } // namespace torch::profiler::impl
679