1// Copyright © 2022 Apple Inc. 2 3#include <ATen/mps/MPSProfiler.h> 4#include <c10/util/Exception.h> 5#include <fmt/format.h> 6 7// these need to be literal strings when passed to os_signpost*() 8// function macros; so no LUTs could be used 9#define kMPSProfilerSubSystemStr "PyTorchMPS" 10#define kMPSCategoryEventsStr "Events" 11#define kMPSCategoryIntervalsStr "Intervals" 12#define kIntSignpostRunOperationStr "PyTorchOperationIntervals" 13#define kIntSignpostBlitCopyStr "PyTorchCopyIntervals" 14#define kIntSignpostCPUFallbacksStr "PyTorchCPUFallbackIntervals" 15#define kEvtSignpostRunOperationStr "PyTorchOperationEvents" 16#define kEvtSignpostBlitCopyStr "PyTorchCopyEvents" 17#define kEvtSignpostCPUFallbacksStr "PyTorchCPUFallbacksEvents" 18#define kEVLogProfileInfoStr "PYTORCH_MPS_LOG_PROFILE_INFO" 19#define kEVTraceSignpostsStr "PYTORCH_MPS_TRACE_SIGNPOSTS" 20 21namespace at::mps { 22namespace Profiler { 23 24const std::string BaseInfo::toString(double gpuTime, double schedulingTime) const { 25 // the gpuTime will be non-zero mainly for event-based signposts. 26 // The interval-based signposts will have "duration" as well as accumulated 27 // total GPU time, up to the point of execution. 28 return fmt::format("{}{}", 29 gpuTime > 0.0 ? fmt::format(", gpu={:.3f} ms", gpuTime) : "", 30 schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : ""); 31} 32 33const std::string OperationInfo::toString(double gpuTime, double schedulingTime) const { 34 return fmt::format("aten::{} (id={}{}, run={}{})", 35 strKey, 36 type == Type::GRAPH ? "G" : "K", 37 profileId, 38 runCount, 39 BaseInfo::toString(gpuTime, schedulingTime)); 40} 41 42const std::string CpuFbInfo::toString(double gpuTime, double schedulingTime) const { 43 return fmt::format("CPU Fallback::{} (id={}, run={}, CopyOverhead={}{})", 44 strKey, 45 profileId, 46 runCount, 47 getIMPSAllocator()->formatSize(currentCopyOverhead), 48 BaseInfo::toString(0.0, schedulingTime)); 49} 50 51const std::string CopyInfo::toString(double gpuTime, double schedulingTime) const { 52 return fmt::format("{}Copy{}: {} --> {} (len={}{})", 53 // Copies could be using Blit Encoder, or using regular 54 // memcpy() on Unified memory 55 usesBlitter ? "Blit" : "Mem", 56 // CopySync indicates COMMIT_AND_WAIT was used to synchronize 57 // the GPU stream with CPU after the blocking copy 58 isNonBlocking ? "" : "Sync", 59 srcStrKey, 60 dstStrKey, 61 getIMPSAllocator()->formatSize(length), 62 BaseInfo::toString(gpuTime, schedulingTime)); 63} 64 65std::string CopyInfo::buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId) { 66 if (tensor.has_value()) { 67 return BaseInfo::buildTensorString(*tensor, includeBufferId); 68 } 69 // if tensor is not defined (e.g., copy_blit_mps()), then use buffer 70 // pointer to build the string. 71 const bool isBufferOnMPS = isStorageOnMPS(buffer, tensor); 72 return fmt::format("{}:{:p}", isBufferOnMPS ? "MPS" : "CPU", buffer); 73} 74 75MPSProfiler::MPSProfiler() : m_os_log_events(nullptr), m_os_log_intervals(nullptr) { 76 // see enum LogOptions for the description. 77 static const char* log_options_str = getenv(kEVLogProfileInfoStr); 78 m_log_options = log_options_str ? strtol(log_options_str, nullptr, 0) : 0; 79 // see enums profilerOptions and SignpostTypes for the description. 80 static const char* trace_signpost_str = getenv(kEVTraceSignpostsStr); 81 uint32_t trace_signposts = trace_signpost_str ? strtol(trace_signpost_str, nullptr, 0) : 0; 82 83 TORCH_CHECK(m_log_options <= LogOptions::LOG_COUNT, 84 "invalid log options ", 85 m_log_options, 86 " passed to ", 87 kEVLogProfileInfoStr) 88 // lower 16 bits used for options (see enum ProfileOptions) 89 m_profile_options |= trace_signposts & 0xFFFF; 90 TORCH_CHECK(m_profile_options <= ProfileOptions::OPTIONS_COUNT, 91 "invalid profiling options ", 92 trace_signposts, 93 " passed to ", 94 kEVTraceSignpostsStr) 95 // upper 16 bits used for signpost types (see enum SignpostTypes) 96 m_signpost_types |= trace_signposts & 0xFFFF0000; 97 TORCH_CHECK(m_signpost_types <= SignpostTypes::SIGNPOST_COUNT, 98 "invalid signpost types ", 99 trace_signposts, 100 " passed to ", 101 kEVTraceSignpostsStr) 102 currentSigint.sa_handler = nullptr; 103 previousSigint.sa_handler = nullptr; 104 105 initialize(); 106} 107 108MPSProfiler::~MPSProfiler() { 109 // first make sure completion handlers are completed 110 auto stream = getDefaultMPSStream(); 111 dispatch_sync(stream->queue(), ^() { 112 if (hasPendingCompletionHandlers) { 113 stream->synchronize(SyncType::COMMIT_AND_WAIT); 114 } 115 }); 116 logProfilingStats(); 117 118 if (m_os_log_events) { 119 os_release(m_os_log_events); 120 } 121 if (m_os_log_intervals) { 122 os_release(m_os_log_intervals); 123 } 124} 125 126void MPSProfiler::initialize() { 127 if ((m_signpost_types == SignpostTypes::SIGNPOST_NONE) && 128 (m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL)) { 129 m_profile_options |= ProfileOptions::ALL_SIGNPOST_INTERVALS; 130 } 131 132 if (m_profile_options & (ProfileOptions::ALL_SIGNPOST_EVENTS | ProfileOptions::ALL_SIGNPOST_INTERVALS)) { 133 // enable all signposts types 134 m_signpost_types |= (SignpostTypes::RUN_OPERATION | SignpostTypes::CPU_FALLBACK | SignpostTypes::BLIT_COPY); 135 136 if (m_profile_options & ProfileOptions::ALL_SIGNPOST_EVENTS) { 137 m_profile_options |= ProfileOptions::USE_EVENTS; 138 } 139 if (m_profile_options & ProfileOptions::ALL_SIGNPOST_INTERVALS) { 140 m_profile_options |= ProfileOptions::USE_INTERVALS; 141 } 142 } 143 144 if (m_log_options & LogOptions::ALL_STATS) { 145 m_log_options |= LogOptions::OPERATION_STATS | LogOptions::COPY_STATS | LogOptions::CPU_FALLBACK_STATS; 146 } 147 148 if (m_signpost_types != SignpostTypes::SIGNPOST_NONE) { 149 // if no signpost options passed, use interval mode by default 150 if (!(m_profile_options & (ProfileOptions::USE_EVENTS | ProfileOptions::USE_INTERVALS))) { 151 m_profile_options |= ProfileOptions::USE_INTERVALS; 152 } 153 if ((m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL) && 154 (m_profile_options & ProfileOptions::USE_EVENTS)) { 155 TORCH_CHECK((m_profile_options & ProfileOptions::USE_INTERVALS), 156 "the option 'INCLUDE_SCHEDULE_INTERVAL' only works for interval-based signposts"); 157 } 158 159 // technically, it's possible to trace both events and intervals at the same time 160 if (m_profile_options & ProfileOptions::USE_EVENTS) { 161 if (!m_os_log_events) { 162 m_os_log_events = os_log_create(kMPSProfilerSubSystemStr, kMPSCategoryEventsStr); 163 TORCH_CHECK(m_os_log_events, "failed to create OS signpost log for events profiler"); 164 } 165 // include GPU time in metadata for event-based intervals by default, since 166 // events are marked in Metal Completion Handlers which outputs GPU time 167 m_log_options |= INCLUDE_GPU_TIME; 168 } 169 if (m_profile_options & ProfileOptions::USE_INTERVALS) { 170 if (!m_os_log_intervals) { 171 m_os_log_intervals = os_log_create(kMPSProfilerSubSystemStr, kMPSCategoryIntervalsStr); 172 TORCH_CHECK(m_os_log_intervals, "failed to create OS signpost log for intervals profiler"); 173 } 174 } 175 } 176 177 if (m_log_options & LogOptions::COPY_STATS) { 178 if (m_copy_stat_list.empty()) { 179 m_copy_stat_list.emplace(CopyInfo::Kind::MPS_TO_MPS, std::make_unique<CopyStat>("MPS to MPS")); 180 m_copy_stat_list.emplace(CopyInfo::Kind::MPS_TO_CPU, std::make_unique<CopyStat>("MPS to CPU")); 181 m_copy_stat_list.emplace(CopyInfo::Kind::CPU_TO_MPS, std::make_unique<CopyStat>("CPU to MPS")); 182 } 183 } 184 185 // used to capture sigint signal to log profiling stats 186 if (m_log_options & (LogOptions::OPERATION_STATS | LogOptions::COPY_STATS | LogOptions::CPU_FALLBACK_STATS)) { 187 if (!currentSigint.sa_handler) { 188 currentSigint.sa_handler = &handleIntSignal; 189 currentSigint.sa_flags = SA_RESTART; 190 sigfillset(¤tSigint.sa_mask); 191 if (sigaction(SIGINT, ¤tSigint, &previousSigint) == -1) { 192 AT_ERROR("Cannot install SIGINT handler for MPSProfiler."); 193 } 194 } 195 } 196} 197 198void MPSProfiler::StartTrace(const std::string& mode, bool waitUntilCompleted) { 199 TORCH_CHECK(m_profile_options == ProfileOptions::OPTIONS_NONE, "Tracing Signposts is already enabled "); 200 201 std::stringstream ss(mode); 202 std::string token; 203 while (getline(ss, token, ',')) { 204 if (!token.empty()) { 205 if (token == "interval") { 206 m_profile_options |= ProfileOptions::ALL_SIGNPOST_INTERVALS; 207 } else if (token == "event") { 208 m_profile_options |= ProfileOptions::ALL_SIGNPOST_EVENTS; 209 } else { 210 AT_ERROR("Invalid Signpost trace mode: ", token); 211 } 212 } 213 } 214 if (m_profile_options != ProfileOptions::OPTIONS_NONE) { 215 if (waitUntilCompleted) { 216 m_profile_options |= ProfileOptions::WAIT_UNTIL_COMPLETED; 217 } 218 initialize(); 219 } 220} 221 222void MPSProfiler::StopTrace() { 223 m_profile_options = ProfileOptions::OPTIONS_NONE; 224 m_signpost_types = SignpostTypes::SIGNPOST_NONE; 225} 226 227void MPSProfiler::beginProfileExecution(BaseInfo& info, bool cpuExecution) { 228 // see comments in isProfileInfoLoggingEnabled() 229 if (isProfileInfoLoggingEnabled(info.type, /*isExecutionEnded*/ false)) { 230 fmt::print(stderr, "{}\n", info.toString()); 231 } 232 SignpostTypes signpostType = getSignpostType(info.type); 233 if (!(m_signpost_types & signpostType)) { 234 return; 235 } 236 if (m_profile_options & ProfileOptions::USE_EVENTS) { 237 info.eventSignpostId = generateSignpostId(OS_SIGNPOST_EVENT); 238 } 239 if (m_profile_options & ProfileOptions::USE_INTERVALS) { 240 info.intervalSignpostId = generateSignpostId(OS_SIGNPOST_INTERVAL_BEGIN); 241 // if scheduling part is included, we begin the interval early in here, 242 // otherwise we begin when the scheduledHandler callback is triggered. 243 if ((m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL) || cpuExecution) { 244 beginSignpostInterval(signpostType, info.intervalSignpostId, info.toString()); 245 info.completed = false; 246 // for graphs, we add the scheduleHandler in beginProfileGPUInterval() 247 } else if (info.type == BaseInfo::Type::KERNEL || info.type == BaseInfo::Type::COPY) { 248 addProfilerScheduledHandler(info); 249 } 250 } 251} 252 253void MPSProfiler::endProfileExecution(BaseInfo& info, 254 os_signpost_id_t event_signpost_id, 255 os_signpost_id_t interval_signpost_id, 256 double gpuTime, 257 double schedulingTime) { 258 const SignpostTypes signpostType = getSignpostType(info.type); 259 260 if (info.type == BaseInfo::Type::COPY) { 261 updateCopyStats(static_cast<CopyInfo&>(info), gpuTime, schedulingTime); 262 } else { 263 info.totalGpuTime = info.totalGpuTime + gpuTime; 264 info.totalSchedulingTime = info.totalSchedulingTime + schedulingTime; 265 } 266 // if Kernel time is not included in metadata separately, we add it to gpuTime in metadata 267 if (gpuTime > 0.0 && !(m_log_options & LogOptions::INCLUDE_KERNEL_TIME)) { 268 gpuTime += schedulingTime; 269 schedulingTime = 0; 270 } 271 const std::string& infoStr = info.toString(gpuTime, schedulingTime); 272 // see comments in isProfileInfoLoggingEnabled() 273 if (isProfileInfoLoggingEnabled(info.type, /*isExecutionEnded*/ true)) { 274 fmt::print(stderr, "{}\n", infoStr); 275 } 276 // it is possible to use both interval and event based signposts at the same time 277 if ((m_profile_options & ProfileOptions::USE_EVENTS) && event_signpost_id) { 278 emitSignpostEvent(signpostType, event_signpost_id, infoStr); 279 } 280 // GPU time for signpost intervals is calculated based on its duration 281 if ((m_profile_options & ProfileOptions::USE_INTERVALS) && interval_signpost_id) { 282 endSignpostInterval(signpostType, interval_signpost_id); 283 } 284 info.completed = true; 285} 286 287uint64_t MPSProfiler::beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph) { 288 // only do profiling if operation execution profiling or logging are enabled 289 if (!isOperationProfilingEnabled()) { 290 return 0; 291 } 292 if (m_op_info_list.count(uintptr_t(handle)) == 0) { 293 auto opInfo = 294 std::make_unique<OperationInfo>(handle, isGraph, isGraph ? ++m_graph_counter : ++m_kernel_counter, strKey); 295 m_op_info_list.emplace(opInfo->handle, std::move(opInfo)); 296 } 297 auto& opInfo = *m_op_info_list[uintptr_t(handle)]; 298 opInfo.strKey.assign(strKey); 299 opInfo.runCount++; 300 beginProfileExecution(opInfo); 301 302 return opInfo.profileId; 303} 304 305uint64_t MPSProfiler::beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors) { 306 if (isOperationProfilingEnabled()) { 307 const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID; 308 std::string profilerStrKey = OperationInfo::buildKernelString(kernelName, tensors, includeBufferId); 309 return beginProfileKernel(handle, profilerStrKey, false); 310 } 311 return 0; 312} 313 314void MPSProfiler::beginProfileGPUInterval(const void* handle) { 315 // this function is only relevant for interval-based Signposts which exclude 316 // schedule time (only includes GPU run time) 317 if (!(m_profile_options & ProfileOptions::USE_INTERVALS) || 318 (m_profile_options & ProfileOptions::INCLUDE_SCHEDULE_INTERVAL)) { 319 return; 320 } 321 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_op_info_list.count(uintptr_t(handle)), "Failed to get operation information!"); 322 auto& opInfo = *m_op_info_list[uintptr_t(handle)]; 323 // this begins the interval when scheduling the execution is 324 // completed already (i.e., scheduling excluded from interval) 325 addProfilerScheduledHandler(opInfo); 326} 327 328void MPSProfiler::endProfileKernel(const void* handle, SyncType syncType) { 329 // only do profiling if operation execution profiling or logging are enabled 330 if (!isOperationProfilingEnabled()) { 331 return; 332 } 333 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_op_info_list.count(uintptr_t(handle)), "Failed to get operation information!"); 334 auto& opInfo = *m_op_info_list[uintptr_t(handle)]; 335 addProfilerCompletedHandler(opInfo, syncType); 336} 337 338uint64_t MPSProfiler::beginProfileCPUFallback(const std::string& opName, const TensorList& tensors) { 339 if (m_cpu_fb_info_list.count(opName) == 0) { 340 auto cpuFbInfo = std::make_unique<CpuFbInfo>(++m_cpu_fb_counter, opName); 341 m_cpu_fb_info_list.emplace(opName, std::move(cpuFbInfo)); 342 } 343 auto& cpuFbInfo = *m_cpu_fb_info_list[opName]; 344 cpuFbInfo.runCount++; 345 cpuFbInfo.startTime = BaseInfo::getTime(); 346 const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID; 347 cpuFbInfo.strKey = OperationInfo::buildKernelString(opName, tensors, includeBufferId); 348 cpuFbInfo.updateCopyOverhead(tensors); 349 beginProfileExecution(cpuFbInfo, true); 350 351 return cpuFbInfo.profileId; 352} 353 354void MPSProfiler::endProfileCPUFallback(const std::string& opName) { 355 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_cpu_fb_info_list.count(opName), "Failed to get CPU Fallback information!"); 356 auto& cpuFbInfo = *m_cpu_fb_info_list[opName]; 357 // CPU time in ms 358 double cpuTime = double(BaseInfo::getTime() - cpuFbInfo.startTime) * 1e-6; 359 endProfileExecution(cpuFbInfo, cpuFbInfo.eventSignpostId, cpuFbInfo.intervalSignpostId, 0, cpuTime); 360} 361 362uint64_t MPSProfiler::beginProfileCopy(const void* srcBuffer, 363 const void* dstBuffer, 364 const OptionalTensorRef srcTensor, 365 const OptionalTensorRef dstTensor, 366 size_t length, 367 bool isNonBlocking, 368 bool usesBlitter) { 369 if (!isCopyProfilingEnabled()) { 370 return 0; 371 } 372 const bool includeBufferId = m_log_options & LogOptions::INCLUDE_BUFFER_ID; 373 const uint64_t profileId = ++m_copy_counter; 374 auto copyInfo = std::make_unique<CopyInfo>(dstBuffer, length, profileId, isNonBlocking, usesBlitter); 375 copyInfo->srcStrKey = CopyInfo::buildTensorString(srcBuffer, srcTensor, includeBufferId); 376 copyInfo->dstStrKey = CopyInfo::buildTensorString(dstBuffer, dstTensor, includeBufferId); 377 copyInfo->kind = CopyInfo::getCopyKind(srcBuffer, dstBuffer, srcTensor, dstTensor); 378 if (!usesBlitter) { 379 // for copies that don't use blitters, we measure CPU time 380 copyInfo->startTime = BaseInfo::getTime(); 381 } 382 // don't generate signposts if the non-blocking copy is not using the blitter 383 if (usesBlitter || !isNonBlocking) { 384 beginProfileExecution(*copyInfo, !usesBlitter); 385 } 386 // this should not happen since we erase the copy info after profiling/logging it. 387 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_copy_info_list.count(profileId) == 0); 388 // the copy info isn't retained in the list, so we erase the completed ones 389 for (auto it = m_copy_info_list.begin(), last = m_copy_info_list.end(); it != last;) { 390 if (it->second->completed) { 391 it = m_copy_info_list.erase(it); 392 } else { 393 ++it; 394 } 395 } 396 m_copy_info_list.emplace(profileId, std::move(copyInfo)); 397 398 return profileId; 399} 400 401void MPSProfiler::endProfileCopy(uint64_t profileId, SyncType syncType) { 402 // this is just an identifier, and not used to access memory 403 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_copy_info_list.count(profileId), "Failed to get copy information!"); 404 auto& copyInfo = *m_copy_info_list[profileId]; 405 if (copyInfo.usesBlitter) { 406 addProfilerCompletedHandler(copyInfo, syncType); 407 } else { 408 double cpuTime = double(BaseInfo::getTime() - copyInfo.startTime) * 1e-6; 409 endProfileExecution(copyInfo, copyInfo.eventSignpostId, copyInfo.intervalSignpostId, 0, cpuTime); 410 } 411} 412 413void MPSProfiler::addProfilerScheduledHandler(BaseInfo& info) { 414 const SignpostTypes signpostType = getSignpostType(info.type); 415 const os_signpost_id_t intervalSignpostId = info.intervalSignpostId; 416 417 auto m_stream = getDefaultMPSStream(); 418 // NOTE: the following block isn't thread-safe 419 [m_stream->commandBuffer() addScheduledHandler:^(id<MTLCommandBuffer> cb) { 420 // begin the interval once scheduling has completed (if INCLUDE_SCHEDULE_INTERVAL flag is disabled) 421 beginSignpostInterval(signpostType, intervalSignpostId, info.toString()); 422 info.completed = false; 423 }]; 424} 425 426void MPSProfiler::updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime) { 427 if (!(m_log_options & LogOptions::COPY_STATS)) { 428 return; 429 } 430 auto& copyStat = *m_copy_stat_list[copyInfo.kind]; 431 copyStat.totalCount++; 432 copyStat.length += copyInfo.length; 433 copyStat.totalGpuTime = copyStat.totalGpuTime + gpuTime; 434 copyStat.totalSchedulingTime = copyStat.totalSchedulingTime + schedulingTime; 435 if (copyInfo.length <= sizeof(int64_t)) { 436 copyStat.scalarsCount++; 437 copyStat.scalarsGpuTime = copyStat.scalarsGpuTime + gpuTime; 438 } 439 copyStat.blockingCount += !copyInfo.isNonBlocking ? 1 : 0; 440 copyStat.memcpyCount += !copyInfo.usesBlitter ? 1 : 0; 441} 442 443void MPSProfiler::addProfilerCompletedHandler(BaseInfo& info, SyncType syncType) { 444 const os_signpost_id_t intervalSignpostId = info.intervalSignpostId; 445 const os_signpost_id_t eventSignpostId = info.eventSignpostId; 446 447 // signpost ID is used only for interval-based signposts, and must be non-zero 448 if (m_profile_options & ProfileOptions::USE_INTERVALS) { 449 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(intervalSignpostId, "Signpost interval has no identifier!"); 450 } 451 // reset signpostIds for sanity check on next call 452 info.intervalSignpostId = 0; 453 info.eventSignpostId = 0; 454 hasPendingCompletionHandlers = true; 455 456 auto m_stream = getDefaultMPSStream(); 457 // NOTE: the following block isn't thread-safe 458 [m_stream->commandBuffer() addCompletedHandler:^(id<MTLCommandBuffer> cb) { 459 CFTimeInterval gpuTime = cb.GPUEndTime > cb.GPUStartTime ? (cb.GPUEndTime - cb.GPUStartTime) * 1000.0 : 0.; 460 CFTimeInterval schedulingTime = 461 cb.kernelEndTime > cb.kernelStartTime ? (cb.kernelEndTime - cb.kernelStartTime) * 1000.0 : 0.; 462 463 endProfileExecution(info, eventSignpostId, intervalSignpostId, gpuTime, schedulingTime); 464 hasPendingCompletionHandlers = false; 465 }]; 466 467 m_stream->synchronize((m_profile_options & ProfileOptions::WAIT_UNTIL_COMPLETED) ? SyncType::COMMIT_AND_WAIT 468 : syncType); 469} 470 471void MPSProfiler::logOperationsProfilingStats(std::FILE* f) const { 472 if (m_op_info_list.empty()) { 473 // this is not an error, but to let the user know that the 474 // LogOptions::KERNEL_STATS that they passed to EV is not yielding anything. 475 fmt::print(f, "There are no MPS operations logged for profiling\n"); 476 return; 477 } 478 // dump the ops info into a vector to sort them 479 std::vector<OperationInfo*> opsList; 480 std::transform(m_op_info_list.begin(), m_op_info_list.end(), std::back_inserter(opsList), [](auto& opInfo) { 481 return opInfo.second.get(); 482 }); 483 484 // sort based on "Mean GPU time" in descending order 485 std::sort(opsList.begin(), opsList.end(), [](const OperationInfo* a, const OperationInfo* b) { 486 return (a->totalGpuTime / double(a->runCount)) > (b->totalGpuTime / double(b->runCount)); 487 }); 488 // print the table of operation profiling stats 489 fmt::print(f, 490 "\n{:-^200}\n{:^6}|{:^7}|{:^15}|{:^14}|{:^15}| {}\n{:-^200}\n", 491 fmt::format(" MPS Operations Profiling: {} graphs, {} kernels ", m_graph_counter, m_kernel_counter), 492 "ID", 493 "#Runs", 494 "Mean KRNL(ms)", 495 "Mean GPU(ms)", 496 "Total GPU(ms)", 497 "Operation Name", 498 ""); 499 500 for (const auto& opInfo : opsList) { 501 fmt::print(f, 502 "{:^7}{:^8}{:^16}{:^15}{:^16} {}\n", 503 fmt::format("{}{}", opInfo->type == BaseInfo::Type::GRAPH ? "G" : "K", opInfo->profileId), 504 opInfo->runCount, 505 fmt::format("{:.3f}", opInfo->totalSchedulingTime / double(opInfo->runCount)), 506 fmt::format("{:.3f}", opInfo->totalGpuTime / double(opInfo->runCount)), 507 fmt::format("{:.3f}", opInfo->totalGpuTime.load()), 508 opInfo->strKey); 509 } 510} 511 512void MPSProfiler::logCPUFallbackProfilingStats(std::FILE* f) const { 513 if (m_cpu_fb_info_list.empty()) { 514 // this is not an error, but to let the user know that the 515 // LogOptions::KERNEL_STATS that they passed to EV is not yielding anything. 516 fmt::print(f, "There are no CPU Fallbacks logged for profiling\n"); 517 return; 518 } 519 size_t totalCopyOverhead = 0; 520 size_t totalRunCount = 0; 521 double totalCPUTime = 0.; 522 // dump the map's info into a vector to sort them 523 std::vector<CpuFbInfo*> cpuFbList; 524 std::transform( 525 m_cpu_fb_info_list.begin(), m_cpu_fb_info_list.end(), std::back_inserter(cpuFbList), [&](auto& cpuFbInfo) { 526 auto cpuFbInfoPtr = cpuFbInfo.second.get(); 527 totalRunCount += cpuFbInfoPtr->runCount; 528 totalCopyOverhead += cpuFbInfoPtr->totalCopyOverhead; 529 totalCPUTime += cpuFbInfoPtr->totalSchedulingTime; 530 return cpuFbInfoPtr; 531 }); 532 533 // sort based on "Mean CPU time" in descending order 534 std::sort(cpuFbList.begin(), cpuFbList.end(), [](const CpuFbInfo* a, const CpuFbInfo* b) { 535 return (a->totalSchedulingTime / double(a->runCount)) > (b->totalSchedulingTime / double(b->runCount)); 536 }); 537 538 // print the table of CPU Fallback profiling stats 539 fmt::print(f, 540 "\n{:-^150}\n{:^5}|{:^7}|{:^14}|{:^15}|{:^15}| {}\n{:-^150}\n", 541 fmt::format(" CPU Fallback Profiling: Total {} Runs, {:.2f} ms, {} Copies ", 542 totalRunCount, 543 totalCPUTime, 544 getIMPSAllocator()->formatSize(totalCopyOverhead)), 545 "ID", 546 "#Runs", 547 "Mean CPU(ms)", 548 "Total CPU(ms)", 549 "Copy Overhead", 550 "Operation Name", 551 ""); 552 553 for (const auto& cpuFbInfo : cpuFbList) { 554 fmt::print(f, 555 "{:^6}{:^8}{:^15}{:^16}{:^16} {}\n", 556 cpuFbInfo->profileId, 557 cpuFbInfo->runCount, 558 fmt::format("{:.3f}", cpuFbInfo->totalSchedulingTime / double(cpuFbInfo->runCount)), 559 fmt::format("{:.3f}", cpuFbInfo->totalSchedulingTime.load()), 560 getIMPSAllocator()->formatSize(cpuFbInfo->totalCopyOverhead), 561 cpuFbInfo->opName); 562 } 563} 564 565void MPSProfiler::logCopyProfilingStats(std::FILE* f) const { 566 size_t totalCopiesCount = 0; 567 size_t totalCopySize = 0; 568 size_t totalScalarCopyCount = 0; 569 570 for (const auto& copyStatPair : m_copy_stat_list) { 571 const auto& copyStat = *copyStatPair.second; 572 totalCopiesCount += copyStat.totalCount; 573 totalCopySize += copyStat.length; 574 totalScalarCopyCount += copyStat.scalarsCount; 575 } 576 if (totalCopiesCount == 0) { 577 // this is not an error, but to let the user know that the 578 // LogOptions::COPY_STATS that they passed to EV is not yielding anything. 579 fmt::print(f, "There are no copies logged for profiling\n"); 580 return; 581 } 582 583 // print the table of copy profiling stats 584 fmt::print(f, 585 "\n{:-^160}\n{:^12}|{:^10}|{:^17}|{:^16}|{:^15}|{:^9}|{:^13}|{:^10}|{:^8}\n{:-^160}\n", 586 fmt::format(" MPS Copy Profiling: {} total copies ({}), {} scalar copies ", 587 totalCopiesCount, 588 getIMPSAllocator()->formatSize(totalCopySize), 589 totalScalarCopyCount), 590 "Kind", 591 "Total#", 592 "Total Size", 593 "Total KRNL(ms)", 594 "Total GPU(ms)", 595 "Scalars", 596 "Scalars GPU", 597 "Blocking", 598 "memcpy", 599 ""); 600 601 for (const auto& copyStatPair : m_copy_stat_list) { 602 const auto& copyStat = *copyStatPair.second; 603 if (copyStat.totalCount > 0) { 604 fmt::print( 605 f, 606 "{:^13}{:^11}{:^18}{:^17}{:^16}{:^10}{:^14}{:^11}{:^9}\n", 607 copyStat.kindStr, 608 copyStat.totalCount, 609 getIMPSAllocator()->formatSize(copyStat.length), 610 fmt::format("{:.3f}", copyStat.totalSchedulingTime.load()), 611 fmt::format("{:.3f}", copyStat.totalGpuTime.load()), 612 copyStat.scalarsCount, 613 fmt::format("{:.2f} %", 614 copyStat.totalGpuTime > 0.0 615 ? (1.0 - ((copyStat.totalGpuTime - copyStat.scalarsGpuTime) / copyStat.totalGpuTime)) * 100.0 616 : 0.0), 617 copyStat.blockingCount, 618 copyStat.memcpyCount); 619 } 620 } 621} 622 623void MPSProfiler::logProfilingStats() { 624 if (hasLoggedStats.exchange(true)) { 625 return; 626 } 627 // logs kernel profiling stats when the process ends (if enabled). 628 if (m_log_options & LogOptions::OPERATION_STATS) { 629 logOperationsProfilingStats(stderr); 630 } 631 // logs CPU Fallback profiling stats when the process ends (if enabled). 632 if (m_log_options & LogOptions::CPU_FALLBACK_STATS) { 633 logCPUFallbackProfilingStats(stderr); 634 } 635 // logs copies profiling stats when the process ends (if enabled). 636 if (m_log_options & LogOptions::COPY_STATS) { 637 logCopyProfilingStats(stderr); 638 } 639} 640 641bool MPSProfiler::isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded) { 642 bool isInfoLoggingEnabled = false; 643 // logging the operations, copies, cpu fallbacks info during the execution 644 // is enabled via the env-var defined in kEVLogProfileInfoStr 645 switch (infoType) { 646 case BaseInfo::Type::GRAPH: 647 case BaseInfo::Type::KERNEL: 648 isInfoLoggingEnabled = (m_log_options & LogOptions::OPERATION_INFO); 649 break; 650 case BaseInfo::Type::COPY: 651 isInfoLoggingEnabled = (m_log_options & LogOptions::COPY_INFO); 652 break; 653 case BaseInfo::Type::CPU_FALLBACK: 654 isInfoLoggingEnabled = (m_log_options & LogOptions::CPU_FALLBACK_INFO); 655 break; 656 default: 657 AT_ERROR("invalid profiling info type"); 658 } 659 if (!isInfoLoggingEnabled) { 660 return false; 661 } 662 // if GPU/Kernel times are included then log info when op execution ends 663 bool logWhenExecutionEnds = m_log_options & (LogOptions::INCLUDE_GPU_TIME | LogOptions::INCLUDE_KERNEL_TIME); 664 return isExecutionEnded ? logWhenExecutionEnds : !logWhenExecutionEnds; 665} 666 667void MPSProfiler::emitSignpostEvent(SignpostTypes signpost_type, 668 os_signpost_id_t signpost_id, 669 const std::string& msg_str) const { 670 if (!(m_signpost_types & signpost_type) || !signpost_id || !m_os_log_events || 671 !os_signpost_enabled(m_os_log_events)) { 672 return; 673 } 674 const char* msg = msg_str.c_str(); 675 676 // need to use switch-case as the signpost names must be literal strings 677 switch (signpost_type) { 678 case SignpostTypes::RUN_OPERATION: 679 os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostRunOperationStr, "%s", msg); 680 break; 681 case SignpostTypes::BLIT_COPY: 682 os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostBlitCopyStr, "%s", msg); 683 break; 684 case SignpostTypes::CPU_FALLBACK: 685 os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostCPUFallbacksStr, "%s", msg); 686 break; 687 default: 688 AT_ERROR("unknown SignpostType in MPS profiler"); 689 } 690} 691 692void MPSProfiler::beginSignpostInterval(SignpostTypes signpost_type, 693 os_signpost_id_t signpost_id, 694 const std::string& msg_str) const { 695 if (!(m_signpost_types & signpost_type) || !signpost_id || !m_os_log_intervals || 696 !os_signpost_enabled(m_os_log_intervals)) { 697 return; 698 } 699 const char* msg = msg_str.c_str(); 700 701 switch (signpost_type) { 702 case SignpostTypes::RUN_OPERATION: 703 os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostRunOperationStr, "%s", msg); 704 break; 705 case SignpostTypes::BLIT_COPY: 706 os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostBlitCopyStr, "%s", msg); 707 break; 708 case SignpostTypes::CPU_FALLBACK: 709 os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr, "%s", msg); 710 break; 711 default: 712 AT_ERROR("unknown SignpostType in MPS profiler"); 713 } 714} 715 716void MPSProfiler::endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const { 717 if (!m_os_log_intervals || !os_signpost_enabled(m_os_log_intervals)) { 718 return; 719 } 720 switch (signpost_type) { 721 case SignpostTypes::RUN_OPERATION: 722 os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostRunOperationStr); 723 break; 724 case SignpostTypes::BLIT_COPY: 725 os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostBlitCopyStr); 726 break; 727 case SignpostTypes::CPU_FALLBACK: 728 os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr); 729 break; 730 default: 731 AT_ERROR("unknown SignpostType in MPS profiler"); 732 } 733} 734 735os_signpost_id_t MPSProfiler::generateSignpostId(os_signpost_type_t signpostType, const void* ptr) { 736 os_log_t os_log = signpostType == OS_SIGNPOST_EVENT ? m_os_log_events : m_os_log_intervals; 737 if (ptr) { 738 return os_signpost_id_make_with_pointer(os_log, ptr); 739 } 740 return os_signpost_id_generate(os_log); 741} 742 743MPSProfiler::SignpostTypes MPSProfiler::getSignpostType(BaseInfo::Type infoType) { 744 switch (infoType) { 745 case BaseInfo::Type::GRAPH: 746 case BaseInfo::Type::KERNEL: 747 return SignpostTypes::RUN_OPERATION; 748 case BaseInfo::Type::COPY: 749 return SignpostTypes::BLIT_COPY; 750 case BaseInfo::Type::CPU_FALLBACK: 751 return SignpostTypes::CPU_FALLBACK; 752 default: 753 AT_ERROR("invalid profiling info type"); 754 } 755} 756 757void MPSProfiler::handleIntSignal(int signal) { 758 getMPSProfiler().logProfilingStats(); 759 if (previousSigint.sa_handler) { 760 previousSigint.sa_handler(signal); 761 } 762} 763 764// used to capture sigint signal to log profiling stats 765struct sigaction MPSProfiler::currentSigint {}; 766struct sigaction MPSProfiler::previousSigint {}; 767 768bool MPSProfiler::isCapturing() const { 769 return [captureManager isCapturing]; 770} 771 772bool MPSProfiler::isCaptureEnabled() const { 773 if (captureManager == nil) { 774 captureManager = [MTLCaptureManager sharedCaptureManager]; 775 } 776 static bool isEnabled = [this]() { 777 return [captureManager supportsDestination:MTLCaptureDestinationGPUTraceDocument]; 778 }(); 779 return isEnabled; 780} 781 782void MPSProfiler::startCapture(const std::string& name, MPSStream* stream) { 783 if (captureManager == nil) { 784 captureManager = [MTLCaptureManager sharedCaptureManager]; 785 } 786 NSError* err = nil; 787 NSString* fname = [NSString stringWithFormat:@"%04d-%s.gputrace", captureCount++, name.c_str()]; 788 MTLCaptureDescriptor* captureDescriptor = [MTLCaptureDescriptor new]; 789 captureDescriptor.captureObject = stream ? (id)stream->commandQueue() : (id)MPSDevice::getInstance()->device(); 790 captureDescriptor.destination = MTLCaptureDestinationGPUTraceDocument; 791 captureDescriptor.outputURL = [NSURL fileURLWithPath:fname]; 792 auto rc = [captureManager startCaptureWithDescriptor:captureDescriptor error:&err]; 793 TORCH_CHECK(rc, "Failed to start capture of ", [fname UTF8String], " error ", [[err description] UTF8String]); 794} 795 796void MPSProfiler::stopCapture(MPSStream* stream) { 797 if (stream) { 798 stream->synchronize(SyncType::COMMIT); 799 } 800 [captureManager stopCapture]; 801} 802 803} // namespace Profiler 804 805Profiler::MPSProfiler& getMPSProfiler() { 806 static std::unique_ptr<Profiler::MPSProfiler> mps_profiler; 807 if (mps_profiler == nullptr) { 808 mps_profiler = std::make_unique<Profiler::MPSProfiler>(); 809 } 810 return *mps_profiler; 811} 812 813} // namespace at::mps 814