xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Torch
4import torch
5import torch.cuda
6import torch.jit
7import torch.jit._logging
8import torch.jit.frontend
9import torch.jit.quantized
10
11# Testing utils
12from torch.testing._internal.common_dtype import floating_and_complex_types_and
13from torch.testing._internal.common_utils import TestCase, \
14    freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
15from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
16
17# Standard library
18from itertools import chain
19from typing import List, Union
20from torch._C import TensorType
21
22import io
23
24def check_output_types(self, func, ref_outputs, args, kwargs):
25    graph = getattr(func, 'last_graph', None)
26    types = [o.type() for o in graph.outputs()]
27    self.assertTrue(len(types) == 1)
28    t = types[0]
29    torch._C._jit_assert_is_instance(ref_outputs, t)
30
31# Test names in this set are only checked for a single derivative
32nn_functional_single_grad = frozenset('test_nn_' + name for name in [
33    'pdist',
34    'multilabel_margin_loss',
35    'max_unpool3d',
36    'multi_margin_loss',
37    'binary_cross_entropy',
38    'binary_cross_entropy_size_average',
39    'ctc_loss',
40    'grid_sample',
41])
42
43def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
44                            allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
45    """Verifies a function performs identically to some reference implementation.
46
47    Commonly, this is used to verify that a JIT implementation
48    (output_func) matches the behavior of the eager implementation
49    (reference_func).
50    """
51    kwargs = kwargs if kwargs else {}
52
53    def allSum(vs):
54        if isinstance(vs, torch.Tensor):
55            vs = (vs,)
56        return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
57                   for i, v in enumerate(vs)
58                   if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
59
60    def clone_tensor(t, preserve_requires_grad):
61        require_grad = preserve_requires_grad and t.requires_grad
62        return t.detach().clone().requires_grad_(require_grad)
63
64    def clone_inputs(preserve_requires_grad: bool):
65        inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
66
67        for arg in args:
68            if isinstance(arg, torch.Tensor):
69                inputs.append(clone_tensor(arg, preserve_requires_grad))
70            elif is_iterable_of_tensors(arg):
71                inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
72            else:
73                inputs.append(arg)
74
75        return inputs
76
77    # Returns tensors in args that requires_grad, including tensors in TensorList args
78    def get_recording_tensors(args):
79        recording_tensors: List[torch.Tensor] = []
80
81        for arg in args:
82            if isinstance(arg, torch.Tensor) and arg.requires_grad:
83                recording_tensors.append(arg)
84            elif is_iterable_of_tensors(arg):
85                recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
86
87        return recording_tensors
88
89    # test no gradients case
90    nograd_inputs = clone_inputs(preserve_requires_grad=False)
91    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
92    with enable_profiling_mode_for_profiling_tests():
93        outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
94    self.assertEqual(outputs, outputs_test)
95
96    if check_types:
97        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
98
99    if no_grad:
100        # skip grad tests
101        return
102
103    with enable_profiling_mode_for_profiling_tests():
104        # test single grad case
105        recording_inputs = clone_inputs(preserve_requires_grad=True)
106        recording_tensors = get_recording_tensors(recording_inputs)
107        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
108        grads = torch.autograd.grad(allSum(outputs), recording_tensors,
109                                    allow_unused=allow_unused)
110        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
111        grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
112                                         allow_unused=allow_unused)
113        self.assertEqual(outputs, outputs_test)
114        self.assertEqual(grads, grads_test)
115        # test the grad grad case
116        if self._testMethodName in nn_functional_single_grad or no_gradgrad:
117            return
118
119        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
120        l1 = allSum(outputs)
121        grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
122                                    allow_unused=allow_unused)
123
124        l2 = (allSum(grads) * l1)
125        grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
126        recording_inputs = clone_inputs(preserve_requires_grad=True)
127        recording_tensors = get_recording_tensors(recording_inputs)
128        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
129        l1_test = allSum(outputs_test)
130        grads_test = torch.autograd.grad(
131            l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
132
133        l2_test = (allSum(grads_test) * l1_test)
134        grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
135
136        self.assertEqual(outputs, outputs_test)
137        self.assertEqual(grads, grads_test)
138        for g2, g2_test in zip(grads2, grads2_test):
139            if g2 is None and g2_test is None:
140                continue
141            self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
142
143class JitCommonTestCase(TestCase):
144    def createFunctionFromGraph(self, trace):
145        graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
146        return torch._C._create_function_from_graph("forward", graph)
147
148    def assertExportImport(self, trace, inputs):
149        m = self.createFunctionFromGraph(trace)
150        self.assertExportImportModule(m, inputs)
151
152    def assertExportImportModule(self, m, inputs):
153        m_import = self.getExportImportCopy(m)
154        a = self.runAndSaveRNG(m, inputs)
155        b = self.runAndSaveRNG(m_import, inputs)
156        self.assertEqual(a, b, "Results of original model and "
157                               "exported/imported version of model differed")
158
159    def runAndSaveRNG(self, func, inputs, kwargs=None):
160        kwargs = kwargs if kwargs else {}
161        with freeze_rng_state():
162            results = func(*inputs, **kwargs)
163        return results
164
165    def getExportImportCopy(self, m, also_test_file=True, map_location=None):
166        buffer = io.BytesIO()
167        torch.jit.save(m, buffer)
168        buffer.seek(0)
169        imported = torch.jit.load(buffer, map_location=map_location)
170
171        if not also_test_file:
172            return imported
173
174        with TemporaryFileName() as fname:
175            torch.jit.save(imported, fname)
176            return torch.jit.load(fname, map_location=map_location)
177
178    def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
179                             fusion_nodes_not_found, non_fusible_nodes_being_fused,
180                             fusion_nodes_found, nodes_in_diff_graph):
181        err_msg = "\nFailure in testing nodes' autodifferentiation. "
182        if should_autodiff_node:
183            err_msg += "One or more nodes were expected to be autodiffed, " \
184                "but were not found in specified fusible/nonfusible " \
185                "DifferentiableGraph groups. \nSpecifically:"
186            # The node is intended to appear in a differentiable graph but doesn't
187            diff_nodes_missing = []
188            # The node is intended to appear in a differentiable graph
189            # outside of a fusion group but instead is in a fusion group
190            diff_nodes_in_fusion = []
191            # The node is intended to appear in a fusion group but doesn't
192            fusion_nodes_missing = []
193            # The node is intended to appear in a fusion group but instead
194            # is just in an outer differentiable graph
195            fusion_nodes_in_diff = []
196            for node in nodes_not_in_diff_graph:
197                if node in non_fusible_nodes_being_fused:
198                    diff_nodes_in_fusion.append(node)
199                else:
200                    diff_nodes_missing.append(node)
201            for node in fusion_nodes_not_found:
202                if node in nodes_in_diff_graph:
203                    fusion_nodes_in_diff.append(node)
204                else:
205                    fusion_nodes_missing.append(node)
206            if len(diff_nodes_missing) > 0:
207                err_msg += f"\n  {diff_nodes_missing} were not in one of the " \
208                    "DifferentiableGraphs when they were expected to be. " \
209                    "Did you intend for these nodes to be autodiffed? " \
210                    "If not, remove them from the list of nonfusible nodes."
211            if len(diff_nodes_in_fusion) > 0:
212                err_msg += f"\n  {diff_nodes_in_fusion} were found in one of the FusionGroups " \
213                    "when they were expected to be just in a DifferentiableGraph. If it was " \
214                    "intended for these nodes to be in FusionGroups, reclassify these nodes as " \
215                    "fusible nodes. If these nodes were not intended to be fused, your " \
216                    "autodifferentiation logic might be wrong."
217            if len(fusion_nodes_missing) > 0:
218                err_msg += f"\n  {fusion_nodes_missing} were not in one of the FusionGroups " \
219                    "of the DifferentiableGraphs when they were expected to be. " \
220                    "They were also not found in an outer DifferentiableGraph. Did you " \
221                    "intend for these nodes to be autodifferentiated? If not, you should " \
222                    "remove these nodes from the test's fusible nodes. Otherwise your " \
223                    "autodifferentiation logic might be wrong."
224            if len(fusion_nodes_in_diff) > 0:
225                err_msg += f"\n  {fusion_nodes_in_diff} were not in one of the FusionGroups " \
226                    "of the DifferentiableGraphs when they were expected to be, " \
227                    "instead they were found just in an outer DifferentiableGraph. " \
228                    "Did you intend for these nodes to be fused? If not, you should " \
229                    "move these nodes into the test's nonfusible nodes. Otherwise your " \
230                    "autodifferentiation logic might be wrong."
231        else:
232            err_msg += "One or more nodes were not expected to be autodiffed " \
233                "but were found in a DifferentiableGraph or in a FusionGroup " \
234                "of a DifferentiableGraph. Did you intend for these nodes to be " \
235                "autodiffed? If so, change this test to expect autodifferentiation. " \
236                "\nSpecifically:"
237            if len(fusion_nodes_found) > 0:
238                err_msg += f"\n  {fusion_nodes_found} were not expected to be in " \
239                    "one of the DifferentiableGraphs, but appeared in a FusionGroup " \
240                    "of a DifferentiableGraph. "
241            if len(nodes_in_diff_graph) > 0:
242                err_msg += f"\n  {nodes_in_diff_graph} were not expected to " \
243                    "be in one of the DifferentiableGraphs but were."
244        return err_msg
245
246    def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
247        diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
248        diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
249
250        # Note: currently no tests have fusible_nodes
251        fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
252        fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
253
254        # For any non-fusible node, it must show up in one of the DifferentiableGraphs.
255        nodes_in_diff_graph = []
256        nodes_not_in_diff_graph = []
257        non_fusible_nodes_being_fused = []
258        for node in nonfusible_nodes:
259            if any(g.findNode(node) is not None for g in diff_subgraphs):
260                nodes_in_diff_graph.append(node)
261            else:
262                nodes_not_in_diff_graph.append(node)
263            if any(g.findNode(node) is not None for g in fusion_subgraphs):
264                non_fusible_nodes_being_fused.append(node)
265        found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
266
267        # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
268        fusion_nodes_found = []
269        fusion_nodes_not_found = []
270        for node in fusible_nodes:
271            if any(g.findNode(node) is not None for g in fusion_subgraphs):
272                fusion_nodes_found.append(node)
273            else:
274                fusion_nodes_not_found.append(node)
275        found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
276
277        if should_autodiff_node is not None:
278            err_msg = self.autoDiffErrorMessage(should_autodiff_node,
279                                                nodes_not_in_diff_graph,
280                                                fusion_nodes_not_found,
281                                                non_fusible_nodes_being_fused,
282                                                fusion_nodes_found,
283                                                nodes_in_diff_graph)
284            self.assertEqual(should_autodiff_node,
285                             found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
286
287    def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
288                           traced_graph, assert_propagation, constant_prop=True):
289        # repropagte input shapes provided by tracing,
290        prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
291        for enable_test_mode in [True, False]:
292            # here we are testing allowing/disallowing substituting in complete shapes as constants,
293            # disallowing constants helps stress test partial eval and substitution pipeline
294            torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
295            torch._C._jit_erase_non_input_shape_information(traced_graph)
296            if constant_prop:
297                torch._C._jit_pass_constant_propagation(traced_graph)
298            torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
299            # Add sizes to default tensor type to avoid checking something out of scope
300            # and difficulties with tracer leaving in other parts of tensor type
301            output = next(traced_graph.outputs()).type()
302
303            def test_type(type, actual_size):
304                sizes = type.symbolic_sizes()
305                out_type = TensorType.get().with_sizes(sizes)
306                actual_type = TensorType.get().with_sizes(actual_size)
307
308                # always check actual shape is a subtype of the output
309                self.assertTrue(actual_type.isSubtypeOf(out_type))
310
311                # and then if assertion flag is provided, check shape analysis
312                # is successful
313                if assert_propagation:
314                    self.assertEqual(out_type.sizes(), actual_size)
315
316            if output.isSubtypeOf(torch._C.TensorType.get()):
317                test_type(output, out_sizes)
318            else:
319                tuple_elements = output.elements()
320                for i in range(len(tuple_elements)):
321                    test_type(tuple_elements[i], out_sizes[i])
322
323        torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)
324