xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/annotate_warns.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/annotate_warns.h>
2 
3 #include <atomic>
4 
5 namespace torch::jit {
6 
AnnotateWarns(Block * b)7 static void AnnotateWarns(Block* b) {
8   static std::atomic<int64_t> idx(0);
9   for (Node* n : b->nodes()) {
10     for (Block* child_b : n->blocks()) {
11       AnnotateWarns(child_b);
12     }
13 
14     if (n->kind() != aten::warn) {
15       continue;
16     }
17 
18     n->i_(attr::warn_id, idx);
19     idx++;
20   }
21 }
22 
AnnotateWarns(const std::shared_ptr<Graph> & graph)23 void AnnotateWarns(const std::shared_ptr<Graph>& graph) {
24   AnnotateWarns(graph->block());
25 }
26 
27 } // namespace torch::jit
28