Home
last modified time | relevance | path

Searched defs:cumulative_sequence_length_q (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/
H A DNestedTensorTransformerUtils.cpp251 Tensor cumulative_sequence_length_q; in sdpa_nested_preprocessing_with_broadcast() local
464 const Tensor& cumulative_sequence_length_q, in sdpa_nested_preprocessing_backward()
H A DNestedTensorTransformerFunctions.cpp333 const Tensor& cumulative_sequence_length_q, in _scaled_dot_product_flash_attention_backward_nested()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention_backward.cu68 const Tensor& cumulative_sequence_length_q, in _flash_attention_backward()
739 const Tensor& cumulative_sequence_length_q, in _scaled_dot_product_flash_attention_backward_cuda()
H A Dattention.cu847 const std::optional<Tensor>& cumulative_sequence_length_q, in _flash_attention_forward()