xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSProfiler.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/mps/MPSProfiler.h>
4#include <c10/util/Exception.h>
5#include <fmt/format.h>
6
7// these need to be literal strings when passed to os_signpost*()
8// function macros; so no LUTs could be used
9#define kMPSProfilerSubSystemStr "PyTorchMPS"
10#define kMPSCategoryEventsStr "Events"
11#define kMPSCategoryIntervalsStr "Intervals"
12#define kIntSignpostRunOperationStr "PyTorchOperationIntervals"
13#define kIntSignpostBlitCopyStr "PyTorchCopyIntervals"
14#define kIntSignpostCPUFallbacksStr "PyTorchCPUFallbackIntervals"
15#define kEvtSignpostRunOperationStr "PyTorchOperationEvents"
16#define kEvtSignpostBlitCopyStr "PyTorchCopyEvents"
17#define kEvtSignpostCPUFallbacksStr "PyTorchCPUFallbacksEvents"
18#define kEVLogProfileInfoStr "PYTORCH_MPS_LOG_PROFILE_INFO"
19#define kEVTraceSignpostsStr "PYTORCH_MPS_TRACE_SIGNPOSTS"
20
21namespace at::mps {
22namespace Profiler {
23
24const std::string BaseInfo::toString(double gpuTime, double schedulingTime) const {
25  // the gpuTime will be non-zero mainly for event-based signposts.
26  // The interval-based signposts will have "duration" as well as accumulated
27  // total GPU time, up to the point of execution.
28  return fmt::format("{}{}",
29                     gpuTime > 0.0 ? fmt::format(", gpu={:.3f} ms", gpuTime) : "",
30                     schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : "");
31}
32
33const std::string OperationInfo::toString(double gpuTime, double schedulingTime) const {
34  return fmt::format("aten::{} (id={}{}, run={}{})",
35                     strKey,
36                     type == Type::GRAPH ? "G" : "K",
37                     profileId,
38                     runCount,
39                     BaseInfo::toString(gpuTime, schedulingTime));
40}
41
42const std::string CpuFbInfo::toString(double gpuTime, double schedulingTime) const {
43  return fmt::format("CPU Fallback::{} (id={}, run={}, CopyOverhead={}{})",
44                     strKey,
45                     profileId,
46                     runCount,
47                     getIMPSAllocator()->formatSize(currentCopyOverhead),
48                     BaseInfo::toString(0.0, schedulingTime));
49}
50
51const std::string CopyInfo::toString(double gpuTime, double schedulingTime) const {
52  return fmt::format("{}Copy{}: {} --> {} (len={}{})",
53                     // Copies could be using Blit Encoder, or using regular
54                     // memcpy() on Unified memory
55                     usesBlitter ? "Blit" : "Mem",
56                     // CopySync indicates COMMIT_AND_WAIT was used to synchronize
57                     // the GPU stream with CPU after the blocking copy
58                     isNonBlocking ? "" : "Sync",
59                     srcStrKey,
60                     dstStrKey,
61                     getIMPSAllocator()->formatSize(length),
62                     BaseInfo::toString(gpuTime, schedulingTime));
63}
64
65std::string CopyInfo::buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId) {
66  if (tensor.has_value()) {
67    return BaseInfo::buildTensorString(*tensor, includeBufferId);
68  }
69  // if tensor is not defined (e.g., copy_blit_mps()), then use buffer
70  // pointer to build the string.
71  const bool isBufferOnMPS = isStorageOnMPS(buffer, tensor);
72  return fmt::format("{}:{:p}", isBufferOnMPS ? "MPS" : "CPU", buffer);
73}
74
75MPSProfiler::MPSProfiler() : m_os_log_events(nullptr), m_os_log_intervals(nullptr) {
76  // see enum LogOptions for the description.
77  static const char* log_options_str = getenv(kEVLogProfileInfoStr);
78  m_log_options = log_options_str ? strtol(log_options_str, nullptr, 0) : 0;
79  // see enums profilerOptions and SignpostTypes for the description.
80  static const char* trace_signpost_str = getenv(kEVTraceSignpostsStr);
81  uint32_t trace_signposts = trace_signpost_str ? strtol(trace_signpost_str, nullptr, 0) : 0;
82
83  TORCH_CHECK(m_log_options <= LogOptions::LOG_COUNT,
84              "invalid log options ",
85              m_log_options,
86              " passed to ",
87              kEVLogProfileInfoStr)
88  // lower 16 bits used for options (see enum ProfileOptions)
89  m_profile_options |= trace_signposts & 0xFFFF;
90  TORCH_CHECK(m_profile_options <= ProfileOptions::OPTIONS_COUNT,
91              "invalid profiling options ",
92              trace_signposts,
93              " passed to ",
94              kEVTraceSignpostsStr)
95  // upper 16 bits used for signpost types (see enum SignpostTypes)
96  m_signpost_types |= trace_signposts & 0xFFFF0000;
97  TORCH_CHECK(m_signpost_types <= SignpostTypes::SIGNPOST_COUNT,
98              "invalid signpost types ",
99              trace_signposts,
100              " passed to ",
101              kEVTraceSignpostsStr)
102  currentSigint.sa_handler = nullptr;
103  previousSigint.sa_handler = nullptr;
104
105  initialize();
106}
107
108MPSProfiler::~MPSProfiler() {
109  // first make sure completion handlers are completed
110  auto stream = getDefaultMPSStream();
111  dispatch_sync(stream->queue(), ^() {
112    if (hasPendingCompletionHandlers) {
113      stream->synchronize(SyncType::COMMIT_AND_WAIT);
114    }
115  });
116  logProfilingStats();
117
118  if (m_os_log_events) {
119    os_release(m_os_log_events);
120  }
121  if (m_os_log_intervals) {
122    os_release(m_os_log_intervals);
123  }
124}
125
126void MPSProfiler::initialize() {
127  if ((m_signpost_types == SignpostTypes::SIGNPOST_NONE) &&
128      (m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL)) {
129    m_profile_options |= ProfileOptions::ALL_SIGNPOST_INTERVALS;
130  }
131
132  if (m_profile_options & (ProfileOptions::ALL_SIGNPOST_EVENTS | ProfileOptions::ALL_SIGNPOST_INTERVALS)) {
133    // enable all signposts types
134    m_signpost_types |= (SignpostTypes::RUN_OPERATION | SignpostTypes::CPU_FALLBACK | SignpostTypes::BLIT_COPY);
135
136    if (m_profile_options & ProfileOptions::ALL_SIGNPOST_EVENTS) {
137      m_profile_options |= ProfileOptions::USE_EVENTS;
138    }
139    if (m_profile_options & ProfileOptions::ALL_SIGNPOST_INTERVALS) {
140      m_profile_options |= ProfileOptions::USE_INTERVALS;
141    }
142  }
143
144  if (m_log_options & LogOptions::ALL_STATS) {
145    m_log_options |= LogOptions::OPERATION_STATS | LogOptions::COPY_STATS | LogOptions::CPU_FALLBACK_STATS;
146  }
147
148  if (m_signpost_types != SignpostTypes::SIGNPOST_NONE) {
149    // if no signpost options passed, use interval mode by default
150    if (!(m_profile_options & (ProfileOptions::USE_EVENTS | ProfileOptions::USE_INTERVALS))) {
151      m_profile_options |= ProfileOptions::USE_INTERVALS;
152    }
153    if ((m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL) &&
154        (m_profile_options & ProfileOptions::USE_EVENTS)) {
155      TORCH_CHECK((m_profile_options & ProfileOptions::USE_INTERVALS),
156                  "the option 'INCLUDE_SCHEDULE_INTERVAL' only works for interval-based signposts");
157    }
158
159    // technically, it's possible to trace both events and intervals at the same time
160    if (m_profile_options & ProfileOptions::USE_EVENTS) {
161      if (!m_os_log_events) {
162        m_os_log_events = os_log_create(kMPSProfilerSubSystemStr, kMPSCategoryEventsStr);
163        TORCH_CHECK(m_os_log_events, "failed to create OS signpost log for events profiler");
164      }
165      // include GPU time in metadata for event-based intervals by default, since
166      // events are marked in Metal Completion Handlers which outputs GPU time
167      m_log_options |= INCLUDE_GPU_TIME;
168    }
169    if (m_profile_options & ProfileOptions::USE_INTERVALS) {
170      if (!m_os_log_intervals) {
171        m_os_log_intervals = os_log_create(kMPSProfilerSubSystemStr, kMPSCategoryIntervalsStr);
172        TORCH_CHECK(m_os_log_intervals, "failed to create OS signpost log for intervals profiler");
173      }
174    }
175  }
176
177  if (m_log_options & LogOptions::COPY_STATS) {
178    if (m_copy_stat_list.empty()) {
179      m_copy_stat_list.emplace(CopyInfo::Kind::MPS_TO_MPS, std::make_unique<CopyStat>("MPS to MPS"));
180      m_copy_stat_list.emplace(CopyInfo::Kind::MPS_TO_CPU, std::make_unique<CopyStat>("MPS to CPU"));
181      m_copy_stat_list.emplace(CopyInfo::Kind::CPU_TO_MPS, std::make_unique<CopyStat>("CPU to MPS"));
182    }
183  }
184
185  // used to capture sigint signal to log profiling stats
186  if (m_log_options & (LogOptions::OPERATION_STATS | LogOptions::COPY_STATS | LogOptions::CPU_FALLBACK_STATS)) {
187    if (!currentSigint.sa_handler) {
188      currentSigint.sa_handler = &handleIntSignal;
189      currentSigint.sa_flags = SA_RESTART;
190      sigfillset(&currentSigint.sa_mask);
191      if (sigaction(SIGINT, &currentSigint, &previousSigint) == -1) {
192        AT_ERROR("Cannot install SIGINT handler for MPSProfiler.");
193      }
194    }
195  }
196}
197
198void MPSProfiler::StartTrace(const std::string& mode, bool waitUntilCompleted) {
199  TORCH_CHECK(m_profile_options == ProfileOptions::OPTIONS_NONE, "Tracing Signposts is already enabled ");
200
201  std::stringstream ss(mode);
202  std::string token;
203  while (getline(ss, token, ',')) {
204    if (!token.empty()) {
205      if (token == "interval") {
206        m_profile_options |= ProfileOptions::ALL_SIGNPOST_INTERVALS;
207      } else if (token == "event") {
208        m_profile_options |= ProfileOptions::ALL_SIGNPOST_EVENTS;
209      } else {
210        AT_ERROR("Invalid Signpost trace mode: ", token);
211      }
212    }
213  }
214  if (m_profile_options != ProfileOptions::OPTIONS_NONE) {
215    if (waitUntilCompleted) {
216      m_profile_options |= ProfileOptions::WAIT_UNTIL_COMPLETED;
217    }
218    initialize();
219  }
220}
221
222void MPSProfiler::StopTrace() {
223  m_profile_options = ProfileOptions::OPTIONS_NONE;
224  m_signpost_types = SignpostTypes::SIGNPOST_NONE;
225}
226
227void MPSProfiler::beginProfileExecution(BaseInfo& info, bool cpuExecution) {
228  // see comments in isProfileInfoLoggingEnabled()
229  if (isProfileInfoLoggingEnabled(info.type, /*isExecutionEnded*/ false)) {
230    fmt::print(stderr, "{}\n", info.toString());
231  }
232  SignpostTypes signpostType = getSignpostType(info.type);
233  if (!(m_signpost_types & signpostType)) {
234    return;
235  }
236  if (m_profile_options & ProfileOptions::USE_EVENTS) {
237    info.eventSignpostId = generateSignpostId(OS_SIGNPOST_EVENT);
238  }
239  if (m_profile_options & ProfileOptions::USE_INTERVALS) {
240    info.intervalSignpostId = generateSignpostId(OS_SIGNPOST_INTERVAL_BEGIN);
241    // if scheduling part is included, we begin the interval early in here,
242    // otherwise we begin when the scheduledHandler callback is triggered.
243    if ((m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL) || cpuExecution) {
244      beginSignpostInterval(signpostType, info.intervalSignpostId, info.toString());
245      info.completed = false;
246      // for graphs, we add the scheduleHandler in beginProfileGPUInterval()
247    } else if (info.type == BaseInfo::Type::KERNEL || info.type == BaseInfo::Type::COPY) {
248      addProfilerScheduledHandler(info);
249    }
250  }
251}
252
253void MPSProfiler::endProfileExecution(BaseInfo& info,
254                                      os_signpost_id_t event_signpost_id,
255                                      os_signpost_id_t interval_signpost_id,
256                                      double gpuTime,
257                                      double schedulingTime) {
258  const SignpostTypes signpostType = getSignpostType(info.type);
259
260  if (info.type == BaseInfo::Type::COPY) {
261    updateCopyStats(static_cast<CopyInfo&>(info), gpuTime, schedulingTime);
262  } else {
263    info.totalGpuTime = info.totalGpuTime + gpuTime;
264    info.totalSchedulingTime = info.totalSchedulingTime + schedulingTime;
265  }
266  // if Kernel time is not included in metadata separately, we add it to gpuTime in metadata
267  if (gpuTime > 0.0 && !(m_log_options & LogOptions::INCLUDE_KERNEL_TIME)) {
268    gpuTime += schedulingTime;
269    schedulingTime = 0;
270  }
271  const std::string& infoStr = info.toString(gpuTime, schedulingTime);
272  // see comments in isProfileInfoLoggingEnabled()
273  if (isProfileInfoLoggingEnabled(info.type, /*isExecutionEnded*/ true)) {
274    fmt::print(stderr, "{}\n", infoStr);
275  }
276  // it is possible to use both interval and event based signposts at the same time
277  if ((m_profile_options & ProfileOptions::USE_EVENTS) && event_signpost_id) {
278    emitSignpostEvent(signpostType, event_signpost_id, infoStr);
279  }
280  // GPU time for signpost intervals is calculated based on its duration
281  if ((m_profile_options & ProfileOptions::USE_INTERVALS) && interval_signpost_id) {
282    endSignpostInterval(signpostType, interval_signpost_id);
283  }
284  info.completed = true;
285}
286
287uint64_t MPSProfiler::beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph) {
288  // only do profiling if operation execution profiling or logging are enabled
289  if (!isOperationProfilingEnabled()) {
290    return 0;
291  }
292  if (m_op_info_list.count(uintptr_t(handle)) == 0) {
293    auto opInfo =
294        std::make_unique<OperationInfo>(handle, isGraph, isGraph ? ++m_graph_counter : ++m_kernel_counter, strKey);
295    m_op_info_list.emplace(opInfo->handle, std::move(opInfo));
296  }
297  auto& opInfo = *m_op_info_list[uintptr_t(handle)];
298  opInfo.strKey.assign(strKey);
299  opInfo.runCount++;
300  beginProfileExecution(opInfo);
301
302  return opInfo.profileId;
303}
304
305uint64_t MPSProfiler::beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors) {
306  if (isOperationProfilingEnabled()) {
307    const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID;
308    std::string profilerStrKey = OperationInfo::buildKernelString(kernelName, tensors, includeBufferId);
309    return beginProfileKernel(handle, profilerStrKey, false);
310  }
311  return 0;
312}
313
314void MPSProfiler::beginProfileGPUInterval(const void* handle) {
315  // this function is only relevant for interval-based Signposts which exclude
316  // schedule time (only includes GPU run time)
317  if (!(m_profile_options & ProfileOptions::USE_INTERVALS) ||
318      (m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL)) {
319    return;
320  }
321  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_op_info_list.count(uintptr_t(handle)), "Failed to get operation information!");
322  auto& opInfo = *m_op_info_list[uintptr_t(handle)];
323  // this begins the interval when scheduling the execution is
324  // completed already (i.e., scheduling excluded from interval)
325  addProfilerScheduledHandler(opInfo);
326}
327
328void MPSProfiler::endProfileKernel(const void* handle, SyncType syncType) {
329  // only do profiling if operation execution profiling or logging are enabled
330  if (!isOperationProfilingEnabled()) {
331    return;
332  }
333  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_op_info_list.count(uintptr_t(handle)), "Failed to get operation information!");
334  auto& opInfo = *m_op_info_list[uintptr_t(handle)];
335  addProfilerCompletedHandler(opInfo, syncType);
336}
337
338uint64_t MPSProfiler::beginProfileCPUFallback(const std::string& opName, const TensorList& tensors) {
339  if (m_cpu_fb_info_list.count(opName) == 0) {
340    auto cpuFbInfo = std::make_unique<CpuFbInfo>(++m_cpu_fb_counter, opName);
341    m_cpu_fb_info_list.emplace(opName, std::move(cpuFbInfo));
342  }
343  auto& cpuFbInfo = *m_cpu_fb_info_list[opName];
344  cpuFbInfo.runCount++;
345  cpuFbInfo.startTime = BaseInfo::getTime();
346  const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID;
347  cpuFbInfo.strKey = OperationInfo::buildKernelString(opName, tensors, includeBufferId);
348  cpuFbInfo.updateCopyOverhead(tensors);
349  beginProfileExecution(cpuFbInfo, true);
350
351  return cpuFbInfo.profileId;
352}
353
354void MPSProfiler::endProfileCPUFallback(const std::string& opName) {
355  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_cpu_fb_info_list.count(opName), "Failed to get CPU Fallback information!");
356  auto& cpuFbInfo = *m_cpu_fb_info_list[opName];
357  // CPU time in ms
358  double cpuTime = double(BaseInfo::getTime() - cpuFbInfo.startTime) * 1e-6;
359  endProfileExecution(cpuFbInfo, cpuFbInfo.eventSignpostId, cpuFbInfo.intervalSignpostId, 0, cpuTime);
360}
361
362uint64_t MPSProfiler::beginProfileCopy(const void* srcBuffer,
363                                       const void* dstBuffer,
364                                       const OptionalTensorRef srcTensor,
365                                       const OptionalTensorRef dstTensor,
366                                       size_t length,
367                                       bool isNonBlocking,
368                                       bool usesBlitter) {
369  if (!isCopyProfilingEnabled()) {
370    return 0;
371  }
372  const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID;
373  const uint64_t profileId = ++m_copy_counter;
374  auto copyInfo = std::make_unique<CopyInfo>(dstBuffer, length, profileId, isNonBlocking, usesBlitter);
375  copyInfo->srcStrKey = CopyInfo::buildTensorString(srcBuffer, srcTensor, includeBufferId);
376  copyInfo->dstStrKey = CopyInfo::buildTensorString(dstBuffer, dstTensor, includeBufferId);
377  copyInfo->kind = CopyInfo::getCopyKind(srcBuffer, dstBuffer, srcTensor, dstTensor);
378  if (!usesBlitter) {
379    // for copies that don't use blitters, we measure CPU time
380    copyInfo->startTime = BaseInfo::getTime();
381  }
382  // don't generate signposts if the non-blocking copy is not using the blitter
383  if (usesBlitter || !isNonBlocking) {
384    beginProfileExecution(*copyInfo, !usesBlitter);
385  }
386  // this should not happen since we erase the copy info after profiling/logging it.
387  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_copy_info_list.count(profileId) == 0);
388  // the copy info isn't retained in the list, so we erase the completed ones
389  for (auto it = m_copy_info_list.begin(), last = m_copy_info_list.end(); it != last;) {
390    if (it->second->completed) {
391      it = m_copy_info_list.erase(it);
392    } else {
393      ++it;
394    }
395  }
396  m_copy_info_list.emplace(profileId, std::move(copyInfo));
397
398  return profileId;
399}
400
401void MPSProfiler::endProfileCopy(uint64_t profileId, SyncType syncType) {
402  // this is just an identifier, and not used to access memory
403  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_copy_info_list.count(profileId), "Failed to get copy information!");
404  auto& copyInfo = *m_copy_info_list[profileId];
405  if (copyInfo.usesBlitter) {
406    addProfilerCompletedHandler(copyInfo, syncType);
407  } else {
408    double cpuTime = double(BaseInfo::getTime() - copyInfo.startTime) * 1e-6;
409    endProfileExecution(copyInfo, copyInfo.eventSignpostId, copyInfo.intervalSignpostId, 0, cpuTime);
410  }
411}
412
413void MPSProfiler::addProfilerScheduledHandler(BaseInfo& info) {
414  const SignpostTypes signpostType = getSignpostType(info.type);
415  const os_signpost_id_t intervalSignpostId = info.intervalSignpostId;
416
417  auto m_stream = getDefaultMPSStream();
418  // NOTE: the following block isn't thread-safe
419  [m_stream->commandBuffer() addScheduledHandler:^(id<MTLCommandBuffer> cb) {
420    // begin the interval once scheduling has completed (if INCLUDE_SCHEDULE_INTERVAL flag is disabled)
421    beginSignpostInterval(signpostType, intervalSignpostId, info.toString());
422    info.completed = false;
423  }];
424}
425
426void MPSProfiler::updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime) {
427  if (!(m_log_options & LogOptions::COPY_STATS)) {
428    return;
429  }
430  auto& copyStat = *m_copy_stat_list[copyInfo.kind];
431  copyStat.totalCount++;
432  copyStat.length += copyInfo.length;
433  copyStat.totalGpuTime = copyStat.totalGpuTime + gpuTime;
434  copyStat.totalSchedulingTime = copyStat.totalSchedulingTime + schedulingTime;
435  if (copyInfo.length <= sizeof(int64_t)) {
436    copyStat.scalarsCount++;
437    copyStat.scalarsGpuTime = copyStat.scalarsGpuTime + gpuTime;
438  }
439  copyStat.blockingCount += !copyInfo.isNonBlocking ? 1 : 0;
440  copyStat.memcpyCount += !copyInfo.usesBlitter ? 1 : 0;
441}
442
443void MPSProfiler::addProfilerCompletedHandler(BaseInfo& info, SyncType syncType) {
444  const os_signpost_id_t intervalSignpostId = info.intervalSignpostId;
445  const os_signpost_id_t eventSignpostId = info.eventSignpostId;
446
447  // signpost ID is used only for interval-based signposts, and must be non-zero
448  if (m_profile_options & ProfileOptions::USE_INTERVALS) {
449    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(intervalSignpostId, "Signpost interval has no identifier!");
450  }
451  // reset signpostIds for sanity check on next call
452  info.intervalSignpostId = 0;
453  info.eventSignpostId = 0;
454  hasPendingCompletionHandlers = true;
455
456  auto m_stream = getDefaultMPSStream();
457  // NOTE: the following block isn't thread-safe
458  [m_stream->commandBuffer() addCompletedHandler:^(id<MTLCommandBuffer> cb) {
459    CFTimeInterval gpuTime = cb.GPUEndTime > cb.GPUStartTime ? (cb.GPUEndTime - cb.GPUStartTime) * 1000.0 : 0.;
460    CFTimeInterval schedulingTime =
461        cb.kernelEndTime > cb.kernelStartTime ? (cb.kernelEndTime - cb.kernelStartTime) * 1000.0 : 0.;
462
463    endProfileExecution(info, eventSignpostId, intervalSignpostId, gpuTime, schedulingTime);
464    hasPendingCompletionHandlers = false;
465  }];
466
467  m_stream->synchronize((m_profile_options & ProfileOptions::WAIT_UNTIL_COMPLETED) ? SyncType::COMMIT_AND_WAIT
468                                                                                   : syncType);
469}
470
471void MPSProfiler::logOperationsProfilingStats(std::FILE* f) const {
472  if (m_op_info_list.empty()) {
473    // this is not an error, but to let the user know that the
474    // LogOptions::KERNEL_STATS that they passed to EV is not yielding anything.
475    fmt::print(f, "There are no MPS operations logged for profiling\n");
476    return;
477  }
478  // dump the ops info into a vector to sort them
479  std::vector<OperationInfo*> opsList;
480  std::transform(m_op_info_list.begin(), m_op_info_list.end(), std::back_inserter(opsList), [](auto& opInfo) {
481    return opInfo.second.get();
482  });
483
484  // sort based on "Mean GPU time" in descending order
485  std::sort(opsList.begin(), opsList.end(), [](const OperationInfo* a, const OperationInfo* b) {
486    return (a->totalGpuTime / double(a->runCount)) > (b->totalGpuTime / double(b->runCount));
487  });
488  // print the table of operation profiling stats
489  fmt::print(f,
490             "\n{:-^200}\n{:^6}|{:^7}|{:^15}|{:^14}|{:^15}| {}\n{:-^200}\n",
491             fmt::format(" MPS Operations Profiling: {} graphs, {} kernels ", m_graph_counter, m_kernel_counter),
492             "ID",
493             "#Runs",
494             "Mean KRNL(ms)",
495             "Mean GPU(ms)",
496             "Total GPU(ms)",
497             "Operation Name",
498             "");
499
500  for (const auto& opInfo : opsList) {
501    fmt::print(f,
502               "{:^7}{:^8}{:^16}{:^15}{:^16} {}\n",
503               fmt::format("{}{}", opInfo->type == BaseInfo::Type::GRAPH ? "G" : "K", opInfo->profileId),
504               opInfo->runCount,
505               fmt::format("{:.3f}", opInfo->totalSchedulingTime / double(opInfo->runCount)),
506               fmt::format("{:.3f}", opInfo->totalGpuTime / double(opInfo->runCount)),
507               fmt::format("{:.3f}", opInfo->totalGpuTime.load()),
508               opInfo->strKey);
509  }
510}
511
512void MPSProfiler::logCPUFallbackProfilingStats(std::FILE* f) const {
513  if (m_cpu_fb_info_list.empty()) {
514    // this is not an error, but to let the user know that the
515    // LogOptions::KERNEL_STATS that they passed to EV is not yielding anything.
516    fmt::print(f, "There are no CPU Fallbacks logged for profiling\n");
517    return;
518  }
519  size_t totalCopyOverhead = 0;
520  size_t totalRunCount = 0;
521  double totalCPUTime = 0.;
522  // dump the map's info into a vector to sort them
523  std::vector<CpuFbInfo*> cpuFbList;
524  std::transform(
525      m_cpu_fb_info_list.begin(), m_cpu_fb_info_list.end(), std::back_inserter(cpuFbList), [&](auto& cpuFbInfo) {
526        auto cpuFbInfoPtr = cpuFbInfo.second.get();
527        totalRunCount += cpuFbInfoPtr->runCount;
528        totalCopyOverhead += cpuFbInfoPtr->totalCopyOverhead;
529        totalCPUTime += cpuFbInfoPtr->totalSchedulingTime;
530        return cpuFbInfoPtr;
531      });
532
533  // sort based on "Mean CPU time" in descending order
534  std::sort(cpuFbList.begin(), cpuFbList.end(), [](const CpuFbInfo* a, const CpuFbInfo* b) {
535    return (a->totalSchedulingTime / double(a->runCount)) > (b->totalSchedulingTime / double(b->runCount));
536  });
537
538  // print the table of CPU Fallback profiling stats
539  fmt::print(f,
540             "\n{:-^150}\n{:^5}|{:^7}|{:^14}|{:^15}|{:^15}| {}\n{:-^150}\n",
541             fmt::format(" CPU Fallback Profiling: Total {} Runs, {:.2f} ms, {} Copies ",
542                         totalRunCount,
543                         totalCPUTime,
544                         getIMPSAllocator()->formatSize(totalCopyOverhead)),
545             "ID",
546             "#Runs",
547             "Mean CPU(ms)",
548             "Total CPU(ms)",
549             "Copy Overhead",
550             "Operation Name",
551             "");
552
553  for (const auto& cpuFbInfo : cpuFbList) {
554    fmt::print(f,
555               "{:^6}{:^8}{:^15}{:^16}{:^16} {}\n",
556               cpuFbInfo->profileId,
557               cpuFbInfo->runCount,
558               fmt::format("{:.3f}", cpuFbInfo->totalSchedulingTime / double(cpuFbInfo->runCount)),
559               fmt::format("{:.3f}", cpuFbInfo->totalSchedulingTime.load()),
560               getIMPSAllocator()->formatSize(cpuFbInfo->totalCopyOverhead),
561               cpuFbInfo->opName);
562  }
563}
564
565void MPSProfiler::logCopyProfilingStats(std::FILE* f) const {
566  size_t totalCopiesCount = 0;
567  size_t totalCopySize = 0;
568  size_t totalScalarCopyCount = 0;
569
570  for (const auto& copyStatPair : m_copy_stat_list) {
571    const auto& copyStat = *copyStatPair.second;
572    totalCopiesCount += copyStat.totalCount;
573    totalCopySize += copyStat.length;
574    totalScalarCopyCount += copyStat.scalarsCount;
575  }
576  if (totalCopiesCount == 0) {
577    // this is not an error, but to let the user know that the
578    // LogOptions::COPY_STATS that they passed to EV is not yielding anything.
579    fmt::print(f, "There are no copies logged for profiling\n");
580    return;
581  }
582
583  // print the table of copy profiling stats
584  fmt::print(f,
585             "\n{:-^160}\n{:^12}|{:^10}|{:^17}|{:^16}|{:^15}|{:^9}|{:^13}|{:^10}|{:^8}\n{:-^160}\n",
586             fmt::format(" MPS Copy Profiling: {} total copies ({}), {} scalar copies ",
587                         totalCopiesCount,
588                         getIMPSAllocator()->formatSize(totalCopySize),
589                         totalScalarCopyCount),
590             "Kind",
591             "Total#",
592             "Total Size",
593             "Total KRNL(ms)",
594             "Total GPU(ms)",
595             "Scalars",
596             "Scalars GPU",
597             "Blocking",
598             "memcpy",
599             "");
600
601  for (const auto& copyStatPair : m_copy_stat_list) {
602    const auto& copyStat = *copyStatPair.second;
603    if (copyStat.totalCount > 0) {
604      fmt::print(
605          f,
606          "{:^13}{:^11}{:^18}{:^17}{:^16}{:^10}{:^14}{:^11}{:^9}\n",
607          copyStat.kindStr,
608          copyStat.totalCount,
609          getIMPSAllocator()->formatSize(copyStat.length),
610          fmt::format("{:.3f}", copyStat.totalSchedulingTime.load()),
611          fmt::format("{:.3f}", copyStat.totalGpuTime.load()),
612          copyStat.scalarsCount,
613          fmt::format("{:.2f} %",
614                      copyStat.totalGpuTime > 0.0
615                          ? (1.0 - ((copyStat.totalGpuTime - copyStat.scalarsGpuTime) / copyStat.totalGpuTime)) * 100.0
616                          : 0.0),
617          copyStat.blockingCount,
618          copyStat.memcpyCount);
619    }
620  }
621}
622
623void MPSProfiler::logProfilingStats() {
624  if (hasLoggedStats.exchange(true)) {
625    return;
626  }
627  // logs kernel profiling stats when the process ends (if enabled).
628  if (m_log_options & LogOptions::OPERATION_STATS) {
629    logOperationsProfilingStats(stderr);
630  }
631  // logs CPU Fallback profiling stats when the process ends (if enabled).
632  if (m_log_options & LogOptions::CPU_FALLBACK_STATS) {
633    logCPUFallbackProfilingStats(stderr);
634  }
635  // logs copies profiling stats when the process ends (if enabled).
636  if (m_log_options & LogOptions::COPY_STATS) {
637    logCopyProfilingStats(stderr);
638  }
639}
640
641bool MPSProfiler::isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded) {
642  bool isInfoLoggingEnabled = false;
643  // logging the operations, copies, cpu fallbacks info during the execution
644  // is enabled via the env-var defined in kEVLogProfileInfoStr
645  switch (infoType) {
646    case BaseInfo::Type::GRAPH:
647    case BaseInfo::Type::KERNEL:
648      isInfoLoggingEnabled = (m_log_options & LogOptions::OPERATION_INFO);
649      break;
650    case BaseInfo::Type::COPY:
651      isInfoLoggingEnabled = (m_log_options & LogOptions::COPY_INFO);
652      break;
653    case BaseInfo::Type::CPU_FALLBACK:
654      isInfoLoggingEnabled = (m_log_options & LogOptions::CPU_FALLBACK_INFO);
655      break;
656    default:
657      AT_ERROR("invalid profiling info type");
658  }
659  if (!isInfoLoggingEnabled) {
660    return false;
661  }
662  // if GPU/Kernel times are included then log info when op execution ends
663  bool logWhenExecutionEnds = m_log_options & (LogOptions::INCLUDE_GPU_TIME | LogOptions::INCLUDE_KERNEL_TIME);
664  return isExecutionEnded ? logWhenExecutionEnds : !logWhenExecutionEnds;
665}
666
667void MPSProfiler::emitSignpostEvent(SignpostTypes signpost_type,
668                                    os_signpost_id_t signpost_id,
669                                    const std::string& msg_str) const {
670  if (!(m_signpost_types & signpost_type) || !signpost_id || !m_os_log_events ||
671      !os_signpost_enabled(m_os_log_events)) {
672    return;
673  }
674  const char* msg = msg_str.c_str();
675
676  // need to use switch-case as the signpost names must be literal strings
677  switch (signpost_type) {
678    case SignpostTypes::RUN_OPERATION:
679      os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostRunOperationStr, "%s", msg);
680      break;
681    case SignpostTypes::BLIT_COPY:
682      os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostBlitCopyStr, "%s", msg);
683      break;
684    case SignpostTypes::CPU_FALLBACK:
685      os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostCPUFallbacksStr, "%s", msg);
686      break;
687    default:
688      AT_ERROR("unknown SignpostType in MPS profiler");
689  }
690}
691
692void MPSProfiler::beginSignpostInterval(SignpostTypes signpost_type,
693                                        os_signpost_id_t signpost_id,
694                                        const std::string& msg_str) const {
695  if (!(m_signpost_types & signpost_type) || !signpost_id || !m_os_log_intervals ||
696      !os_signpost_enabled(m_os_log_intervals)) {
697    return;
698  }
699  const char* msg = msg_str.c_str();
700
701  switch (signpost_type) {
702    case SignpostTypes::RUN_OPERATION:
703      os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostRunOperationStr, "%s", msg);
704      break;
705    case SignpostTypes::BLIT_COPY:
706      os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostBlitCopyStr, "%s", msg);
707      break;
708    case SignpostTypes::CPU_FALLBACK:
709      os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr, "%s", msg);
710      break;
711    default:
712      AT_ERROR("unknown SignpostType in MPS profiler");
713  }
714}
715
716void MPSProfiler::endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const {
717  if (!m_os_log_intervals || !os_signpost_enabled(m_os_log_intervals)) {
718    return;
719  }
720  switch (signpost_type) {
721    case SignpostTypes::RUN_OPERATION:
722      os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostRunOperationStr);
723      break;
724    case SignpostTypes::BLIT_COPY:
725      os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostBlitCopyStr);
726      break;
727    case SignpostTypes::CPU_FALLBACK:
728      os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr);
729      break;
730    default:
731      AT_ERROR("unknown SignpostType in MPS profiler");
732  }
733}
734
735os_signpost_id_t MPSProfiler::generateSignpostId(os_signpost_type_t signpostType, const void* ptr) {
736  os_log_t os_log = signpostType == OS_SIGNPOST_EVENT ? m_os_log_events : m_os_log_intervals;
737  if (ptr) {
738    return os_signpost_id_make_with_pointer(os_log, ptr);
739  }
740  return os_signpost_id_generate(os_log);
741}
742
743MPSProfiler::SignpostTypes MPSProfiler::getSignpostType(BaseInfo::Type infoType) {
744  switch (infoType) {
745    case BaseInfo::Type::GRAPH:
746    case BaseInfo::Type::KERNEL:
747      return SignpostTypes::RUN_OPERATION;
748    case BaseInfo::Type::COPY:
749      return SignpostTypes::BLIT_COPY;
750    case BaseInfo::Type::CPU_FALLBACK:
751      return SignpostTypes::CPU_FALLBACK;
752    default:
753      AT_ERROR("invalid profiling info type");
754  }
755}
756
757void MPSProfiler::handleIntSignal(int signal) {
758  getMPSProfiler().logProfilingStats();
759  if (previousSigint.sa_handler) {
760    previousSigint.sa_handler(signal);
761  }
762}
763
764// used to capture sigint signal to log profiling stats
765struct sigaction MPSProfiler::currentSigint {};
766struct sigaction MPSProfiler::previousSigint {};
767
768bool MPSProfiler::isCapturing() const {
769  return [captureManager isCapturing];
770}
771
772bool MPSProfiler::isCaptureEnabled() const {
773  if (captureManager == nil) {
774    captureManager = [MTLCaptureManager sharedCaptureManager];
775  }
776  static bool isEnabled = [this]() {
777    return [captureManager supportsDestination:MTLCaptureDestinationGPUTraceDocument];
778  }();
779  return isEnabled;
780}
781
782void MPSProfiler::startCapture(const std::string& name, MPSStream* stream) {
783  if (captureManager == nil) {
784    captureManager = [MTLCaptureManager sharedCaptureManager];
785  }
786  NSError* err = nil;
787  NSString* fname = [NSString stringWithFormat:@"%04d-%s.gputrace", captureCount++, name.c_str()];
788  MTLCaptureDescriptor* captureDescriptor = [MTLCaptureDescriptor new];
789  captureDescriptor.captureObject = stream ? (id)stream->commandQueue() : (id)MPSDevice::getInstance()->device();
790  captureDescriptor.destination = MTLCaptureDestinationGPUTraceDocument;
791  captureDescriptor.outputURL = [NSURL fileURLWithPath:fname];
792  auto rc = [captureManager startCaptureWithDescriptor:captureDescriptor error:&err];
793  TORCH_CHECK(rc, "Failed to start capture of ", [fname UTF8String], " error ", [[err description] UTF8String]);
794}
795
796void MPSProfiler::stopCapture(MPSStream* stream) {
797  if (stream) {
798    stream->synchronize(SyncType::COMMIT);
799  }
800  [captureManager stopCapture];
801}
802
803} // namespace Profiler
804
805Profiler::MPSProfiler& getMPSProfiler() {
806  static std::unique_ptr<Profiler::MPSProfiler> mps_profiler;
807  if (mps_profiler == nullptr) {
808    mps_profiler = std::make_unique<Profiler::MPSProfiler>();
809  }
810  return *mps_profiler;
811}
812
813} // namespace at::mps
814