xref: /aosp_15_r20/external/pytorch/torch/_dynamo/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from . import convert_frame, eval_frame, resume_execution
3from .backends.registry import list_backends, lookup_backend, register_backend
4from .callback import callback_handler, on_compile_end, on_compile_start
5from .code_context import code_context
6from .convert_frame import replay
7from .decorators import (
8    allow_in_graph,
9    assume_constant_result,
10    disable,
11    disallow_in_graph,
12    forbid_in_graph,
13    graph_break,
14    mark_dynamic,
15    mark_static,
16    mark_static_address,
17    maybe_mark_dynamic,
18    run,
19)
20from .eval_frame import (
21    _reset_guarded_backend_cache,
22    explain,
23    export,
24    is_dynamo_supported,
25    is_inductor_supported,
26    optimize,
27    optimize_assert,
28    OptimizedModule,
29    reset_code,
30)
31from .external_utils import is_compiling
32from .mutation_guard import GenerationTracker
33from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
34
35__all__ = [
36    "allow_in_graph",
37    "assume_constant_result",
38    "disallow_in_graph",
39    "forbid_in_graph",
40    "graph_break",
41    "mark_dynamic",
42    "maybe_mark_dynamic",
43    "mark_static",
44    "mark_static_address",
45    "optimize",
46    "optimize_assert",
47    "export",
48    "explain",
49    "run",
50    "replay",
51    "disable",
52    "reset",
53    "OptimizedModule",
54    "is_compiling",
55    "register_backend",
56    "list_backends",
57    "lookup_backend",
58]
59
60if torch.manual_seed is torch.random.manual_seed:
61    import torch.jit._builtins
62
63    # Wrap manual_seed with the disable decorator.
64    # Can't do it at its implementation due to dependency issues.
65    torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
66    # Add the new manual_seed to the builtin registry.
67    torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
68
69
70def reset() -> None:
71    """Clear all compile caches and restore initial state"""
72    with convert_frame.compile_lock:
73        reset_code_caches()
74        convert_frame.input_codes.clear()
75        convert_frame.output_codes.clear()
76        orig_code_map.clear()
77        guard_failures.clear()
78        graph_break_reasons.clear()
79        resume_execution.ContinueExecutionCache.cache.clear()
80        _reset_guarded_backend_cache()
81        reset_frame_count()
82        torch._C._dynamo.compiled_autograd.clear_cache()
83        convert_frame.FRAME_COUNTER = 0
84        convert_frame.FRAME_COMPILE_COUNTER.clear()
85        callback_handler.clear()
86        GenerationTracker.clear()
87        torch._dynamo.utils.warn_once_cache.clear()
88
89
90def reset_code_caches() -> None:
91    """Clear compile caches that are keyed by code objects"""
92    with convert_frame.compile_lock:
93        for weak_code in (
94            convert_frame.input_codes.seen + convert_frame.output_codes.seen
95        ):
96            code = weak_code()
97            if code:
98                reset_code(code)
99        code_context.clear()
100