Flash-style attention (online softmax, causal, streaming) More...
#include <math.h>#include <stddef.h>#include <stdint.h>Go to the source code of this file.
Macros | |
| #define | CK_FLASH_ATTN_FAST_EXP 0 |
| #define | CK_FLASH_ATTN_TILE_K 32 |
Functions | |
| void | attention_flash_cleanup (void) |
| Clean up flash attention resources. More... | |
| void | attention_flash_decode (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale) |
| Main flash attention function with SIMD dispatch. More... | |
| static void | attention_flash_decode_scalar (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale) |
| Scalar flash-style attention (online softmax) More... | |
| void | attention_flash_init (int max_context, int max_heads, int max_head_dim) |
| Initialize flash attention buffers. More... | |
| static float | ck_expf (float x) |
| static float | ck_fast_expf (float x) |
| int | ck_flash_attn_choose_tile_k (int D_h) |
| int | ck_flash_attn_fast_exp_kind (void) |
| static int | ck_flash_attn_tile_k (int D_h) |
| static int | max_k_for_query (int t_q, int T_q, int T_k) |
Flash-style attention (online softmax, causal, streaming)
After changes: make test && make llamacpp-parity-full
Layout: Q/K/V/Out: [T, H, D_h] contiguous
Causal alignment: Queries are assumed to correspond to the last T_q positions in the KV cache. This makes T_q == T_k behave like standard causal prefill, and T_q == 1 behave like decode over a full KV cache.
Notes:
Definition in file attention_flash_true.c.
| #define CK_FLASH_ATTN_FAST_EXP 0 |
Definition at line 44 of file attention_flash_true.c.
| #define CK_FLASH_ATTN_TILE_K 32 |
Definition at line 40 of file attention_flash_true.c.
| void attention_flash_cleanup | ( | void | ) |
Clean up flash attention resources.
Definition at line 739 of file attention_flash_true.c.
| void attention_flash_decode | ( | float * | out, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| int | T_q, | ||
| int | T_k, | ||
| int | H, | ||
| int | D_h, | ||
| float | scale | ||
| ) |
Main flash attention function with SIMD dispatch.
| out | Output [T_q, H, D_h] |
| q | Query [T_q, H, D_h] |
| k | Key [T_k, H, D_h] |
| v | Value [T_k, H, D_h] |
| T_q | Number of query tokens (1 for decode) |
| T_k | Number of key/value tokens (context length) |
| H | Number of heads |
| D_h | Head dimension |
| scale | 1/sqrt(D_h) |
Definition at line 696 of file attention_flash_true.c.
References attention_flash_decode_scalar().
Referenced by attention_forward_decode_head_major_gqa_flash(), ck_attention_flash_decode_wrapper(), mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
|
static |
Scalar flash-style attention (online softmax)
Definition at line 142 of file attention_flash_true.c.
References ck_expf(), CK_FLASH_ATTN_TILE_K, ck_flash_attn_tile_k(), max_k_for_query(), and score.
Referenced by attention_flash_decode().
| void attention_flash_init | ( | int | max_context, |
| int | max_heads, | ||
| int | max_head_dim | ||
| ) |
Initialize flash attention buffers.
Definition at line 731 of file attention_flash_true.c.
|
inlinestatic |
Definition at line 80 of file attention_flash_true.c.
References ck_fast_expf().
Referenced by attention_flash_decode_scalar().
|
inlinestatic |
| int ck_flash_attn_choose_tile_k | ( | int | D_h | ) |
Definition at line 108 of file attention_flash_true.c.
References ck_flash_attn_tile_k().
| int ck_flash_attn_fast_exp_kind | ( | void | ) |
Definition at line 112 of file attention_flash_true.c.
|
inlinestatic |
Definition at line 88 of file attention_flash_true.c.
References CK_FLASH_ATTN_TILE_K.
Referenced by attention_flash_decode_scalar(), and ck_flash_attn_choose_tile_k().
|
inlinestatic |
Definition at line 126 of file attention_flash_true.c.
Referenced by attention_flash_decode_scalar().