xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/unwind/fde.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/irange.h>
3 #include <torch/csrc/profiler/unwind/action.h>
4 #include <torch/csrc/profiler/unwind/lexer.h>
5 #include <array>
6 #include <iostream>
7 #include <sstream>
8 #include <vector>
9 
10 namespace torch::unwind {
11 
12 struct TableState {
13   Action cfa;
14   std::array<Action, D_REG_SIZE> registers;
15   friend std::ostream& operator<<(std::ostream& out, const TableState& self) {
16     out << "cfa = " << self.cfa << "; ";
17     for (auto r : c10::irange(self.registers.size())) {
18       if (self.registers.at(r).kind != A_UNDEFINED) {
19         out << "r" << r << " = " << self.registers.at(r) << "; ";
20       }
21     }
22     return out;
23   }
24 };
25 
26 // FDE - Frame Description Entry (Concept in ELF spec)
27 // This format is explained well by
28 // https://www.airs.com/blog/archives/460
29 // Details of different dwarf actions are explained
30 // in the spec document:
31 // https://web.archive.org/web/20221129184704/https://dwarfstd.org/doc/DWARF4.doc
32 // An overview of how DWARF unwinding works is given in
33 // https://dl.acm.org/doi/pdf/10.1145/3360572
34 // A similar implementation written in rust is:
35 // https://github.com/mstange/framehop/
36 
37 template <bool LOG = false>
38 struct FDE {
FDEFDE39   FDE(void* data, const char* library_name, uint64_t load_bias)
40       : library_name_(library_name), load_bias_(load_bias) {
41     Lexer L(data);
42     auto length = L.read4or8Length();
43     void* fde_start = L.loc();
44     void* cie_data = (void*)((int64_t)fde_start - L.read<uint32_t>());
45     Lexer LC(cie_data);
46     auto cie_length = LC.read4or8Length();
47     void* cie_start = LC.loc();
48     auto zero = LC.read<uint32_t>();
49     TORCH_INTERNAL_ASSERT(zero == 0, "expected 0 for CIE");
50     auto version = LC.read<uint8_t>();
51     TORCH_INTERNAL_ASSERT(
52         version == 1 || version == 3, "non-1 version for CIE");
53     augmentation_string_ = LC.readCString();
54     if (hasAugmentation("eh")) {
55       throw UnwindError("unsupported 'eh' augmentation string");
56     }
57     code_alignment_factor_ = LC.readULEB128();
58     data_alignment_factor_ = LC.readSLEB128();
59     if (version == 1) {
60       ra_register_ = LC.read<uint8_t>();
61     } else {
62       ra_register_ = LC.readULEB128();
63     }
64     // we assume this in the state
65     TORCH_INTERNAL_ASSERT(ra_register_ == 16, "unexpected number of registers");
66     if (augmentation_string_ && *augmentation_string_ == 'z') {
67       augmentation_length_ = LC.readULEB128();
68       Lexer A(LC.loc());
69       for (auto ap = augmentation_string_ + 1; *ap; ap++) {
70         switch (*ap) {
71           case 'L':
72             lsda_enc = A.read<uint8_t>();
73             break;
74           case 'R':
75             fde_enc = A.read<uint8_t>();
76             break;
77           case 'P': {
78             uint8_t personality_enc = A.read<uint8_t>();
79             A.readEncoded(personality_enc);
80           } break;
81           case 'S': {
82             // signal handler
83           } break;
84           default: {
85             throw UnwindError("unknown augmentation string");
86           } break;
87         }
88       }
89     }
90     LC.skip(augmentation_length_);
91     low_pc_ = L.readEncoded(fde_enc);
92     high_pc_ = low_pc_ + L.readEncodedValue(fde_enc);
93 
94     if (hasAugmentation("z")) {
95       augmentation_length_fde_ = L.readULEB128();
96     }
97     L.readEncodedOr(lsda_enc, 0);
98 
99     cie_begin_ = LC.loc();
100     fde_begin_ = L.loc();
101     cie_end_ = (void*)((const char*)cie_start + cie_length);
102     fde_end_ = (void*)((const char*)fde_start + length);
103   }
104 
105   // OP Code implementations
106 
advance_rawFDE107   void advance_raw(int64_t amount) {
108     auto previous_pc = current_pc_;
109     current_pc_ += amount;
110     if (LOG) {
111       (*out_) << (void*)(previous_pc - load_bias_) << "-"
112               << (void*)(current_pc_ - load_bias_) << ": " << state() << "\n";
113     }
114   }
115 
advance_locFDE116   void advance_loc(int64_t amount) {
117     if (LOG) {
118       (*out_) << "advance_loc " << amount << "\n";
119     }
120     advance_raw(amount * code_alignment_factor_);
121   }
122 
offsetFDE123   void offset(int64_t reg, int64_t offset) {
124     if (LOG) {
125       (*out_) << "offset " << reg << " " << offset << "\n";
126     }
127     if (reg > (int64_t)state().registers.size()) {
128       if (LOG) {
129         (*out_) << "OFFSET OF BIG REGISTER " << reg << "ignored...\n";
130       }
131       return;
132     }
133     state().registers.at(reg) =
134         Action{A_LOAD_CFA_OFFSET, -1, offset * data_alignment_factor_};
135   }
136 
restoreFDE137   void restore(int64_t reg) {
138     if (LOG) {
139       (*out_) << "restore " << reg << "\n";
140     }
141     if (reg > (int64_t)state().registers.size()) {
142       if (LOG) {
143         (*out_) << "RESTORE OF BIG REGISTER " << reg << "ignored...\n";
144       }
145       return;
146     }
147     state().registers.at(reg) = initial_state_.registers.at(reg);
148   }
149 
def_cfaFDE150   void def_cfa(int64_t reg, int64_t off) {
151     if (LOG) {
152       (*out_) << "def_cfa " << reg << " " << off << "\n";
153     }
154     last_reg_ = reg;
155     last_offset_ = off;
156     state().cfa = Action::regPlusData(reg, off);
157   }
def_cfa_registerFDE158   void def_cfa_register(int64_t reg) {
159     def_cfa(reg, last_offset_);
160   }
def_cfa_offsetFDE161   void def_cfa_offset(int64_t off) {
162     def_cfa(last_reg_, off);
163   }
164 
remember_stateFDE165   void remember_state() {
166     if (LOG) {
167       (*out_) << "remember_state\n";
168     }
169     state_stack_.push_back(state());
170   }
restore_stateFDE171   void restore_state() {
172     if (LOG) {
173       (*out_) << "restore_state\n";
174     }
175     state_stack_.pop_back();
176   }
177 
undefinedFDE178   void undefined(int64_t reg) {
179     if (LOG) {
180       (*out_) << "undefined " << reg << "\n";
181     }
182     state().registers.at(reg) = Action::undefined();
183   }
register_FDE184   void register_(int64_t reg, int64_t rhs_reg) {
185     if (LOG) {
186       (*out_) << "register " << reg << " " << rhs_reg << "\n";
187     }
188     state().registers.at(reg) = Action::regPlusData(reg, 0);
189   }
190 
stateFDE191   TableState& state() {
192     return state_stack_.back();
193   }
194 
dumpFDE195   void dump(std::ostream& out) {
196     out_ = &out;
197     out << "FDE(augmentation_string=" << augmentation_string_
198         << ", low_pc=" << (void*)(low_pc_ - load_bias_)
199         << ",high_pc=" << (void*)(high_pc_ - load_bias_)
200         << ",code_alignment_factor=" << code_alignment_factor_
201         << ", data_alignment_factor=" << data_alignment_factor_
202         << ", ra_register_=" << ra_register_ << ")\n";
203     readUpTo(high_pc_);
204     out_ = &std::cout;
205   }
206 
readUpToFDE207   TableState readUpTo(uint64_t addr) {
208     if (addr < low_pc_ || addr > high_pc_) {
209       throw UnwindError("Address not in range");
210     }
211     if (LOG) {
212       (*out_) << "readUpTo " << (void*)addr << " for " << library_name_
213               << " at " << (void*)load_bias_ << "\n";
214     }
215     state_stack_.emplace_back();
216     current_pc_ = low_pc_;
217     // parse instructions...
218     Lexer LC(cie_begin_);
219     while (LC.loc() < cie_end_ && current_pc_ <= addr) {
220       readInstruction(LC);
221     }
222     if (current_pc_ > addr) {
223       return state();
224     }
225 
226     initial_state_ = state_stack_.back();
227 
228     if (LOG) {
229       (*out_) << "--\n";
230     }
231 
232     Lexer L(fde_begin_);
233     while (L.loc() < fde_end_ && current_pc_ <= addr) {
234       readInstruction(L);
235     }
236     // so that we print the full range in debugging
237     if (current_pc_ <= addr) {
238       advance_raw(addr - current_pc_);
239     }
240     return state();
241   }
242 
dumpAddr2LineFDE243   void dumpAddr2Line() {
244     std::cout << "addr2line -f -e " << library_name_ << " "
245               << (void*)(low_pc_ - load_bias_) << "\n";
246   }
247 
readInstructionFDE248   void readInstruction(Lexer& L) {
249     uint8_t bc = L.read<uint8_t>();
250     auto op = bc >> 6;
251     auto lowbits = bc & 0x3F;
252     switch (op) {
253       case 0x0: {
254         switch (lowbits) {
255           case DW_CFA_nop: {
256             return; // nop
257           }
258           case DW_CFA_advance_loc1: {
259             auto delta = L.read<uint8_t>();
260             return advance_loc(delta);
261           }
262           case DW_CFA_advance_loc2: {
263             auto delta = L.read<uint16_t>();
264             return advance_loc(delta);
265           }
266           case DW_CFA_advance_loc4: {
267             auto delta = L.read<uint32_t>();
268             return advance_loc(delta);
269           }
270           case DW_CFA_restore_extended: {
271             auto reg = L.readULEB128();
272             return restore(reg);
273           }
274           case DW_CFA_undefined: {
275             auto reg = L.readULEB128();
276             return undefined(reg);
277           }
278           case DW_CFA_register: {
279             auto reg = L.readULEB128();
280             auto rhs_reg = L.readULEB128();
281             return register_(reg, rhs_reg);
282           }
283           case DW_CFA_def_cfa: {
284             auto reg = L.readULEB128();
285             auto off = L.readULEB128();
286             return def_cfa(reg, off);
287           }
288           case DW_CFA_def_cfa_register: {
289             auto reg = L.readULEB128();
290             return def_cfa_register(reg);
291           }
292           case DW_CFA_def_cfa_offset: {
293             auto off = L.readULEB128();
294             return def_cfa_offset(off);
295           }
296           case DW_CFA_offset_extended_sf: {
297             auto reg = L.readULEB128();
298             auto off = L.readSLEB128();
299             return offset(reg, off);
300           }
301           case DW_CFA_remember_state: {
302             return remember_state();
303           }
304           case DW_CFA_restore_state: {
305             return restore_state();
306           }
307           case DW_CFA_GNU_args_size: {
308             // GNU_args_size, we do not need to know it..
309             L.readULEB128();
310             return;
311           }
312           case DW_CFA_expression: {
313             auto reg = L.readULEB128();
314             auto len = L.readULEB128();
315             auto end = (void*)((uint64_t)L.loc() + len);
316             auto op = L.read<uint8_t>();
317             if ((op & 0xF0) == 0x70) { // DW_bregX
318               auto rhs_reg = (op & 0xF);
319               auto addend = L.readSLEB128();
320               if (L.loc() == end) {
321                 state().registers.at(reg) =
322                     Action::regPlusDataDeref(rhs_reg, addend);
323                 return;
324               }
325             }
326             throw UnwindError("Unsupported dwarf expression");
327           }
328           case DW_CFA_def_cfa_expression: {
329             auto len = L.readULEB128();
330             auto end = (void*)((uint64_t)L.loc() + len);
331             auto op = L.read<uint8_t>();
332             if ((op & 0xF0) == 0x70) { // DW_bregX
333               auto rhs_reg = (op & 0xF);
334               auto addend = L.readSLEB128();
335               if (L.loc() != end) {
336                 auto op2 = L.read<uint8_t>();
337                 if (op2 == DW_OP_deref && L.loc() == end) { // deref
338                   state().cfa = Action::regPlusDataDeref(rhs_reg, addend);
339                   return;
340                 }
341               }
342             }
343             throw UnwindError("Unsupported def_cfa dwarf expression");
344           }
345           default: {
346             std::stringstream ss;
347             ss << "unknown op code " << (void*)(uint64_t)lowbits;
348             throw UnwindError(ss.str());
349           }
350         }
351       }
352       case DW_CFA_advance_loc: {
353         return advance_loc(lowbits);
354       }
355       case DW_CFA_offset: {
356         auto off = L.readULEB128();
357         return offset(lowbits, off);
358       }
359       case DW_CFA_restore: {
360         return restore(lowbits);
361       }
362     }
363   }
364   // used for debug printing
365   const char* library_name_;
366   uint64_t load_bias_;
367 
368   // parsed from the eh_string data structures:
369   const char* augmentation_string_ = nullptr;
370   int64_t augmentation_length_ = 0;
371   int64_t augmentation_length_fde_ = 0;
372 
373   int64_t code_alignment_factor_;
374   int64_t data_alignment_factor_;
375   void* cie_data_;
376 
377   int64_t ra_register_;
378   uint8_t lsda_enc = DW_EH_PE_omit;
379   uint8_t fde_enc = DW_EH_PE_absptr;
380   uint64_t low_pc_ = UINT64_MAX;
381   uint64_t high_pc_ = UINT64_MAX;
382 
383   void* cie_begin_;
384   void* fde_begin_;
385   void* cie_end_;
386   void* fde_end_;
387 
388   // state accumulated while parsing instructions
389   int64_t last_reg_ = 0;
390   int64_t last_offset_ = 0;
391   uint64_t current_pc_;
392 
393   TableState
394       initial_state_; // state after the initial instructions, used by restore
395   std::vector<TableState> state_stack_;
396 
397   std::ostream* out_ = &std::cout; // for debug dumping
398  private:
hasAugmentationFDE399   bool hasAugmentation(const char* s) {
400     return strstr(augmentation_string_, s) != nullptr;
401   }
402 };
403 
404 } // namespace torch::unwind
405