xref: /aosp_15_r20/external/pytorch/test/cpp_api_parity/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import shutil
3import unittest
4import warnings
5from collections import namedtuple
6
7import torch
8import torch.testing._internal.common_nn as common_nn
9import torch.utils.cpp_extension
10from torch.testing._internal.common_cuda import TEST_CUDA
11
12
13# Note that this namedtuple is for C++ parity test mechanism's internal use.
14# For guidance on how to add a new C++ parity test, please see
15# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
16TorchNNModuleTestParams = namedtuple(
17    "TorchNNModuleTestParams",
18    [
19        # NN module name (e.g. "BCELoss")
20        "module_name",
21        # Unique identifier for this module config (e.g. "BCELoss_weights_cuda")
22        "module_variant_name",
23        # An instance of an NN test class (e.g. `CriterionTest`) which stores
24        # necessary information (e.g. input / target / extra_args) for running the Python test
25        "test_instance",
26        # Constructor arguments passed to the C++ module constructor, which must be
27        # strictly equivalent to the Python module constructor arguments
28        # (e.g. `torch::nn::BCELossOptions().weight(torch::rand(10))`,
29        # which is strictly equivalent to passing `torch.rand(10)` to `torch.nn.BCELoss`
30        # constructor in Python)
31        "cpp_constructor_args",
32        # All arguments used in NN module's forward pass.
33        # Please see `compute_arg_dict` function for details on how we construct this dict.
34        # (e.g.
35        # ```
36        # arg_dict = {
37        #     'input': [python_input_tensor],
38        #     'target': [python_target_tensor],
39        #     'extra_args': [],
40        #     'other': [],
41        # }
42        # ```
43        # )
44        "arg_dict",
45        # Whether we expect this NN module test to pass the Python/C++ parity test
46        # (e.g. `True`)
47        "has_parity",
48        # Device (e.g. "cuda")
49        "device",
50        # Temporary folder to store C++ outputs (to be compared with Python outputs later)
51        "cpp_tmp_folder",
52    ],
53)
54
55# Note that this namedtuple is for C++ parity test mechanism's internal use.
56# For guidance on how to add a new C++ parity test, please see
57# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
58TorchNNFunctionalTestParams = namedtuple(
59    "TorchNNFunctionalTestParams",
60    [
61        # NN functional name (e.g. "binary_cross_entropy")
62        "functional_name",
63        # Unique identifier for this functional config (e.g. "BCELoss_no_reduce_cuda")
64        "functional_variant_name",
65        # An instance of an NN test class (e.g. `NewModuleTest`) which stores
66        # necessary information (e.g. input / target / extra_args) for running the Python test
67        "test_instance",
68        # The C++ function call that is strictly equivalent to the Python function call
69        # (e.g. "F::binary_cross_entropy(
70        #            i, t.to(i.options()),F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
71        # which is strictly equivalent to `F.binary_cross_entropy(i, t.type_as(i), reduction='none')` in Python)
72        "cpp_function_call",
73        # All arguments used in NN functional's function call.
74        # Please see `compute_arg_dict` function for details on how we construct this dict.
75        # (e.g.
76        # ```
77        # arg_dict = {
78        #     'input': [python_input_tensor],
79        #     'target': [python_target_tensor],
80        #     'extra_args': [],
81        #     'other': [],
82        # }
83        # ```
84        # )
85        "arg_dict",
86        # Whether we expect this NN functional test to pass the Python/C++ parity test
87        # (e.g. `True`)
88        "has_parity",
89        # Device (e.g. "cuda")
90        "device",
91        # Temporary folder to store C++ outputs (to be compared with Python outputs later)
92        "cpp_tmp_folder",
93    ],
94)
95
96CppArg = namedtuple("CppArg", ["name", "value"])
97
98TORCH_NN_COMMON_TEST_HARNESS = """
99#include <torch/script.h>
100
101void write_ivalue_to_file(const torch::IValue& ivalue, const std::string& file_path) {
102    auto bytes = torch::jit::pickle_save(ivalue);
103    std::ofstream fout(file_path, std::ios::out | std::ios::binary);
104    fout.write(bytes.data(), bytes.size());
105    fout.close();
106}
107
108c10::Dict<std::string, torch::Tensor> load_dict_from_file(const std::string& file_path) {
109    c10::Dict<std::string, torch::Tensor> arg_dict;
110    auto arg_dict_module = torch::jit::load(file_path);
111    for (const auto& p : arg_dict_module.named_buffers(/*recurse=*/false)) {
112        arg_dict.insert(p.name, p.value);
113    }
114    return arg_dict;
115}
116
117// Generates rand tensor with non-equal values. This ensures that duplicate
118// values won't be causing test failure for modules like MaxPooling.
119// size should be small, otherwise randperm fails / long overflows.
120torch::Tensor _rand_tensor_non_equal(torch::IntArrayRef size) {
121    int64_t total = 1;
122    for (int64_t elem : size) {
123        total *= elem;
124    }
125    return torch::randperm(total).view(size).to(torch::kDouble);
126}
127"""
128
129
130def compile_cpp_code_inline(name, cpp_sources, functions):
131    cpp_module = torch.utils.cpp_extension.load_inline(
132        name=name,
133        cpp_sources=cpp_sources,
134        extra_cflags=[
135            "-g"
136        ],  # Enable debug symbols by default for debugging test failures.
137        functions=functions,
138        verbose=False,
139    )
140    return cpp_module
141
142
143def compute_temp_file_path(cpp_tmp_folder, variant_name, file_suffix):
144    return os.path.join(cpp_tmp_folder, f"{variant_name}_{file_suffix}.pt")
145
146
147def is_torch_nn_functional_test(test_params_dict):
148    return "wrap_functional" in str(test_params_dict.get("constructor", ""))
149
150
151def convert_to_list(python_input):
152    if isinstance(python_input, torch.Tensor):
153        return [python_input]
154    else:
155        return list(python_input)
156
157
158def set_python_tensors_requires_grad(python_tensors):
159    return [
160        tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor
161        for tensor in python_tensors
162    ]
163
164
165def move_python_tensors_to_device(python_tensors, device):
166    return [tensor.to(device) for tensor in python_tensors]
167
168
169def has_test(unit_test_class, test_name):
170    return hasattr(unit_test_class, test_name)
171
172
173def add_test(unit_test_class, test_name, test_fn):
174    if has_test(unit_test_class, test_name):
175        raise RuntimeError("Found two tests with the same name: " + test_name)
176    setattr(unit_test_class, test_name, test_fn)
177
178
179def set_cpp_tensors_requires_grad(cpp_tensor_stmts, python_tensors):
180    assert len(cpp_tensor_stmts) == len(python_tensors)
181    return [
182        f"{tensor_stmt}.requires_grad_(true)"
183        if tensor.dtype != torch.long
184        else tensor_stmt
185        for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors)
186    ]
187
188
189def move_cpp_tensors_to_device(cpp_tensor_stmts, device):
190    return [f'{tensor_stmt}.to("{device}")' for tensor_stmt in cpp_tensor_stmts]
191
192
193def is_criterion_test(test_instance):
194    return isinstance(test_instance, common_nn.CriterionTest)
195
196
197# This function computes the following:
198# - What variable declaration statements should show up in the C++ parity test function
199# - What arguments should be passed into the C++ module/functional's forward function
200#
201# For example, for the "L1Loss" test, the return values from this function are:
202# ```
203# // Note that `arg_dict` stores all tensor values we transfer from Python to C++
204# cpp_args_construction_stmts = [
205#   "auto i0 = arg_dict.at("i0").to("cpu").requires_grad_(true)",
206#   "auto t0 = arg_dict.at("t0").to("cpu")",
207# ],
208# cpp_forward_args_symbols = [
209#   "i0",
210#   "t0",
211# ]
212# ```
213def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params):
214    device = test_params.device
215    cpp_forward_args_symbols = []
216
217    def add_cpp_forward_args(args):
218        args_stmts = []
219        for arg_name, _ in args:
220            args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")')
221            cpp_forward_args_symbols.append(arg_name)
222        return args_stmts
223
224    cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad(
225        move_cpp_tensors_to_device(
226            add_cpp_forward_args(test_params.arg_dict["input"]), device
227        ),
228        test_params.arg_dict["input"],
229    )
230    cpp_forward_target_args_stmts = move_cpp_tensors_to_device(
231        add_cpp_forward_args(test_params.arg_dict["target"]), device
232    )
233    cpp_forward_extra_args_stmts = move_cpp_tensors_to_device(
234        add_cpp_forward_args(test_params.arg_dict["extra_args"]), device
235    )
236
237    # Build the list of other arguments needed
238    cpp_other_args_stmts = []
239    for arg_name, _ in test_params.arg_dict["other"]:
240        cpp_other_args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")')
241    cpp_other_args_stmts = move_cpp_tensors_to_device(cpp_other_args_stmts, device)
242
243    cpp_args_construction_stmts = (
244        cpp_forward_input_args_stmts
245        + cpp_forward_target_args_stmts
246        + cpp_forward_extra_args_stmts
247        + cpp_other_args_stmts
248    )
249
250    return cpp_args_construction_stmts, cpp_forward_args_symbols
251
252
253def serialize_arg_dict_as_script_module(arg_dict):
254    arg_dict_flat = dict(
255        arg_dict["input"]
256        + arg_dict["target"]
257        + arg_dict["extra_args"]
258        + arg_dict["other"]
259    )
260    arg_dict_module = torch.nn.Module()
261    for arg_name, arg_value in arg_dict_flat.items():
262        assert isinstance(arg_value, torch.Tensor)
263        arg_dict_module.register_buffer(arg_name, arg_value)
264
265    return torch.jit.script(arg_dict_module)
266
267
268# NOTE: any argument symbol used in `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call`
269# must have a mapping in `cpp_var_map`.
270#
271# The mapping can take one of the following formats:
272#
273# 1. `argument_name` -> Python value
274# 2. `argument_name` -> '_get_input()' (which means `argument_name` in C++ will be bound to `test_instance._get_input()`)
275#
276# For example:
277# ```
278# def bceloss_weights_no_reduce_test():
279#     t = torch.randn(15, 10).gt(0).double()
280#     weights = torch.rand(10)
281#     return dict(
282#         fullname='BCELoss_weights_no_reduce',
283#         constructor=wrap_functional(
284#             lambda i: F.binary_cross_entropy(i, t.type_as(i),
285#                                              weight=weights.type_as(i), reduction='none')),
286#         cpp_function_call='''F::binary_cross_entropy(
287#                              i, t.to(i.options()),
288#                              F::BinaryCrossEntropyFuncOptions()
289#                              .weight(weights.to(i.options()))
290#                              .reduction(torch::kNone))''',
291#         input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
292#         cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
293#         reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
294#     )
295# ```
296def compute_arg_dict(test_params_dict, test_instance):
297    arg_dict = {
298        "input": [],
299        "target": [],
300        "extra_args": [],
301        "other": [],
302    }
303
304    def put_args_into_arg_dict(arg_type, arg_type_prefix, args):
305        for i, arg in enumerate(args):
306            arg_dict[arg_type].append(CppArg(name=arg_type_prefix + str(i), value=arg))
307
308    put_args_into_arg_dict("input", "i", convert_to_list(test_instance._get_input()))
309    if is_criterion_test(test_instance):
310        put_args_into_arg_dict(
311            "target", "t", convert_to_list(test_instance._get_target())
312        )
313    if test_instance.extra_args:
314        put_args_into_arg_dict(
315            "extra_args", "e", convert_to_list(test_instance.extra_args)
316        )
317
318    cpp_var_map = test_params_dict.get("cpp_var_map", {})
319    for arg_name, arg_value in cpp_var_map.items():
320        if isinstance(arg_value, str):
321            if arg_value == "_get_input()":
322                arg_dict["other"].append(
323                    CppArg(name=arg_name, value=test_instance._get_input())
324                )
325            else:
326                raise RuntimeError(
327                    f"`{arg_name}` has unsupported string value: {arg_value}"
328                )
329        elif isinstance(arg_value, torch.Tensor):
330            arg_dict["other"].append(CppArg(name=arg_name, value=arg_value))
331        else:
332            raise RuntimeError(f"`{arg_name}` has unsupported value: {arg_value}")
333
334    return arg_dict
335
336
337def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device):
338    if device == "cuda":
339        test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn)
340        test_fn = unittest.skipIf(not test_cuda, "Excluded from CUDA tests")(test_fn)
341
342    # If `Implementation Parity` entry in parity table for this module is `No`,
343    # or `has_parity` entry in test params dict is `False`, we mark the test as
344    # expected failure.
345    if not has_impl_parity:
346        test_fn = unittest.expectedFailure(test_fn)
347
348    return test_fn
349
350
351MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = """
352What should I do when C++ API parity test is failing?
353
354- If you are changing the implementation of an existing `torch.nn` module / `torch.nn.functional` function:
355Answer: Ideally you should also change the C++ API implementation for that module / function
356(you can start by searching for the module / function name in `torch/csrc/api/` folder).
357
358- If you are adding a new test for an existing `torch.nn` module / `torch.nn.functional` function:
359Answer: Ideally you should fix the C++ API implementation for that module / function
360to exactly match the Python API implementation (you can start by searching for the module /
361function name in `torch/csrc/api/` folder).
362
363- If you are adding a test for a *new* `torch.nn` module / `torch.nn.functional` function:
364Answer: Ideally you should add the corresponding C++ API implementation for that module / function,
365and it should exactly match the Python API implementation. (We have done a large effort on this
366which is tracked at https://github.com/pytorch/pytorch/issues/25883.)
367
368However, if any of the above is proven to be too complicated, you can just add
369`test_cpp_api_parity=False` to any failing test in `torch/testing/_internal/common_nn.py`,
370and the C++ API parity test will be skipped accordingly. Note that you should
371also file an issue when you do this.
372
373For more details on how to add a C++ API parity test, please see:
374NOTE [How to check NN module / functional API parity between Python and C++ frontends]
375"""
376
377
378def generate_error_msg(name, cpp_value, python_value):
379    return (
380        f"Parity test failed: {name} in C++ has value: {cpp_value}, "
381        f"which does not match the corresponding value in Python: {python_value}.\n{MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE}"
382    )
383
384
385def try_remove_folder(folder_path):
386    if os.path.exists(folder_path):
387        # Don't block the process if this fails, but show the error message as warning.
388        try:
389            shutil.rmtree(folder_path)
390        except Exception as e:
391            warnings.warn(
392                f"Non-blocking folder removal fails with the following error:\n{str(e)}"
393            )
394