1.. role:: hidden 2 :class: hidden-section 3 4torch._logging 5============== 6 7PyTorch has a configurable logging system, where different components can be 8given different log level settings. For instance, one component's log messages 9can be completely disabled, while another component's log messages can be 10set to maximum verbosity. 11 12.. warning:: This feature is in beta and may have compatibility breaking 13 changes in the future. 14 15.. warning:: This feature has not been expanded to control the log messages of 16 all components in PyTorch yet. 17 18There are two ways to configure the logging system: through the environment variable ``TORCH_LOGS`` 19or the python API torch._logging.set_logs. 20 21.. automodule:: torch._logging 22.. currentmodule:: torch._logging 23 24.. autosummary:: 25 :toctree: generated 26 :nosignatures: 27 28 set_logs 29 30The environment variable ``TORCH_LOGS`` is a comma-separated list of 31``[+-]<component>`` pairs, where ``<component>`` is a component specified below. The ``+`` prefix 32will decrease the log level of the component, displaying more log messages while the ``-`` prefix 33will increase the log level of the component and display fewer log messages. The default setting 34is the behavior when a component is not specified in ``TORCH_LOGS``. In addition to components, there are 35also artifacts. Artifacts are specific pieces of debug information associated with a component that are either displayed or not displayed, 36so prefixing an artifact with ``+`` or ``-`` will be a no-op. Since they are associated with a component, enabling that component will typically also enable that artifact, 37unless that artifact was specified to be `off_by_default`. This option is specified in _registrations.py for artifacts that are so spammy they should only be displayed when explicitly enabled. 38The following components and artifacts are configurable through the ``TORCH_LOGS`` environment 39variable (see torch._logging.set_logs for the python API): 40 41Components: 42 ``all`` 43 Special component which configures the default log level of all components. Default: ``logging.WARN`` 44 45 ``dynamo`` 46 The log level for the TorchDynamo component. Default: ``logging.WARN`` 47 48 ``aot`` 49 The log level for the AOTAutograd component. Default: ``logging.WARN`` 50 51 ``inductor`` 52 The log level for the TorchInductor component. Default: ``logging.WARN`` 53 54 ``your.custom.module`` 55 The log level for an arbitrary unregistered module. Provide the fully qualified name and the module will be enabled. Default: ``logging.WARN`` 56 57Artifacts: 58 ``bytecode`` 59 Whether to emit the original and generated bytecode from TorchDynamo. 60 Default: ``False`` 61 62 ``aot_graphs`` 63 Whether to emit the graphs generated by AOTAutograd. Default: ``False`` 64 65 ``aot_joint_graph`` 66 Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` 67 68 ``compiled_autograd`` 69 Whether to emit logs from compiled_autograd. Defaults: ``False`` 70 71 ``ddp_graphs`` 72 Whether to emit graphs generated by DDPOptimizer. Default: ``False`` 73 74 ``graph`` 75 Whether to emit the graph captured by TorchDynamo in tabular format. 76 Default: ``False`` 77 78 ``graph_code`` 79 Whether to emit the python source of the graph captured by TorchDynamo. 80 Default: ``False`` 81 82 ``graph_breaks`` 83 Whether to emit a message when a unique graph break is encountered during 84 TorchDynamo tracing. Default: ``False`` 85 86 ``guards`` 87 Whether to emit the guards generated by TorchDynamo for each compiled 88 function. Default: ``False`` 89 90 ``recompiles`` 91 Whether to emit a guard failure reason and message every time 92 TorchDynamo recompiles a function. Default: ``False`` 93 94 ``output_code`` 95 Whether to emit the TorchInductor output code. Default: ``False`` 96 97 ``schedule`` 98 Whether to emit the TorchInductor schedule. Default: ``False`` 99 100Examples: 101 ``TORCH_LOGS="+dynamo,aot"`` will set the log level of TorchDynamo to ``logging.DEBUG`` and AOT to ``logging.INFO`` 102 103 ``TORCH_LOGS="-dynamo,+inductor"`` will set the log level of TorchDynamo to ``logging.ERROR`` and TorchInductor to ``logging.DEBUG`` 104 105 ``TORCH_LOGS="aot_graphs"`` will enable the ``aot_graphs`` artifact 106 107 ``TORCH_LOGS="+dynamo,schedule"`` will enable set the log level of TorchDynamo to ``logging.DEBUG`` and enable the ``schedule`` artifact 108 109 ``TORCH_LOGS="+some.random.module,schedule"`` will set the log level of some.random.module to ``logging.DEBUG`` and enable the ``schedule`` artifact 110