Home
last modified time | relevance | path

Searched defs:key_padding_mask (Results 1 – 8 of 8) sorted by relevance

/aosp_15_r20/external/pytorch/test/cpp/api/
H A Dtransformer.cpp341 torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; in transformer_decoder_layer_test_helper() local
939 torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; in transformer_decoder_test_helper() local
H A Dmodules.cpp3436 const torch::Tensor& key_padding_mask = {},
3543 torch::Tensor key_padding_mask; in _multihead_attn_test_helper() local
/aosp_15_r20/external/pytorch/test/
H A Dtest_transformers.py1092 …def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True… argument
2018 key_padding_mask=None, argument
2083 …ct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): argument
2114 key_padding_mask, argument
H A Dtest_native_mha.py167 def forward(self, q, k, v, key_padding_mask): argument
H A Dtest_jit.py14922 key_padding_mask=None, # type: Optional[Tensor] argument
/aosp_15_r20/external/pytorch/test/nn/
H A Dtest_multihead_attention.py48 key_padding_mask=None, argument
/aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/
H A Dactivation.cpp444 const Tensor& key_padding_mask, in forward()
/aosp_15_r20/external/pytorch/test/dynamo/
H A Dtest_repros.py888 def _sa_block(self, x, attn_mask, key_padding_mask): argument