/aosp_15_r20/external/pytorch/torch/_higher_order_ops/ |
H A D | flex_attention.py | 405 query_unwrapped = ctx.unwrap_tensors(query) 406 key_unwrapped = ctx.unwrap_tensors(key) 407 value_unwrapped = ctx.unwrap_tensors(value) 408 block_mask_unwrapped = ctx.unwrap_tensors(block_mask) 409 score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) 410 mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) 976 query_unwrapped = ctx.unwrap_tensors(query) 977 key_unwrapped = ctx.unwrap_tensors(key) 978 value_unwrapped = ctx.unwrap_tensors(value) 979 out_unwrapped = ctx.unwrap_tensors(out) [all …]
|
H A D | hints_wrap.py | 94 unwrapped_args = ctx.unwrap_tensors(args) 95 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 96 unwrapped_hints = ctx.unwrap_tensors(hints)
|
H A D | auto_functionalize.py | 353 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] 491 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] 497 all_basis_unwrapped = ctx.unwrap_tensors(all_bases) 615 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 710 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
H A D | effects.py | 265 unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type] 266 unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type] 267 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
H A D | while_loop.py | 238 unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) 239 unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
H A D | map.py | 245 unwrapped_xs = ctx.unwrap_tensors(xs) 246 unwrapped_args = ctx.unwrap_tensors(pos_args)
|
H A D | cond.py | 450 unwrapped_inputs = ctx.unwrap_tensors(inputs) 451 unwrapped_pred = ctx.unwrap_tensors(pred)
|
H A D | run_const_graph.py | 38 unwrapped_args = ctx.unwrap_tensors(args)
|
H A D | strict_mode.py | 89 unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
H A D | out_dtype.py | 162 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
|
H A D | executorch_call_delegate.py | 139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
|
H A D | triton_kernel_wrap.py | 631 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 723 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
H A D | associative_scan.py | 359 unwrapped_input = ctx.unwrap_tensors(input)
|
/aosp_15_r20/external/pytorch/torch/_prims/ |
H A D | rng_prims.py | 301 unwrapped_rng_state = ctx.unwrap_tensors(rng_state) 302 unwrapped_args = ctx.unwrap_tensors(args) 303 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
/aosp_15_r20/external/pytorch/torch/_export/ |
H A D | wrappers.py | 43 unwrapped_args = ctx.unwrap_tensors(args) 44 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
/aosp_15_r20/external/pytorch/torch/_subclasses/ |
H A D | functional_tensor.py | 642 def unwrap_tensors( member in BaseFunctionalizeAPI 686 def unwrap_tensors( member in PythonFunctionalizeAPI 728 def unwrap_tensors( member in CppFunctionalizeAPI 767 def unwrap_tensors( member in FunctorchFunctionalizeAPI
|
/aosp_15_r20/external/pytorch/torch/_functorch/ |
H A D | eager_transforms.py | 123 def unwrap_tensors(x): function 128 return tree_map(unwrap_tensors, tuple(x)) 132 return tree_map(unwrap_tensors, inps)
|
/aosp_15_r20/external/pytorch/torch/_dynamo/ |
H A D | _trace_wrapped_higher_order_op.py | 125 unwrapped_args = ctx.unwrap_tensors(args)
|
/aosp_15_r20/external/executorch/exir/ |
H A D | delegate.py | 139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
|