1# mypy: ignore-errors 2 3import torch 4from torch.fx.experimental.proxy_tensor import make_fx 5from torch.testing._utils import wrapper_set_seed 6import torch.utils._pytree as pytree 7 8 9def make_fx_check( 10 func, 11 args, 12 kwargs, 13 tracing_mode, 14 assert_close=torch.testing.assert_close, 15 randomize_data=False, 16): 17 f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs) 18 19 def run(f, *args, **kwargs): 20 return wrapper_set_seed(f, *args, **kwargs) 21 22 traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args) 23 24 msg = ( 25 "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different " 26 "values. This could mean that your abstract impls (meta/FakeTensor impls) " 27 "are incorrect, that your operator is not completely traceable (e.g., " 28 "it relies on some global state), or that there is a bug in make_fx. " 29 "Note that if you passed a python function (and not an operator) to " 30 "make_fx_check, it is still possible that the python function will still " 31 "work with torch.compile because it handles capturing pieces of " 32 "your python code to compile." 33 ) 34 35 # Randomize the data and run the traced graph with it, to catch bugs 36 # where we may have baked in Tensor data into the trace. 37 # This is not guaranteed to succeed, because `f` might have preconditions 38 # on the values of the inputs, so we just ignore if we used 39 # random data and it fails. 40 if randomize_data: 41 new_args = randomize(new_args) 42 try: 43 expected = run(f, *new_args) 44 except Exception: 45 if randomize_data: 46 return 47 raise 48 result = run(traced_f, *new_args) 49 assert_close(result, expected, msg=msg) 50 51 52# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes. 53# Absent that, here is our strategy: 54# 55# If any argument is a torch.Size(), maybe get dynamic shapes for it by: 56# - Create a temporary Tensor whose size is the torch.Size() we want. Note that 57# we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. 58# - Pass it to make_fx such that it is is converted to a proxy Tensor 59# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in 60# symbolic mode, a no-op otherwise) 61def handle_sizes_for_dynamic_shapes(func, args, kwargs): 62 def f(args, kwargs, extra_args, extra_kwargs): 63 if extra_args: 64 for i, t in extra_args: 65 args[i] = t.size() 66 if extra_kwargs: 67 for k, t in extra_kwargs.items(): 68 kwargs[k] = t.size() 69 70 return func(*args, **kwargs) 71 72 extra_args = [] 73 extra_kwargs = {} 74 for i, arg in enumerate(args): 75 if isinstance(arg, torch.Size): 76 extra_args.append((i, torch.empty(arg, device="cpu"))) 77 for key, value in kwargs.items(): 78 if isinstance(value, torch.Size): 79 extra_kwargs[key] = torch.empty(value, device="cpu") 80 81 return f, args, kwargs, extra_args, extra_kwargs 82 83 84def randomize(args): 85 def transform(x): 86 if not x.dtype.is_floating_point: 87 return x 88 return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad) 89 return pytree.tree_map_only(torch.Tensor, transform, args) 90