xref: /aosp_15_r20/external/emboss/compiler/util/traverse_ir.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Routines for fully traversing an IR."""
16
17import inspect
18
19from compiler.util import ir_data
20from compiler.util import ir_data_fields
21from compiler.util import ir_data_utils
22from compiler.util import simple_memoizer
23
24
25class _FunctionCaller:
26  """Provides a template for setting up a generic call to a function.
27
28  The function parameters are inspected at run-time to build up a set of valid
29  and required arguments. When invoking the function unneccessary parameters
30  will be trimmed out. If arguments are missing an assertion will be triggered.
31
32  This is currently limited to functions that have at least one positional
33  parameter.
34
35  Example usage:
36  ```
37  def func_1(a, b, c=2): pass
38  def func_2(a, d): pass
39  caller_1 = _FunctionCaller(func_1)
40  caller_2 = _FunctionCaller(func_2)
41  generic_params = {"b": 2, "c": 3, "d": 4}
42
43  # Equivalent of: func_1(a, b=2, c=3)
44  caller_1.invoke(a, generic_params)
45
46  # Equivalent of: func_2(a, d=4)
47  caller_2.invoke(a, generic_params)
48  """
49
50  def __init__(self, function):
51    self.function = function
52    self.needs_filtering = True
53    self.valid_arg_names = set()
54    self.required_arg_names = set()
55
56    argspec = inspect.getfullargspec(function)
57    if argspec.varkw:
58      # If the function accepts a kwargs parameter, then it will accept all
59      # arguments.
60      # Note: this isn't technically true if one of the keyword arguments has the
61      # same name as one of the positional arguments.
62      self.needs_filtering = False
63    else:
64      # argspec.args is a list of all parameter names excluding keyword only
65      # args. The first element is our required positional_arg and should be
66      # ignored.
67      args = argspec.args[1:]
68      self.valid_arg_names.update(args)
69
70      # args.kwonlyargs gives us the list of keyword only args which are
71      # also valid.
72      self.valid_arg_names.update(argspec.kwonlyargs)
73
74      # Required args are positional arguments that don't have defaults.
75      # Keyword only args are always optional and can be ignored. Args with
76      # defaults are the last elements of the argsepec.args list and should
77      # be ignored.
78      if argspec.defaults:
79        # Trim the arguments with defaults.
80        args = args[: -len(argspec.defaults)]
81      self.required_arg_names.update(args)
82
83  def invoke(self, positional_arg, keyword_args):
84    """Invokes the function with the given args."""
85    if self.needs_filtering:
86      # Trim to just recognized args.
87      matched_args = {
88          k: v for k, v in keyword_args.items() if k in self.valid_arg_names
89      }
90      # Check if any required args are missing.
91      missing_args = self.required_arg_names.difference(matched_args.keys())
92      assert not missing_args, (
93          f"Attempting to call '{self.function.__name__}'; "
94          f"missing {missing_args} (have {set(keyword_args.keys())})"
95      )
96      keyword_args = matched_args
97
98    return self.function(positional_arg, **keyword_args)
99
100
101@simple_memoizer.memoize
102def _memoized_caller(function):
103  default_lambda_name = (lambda: None).__name__
104  assert (
105      callable(function) and not function.__name__ == default_lambda_name
106  ), "For performance reasons actions must be defined as static functions"
107  return _FunctionCaller(function)
108
109
110def _call_with_optional_args(function, positional_arg, keyword_args):
111  """Calls function with whatever keyword_args it will accept."""
112  caller = _memoized_caller(function)
113  return caller.invoke(positional_arg, keyword_args)
114
115
116def _fast_traverse_proto_top_down(proto, incidental_actions, pattern,
117                                  skip_descendants_of, action, parameters):
118  """Traverses an IR, calling `action` on some nodes."""
119
120  # Parameters are scoped to the branch of the tree, so make a copy here, before
121  # any action or incidental_action can update them.
122  parameters = parameters.copy()
123
124  # If there is an incidental action for this node type, run it.
125  if type(proto) in incidental_actions:  # pylint: disable=unidiomatic-typecheck
126    for incidental_action in incidental_actions[type(proto)]:
127      parameters.update(_call_with_optional_args(
128          incidental_action, proto, parameters) or {})
129
130  # If we are at the end of pattern, check to see if we should call action.
131  if len(pattern) == 1:
132    new_pattern = pattern
133    if pattern[0] == type(proto):
134      parameters.update(
135          _call_with_optional_args(action, proto, parameters) or {})
136  else:
137    # Otherwise, if this node's type matches the head of pattern, recurse with
138    # the tail of the pattern.
139    if pattern[0] == type(proto):
140      new_pattern = pattern[1:]
141    else:
142      new_pattern = pattern
143
144  # If the current node's type is one of the types whose branch should be
145  # skipped, then bail.  This has to happen after `action` is called, because
146  # clients rely on being able to, e.g., get a callback for the "root"
147  # Expression without getting callbacks for every sub-Expression.
148  # pylint: disable=unidiomatic-typecheck
149  if type(proto) in skip_descendants_of:
150    return
151
152  # Otherwise, recurse.  _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given
153  # the current node's type and the current target type, which fields to check.
154  singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[
155      type(proto), new_pattern[0]]
156  for member_name in singular_fields:
157    if proto.HasField(member_name):
158      _fast_traverse_proto_top_down(getattr(proto, member_name),
159                                    incidental_actions, new_pattern,
160                                    skip_descendants_of, action, parameters)
161  for member_name in repeated_fields:
162    for array_element in getattr(proto, member_name) or []:
163      _fast_traverse_proto_top_down(array_element, incidental_actions,
164                                    new_pattern, skip_descendants_of, action,
165                                    parameters)
166
167
168def _fields_to_scan_by_current_and_target():
169  """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET."""
170  # In order to avoid spending a *lot* of time just walking the IR, this
171  # function sets up a dict that allows `_fast_traverse_proto_top_down()` to
172  # skip traversing large portions of the IR, depending on what node types it is
173  # targeting.
174  #
175  # Without this branch culling scheme, the Emboss front end (at time of
176  # writing) spends roughly 70% (19s out of 31s) of its time just walking the
177  # IR.  With branch culling, that goes down to 6% (0.7s out of 12.2s).
178
179  # type_to_fields is a map of types to maps of field names to field types.
180  # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit.
181  type_to_fields = {}
182
183  # Later, we need to know which fields are singular and which are repeated,
184  # because the access methods are not uniform.  This maps (type, field_name)
185  # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module,
186  # "type"] == ir_data.Repeated.
187  type_fields_to_cardinality = {}
188
189  # Fill out the above maps by recursively walking the IR type tree, starting
190  # from the root.
191  types_to_check = [ir_data.EmbossIr]
192  while types_to_check:
193    type_to_check = types_to_check.pop()
194    if type_to_check in type_to_fields:
195      continue
196    fields = {}
197    for field_name, field_type in ir_data_utils.field_specs(type_to_check).items():
198      if field_type.is_dataclass:
199        fields[field_name] = field_type.data_type
200        types_to_check.append(field_type.data_type)
201        type_fields_to_cardinality[type_to_check, field_name] = (
202            field_type.container)
203    type_to_fields[type_to_check] = fields
204
205  # type_to_descendant_types is a map of all types that can be reached from a
206  # particular type.  After the setup, type_to_descendant_types[ir_data.EmbossIr]
207  # == set(<all types>) and type_to_descendant_types[ir_data.Reference] ==
208  # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and
209  # type_to_descendant_types[ir_data.Word] == set().
210  #
211  # The while loop basically ors in the known descendants of each known
212  # descendant of each type until the dict stops changing, which is a bit
213  # brute-force, but in practice only iterates a few times.
214  type_to_descendant_types = {}
215  for parent_type, field_map in type_to_fields.items():
216    type_to_descendant_types[parent_type] = set(field_map.values())
217  previous_map = {}
218  while type_to_descendant_types != previous_map:
219    # In order to check the previous iteration against the current iteration, it
220    # is necessary to make a two-level copy.  Otherwise, the updates to the
221    # values will also update previous_map's values, which causes the loop to
222    # exit prematurely.
223    previous_map = {k: set(v) for k, v in type_to_descendant_types.items()}
224    for ancestor_type, descendents in previous_map.items():
225      for descendent in descendents:
226        type_to_descendant_types[ancestor_type] |= previous_map[descendent]
227
228  # Finally, we have all of the information we need to make the map we really
229  # want: given a current node type and a target node type, which fields should
230  # be checked?  (This implicitly skips fields that *can't* contain the target
231  # type.)
232  fields_to_scan_by_current_and_target = {}
233  for current_node_type in type_to_fields:
234    for target_node_type in type_to_fields:
235      singular_fields_to_scan = []
236      repeated_fields_to_scan = []
237      for field_name, field_type in type_to_fields[current_node_type].items():
238        # If the target node type cannot contain another instance of itself, it
239        # is still necessary to scan fields that have the actual target type.
240        if (target_node_type == field_type or
241            target_node_type in type_to_descendant_types[field_type]):
242          # Singular and repeated fields go to different lists, so that they can
243          # be handled separately.
244          if (type_fields_to_cardinality[current_node_type, field_name] is not
245              ir_data_fields.FieldContainer.LIST):
246            singular_fields_to_scan.append(field_name)
247          else:
248            repeated_fields_to_scan.append(field_name)
249      fields_to_scan_by_current_and_target[
250          current_node_type, target_node_type] = (
251              singular_fields_to_scan, repeated_fields_to_scan)
252  return fields_to_scan_by_current_and_target
253
254
255_FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target()
256
257def _emboss_ir_action(ir):
258  return {"ir": ir}
259
260def _module_action(m):
261  return {"source_file_name": m.source_file_name}
262
263def _type_definition_action(t):
264  return {"type_definition": t}
265
266def _field_action(f):
267  return {"field": f}
268
269def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None,
270                              skip_descendants_of=(), parameters=None):
271  """Traverses an IR from the top down, executing the given actions.
272
273  `fast_traverse_ir_top_down` walks the given IR in preorder traversal,
274  specifically looking for nodes whose path from the root of the tree matches
275  `pattern`.  For every node which matches `pattern`, `action` will be called.
276
277  `pattern` is just a list of node types.  For example, to execute `print` on
278  every `ir_data.Word` in the IR:
279
280      fast_traverse_ir_top_down(ir, [ir_data.Word], print)
281
282  If more than one type is specified, then each one must be found inside the
283  previous.  For example, to print only the Words inside of import statements:
284
285      fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print)
286
287  The optional arguments provide additional control.
288
289  `skip_descendants_of` is a list of types that should be treated as if they are
290  leaf nodes when they are encountered.  That is, traversal will skip any
291  nodes with any ancestor node whose type is in `skip_descendants_of`.  For
292  example, to `do_something` only on outermost `Expression`s:
293
294      fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something,
295                                skip_descendants_of={ir_data.Expression})
296
297  `parameters` specifies a dictionary of initial parameters which can be passed
298  as arguments to `action` and `incidental_actions`.  Note that the parameters
299  can be overridden for parts of the tree by `action` and `incidental_actions`.
300  Parameters can be used to set an object which may be updated by `action`, such
301  as a list of errors generated by some check in `action`:
302
303      def check_structure(structure, errors):
304        if structure_is_bad(structure):
305          errors.append(error_for_structure(structure))
306
307      errors = []
308      fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure,
309                                parameters={"errors": errors})
310      if errors:
311        print("Errors: {}".format(errors))
312        sys.exit(1)
313
314  `incidental_actions` is a map from node types to functions (or tuples of
315  functions or lists of functions) which should be called on those nodes.
316  Because `fast_traverse_ir_top_down` may skip branches that can't contain
317  `pattern`, functions in `incidental_actions` should generally not have any
318  side effects: instead, they may return a dictionary, which will be used to
319  override `parameters` for any children of the node they were called on.  For
320  example:
321
322      def do_something(expression, field_name=None):
323        if field_name:
324          print("Found {} inside {}".format(expression, field_name))
325        else:
326          print("Found {} not in any field".format(expression))
327
328      fast_traverse_ir_top_down(
329          ir, [ir_data.Expression], do_something,
330          incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}})
331
332  (The `action` may also return a dict in the same way.)
333
334  A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so
335  that certain parameters are contextually available with well-known names:
336
337      ir: The complete IR (the root ir_data.EmbossIr node).
338      source_file_name: The file name from which the current node was sourced.
339      type_definition: The most-immediate ancestor type definition.
340      field: The field containing the current node, if any.
341
342  Arguments:
343    ir: An ir_data.Ir object to walk.
344    pattern: A list of node types to match.
345    action: A callable, which will be called on nodes matching `pattern`.
346    incidental_actions: A dict of node types to callables, which can be used to
347        set new parameters for `action` for part of the IR tree.
348    skip_descendants_of: A list of types whose children should be skipped when
349        traversing `ir`.
350    parameters: A list of top-level parameters.
351
352  Returns:
353    None
354  """
355  all_incidental_actions = {
356      ir_data.EmbossIr: [_emboss_ir_action],
357      ir_data.Module: [_module_action],
358      ir_data.TypeDefinition: [_type_definition_action],
359      ir_data.Field: [_field_action],
360  }
361  if incidental_actions:
362    for key, incidental_action in incidental_actions.items():
363      if not isinstance(incidental_action, (list, tuple)):
364        incidental_action = [incidental_action]
365      all_incidental_actions.setdefault(key, []).extend(incidental_action)
366  _fast_traverse_proto_top_down(ir, all_incidental_actions, pattern,
367                                skip_descendants_of, action, parameters or {})
368
369
370def fast_traverse_node_top_down(node, pattern, action, incidental_actions=None,
371                                skip_descendants_of=(), parameters=None):
372  """Traverse a subtree of an IR, executing the given actions.
373
374  fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that:
375
376  It may be called on a subtree, instead of the top of the IR.
377
378  It does not have any built-in incidental actions.
379
380  Arguments:
381    node: An ir_data.Ir object to walk.
382    pattern: A list of node types to match.
383    action: A callable, which will be called on nodes matching `pattern`.
384    incidental_actions: A dict of node types to callables, which can be used to
385        set new parameters for `action` for part of the IR tree.
386    skip_descendants_of: A list of types whose children should be skipped when
387        traversing `node`.
388    parameters: A list of top-level parameters.
389
390  Returns:
391    None
392  """
393  _fast_traverse_proto_top_down(node, incidental_actions or {}, pattern,
394                                skip_descendants_of or {}, action,
395                                parameters or {})
396