67 const int dr = (n + nth - 1) / nth;
68 const int r0 = dr * ith;
69 const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
73 for (
int i = r0; i < r1; i++) {
82 const int dr = (n + nth - 1) / nth;
83 const int r0 = dr * ith;
84 const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
88 for (
int i = r0; i < r1; i++) {
96 const int dr = (n + nth - 1) / nth;
97 const int r0 = dr * ith;
98 const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
102 memset(y + r0, 0, (r1 - r0) *
sizeof(
float));
136 int head_dim,
int embed_dim,
139 const int q_dim = H * head_dim;
140 const int kv_dim = H_kv * head_dim;
143 const int aligned_embed = ((embed_dim + 255) / 256) * 256;
145 if (num_threads <= 0) {
146 num_threads = omp_get_max_threads();
150 #pragma omp parallel num_threads(num_threads)
152 const int ith = omp_get_thread_num();
153 const int nth = omp_get_num_threads();
197 const int aligned_embed = ((embed_dim + 255) / 256) * 256;
198 const int aligned_inter = ((intermediate + 255) / 256) * 256;
200 if (num_threads <= 0) {
201 num_threads = omp_get_max_threads();
204 #pragma omp parallel num_threads(num_threads)
206 const int ith = omp_get_thread_num();
207 const int nth = omp_get_num_threads();
217 const int dr = (aligned_inter + nth - 1) / nth;
218 const int r0 = dr * ith;
219 const int r1 = (r0 + dr < aligned_inter) ? (r0 + dr) : aligned_inter;
221 for (
int i = r0; i < r1 && i < intermediate; i++) {
222 float g = gate_buf[i];
223 float silu_g = g / (1.0f + expf(-g));
224 swiglu_buf[i] = silu_g * up_buf[i];
257 const void *ln1_weight,
258 const void *ln2_weight,
282 if (num_threads <= 0) {
283 num_threads = omp_get_max_threads();
287 const int aligned_embed = ((embed_dim + 255) / 256) * 256;
288 const int aligned_inter = ((intermediate + 255) / 256) * 256;
289 const int aligned_head = ((head_dim + 31) / 32) * 32;
292 float *ln1_out = scratch;
293 float *q_vec = ln1_out + aligned_embed;
294 float *k_vec = q_vec + H * aligned_head;
295 float *v_vec = k_vec + H_kv * aligned_head;
296 float *attn_out = v_vec + H_kv * aligned_head;
297 float *o_out = attn_out + H * aligned_head;
298 float *ln2_out = o_out + aligned_embed;
299 float *gate_buf = ln2_out + aligned_embed;
300 float *up_buf = gate_buf + aligned_inter;
301 float *swiglu_buf = up_buf + aligned_inter;
302 float *mlp_out = swiglu_buf + aligned_inter;
305 const size_t q8_embed_bytes = ((aligned_embed + 255) / 256) * 292;
306 const size_t q8_inter_bytes = ((aligned_inter + 255) / 256) * 292;
307 uint8_t *ln1_q8 = (uint8_t *)(mlp_out + aligned_embed);
308 uint8_t *ln2_q8 = ln1_q8 + q8_embed_bytes;
309 uint8_t *down_q8 = ln2_q8 + q8_embed_bytes;
311 #pragma omp parallel num_threads(num_threads)
313 const int ith = omp_get_thread_num();
314 const int nth = omp_get_num_threads();
323 rmsnorm(hidden, (
const float *)ln1_weight, ln1_out, embed_dim, eps);
339 const int kv_head_stride = max_seq * aligned_head;
340 for (
int h = 0; h < H_kv; h++) {
341 memcpy(k_cache + h * kv_head_stride + token_index * aligned_head,
342 k_vec + h * head_dim, head_dim *
sizeof(
float));
343 memcpy(v_cache + h * kv_head_stride + token_index * aligned_head,
344 v_vec + h * head_dim, head_dim *
sizeof(
float));
349 memcpy(attn_out, q_vec, H * head_dim *
sizeof(
float));
377 rmsnorm(hidden, (
const float *)ln2_weight, ln2_out, embed_dim, eps);
390 const int dr = (intermediate + nth - 1) / nth;
391 const int r0 = dr * ith;
392 const int r1 = (r0 + dr < intermediate) ? (r0 + dr) : intermediate;
394 for (
int i = r0; i < r1; i++) {
395 float g = gate_buf[i];
396 float silu_g = g / (1.0f + expf(-g));
397 swiglu_buf[i] = silu_g * up_buf[i];
424 int max_threads = omp_get_max_threads();
429 if (max_threads >= 4) {
void quantize_row_q8_k(const float *x, void *y, int k)
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Quantization block structures for weight-only quantization.
static void vec_zero_parallel(float *y, int n, int ith, int nth)
static void vec_scale_parallel(float *y, float scale, int n, int ith, int nth)
int get_optimal_decode_threads(void)
void qkv_projection_parallel(const void *ln1_q8, const void *WQ, const void *WK, const void *WV, float *q_out, float *k_out, float *v_out, int H, int H_kv, int head_dim, int embed_dim, int num_threads)
static void residual_add_parallel(const float *a, const float *b, float *out, int n, int ith, int nth)
void mlp_parallel(const void *ln2_q8, const void *W_gate, const void *W_up, const void *W_down, float *gate_buf, float *up_buf, float *swiglu_buf, void *down_q8, float *mlp_out, int intermediate, int embed_dim, int num_threads)
void decode_layer_parallel(float *hidden, const void *ln1_weight, const void *ln2_weight, const void *WQ, const void *WK, const void *WV, const void *WO, const void *W_gate, const void *W_up, const void *W_down, float *k_cache, float *v_cache, int token_index, float *scratch, int embed_dim, int intermediate, int H, int H_kv, int head_dim, int max_seq, float eps, int num_threads)