← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rope_kernels.c
Go to the documentation of this file.
1 /**
2  * @file rope_kernels.c
3  * @brief RoPE (Rotary Position Embedding) kernels with SIMD
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Applies rotary position embeddings to query and key vectors.
15  * Used by Llama, SmolLM, and most modern transformer architectures.
16  *
17  * Math (Llama-style rotate-half):
18  * Split head_dim into two halves (0..half-1, half..head_dim-1).
19  * For each position m and index i in [0, half):
20  * x0 = x[i], x1 = x[i + half]
21  * x'[i] = x0 * cos(m * theta_i) - x1 * sin(m * theta_i)
22  * x'[i+half] = x0 * sin(m * theta_i) + x1 * cos(m * theta_i)
23  *
24  * Where theta_i = 1 / (base^(2i/d)), typically base=10000.
25  *
26  * Layout:
27  * x: [num_heads, num_tokens, head_dim] head-major
28  * cos_cache, sin_cache: [max_seq_len, head_dim/2] precomputed
29  */
30 
31 #include <math.h>
32 #include <stddef.h>
33 
34 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
35 #include <immintrin.h>
36 #endif
37 
38 #ifndef M_PI
39 #define M_PI 3.14159265358979323846
40 #endif
41 
42 /**
43  * Precompute RoPE cos/sin cache
44  * @test test_rope.py::TestRoPECache::test_cache_computation
45  * @test test_rope.py::TestRoPECache::test_cache_values
46  *
47  * Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1.
48  * cos_cache, sin_cache: [max_seq_len, head_dim/2]
49  *
50  * After changes: make test
51  */
52 void rope_precompute_cache(float *cos_cache,
53  float *sin_cache,
54  int max_seq_len,
55  int head_dim,
56  float base)
57 {
58  int half_dim = head_dim / 2;
59 
60  long double base_ld = (long double)base;
61  long double head_dim_ld = (long double)head_dim;
62  long double log_base = logl(base_ld);
63  for (int pos = 0; pos < max_seq_len; ++pos) {
64  for (int i = 0; i < half_dim; ++i) {
65  long double exponent = ((long double)(2 * i)) / head_dim_ld;
66  long double freq = expl(-exponent * log_base);
67  float freq_f = (float)freq;
68  float angle_f = (float)pos * freq_f;
69  cos_cache[pos * half_dim + i] = cosf(angle_f);
70  sin_cache[pos * half_dim + i] = sinf(angle_f);
71  }
72  }
73 }
74 
75 // Apply RoPE to a single head's Q or K tensor in-place.
76 // x: [num_tokens, head_dim] for one head
77 // cos_cache, sin_cache: [max_seq_len, head_dim/2]
78 // pos_offset: starting position (for KV cache continuation)
79 static inline void rope_apply_head(float *x,
80  const float *cos_cache,
81  const float *sin_cache,
82  int num_tokens,
83  int head_dim,
84  int aligned_head_dim,
85  int pos_offset)
86 {
87  int half_dim = head_dim / 2;
88 
89  for (int t = 0; t < num_tokens; ++t) {
90  int pos = pos_offset + t;
91  const float *cos_row = cos_cache + pos * half_dim;
92  const float *sin_row = sin_cache + pos * half_dim;
93  float *x_row = x + (size_t)t * (size_t)aligned_head_dim;
94 
95 #if defined(__AVX512F__)
96  // Process 16 pairs at a time
97  int i = 0;
98  for (; i + 16 <= half_dim; i += 16) {
99  __m512 x0 = _mm512_loadu_ps(&x_row[i]);
100  __m512 x1 = _mm512_loadu_ps(&x_row[i + half_dim]);
101  __m512 c = _mm512_loadu_ps(&cos_row[i]);
102  __m512 s = _mm512_loadu_ps(&sin_row[i]);
103 
104  // x'[i] = x0 * c - x1 * s
105  __m512 r0 = _mm512_fmsub_ps(x0, c, _mm512_mul_ps(x1, s));
106  // x'[i+half] = x0 * s + x1 * c
107  __m512 r1 = _mm512_fmadd_ps(x0, s, _mm512_mul_ps(x1, c));
108 
109  _mm512_storeu_ps(&x_row[i], r0);
110  _mm512_storeu_ps(&x_row[i + half_dim], r1);
111  }
112  // Handle remaining elements
113  for (; i < half_dim; ++i) {
114  float x0 = x_row[i];
115  float x1 = x_row[i + half_dim];
116  float c = cos_row[i];
117  float s = sin_row[i];
118  x_row[i] = x0 * c - x1 * s;
119  x_row[i + half_dim] = x0 * s + x1 * c;
120  }
121 
122 #elif defined(__AVX__)
123  // Process 8 pairs at a time
124  int i = 0;
125  for (; i + 8 <= half_dim; i += 8) {
126  __m256 x0 = _mm256_loadu_ps(&x_row[i]);
127  __m256 x1 = _mm256_loadu_ps(&x_row[i + half_dim]);
128  __m256 c = _mm256_loadu_ps(&cos_row[i]);
129  __m256 s = _mm256_loadu_ps(&sin_row[i]);
130 
131  // x'[i] = x0 * c - x1 * s (no FMA in AVX1)
132  __m256 x0c = _mm256_mul_ps(x0, c);
133  __m256 x1s = _mm256_mul_ps(x1, s);
134  __m256 r0 = _mm256_sub_ps(x0c, x1s);
135 
136  // x'[i+half] = x0 * s + x1 * c
137  __m256 x0s = _mm256_mul_ps(x0, s);
138  __m256 x1c = _mm256_mul_ps(x1, c);
139  __m256 r1 = _mm256_add_ps(x0s, x1c);
140 
141  _mm256_storeu_ps(&x_row[i], r0);
142  _mm256_storeu_ps(&x_row[i + half_dim], r1);
143  }
144  // Handle remaining elements
145  for (; i < half_dim; ++i) {
146  float x0 = x_row[i];
147  float x1 = x_row[i + half_dim];
148  float c = cos_row[i];
149  float s = sin_row[i];
150  x_row[i] = x0 * c - x1 * s;
151  x_row[i + half_dim] = x0 * s + x1 * c;
152  }
153 
154 #else
155  // Scalar fallback
156  for (int i = 0; i < half_dim; ++i) {
157  float x0 = x_row[i];
158  float x1 = x_row[i + half_dim];
159  float c = cos_row[i];
160  float s = sin_row[i];
161 
162  x_row[i] = x0 * c - x1 * s;
163  x_row[i + half_dim] = x0 * s + x1 * c;
164  }
165 #endif
166  }
167 }
168 
169 /**
170  * RoPE forward (head-major layout, in-place)
171  * @test test_rope.py::TestRoPEForward::test_rope_forward
172  * @test test_rope.py::TestRoPEForward::test_rope_vs_separate
173  * @test test_parity.py::test_rope_parity
174  *
175  * Applies rotary position embeddings in-place to Q or K tensor.
176  * x: [num_heads, num_tokens, head_dim] head-major
177  *
178  * After changes: make test && make llamacpp-parity-full
179  */
180 void rope_forward(float *x,
181  const float *cos_cache,
182  const float *sin_cache,
183  int num_heads,
184  int num_tokens,
185  int head_dim,
186  int aligned_head_dim,
187  int pos_offset)
188 {
189  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  rope_apply_head(x + h * head_stride,
193  cos_cache, sin_cache,
194  num_tokens, head_dim, aligned_head_dim, pos_offset);
195  }
196 }
197 
198 /**
199  * RoPE forward with custom head stride (for KV cache layouts)
200  * @test test_rope.py::TestRoPEForward::test_rope_strided
201  * @test test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode
202  *
203  * Variant with configurable head_stride_tokens for non-contiguous head layouts.
204  *
205  * After changes: make test
206  */
207 void rope_forward_strided(float *x,
208  const float *cos_cache,
209  const float *sin_cache,
210  int num_heads,
211  int num_tokens,
212  int head_dim,
213  int aligned_head_dim,
214  int pos_offset,
215  int head_stride_tokens)
216 {
217  size_t head_stride = (size_t)head_stride_tokens * (size_t)aligned_head_dim;
218 
219  for (int h = 0; h < num_heads; ++h) {
220  rope_apply_head(x + h * head_stride,
221  cos_cache, sin_cache,
222  num_tokens, head_dim, aligned_head_dim, pos_offset);
223  }
224 }
225 
226 /**
227  * RoPE backward (inverse rotation)
228  * @test test_rope.py::TestRoPEBackward::test_rope_backward
229  * @test test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate
230  *
231  * RoPE backward: inverse rotation (rotate by -θ).
232  * Since cos(-θ) = cos(θ) and sin(-θ) = -sin(θ):
233  * d_x[2i] = d0 * c + d1 * s
234  * d_x[2i+1] = -d0 * s + d1 * c
235  *
236  * After changes: make test
237  */
238 void rope_backward(const float *d_out,
239  float *d_x,
240  const float *cos_cache,
241  const float *sin_cache,
242  int num_heads,
243  int num_tokens,
244  int head_dim,
245  int aligned_head_dim,
246  int pos_offset)
247 {
248  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
249  int half_dim = head_dim / 2;
250 
251  for (int h = 0; h < num_heads; ++h) {
252  for (int t = 0; t < num_tokens; ++t) {
253  int pos = pos_offset + t;
254  const float *cos_row = cos_cache + pos * half_dim;
255  const float *sin_row = sin_cache + pos * half_dim;
256 
257  size_t idx = h * head_stride + (size_t)t * (size_t)aligned_head_dim;
258  const float *d_out_row = d_out + idx;
259  float *d_x_row = d_x + idx;
260 
261 #if defined(__AVX512F__)
262  int i = 0;
263  for (; i + 16 <= half_dim; i += 16) {
264  __m512 d0 = _mm512_loadu_ps(&d_out_row[i]);
265  __m512 d1 = _mm512_loadu_ps(&d_out_row[i + half_dim]);
266  __m512 c = _mm512_loadu_ps(&cos_row[i]);
267  __m512 s = _mm512_loadu_ps(&sin_row[i]);
268 
269  // Inverse: d_x[i] = d0 * c + d1 * s
270  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
271  // Inverse: d_x[i+half] = -d0 * s + d1 * c
272  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
273 
274  _mm512_storeu_ps(&d_x_row[i], r0);
275  _mm512_storeu_ps(&d_x_row[i + half_dim], r1);
276  }
277  for (; i < half_dim; ++i) {
278  float d0 = d_out_row[i];
279  float d1 = d_out_row[i + half_dim];
280  float c = cos_row[i];
281  float s = sin_row[i];
282  d_x_row[i] = d0 * c + d1 * s;
283  d_x_row[i + half_dim] = -d0 * s + d1 * c;
284  }
285 
286 #elif defined(__AVX__)
287  int i = 0;
288  for (; i + 8 <= half_dim; i += 8) {
289  __m256 d0 = _mm256_loadu_ps(&d_out_row[i]);
290  __m256 d1 = _mm256_loadu_ps(&d_out_row[i + half_dim]);
291  __m256 c = _mm256_loadu_ps(&cos_row[i]);
292  __m256 s = _mm256_loadu_ps(&sin_row[i]);
293 
294  // Inverse: d_x[i] = d0 * c + d1 * s
295  __m256 d0c = _mm256_mul_ps(d0, c);
296  __m256 d1s = _mm256_mul_ps(d1, s);
297  __m256 r0 = _mm256_add_ps(d0c, d1s);
298 
299  // Inverse: d_x[i+half] = -d0 * s + d1 * c = d1 * c - d0 * s
300  __m256 d1c = _mm256_mul_ps(d1, c);
301  __m256 d0s = _mm256_mul_ps(d0, s);
302  __m256 r1 = _mm256_sub_ps(d1c, d0s);
303 
304  _mm256_storeu_ps(&d_x_row[i], r0);
305  _mm256_storeu_ps(&d_x_row[i + half_dim], r1);
306  }
307  for (; i < half_dim; ++i) {
308  float d0 = d_out_row[i];
309  float d1 = d_out_row[i + half_dim];
310  float c = cos_row[i];
311  float s = sin_row[i];
312  d_x_row[i] = d0 * c + d1 * s;
313  d_x_row[i + half_dim] = -d0 * s + d1 * c;
314  }
315 
316 #else
317  for (int i = 0; i < half_dim; ++i) {
318  float d0 = d_out_row[i];
319  float d1 = d_out_row[i + half_dim];
320  float c = cos_row[i];
321  float s = sin_row[i];
322 
323  // Inverse rotation: rotate by -θ
324  d_x_row[i] = d0 * c + d1 * s;
325  d_x_row[i + half_dim] = -d0 * s + d1 * c;
326  }
327 #endif
328 
329  for (int i = head_dim; i < aligned_head_dim; ++i) {
330  d_x_row[i] = 0.0f;
331  }
332  }
333  }
334 }
335 
336 /**
337  * RoPE backward in-place (overwrite with inverse rotation)
338  * @test test_rope.py::TestRoPEBackward::test_rope_backward_inplace
339  *
340  * In-place backward: overwrite d_out with inverse-rotated gradients.
341  * Useful when d_x == d_out is acceptable (saves memory).
342  *
343  * After changes: make test
344  */
345 void rope_backward_inplace(float *d_x,
346  const float *cos_cache,
347  const float *sin_cache,
348  int num_heads,
349  int num_tokens,
350  int head_dim,
351  int aligned_head_dim,
352  int pos_offset)
353 {
354  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
355  int half_dim = head_dim / 2;
356 
357  for (int h = 0; h < num_heads; ++h) {
358  for (int t = 0; t < num_tokens; ++t) {
359  int pos = pos_offset + t;
360  const float *cos_row = cos_cache + pos * half_dim;
361  const float *sin_row = sin_cache + pos * half_dim;
362 
363  float *d_row = d_x + h * head_stride + (size_t)t * (size_t)aligned_head_dim;
364 
365 #if defined(__AVX512F__)
366  int i = 0;
367  for (; i + 16 <= half_dim; i += 16) {
368  __m512 d0 = _mm512_loadu_ps(&d_row[i]);
369  __m512 d1 = _mm512_loadu_ps(&d_row[i + half_dim]);
370  __m512 c = _mm512_loadu_ps(&cos_row[i]);
371  __m512 s = _mm512_loadu_ps(&sin_row[i]);
372 
373  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
374  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
375 
376  _mm512_storeu_ps(&d_row[i], r0);
377  _mm512_storeu_ps(&d_row[i + half_dim], r1);
378  }
379  for (; i < half_dim; ++i) {
380  float d0 = d_row[i];
381  float d1 = d_row[i + half_dim];
382  float c = cos_row[i];
383  float s = sin_row[i];
384  d_row[i] = d0 * c + d1 * s;
385  d_row[i + half_dim] = -d0 * s + d1 * c;
386  }
387 
388 #elif defined(__AVX__)
389  int i = 0;
390  for (; i + 8 <= half_dim; i += 8) {
391  __m256 d0 = _mm256_loadu_ps(&d_row[i]);
392  __m256 d1 = _mm256_loadu_ps(&d_row[i + half_dim]);
393  __m256 c = _mm256_loadu_ps(&cos_row[i]);
394  __m256 s = _mm256_loadu_ps(&sin_row[i]);
395 
396  __m256 d0c = _mm256_mul_ps(d0, c);
397  __m256 d1s = _mm256_mul_ps(d1, s);
398  __m256 r0 = _mm256_add_ps(d0c, d1s);
399 
400  __m256 d1c = _mm256_mul_ps(d1, c);
401  __m256 d0s = _mm256_mul_ps(d0, s);
402  __m256 r1 = _mm256_sub_ps(d1c, d0s);
403 
404  _mm256_storeu_ps(&d_row[i], r0);
405  _mm256_storeu_ps(&d_row[i + half_dim], r1);
406  }
407  for (; i < half_dim; ++i) {
408  float d0 = d_row[i];
409  float d1 = d_row[i + half_dim];
410  float c = cos_row[i];
411  float s = sin_row[i];
412  d_row[i] = d0 * c + d1 * s;
413  d_row[i + half_dim] = -d0 * s + d1 * c;
414  }
415 
416 #else
417  for (int i = 0; i < half_dim; ++i) {
418  float d0 = d_row[i];
419  float d1 = d_row[i + half_dim];
420  float c = cos_row[i];
421  float s = sin_row[i];
422 
423  // Inverse rotation: rotate by -θ
424  d_row[i] = d0 * c + d1 * s;
425  d_row[i + half_dim] = -d0 * s + d1 * c;
426  }
427 #endif
428 
429  for (int i = head_dim; i < aligned_head_dim; ++i) {
430  d_row[i] = 0.0f;
431  }
432  }
433  }
434 }
435 
436 /**
437  * RoPE forward for both Q and K (common inference pattern)
438  * @test test_rope.py::TestRoPEForward::test_rope_forward_qk
439  * @test test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope
440  * @test test_parity.py::test_rope_qk_parity
441  *
442  * Combined RoPE forward for both Q and K in one call.
443  * q: [num_heads, num_tokens, head_dim]
444  * k: [num_kv_heads, num_tokens, head_dim]
445  *
446  * After changes: make test && make llamacpp-parity-full
447  */
448 void rope_forward_qk(float *q,
449  float *k,
450  const float *cos_cache,
451  const float *sin_cache,
452  int num_heads,
453  int num_kv_heads,
454  int num_tokens,
455  int head_dim,
456  int aligned_head_dim,
457  int pos_offset)
458 {
459  rope_forward(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
460  rope_forward(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
461 }
462 
463 /**
464  * RoPE forward for both Q and K with custom strides (KV cache layouts)
465  * @test test_rope.py::TestRoPEForward::test_rope_forward_qk_strided
466  * @test test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided
467  *
468  * Combined QK RoPE with configurable strides for KV cache layouts.
469  *
470  * After changes: make test
471  */
473  float *k,
474  const float *cos_cache,
475  const float *sin_cache,
476  int num_heads,
477  int num_kv_heads,
478  int num_tokens,
479  int head_dim,
480  int aligned_head_dim,
481  int pos_offset,
482  int q_stride_tokens,
483  int k_stride_tokens)
484 {
485  rope_forward_strided(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, q_stride_tokens);
486  rope_forward_strided(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, k_stride_tokens);
487 }
488 
489 /**
490  * RoPE backward for both dQ and dK
491  * @test test_rope.py::TestRoPEBackward::test_rope_backward_qk
492  *
493  * Combined RoPE backward for both dQ and dK gradients.
494  *
495  * After changes: make test
496  */
497 void rope_backward_qk(const float *d_q_out,
498  const float *d_k_out,
499  float *d_q,
500  float *d_k,
501  const float *cos_cache,
502  const float *sin_cache,
503  int num_heads,
504  int num_kv_heads,
505  int num_tokens,
506  int head_dim,
507  int aligned_head_dim,
508  int pos_offset)
509 {
510  rope_backward(d_q_out, d_q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
511  rope_backward(d_k_out, d_k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
512 }
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
Definition: rope_kernels.c:52
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:497
static void rope_apply_head(float *x, const float *cos_cache, const float *sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:79
void rope_backward_inplace(float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:345
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
Definition: rope_kernels.c:207