Home
last modified time | relevance | path

Searched defs:average_attn_weights (Results 1 – 6 of 6) sorted by relevance

/aosp_15_r20/external/pytorch/test/nn/
H A Dtest_multihead_attention.py41 def test_multihead_attention(self, average_attn_weights): argument
49 average_attn_weights=average_attn_weights, argument
145 average_attn_weights=average_attn_weights, argument
/aosp_15_r20/external/pytorch/test/
H A Dtest_native_mha.py113 …self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=… argument
278 … need_weights, average_attn_weights, use_padding, pad_all, fused): argument
/aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/
H A Dactivation.cpp447 bool average_attn_weights) { in forward()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/
H A Dattention.cpp273 bool average_attn_weights, in native_multi_head_attention_cpu()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/
H A Dattention.cu491 bool average_attn_weights, in native_multi_head_attention_cuda()
/aosp_15_r20/external/pytorch/test/cpp/api/
H A Dmodules.cpp3437 bool average_attn_weights = true) {
3508 bool average_attn_weights = true) { in _multihead_attn_test_helper()