Attention score/softmax/output kernels with SIMD (SSE/AVX/AVX512) More...
Go to the source code of this file.
Functions | |
| void | attention_backward_causal_head_major (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_backward_causal_head_major_gqa (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_backward_causal_head_major_gqa_bf16 (const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v) |
| static void | attention_flash_query_causal (const float *q_vec, const float *k_head, const float *v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec) |
| static void | attention_flash_query_sliding (const float *q_vec, const float *k_head, const float *v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec, int sliding_window) |
| void | attention_forward_causal_head_major (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa_bf16 (const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v) |
| void | attention_forward_causal_head_major_gqa_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa_flash (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim) |
| void | attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens) |
| void | attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window) |
| void | attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| void | attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window) |
| void | attention_forward_decode_head_major_gqa_regular (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| WARNING: This is NOT true flash attention! More... | |
| static void | convert_bf16_tensor_to_buf (const uint16_t *src, float *dst, size_t count) |
| static size_t | qkv_index (int h, int t, int d, int num_tokens, int aligned_head_dim) |
| static size_t | score_index (int h, int i, int j, int aligned_context_window) |
Attention score/softmax/output kernels with SIMD (SSE/AVX/AVX512)
After changes: make test && make llamacpp-parity-full
Attention: softmax(Q @ K^T / sqrt(d)) @ V Supports GQA (grouped-query attention) with head broadcasting.
Definition in file attention_kernels.c.
| #define FLASH_QUERY_IMPL attention_flash_query_causal |
| #define FLASH_QUERY_IMPL attention_flash_query_causal |
| #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal |
| #define SLIDING_DECODE_IMPL attention_flash_query_sliding |
| #define SLIDING_FLASH_IMPL attention_flash_query_sliding |
| void attention_backward_causal_head_major | ( | const float * | d_output, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention backward (non-GQA version)
test_attention_backward.py::TestAttentionBackward::test_backward
test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate
test_parity.py::test_attention_backward_parity
Non-GQA version where num_heads == num_kv_heads. Simpler than GQA, no head broadcasting needed.
After changes: make test && make llamacpp-parity-full
Definition at line 1811 of file attention_kernels.c.
References attention_backward_causal_head_major_gqa().
| void attention_backward_causal_head_major_gqa | ( | const float * | d_output, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention backward (score-matrix version)
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate
test_parity.py::test_attention_backward_parity
Computes dQ, dK, dV given dOutput and attention weights. Supports grouped-query attention with head broadcasting.
After changes: make test && make llamacpp-parity-full
Definition at line 1672 of file attention_kernels.c.
References qkv_index(), and score_index().
Referenced by attention_backward_causal_head_major(), attention_backward_causal_head_major_gqa_bf16(), and ck_layer_backward_rmsnorm_swiglu().
| void attention_backward_causal_head_major_gqa_bf16 | ( | const uint16_t * | d_output, |
| float * | d_x, | ||
| const uint16_t * | q, | ||
| const uint16_t * | k, | ||
| const uint16_t * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window, | ||
| float * | scratch_d_output, | ||
| float * | scratch_q, | ||
| float * | scratch_k, | ||
| float * | scratch_v | ||
| ) |
BF16 attention backward with caller-provided scratch buffers
Accepts BF16 inputs, converts to FP32, runs FP32 backward. Caller provides scratch buffers (no per-call malloc).
After changes: make test
Definition at line 1619 of file attention_kernels.c.
References attention_backward_causal_head_major_gqa(), and convert_bf16_tensor_to_buf().
|
static |
|
static |
| void attention_forward_causal_head_major | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention forward (score-matrix version)
test_attention.py::TestAttentionForward::test_causal_forward
test_attention.py::TestAttentionForward::test_gqa_broadcast
test_attention.py::TestAttentionForward::test_exact_vs_fast
test_parity.py::test_attention_parity
Computes softmax(Q @ K^T / sqrt(d)) @ V with causal masking. Uses O(N^2) memory for scores matrix.
After changes: make test && make llamacpp-parity-full
Definition at line 70 of file attention_kernels.c.
References causal_softmax_head_major(), qkv_index(), and score_index().
| void attention_forward_causal_head_major_exact | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention forward (exact version using stdlib expf)
test_attention.py::TestAttentionForward::test_exact_single
test_attention.py::TestAttentionForward::test_exact_vs_fast
Uses standard library expf for numerical accuracy reference. Slower but provides maximum accuracy.
After changes: make test
Definition at line 146 of file attention_kernels.c.
References causal_softmax_head_major_exact(), qkv_index(), and score_index().
| void attention_forward_causal_head_major_gqa | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention forward (score-matrix version)
test_attention.py::TestAttentionForward::test_gqa_forward
test_attention.py::TestAttentionForward::test_gqa_broadcast
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
test_parity.py::test_attention_gqa_parity
Grouped-query attention: Q has num_heads, K/V have num_kv_heads. Each query head maps to a KV head via ratio.
After changes: make test && make llamacpp-parity-full
Definition at line 224 of file attention_kernels.c.
References causal_softmax_head_major(), qkv_index(), and score_index().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), and ck_layer_forward_rmsnorm_swiglu_ref().
| void attention_forward_causal_head_major_gqa_bf16 | ( | const uint16_t * | q, |
| const uint16_t * | k, | ||
| const uint16_t * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window, | ||
| float * | scratch_q, | ||
| float * | scratch_k, | ||
| float * | scratch_v | ||
| ) |
BF16 GQA causal attention forward
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash
Accepts BF16 inputs, converts to FP32, uses exact softmax. Caller provides scratch buffers (no per-call malloc).
After changes: make test
Definition at line 366 of file attention_kernels.c.
References attention_forward_causal_head_major_gqa_exact(), and convert_bf16_tensor_to_buf().
| void attention_forward_causal_head_major_gqa_exact | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention forward (exact version using stdlib expf)
test_attention.py::TestAttentionForward::test_gqa_exact
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
Uses standard library expf for numerical accuracy reference. Used by BF16 wrapper to avoid approximation error accumulation.
After changes: make test
Definition at line 294 of file attention_kernels.c.
References causal_softmax_head_major_exact(), qkv_index(), and score_index().
Referenced by attention_forward_causal_head_major_gqa_bf16().
| void attention_forward_causal_head_major_gqa_flash | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention forward for GQA (prefill, no score materialization)
test_flash_attention.py::TestFlashAttention::test_flash_forward
test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix
test_flash_attention.py::TestFlashAttention::test_flash_gqa
test_attention.py::TestAttentionForward::test_flash_forward
Online softmax with streaming KV. O(N) memory instead of O(N^2). For prefill: all tokens attend to previous tokens.
After changes: make test && make llamacpp-parity-full
Definition at line 800 of file attention_kernels.c.
References FLASH_QUERY_IMPL, and qkv_index().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().
| void attention_forward_causal_head_major_gqa_flash_strided | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens | ||
| ) |
Flash attention forward with custom KV stride (for KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_strided
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention
Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.
After changes: make test
Definition at line 859 of file attention_kernels.c.
References FLASH_QUERY_IMPL, and qkv_index().
Referenced by ck_test_attention_causal(), mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().
| void attention_forward_causal_head_major_gqa_flash_strided_sliding | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| int | sliding_window | ||
| ) |
Flash attention forward with sliding window (prefill)
Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.
After changes: make test
Definition at line 1316 of file attention_kernels.c.
References qkv_index(), and SLIDING_FLASH_IMPL.
Referenced by ck_test_attention_sliding_window().
| void attention_forward_decode_head_major_gqa_flash | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention decode (single token attends to KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_decode
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode
test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode
test_attention.py::TestAttentionForward::test_flash_decode
Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.
After changes: make test && make llamacpp-parity-full
Definition at line 1467 of file attention_kernels.c.
References attention_flash_decode().
Referenced by mega_fused_attention_decode_q5_0(), mega_fused_attention_decode_q5_0_parallel_simd(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void attention_forward_decode_head_major_gqa_flash_sliding | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | sliding_window | ||
| ) |
Flash attention decode with sliding window
Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)
After changes: make test
Definition at line 1382 of file attention_kernels.c.
References SLIDING_DECODE_IMPL.
Referenced by ck_test_attention_decode_sliding().
| void attention_forward_decode_head_major_gqa_regular | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
WARNING: This is NOT true flash attention!
This function is named "flash" but implements regular attention with O(n) complexity. It's kept for reference and as a fallback.
TRUE flash attention is implemented in attention_flash_true.c
test_kv_cache_attention.py::TestKVCacheAttention::test_regular_decode
test_attention.py::TestAttentionForward::test_regular_decode
Regular attention decode (score-matrix version) for fallback.
After changes: make test
Definition at line 1524 of file attention_kernels.c.
References FLASH_QUERY_IMPL_DECODE.
Referenced by ck_attention_flash_decode_wrapper(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
|
static |
Definition at line 28 of file attention_kernels.c.
References bf16_tensor_to_float().
Referenced by attention_backward_causal_head_major_gqa_bf16(), and attention_forward_causal_head_major_gqa_bf16().
|
inlinestatic |
Definition at line 36 of file attention_kernels.c.
Referenced by attention_backward_causal_head_major_gqa(), attention_forward_causal_head_major(), attention_forward_causal_head_major_exact(), attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_exact(), attention_forward_causal_head_major_gqa_flash(), attention_forward_causal_head_major_gqa_flash_strided(), and attention_forward_causal_head_major_gqa_flash_strided_sliding().
|
inlinestatic |