1 #include <torch/csrc/autograd/function.h>
2 #include <torch/csrc/profiler/kineto_shim.h>
3 #include <torch/csrc/profiler/util.h>
4
5 #include <c10/util/ArrayRef.h>
6 #include <c10/util/irange.h>
7 #include <fmt/format.h>
8 #include <fmt/ranges.h>
9
10 #ifdef USE_KINETO
11 #include <libkineto.h>
12 #endif
13 #ifdef USE_DISTRIBUTED
14 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
15 #endif // USE_DISTRIBUTED
16
17 namespace torch::profiler::impl {
18
19 namespace {
20 std::optional<bool> soft_assert_raises_;
21 } // namespace
22
setSoftAssertRaises(std::optional<bool> value)23 void setSoftAssertRaises(std::optional<bool> value) {
24 soft_assert_raises_ = value;
25 }
26
softAssertRaises()27 bool softAssertRaises() {
28 return soft_assert_raises_.value_or(false);
29 }
30
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,const char * args)31 void logSoftAssert(
32 const char* func,
33 const char* file,
34 uint32_t line,
35 const char* cond,
36 const char* args) {
37 #ifdef USE_KINETO
38 std::string error;
39 error = fmt::format(
40 "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
41 cond,
42 file,
43 line,
44 func,
45 args);
46 // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
47 kineto::logInvariantViolation(cond, error, "", "");
48 #endif
49 }
50
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,const std::string & args)51 void logSoftAssert(
52 const char* func,
53 const char* file,
54 uint32_t line,
55 const char* cond,
56 const std::string& args) {
57 #ifdef USE_KINETO
58 std::string error;
59 error = fmt::format(
60 "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
61 cond,
62 file,
63 line,
64 func,
65 args);
66 // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
67 kineto::logInvariantViolation(cond, error, "", "");
68 #endif
69 }
70
71 // ----------------------------------------------------------------------------
72 // -- NVTX --------------------------------------------------------------------
73 // ----------------------------------------------------------------------------
getNvtxStr(const char * name,int64_t sequence_nr,const std::vector<std::vector<int64_t>> & shapes,at::RecordFunctionHandle op_id,const std::list<std::pair<at::RecordFunctionHandle,int>> & input_op_ids)74 std::string getNvtxStr(
75 const char* name,
76 int64_t sequence_nr,
77 const std::vector<std::vector<int64_t>>& shapes,
78 at::RecordFunctionHandle op_id,
79 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
80 if (sequence_nr >= -1 || !shapes.empty()) {
81 std::string str;
82 if (sequence_nr >= 0) {
83 str = fmt::format("{}, seq = {}", name, sequence_nr);
84 } else if (sequence_nr == -1) {
85 str = name;
86 } else {
87 #if defined(USE_ROCM)
88 // Only ROCM supports < -1 sequence_nr
89 str = name;
90 #endif
91 }
92 if (op_id > 0) {
93 str = fmt::format("{}, op_id = {}", str, op_id);
94 }
95 if (!shapes.empty()) {
96 str = fmt::format("{}, sizes = {}", str, shapesToStr(shapes));
97 }
98 // Include the op ids of the input edges so
99 // you can build the network graph
100 if (!input_op_ids.empty()) {
101 str = fmt::format(
102 "{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids));
103 }
104 return str;
105 } else {
106 return name;
107 }
108 }
109
110 // ----------------------------------------------------------------------------
111 // -- Op context (shapes, call stack) -----------------------------------------
112 // ----------------------------------------------------------------------------
prepareCallstack(const std::vector<jit::StackEntry> & cs)113 std::vector<FileLineFunc> prepareCallstack(
114 const std::vector<jit::StackEntry>& cs) {
115 std::vector<FileLineFunc> entries;
116 entries.reserve(cs.size());
117 for (const auto& entry : cs) {
118 auto& range = entry.range;
119 if (range.source()) {
120 auto& src = range.source();
121 if (src && src->filename()) {
122 auto line =
123 src->starting_line_no() + src->lineno_for_offset(range.start());
124 entries.emplace_back(
125 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
126 FileLineFunc{*(src->filename()), line, entry.filename});
127 }
128 }
129 }
130 return entries;
131 }
132
callstackStr(const std::vector<FileLineFunc> & cs)133 std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
134 std::vector<std::string> cs_str;
135 cs_str.reserve(cs.size());
136 for (const auto& entry : cs) {
137 std::stringstream loc;
138 loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
139 cs_str.push_back(loc.str());
140 }
141 return cs_str;
142 }
143
stacksToStr(const std::vector<std::string> & stacks,const char * delim)144 std::string stacksToStr(
145 const std::vector<std::string>& stacks,
146 const char* delim) {
147 std::ostringstream oss;
148 std::transform(
149 stacks.begin(),
150 stacks.end(),
151 std::ostream_iterator<std::string>(oss, delim),
152 [](std::string s) -> std::string {
153 #ifdef _WIN32
154 // replace the windows backslash with forward slash
155 std::replace(s.begin(), s.end(), '\\', '/');
156 #endif
157 return s;
158 });
159 auto rc = oss.str();
160 return "\"" + rc + "\"";
161 }
162
flattenList(const c10::List<c10::IValue> & list)163 static std::vector<std::vector<int64_t>> flattenList(
164 const c10::List<c10::IValue>& list) {
165 std::vector<std::vector<int64_t>> tensor_dims;
166 for (const c10::IValue& input : list) {
167 if (input.isTensor()) {
168 const at::Tensor& tensor = input.toTensor();
169 if (tensor.defined()) {
170 tensor_dims.push_back(input.toTensor().sizes().vec());
171 }
172 }
173 }
174 return tensor_dims;
175 }
176
inputSizes(const at::RecordFunction & fn,bool flatten_list_enabled)177 std::vector<std::vector<int64_t>> inputSizes(
178 const at::RecordFunction& fn,
179 bool flatten_list_enabled) {
180 std::vector<std::vector<int64_t>> sizes;
181 sizes.reserve(fn.inputs().size());
182 for (const c10::IValue& input : fn.inputs()) {
183 if (input.isTensor()) {
184 const at::Tensor& tensor = input.toTensor();
185 if (tensor.defined()) {
186 sizes.push_back(input.toTensor().sizes().vec());
187 } else {
188 sizes.emplace_back();
189 }
190 } else if (input.isList()) {
191 std::vector<std::vector<int64_t>> tmp_sizes;
192 if (flatten_list_enabled) {
193 tmp_sizes = flattenList(input.toList());
194 }
195 // Extend the current sizes array by the array returned from input sizes
196 if (!tmp_sizes.empty()) {
197 sizes.insert(sizes.end(), tmp_sizes.begin(), tmp_sizes.end());
198 } else {
199 sizes.emplace_back();
200 }
201 } else {
202 sizes.emplace_back();
203 }
204 }
205 return sizes;
206 }
207
shapesToStr(const std::vector<std::vector<int64_t>> & shapes)208 std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
209 std::string str("[");
210 for (const auto t_idx : c10::irange(shapes.size())) {
211 if (t_idx > 0) {
212 str = fmt::format("{}, ", str);
213 }
214 str = fmt::format("{}{}", str, shapeToStr(shapes[t_idx]));
215 }
216 str = fmt::format("{}]", str);
217 return str;
218 }
219
variantShapesToStr(const std::vector<shape> & shapes)220 std::string variantShapesToStr(const std::vector<shape>& shapes) {
221 std::string str("[");
222 for (const auto t_idx : c10::irange(shapes.size())) {
223 if (t_idx > 0) {
224 str = fmt::format("{}, ", str);
225 }
226 if (std::holds_alternative<std::vector<int64_t>>(shapes[t_idx])) {
227 const auto& shape = std::get<std::vector<int64_t>>(shapes[t_idx]);
228 str = fmt::format("{}{}", str, shapeToStr(shape));
229 } else if (std::holds_alternative<std::vector<std::vector<int64_t>>>(
230 shapes[t_idx])) {
231 const auto& tensor_shape =
232 std::get<std::vector<std::vector<int64_t>>>(shapes[t_idx]);
233 if (tensor_shape.size() > TENSOR_LIST_DISPLAY_LENGTH_LIMIT) {
234 // skip if the tensor list is too long
235 str = fmt::format("{}[]", str);
236 continue;
237 }
238 str = fmt::format("{}[", str);
239 for (const auto s_idx : c10::irange(tensor_shape.size())) {
240 if (s_idx > 0) {
241 str = fmt::format("{}, ", str);
242 }
243 str = fmt::format("{}{}", str, shapeToStr(tensor_shape[s_idx]));
244 }
245 str = fmt::format("{}]", str);
246 }
247 }
248 str = fmt::format("{}]", str);
249 return str;
250 }
251
shapeToStr(const std::vector<int64_t> & shape)252 std::string shapeToStr(const std::vector<int64_t>& shape) {
253 std::string str("[");
254 for (const auto s_idx : c10::irange(shape.size())) {
255 if (s_idx > 0) {
256 str = fmt::format("{}, ", str);
257 }
258 str = fmt::format("{}{}", str, shape[s_idx]);
259 }
260 str = fmt::format("{}]", str);
261 return str;
262 }
263
inputOpIdsToStr(const std::list<std::pair<at::RecordFunctionHandle,int>> & input_op_ids)264 std::string inputOpIdsToStr(
265 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
266 std::string str("[");
267 int idx = 0;
268
269 for (const auto& op_id_info_pair : input_op_ids) {
270 if (idx++ > 0) {
271 str = fmt::format("{}, ", str);
272 }
273 // (OpId,OutputNr)
274 str = fmt::format(
275 "{}({},{})", str, op_id_info_pair.first, op_id_info_pair.second);
276 }
277 str = fmt::format("{}]", str);
278 return str;
279 }
280
strListToStr(const std::vector<std::string> & types)281 std::string strListToStr(const std::vector<std::string>& types) {
282 if (types.empty()) {
283 return "[]";
284 } else {
285 std::ostringstream oss;
286 std::transform(
287 types.begin(),
288 types.end(),
289 std::ostream_iterator<std::string>(oss, ", "),
290 [](const std::string& s) -> std::string { return "\"" + s + "\""; });
291 auto rc = oss.str();
292 rc.erase(rc.length() - 2); // remove last ", "
293 return "[" + rc + "]";
294 }
295 }
ivalueToStr(const c10::IValue & val,bool isString)296 std::string ivalueToStr(const c10::IValue& val, bool isString) {
297 std::stringstream ss;
298 if (val.isNone()) {
299 return "\"None\"";
300 } else {
301 ss.str("");
302 if (isString) {
303 ss << "\"";
304 }
305 ss << val;
306 if (isString) {
307 ss << "\"";
308 }
309 std::string mystr = ss.str();
310
311 // A double quote can cause issues with the chrome tracing so force
312 // all inputs to not contain more than the 2 we add in this function
313 int count = std::count(mystr.begin(), mystr.end(), '\"');
314 return count > 2 ? "\"None\"" : mystr;
315 }
316 }
317
ivalueListToStr(const std::vector<c10::IValue> & list)318 std::string ivalueListToStr(const std::vector<c10::IValue>& list) {
319 std::vector<std::string> concrete_str_inputs;
320 std::stringstream ss;
321 for (const auto& val : list) {
322 if (val.isNone()) {
323 concrete_str_inputs.emplace_back("");
324 } else {
325 ss.str("");
326 ss << val;
327 concrete_str_inputs.emplace_back(ss.str());
328 }
329 }
330 return strListToStr(concrete_str_inputs);
331 }
332
inputTypes(const at::RecordFunction & fn)333 std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
334 std::vector<std::string> types;
335 types.reserve(fn.inputs().size());
336 for (const c10::IValue& input : fn.inputs()) {
337 if (input.isTensor()) {
338 const at::Tensor& tensor = input.toTensor();
339 if (tensor.defined()) {
340 types.push_back(
341 static_cast<std::string>(input.toTensor().dtype().name()));
342 } else {
343 types.emplace_back();
344 }
345 } else if (input.isScalar() || input.isList()) {
346 types.push_back(input.tagKind());
347 } else {
348 types.emplace_back();
349 }
350 }
351 return types;
352 }
353
354 // ----------------------------------------------------------------------------
355 // -- NCCL Metadata -----------------------------------------------------------
356 // ----------------------------------------------------------------------------
357
358 static constexpr int32_t kTruncatLength = 30;
359
360 template <typename ListLikeType>
format_list(ListLikeType list,bool truncate)361 inline std::string format_list(ListLikeType list, bool truncate) {
362 if (truncate && list.size() > kTruncatLength) {
363 return fmt::format(
364 "\"[{}, ...]\"",
365 fmt::join(list.begin(), list.begin() + kTruncatLength, ", "));
366 }
367 return fmt::format("\"[{}]\"", fmt::join(list.begin(), list.end(), ", "));
368 }
369
saveNcclMeta(const at::RecordFunction & fn,bool truncate)370 std::unordered_map<std::string, std::string> saveNcclMeta(
371 const at::RecordFunction& fn,
372 bool truncate) {
373 std::unordered_map<std::string, std::string> map;
374 #ifdef USE_DISTRIBUTED
375 auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
376 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
377 if (debugInfo == nullptr) {
378 LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
379 << fn.name();
380 return map;
381 }
382
383 auto& collective_name = debugInfo->getCollectiveName();
384 map.emplace(kCommsName, fmt::format("\"{}\"", collective_name));
385 map.emplace(
386 kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
387 map.emplace(kInMsgNelems, std::to_string(debugInfo->getInMessageNelems()));
388 map.emplace(kOutMsgNelems, std::to_string(debugInfo->getOutMessageNelems()));
389
390 auto& inSplitSizes = debugInfo->getInputSplitSizes();
391 map.emplace(kInSplit, format_list(inSplitSizes, truncate));
392
393 auto& outSplitSizes = debugInfo->getOutputSplitSizes();
394 map.emplace(kOutSplit, format_list(outSplitSizes, truncate));
395
396 auto globalRankStart = debugInfo->getGlobalRankStart();
397 if (globalRankStart >= 0) {
398 map.emplace(kGlobalRankStart, std::to_string(globalRankStart));
399 }
400 auto globalRankStride = debugInfo->getGlobalRankStride();
401 if (globalRankStride > 0) {
402 map.emplace(kGlobalRankStride, std::to_string(globalRankStride));
403 }
404 map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
405 auto& group_name = debugInfo->getProcessGroupName();
406 if (!group_name.empty()) {
407 map.emplace(kProcessGroupName, fmt::format("\"{}\"", group_name));
408 }
409 auto& group_desc = debugInfo->getProcessGroupDesc();
410 if (!group_desc.empty()) {
411 map.emplace(kProcessGroupDesc, fmt::format("\"{}\"", group_desc));
412 }
413 auto& groupRanks = debugInfo->getGroupRanks();
414 map.emplace(kGroupRanks, format_list(groupRanks, truncate));
415
416 auto rank = debugInfo->getRank();
417 map.emplace(kRank, std::to_string(rank));
418 int nRanks = static_cast<int>(groupRanks.size());
419 if (collective_name == "send") {
420 if (rank >= 0 && rank < nRanks) {
421 map.emplace(kP2pDst, std::to_string(groupRanks[rank]));
422 }
423 } else if (collective_name == "recv") {
424 if (rank >= 0 && rank < nRanks) {
425 map.emplace(kP2pSrc, std::to_string(groupRanks[rank]));
426 }
427 }
428 #endif // USE_DISTRIBUTED
429 return map;
430 }
431
432 // ----------------------------------------------------------------------------
433 // -- FLOPS -------------------------------------------------------------------
434 // ----------------------------------------------------------------------------
435 static constexpr auto kConv2dStride = 3;
436 static constexpr auto kConv2dPadding = 4;
437 static constexpr auto kConv2dDilation = 5;
438 static constexpr auto kConv2dGroups = 6;
439
440 // List of supported operators
441 static constexpr auto kConv2dOp = "aten::conv2d";
442 static constexpr auto kMMOp = "aten::mm";
443 static constexpr auto kAddMMOp = "aten::addmm";
444 static constexpr auto kMulOp = "aten::mul";
445 static constexpr auto kAddOp = "aten::add";
446 static constexpr auto kBMMOp = "aten::bmm";
447 static constexpr auto kBAddBMMOp = "aten::baddbmm";
448
449 static constexpr auto kInputSize = "input_size";
450 static constexpr auto kWeightSize = "weight_size";
451 static constexpr auto kGroups = "groups";
452 static constexpr auto kPadding = "padding";
453 static constexpr auto kStride = "stride";
454 static constexpr auto kDilation = "dilation";
455 static constexpr auto kMatSize = "mat_size";
456 static constexpr auto kMat1Size = "mat1_size";
457 static constexpr auto kMat2Size = "mat2_size";
458
getInputSizes(const std::string & op_name,size_t min_size,c10::ArrayRef<const c10::IValue> inputs,const c10::ArrayRef<int> & should_be_tensor)459 static std::vector<c10::IntArrayRef> getInputSizes(
460 const std::string& op_name,
461 size_t min_size,
462 c10::ArrayRef<const c10::IValue> inputs,
463 const c10::ArrayRef<int>& should_be_tensor) {
464 std::stringstream ss;
465 if (inputs.size() < min_size) {
466 ss << "Failed to save extra arguments for flops computation of op "
467 << op_name << ", min size: " << min_size
468 << ", actual size: " << inputs.size();
469 TORCH_WARN(ss.str());
470 return {};
471 }
472 std::vector<c10::IntArrayRef> inputSizes = {};
473 for (auto index : should_be_tensor) {
474 if (!inputs[index].isTensor()) {
475 ss << "Failed to save extra arguments for flops computation of op "
476 << op_name << ", input[" << index << "] must be a tensor.";
477 TORCH_WARN(ss.str());
478 return {};
479 }
480 at::Tensor t = inputs[index].toTensor();
481 if (t.is_nested()) {
482 ss << "Failed to save extra arguments for flops computation of op "
483 << op_name << " with input[" << index << "] as nested tensor.";
484 TORCH_WARN(ss.str());
485 return {};
486 }
487 inputSizes.emplace_back(t.sizes());
488 }
489 return inputSizes;
490 }
491
saveExtraArgs(const at::RecordFunction & fn)492 std::unordered_map<std::string, c10::IValue> saveExtraArgs(
493 const at::RecordFunction& fn) {
494 // for specific types of fn, return the saved extra args for computing flops
495 std::unordered_map<std::string, c10::IValue> map;
496 auto inputs = fn.inputs();
497 std::string fname(fn.name());
498
499 if (inputs.empty()) {
500 // Input shape is unavailable, return empty map
501 return map;
502 }
503
504 if (fname == kConv2dOp) {
505 const auto inputSizes =
506 getInputSizes(fname, kConv2dGroups + 1, inputs, {0, 1});
507 if (inputSizes.empty()) {
508 return map;
509 }
510 if (inputSizes[1].size() != 4) {
511 TORCH_WARN(
512 "Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
513 return map;
514 }
515 map[kInputSize] = at::IValue(inputSizes[0]);
516 map[kWeightSize] = at::IValue(inputSizes[1]);
517 map[kStride] = inputs[kConv2dStride];
518 map[kPadding] = inputs[kConv2dPadding];
519 map[kDilation] = inputs[kConv2dDilation];
520 map[kGroups] = inputs[kConv2dGroups];
521 } else if (fname == kMMOp) {
522 const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
523 if (inputSizes.empty()) {
524 return map;
525 }
526
527 map[kMat1Size] = at::IValue(inputSizes[0]);
528 map[kMat2Size] = at::IValue(inputSizes[1]);
529 } else if (fname == kAddMMOp) {
530 const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
531 if (inputSizes.empty()) {
532 return map;
533 }
534 // Exact FLOP count depends on scaling factors alpha and beta but
535 // just assume these are +=1.
536 // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
537 // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
538 map[kMat1Size] = at::IValue(inputSizes[1]);
539 map[kMat2Size] = at::IValue(inputSizes[2]);
540 } else if (fname == kMulOp) {
541 const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
542 if (inputSizes.empty()) {
543 return map;
544 }
545 map[kMatSize] = at::IValue(inputSizes[0]);
546 } else if (fname == kAddOp) {
547 const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
548 if (inputSizes.empty()) {
549 return map;
550 }
551 map[kMatSize] = at::IValue(inputSizes[0]);
552 } else if (fname == kBMMOp) {
553 const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
554 if (inputSizes.empty()) {
555 return map;
556 }
557
558 map[kMat1Size] = at::IValue(inputSizes[0]);
559 map[kMat2Size] = at::IValue(inputSizes[1]);
560 } else if (fname == kBAddBMMOp) {
561 const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
562 if (inputSizes.empty()) {
563 return map;
564 }
565
566 // Exact FLOP count depends on scaling factors alpha and beta but
567 // just assume these are +=1.
568 // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
569 // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
570 map[kMat1Size] = at::IValue(inputSizes[1]);
571 map[kMat2Size] = at::IValue(inputSizes[2]);
572 }
573
574 return map;
575 }
576
computeFlops(const std::string & op_name,const std::unordered_map<std::string,c10::IValue> & extra_args)577 uint64_t computeFlops(
578 const std::string& op_name,
579 const std::unordered_map<std::string, c10::IValue>& extra_args) {
580 if (op_name == kConv2dOp) {
581 if (extra_args.find(kInputSize) == extra_args.end() ||
582 extra_args.find(kWeightSize) == extra_args.end() ||
583 extra_args.find(kGroups) == extra_args.end() ||
584 extra_args.find(kPadding) == extra_args.end() ||
585 extra_args.find(kStride) == extra_args.end() ||
586 extra_args.find(kDilation) == extra_args.end()) {
587 TORCH_WARN(
588 "Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments.");
589 return 0;
590 }
591 auto input_sizes_ref = extra_args.at(kInputSize);
592 auto kernel_sizes_ref = extra_args.at(kWeightSize);
593 auto groups_ref = extra_args.at(kGroups);
594 auto padding_ref = extra_args.at(kPadding);
595 auto stride_ref = extra_args.at(kStride);
596 auto dilation_ref = extra_args.at(kDilation);
597 if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
598 TORCH_WARN(
599 "Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
600 return 0;
601 }
602 if (!padding_ref.isIntList() || !stride_ref.isIntList() ||
603 !dilation_ref.isIntList()) {
604 TORCH_WARN(
605 "Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values.");
606 return 0;
607 }
608
609 const auto input_sizes = input_sizes_ref.toDimVector();
610 const auto kernel_sizes = kernel_sizes_ref.toDimVector();
611 const uint64_t groups = groups_ref.toInt();
612 const std::vector<int64_t> padding = padding_ref.toIntVector();
613 const std::vector<int64_t> stride = stride_ref.toIntVector();
614 const std::vector<int64_t> dilation = dilation_ref.toIntVector();
615 if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
616 TORCH_WARN(
617 "Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
618 return 0;
619 }
620 if (!groups) {
621 TORCH_WARN(
622 "Failed to compute flops for op aten::conv2d because group size must not be 0.");
623 return 0;
624 }
625 if (padding.size() != 2 || dilation.size() != 2) {
626 TORCH_WARN(
627 "Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
628 return 0;
629 }
630 if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
631 TORCH_WARN(
632 "Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
633 return 0;
634 }
635 // format of the input is defined in
636 // torch.ao.nn.quantized.functional.conv2d()
637 const uint64_t conv2d_multiply_factor = 2;
638 auto [minibatch, in_channels, input_h, input_w] = std::make_tuple(
639 input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]);
640 auto [out_channels, _, kernel_h, kernel_w] = std::make_tuple(
641 kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]);
642 uint64_t output_h =
643 (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) /
644 stride[0] +
645 1;
646 uint64_t output_w =
647 (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) /
648 stride[1] +
649 1;
650
651 return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h *
652 kernel_w * in_channels * out_channels / groups;
653 } else if (op_name == kMMOp || op_name == kAddMMOp) {
654 if (extra_args.find(kMat1Size) == extra_args.end() ||
655 extra_args.find(kMat2Size) == extra_args.end()) {
656 TORCH_WARN(
657 "Calculating flops for ",
658 op_name,
659 " requires mat1_size and mat2_size in saved arguments.");
660 return 0;
661 }
662 auto mat1_sizes_ref = extra_args.at(kMat1Size);
663 auto mat2_sizes_ref = extra_args.at(kMat2Size);
664 if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
665 TORCH_WARN(
666 "Failed to compute flops for op ",
667 op_name,
668 " because it requires mat1_size and mat2_size to be IntList.");
669 return 0;
670 }
671
672 const auto mat1_size = mat1_sizes_ref.toDimVector();
673 const auto mat2_size = mat2_sizes_ref.toDimVector();
674 if (mat1_size.empty()) {
675 return 0;
676 }
677
678 int64_t overlap_dim = mat1_size.back();
679 if (overlap_dim == 0) {
680 return 0;
681 }
682
683 const uint64_t gemm_multiply_factor = 2;
684 uint64_t flops = 1;
685 for (int64_t dim : mat1_size) {
686 flops *= dim;
687 }
688 flops /= overlap_dim;
689 for (int64_t dim : mat2_size) {
690 flops *= dim;
691 }
692 flops *= gemm_multiply_factor;
693 return flops;
694 } else if (op_name == kBMMOp || op_name == kBAddBMMOp) {
695 if (extra_args.find(kMat1Size) == extra_args.end() ||
696 extra_args.find(kMat2Size) == extra_args.end()) {
697 TORCH_WARN(
698 "Calculating flops for ",
699 op_name,
700 " requires mat1_size and mat2_size in saved arguments.");
701 return 0;
702 }
703 auto mat1_sizes_ref = extra_args.at(kMat1Size);
704 auto mat2_sizes_ref = extra_args.at(kMat2Size);
705 if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
706 TORCH_WARN(
707 "Failed to compute flops for op ",
708 op_name,
709 " because it requires mat1_size and mat2_size to be IntList.");
710 return 0;
711 }
712
713 const auto mat1_size = mat1_sizes_ref.toDimVector();
714 const auto mat2_size = mat2_sizes_ref.toDimVector();
715 if (mat1_size.empty()) {
716 return 0;
717 }
718
719 int64_t batch_size = mat1_size.front();
720 if (batch_size == 0) {
721 return 0;
722 }
723
724 int64_t overlap_dim = mat1_size.back();
725 if (overlap_dim == 0) {
726 return 0;
727 }
728
729 const uint64_t gemm_multiply_factor = 2;
730 uint64_t flops = 1;
731 for (int64_t dim : mat1_size) {
732 flops *= dim;
733 }
734 flops /= overlap_dim;
735 flops /= batch_size;
736 for (int64_t dim : mat2_size) {
737 flops *= dim;
738 }
739 flops *= gemm_multiply_factor;
740 return flops;
741 } else if (op_name == kMulOp) {
742 if (extra_args.find(kMatSize) == extra_args.end()) {
743 TORCH_WARN(
744 "Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
745 return 0;
746 }
747 auto mat_sizes = extra_args.at(kMatSize);
748 if (!mat_sizes.isIntList()) {
749 TORCH_WARN(
750 "Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
751 return 0;
752 }
753
754 const auto mat_size = mat_sizes.toDimVector();
755 uint64_t flops = 1;
756 for (int64_t dim : mat_size) {
757 flops *= dim;
758 }
759 return flops;
760 } else if (op_name == kAddOp) {
761 if (extra_args.find(kMatSize) == extra_args.end()) {
762 TORCH_WARN(
763 "Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
764 return 0;
765 }
766 auto mat_sizes = extra_args.at(kMatSize);
767 if (!mat_sizes.isIntList()) {
768 TORCH_WARN(
769 "Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
770 return 0;
771 }
772
773 const auto mat_size = mat_sizes.toDimVector();
774 uint64_t flops = 1;
775 for (int64_t dim : mat_size) {
776 flops *= dim;
777 }
778 return flops;
779 }
780 return 0;
781 }
782
783 } // namespace torch::profiler::impl
784