37 if (num_heads <= 0 || tokens <= 0 || cache_capacity <= 0 || aligned_head_dim <= 0) {
40 if (tokens > cache_capacity) {
41 tokens = cache_capacity;
43 if (tokens == cache_capacity) {
47 const size_t old_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
48 const size_t new_head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
49 const size_t bytes = (size_t)tokens * (
size_t)aligned_head_dim *
sizeof(float);
53 for (
int h = num_heads - 1; h >= 0; --h) {
54 float *src = buf + (size_t)h * old_head_stride;
55 float *dst = buf + (size_t)h * new_head_stride;
56 memmove(dst, src, bytes);
61 const float *__restrict v_token,
62 float *__restrict k_cache,
63 float *__restrict v_cache,
70 if (!k_token || !v_token || !k_cache || !v_cache) {
73 if (num_kv_heads <= 0 || token_index < 0 || cache_capacity <= 0) {
76 if (token_index >= cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
80 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
81 const size_t token_stride = (size_t)aligned_head_dim;
83 for (
int h = 0; h < num_kv_heads; ++h) {
84 const float *k_src = k_token + (size_t)h * token_stride;
85 const float *v_src = v_token + (size_t)h * token_stride;
87 float *k_dst = k_cache + (size_t)h * head_stride + (
size_t)token_index * token_stride;
88 float *v_dst = v_cache + (size_t)h * head_stride + (
size_t)token_index * token_stride;
90 for (
int d = 0; d < head_dim; ++d) {
94 for (
int d = head_dim; d < aligned_head_dim; ++d) {
102 float *__restrict kv_cache_v,
103 const float *__restrict k,
104 const float *__restrict v,
113 kv_cache_k, kv_cache_v,
135 float *__restrict dst,
139 if (!src || !dst || position < 0 ||
vocab_size <= 0) {
145 float *dst_pos = dst + (size_t)position * (
size_t)
vocab_size;
146 memmove(dst_pos, src, (
size_t)
vocab_size *
sizeof(
float));
void kv_cache_repack_head_major_inplace(float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void kv_cache_store(float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)
void logits_copy_to_position(const float *__restrict src, float *__restrict dst, int position, int vocab_size)
Copy logits to position-indexed location in output buffer.