34 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
35 #include <immintrin.h>
39 #define M_PI 3.14159265358979323846
58 int half_dim = head_dim / 2;
60 long double base_ld = (
long double)base;
61 long double head_dim_ld = (
long double)head_dim;
62 long double log_base = logl(base_ld);
63 for (
int pos = 0; pos < max_seq_len; ++pos) {
64 for (
int i = 0; i < half_dim; ++i) {
65 long double exponent = ((
long double)(2 * i)) / head_dim_ld;
66 long double freq = expl(-exponent * log_base);
67 float freq_f = (float)freq;
68 float angle_f = (float)pos * freq_f;
69 cos_cache[pos * half_dim + i] = cosf(angle_f);
70 sin_cache[pos * half_dim + i] = sinf(angle_f);
80 const float *cos_cache,
81 const float *sin_cache,
87 int half_dim = head_dim / 2;
89 for (
int t = 0; t < num_tokens; ++t) {
90 int pos = pos_offset + t;
91 const float *cos_row = cos_cache + pos * half_dim;
92 const float *sin_row = sin_cache + pos * half_dim;
93 float *x_row = x + (size_t)t * (
size_t)aligned_head_dim;
95 #if defined(__AVX512F__)
98 for (; i + 16 <= half_dim; i += 16) {
99 __m512 x0 = _mm512_loadu_ps(&x_row[i]);
100 __m512 x1 = _mm512_loadu_ps(&x_row[i + half_dim]);
101 __m512 c = _mm512_loadu_ps(&cos_row[i]);
102 __m512 s = _mm512_loadu_ps(&sin_row[i]);
105 __m512 r0 = _mm512_fmsub_ps(x0, c, _mm512_mul_ps(x1, s));
107 __m512 r1 = _mm512_fmadd_ps(x0, s, _mm512_mul_ps(x1, c));
109 _mm512_storeu_ps(&x_row[i], r0);
110 _mm512_storeu_ps(&x_row[i + half_dim], r1);
113 for (; i < half_dim; ++i) {
115 float x1 = x_row[i + half_dim];
116 float c = cos_row[i];
117 float s = sin_row[i];
118 x_row[i] = x0 * c - x1 * s;
119 x_row[i + half_dim] = x0 * s + x1 * c;
122 #elif defined(__AVX__)
125 for (; i + 8 <= half_dim; i += 8) {
126 __m256 x0 = _mm256_loadu_ps(&x_row[i]);
127 __m256 x1 = _mm256_loadu_ps(&x_row[i + half_dim]);
128 __m256 c = _mm256_loadu_ps(&cos_row[i]);
129 __m256 s = _mm256_loadu_ps(&sin_row[i]);
132 __m256 x0c = _mm256_mul_ps(x0, c);
133 __m256 x1s = _mm256_mul_ps(x1, s);
134 __m256 r0 = _mm256_sub_ps(x0c, x1s);
137 __m256 x0s = _mm256_mul_ps(x0, s);
138 __m256 x1c = _mm256_mul_ps(x1, c);
139 __m256 r1 = _mm256_add_ps(x0s, x1c);
141 _mm256_storeu_ps(&x_row[i], r0);
142 _mm256_storeu_ps(&x_row[i + half_dim], r1);
145 for (; i < half_dim; ++i) {
147 float x1 = x_row[i + half_dim];
148 float c = cos_row[i];
149 float s = sin_row[i];
150 x_row[i] = x0 * c - x1 * s;
151 x_row[i + half_dim] = x0 * s + x1 * c;
156 for (
int i = 0; i < half_dim; ++i) {
158 float x1 = x_row[i + half_dim];
159 float c = cos_row[i];
160 float s = sin_row[i];
162 x_row[i] = x0 * c - x1 * s;
163 x_row[i + half_dim] = x0 * s + x1 * c;
181 const float *cos_cache,
182 const float *sin_cache,
186 int aligned_head_dim,
189 size_t head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
191 for (
int h = 0; h < num_heads; ++h) {
193 cos_cache, sin_cache,
194 num_tokens, head_dim, aligned_head_dim, pos_offset);
208 const float *cos_cache,
209 const float *sin_cache,
213 int aligned_head_dim,
215 int head_stride_tokens)
217 size_t head_stride = (size_t)head_stride_tokens * (
size_t)aligned_head_dim;
219 for (
int h = 0; h < num_heads; ++h) {
221 cos_cache, sin_cache,
222 num_tokens, head_dim, aligned_head_dim, pos_offset);
240 const float *cos_cache,
241 const float *sin_cache,
245 int aligned_head_dim,
248 size_t head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
249 int half_dim = head_dim / 2;
251 for (
int h = 0; h < num_heads; ++h) {
252 for (
int t = 0; t < num_tokens; ++t) {
253 int pos = pos_offset + t;
254 const float *cos_row = cos_cache + pos * half_dim;
255 const float *sin_row = sin_cache + pos * half_dim;
257 size_t idx = h * head_stride + (size_t)t * (
size_t)aligned_head_dim;
258 const float *d_out_row = d_out + idx;
259 float *d_x_row = d_x + idx;
261 #if defined(__AVX512F__)
263 for (; i + 16 <= half_dim; i += 16) {
264 __m512 d0 = _mm512_loadu_ps(&d_out_row[i]);
265 __m512 d1 = _mm512_loadu_ps(&d_out_row[i + half_dim]);
266 __m512 c = _mm512_loadu_ps(&cos_row[i]);
267 __m512 s = _mm512_loadu_ps(&sin_row[i]);
270 __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
272 __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
274 _mm512_storeu_ps(&d_x_row[i], r0);
275 _mm512_storeu_ps(&d_x_row[i + half_dim], r1);
277 for (; i < half_dim; ++i) {
278 float d0 = d_out_row[i];
279 float d1 = d_out_row[i + half_dim];
280 float c = cos_row[i];
281 float s = sin_row[i];
282 d_x_row[i] = d0 * c + d1 * s;
283 d_x_row[i + half_dim] = -d0 * s + d1 * c;
286 #elif defined(__AVX__)
288 for (; i + 8 <= half_dim; i += 8) {
289 __m256 d0 = _mm256_loadu_ps(&d_out_row[i]);
290 __m256 d1 = _mm256_loadu_ps(&d_out_row[i + half_dim]);
291 __m256 c = _mm256_loadu_ps(&cos_row[i]);
292 __m256 s = _mm256_loadu_ps(&sin_row[i]);
295 __m256 d0c = _mm256_mul_ps(d0, c);
296 __m256 d1s = _mm256_mul_ps(d1, s);
297 __m256 r0 = _mm256_add_ps(d0c, d1s);
300 __m256 d1c = _mm256_mul_ps(d1, c);
301 __m256 d0s = _mm256_mul_ps(d0, s);
302 __m256 r1 = _mm256_sub_ps(d1c, d0s);
304 _mm256_storeu_ps(&d_x_row[i], r0);
305 _mm256_storeu_ps(&d_x_row[i + half_dim], r1);
307 for (; i < half_dim; ++i) {
308 float d0 = d_out_row[i];
309 float d1 = d_out_row[i + half_dim];
310 float c = cos_row[i];
311 float s = sin_row[i];
312 d_x_row[i] = d0 * c + d1 * s;
313 d_x_row[i + half_dim] = -d0 * s + d1 * c;
317 for (
int i = 0; i < half_dim; ++i) {
318 float d0 = d_out_row[i];
319 float d1 = d_out_row[i + half_dim];
320 float c = cos_row[i];
321 float s = sin_row[i];
324 d_x_row[i] = d0 * c + d1 * s;
325 d_x_row[i + half_dim] = -d0 * s + d1 * c;
329 for (
int i = head_dim; i < aligned_head_dim; ++i) {
346 const float *cos_cache,
347 const float *sin_cache,
351 int aligned_head_dim,
354 size_t head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
355 int half_dim = head_dim / 2;
357 for (
int h = 0; h < num_heads; ++h) {
358 for (
int t = 0; t < num_tokens; ++t) {
359 int pos = pos_offset + t;
360 const float *cos_row = cos_cache + pos * half_dim;
361 const float *sin_row = sin_cache + pos * half_dim;
363 float *d_row = d_x + h * head_stride + (size_t)t * (
size_t)aligned_head_dim;
365 #if defined(__AVX512F__)
367 for (; i + 16 <= half_dim; i += 16) {
368 __m512 d0 = _mm512_loadu_ps(&d_row[i]);
369 __m512 d1 = _mm512_loadu_ps(&d_row[i + half_dim]);
370 __m512 c = _mm512_loadu_ps(&cos_row[i]);
371 __m512 s = _mm512_loadu_ps(&sin_row[i]);
373 __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
374 __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
376 _mm512_storeu_ps(&d_row[i], r0);
377 _mm512_storeu_ps(&d_row[i + half_dim], r1);
379 for (; i < half_dim; ++i) {
381 float d1 = d_row[i + half_dim];
382 float c = cos_row[i];
383 float s = sin_row[i];
384 d_row[i] = d0 * c + d1 * s;
385 d_row[i + half_dim] = -d0 * s + d1 * c;
388 #elif defined(__AVX__)
390 for (; i + 8 <= half_dim; i += 8) {
391 __m256 d0 = _mm256_loadu_ps(&d_row[i]);
392 __m256 d1 = _mm256_loadu_ps(&d_row[i + half_dim]);
393 __m256 c = _mm256_loadu_ps(&cos_row[i]);
394 __m256 s = _mm256_loadu_ps(&sin_row[i]);
396 __m256 d0c = _mm256_mul_ps(d0, c);
397 __m256 d1s = _mm256_mul_ps(d1, s);
398 __m256 r0 = _mm256_add_ps(d0c, d1s);
400 __m256 d1c = _mm256_mul_ps(d1, c);
401 __m256 d0s = _mm256_mul_ps(d0, s);
402 __m256 r1 = _mm256_sub_ps(d1c, d0s);
404 _mm256_storeu_ps(&d_row[i], r0);
405 _mm256_storeu_ps(&d_row[i + half_dim], r1);
407 for (; i < half_dim; ++i) {
409 float d1 = d_row[i + half_dim];
410 float c = cos_row[i];
411 float s = sin_row[i];
412 d_row[i] = d0 * c + d1 * s;
413 d_row[i + half_dim] = -d0 * s + d1 * c;
417 for (
int i = 0; i < half_dim; ++i) {
419 float d1 = d_row[i + half_dim];
420 float c = cos_row[i];
421 float s = sin_row[i];
424 d_row[i] = d0 * c + d1 * s;
425 d_row[i + half_dim] = -d0 * s + d1 * c;
429 for (
int i = head_dim; i < aligned_head_dim; ++i) {
450 const float *cos_cache,
451 const float *sin_cache,
456 int aligned_head_dim,
459 rope_forward(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
460 rope_forward(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
474 const float *cos_cache,
475 const float *sin_cache,
480 int aligned_head_dim,
485 rope_forward_strided(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, q_stride_tokens);
486 rope_forward_strided(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, k_stride_tokens);
498 const float *d_k_out,
501 const float *cos_cache,
502 const float *sin_cache,
507 int aligned_head_dim,
510 rope_backward(d_q_out, d_q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
511 rope_backward(d_k_out, d_k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
static void rope_apply_head(float *x, const float *cos_cache, const float *sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_backward_inplace(float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)