1# mypy: allow-untyped-defs 2import traceback 3from contextlib import contextmanager 4from typing import List, Any, Dict 5from ._compatibility import compatibility 6 7__all__ = ['preserve_node_meta', 'has_preserved_node_meta', 8 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', 9 'format_stack', 'set_current_meta', 'get_current_meta'] 10 11current_meta: Dict[str, Any] = {} 12should_preserve_node_meta = False 13 14 15@compatibility(is_backward_compatible=False) 16@contextmanager 17def preserve_node_meta(): 18 global should_preserve_node_meta 19 global current_meta 20 21 saved_should_preserve_node_meta = should_preserve_node_meta 22 # Shallow copy is OK since fields of current_meta are not mutated 23 saved_current_meta = current_meta.copy() 24 try: 25 should_preserve_node_meta = True 26 yield 27 finally: 28 should_preserve_node_meta = saved_should_preserve_node_meta 29 current_meta = saved_current_meta 30 31 32@compatibility(is_backward_compatible=False) 33def set_stack_trace(stack : List[str]): 34 global current_meta 35 36 if should_preserve_node_meta and stack: 37 current_meta["stack_trace"] = "".join(stack) 38 39 40@compatibility(is_backward_compatible=False) 41def set_grad_fn_seq_nr(seq_nr): 42 global current_meta 43 44 if should_preserve_node_meta: 45 # The seq_nr is captured by eager mode in the grad_fn during forward 46 current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] 47 current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 48 49 50@compatibility(is_backward_compatible=False) 51def reset_grad_fn_seq_nr(): 52 # NB: reset state properly, this would be helpful towards supporting 53 # reentrant autograd if we actually wanted to do that. 54 global current_meta 55 if should_preserve_node_meta: 56 current_level = current_meta.get("in_grad_fn", 0) 57 assert current_level > 0 58 if current_level == 1: 59 del current_meta["in_grad_fn"] 60 del current_meta["grad_fn_seq_nr"] 61 else: 62 current_meta["in_grad_fn"] = current_level - 1 63 current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] 64 65 66@compatibility(is_backward_compatible=False) 67def format_stack() -> List[str]: 68 if should_preserve_node_meta: 69 return [current_meta.get("stack_trace", "")] 70 else: 71 # fallback to traceback.format_stack() 72 return traceback.format_list(traceback.extract_stack()[:-1]) 73 74 75@compatibility(is_backward_compatible=False) 76def has_preserved_node_meta() -> bool: 77 return should_preserve_node_meta 78 79 80@compatibility(is_backward_compatible=False) 81@contextmanager 82def set_current_meta(node): 83 global current_meta 84 if should_preserve_node_meta and node.meta: 85 saved_meta = current_meta 86 try: 87 current_meta = node.meta.copy() 88 89 # Append (node.name, node.target) onto "from_node" for provenance tracking 90 if "from_node" not in current_meta: 91 current_meta["from_node"] = [(node.name, node.target)] 92 elif current_meta["from_node"][-1][0] != node.name: 93 current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] 94 95 yield 96 finally: 97 current_meta = saved_meta 98 else: 99 yield 100 101 102@compatibility(is_backward_compatible=False) 103def get_current_meta() -> Dict[str, Any]: 104 return current_meta 105