1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Routines for fully traversing an IR.""" 16 17import inspect 18 19from compiler.util import ir_data 20from compiler.util import ir_data_fields 21from compiler.util import ir_data_utils 22from compiler.util import simple_memoizer 23 24 25class _FunctionCaller: 26 """Provides a template for setting up a generic call to a function. 27 28 The function parameters are inspected at run-time to build up a set of valid 29 and required arguments. When invoking the function unneccessary parameters 30 will be trimmed out. If arguments are missing an assertion will be triggered. 31 32 This is currently limited to functions that have at least one positional 33 parameter. 34 35 Example usage: 36 ``` 37 def func_1(a, b, c=2): pass 38 def func_2(a, d): pass 39 caller_1 = _FunctionCaller(func_1) 40 caller_2 = _FunctionCaller(func_2) 41 generic_params = {"b": 2, "c": 3, "d": 4} 42 43 # Equivalent of: func_1(a, b=2, c=3) 44 caller_1.invoke(a, generic_params) 45 46 # Equivalent of: func_2(a, d=4) 47 caller_2.invoke(a, generic_params) 48 """ 49 50 def __init__(self, function): 51 self.function = function 52 self.needs_filtering = True 53 self.valid_arg_names = set() 54 self.required_arg_names = set() 55 56 argspec = inspect.getfullargspec(function) 57 if argspec.varkw: 58 # If the function accepts a kwargs parameter, then it will accept all 59 # arguments. 60 # Note: this isn't technically true if one of the keyword arguments has the 61 # same name as one of the positional arguments. 62 self.needs_filtering = False 63 else: 64 # argspec.args is a list of all parameter names excluding keyword only 65 # args. The first element is our required positional_arg and should be 66 # ignored. 67 args = argspec.args[1:] 68 self.valid_arg_names.update(args) 69 70 # args.kwonlyargs gives us the list of keyword only args which are 71 # also valid. 72 self.valid_arg_names.update(argspec.kwonlyargs) 73 74 # Required args are positional arguments that don't have defaults. 75 # Keyword only args are always optional and can be ignored. Args with 76 # defaults are the last elements of the argsepec.args list and should 77 # be ignored. 78 if argspec.defaults: 79 # Trim the arguments with defaults. 80 args = args[: -len(argspec.defaults)] 81 self.required_arg_names.update(args) 82 83 def invoke(self, positional_arg, keyword_args): 84 """Invokes the function with the given args.""" 85 if self.needs_filtering: 86 # Trim to just recognized args. 87 matched_args = { 88 k: v for k, v in keyword_args.items() if k in self.valid_arg_names 89 } 90 # Check if any required args are missing. 91 missing_args = self.required_arg_names.difference(matched_args.keys()) 92 assert not missing_args, ( 93 f"Attempting to call '{self.function.__name__}'; " 94 f"missing {missing_args} (have {set(keyword_args.keys())})" 95 ) 96 keyword_args = matched_args 97 98 return self.function(positional_arg, **keyword_args) 99 100 101@simple_memoizer.memoize 102def _memoized_caller(function): 103 default_lambda_name = (lambda: None).__name__ 104 assert ( 105 callable(function) and not function.__name__ == default_lambda_name 106 ), "For performance reasons actions must be defined as static functions" 107 return _FunctionCaller(function) 108 109 110def _call_with_optional_args(function, positional_arg, keyword_args): 111 """Calls function with whatever keyword_args it will accept.""" 112 caller = _memoized_caller(function) 113 return caller.invoke(positional_arg, keyword_args) 114 115 116def _fast_traverse_proto_top_down(proto, incidental_actions, pattern, 117 skip_descendants_of, action, parameters): 118 """Traverses an IR, calling `action` on some nodes.""" 119 120 # Parameters are scoped to the branch of the tree, so make a copy here, before 121 # any action or incidental_action can update them. 122 parameters = parameters.copy() 123 124 # If there is an incidental action for this node type, run it. 125 if type(proto) in incidental_actions: # pylint: disable=unidiomatic-typecheck 126 for incidental_action in incidental_actions[type(proto)]: 127 parameters.update(_call_with_optional_args( 128 incidental_action, proto, parameters) or {}) 129 130 # If we are at the end of pattern, check to see if we should call action. 131 if len(pattern) == 1: 132 new_pattern = pattern 133 if pattern[0] == type(proto): 134 parameters.update( 135 _call_with_optional_args(action, proto, parameters) or {}) 136 else: 137 # Otherwise, if this node's type matches the head of pattern, recurse with 138 # the tail of the pattern. 139 if pattern[0] == type(proto): 140 new_pattern = pattern[1:] 141 else: 142 new_pattern = pattern 143 144 # If the current node's type is one of the types whose branch should be 145 # skipped, then bail. This has to happen after `action` is called, because 146 # clients rely on being able to, e.g., get a callback for the "root" 147 # Expression without getting callbacks for every sub-Expression. 148 # pylint: disable=unidiomatic-typecheck 149 if type(proto) in skip_descendants_of: 150 return 151 152 # Otherwise, recurse. _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given 153 # the current node's type and the current target type, which fields to check. 154 singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[ 155 type(proto), new_pattern[0]] 156 for member_name in singular_fields: 157 if proto.HasField(member_name): 158 _fast_traverse_proto_top_down(getattr(proto, member_name), 159 incidental_actions, new_pattern, 160 skip_descendants_of, action, parameters) 161 for member_name in repeated_fields: 162 for array_element in getattr(proto, member_name) or []: 163 _fast_traverse_proto_top_down(array_element, incidental_actions, 164 new_pattern, skip_descendants_of, action, 165 parameters) 166 167 168def _fields_to_scan_by_current_and_target(): 169 """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET.""" 170 # In order to avoid spending a *lot* of time just walking the IR, this 171 # function sets up a dict that allows `_fast_traverse_proto_top_down()` to 172 # skip traversing large portions of the IR, depending on what node types it is 173 # targeting. 174 # 175 # Without this branch culling scheme, the Emboss front end (at time of 176 # writing) spends roughly 70% (19s out of 31s) of its time just walking the 177 # IR. With branch culling, that goes down to 6% (0.7s out of 12.2s). 178 179 # type_to_fields is a map of types to maps of field names to field types. 180 # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit. 181 type_to_fields = {} 182 183 # Later, we need to know which fields are singular and which are repeated, 184 # because the access methods are not uniform. This maps (type, field_name) 185 # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module, 186 # "type"] == ir_data.Repeated. 187 type_fields_to_cardinality = {} 188 189 # Fill out the above maps by recursively walking the IR type tree, starting 190 # from the root. 191 types_to_check = [ir_data.EmbossIr] 192 while types_to_check: 193 type_to_check = types_to_check.pop() 194 if type_to_check in type_to_fields: 195 continue 196 fields = {} 197 for field_name, field_type in ir_data_utils.field_specs(type_to_check).items(): 198 if field_type.is_dataclass: 199 fields[field_name] = field_type.data_type 200 types_to_check.append(field_type.data_type) 201 type_fields_to_cardinality[type_to_check, field_name] = ( 202 field_type.container) 203 type_to_fields[type_to_check] = fields 204 205 # type_to_descendant_types is a map of all types that can be reached from a 206 # particular type. After the setup, type_to_descendant_types[ir_data.EmbossIr] 207 # == set(<all types>) and type_to_descendant_types[ir_data.Reference] == 208 # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and 209 # type_to_descendant_types[ir_data.Word] == set(). 210 # 211 # The while loop basically ors in the known descendants of each known 212 # descendant of each type until the dict stops changing, which is a bit 213 # brute-force, but in practice only iterates a few times. 214 type_to_descendant_types = {} 215 for parent_type, field_map in type_to_fields.items(): 216 type_to_descendant_types[parent_type] = set(field_map.values()) 217 previous_map = {} 218 while type_to_descendant_types != previous_map: 219 # In order to check the previous iteration against the current iteration, it 220 # is necessary to make a two-level copy. Otherwise, the updates to the 221 # values will also update previous_map's values, which causes the loop to 222 # exit prematurely. 223 previous_map = {k: set(v) for k, v in type_to_descendant_types.items()} 224 for ancestor_type, descendents in previous_map.items(): 225 for descendent in descendents: 226 type_to_descendant_types[ancestor_type] |= previous_map[descendent] 227 228 # Finally, we have all of the information we need to make the map we really 229 # want: given a current node type and a target node type, which fields should 230 # be checked? (This implicitly skips fields that *can't* contain the target 231 # type.) 232 fields_to_scan_by_current_and_target = {} 233 for current_node_type in type_to_fields: 234 for target_node_type in type_to_fields: 235 singular_fields_to_scan = [] 236 repeated_fields_to_scan = [] 237 for field_name, field_type in type_to_fields[current_node_type].items(): 238 # If the target node type cannot contain another instance of itself, it 239 # is still necessary to scan fields that have the actual target type. 240 if (target_node_type == field_type or 241 target_node_type in type_to_descendant_types[field_type]): 242 # Singular and repeated fields go to different lists, so that they can 243 # be handled separately. 244 if (type_fields_to_cardinality[current_node_type, field_name] is not 245 ir_data_fields.FieldContainer.LIST): 246 singular_fields_to_scan.append(field_name) 247 else: 248 repeated_fields_to_scan.append(field_name) 249 fields_to_scan_by_current_and_target[ 250 current_node_type, target_node_type] = ( 251 singular_fields_to_scan, repeated_fields_to_scan) 252 return fields_to_scan_by_current_and_target 253 254 255_FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target() 256 257def _emboss_ir_action(ir): 258 return {"ir": ir} 259 260def _module_action(m): 261 return {"source_file_name": m.source_file_name} 262 263def _type_definition_action(t): 264 return {"type_definition": t} 265 266def _field_action(f): 267 return {"field": f} 268 269def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None, 270 skip_descendants_of=(), parameters=None): 271 """Traverses an IR from the top down, executing the given actions. 272 273 `fast_traverse_ir_top_down` walks the given IR in preorder traversal, 274 specifically looking for nodes whose path from the root of the tree matches 275 `pattern`. For every node which matches `pattern`, `action` will be called. 276 277 `pattern` is just a list of node types. For example, to execute `print` on 278 every `ir_data.Word` in the IR: 279 280 fast_traverse_ir_top_down(ir, [ir_data.Word], print) 281 282 If more than one type is specified, then each one must be found inside the 283 previous. For example, to print only the Words inside of import statements: 284 285 fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print) 286 287 The optional arguments provide additional control. 288 289 `skip_descendants_of` is a list of types that should be treated as if they are 290 leaf nodes when they are encountered. That is, traversal will skip any 291 nodes with any ancestor node whose type is in `skip_descendants_of`. For 292 example, to `do_something` only on outermost `Expression`s: 293 294 fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something, 295 skip_descendants_of={ir_data.Expression}) 296 297 `parameters` specifies a dictionary of initial parameters which can be passed 298 as arguments to `action` and `incidental_actions`. Note that the parameters 299 can be overridden for parts of the tree by `action` and `incidental_actions`. 300 Parameters can be used to set an object which may be updated by `action`, such 301 as a list of errors generated by some check in `action`: 302 303 def check_structure(structure, errors): 304 if structure_is_bad(structure): 305 errors.append(error_for_structure(structure)) 306 307 errors = [] 308 fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure, 309 parameters={"errors": errors}) 310 if errors: 311 print("Errors: {}".format(errors)) 312 sys.exit(1) 313 314 `incidental_actions` is a map from node types to functions (or tuples of 315 functions or lists of functions) which should be called on those nodes. 316 Because `fast_traverse_ir_top_down` may skip branches that can't contain 317 `pattern`, functions in `incidental_actions` should generally not have any 318 side effects: instead, they may return a dictionary, which will be used to 319 override `parameters` for any children of the node they were called on. For 320 example: 321 322 def do_something(expression, field_name=None): 323 if field_name: 324 print("Found {} inside {}".format(expression, field_name)) 325 else: 326 print("Found {} not in any field".format(expression)) 327 328 fast_traverse_ir_top_down( 329 ir, [ir_data.Expression], do_something, 330 incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}}) 331 332 (The `action` may also return a dict in the same way.) 333 334 A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so 335 that certain parameters are contextually available with well-known names: 336 337 ir: The complete IR (the root ir_data.EmbossIr node). 338 source_file_name: The file name from which the current node was sourced. 339 type_definition: The most-immediate ancestor type definition. 340 field: The field containing the current node, if any. 341 342 Arguments: 343 ir: An ir_data.Ir object to walk. 344 pattern: A list of node types to match. 345 action: A callable, which will be called on nodes matching `pattern`. 346 incidental_actions: A dict of node types to callables, which can be used to 347 set new parameters for `action` for part of the IR tree. 348 skip_descendants_of: A list of types whose children should be skipped when 349 traversing `ir`. 350 parameters: A list of top-level parameters. 351 352 Returns: 353 None 354 """ 355 all_incidental_actions = { 356 ir_data.EmbossIr: [_emboss_ir_action], 357 ir_data.Module: [_module_action], 358 ir_data.TypeDefinition: [_type_definition_action], 359 ir_data.Field: [_field_action], 360 } 361 if incidental_actions: 362 for key, incidental_action in incidental_actions.items(): 363 if not isinstance(incidental_action, (list, tuple)): 364 incidental_action = [incidental_action] 365 all_incidental_actions.setdefault(key, []).extend(incidental_action) 366 _fast_traverse_proto_top_down(ir, all_incidental_actions, pattern, 367 skip_descendants_of, action, parameters or {}) 368 369 370def fast_traverse_node_top_down(node, pattern, action, incidental_actions=None, 371 skip_descendants_of=(), parameters=None): 372 """Traverse a subtree of an IR, executing the given actions. 373 374 fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that: 375 376 It may be called on a subtree, instead of the top of the IR. 377 378 It does not have any built-in incidental actions. 379 380 Arguments: 381 node: An ir_data.Ir object to walk. 382 pattern: A list of node types to match. 383 action: A callable, which will be called on nodes matching `pattern`. 384 incidental_actions: A dict of node types to callables, which can be used to 385 set new parameters for `action` for part of the IR tree. 386 skip_descendants_of: A list of types whose children should be skipped when 387 traversing `node`. 388 parameters: A list of top-level parameters. 389 390 Returns: 391 None 392 """ 393 _fast_traverse_proto_top_down(node, incidental_actions or {}, pattern, 394 skip_descendants_of or {}, action, 395 parameters or {}) 396