xref: /aosp_15_r20/external/executorch/backends/cadence/runtime/runtime.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import logging
9import numbers
10import os
11import tempfile
12from typing import Any, Optional, Sequence, Tuple, Union
13
14import executorch.exir.schema as et_schema
15
16import numpy as np
17import torch
18
19from executorch.backends.cadence.runtime import utils
20from executorch.backends.cadence.runtime.executor import Executor
21from executorch.devtools import Inspector
22from executorch.exir import ExecutorchProgramManager
23from executorch.exir._serialize._program import deserialize_pte_binary
24from executorch.exir.schema import DataLocation
25
26from numpy import ndarray
27
28from torch.utils._pytree import TreeSpec
29
30
31class CadenceETDump:
32    def __init__(self, output_dir: str) -> None:
33        self.tensor_dump_dir: str = os.path.join(output_dir, "tensors")
34        self.etdump_path: str = os.path.join(output_dir, "etdump.etdp")
35        self.etrecord_path: Optional[str] = os.path.join(output_dir, "etrecord.bin")
36        self.debug_buffer_path: Optional[str] = os.path.join(
37            output_dir, "debug_output.bin"
38        )
39
40        if not os.path.exists(self.etdump_path):
41            raise RuntimeError(f"{self.etdump_path} does not exist")
42        # pyre-ignore[6]: os.path.exists expects str, but got Optional[str]
43        if not os.path.exists(self.etrecord_path):
44            logging.warning(
45                "ETRecord not found, intermediate tensors will not be dumped"
46            )
47            self.etrecord_path = None
48        # pyre-ignore[6]: os.path.exists expects str, but got Optional[str]
49        if not os.path.exists(self.debug_buffer_path):
50            logging.warning(
51                "Debug buffer not found, intermediate tensors will not be dumped"
52            )
53            self.debug_buffer_path = None
54
55        self.et_inspector: Inspector = Inspector(
56            etdump_path=self.etdump_path,
57            debug_buffer_path=self.debug_buffer_path,
58            etrecord=self.etrecord_path,
59        )
60
61    def get_outputs(self, log_to_stdout: bool = False) -> Tuple[torch.Tensor]:
62        output = [
63            event_block.run_output
64            for event_block in self.et_inspector.event_blocks
65            if event_block.name == "Execute"
66        ]
67        logging.debug(f"[ETdump] output: {output}")
68        return output[0]
69
70    def print_event_block(self) -> None:
71        logging.debug("[ETdump] data tabular:")
72        if logging.getLogger().level <= logging.DEBUG:
73            self.et_inspector.print_data_tabular()
74
75    def print_event_data(self) -> None:
76        logging.debug("[ETdump] event data ")
77        for event_block in self.et_inspector.event_blocks:
78            for event in event_block.events:
79                logging.debug(event)
80
81    def dump_intermediate_tensors(self) -> None:
82        if self.etrecord_path is None:
83            logging.info("[ETdump] Intermediate tensors not available")
84            return
85
86        logging.info(f"[ETdump] Dumping intermediate tensors to {self.tensor_dump_dir}")
87        os.makedirs(self.tensor_dump_dir, exist_ok=True)
88        exec_blocks = [
89            eb for eb in self.et_inspector.event_blocks if eb.name == "Execute"
90        ]
91        if len(exec_blocks) > 1:
92            logging.warning(
93                f'Found {len(exec_blocks)} "Execute" blocks, using the first one and ignoring the rest.'
94            )
95        block = exec_blocks[0]
96
97        # OPERATOR_CALL events are duplicates that contain framework tax data. We don't need them
98        op_events = [e for e in block.events if e.name != "OPERATOR_CALL"]
99        torch.set_printoptions(profile="full")
100
101        for event in op_events:
102            instr_id = event._instruction_id
103            if not event.debug_data:
104                logging.debug(
105                    f"Missing intermediate tensor data for {event.name} ({instr_id=})"
106                )
107                continue
108
109            with open(f"{self.tensor_dump_dir}/{instr_id}.txt", "w") as f:
110                for dd in event.debug_data:
111                    f.write(f"{str(dd)}\n\n")
112        torch.set_printoptions(profile="default")
113
114
115def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[str]:
116    """
117    Get the list of operators from a Program
118    """
119
120    op_names = {
121        f"{op.name}.{op.overload}"
122        for op in program.execution_plan[execution_plan_id].operators
123    }
124    for delegate in program.execution_plan[execution_plan_id].delegates:
125        logging.debug(f"Delegate: {delegate.id}")
126        if delegate.id == "CadenceExecutorchBackend":
127            assert delegate.processed.location == DataLocation.INLINE
128            op_names |= get_op_names(
129                deserialize_pte_binary(
130                    program.backend_delegate_data[delegate.processed.index].data
131                )
132            )
133    return op_names
134
135
136# Run an ExecutorchProgram using the specified inputs and backend
137def run(
138    executorch_prog: ExecutorchProgramManager,
139    inputs: Any,
140    ref_outputs: Optional[Sequence[torch.Tensor]] = None,
141    working_dir: Optional[str] = None,
142) -> Any:
143    # Get the Program
144    program = executorch_prog.executorch_program
145    out_spec = executorch_prog.exported_program().call_spec.out_spec
146    # Run the program and return the outputs
147    assert isinstance(
148        program, et_schema.Program
149    ), f"program must be Program. Got {type(program)} instead."
150
151    if working_dir is None:
152        working_dir = tempfile.mkdtemp(dir="/tmp")
153
154    # initialize e2e Executor with executorch_cfg.
155    executor = Executor(working_dir)
156
157    # run Executor
158    executor()
159
160    etdump = CadenceETDump(output_dir=working_dir)
161    outputs = etdump.get_outputs()
162
163    assert isinstance(out_spec, TreeSpec)
164    outputs = torch.utils._pytree.tree_unflatten(outputs, out_spec)
165
166    return outputs
167
168
169def compare(
170    # pyre-fixme[2]: Parameter annotation cannot be `Any`.
171    outputs: Any,
172    # pyre-fixme[2]: Parameter annotation cannot be `Any`.
173    ref_outputs: Any,
174    name: str = "",
175    eps_error: float = 1e-1,
176    eps_warn: float = 1e-5,
177) -> None:
178    if isinstance(ref_outputs, dict):
179        for k, v in outputs.items():
180            compare(v, ref_outputs[k], f"{name}/{k}", eps_error, eps_warn)
181        return
182
183    if isinstance(ref_outputs, (list, tuple)):
184        for i, (output, ref_output) in enumerate(zip(outputs, ref_outputs)):
185            compare(output, ref_output, f"{name}/{i}", eps_error, eps_warn)
186        return
187
188    assert isinstance(ref_outputs, torch.Tensor), f"Got {type(ref_outputs)} instead."
189
190    ref_outputs = to_nd_array(ref_outputs)
191    outputs = to_nd_array(outputs)
192
193    # compare
194    rms = utils.rms(outputs, ref_outputs)
195    norm_rms = utils.normalized_rms(outputs, ref_outputs)
196    max_abs_diff = utils.max_abs_diff(outputs, ref_outputs)
197    max_rel_diff = utils.max_rel_diff(outputs, ref_outputs)
198    stats = (
199        f"{rms = }, {norm_rms = }, {max_abs_diff = }, {max_rel_diff = :.2f}%, "
200        f"{outputs.shape = }[{outputs.dtype}], {ref_outputs.shape = }[{ref_outputs.dtype}]"
201    )
202
203    if np.isnan(rms) or rms > eps_error:
204        logging.error(f"\33[31m[Error]\33[0m Output {name} mismatched! {stats}")
205        logging.error(f"Expected: {ref_outputs}\n")
206        logging.error(f"Got instead: {outputs}\n")
207        raise RuntimeError(f"\33[31m[Error]\33[0m Output {name} mismatched! {stats}")
208    elif rms > eps_warn:
209        logging.warning(f"\33[33m[Warning]\33[0m Output {name} mismatched!. {stats}")
210    else:
211        logging.info(f"\33[32m[Passed]\33[0m Output {name} matched. {stats}")
212
213
214def run_and_compare(
215    executorch_prog: ExecutorchProgramManager,
216    inputs: Any,
217    ref_outputs: Optional[Sequence[torch.Tensor]] = None,
218    working_dir: Optional[str] = None,
219    eps_error: float = 1e-1,
220    eps_warn: float = 1e-5,
221) -> Any:
222    outputs = run(executorch_prog, inputs, ref_outputs, working_dir)
223    compare(outputs, ref_outputs, eps_error=eps_error, eps_warn=eps_warn)
224
225
226# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
227def to_nd_array(v: Union[bool, numbers.Number, ndarray, torch.Tensor]) -> np.ndarray:
228    if isinstance(v, np.ndarray):
229        return v
230
231    if isinstance(v, torch.Tensor):
232        # If v was quantized, we compare its int representation.
233        v = v.int_repr() if v.is_quantized else v
234        return v.cpu().detach().numpy()
235
236    if isinstance(v, (numbers.Number, bool)):
237        return np.array([v])
238
239    raise RuntimeError(f"Unknown type {type(v)}")
240