xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import csv
2from abc import ABC, abstractmethod
3
4from fbscribelogger import make_scribe_logger
5
6import torch._C._instruction_counter as i_counter
7import torch._dynamo.config as config
8from torch._dynamo.utils import CompileTimeInstructionCounter
9
10
11scribe_log_torch_benchmark_compile_time = make_scribe_logger(
12    "TorchBenchmarkCompileTime",
13    """
14struct TorchBenchmarkCompileTimeLogEntry {
15
16  # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA.
17  4: optional string commit_sha;
18
19  # The unit timestamp in second for the Scuba Time Column override
20  6: optional i64 time;
21  7: optional i64 instruction_count; # Instruction count of compilation step
22  8: optional string name; # Benchmark name
23
24  # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105.  Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed.
25  16: optional i64 commit_date;
26
27  # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID.
28  17: optional string github_run_id;
29
30  # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT.
31  18: optional string github_run_attempt;
32
33  # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED.
34  20: optional bool github_ref_protected;
35
36  # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF.
37  21: optional string github_ref;
38
39  # The weight of the record according to current sampling rate
40  25: optional i64 weight;
41
42  # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, amz2023.linux.2xlarge).
43  26: optional string github_job;
44
45  # The GitHub user who triggered the job.  Derived from GITHUB_TRIGGERING_ACTOR.
46  27: optional string github_triggering_actor;
47
48  # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER.
49  28: optional string github_run_number_str;
50}
51""",  # noqa: B950
52)
53
54
55class BenchmarkBase(ABC):
56    # measure total number of instruction spent in _work.
57    _enable_instruction_count = False
58
59    # measure total number of instruction spent in convert_frame.compile_inner
60    # TODO is there other parts we need to add ?
61    _enable_compile_time_instruction_count = False
62
63    def enable_instruction_count(self):
64        self._enable_instruction_count = True
65        return self
66
67    def enable_compile_time_instruction_count(self):
68        self._enable_compile_time_instruction_count = True
69        return self
70
71    def name(self):
72        return ""
73
74    def description(self):
75        return ""
76
77    @abstractmethod
78    def _prepare(self):
79        pass
80
81    @abstractmethod
82    def _work(self):
83        pass
84
85    def _prepare_once(self):  # noqa: B027
86        pass
87
88    def _count_instructions(self):
89        print(f"collecting instruction count for {self.name()}")
90        results = []
91        for i in range(10):
92            self._prepare()
93            id = i_counter.start()
94            self._work()
95            count = i_counter.end(id)
96            print(f"instruction count for iteration {i} is {count}")
97            results.append(count)
98        return min(results)
99
100    def _count_compile_time_instructions(self):
101        print(f"collecting compile time instruction count for {self.name()}")
102        config.record_compile_time_instruction_count = True
103
104        results = []
105        for i in range(10):
106            self._prepare()
107            # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
108            # hence this will only count instruction count spent in compile_inner.
109            CompileTimeInstructionCounter.clear()
110            self._work()
111            count = CompileTimeInstructionCounter.value()
112            if count == 0:
113                raise RuntimeError(
114                    "compile time instruction count is 0, please check your benchmarks"
115                )
116            print(f"compile time instruction count for iteration {i} is {count}")
117            results.append(count)
118
119        config.record_compile_time_instruction_count = False
120        return min(results)
121
122    def append_results(self, path):
123        with open(path, "a", newline="") as csvfile:
124            # Create a writer object
125            writer = csv.writer(csvfile)
126            # Write the data to the CSV file
127            for entry in self.results:
128                writer.writerow(entry)
129
130    def print(self):
131        for entry in self.results:
132            print(f"{entry[0]},{entry[1]},{entry[2]}")
133
134    def collect_all(self):
135        self._prepare_once()
136        self.results = []
137        if (
138            self._enable_instruction_count
139            and self._enable_compile_time_instruction_count
140        ):
141            raise RuntimeError(
142                "not supported until we update the logger, both logs to the same field now"
143            )
144
145        if self._enable_instruction_count:
146            r = self._count_instructions()
147            self.results.append((self.name(), "instruction_count", r))
148            scribe_log_torch_benchmark_compile_time(
149                name=self.name(),
150                instruction_count=r,
151            )
152        if self._enable_compile_time_instruction_count:
153            r = self._count_compile_time_instructions()
154
155            self.results.append(
156                (
157                    self.name(),
158                    "compile_time_instruction_count",
159                    r,
160                )
161            )
162            # TODO add a new field compile_time_instruction_count to the logger.
163            scribe_log_torch_benchmark_compile_time(
164                name=self.name(),
165                instruction_count=r,
166            )
167        return self
168