xref: /aosp_15_r20/external/pytorch/tools/test/test_upload_stats_lib.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import decimal
4import inspect
5import sys
6import unittest
7from pathlib import Path
8from typing import Any
9from unittest import mock
10
11
12REPO_ROOT = Path(__file__).resolve().parent.parent.parent
13sys.path.insert(0, str(REPO_ROOT))
14
15from tools.stats.upload_metrics import add_global_metric, emit_metric
16from tools.stats.upload_stats_lib import BATCH_SIZE, upload_to_rockset
17
18
19sys.path.remove(str(REPO_ROOT))
20
21# default values
22REPO = "some/repo"
23BUILD_ENV = "cuda-10.2"
24TEST_CONFIG = "test-config"
25WORKFLOW = "some-workflow"
26JOB = "some-job"
27RUN_ID = 56
28RUN_NUMBER = 123
29RUN_ATTEMPT = 3
30PR_NUMBER = 6789
31JOB_ID = 234
32JOB_NAME = "some-job-name"
33
34
35class TestUploadStats(unittest.TestCase):
36    # Before each test, set the env vars to their default values
37    def setUp(self) -> None:
38        mock.patch.dict(
39            "os.environ",
40            {
41                "CI": "true",
42                "BUILD_ENVIRONMENT": BUILD_ENV,
43                "TEST_CONFIG": TEST_CONFIG,
44                "GITHUB_REPOSITORY": REPO,
45                "GITHUB_WORKFLOW": WORKFLOW,
46                "GITHUB_JOB": JOB,
47                "GITHUB_RUN_ID": str(RUN_ID),
48                "GITHUB_RUN_NUMBER": str(RUN_NUMBER),
49                "GITHUB_RUN_ATTEMPT": str(RUN_ATTEMPT),
50                "JOB_ID": str(JOB_ID),
51                "JOB_NAME": str(JOB_NAME),
52            },
53            clear=True,  # Don't read any preset env vars
54        ).start()
55
56    @mock.patch("boto3.Session.resource")
57    def test_emits_default_and_given_metrics(self, mock_resource: Any) -> None:
58        metric = {
59            "some_number": 123,
60            "float_number": 32.34,
61        }
62
63        # Querying for this instead of hard coding it b/c this will change
64        # based on whether we run this test directly from python or from
65        # pytest
66        current_module = inspect.getmodule(inspect.currentframe()).__name__  # type: ignore[union-attr]
67
68        emit_should_include = {
69            "metric_name": "metric_name",
70            "calling_file": "test_upload_stats_lib.py",
71            "calling_module": current_module,
72            "calling_function": "test_emits_default_and_given_metrics",
73            "repo": REPO,
74            "workflow": WORKFLOW,
75            "build_environment": BUILD_ENV,
76            "job": JOB,
77            "test_config": TEST_CONFIG,
78            "run_id": RUN_ID,
79            "run_number": RUN_NUMBER,
80            "run_attempt": RUN_ATTEMPT,
81            "some_number": 123,
82            "float_number": decimal.Decimal(str(32.34)),
83            "job_id": JOB_ID,
84            "job_name": JOB_NAME,
85        }
86
87        # Preserve the metric emitted
88        emitted_metric: dict[str, Any] = {}
89
90        def mock_put_item(Item: dict[str, Any]) -> None:
91            nonlocal emitted_metric
92            emitted_metric = Item
93
94        mock_resource.return_value.Table.return_value.put_item = mock_put_item
95
96        emit_metric("metric_name", metric)
97
98        self.assertEqual(
99            emitted_metric,
100            {**emit_should_include, **emitted_metric},
101        )
102
103    @mock.patch("boto3.Session.resource")
104    def test_when_global_metric_specified_then_it_emits_it(
105        self, mock_resource: Any
106    ) -> None:
107        metric = {
108            "some_number": 123,
109        }
110
111        global_metric_name = "global_metric"
112        global_metric_value = "global_value"
113
114        add_global_metric(global_metric_name, global_metric_value)
115
116        emit_should_include = {
117            **metric,
118            global_metric_name: global_metric_value,
119        }
120
121        # Preserve the metric emitted
122        emitted_metric: dict[str, Any] = {}
123
124        def mock_put_item(Item: dict[str, Any]) -> None:
125            nonlocal emitted_metric
126            emitted_metric = Item
127
128        mock_resource.return_value.Table.return_value.put_item = mock_put_item
129
130        emit_metric("metric_name", metric)
131
132        self.assertEqual(
133            emitted_metric,
134            {**emitted_metric, **emit_should_include},
135        )
136
137    @mock.patch("boto3.Session.resource")
138    def test_when_local_and_global_metric_specified_then_global_is_overridden(
139        self, mock_resource: Any
140    ) -> None:
141        global_metric_name = "global_metric"
142        global_metric_value = "global_value"
143        local_override = "local_override"
144
145        add_global_metric(global_metric_name, global_metric_value)
146
147        metric = {
148            "some_number": 123,
149            global_metric_name: local_override,
150        }
151
152        emit_should_include = {
153            **metric,
154            global_metric_name: local_override,
155        }
156
157        # Preserve the metric emitted
158        emitted_metric: dict[str, Any] = {}
159
160        def mock_put_item(Item: dict[str, Any]) -> None:
161            nonlocal emitted_metric
162            emitted_metric = Item
163
164        mock_resource.return_value.Table.return_value.put_item = mock_put_item
165
166        emit_metric("metric_name", metric)
167
168        self.assertEqual(
169            emitted_metric,
170            {**emitted_metric, **emit_should_include},
171        )
172
173    @mock.patch("boto3.Session.resource")
174    def test_when_optional_envvar_set_to_actual_value_then_emit_vars_emits_it(
175        self, mock_resource: Any
176    ) -> None:
177        metric = {
178            "some_number": 123,
179        }
180
181        emit_should_include = {
182            **metric,
183            "pr_number": PR_NUMBER,
184        }
185
186        mock.patch.dict(
187            "os.environ",
188            {
189                "PR_NUMBER": str(PR_NUMBER),
190            },
191        ).start()
192
193        # Preserve the metric emitted
194        emitted_metric: dict[str, Any] = {}
195
196        def mock_put_item(Item: dict[str, Any]) -> None:
197            nonlocal emitted_metric
198            emitted_metric = Item
199
200        mock_resource.return_value.Table.return_value.put_item = mock_put_item
201
202        emit_metric("metric_name", metric)
203
204        self.assertEqual(
205            emitted_metric,
206            {**emit_should_include, **emitted_metric},
207        )
208
209    @mock.patch("boto3.Session.resource")
210    def test_when_optional_envvar_set_to_a_empty_str_then_emit_vars_ignores_it(
211        self, mock_resource: Any
212    ) -> None:
213        metric = {"some_number": 123}
214
215        emit_should_include: dict[str, Any] = metric.copy()
216
217        # Github Actions defaults some env vars to an empty string
218        default_val = ""
219        mock.patch.dict(
220            "os.environ",
221            {
222                "PR_NUMBER": default_val,
223            },
224        ).start()
225
226        # Preserve the metric emitted
227        emitted_metric: dict[str, Any] = {}
228
229        def mock_put_item(Item: dict[str, Any]) -> None:
230            nonlocal emitted_metric
231            emitted_metric = Item
232
233        mock_resource.return_value.Table.return_value.put_item = mock_put_item
234
235        emit_metric("metric_name", metric)
236
237        self.assertEqual(
238            emitted_metric,
239            {**emit_should_include, **emitted_metric},
240            f"Metrics should be emitted when an option parameter is set to '{default_val}'",
241        )
242        self.assertFalse(
243            emitted_metric.get("pr_number"),
244            f"Metrics should not include optional item 'pr_number' when it's envvar is set to '{default_val}'",
245        )
246
247    @mock.patch("boto3.Session.resource")
248    def test_blocks_emission_if_reserved_keyword_used(self, mock_resource: Any) -> None:
249        metric = {"repo": "awesome/repo"}
250
251        with self.assertRaises(ValueError):
252            emit_metric("metric_name", metric)
253
254    @mock.patch("boto3.Session.resource")
255    def test_no_metrics_emitted_if_required_env_var_not_set(
256        self, mock_resource: Any
257    ) -> None:
258        metric = {"some_number": 123}
259
260        mock.patch.dict(
261            "os.environ",
262            {
263                "CI": "true",
264                "BUILD_ENVIRONMENT": BUILD_ENV,
265            },
266            clear=True,
267        ).start()
268
269        put_item_invoked = False
270
271        def mock_put_item(Item: dict[str, Any]) -> None:
272            nonlocal put_item_invoked
273            put_item_invoked = True
274
275        mock_resource.return_value.Table.return_value.put_item = mock_put_item
276
277        emit_metric("metric_name", metric)
278
279        self.assertFalse(put_item_invoked)
280
281    @mock.patch("boto3.Session.resource")
282    def test_no_metrics_emitted_if_required_env_var_set_to_empty_string(
283        self, mock_resource: Any
284    ) -> None:
285        metric = {"some_number": 123}
286
287        mock.patch.dict(
288            "os.environ",
289            {
290                "GITHUB_JOB": "",
291            },
292        ).start()
293
294        put_item_invoked = False
295
296        def mock_put_item(Item: dict[str, Any]) -> None:
297            nonlocal put_item_invoked
298            put_item_invoked = True
299
300        mock_resource.return_value.Table.return_value.put_item = mock_put_item
301
302        emit_metric("metric_name", metric)
303
304        self.assertFalse(put_item_invoked)
305
306    def test_upload_to_rockset_batch_size(self) -> None:
307        cases = [
308            {
309                "batch_size": BATCH_SIZE - 1,
310                "expected_number_of_requests": 1,
311            },
312            {
313                "batch_size": BATCH_SIZE,
314                "expected_number_of_requests": 1,
315            },
316            {
317                "batch_size": BATCH_SIZE + 1,
318                "expected_number_of_requests": 2,
319            },
320        ]
321
322        for case in cases:
323            mock_client = mock.Mock()
324            mock_client.Documents.add_documents.return_value = "OK"
325
326            batch_size = case["batch_size"]
327            expected_number_of_requests = case["expected_number_of_requests"]
328
329            docs = list(range(batch_size))
330            upload_to_rockset(
331                collection="test", docs=docs, workspace="commons", client=mock_client
332            )
333            self.assertEqual(
334                mock_client.Documents.add_documents.call_count,
335                expected_number_of_requests,
336            )
337
338
339if __name__ == "__main__":
340    unittest.main()
341