Home
last modified time | relevance | path

Searched refs:unwrap_tensors (Results 1 – 19 of 19) sorted by relevance

/aosp_15_r20/external/pytorch/torch/_higher_order_ops/
H A Dflex_attention.py405 query_unwrapped = ctx.unwrap_tensors(query)
406 key_unwrapped = ctx.unwrap_tensors(key)
407 value_unwrapped = ctx.unwrap_tensors(value)
408 block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
409 score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
410 mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
976 query_unwrapped = ctx.unwrap_tensors(query)
977 key_unwrapped = ctx.unwrap_tensors(key)
978 value_unwrapped = ctx.unwrap_tensors(value)
979 out_unwrapped = ctx.unwrap_tensors(out)
[all …]
H A Dhints_wrap.py94 unwrapped_args = ctx.unwrap_tensors(args)
95 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
96 unwrapped_hints = ctx.unwrap_tensors(hints)
H A Dauto_functionalize.py353 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
491 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
497 all_basis_unwrapped = ctx.unwrap_tensors(all_bases)
615 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
710 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
H A Deffects.py265 unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
266 unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
267 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
H A Dwhile_loop.py238 unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
239 unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
H A Dmap.py245 unwrapped_xs = ctx.unwrap_tensors(xs)
246 unwrapped_args = ctx.unwrap_tensors(pos_args)
H A Dcond.py450 unwrapped_inputs = ctx.unwrap_tensors(inputs)
451 unwrapped_pred = ctx.unwrap_tensors(pred)
H A Drun_const_graph.py38 unwrapped_args = ctx.unwrap_tensors(args)
H A Dstrict_mode.py89 unwrapped_inputs = ctx.unwrap_tensors(inputs)
H A Dout_dtype.py162 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
H A Dexecutorch_call_delegate.py139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
H A Dtriton_kernel_wrap.py631 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
723 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
H A Dassociative_scan.py359 unwrapped_input = ctx.unwrap_tensors(input)
/aosp_15_r20/external/pytorch/torch/_prims/
H A Drng_prims.py301 unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
302 unwrapped_args = ctx.unwrap_tensors(args)
303 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
/aosp_15_r20/external/pytorch/torch/_export/
H A Dwrappers.py43 unwrapped_args = ctx.unwrap_tensors(args)
44 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
/aosp_15_r20/external/pytorch/torch/_subclasses/
H A Dfunctional_tensor.py642 def unwrap_tensors( member in BaseFunctionalizeAPI
686 def unwrap_tensors( member in PythonFunctionalizeAPI
728 def unwrap_tensors( member in CppFunctionalizeAPI
767 def unwrap_tensors( member in FunctorchFunctionalizeAPI
/aosp_15_r20/external/pytorch/torch/_functorch/
H A Deager_transforms.py123 def unwrap_tensors(x): function
128 return tree_map(unwrap_tensors, tuple(x))
132 return tree_map(unwrap_tensors, inps)
/aosp_15_r20/external/pytorch/torch/_dynamo/
H A D_trace_wrapped_higher_order_op.py125 unwrapped_args = ctx.unwrap_tensors(args)
/aosp_15_r20/external/executorch/exir/
H A Ddelegate.py139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)