xref: /aosp_15_r20/external/pytorch/torch/_inductor/config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os  # noqa: C101
2import sys
3from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
4
5import torch
6
7
8def is_fbcode() -> bool:
9    return not hasattr(torch.version, "git_version")
10
11
12def fx_graph_remote_cache_default() -> Optional[bool]:
13    if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1":
14        return True
15    if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0":
16        return False
17    return None
18
19
20def autotune_remote_cache_default() -> Optional[bool]:
21    if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1":
22        return True
23    if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "0":
24        return False
25    return None
26
27
28# Enable auto_functionalized_v2 (enabled by default)
29enable_auto_functionalized_v2 = (
30    os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1"
31)
32
33# add some debug printouts
34debug = False
35
36# Whether to disable a progress bar for autotuning
37disable_progress = True
38
39# Whether to enable printing the source code for each future
40verbose_progress = False
41
42# use fx aot graph codegen cache
43fx_graph_cache = (
44    os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "0" if is_fbcode() else "1") == "1"
45)
46
47# use remote fx aot graph codegen cache
48# False: Disables the cache
49# True: Enables the cache
50# None: Not set -- Off for OSS, JustKnobs based for internal
51fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
52
53# enable autotune local cache
54autotune_local_cache = True
55
56# enable autotune remote cache
57# False: Disables the cache
58# True: Enables the cache
59# None: Not set -- Off for OSS, JustKnobs based for internal
60autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
61
62# Force disabled all inductor level caching -- This will override any other caching flag
63force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
64
65# sleep in inductor for testing
66sleep_sec_TESTING_ONLY: Optional[int] = None
67
68# The default layout constraint for custom operators.
69# This must be the name of one of the layout constraint tags
70# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
71# If the custom op does not have a layout constraint tag already
72# then we assume the following applies.
73custom_op_default_layout_constraint = "flexible_layout"
74
75# use cpp wrapper instead of python wrapper
76cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
77
78# codegen cpp wrapper code in an ABI compatible mode
79abi_compatible = (
80    os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
81)
82
83c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2")
84
85# dead code elimination
86dce = False
87
88# assume weight tensors are fixed size
89static_weight_shapes = True
90
91# put correctness assertions in generated code
92size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
93nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
94
95# enable loop reordering based on input orders
96pick_loop_orders = True
97
98# reuse a kernel input as the output
99inplace_buffers = True
100
101# reuse a buffer for an unrelated purpose
102allow_buffer_reuse = True
103
104# Enable pooled allocations for non-output tensors
105memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
106
107# How to organize memory under memory_planning=True:
108# - "none": do not try to pool storage, just reuse
109# - "intermediates": all non-outputs share storage, outputs each get unique storage
110# - "outputs": two pools, one for intermediates (freed on return) and one for outputs
111# - "combined": a single pool for both intermediates and outputs
112memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
113
114# codegen benchmark harness
115benchmark_harness = True
116
117# fuse pointwise into templates
118epilogue_fusion = True
119
120# do epilogue fusions before other fusions
121epilogue_fusion_first = False
122
123# enable pattern match+replace optimizations
124pattern_matcher = True
125
126# set to True to enable the back-to-back GEMM pass
127b2b_gemm_pass = False
128
129# register custom graph optimization pass hook. so far, pre/post passes are
130# only applied before/after pattern_matcher in post_grad_passes.
131#
132# def my_custom_pre_pass(graph: torch.fx.graph.Graph):
133#     # my custom graph optimization pass
134#     ...
135#
136# def my_custom_post_pass(graph: torch.fx.graph.Graph):
137#     # my custom graph optimization pass
138#     ...
139#
140# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
141# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
142post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
143post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
144
145# Registers a custom joint graph pass.
146joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
147joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
148
149# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
150# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
151# use post-grad passes.
152pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
153
154# Registers a custom pass to be run right before fusion in Inductor scheduler.
155# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
156# hence custom IR passes built on top of it might break in the future.
157_pre_fusion_custom_pass: Optional[
158    Callable[
159        [List["torch._inductor.scheduler.BaseSchedulerNode"]],
160        List["torch._inductor.scheduler.BaseSchedulerNode"],
161    ]
162] = None
163
164# Deprecated
165split_cat_fx_passes = True
166
167# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
168efficient_conv_bn_eval_fx_passes = False
169
170# Enable predispatch aten IR for export
171is_predispatch = False
172
173# Deprecated
174group_fusion = False
175
176# Deprecated
177batch_fusion = True
178
179# Pre grad fusion and options in order, set to empty dict to disable fusion.
180# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
181# batch fusion options:
182# batch_linear
183# batch_linear_lhs
184# batch_layernorm
185# batch_tanh
186# batch_relu
187# batch_sigmoid
188
189# split cat fusion options:
190# normalization_pass
191# remove_split_with_size_one_pass
192# merge_getitem_cat_pass
193# merge_stack_tahn_unbind
194# merge_splits_pass
195# mutate_cat_pass
196# split_cat_pass
197pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
198    "batch_linear": {},
199    "batch_linear_lhs": {},
200    "batch_layernorm": {},
201    "batch_tanh": {},
202    "batch_relu": {},
203    "batch_sigmoid": {},
204}
205
206# Post grad fusion and options, set to empty dict to disable fusion.
207# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
208post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
209
210# enable reordering pass for improving memory locality
211reorder_for_locality = True
212
213# Scale down RBLOCK for better occupancy
214dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
215
216# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
217# but the mul gets fused with other pointwise ops instead.
218force_fuse_int_mm_with_mul = False
219
220# for pattern torch.mm(a, b.to(dtype)) with cuda tensors,
221# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel.
222# Autotune will compare perf with normal cast->then->mm option
223use_mixed_mm = True
224
225# enable runtime numeric check for pre/post grad fx passes
226# floating point provides limited accuracy (about 7 decimal digits for single precision
227# floating point numbers,about 16 decimal digits for double precision floating point numbers)
228# according to PyTorch documentation.
229# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
230fx_passes_numeric_check: Dict[str, Any] = {
231    "pre_grad": False,
232    "precision": 1e-4,
233    "num_iterations": 1,
234    "requires_optimizer": True,
235}
236
237# mixed_mm_choice can be used to control the behaviour for pattern torch.mm(a, b.to(dtype)) with cuda tensors.
238# The fallback aten implementation is normal cast->then->mm option.
239# If mixed_mm_choice is "default": this flag will be ignored.
240# If mixed_mm_choice is "triton":
241# - Always use torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel.
242# - Autotune will not compare with fallback.
243# If mixed_mm_choice is "aten": always use the fallback aten implementation.
244# If mixed_mm_choice is "heuristic":
245# - Enables the heuristic.
246# - If the heuristic decides to add a config, it will add the config as the first choice.
247# - If autotune is disabled, this config will always be chosen.
248# - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel.
249# The use_mixed_mm flag will be ignored if mixed_mm_choice != "default".
250mixed_mm_choice = "heuristic"
251
252# enable reordering pass for increasing overlap between compute and communication
253reorder_for_compute_comm_overlap = False
254
255# passes (in execution order) for increasing overlap between compute and communication
256# for built-in passes, use string name; for user-defined passes, pass in the function handle
257# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
258# hence custom IR passes built on top of it might break in the future.
259reorder_for_compute_comm_overlap_passes = [
260    "reorder_compute_for_overlap",
261    "sink_waits",
262    "raise_comms",
263]
264
265# runtime estimation function for ops
266# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
267estimate_op_runtime = "default"
268
269# unit: GB/s, uni-directional P2P bandwidth per card
270# default value is NVLink
271intra_node_bw = 300
272
273# unit: GB/s, uni-directional P2P bandwidth per node
274# default value is InfiniBand
275inter_node_bw = 25
276
277# enable slow autotuning passes to select algorithms
278max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
279
280# enable slow autotuning passes to select pointwise/reductions algorithms
281max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
282
283# enable slow autotuning passes to select gemm algorithms
284max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
285
286# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
287# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
288# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
289# that triton does not use TF32 wherever cublas would not use TF32
290force_same_precision = (
291    True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
292)
293
294# Specify candidate backends for gemm autotune.
295# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP.
296# ATen: default Pytorch ATen kernels.
297# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs).
298# CUTLASS: Cutlass templates and kernels (NVidia GPUs only).
299# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only).
300# CPP: CPP templates and kernels for CPU.
301max_autotune_gemm_backends = os.environ.get(
302    "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
303).upper()
304
305# As above, specify candidate backends for conv autotune.
306# NB: in some cases for 1x1 convs we emit as matmul,
307# which will use the backends of `max_autotune_gemm_backends`
308max_autotune_conv_backends = os.environ.get(
309    "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON"
310).upper()
311
312
313# Specify the size of the search space for GEMM autotuning.
314# DEFAULT     - balance between compile time overhead and performance
315# EXHAUSTIVE  - maximize performance
316max_autotune_gemm_search_space = os.environ.get(
317    "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
318).upper()
319
320# Whether we fall back to ATen or hard error when no matches are found during autotuning
321autotune_fallback_to_aten = (
322    os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1"
323)
324
325# the value used as a fallback for the unbacked SymInts
326# that can appear in the input shapes (e.g., in autotuning)
327unbacked_symint_fallback = 8192
328
329# DEPRECATED, DO NOT USE
330search_autotune_cache = False
331
332save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
333
334# We will disable creating subprocess for autotuning if this is False
335autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
336
337# The following three timeouts are applicable if autotune_in_subproc is True:
338
339# Max time that a a valid benchmark result may take during autotuning
340max_autotune_subproc_result_timeout_seconds = 60.0
341# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM
342max_autotune_subproc_graceful_timeout_seconds = 1.0
343# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses
344max_autotune_subproc_terminate_timeout_seconds = 2.0
345
346# If autotuning in subprocess, whether to use multiple devices
347autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
348
349coordinate_descent_tuning = (
350    os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
351)
352coordinate_descent_check_all_directions = (
353    os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
354)
355coordinate_descent_search_radius = int(
356    os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
357)
358
359# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
360# generate the learned heursitic to code which is shipped with the compiler
361# Specify a list of comma separated optimizations to collect data for
362autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
363# Specify a list of comma separated optimizations to use learned heuristics for
364autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm")
365
366
367def run_autoheuristic(name: str) -> bool:
368    return collect_autoheuristic(name) or use_autoheuristic(name)
369
370
371def collect_autoheuristic(name: str) -> bool:
372    return name in torch._inductor.config.autoheuristic_collect.split(",")
373
374
375def use_autoheuristic(name: str) -> bool:
376    return name in torch._inductor.config.autoheuristic_use.split(",")
377
378
379# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
380# If set to another path, autoheuristic will instead log results to the given path.
381autoheuristic_log_path = os.environ.get(
382    "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT"
383)
384
385# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
386layout_opt_default = "1" if not torch.version.hip else "0"
387layout_optimization = (
388    os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
389)
390
391force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
392
393
394# Whether to keep the output strides the same as eager after layout optimization.
395keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
396
397# Enabling this will let compiler print warning messages if a generated triton
398# kernel has inputs with mixed layouts.  This is helpful for perf debugging
399# since kernel with mixed layout inputs may run much slower then one whose inputs
400# have uniform layouts.
401warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
402
403# control store vs recompute heuristic
404# For fanouts, rematerialization can lead to exponential blowup. So, have
405# smaller threshold
406realize_reads_threshold = 4
407realize_opcount_threshold = 30
408
409# Threshold to prevent excessive accumulation of ops in one buffer during lowering
410realize_acc_reads_threshold = 8
411
412# fallback to eager for random/dropout, this is slow but useful for debugging
413fallback_random = False
414
415# automatically create fallbacks when encountering an unhandled op
416implicit_fallbacks = True
417
418# fuse even in cases without common reads
419aggressive_fusion = False
420
421# For each fused kernel in the wrapper, comment with the nodes that get fused.
422# Useful for debugging fusion.
423debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
424benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
425enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
426loop_ordering_after_fusion = (
427    os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
428)
429
430# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
431benchmark_epilogue_fusion = (
432    os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
433)
434
435# Take how many of the top triton kernels to benchmark epilogue
436max_epilogue_benchmarked_choices = 1
437
438# how many nodes to allow into a single fusion
439max_fusion_size = 64
440
441# max number of inputs to generate cat as a pointwise op with masked laods
442max_pointwise_cat_inputs = 8
443
444# replace small reductions with pointwise, disable with `= 1`
445unroll_reductions_threshold = 8
446
447# Add extra comments to output code (causes compile cache misses)
448comment_origin = False
449
450# Convert 1x1 convs into matmuls
451conv_1x1_as_mm = False
452
453# Enable split reductions for better utilization when the dimension
454# being reduced over is large (by splitting it)
455split_reductions = True
456
457benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
458
459# Enable constant and index_expr folding
460constant_and_index_propagation = True
461
462# we always add constants into graph.constants without
463# performing any constant-inlining optimization
464always_keep_tensor_constants = False
465
466# assert that indirect indexing does not read / write out of bounds
467assert_indirect_indexing = True
468
469# compute CSE bounds on variables that do not appear in the FX graph
470compute_all_bounds = False
471
472# enable the combo kernel that combines data-independent kernels (additional
473# to foreach kernels) into a single one (Experimental)
474combo_kernels = False
475# benchmark combo kernels and only allow ones with perf gains
476benchmark_combo_kernel = False
477# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach,
478# 2 - enable for all
479combo_kernels_autotune = 1
480# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable
481# for all except for foreach, 2 - enable for all
482combo_kernel_allow_mixed_sizes = 1
483# Enable dynamic shapes for foreach kernels
484combo_kernel_foreach_dynamic_shapes = False
485
486# constant folding on the joint graph
487joint_graph_constant_folding = True
488
489# Enable indirect_indexing asserts for decompositions and lowerings
490debug_index_asserts = False
491
492# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
493# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
494# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
495# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
496# to emulate the eager numerics.
497emulate_precision_casts = False
498
499# warnings intended for PyTorch developers, disable for point releases
500is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
501developer_warnings = is_fbcode() or is_nightly_or_source
502
503# This pattern matches a special usage of scatter
504# 1. It's applied to a constant tensor
505# 2. The index tensor has size 1 in the scatter dimension
506# Such pattern generates a sparse matrix when the const tensor is all-zero.
507# We can lower this pattern to a pointwise kernel for more fusion opportunities
508# and saving memory footprint.
509optimize_scatter_upon_const_tensor = (
510    os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
511)
512
513
514# The multiprocessing start method to use for inductor workers in the codecache.
515# Can be "subprocess" or "fork".
516def decide_worker_start_method() -> str:
517    start_method = os.environ.get(
518        "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess"
519    )
520    assert start_method in (
521        "subprocess",
522        "fork",
523    ), f"Invalid start method: {start_method}"
524    return start_method
525
526
527worker_start_method = decide_worker_start_method()
528
529# Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned
530# on by DDP and should not be set by the users.
531_fuse_ddp_communication = False
532_fuse_ddp_bucket_size = 25
533
534# Flag to control which fusion passes to apply. Functions in the list will
535# be applied in order. There are two different different fusion passes
536# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default
537# one is "fuse_ddp_with_concat_op". Users can also change this to a customized
538# fusion function.
539#
540# The fusion currently does not support multiple DDP with different PG or
541# data type. This feature will be added in the future PRs.
542#
543# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp
544# overlapping. At this moment, this pass performs better than
545# reorder_for_compute_comm_overlap_passes but we will add the logic of
546# "schedule_comm_wait" in the future and remove the one here.
547_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
548    "fuse_ddp_with_concat_op",
549    "schedule_comm_wait",
550]
551
552_micro_pipeline_tp: bool = False
553
554
555def decide_compile_threads() -> int:
556    """
557    Here are the precedence to decide compile_threads
558    1. User can override it by TORCHINDUCTOR_COMPILE_THREADS.  One may want to disable async compiling by
559       setting this to 1 to make pdb happy.
560    2. Set to 1 if it's win32 platform
561    3. decide by the number of CPU cores
562    """
563    if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
564        return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
565    elif sys.platform == "win32":
566        return 1
567    elif is_fbcode():
568        return 1
569    else:
570        cpu_count = (
571            len(os.sched_getaffinity(0))
572            if hasattr(os, "sched_getaffinity")
573            else os.cpu_count()
574        )
575        assert cpu_count
576        return min(32, cpu_count)
577
578
579compile_threads = decide_compile_threads()
580
581# gemm autotuning global cache dir
582if is_fbcode():
583    try:
584        from libfb.py import parutil
585
586        if __package__:
587            global_cache_dir = parutil.get_dir_path(
588                os.path.join(__package__.replace(".", os.sep), "fb/cache")
589            )
590        else:
591            global_cache_dir = parutil.get_dir_path("fb/cache")
592    except (ValueError, ModuleNotFoundError):
593        global_cache_dir = None
594
595else:
596    global_cache_dir = None
597
598# If kernel is fused, the name is generated from the origin node op names
599# for larger kernels limit this
600kernel_name_max_ops = 10
601
602# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
603shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
604
605# Control if we will do padding for pointwise/reductions
606comprehensive_padding = (
607    os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1"
608)
609pad_channels_last = False
610
611# Disable comprehensive padding on the CPU
612disable_padding_cpu = True
613
614# The width of comprehensive padding, in bytes.
615# CUDA max memory transaction size is 128 bytes for a warp.
616padding_alignment_bytes = 128
617
618# Threshold on the minimum stride that will be padded.
619#
620# Don't align a too small stride since that causes too much memory increase.
621# Pad too small stride may also cause perf loss. We may result in many tiny data blocks
622# with gaps in between. That causes less coalesced GPU memory access!
623#
624# Initially we pick 320 as the threshold since for alignement=16,
625# that results in at most 5% memory cost.
626#
627# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
628# Let's say an inner reduction has a row size 513. Inductor will generate
629# persistent reduction code.
630# If we do padding, the strides are not contiguous any more. Inductor
631# uses a much smaller threshold for persistent reduction in this case and
632# generates potentially worse non-persistent reduction code.
633#
634# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
635# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
636padding_stride_threshold = 1024
637
638# Enable padding outputs, even if they would not be padded in eager mode.
639# By default, we use the same strides as eager mode.
640pad_outputs = False
641
642# Whether to treat output of the backward graph as user visible.
643# For user visible outputs, inductor will make sure the stride matches with eager.
644bw_outputs_user_visible = True
645
646# Whether to always use shape padding if it is enabled and possible
647force_shape_pad: bool = False
648
649# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
650permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
651
652# Mark the wrapper call in PyTorch profiler
653profiler_mark_wrapper_call = False
654
655# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
656# every intermediate for which we can correlate it with an intermediate
657# from the original FX graph
658generate_intermediate_hooks = False
659
660# Populate traceback field on IRNode; good for debugging why origin_node is
661# not populated, or finding out where an IRNode was constructed
662debug_ir_traceback = False
663
664# used for debugging to make sure config is properly set
665_raise_error_for_testing = False
666
667_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
668profile_bandwidth = _profile_var != ""
669profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
670# Specify a file where we print out the profiling results.
671# None means we do not dump results to a file.
672profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
673# Switch to do_bench_using_profiling to exclude the CPU overheads
674profile_bandwidth_with_do_bench_using_profiling = (
675    os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
676)
677
678
679# TODO: remove later
680disable_cpp_codegen = False
681
682
683# Freezing will attempt to inline weights as constants in optimization
684# and run constant folding and other optimizations on them. After freezing, weights
685# can no longer be updated.
686freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
687
688# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
689# of potentially keeping multiple copies of weights.
690freezing_discard_parameters: bool = False
691
692# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
693# should be run with this flag both on and off to make sure we have coverage.
694allow_stack_allocation: bool = (
695    os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1" if is_fbcode() else "0") == "1"
696)
697
698# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
699# to maximize performance for use cases that it can accommodate at the expense of
700# generality. In brief:
701# - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
702#   tensor must be contiguous)
703# - constant handling is unchanged because it is not a per-inference-iteration bottleneck
704#
705# When the DSO is generated in this mode, the usual interface will also be supported,
706# but performance for that interface may be degraded.
707use_minimal_arrayref_interface: bool = False
708
709# decompose some memory bound matmul/bmm to mul
710decompose_mem_bound_mm: bool = False
711
712# assume_aligned_inputs means that we assume that inputs will be aligned; we generate
713# code using this assumption, and clone tensors before use if they aren't aligned.
714# In the common case, most inputs will be aligned.
715assume_aligned_inputs: bool = False
716
717# For the user-written Triton kernels compiled with the model, ignore the unsupported
718# arguments passed to the @triton.autotune in the user's code; this is unsafe, as
719# ignoring the unsupported args may lead to unexpected autotuning behavior: don't
720# set unless you know what you're doing.
721unsafe_ignore_unsupported_triton_autotune_args: bool = False
722
723# When True, we will check in scheduler.py _codegen that there are no "loops"
724# in the call stack; that is to say, the same frame multiple times.  This
725# ensures that a cProfile trace to this frame will be a straight line without
726# any cycles.
727check_stack_no_cycles_TESTING_ONLY: bool = False
728
729
730# config specific to codegen/cpp.py
731class cpp:
732    # set to torch.get_num_threads()
733    threads = -1
734
735    # Do not generate loops when the condition doesn't hold, like:
736    # for(long i0=4096; i0<4096; i0+=1)
737    no_redundant_loops = (
738        os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1"
739    )
740
741    # Assume number of threads is dynamic, don't specialize thread number.
742    # Kernels don't recompile on thread number changes with this flag on.
743    # For single-threaded workload, turning it on would incur a slight
744    # performance degradation.
745    dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
746
747    simdlen: Optional[int] = None
748    min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
749    cxx = (
750        None,  # download gcc12 from conda-forge if conda is installed
751        # "g++-12",
752        # "g++-11",
753        # "g++-10",
754        # "clang++",
755        os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
756        # "g++.par",
757    )
758    # Allow kernel performance profiling via PyTorch profiler
759    enable_kernel_profile = (
760        os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
761    )
762
763    # enable weight prepacking to get a better performance; may lead to large memory footprint
764    weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1"
765
766    # Inject a bug into our relu implementation; useful for testing our repro
767    # extraction and minification functionality.
768    # Valid values: "compile_error", "runtime_error", "accuracy"
769    inject_relu_bug_TESTING_ONLY: Optional[str] = None
770    inject_log1p_bug_TESTING_ONLY: Optional[str] = None
771
772    # If None, autodetect whether or not AVX512/AVX2 can be used.  Otherwise,
773    # force usage as specified, without testing.
774    vec_isa_ok: Optional[bool] = None
775
776    # similar to config.triton.descriptive_names
777    descriptive_names = "original_aten"
778
779    # how many nodes to allow into a single horizontal fusion
780    max_horizontal_fusion_size = int(
781        os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16")
782    )
783
784    # Make scatter_reduce fallback when reduce is sum to avoid performance regression
785    # using atomic_add.
786    fallback_scatter_reduce_sum = (
787        os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1"
788    )
789
790    # Use funsafe-math-optimizations when compiling
791    enable_unsafe_math_opt_flag = (
792        os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1"
793    )
794
795    # Use ffp-contract when compiling
796    enable_floating_point_contract_flag = (
797        os.environ.get("TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "0")
798        == "1"
799    )
800
801    # Disable the tiling select heuristic
802    enable_tiling_heuristics = (
803        os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
804    )
805
806    # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
807    # the maximal parallelism of K-slicing. Since K-slicing requires extra thread
808    # synchronization and buffers,  the maximal number of slices is limited to
809    # mitigate the sync overhead and memory usage.
810    # When set to 0, the number of slices is unlimited.
811    gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))
812
813    # For perf tuning and debugging purpose, configure the pre-defined cache blocking for
814    # MxNxK dims respectively. The blockings are separated by comma and the unit is
815    # the number of register blocks.
816    # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
817    gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
818
819    # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for
820    # MxNxK dims respectively. The factors are separated by comma and their product
821    # should be the same as the total number of threads.
822    # For example, if the total number of threads is 56, "7,4,2" means the work is
823    # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM.
824    gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None)
825
826    # Whether to enable masked vectorization for the tail_loop.
827    enable_loop_tail_vec = True
828
829
830# config specific to codegen/triton.py
831class triton:
832    # Use cudagraphs on output code
833    cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"
834
835    # Use cudagraph trees for memory pooling if `cudagraphs` is True
836    cudagraph_trees = True
837
838    # Should we skip cudagraphing graphs with dynamic shape inputs
839    # If False, we will re-record a graph for each unique set of shape inputs
840    cudagraph_skip_dynamic_graphs = False
841
842    # assertions not on the fast path, steady state
843    slow_path_cudagraph_asserts = True
844
845    # TODO - need to debug why this prevents cleanup
846    cudagraph_trees_history_recording = False
847
848    # Enable cudagraph support for mutated inputs from prior cudagraph pool
849    cudagraph_support_input_mutation = False if is_fbcode() else True
850
851    # Maximal number of allowed cudagraph re-record for a function and
852    # a cudagraph node due to static input tensor address changes or
853    # cudagraph managed tensor data pointer changed.
854    # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit
855    # note: we are conservative here and choose a large limit.
856    cudagraph_unexpected_rerecord_limit = 128
857
858    # Warn loudly when the number of cudagraphs due to dynamic shape
859    # exceeds this limit
860    cudagraph_dynamic_shape_warn_limit: Optional[int] = 50
861
862    # synchronize after cudagraph invocation
863    force_cudagraph_sync = False
864
865    # always run cudagraphs in the eager warmup stage
866    # instead of recording and executing cudagraphs
867    force_cudagraphs_warmup = False
868
869    # assertions on the fast path
870    fast_path_cudagraph_asserts = False
871
872    # skip warmup for cudagraph trees
873    skip_cudagraph_warmup = False
874
875    # Synchronize before and after every compiled graph.
876    debug_sync_graph = False
877
878    # Synchronize after every kernel launch, to help pinpoint bugs
879    debug_sync_kernel = False
880
881    # Always load full blocks (rather than broadcasting inside the block)
882    dense_indexing = False
883
884    # limit tiling dimensions
885    max_tiles = 2
886
887    # Prefer higher dimensional tilings. This simplifies indexing expressions, making
888    # it easier to identify block pointers.
889    prefer_nd_tiling: bool = False
890
891    # use triton.autotune for pointwise ops with complex layouts
892    # this should only be disabled for debugging/testing
893    autotune_pointwise = True
894
895    # max autotune gemm with cublasLt
896    autotune_cublasLt = True
897
898    # Tune the generated Triton kernels at compile time instead of first time they run
899    autotune_at_compile_time = False
900
901    # should we stop a fusion to allow better tiling?
902    tiling_prevents_pointwise_fusion = True
903    tiling_prevents_reduction_fusion = True
904
905    # should we give different names to kernels
906    # Note: This is orthogonal to descriptive_names - this is deciding whether
907    # our triton kernel names should all be `triton_` (to maximize caching) or
908    # whether they should be unique.
909    unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
910
911    # should we put op names in kernel names
912    # False: No special names (just triton__1, triton__2, etc.)
913    # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
914    # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
915    # "inductor_node": Maps to the node name in the FX graph passed to Inductor
916    descriptive_names = "original_aten"
917
918    # use alternate codegen for smaller reductions
919    persistent_reductions = (
920        os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
921    )
922
923    # 0/False: disable
924    # 1/True: enable, use tuning to pick between different subkernels
925    # 2: enable, force using persistent reduction (for debugging)
926    # 3: enable, force using non-persistent reduction (for debugging)
927    multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
928
929    # hint to Triton when arguments are divisible by 16
930    divisible_by_16 = True
931
932    # Minimum RBLOCK to be used for a TritonSplitScanKernel
933    # NOTE: This also indirectly controls the size of workspace buffer required
934    min_split_scan_rblock = 256
935
936    # Store the generated cubin files for cpp wrapper code to load
937    store_cubin = False
938
939    # the max number of spills we allow for the configs we benchmark.
940    # Setting this to 0 means we skip a config if it spills even a single
941    # register.
942    # Setting it to a larger value allows a config spilling a small amount
943    # of registers being benchmarked.
944    #
945    # NOTE: triton will always report >0 register spills for kernels using sin/cos.
946    # (check this issue https://github.com/openai/triton/issues/1756 )
947    # So far we see a fixed 8 spilled registers for kernels using sin/cos.
948    # Raise the threshold to 16 to be safe.
949    # We should revisit this once we understand more of the source of register spills.
950    spill_threshold: int = 16
951
952    # Generate code containing the newer tl.make_block_ptr() API for loads/store
953    use_block_ptr = False
954
955    # Inject a bug into our relu implementation; useful for testing our repro
956    # extraction and minification functionality.
957    # Valid values: "compile_error", "runtime_error", "accuracy"
958    inject_relu_bug_TESTING_ONLY: Optional[str] = None
959
960    # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
961    codegen_upcast_to_fp32 = True
962
963
964class aot_inductor:
965    # AOTInductor output path
966    # If an absolute path is specified, the generated lib files will be stored under the directory;
967    # If a relative path is specified, it will be used as a subdirectory under the default caching path;
968    # If not specified, a temp directory will be created under the default caching path.
969    # If the specified path contains something like "model.so", the sub-string will be used
970    # to name the generated library.
971    output_path = ""
972
973    debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
974
975    debug_dump_consts_bin: bool = (
976        os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1"
977    )
978
979    # option for debug printing/saving for intermediate tensor values for aot inductor
980    # 0: disable debug dumping
981    # 1: enable saving intermediate tensor values
982    # 2: enable printing intermediate tensor values
983    debug_intermediate_value_printer = os.environ.get(
984        "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
985    )
986
987    # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
988    filtered_kernel_names = os.environ.get(
989        "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None
990    )
991
992    # Serialized tree spec for flattening inputs
993    serialized_in_spec = ""
994
995    # Serialized tree spec for flattening outputs
996    serialized_out_spec = ""
997
998    # flag to decide whether to create a submodule for constant graph.
999    use_runtime_constant_folding: bool = False
1000
1001    # flag to force weight to be appened to the shared library and mmaped  by the runtime
1002    # rather than embedded into the data section. Needed to support 1B+ parameter models
1003    force_mmap_weights: bool = False
1004
1005    package: bool = False
1006
1007
1008class cuda:
1009    # CUDA arch to use for CUDA template kernel compilation.
1010    # e.g. "70", "75", "80", "90", etc.
1011    # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
1012    arch: Optional[str] = None
1013
1014    # CUDA version to use for CUDA template kernel compilation.
1015    # e.g. "11.4", "12.1", etc.
1016    # When version is None, Inductor uses torch.version.cuda.
1017    version: Optional[str] = None
1018
1019    # Optimization level for the host compiler.
1020    compile_opt_level = "-O1"
1021
1022    # Whether to enable device LTO (link-time-optimization).
1023    enable_cuda_lto = False
1024
1025    # Whether to keep intermediate files dring compilation.
1026    enable_ptxas_info = False
1027
1028    # Whether to enable debug info, e.g. line number, cutlass debug info.
1029    enable_debug_info = False
1030
1031    # Whether to use fast math.
1032    use_fast_math = False
1033
1034    # Path to the CUTLASS repo root directory.
1035    # The default path only works under PyTorch local development environment.
1036    cutlass_dir = os.environ.get(
1037        "TORCHINDUCTOR_CUTLASS_DIR",
1038        os.path.abspath(
1039            os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
1040        ),
1041    )
1042
1043    # Configures the maximum number of CUTLASS configs to profile in max_autotune.
1044    # By default it's None, so that all CUTLASS configs are tuned.
1045    # This is mainly used to reduce test time in CI.
1046    cutlass_max_profiling_configs: Optional[int] = None
1047
1048    # Path to CUDA NVCC.
1049    # NVCC search order:
1050    # 1) cuda_cxx set in this config
1051    # 2) CUDACXX environment variable
1052    # 3) CUDA_HOME environment variable
1053    # 4) default system search PATH.
1054    cuda_cxx: Optional[str] = None
1055
1056    # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
1057    cutlass_backend_min_gemm_size: int = 1
1058
1059    # enable generation of inline standalone runner in CUDA CPP generated code
1060    # which allows to compile the generated code into a standalone executable.
1061    generate_test_runner: bool = (
1062        os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "1") == "1"
1063    )
1064
1065    # Keep only Cutlass op configs which contain this regular expression pattern
1066    # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs
1067    cutlass_op_allowlist_regex: Optional[str] = None
1068
1069    # Note: Names of Cutlass ops names can be obtained by calling
1070    # op.configuration_name() on a Cutlass op instance, for example those
1071    # returned from cutlass_utils.gen_ops() or the op argument passed to
1072    # CUTLASSGemmTemplate.render(...)
1073
1074    # Filter Cutlass configs which contain this regular expression pattern
1075    # Set this to "pingpong" to avoid numerical issues
1076    # caused by the op ordering of the "pingpong" memory access
1077    # pattern used by some Cutlass Kernels.
1078    cutlass_op_denylist_regex: Optional[str] = "pingpong"
1079
1080
1081class rocm:
1082    # Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
1083    # If empty, the `native` arch is used
1084    arch: List[str] = []
1085
1086    # Enable the CK backend for CDNA2 and CDNA3 only (for now)
1087    # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
1088    ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
1089
1090    # Optimization level, use to balance compilation speed and runtime performance
1091    compile_opt_level = "-O2"
1092
1093    # Flag to keep debug information in compiled objects
1094    is_debug = False
1095
1096    # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.)
1097    save_temps = False
1098
1099    # Flag to add `-ffast-math`` to compile flags
1100    use_fast_math = True
1101
1102    # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags
1103    flush_denormals = True
1104
1105    # Flag to print register and LDS usage during compilation
1106    print_kernel_resource_usage = False
1107
1108    # Path to ROCm installation, if None, use env variable ROCM_HOME
1109    rocm_home: Optional[str] = None
1110
1111    # Path to Composable Kernel library.
1112    # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
1113    ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
1114
1115    # Number of op instance choices to trade off between runtime perf and compilation time
1116    n_max_profiling_configs: Optional[int] = None
1117
1118    # Flag to use a short list of CK instances which perform well across a variety of shapes.
1119    # Currently RCR and F16 only
1120    use_preselected_instances: bool = False
1121
1122
1123# Backend to use for CPU codegen either "cpp" or "halide" (experimental)
1124cpu_backend = "cpp"
1125
1126# Backend to use for CUDA codegen either "triton" or "halide" (experimental)
1127cuda_backend = "triton"
1128
1129
1130class halide:
1131    # Base halide target to use for CPU devices
1132    cpu_target = "host"
1133
1134    # Base halide target to use for CUDA devices
1135    gpu_target = "host-cuda"
1136
1137    # Halide autoscheduler to use, choices are:
1138    # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
1139    scheduler_cuda = "Anderson2021"
1140    scheduler_cpu = "Adams2019"
1141
1142    # Controls `no_asserts` flag passed to Halide target (warning: can false positive)
1143    asserts = False
1144
1145    # Controls `debug` flag passed to Halide target
1146    debug = False
1147
1148    # Enable (or fallback on) scan kernels such as cumsum
1149    # Halide autoschedulers struggle with these kernels
1150    scan_kernels = False
1151
1152
1153# create a directory containing lots of debug information
1154class trace:
1155    # master switch for all debugging flags below
1156    enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
1157
1158    # Save debug information to a temporary directory
1159    # If not specified, a temp directory will be created by system
1160    debug_dir: Optional[str] = None
1161
1162    # Save python logger call >=logging.DEBUG
1163    debug_log = False
1164
1165    # Save python logger call >=logging.INFO
1166    info_log = False
1167
1168    # Save input FX graph (post decomps, pre optimization)
1169    fx_graph = True
1170
1171    # Save FX graph after transformations
1172    fx_graph_transformed = True
1173
1174    # Save TorchInductor IR before fusion pass
1175    ir_pre_fusion = True
1176
1177    # Save TorchInductor IR after fusion pass
1178    ir_post_fusion = True
1179
1180    # Copy generated code to trace dir
1181    output_code = True
1182
1183    # SVG figure showing post-fusion graph
1184    graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
1185
1186    # SVG figure showing fx with fusion
1187    draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
1188
1189    # We draw our fx graphs with the "record" shape attribute by default.
1190    # Sometimes, when the graph is very complex, we may hit dot errors like below:
1191    #   "flat edge between adjacent nodes one of which has a record shape -
1192    #    replace records with HTML-like labels"
1193    # and thus fail to generate a graph. So, let's give the user an option
1194    # to specify the shape attribute for the dot graph. For example, passing
1195    # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables
1196    # to workaround the above failure.
1197    dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
1198
1199    # If not None, this is the URL that saves the SVG files of the input/output
1200    # graph of each pass that changed the graph
1201    # The nodes that are being transformed in each pass will be colored in yellow
1202    # URL only supports local directory for now
1203    log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None)
1204
1205    # Store cProfile (see snakeviz to view)
1206    compile_profile = False
1207
1208    # Upload the .tar.gz file
1209    # Needs to be overriden based on specific environment needs
1210    upload_tar: Optional[Callable[[str], None]] = None
1211
1212    log_autotuning_results: bool = False
1213
1214
1215_save_config_ignore = [
1216    # workaround: "Can't pickle <function ...>"
1217    "trace.upload_tar",
1218    "post_grad_custom_post_pass",
1219    "post_grad_custom_pre_pass",
1220    "joint_custom_pre_pass",
1221    "joint_custom_post_pass",
1222    "pre_grad_custom_pass",
1223]
1224
1225_cache_config_ignore_prefix = [
1226    # trace functions are not relevant to config caching
1227    "trace",
1228    # uses absolute path
1229    "cuda.cutlass_dir",
1230    # not relevant
1231    "compile_threads",
1232]
1233
1234if TYPE_CHECKING:
1235    from torch.utils._config_typing import *  # noqa: F401, F403
1236
1237from torch.utils._config_module import install_config_module
1238
1239
1240# adds patch, save_config, etc
1241install_config_module(sys.modules[__name__])
1242