Home
last modified time | relevance | path

Searched defs:score_mod (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/pytorch/test/inductor/
H A Dtest_flex_decoding.py43 def create_attention(score_mod, block_mask, enable_gqa=False): argument
52 def create_block_mask_test(score_mod, query, key): argument
144 def score_mod(score, b, h, m, n): function
375 def sdpa_hop(q, k, v, score_mod, block_mask): argument
458 def score_mod(score, b, h, q, kv): function
482 def score_mod(score, b, h, m, n): function
558 def test_non_equal_head_dims(self, dtype, score_mod, head_dims): argument
768 def score_mod(score, b, h, m, n): function
873 def test_logsumexp_correctness(self, dtype, score_mod): argument
891 def sdpa_hop(q, k, v, score_mod): argument
[all …]
H A Dtest_flex_attention.py47 def create_attention(score_mod): argument
84 def score_mod(score, b, h, m, n): function
445 def score_mod(score, b, h, q, kv): function
468 def score_mod(score, b, h, m, n): function
823 def score_mod(score, b, h, m, n): function
847 def test_logsumexp_correctness(self, dtype, score_mod): argument
849 def sdpa_hop(q, k, v, score_mod): argument
853 def eager_sdpa_hop(q, k, v, score_mod): argument
912 def func(q, k, v, score_mod): argument
933 def func(q, k, v, score_mod): argument
[all …]
/aosp_15_r20/external/pytorch/torch/testing/_internal/
H A Dhop_db.py120 def score_mod(score, b, h, m, n): function
/aosp_15_r20/external/pytorch/torch/_higher_order_ops/
H A Dflex_attention.py481 def create_fw_bw_graph(score_mod, index_values, other_buffers): argument