xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/StreamTimer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #include <cuda_runtime.h>
11 
12 #include <c10/cuda/CUDAStream.h>
13 #include <ATen/cuda/Exceptions.h>
14 #include <ATen/cuda/tunable/StreamTimer.h>
15 
16 namespace at::cuda::tunable {
17 
StreamTimer()18 StreamTimer::StreamTimer() {
19   AT_CUDA_CHECK(cudaEventCreate(&start_));
20   AT_CUDA_CHECK(cudaEventCreate(&end_));
21 }
22 
~StreamTimer()23 StreamTimer::~StreamTimer() {
24 }
25 
Start()26 void StreamTimer::Start() {
27   AT_CUDA_CHECK(cudaDeviceSynchronize());
28   AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
29 }
30 
End()31 void StreamTimer::End() {
32   AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream()));
33   AT_CUDA_CHECK(cudaEventSynchronize(end_));
34 }
35 
Duration()36 float StreamTimer::Duration() {
37   float time;
38   // time is in ms with a resolution of 1 us
39   AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
40   return time;
41 }
42 
43 } // namespace at::cuda::tunable
44