← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_int4.c File Reference

RMSNorm kernels with INT4 output quantization. More...

#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "ckernel_engine.h"

Go to the source code of this file.

Functions

static void convert_float_to_int4 (const float *src, uint8_t *dst, size_t count)
 
static void convert_int4_to_float (const uint8_t *src, float *dst, size_t count)
 
static int8_t decode_int4 (uint8_t packed, int index)
 
static uint8_t encode_int4_nibble (int8_t value)
 
void rmsnorm_backward_int4 (const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
 
void rmsnorm_forward_int4 (const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
 

Detailed Description

RMSNorm kernels with INT4 output quantization.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Definition in file rmsnorm_kernels_int4.c.

Function Documentation

◆ convert_float_to_int4()

static void convert_float_to_int4 ( const float *  src,
uint8_t *  dst,
size_t  count 
)
static

Definition at line 55 of file rmsnorm_kernels_int4.c.

58 {
59  size_t bytes = (count + 1) / 2;
60  for (size_t i = 0; i < bytes; ++i) {
61  dst[i] = 0;
62  }
63  for (size_t i = 0; i < count; ++i) {
64  uint8_t quant = encode_int4_nibble((int8_t)lrintf(src[i]));
65  size_t byte_idx = i >> 1;
66  if ((i & 1) == 0) {
67  dst[byte_idx] = (dst[byte_idx] & 0xF0) | quant;
68  } else {
69  dst[byte_idx] = (dst[byte_idx] & 0x0F) | (quant << 4);
70  }
71  }
72 }
static uint8_t encode_int4_nibble(int8_t value)

References encode_int4_nibble().

Referenced by rmsnorm_backward_int4(), and rmsnorm_forward_int4().

◆ convert_int4_to_float()

static void convert_int4_to_float ( const uint8_t *  src,
float *  dst,
size_t  count 
)
static

Definition at line 45 of file rmsnorm_kernels_int4.c.

48 {
49  for (size_t i = 0; i < count; ++i) {
50  uint8_t packed = src[i >> 1];
51  dst[i] = (float)decode_int4(packed, (int)(i & 1));
52  }
53 }
static int8_t decode_int4(uint8_t packed, int index)

References decode_int4().

Referenced by rmsnorm_backward_int4(), and rmsnorm_forward_int4().

◆ decode_int4()

static int8_t decode_int4 ( uint8_t  packed,
int  index 
)
inlinestatic

Definition at line 21 of file rmsnorm_kernels_int4.c.

22 {
23  int8_t nibble;
24  if ((index & 1) == 0) {
25  nibble = packed & 0x0F;
26  } else {
27  nibble = (packed >> 4) & 0x0F;
28  }
29  if (nibble >= 8) {
30  nibble -= 16;
31  }
32  return nibble;
33 }

Referenced by convert_int4_to_float().

◆ encode_int4_nibble()

static uint8_t encode_int4_nibble ( int8_t  value)
inlinestatic

Definition at line 35 of file rmsnorm_kernels_int4.c.

36 {
37  if (value > 7) {
38  value = 7;
39  } else if (value < -8) {
40  value = -8;
41  }
42  return (uint8_t)(value & 0x0F);
43 }

Referenced by convert_float_to_int4().

◆ rmsnorm_backward_int4()

void rmsnorm_backward_int4 ( const uint8_t *  d_output,
const uint8_t *  input,
const float *  gamma,
const float *  rstd_cache,
uint8_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float *  scratch_d_output,
float *  scratch_input,
float *  scratch_d_input 
)

Definition at line 104 of file rmsnorm_kernels_int4.c.

116 {
117  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
118  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
119 
120  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
121 
122  convert_int4_to_float(d_output, scratch_d_output, total);
123  convert_int4_to_float(input, scratch_input, total);
124 
125  for (int d = 0; d < d_model; ++d) {
126  d_gamma[d] = 0.0f;
127  }
128 
129  rmsnorm_backward(scratch_d_output,
130  scratch_input,
131  gamma,
132  rstd_cache,
133  scratch_d_input,
134  d_gamma,
135  tokens,
136  d_model,
137  aligned_embed_dim);
138 
139  convert_float_to_int4(scratch_d_input, d_input, total);
140 }
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static void convert_int4_to_float(const uint8_t *src, float *dst, size_t count)
static void convert_float_to_int4(const float *src, uint8_t *dst, size_t count)

References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_backward().

◆ rmsnorm_forward_int4()

void rmsnorm_forward_int4 ( const uint8_t *  input,
const float *  gamma,
uint8_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 78 of file rmsnorm_kernels_int4.c.

88 {
89  if (!input || !gamma || !output) return;
90  if (!scratch_input || !scratch_output) return;
91 
92  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
93 
94  convert_int4_to_float(input, scratch_input, total);
95  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
96  tokens, d_model, aligned_embed_dim, eps);
97  convert_float_to_int4(scratch_output, output, total);
98 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_forward().