11 #ifndef CKERNEL_QUANT_H
12 #define CKERNEL_QUANT_H
121 #define K_SCALE_SIZE 12
186 case 0:
return QK4_0;
187 case 1:
return QK8_0;
223 return (n_elements / block_size) * type_size;
247 uint8_t *sc, uint8_t *m) {
249 sc[0] = scales[0] & 0x3F;
250 sc[1] = scales[1] & 0x3F;
251 sc[2] = scales[2] & 0x3F;
252 sc[3] = scales[3] & 0x3F;
254 m[0] = scales[4] & 0x3F;
255 m[1] = scales[5] & 0x3F;
256 m[2] = scales[6] & 0x3F;
257 m[3] = scales[7] & 0x3F;
261 sc[4] = (scales[8] & 0x0F) | ((scales[0] >> 6) << 4);
262 sc[5] = (scales[9] & 0x0F) | ((scales[1] >> 6) << 4);
263 sc[6] = (scales[10] & 0x0F) | ((scales[2] >> 6) << 4);
264 sc[7] = (scales[11] & 0x0F) | ((scales[3] >> 6) << 4);
266 m[4] = (scales[8] >> 4) | ((scales[4] >> 6) << 4);
267 m[5] = (scales[9] >> 4) | ((scales[5] >> 6) << 4);
268 m[6] = (scales[10] >> 4) | ((scales[6] >> 6) << 4);
269 m[7] = (scales[11] >> 4) | ((scales[7] >> 6) << 4);
286 uint8_t *sc, uint8_t *m) {
304 uint32_t sign = (h & 0x8000) << 16;
305 uint32_t exp = (h >> 10) & 0x1F;
306 uint32_t mant = h & 0x3FF;
316 while ((mant & 0x400) == 0) {
321 result = sign | ((exp + 127 - 15) << 23) | (mant << 13);
323 }
else if (exp == 31) {
324 result = sign | 0x7F800000 | (mant << 13);
326 result = sign | ((exp + 127 - 15) << 23) | (mant << 13);
329 union { uint32_t u;
float f; } u;
338 union { uint32_t u;
float f; } u;
341 uint32_t sign = (u.u >> 16) & 0x8000;
342 int32_t exp = ((u.u >> 23) & 0xFF) - 127 + 15;
343 uint32_t mant = (u.u >> 13) & 0x3FF;
349 mant = (mant | 0x400) >> (1 - exp);
351 }
else if (exp >= 31) {
352 return sign | 0x7C00;
355 return sign | (exp << 10) | mant;
362 #if defined(__F16C__)
363 #include <immintrin.h>
368 static inline float ck_fp16_to_fp32_simd(
ck_half h) {
375 static inline ck_half ck_fp32_to_fp16_simd(
float f) {
376 return (
ck_half)_cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
384 #if defined(__F16C__)
385 return ck_fp16_to_fp32_simd(h);
392 #if defined(__F16C__)
393 return ck_fp32_to_fp16_simd(f);
400 #define CK_FP16_TO_FP32(x) ck_fp16_to_fp32(x)
401 #define CK_FP32_TO_FP16(x) ck_fp32_to_fp16(x)
402 #define CK_FP16_TO_FP32_SIMD(x) ck_fp16_to_fp32_simd(x)
403 #define CK_FP32_TO_FP16_SIMD(x) ck_fp32_to_fp16_simd(x)
404 #define CK_FP16_TO_FP32_SOFT(x) ck_fp16_to_fp32_soft(x)
405 #define CK_FP32_TO_FP16_SOFT(x) ck_fp32_to_fp16_soft(x)
409 #define ggml_fp16_to_fp32 ck_fp16_to_fp32
410 #define ggml_fp32_to_fp16 ck_fp32_to_fp16
411 #define GGML_FP16_TO_FP32 CK_FP16_TO_FP32
412 #define GGML_FP32_TO_FP16 CK_FP32_TO_FP16
418 void gemm_nt_q5_0_sse_v2(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
419 void gemm_nt_q6_k_sse(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
420 void gemm_nt_q6_k_ref(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
423 void rmsnorm_q8_k_fused(
const float *input,
const float *gamma,
void *vy,
int tokens,
int d_model,
int aligned_embed_dim,
float eps);
426 void gemv_q5_k_ref(
float *y,
const void *W,
const float *x,
int M,
int K);
427 void gemm_nt_q5_k_ref(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
428 void gemv_q5_k(
float *y,
const void *W,
const float *x,
int M,
int K);
429 void gemm_nt_q5_k(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
432 void gemm_nt_q5_0_q8_0(
const void *A_q8,
const void *B_q5,
const float *bias,
float *
C,
int M,
int N,
int K);
void gemm_nt_q5_0_sse_v2(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K)
static float ck_fp16_to_fp32_soft(ck_half h)
Convert FP16 (ck_half) to FP32 — software implementation.
void gemm_nt_q6_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void gemm_nt_q6_k_sse(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
static size_t ck_quant_type_size(int type)
Get the byte size per block for a quant type.
void rmsnorm_q8_k_fused(const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
static ck_half ck_fp32_to_fp16(float f)
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void gemm_nt_q5_0_q8_0_unroll_avx(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
static float ck_fp16_to_fp32(ck_half h)
static void unpack_q5_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q5_K sub-block scales and mins.
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
static size_t ck_quant_block_size(int type)
Get the block size (number of weights per block) for a quant type.
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
static size_t ck_quant_row_size(int type, int64_t n_elements)
Calculate total bytes needed for n_elements with given quant type.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
static ck_half ck_fp32_to_fp16_soft(float f)
Convert FP32 to FP16 (ck_half) — software implementation.
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)