Home
last modified time | relevance | path

Searched full:attn_bias (Results 1 – 22 of 22) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention_backward.cu176 const Tensor& attn_bias, in _scaled_dot_product_cudnn_attention_backward_cuda() argument
204 if (attn_bias.defined()) { in _scaled_dot_product_cudnn_attention_backward_cuda()
205 attn_bias_ = attn_bias; in _scaled_dot_product_cudnn_attention_backward_cuda()
215 …TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bi… in _scaled_dot_product_cudnn_attention_backward_cuda()
235 attn_bias_ /*const std::optional<Tensor>& attn_bias*/, in _scaled_dot_product_cudnn_attention_backward_cuda()
593 "attn_bias: wrong shape (batch dimension)"); in _efficient_attention_backward()
596 "attn_bias: wrong shape (head dimension)"); in _efficient_attention_backward()
599 "attn_bias: wrong shape (seqlenQ dimension)"); in _efficient_attention_backward()
602 "attn_bias: wrong shape (seqlenKV dimension)"); in _efficient_attention_backward()
605 "attn_bias: wrong alignment (last dimension must be contiguous)"); in _efficient_attention_backward()
[all …]
H A Dattention.cu793 const std::optional<at::Tensor>& attn_bias, in _scaled_dot_product_efficient_attention_cuda() argument
815 attn_bias, in _scaled_dot_product_efficient_attention_cuda()
1248 "attn_bias: wrong shape (batch dimension)"); in _efficient_attention_forward()
1251 "attn_bias: wrong shape (head dimension)"); in _efficient_attention_forward()
1254 "attn_bias: wrong shape (seqlenQ dimension)"); in _efficient_attention_forward()
1257 "attn_bias: wrong shape (seqlenKV dimension)"); in _efficient_attention_forward()
1263 "attn_bias: wrong alignment (last dimension must be contiguous)"); in _efficient_attention_forward()
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/
H A DBatchRulesLinearAlgebra.cpp542 const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim, in _scaled_dot_product_efficient_attention_batch_rule() argument
566 if (attn_bias.has_value() && attn_bias->defined()) { in _scaled_dot_product_efficient_attention_batch_rule()
567 …bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value(… in _scaled_dot_product_efficient_attention_batch_rule()
583 const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim, in _scaled_dot_product_cudnn_attention_batch_rule() argument
607 if (attn_bias.has_value() && attn_bias->defined()) { in _scaled_dot_product_cudnn_attention_batch_rule()
608 …bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value(… in _scaled_dot_product_cudnn_attention_batch_rule()
/aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/
H A D_attention.py151 attn_bias: Optional[torch.Tensor] = None,
158 if attn_bias is not None:
159 raise NotImplementedError("attn_bias is not supported yet")
170 attn_bias=attn_bias,
523 attn_bias=bias,
/aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/
H A D_matrix_ops.py434 # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present;
440 all_replicate[3] = None # grad bias is None if attn_bias is not present
462 # the place for optional input attn_bias,
466 # input sharding of attn_bias on heads dim if present
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/
H A Dattention.cpp550 at::Tensor pad_bias(const at::Tensor& attn_bias) { in pad_bias() argument
551 auto last_dim_size = attn_bias.sym_size(-1); in pad_bias()
553 auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count}); in pad_bias()
579 at::Tensor pad_last_dim(const at::Tensor& attn_bias) { in pad_last_dim() argument
580 auto last_dim_size = attn_bias.sym_size(-1); in pad_last_dim()
582 return attn_bias; in pad_last_dim()
585 auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count}); in pad_last_dim()
/aosp_15_r20/external/pytorch/test/
H A Dtest_transformers.py2019 attn_bias=None, argument
2030 attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
2059 if attn_bias is not None:
2060 scores = scores + attn_bias.to(dtype=scores.dtype)
3051 kwargs["attn_bias"] = None
3422 attn_bias=None, argument
3433 realized = attn_bias._materialize(device) if attn_bias is not None else None
3447 attn_mask=attn_bias,
3491 attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3493 attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
[all …]
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/generated/
H A Dc_shim_cuda.h37 … query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute…
38 …sorHandle philox_seed, AtenTensorHandle philox_offset, AtenTensorHandle attn_bias, AtenTensorHandl…
39 … query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute…
40 …e query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandl…
/aosp_15_r20/external/pytorch/torch/nn/attention/
H A Dbias.py102 attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
108 out = F.scaled_dot_product_attention(q, k, v, attn_bias)
283 …e behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/
H A DMHA.h21 const std::optional<Tensor>& attn_bias,
40 const std::optional<Tensor>& attn_bias,
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/
H A Dkernel_forward.h580 "attn_bias is not correctly aligned (strideB). ", in check_supported()
581 "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " in check_supported()
585 "attn_bias is not correctly aligned (strideH). " in check_supported()
586 "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " in check_supported()
590 "attn_bias is not correctly aligned (strideM). " in check_supported()
591 "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " in check_supported()
H A Dkernel_backward.h1224 "attn_bias is not correctly aligned (strideB). ", in check_supported()
1225 "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " in check_supported()
1229 "attn_bias is not correctly aligned (strideH) ." in check_supported()
1230 "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " in check_supported()
1234 "attn_bias is not correctly aligned (strideM). " in check_supported()
1235 "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", in check_supported()
1241 "attn_bias.grad is not correctly aligned (strideB)"); in check_supported()
1244 "attn_bias.grad is not correctly aligned (strideH)"); in check_supported()
1247 "attn_bias.grad is not correctly aligned (strideM)"); in check_supported()
/aosp_15_r20/external/pytorch/test/cpp_extensions/
H A Dopen_registration_extension.cpp449 const std::optional<at::Tensor> & attn_bias, in custom_scaled_dot_product_fused_attention_overrideable() argument
477 const at::Tensor & attn_bias, in custom_scaled_dot_product_fused_attention_overrideable_backward() argument
494 at::empty_like(attn_bias)); in custom_scaled_dot_product_fused_attention_overrideable_backward()
/aosp_15_r20/external/pytorch/torch/
H A D_meta_registrations.py5106 attn_bias: Optional[Tensor],
5259 attn_bias: Optional[Tensor],
5304 attn_bias: Optional[Tensor],
5341 if attn_bias is not None and grad_input_mask[3]:
5342 lastDim = attn_bias.size(-1)
5344 new_sizes = list(attn_bias.size())
5347 new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5368 attn_bias: Tensor,
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/
H A Dshim_common.cpp604 AtenTensorHandle attn_bias, // optional argument in aoti_torch__scaled_dot_product_efficient_attention() argument
619 pointer_to_optional(tensor_handle_to_tensor_pointer(attn_bias)); in aoti_torch__scaled_dot_product_efficient_attention()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/
H A DNestedTensorTransformerFunctions.cpp286 const std::optional<at::Tensor>& attn_bias, in _scaled_dot_product_efficient_attention_nestedtensor_cuda() argument
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/c/
H A Dshim.h377 AtenTensorHandle attn_bias, // optional argument
/aosp_15_r20/external/pytorch/torch/nn/
H A Dfunctional.py5623 attn_bias = torch.zeros(L, S, dtype=query.dtype)
5627 attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
5628 attn_bias.to(query.dtype)
5632 attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
5634 attn_bias += attn_mask
5641 attn_weight += attn_bias
/aosp_15_r20/external/pytorch/tools/autograd/
H A Dderivatives.yaml2807 …uct_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_lo…
2809 …query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, v…
/aosp_15_r20/external/pytorch/test/inductor/
H A Dtest_aot_inductor.py2847 def forward(self, q, k, v, attn_bias): argument
2849 q, k, v, attn_bias, False
H A Dtest_torchinductor.py9655 def fn(q, k, v, attn_bias, compute_log_sumexp): argument
9657 q, k, v, attn_bias, compute_log_sumexp
/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A Dnative_functions.yaml14719 …uct_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_lo…
14725 …ckward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Ten…