27 #include <immintrin.h>
36 #define fp16_to_fp32(x) ggml_fp16_to_fp32(x)
37 #define fp32_to_fp16(x) ggml_fp32_to_fp16(x)
40 #include <immintrin.h>
45 return _cvtss_sh(f, 0);
67 for (
int row = 0; row < M; row++) {
69 const uint16_t *w_row = &W[row * K];
71 for (
int k = 0; k < K; k++) {
86 void gemv_f16_avx512(
float *y,
91 const int K16 = K / 16 * 16;
93 for (
int row = 0; row < M; row++) {
94 __m512 acc = _mm512_setzero_ps();
95 const uint16_t *w_row = &W[row * K];
98 for (
int k = 0; k < K16; k += 16) {
100 __m256i w_f16 = _mm256_loadu_si256((
const __m256i *)&w_row[k]);
103 __m512 w_f32 = _mm512_cvtph_ps(w_f16);
106 __m512 x_vec = _mm512_loadu_ps(&x[k]);
109 acc = _mm512_fmadd_ps(w_f32, x_vec, acc);
113 float sum = _mm512_reduce_add_ps(acc);
116 for (
int k = K16; k < K; k++) {
134 gemv_f16_avx512(y, W, x, M, K);
159 for (
int n = 0; n < N; n++) {
168 void gemm_f16_avx512(
float *Y,
173 const int K16 = K / 16 * 16;
175 for (
int row = 0; row < M; row++) {
176 const uint16_t *w_row = &W[row * K];
181 for (
int n = 0; n < N; n++) {
182 __m512 acc = _mm512_setzero_ps();
183 const float *x_col = &X[n * K];
185 for (
int k = 0; k < K16; k += 16) {
186 __m256i w_f16 = _mm256_loadu_si256((
const __m256i *)&w_row[k]);
187 __m512 w_f32 = _mm512_cvtph_ps(w_f16);
188 __m512 x_vec = _mm512_loadu_ps(&x_col[k]);
189 acc = _mm512_fmadd_ps(w_f32, x_vec, acc);
192 float sum = _mm512_reduce_add_ps(acc);
194 for (
int k = K16; k < K; k++) {
198 Y[n * M + row] = sum;
213 gemm_f16_avx512(Y, W, X, M, N, K);
229 const size_t count16 = count / 16 * 16;
231 for (
size_t i = 0; i < count16; i += 16) {
232 __m256i f16 = _mm256_loadu_si256((
const __m256i *)&src[i]);
233 __m512 f32 = _mm512_cvtph_ps(f16);
234 _mm512_storeu_ps(&dst[i], f32);
237 for (
size_t i = count16; i < count; i++) {
241 for (
size_t i = 0; i < count; i++) {
253 const size_t count16 = count / 16 * 16;
255 for (
size_t i = 0; i < count16; i += 16) {
256 __m512 f32 = _mm512_loadu_ps(&src[i]);
257 __m256i f16 = _mm512_cvtps_ph(f32, 0);
258 _mm256_storeu_si256((__m256i *)&dst[i], f16);
261 for (
size_t i = count16; i < count; i++) {
265 for (
size_t i = 0; i < count; i++) {
295 for (
int k = 0; k < K; k++) {
300 for (
int row = 0; row < M; row++) {
301 const float dy = dY[row];
302 const uint16_t *w_row = &W[row * K];
304 for (
int k = 0; k < K; k++) {
315 void gemv_f16_backward_avx512(
float *dX,
320 const int K16 = K / 16 * 16;
323 for (
int k = 0; k < K16; k += 16) {
324 _mm512_storeu_ps(&dX[k], _mm512_setzero_ps());
326 for (
int k = K16; k < K; k++) {
330 for (
int row = 0; row < M; row++) {
331 const __m512 vdy = _mm512_set1_ps(dY[row]);
332 const uint16_t *w_row = &W[row * K];
334 for (
int k = 0; k < K16; k += 16) {
336 __m256i w_f16 = _mm256_loadu_si256((
const __m256i *)&w_row[k]);
337 __m512 w_f32 = _mm512_cvtph_ps(w_f16);
340 __m512 grad = _mm512_mul_ps(w_f32, vdy);
343 __m512 dx_cur = _mm512_loadu_ps(&dX[k]);
344 _mm512_storeu_ps(&dX[k], _mm512_add_ps(dx_cur, grad));
348 for (
int k = K16; k < K; k++) {
364 gemv_f16_backward_avx512(dX, W, dY, M, K);
378 for (
int n = 0; n < N; n++) {
387 float dot_f16(
const uint16_t *w_f16,
const float *x,
int K)
Quantization block structures for weight-only quantization.
void gemm_f16_ref(float *Y, const uint16_t *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with FP16 weights (scalar reference)
void gemm_f16_backward(float *dX, const uint16_t *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemm_f16(float *Y, const uint16_t *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
void gemv_f16(float *y, const uint16_t *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void convert_f16_to_f32(float *dst, const uint16_t *src, size_t count)
Convert FP16 tensor to FP32.
void gemv_f16_backward_ref(float *dX, const uint16_t *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
float dot_f16(const uint16_t *w_f16, const float *x, int K)
void convert_f32_to_f16(uint16_t *dst, const float *src, size_t count)
Convert FP32 tensor to FP16.
void gemv_f16_backward(float *dX, const uint16_t *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_f16_ref(float *y, const uint16_t *W, const float *x, int M, int K)
Matrix-vector multiply with FP16 weights (scalar reference)