← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k_avx.c File Reference

AVX Q4_K x Q8_K matvec kernel for Sandy/Ivy Bridge. More...

#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

void gemv_q4_k_q8_k_avx (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 
void gemv_q4_k_q8_k_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 
void gemv_q4_k_q8_k_ref (float *y, const void *W, const void *x_q8, int M, int K)
 

Detailed Description

AVX Q4_K x Q8_K matvec kernel for Sandy/Ivy Bridge.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Uses _mm_maddubs_epi16 (SSSE3) for efficient u8*s8 multiply-add while maintaining our scale format from unpack_q4_k_scales.

Key improvement over SSE: _mm_maddubs_epi16 processes 16 pairs per instruction vs SSE's _mm_cvtepu8_epi16 + _mm_madd_epi16 (8 pairs).

Definition in file gemm_kernels_q4k_avx.c.

Function Documentation

◆ gemv_q4_k_q8_k_avx()

void gemv_q4_k_q8_k_avx ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 251 of file gemm_kernels_q4k_avx.c.

255 {
256  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
257 }
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k_ref().

Referenced by gemv_q4_k_q8_k(), and gemv_q4_k_q8_k_amx().

◆ gemv_q4_k_q8_k_parallel()

void gemv_q4_k_q8_k_parallel ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Definition at line 206 of file gemm_kernels_q4k_q8k.c.

211 {
212  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
213  return;
214  }
215  if (ith < 0 || nth <= 0 || ith >= nth) {
216  return;
217  }
218 
219  /* Compute row range for this thread */
220  const int dr = (M + nth - 1) / nth;
221  const int r0 = dr * ith;
222  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
223 
224  if (r0 >= M) {
225  return; /* This thread has no work */
226  }
227 
228  const block_q4_K *blocks = (const block_q4_K *)W;
229  const block_q8_K *x = (const block_q8_K *)x_q8;
230  const int blocks_per_row = K / QK_K;
231 
232  /* Only process rows [r0, r1) */
233  for (int row = r0; row < r1; ++row) {
234  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
235  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
236  }
237 }
#define QK_K
static float dot_q4_k_q8_k_ref(const block_q4_K *w, const block_q8_K *x, int k)

Referenced by gemv_q4_k_q8_k_parallel_simd().

◆ gemv_q4_k_q8_k_parallel_simd()

void gemv_q4_k_q8_k_parallel_simd ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Definition at line 263 of file gemm_kernels_q4k_avx.c.

268 {
269  /* Fall back to reference parallel version */
270  gemv_q4_k_q8_k_parallel(y, W, x_q8, M, K, ith, nth);
271 }
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)

References gemv_q4_k_q8_k_parallel().

Referenced by decode_layer_parallel(), mlp_parallel(), and qkv_projection_parallel().

◆ gemv_q4_k_q8_k_ref()

void gemv_q4_k_q8_k_ref ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 177 of file gemm_kernels_q4k_q8k.c.

181 {
182  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
183  return;
184  }
185 
186  const block_q4_K *blocks = (const block_q4_K *)W;
187  const block_q8_K *x = (const block_q8_K *)x_q8;
188  const int blocks_per_row = K / QK_K;
189 
190  for (int row = 0; row < M; ++row) {
191  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
192  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
193  }
194 }

Referenced by gemv_q4_k_q8_k_avx().