39 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
40 #include <immintrin.h>
63 const float *attn_out,
71 if (!output || !attn_out || !wo)
return;
72 if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0)
return;
74 const int blocks_per_head = head_dim /
QK5_0;
75 const int blocks_per_row = embed_dim /
QK5_0;
79 const size_t token_stride = head_dim;
80 const size_t head_stride = (size_t)tokens * token_stride;
84 for (
int t = 0; t < tokens; t++) {
85 float *out_row = output + (size_t)t * embed_dim;
86 for (
int n = 0; n < embed_dim; n++) {
91 memset(output, 0, (
size_t)tokens * embed_dim *
sizeof(
float));
95 for (
int h = 0; h < num_heads; h++) {
96 const float *head_data = attn_out + (size_t)h * head_stride;
99 const int head_offset = h * blocks_per_head;
101 for (
int n_block = 0; n_block < blocks_per_head; n_block++) {
102 for (
int n = 0; n < embed_dim; n++) {
103 const block_q5_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
108 memcpy(&qh, w_row->
qh,
sizeof(qh));
111 for (
int t = 0; t < tokens; t++) {
112 const float *token_vec = head_data + (size_t)t * token_stride + (
size_t)n_block *
QK5_0;
116 for (
int j = 0; j <
QK5_0 / 2; j++) {
117 const uint8_t packed = w_row->
qs[j];
118 const int lo = (packed & 0x0F);
119 const int hi = (packed >> 4);
120 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
121 const int xh_1 = ((qh >> (j + 12))) & 0x10;
122 const int q0 = (lo | xh_0) - 16;
123 const int q1 = (hi | xh_1) - 16;
125 sum += d * (float)q0 * token_vec[j];
126 sum += d * (float)q1 * token_vec[j + 16];
129 output[(size_t)t * embed_dim + n] += sum;
140 #if defined(__AVX__) && defined(__F16C__)
141 #include <immintrin.h>
151 void gemv_nt_q5_0_head_major_output_avx(
float *output,
152 const float *attn_out,
160 if (!output || !attn_out || !wo)
return;
161 if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0)
return;
163 const int blocks_per_head = head_dim /
QK5_0;
164 const int blocks_per_row = embed_dim /
QK5_0;
167 const size_t token_stride = head_dim;
168 const size_t head_stride = (size_t)tokens * token_stride;
172 for (
int t = 0; t < tokens; t++) {
173 float *out_row = output + (size_t)t * embed_dim;
174 for (
int n = 0; n < embed_dim; n++) {
175 out_row[n] = bias[n];
179 memset(output, 0, (
size_t)tokens * embed_dim *
sizeof(
float));
183 for (
int h = 0; h < num_heads; h++) {
184 const float *head_data = attn_out + (size_t)h * head_stride;
185 const int head_offset = h * blocks_per_head;
189 for (; n + 7 < embed_dim; n += 8) {
191 for (
int n_block = 0; n_block < blocks_per_head; n_block++) {
192 const size_t w_offset = (size_t)(n + head_offset + n_block) * blocks_per_row + n_block;
194 __m256 acc0 = _mm256_setzero_ps();
195 __m256 acc1 = _mm256_setzero_ps();
196 __m256 acc2 = _mm256_setzero_ps();
197 __m256 acc3 = _mm256_setzero_ps();
198 __m256 acc4 = _mm256_setzero_ps();
199 __m256 acc5 = _mm256_setzero_ps();
200 __m256 acc6 = _mm256_setzero_ps();
201 __m256 acc7 = _mm256_setzero_ps();
204 for (
int t = 0; t < tokens; t++) {
205 const float *token_vec = head_data + (size_t)t * token_stride + (
size_t)n_block *
QK5_0;
227 for (
int j = 0; j < 16; j++) {
228 const uint8_t p0 = w0->
qs[j];
229 const uint8_t p1 = w1->
qs[j];
230 const uint8_t p2 = w2->
qs[j];
231 const uint8_t p3 = w3->
qs[j];
232 const uint8_t p4 = w4->
qs[j];
233 const uint8_t p5 = w5->
qs[j];
234 const uint8_t p6 = w6->
qs[j];
235 const uint8_t p7 = w7->
qs[j];
237 const float tv0 = token_vec[j];
238 const float tv1 = token_vec[j + 16];
241 const int lo0 = (p0 & 0x0F) - 8;
242 const int lo1 = (p1 & 0x0F) - 8;
243 const int lo2 = (p2 & 0x0F) - 8;
244 const int lo3 = (p3 & 0x0F) - 8;
245 const int lo4 = (p4 & 0x0F) - 8;
246 const int lo5 = (p5 & 0x0F) - 8;
247 const int lo6 = (p6 & 0x0F) - 8;
248 const int lo7 = (p7 & 0x0F) - 8;
250 __m256 xv = _mm256_set1_ps(tv0);
251 __m256 qw = _mm256_setr_ps(lo0, lo1, lo2, lo3, lo4, lo5, lo6, lo7);
252 __m256 vw = _mm256_setr_ps(d0, d1, d2, d3, d4, d5, d6, d7);
253 acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_mul_ps(qw, vw), xv));
256 const int hi0 = (p0 >> 4) - 8;
257 const int hi1 = (p1 >> 4) - 8;
258 const int hi2 = (p2 >> 4) - 8;
259 const int hi3 = (p3 >> 4) - 8;
260 const int hi4 = (p4 >> 4) - 8;
261 const int hi5 = (p5 >> 4) - 8;
262 const int hi6 = (p6 >> 4) - 8;
263 const int hi7 = (p7 >> 4) - 8;
265 xv = _mm256_set1_ps(tv1);
266 qw = _mm256_setr_ps(hi0, hi1, hi2, hi3, hi4, hi5, hi6, hi7);
267 vw = _mm256_setr_ps(d0, d1, d2, d3, d4, d5, d6, d7);
268 acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_mul_ps(qw, vw), xv));
272 __m256 total = _mm256_add_ps(acc0, acc1);
275 float *out_row = output + (size_t)t * embed_dim + n;
276 __m256 out_val = _mm256_loadu_ps(out_row);
277 out_val = _mm256_add_ps(out_val, total);
278 _mm256_storeu_ps(out_row, out_val);
284 for (; n < embed_dim; n++) {
285 for (
int n_block = 0; n_block < blocks_per_head; n_block++) {
286 const block_q5_0 *w_row = weights + (size_t)(n + head_offset + n_block) * blocks_per_row + n_block;
290 memcpy(&qh, w_row->
qh,
sizeof(qh));
292 for (
int t = 0; t < tokens; t++) {
293 const float *token_vec = head_data + (size_t)t * token_stride + (
size_t)n_block *
QK5_0;
296 for (
int j = 0; j <
QK5_0 / 2; j++) {
297 const uint8_t packed = w_row->
qs[j];
298 const int lo = (packed & 0x0F);
299 const int hi = (packed >> 4);
300 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
301 const int xh_1 = ((qh >> (j + 12))) & 0x10;
302 const int q0 = (lo | xh_0) - 16;
303 const int q1 = (hi | xh_1) - 16;
305 sum += d * (float)q0 * token_vec[j];
306 sum += d * (float)q1 * token_vec[j + 16];
309 output[(size_t)t * embed_dim + n] += sum;
337 #if defined(__AVX__) && defined(__F16C__)
338 gemv_nt_q5_0_head_major_output_avx(output, attn_out, wo, bias,
339 tokens, embed_dim, num_heads, head_dim);
342 tokens, embed_dim, num_heads, head_dim);
362 if (!output || !attn_out || !wo)
return;
363 if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0)
return;
365 const int blocks_per_head = head_dim /
QK8_0;
366 const int blocks_per_row = embed_dim /
QK8_0;
369 const size_t token_stride = head_dim;
370 const size_t head_stride = (size_t)tokens * token_stride;
374 for (
int t = 0; t < tokens; t++) {
375 float *out_row = output + (size_t)t * embed_dim;
376 for (
int n = 0; n < embed_dim; n++) {
377 out_row[n] = bias[n];
381 memset(output, 0, (
size_t)tokens * embed_dim *
sizeof(
float));
385 for (
int h = 0; h < num_heads; h++) {
386 const float *head_data = attn_out + (size_t)h * head_stride;
387 const int head_offset = h * blocks_per_head;
389 for (
int n_block = 0; n_block < blocks_per_head; n_block++) {
390 for (
int n = 0; n < embed_dim; n++) {
391 const block_q8_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
394 for (
int t = 0; t < tokens; t++) {
395 const float *token_vec = head_data + (size_t)t * token_stride + (
size_t)n_block *
QK8_0;
398 for (
int j = 0; j <
QK8_0; j++) {
399 sum += d * (float)w_row->
qs[j] * token_vec[j];
402 output[(size_t)t * embed_dim + n] += sum;
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void gemv_nt_q5_0_head_major_output(float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection reading head-major attention output (Q5_0 weights)