1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import dis 7import inspect 8from dataclasses import dataclass 9from typing import Union 10 11from . import DimList 12 13 14_vmap_levels = [] 15 16 17@dataclass 18class LevelInfo: 19 level: int 20 alive: bool = True 21 22 23class Dim: 24 def __init__(self, name: str, size: Union[None, int] = None): 25 self.name = name 26 self._size = None 27 self._vmap_level = None 28 if size is not None: 29 self.size = size 30 31 def __del__(self): 32 if self._vmap_level is not None: 33 _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 34 while ( 35 not _vmap_levels[-1].alive 36 and current_level() == _vmap_levels[-1].level # noqa: F821 37 ): 38 _vmap_decrement_nesting() # noqa: F821 39 _vmap_levels.pop() 40 41 @property 42 def size(self): 43 assert self.is_bound 44 return self._size 45 46 @size.setter 47 def size(self, size: int): 48 from . import DimensionBindError 49 50 if self._size is None: 51 self._size = size 52 self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 53 self._vmap_stack = len(_vmap_levels) 54 _vmap_levels.append(LevelInfo(self._vmap_level)) 55 56 elif self._size != size: 57 raise DimensionBindError( 58 f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" 59 ) 60 61 @property 62 def is_bound(self): 63 return self._size is not None 64 65 def __repr__(self): 66 return self.name 67 68 69def extract_name(inst): 70 assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" 71 return inst.argval 72 73 74_cache = {} 75 76 77def dims(lists=0): 78 frame = inspect.currentframe() 79 assert frame is not None 80 calling_frame = frame.f_back 81 assert calling_frame is not None 82 code, lasti = calling_frame.f_code, calling_frame.f_lasti 83 key = (code, lasti) 84 if key not in _cache: 85 first = lasti // 2 + 1 86 instructions = list(dis.get_instructions(calling_frame.f_code)) 87 unpack = instructions[first] 88 89 if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": 90 # just a single dim, not a list 91 name = unpack.argval 92 ctor = Dim if lists == 0 else DimList 93 _cache[key] = lambda: ctor(name=name) 94 else: 95 assert unpack.opname == "UNPACK_SEQUENCE" 96 ndims = unpack.argval 97 names = tuple( 98 extract_name(instructions[first + 1 + i]) for i in range(ndims) 99 ) 100 first_list = len(names) - lists 101 _cache[key] = lambda: tuple( 102 Dim(n) if i < first_list else DimList(name=n) 103 for i, n in enumerate(names) 104 ) 105 return _cache[key]() 106 107 108def _dim_set(positional, arg): 109 def convert(a): 110 if isinstance(a, Dim): 111 return a 112 else: 113 assert isinstance(a, int) 114 return positional[a] 115 116 if arg is None: 117 return positional 118 elif not isinstance(arg, (Dim, int)): 119 return tuple(convert(a) for a in arg) 120 else: 121 return (convert(arg),) 122