← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_f16.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_f16.c
3  * @brief GEMM kernels with FP16 (half-precision) weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements matrix multiplication where:
15  * - Weights: FP16 (IEEE half-precision, used by vision encoders)
16  * - Activations: FP32
17  * - Output: FP32
18  *
19  * Used for multimodal projection layers (mmproj-*.gguf files).
20  */
21 
22 #include <stdint.h>
23 #include <stddef.h>
24 #include "ckernel_quant.h" /* For ck_fp16_to_fp32 */
25 
26 #ifdef __AVX512F__
27 #include <immintrin.h>
28 #endif
29 
30 /* ============================================================================
31  * FP16 Conversion Utilities (if not using F16C)
32  * ============================================================================ */
33 
34 #ifndef __F16C__
35 /* Software FP16 to FP32 conversion (already in ggml_quants.h) */
36 #define fp16_to_fp32(x) ggml_fp16_to_fp32(x)
37 #define fp32_to_fp16(x) ggml_fp32_to_fp16(x)
38 #else
39 /* Hardware F16C support */
40 #include <immintrin.h>
41 static inline float fp16_to_fp32(uint16_t h) {
42  return _cvtsh_ss(h);
43 }
44 static inline uint16_t fp32_to_fp16(float f) {
45  return _cvtss_sh(f, 0);
46 }
47 #endif
48 
49 /* ============================================================================
50  * GEMV: y = W @ x (W is FP16, x and y are FP32)
51  * ============================================================================ */
52 
53 /**
54  * @brief Matrix-vector multiply with FP16 weights (scalar reference)
55  *
56  * @param y Output vector [M]
57  * @param W Weight matrix in FP16 [M x K]
58  * @param x Input vector [K]
59  * @param M Number of output rows
60  * @param K Number of columns
61  */
62 void gemv_f16_ref(float *y,
63  const uint16_t *W,
64  const float *x,
65  int M, int K)
66 {
67  for (int row = 0; row < M; row++) {
68  float sum = 0.0f;
69  const uint16_t *w_row = &W[row * K];
70 
71  for (int k = 0; k < K; k++) {
72  float w = fp16_to_fp32(w_row[k]);
73  sum += w * x[k];
74  }
75 
76  y[row] = sum;
77  }
78 }
79 
80 #ifdef __AVX512F__
81 /**
82  * @brief Matrix-vector multiply with FP16 weights (AVX-512)
83  *
84  * Converts FP16 to FP32 in registers using VCVTPH2PS.
85  */
86 void gemv_f16_avx512(float *y,
87  const uint16_t *W,
88  const float *x,
89  int M, int K)
90 {
91  const int K16 = K / 16 * 16;
92 
93  for (int row = 0; row < M; row++) {
94  __m512 acc = _mm512_setzero_ps();
95  const uint16_t *w_row = &W[row * K];
96 
97  /* Process 16 elements at a time */
98  for (int k = 0; k < K16; k += 16) {
99  /* Load 16 x FP16 weights */
100  __m256i w_f16 = _mm256_loadu_si256((const __m256i *)&w_row[k]);
101 
102  /* Convert FP16 to FP32 */
103  __m512 w_f32 = _mm512_cvtph_ps(w_f16);
104 
105  /* Load 16 x FP32 inputs */
106  __m512 x_vec = _mm512_loadu_ps(&x[k]);
107 
108  /* FMA */
109  acc = _mm512_fmadd_ps(w_f32, x_vec, acc);
110  }
111 
112  /* Horizontal sum */
113  float sum = _mm512_reduce_add_ps(acc);
114 
115  /* Handle remainder */
116  for (int k = K16; k < K; k++) {
117  sum += fp16_to_fp32(w_row[k]) * x[k];
118  }
119 
120  y[row] = sum;
121  }
122 }
123 #endif /* __AVX512F__ */
124 
125 /**
126  * @brief Auto-dispatch GEMV based on available SIMD
127  */
128 void gemv_f16(float *y,
129  const uint16_t *W,
130  const float *x,
131  int M, int K)
132 {
133 #ifdef __AVX512F__
134  gemv_f16_avx512(y, W, x, M, K);
135 #else
136  gemv_f16_ref(y, W, x, M, K);
137 #endif
138 }
139 
140 /* ============================================================================
141  * GEMM: Y = W @ X (W is FP16, X and Y are FP32)
142  * ============================================================================ */
143 
144 /**
145  * @brief Matrix-matrix multiply with FP16 weights (scalar reference)
146  *
147  * @param Y Output matrix [M x N]
148  * @param W Weight matrix in FP16 [M x K]
149  * @param X Input matrix [K x N]
150  * @param M Number of output rows
151  * @param N Batch size
152  * @param K Hidden dimension
153  */
154 void gemm_f16_ref(float *Y,
155  const uint16_t *W,
156  const float *X,
157  int M, int N, int K)
158 {
159  for (int n = 0; n < N; n++) {
160  gemv_f16_ref(&Y[n * M], W, &X[n * K], M, K);
161  }
162 }
163 
164 #ifdef __AVX512F__
165 /**
166  * @brief Matrix-matrix multiply with FP16 weights (AVX-512)
167  */
168 void gemm_f16_avx512(float *Y,
169  const uint16_t *W,
170  const float *X,
171  int M, int N, int K)
172 {
173  const int K16 = K / 16 * 16;
174 
175  for (int row = 0; row < M; row++) {
176  const uint16_t *w_row = &W[row * K];
177 
178  /* Pre-convert weight row to FP32 in cache-sized chunks */
179  /* For now, convert on-the-fly per batch element */
180 
181  for (int n = 0; n < N; n++) {
182  __m512 acc = _mm512_setzero_ps();
183  const float *x_col = &X[n * K];
184 
185  for (int k = 0; k < K16; k += 16) {
186  __m256i w_f16 = _mm256_loadu_si256((const __m256i *)&w_row[k]);
187  __m512 w_f32 = _mm512_cvtph_ps(w_f16);
188  __m512 x_vec = _mm512_loadu_ps(&x_col[k]);
189  acc = _mm512_fmadd_ps(w_f32, x_vec, acc);
190  }
191 
192  float sum = _mm512_reduce_add_ps(acc);
193 
194  for (int k = K16; k < K; k++) {
195  sum += fp16_to_fp32(w_row[k]) * x_col[k];
196  }
197 
198  Y[n * M + row] = sum;
199  }
200  }
201 }
202 #endif /* __AVX512F__ */
203 
204 /**
205  * @brief Auto-dispatch GEMM based on available SIMD
206  */
207 void gemm_f16(float *Y,
208  const uint16_t *W,
209  const float *X,
210  int M, int N, int K)
211 {
212 #ifdef __AVX512F__
213  gemm_f16_avx512(Y, W, X, M, N, K);
214 #else
215  gemm_f16_ref(Y, W, X, M, N, K);
216 #endif
217 }
218 
219 /* ============================================================================
220  * FP16 Tensor Conversion Utilities
221  * ============================================================================ */
222 
223 /**
224  * @brief Convert FP16 tensor to FP32
225  */
226 void convert_f16_to_f32(float *dst, const uint16_t *src, size_t count)
227 {
228 #ifdef __AVX512F__
229  const size_t count16 = count / 16 * 16;
230 
231  for (size_t i = 0; i < count16; i += 16) {
232  __m256i f16 = _mm256_loadu_si256((const __m256i *)&src[i]);
233  __m512 f32 = _mm512_cvtph_ps(f16);
234  _mm512_storeu_ps(&dst[i], f32);
235  }
236 
237  for (size_t i = count16; i < count; i++) {
238  dst[i] = fp16_to_fp32(src[i]);
239  }
240 #else
241  for (size_t i = 0; i < count; i++) {
242  dst[i] = fp16_to_fp32(src[i]);
243  }
244 #endif
245 }
246 
247 /**
248  * @brief Convert FP32 tensor to FP16
249  */
250 void convert_f32_to_f16(uint16_t *dst, const float *src, size_t count)
251 {
252 #ifdef __AVX512F__
253  const size_t count16 = count / 16 * 16;
254 
255  for (size_t i = 0; i < count16; i += 16) {
256  __m512 f32 = _mm512_loadu_ps(&src[i]);
257  __m256i f16 = _mm512_cvtps_ph(f32, 0);
258  _mm256_storeu_si256((__m256i *)&dst[i], f16);
259  }
260 
261  for (size_t i = count16; i < count; i++) {
262  dst[i] = fp32_to_fp16(src[i]);
263  }
264 #else
265  for (size_t i = 0; i < count; i++) {
266  dst[i] = fp32_to_fp16(src[i]);
267  }
268 #endif
269 }
270 
271 /* ============================================================================
272  * Backward Pass: Gradient w.r.t. Input
273  *
274  * Given: dL/dY (gradient of loss w.r.t. output)
275  * Compute: dL/dX = W^T @ dL/dY
276  *
277  * For F16 weights, we convert to FP32 on-the-fly during backprop.
278  * ============================================================================ */
279 
280 /**
281  * @brief Backward pass: compute input gradient (scalar reference)
282  *
283  * @param dX Output gradient w.r.t. input [K]
284  * @param W Weight matrix in FP16 format [M x K]
285  * @param dY Gradient w.r.t. output [M]
286  * @param M Number of output rows
287  * @param K Number of columns (input dimension)
288  */
289 void gemv_f16_backward_ref(float *dX,
290  const uint16_t *W,
291  const float *dY,
292  int M, int K)
293 {
294  /* Zero output gradient */
295  for (int k = 0; k < K; k++) {
296  dX[k] = 0.0f;
297  }
298 
299  /* Accumulate: dX += W^T @ dY */
300  for (int row = 0; row < M; row++) {
301  const float dy = dY[row];
302  const uint16_t *w_row = &W[row * K];
303 
304  for (int k = 0; k < K; k++) {
305  float w = fp16_to_fp32(w_row[k]);
306  dX[k] += w * dy;
307  }
308  }
309 }
310 
311 #ifdef __AVX512F__
312 /**
313  * @brief Backward pass with AVX-512
314  */
315 void gemv_f16_backward_avx512(float *dX,
316  const uint16_t *W,
317  const float *dY,
318  int M, int K)
319 {
320  const int K16 = K / 16 * 16;
321 
322  /* Zero output */
323  for (int k = 0; k < K16; k += 16) {
324  _mm512_storeu_ps(&dX[k], _mm512_setzero_ps());
325  }
326  for (int k = K16; k < K; k++) {
327  dX[k] = 0.0f;
328  }
329 
330  for (int row = 0; row < M; row++) {
331  const __m512 vdy = _mm512_set1_ps(dY[row]);
332  const uint16_t *w_row = &W[row * K];
333 
334  for (int k = 0; k < K16; k += 16) {
335  /* Load and convert F16 weights */
336  __m256i w_f16 = _mm256_loadu_si256((const __m256i *)&w_row[k]);
337  __m512 w_f32 = _mm512_cvtph_ps(w_f16);
338 
339  /* Compute gradient */
340  __m512 grad = _mm512_mul_ps(w_f32, vdy);
341 
342  /* Accumulate */
343  __m512 dx_cur = _mm512_loadu_ps(&dX[k]);
344  _mm512_storeu_ps(&dX[k], _mm512_add_ps(dx_cur, grad));
345  }
346 
347  /* Remainder */
348  for (int k = K16; k < K; k++) {
349  dX[k] += fp16_to_fp32(w_row[k]) * dY[row];
350  }
351  }
352 }
353 #endif
354 
355 /**
356  * @brief Auto-dispatch backward
357  */
358 void gemv_f16_backward(float *dX,
359  const uint16_t *W,
360  const float *dY,
361  int M, int K)
362 {
363 #ifdef __AVX512F__
364  gemv_f16_backward_avx512(dX, W, dY, M, K);
365 #else
366  gemv_f16_backward_ref(dX, W, dY, M, K);
367 #endif
368 }
369 
370 /**
371  * @brief Batched backward pass
372  */
373 void gemm_f16_backward(float *dX,
374  const uint16_t *W,
375  const float *dY,
376  int M, int N, int K)
377 {
378  for (int n = 0; n < N; n++) {
379  gemv_f16_backward(&dX[n * K], W, &dY[n * M], M, K);
380  }
381 }
382 
383 /* ============================================================================
384  * Dot Product Utility
385  * ============================================================================ */
386 
387 float dot_f16(const uint16_t *w_f16, const float *x, int K)
388 {
389  float result;
390  gemv_f16(&result, w_f16, x, 1, K);
391  return result;
392 }
Quantization block structures for weight-only quantization.
void gemm_f16_ref(float *Y, const uint16_t *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with FP16 weights (scalar reference)
void gemm_f16_backward(float *dX, const uint16_t *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemm_f16(float *Y, const uint16_t *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
void gemv_f16(float *y, const uint16_t *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void convert_f16_to_f32(float *dst, const uint16_t *src, size_t count)
Convert FP16 tensor to FP32.
#define fp16_to_fp32(x)
void gemv_f16_backward_ref(float *dX, const uint16_t *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
float dot_f16(const uint16_t *w_f16, const float *x, int K)
void convert_f32_to_f16(uint16_t *dst, const float *src, size_t count)
Convert FP32 tensor to FP16.
void gemv_f16_backward(float *dX, const uint16_t *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_f16_ref(float *y, const uint16_t *W, const float *x, int M, int K)
Matrix-vector multiply with FP16 weights (scalar reference)
#define fp32_to_fp16(x)