← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rope_kernels.c File Reference

RoPE (Rotary Position Embedding) kernels with SIMD. More...

#include <math.h>
#include <stddef.h>

Go to the source code of this file.

Macros

#define M_PI   3.14159265358979323846
 

Functions

static void rope_apply_head (float *x, const float *cos_cache, const float *sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward (const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward_inplace (float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward_qk (const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_forward (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_forward_qk (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_forward_qk_strided (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
 
void rope_forward_strided (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
 
void rope_precompute_cache (float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
 

Detailed Description

RoPE (Rotary Position Embedding) kernels with SIMD.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Applies rotary position embeddings to query and key vectors. Used by Llama, SmolLM, and most modern transformer architectures.

Math (Llama-style rotate-half): Split head_dim into two halves (0..half-1, half..head_dim-1). For each position m and index i in [0, half): x0 = x[i], x1 = x[i + half] x'[i] = x0 * cos(m * theta_i) - x1 * sin(m * theta_i) x'[i+half] = x0 * sin(m * theta_i) + x1 * cos(m * theta_i)

Where theta_i = 1 / (base^(2i/d)), typically base=10000.

Layout: x: [num_heads, num_tokens, head_dim] head-major cos_cache, sin_cache: [max_seq_len, head_dim/2] precomputed

Definition in file rope_kernels.c.

Macro Definition Documentation

◆ M_PI

#define M_PI   3.14159265358979323846

Definition at line 39 of file rope_kernels.c.

Function Documentation

◆ rope_apply_head()

static void rope_apply_head ( float *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)
inlinestatic

Definition at line 79 of file rope_kernels.c.

86 {
87  int half_dim = head_dim / 2;
88 
89  for (int t = 0; t < num_tokens; ++t) {
90  int pos = pos_offset + t;
91  const float *cos_row = cos_cache + pos * half_dim;
92  const float *sin_row = sin_cache + pos * half_dim;
93  float *x_row = x + (size_t)t * (size_t)aligned_head_dim;
94 
95 #if defined(__AVX512F__)
96  // Process 16 pairs at a time
97  int i = 0;
98  for (; i + 16 <= half_dim; i += 16) {
99  __m512 x0 = _mm512_loadu_ps(&x_row[i]);
100  __m512 x1 = _mm512_loadu_ps(&x_row[i + half_dim]);
101  __m512 c = _mm512_loadu_ps(&cos_row[i]);
102  __m512 s = _mm512_loadu_ps(&sin_row[i]);
103 
104  // x'[i] = x0 * c - x1 * s
105  __m512 r0 = _mm512_fmsub_ps(x0, c, _mm512_mul_ps(x1, s));
106  // x'[i+half] = x0 * s + x1 * c
107  __m512 r1 = _mm512_fmadd_ps(x0, s, _mm512_mul_ps(x1, c));
108 
109  _mm512_storeu_ps(&x_row[i], r0);
110  _mm512_storeu_ps(&x_row[i + half_dim], r1);
111  }
112  // Handle remaining elements
113  for (; i < half_dim; ++i) {
114  float x0 = x_row[i];
115  float x1 = x_row[i + half_dim];
116  float c = cos_row[i];
117  float s = sin_row[i];
118  x_row[i] = x0 * c - x1 * s;
119  x_row[i + half_dim] = x0 * s + x1 * c;
120  }
121 
122 #elif defined(__AVX__)
123  // Process 8 pairs at a time
124  int i = 0;
125  for (; i + 8 <= half_dim; i += 8) {
126  __m256 x0 = _mm256_loadu_ps(&x_row[i]);
127  __m256 x1 = _mm256_loadu_ps(&x_row[i + half_dim]);
128  __m256 c = _mm256_loadu_ps(&cos_row[i]);
129  __m256 s = _mm256_loadu_ps(&sin_row[i]);
130 
131  // x'[i] = x0 * c - x1 * s (no FMA in AVX1)
132  __m256 x0c = _mm256_mul_ps(x0, c);
133  __m256 x1s = _mm256_mul_ps(x1, s);
134  __m256 r0 = _mm256_sub_ps(x0c, x1s);
135 
136  // x'[i+half] = x0 * s + x1 * c
137  __m256 x0s = _mm256_mul_ps(x0, s);
138  __m256 x1c = _mm256_mul_ps(x1, c);
139  __m256 r1 = _mm256_add_ps(x0s, x1c);
140 
141  _mm256_storeu_ps(&x_row[i], r0);
142  _mm256_storeu_ps(&x_row[i + half_dim], r1);
143  }
144  // Handle remaining elements
145  for (; i < half_dim; ++i) {
146  float x0 = x_row[i];
147  float x1 = x_row[i + half_dim];
148  float c = cos_row[i];
149  float s = sin_row[i];
150  x_row[i] = x0 * c - x1 * s;
151  x_row[i + half_dim] = x0 * s + x1 * c;
152  }
153 
154 #else
155  // Scalar fallback
156  for (int i = 0; i < half_dim; ++i) {
157  float x0 = x_row[i];
158  float x1 = x_row[i + half_dim];
159  float c = cos_row[i];
160  float s = sin_row[i];
161 
162  x_row[i] = x0 * c - x1 * s;
163  x_row[i + half_dim] = x0 * s + x1 * c;
164  }
165 #endif
166  }
167 }

Referenced by rope_forward(), and rope_forward_strided().

◆ rope_backward()

void rope_backward ( const float *  d_out,
float *  d_x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward (inverse rotation)

Test:

test_rope.py::TestRoPEBackward::test_rope_backward

test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate

RoPE backward: inverse rotation (rotate by -θ). Since cos(-θ) = cos(θ) and sin(-θ) = -sin(θ): d_x[2i] = d0 * c + d1 * s d_x[2i+1] = -d0 * s + d1 * c

After changes: make test

Definition at line 238 of file rope_kernels.c.

247 {
248  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
249  int half_dim = head_dim / 2;
250 
251  for (int h = 0; h < num_heads; ++h) {
252  for (int t = 0; t < num_tokens; ++t) {
253  int pos = pos_offset + t;
254  const float *cos_row = cos_cache + pos * half_dim;
255  const float *sin_row = sin_cache + pos * half_dim;
256 
257  size_t idx = h * head_stride + (size_t)t * (size_t)aligned_head_dim;
258  const float *d_out_row = d_out + idx;
259  float *d_x_row = d_x + idx;
260 
261 #if defined(__AVX512F__)
262  int i = 0;
263  for (; i + 16 <= half_dim; i += 16) {
264  __m512 d0 = _mm512_loadu_ps(&d_out_row[i]);
265  __m512 d1 = _mm512_loadu_ps(&d_out_row[i + half_dim]);
266  __m512 c = _mm512_loadu_ps(&cos_row[i]);
267  __m512 s = _mm512_loadu_ps(&sin_row[i]);
268 
269  // Inverse: d_x[i] = d0 * c + d1 * s
270  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
271  // Inverse: d_x[i+half] = -d0 * s + d1 * c
272  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
273 
274  _mm512_storeu_ps(&d_x_row[i], r0);
275  _mm512_storeu_ps(&d_x_row[i + half_dim], r1);
276  }
277  for (; i < half_dim; ++i) {
278  float d0 = d_out_row[i];
279  float d1 = d_out_row[i + half_dim];
280  float c = cos_row[i];
281  float s = sin_row[i];
282  d_x_row[i] = d0 * c + d1 * s;
283  d_x_row[i + half_dim] = -d0 * s + d1 * c;
284  }
285 
286 #elif defined(__AVX__)
287  int i = 0;
288  for (; i + 8 <= half_dim; i += 8) {
289  __m256 d0 = _mm256_loadu_ps(&d_out_row[i]);
290  __m256 d1 = _mm256_loadu_ps(&d_out_row[i + half_dim]);
291  __m256 c = _mm256_loadu_ps(&cos_row[i]);
292  __m256 s = _mm256_loadu_ps(&sin_row[i]);
293 
294  // Inverse: d_x[i] = d0 * c + d1 * s
295  __m256 d0c = _mm256_mul_ps(d0, c);
296  __m256 d1s = _mm256_mul_ps(d1, s);
297  __m256 r0 = _mm256_add_ps(d0c, d1s);
298 
299  // Inverse: d_x[i+half] = -d0 * s + d1 * c = d1 * c - d0 * s
300  __m256 d1c = _mm256_mul_ps(d1, c);
301  __m256 d0s = _mm256_mul_ps(d0, s);
302  __m256 r1 = _mm256_sub_ps(d1c, d0s);
303 
304  _mm256_storeu_ps(&d_x_row[i], r0);
305  _mm256_storeu_ps(&d_x_row[i + half_dim], r1);
306  }
307  for (; i < half_dim; ++i) {
308  float d0 = d_out_row[i];
309  float d1 = d_out_row[i + half_dim];
310  float c = cos_row[i];
311  float s = sin_row[i];
312  d_x_row[i] = d0 * c + d1 * s;
313  d_x_row[i + half_dim] = -d0 * s + d1 * c;
314  }
315 
316 #else
317  for (int i = 0; i < half_dim; ++i) {
318  float d0 = d_out_row[i];
319  float d1 = d_out_row[i + half_dim];
320  float c = cos_row[i];
321  float s = sin_row[i];
322 
323  // Inverse rotation: rotate by -θ
324  d_x_row[i] = d0 * c + d1 * s;
325  d_x_row[i + half_dim] = -d0 * s + d1 * c;
326  }
327 #endif
328 
329  for (int i = head_dim; i < aligned_head_dim; ++i) {
330  d_x_row[i] = 0.0f;
331  }
332  }
333  }
334 }

Referenced by rope_backward_bf16(), and rope_backward_qk().

◆ rope_backward_inplace()

void rope_backward_inplace ( float *  d_x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward in-place (overwrite with inverse rotation)

Test:
test_rope.py::TestRoPEBackward::test_rope_backward_inplace

In-place backward: overwrite d_out with inverse-rotated gradients. Useful when d_x == d_out is acceptable (saves memory).

After changes: make test

Definition at line 345 of file rope_kernels.c.

353 {
354  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
355  int half_dim = head_dim / 2;
356 
357  for (int h = 0; h < num_heads; ++h) {
358  for (int t = 0; t < num_tokens; ++t) {
359  int pos = pos_offset + t;
360  const float *cos_row = cos_cache + pos * half_dim;
361  const float *sin_row = sin_cache + pos * half_dim;
362 
363  float *d_row = d_x + h * head_stride + (size_t)t * (size_t)aligned_head_dim;
364 
365 #if defined(__AVX512F__)
366  int i = 0;
367  for (; i + 16 <= half_dim; i += 16) {
368  __m512 d0 = _mm512_loadu_ps(&d_row[i]);
369  __m512 d1 = _mm512_loadu_ps(&d_row[i + half_dim]);
370  __m512 c = _mm512_loadu_ps(&cos_row[i]);
371  __m512 s = _mm512_loadu_ps(&sin_row[i]);
372 
373  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
374  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
375 
376  _mm512_storeu_ps(&d_row[i], r0);
377  _mm512_storeu_ps(&d_row[i + half_dim], r1);
378  }
379  for (; i < half_dim; ++i) {
380  float d0 = d_row[i];
381  float d1 = d_row[i + half_dim];
382  float c = cos_row[i];
383  float s = sin_row[i];
384  d_row[i] = d0 * c + d1 * s;
385  d_row[i + half_dim] = -d0 * s + d1 * c;
386  }
387 
388 #elif defined(__AVX__)
389  int i = 0;
390  for (; i + 8 <= half_dim; i += 8) {
391  __m256 d0 = _mm256_loadu_ps(&d_row[i]);
392  __m256 d1 = _mm256_loadu_ps(&d_row[i + half_dim]);
393  __m256 c = _mm256_loadu_ps(&cos_row[i]);
394  __m256 s = _mm256_loadu_ps(&sin_row[i]);
395 
396  __m256 d0c = _mm256_mul_ps(d0, c);
397  __m256 d1s = _mm256_mul_ps(d1, s);
398  __m256 r0 = _mm256_add_ps(d0c, d1s);
399 
400  __m256 d1c = _mm256_mul_ps(d1, c);
401  __m256 d0s = _mm256_mul_ps(d0, s);
402  __m256 r1 = _mm256_sub_ps(d1c, d0s);
403 
404  _mm256_storeu_ps(&d_row[i], r0);
405  _mm256_storeu_ps(&d_row[i + half_dim], r1);
406  }
407  for (; i < half_dim; ++i) {
408  float d0 = d_row[i];
409  float d1 = d_row[i + half_dim];
410  float c = cos_row[i];
411  float s = sin_row[i];
412  d_row[i] = d0 * c + d1 * s;
413  d_row[i + half_dim] = -d0 * s + d1 * c;
414  }
415 
416 #else
417  for (int i = 0; i < half_dim; ++i) {
418  float d0 = d_row[i];
419  float d1 = d_row[i + half_dim];
420  float c = cos_row[i];
421  float s = sin_row[i];
422 
423  // Inverse rotation: rotate by -θ
424  d_row[i] = d0 * c + d1 * s;
425  d_row[i + half_dim] = -d0 * s + d1 * c;
426  }
427 #endif
428 
429  for (int i = head_dim; i < aligned_head_dim; ++i) {
430  d_row[i] = 0.0f;
431  }
432  }
433  }
434 }

◆ rope_backward_qk()

void rope_backward_qk ( const float *  d_q_out,
const float *  d_k_out,
float *  d_q,
float *  d_k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward for both dQ and dK

Test:
test_rope.py::TestRoPEBackward::test_rope_backward_qk

Combined RoPE backward for both dQ and dK gradients.

After changes: make test

Definition at line 497 of file rope_kernels.c.

509 {
510  rope_backward(d_q_out, d_q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
511  rope_backward(d_k_out, d_k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
512 }
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238

References rope_backward().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ rope_forward()

void rope_forward ( float *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE forward (head-major layout, in-place)

Test:

test_rope.py::TestRoPEForward::test_rope_forward

test_rope.py::TestRoPEForward::test_rope_vs_separate

test_parity.py::test_rope_parity

Applies rotary position embeddings in-place to Q or K tensor. x: [num_heads, num_tokens, head_dim] head-major

After changes: make test && make llamacpp-parity-full

Definition at line 180 of file rope_kernels.c.

188 {
189  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  rope_apply_head(x + h * head_stride,
193  cos_cache, sin_cache,
194  num_tokens, head_dim, aligned_head_dim, pos_offset);
195  }
196 }
static void rope_apply_head(float *x, const float *cos_cache, const float *sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:79

References rope_apply_head().

Referenced by model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), rope_forward_bf16(), and rope_forward_qk().

◆ rope_forward_qk()

void rope_forward_qk ( float *  q,
float *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE forward for both Q and K (common inference pattern)

Test:

test_rope.py::TestRoPEForward::test_rope_forward_qk

test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope

test_parity.py::test_rope_qk_parity

Combined RoPE forward for both Q and K in one call. q: [num_heads, num_tokens, head_dim] k: [num_kv_heads, num_tokens, head_dim]

After changes: make test && make llamacpp-parity-full

Definition at line 448 of file rope_kernels.c.

458 {
459  rope_forward(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
460  rope_forward(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
461 }
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180

References rope_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), ck_test_rope(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ rope_forward_qk_strided()

void rope_forward_qk_strided ( float *  q,
float *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
int  q_stride_tokens,
int  k_stride_tokens 
)

RoPE forward for both Q and K with custom strides (KV cache layouts)

Test:

test_rope.py::TestRoPEForward::test_rope_forward_qk_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided

Combined QK RoPE with configurable strides for KV cache layouts.

After changes: make test

Definition at line 472 of file rope_kernels.c.

484 {
485  rope_forward_strided(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, q_stride_tokens);
486  rope_forward_strided(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, k_stride_tokens);
487 }
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
Definition: rope_kernels.c:207

References rope_forward_strided().

Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ rope_forward_strided()

void rope_forward_strided ( float *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
int  head_stride_tokens 
)

RoPE forward with custom head stride (for KV cache layouts)

Test:

test_rope.py::TestRoPEForward::test_rope_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode

Variant with configurable head_stride_tokens for non-contiguous head layouts.

After changes: make test

Definition at line 207 of file rope_kernels.c.

216 {
217  size_t head_stride = (size_t)head_stride_tokens * (size_t)aligned_head_dim;
218 
219  for (int h = 0; h < num_heads; ++h) {
220  rope_apply_head(x + h * head_stride,
221  cos_cache, sin_cache,
222  num_tokens, head_dim, aligned_head_dim, pos_offset);
223  }
224 }

References rope_apply_head().

Referenced by rope_forward_qk_strided().

◆ rope_precompute_cache()

void rope_precompute_cache ( float *  cos_cache,
float *  sin_cache,
int  max_seq_len,
int  head_dim,
float  base 
)

Precompute RoPE cos/sin cache

Test:

test_rope.py::TestRoPECache::test_cache_computation

test_rope.py::TestRoPECache::test_cache_values

Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1. cos_cache, sin_cache: [max_seq_len, head_dim/2]

After changes: make test

Definition at line 52 of file rope_kernels.c.

57 {
58  int half_dim = head_dim / 2;
59 
60  long double base_ld = (long double)base;
61  long double head_dim_ld = (long double)head_dim;
62  long double log_base = logl(base_ld);
63  for (int pos = 0; pos < max_seq_len; ++pos) {
64  for (int i = 0; i < half_dim; ++i) {
65  long double exponent = ((long double)(2 * i)) / head_dim_ld;
66  long double freq = expl(-exponent * log_base);
67  float freq_f = (float)freq;
68  float angle_f = (float)pos * freq_f;
69  cos_cache[pos * half_dim + i] = cosf(angle_f);
70  sin_cache[pos * half_dim + i] = sinf(angle_f);
71  }
72  }
73 }

Referenced by ck_test_rope().