← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rope_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file rope_kernels_bf16.c
3  * @brief RoPE (Rotary Position Embedding) kernels for BF16
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  */
14 
15 #include <stdint.h>
16 
17 #include "bf16_utils.h"
18 #include "ckernel_engine.h"
19 
20 /* Suppress false positive warnings about uninitialized variables */
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
23 
24 /*
25  * BF16 RoPE forward with caller-provided scratch buffer.
26  * scratch: [num_heads * num_tokens * aligned_head_dim] floats
27  */
28 void rope_forward_bf16(uint16_t *x,
29  const float *cos_cache,
30  const float *sin_cache,
31  int num_heads,
32  int num_tokens,
33  int head_dim,
34  int aligned_head_dim,
35  int pos_offset,
36  float *scratch)
37 {
38  if (!scratch) return;
39 
40  size_t total = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
41 
42  bf16_tensor_to_float(x, scratch, total);
43  rope_forward(scratch, cos_cache, sin_cache,
44  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
45  float_tensor_to_bf16(scratch, x, total);
46 }
47 
48 /*
49  * BF16 RoPE backward with caller-provided scratch buffers.
50  * scratch_d_out, scratch_d_x: each [num_heads * num_tokens * aligned_head_dim] floats
51  */
52 void rope_backward_bf16(const uint16_t *d_out,
53  uint16_t *d_x,
54  const float *cos_cache,
55  const float *sin_cache,
56  int num_heads,
57  int num_tokens,
58  int head_dim,
59  int aligned_head_dim,
60  int pos_offset,
61  float *scratch_d_out,
62  float *scratch_d_x)
63 {
64  if (!scratch_d_out || !scratch_d_x) return;
65 
66  size_t total = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
67 
68  bf16_tensor_to_float(d_out, scratch_d_out, total);
69  rope_backward(scratch_d_out, scratch_d_x, cos_cache, sin_cache,
70  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
71  float_tensor_to_bf16(scratch_d_x, d_x, total);
72 }
73 
74 /*
75  * BF16 RoPE forward for Q and K with caller-provided scratch buffers.
76  * scratch_q: [num_heads * num_tokens * aligned_head_dim] floats
77  * scratch_k: [num_kv_heads * num_tokens * aligned_head_dim] floats
78  */
79 void rope_forward_qk_bf16(uint16_t *q,
80  uint16_t *k,
81  const float *cos_cache,
82  const float *sin_cache,
83  int num_heads,
84  int num_kv_heads,
85  int num_tokens,
86  int head_dim,
87  int aligned_head_dim,
88  int pos_offset,
89  float *scratch_q,
90  float *scratch_k)
91 {
92  if (!q || !k) return;
93 
94  rope_forward_bf16(q, cos_cache, sin_cache,
95  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_q);
96  rope_forward_bf16(k, cos_cache, sin_cache,
97  num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_k);
98 }
99 
100 /*
101  * BF16 RoPE backward for Q and K with caller-provided scratch buffers.
102  */
103 void rope_backward_qk_bf16(const uint16_t *d_q_out,
104  const uint16_t *d_k_out,
105  uint16_t *d_q,
106  uint16_t *d_k,
107  const float *cos_cache,
108  const float *sin_cache,
109  int num_heads,
110  int num_kv_heads,
111  int num_tokens,
112  int head_dim,
113  int aligned_head_dim,
114  int pos_offset,
115  float *scratch_dq_out,
116  float *scratch_dq,
117  float *scratch_dk_out,
118  float *scratch_dk)
119 {
120  if (!d_q_out || !d_k_out || !d_q || !d_k) return;
121 
122  rope_backward_bf16(d_q_out, d_q, cos_cache, sin_cache,
123  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
124  scratch_dq_out, scratch_dq);
125  rope_backward_bf16(d_k_out, d_k, cos_cache, sin_cache,
126  num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
127  scratch_dk_out, scratch_dk);
128 }
129 
130 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238
void rope_backward_bf16(const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)
void rope_backward_qk_bf16(const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk)
void rope_forward_bf16(uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)
void rope_forward_qk_bf16(uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k)