← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_int4.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_kernels_int4.c
3  * @brief RMSNorm kernels with INT4 output quantization
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  */
14 
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include "ckernel_engine.h"
20 
21 static inline int8_t decode_int4(uint8_t packed, int index)
22 {
23  int8_t nibble;
24  if ((index & 1) == 0) {
25  nibble = packed & 0x0F;
26  } else {
27  nibble = (packed >> 4) & 0x0F;
28  }
29  if (nibble >= 8) {
30  nibble -= 16;
31  }
32  return nibble;
33 }
34 
35 static inline uint8_t encode_int4_nibble(int8_t value)
36 {
37  if (value > 7) {
38  value = 7;
39  } else if (value < -8) {
40  value = -8;
41  }
42  return (uint8_t)(value & 0x0F);
43 }
44 
45 static void convert_int4_to_float(const uint8_t *src,
46  float *dst,
47  size_t count)
48 {
49  for (size_t i = 0; i < count; ++i) {
50  uint8_t packed = src[i >> 1];
51  dst[i] = (float)decode_int4(packed, (int)(i & 1));
52  }
53 }
54 
55 static void convert_float_to_int4(const float *src,
56  uint8_t *dst,
57  size_t count)
58 {
59  size_t bytes = (count + 1) / 2;
60  for (size_t i = 0; i < bytes; ++i) {
61  dst[i] = 0;
62  }
63  for (size_t i = 0; i < count; ++i) {
64  uint8_t quant = encode_int4_nibble((int8_t)lrintf(src[i]));
65  size_t byte_idx = i >> 1;
66  if ((i & 1) == 0) {
67  dst[byte_idx] = (dst[byte_idx] & 0xF0) | quant;
68  } else {
69  dst[byte_idx] = (dst[byte_idx] & 0x0F) | (quant << 4);
70  }
71  }
72 }
73 
74 /*
75  * INT4 RMSNorm forward with caller-provided scratch buffers.
76  * scratch_input, scratch_output: each [tokens * aligned_embed_dim] floats
77  */
78 void rmsnorm_forward_int4(const uint8_t *input,
79  const float *gamma,
80  uint8_t *output,
81  float *rstd_cache,
82  int tokens,
83  int d_model,
84  int aligned_embed_dim,
85  float eps,
86  float *scratch_input,
87  float *scratch_output)
88 {
89  if (!input || !gamma || !output) return;
90  if (!scratch_input || !scratch_output) return;
91 
92  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
93 
94  convert_int4_to_float(input, scratch_input, total);
95  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
96  tokens, d_model, aligned_embed_dim, eps);
97  convert_float_to_int4(scratch_output, output, total);
98 }
99 
100 /*
101  * INT4 RMSNorm backward with caller-provided scratch buffers.
102  * scratch_d_output, scratch_input, scratch_d_input: each [tokens * aligned_embed_dim] floats
103  */
104 void rmsnorm_backward_int4(const uint8_t *d_output,
105  const uint8_t *input,
106  const float *gamma,
107  const float *rstd_cache,
108  uint8_t *d_input,
109  float *d_gamma,
110  int tokens,
111  int d_model,
112  int aligned_embed_dim,
113  float *scratch_d_output,
114  float *scratch_input,
115  float *scratch_d_input)
116 {
117  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
118  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
119 
120  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
121 
122  convert_int4_to_float(d_output, scratch_d_output, total);
123  convert_int4_to_float(input, scratch_input, total);
124 
125  for (int d = 0; d < d_model; ++d) {
126  d_gamma[d] = 0.0f;
127  }
128 
129  rmsnorm_backward(scratch_d_output,
130  scratch_input,
131  gamma,
132  rstd_cache,
133  scratch_d_input,
134  d_gamma,
135  tokens,
136  d_model,
137  aligned_embed_dim);
138 
139  convert_float_to_int4(scratch_d_input, d_input, total);
140 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static void convert_int4_to_float(const uint8_t *src, float *dst, size_t count)
void rmsnorm_backward_int4(const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
static int8_t decode_int4(uint8_t packed, int index)
static void convert_float_to_int4(const float *src, uint8_t *dst, size_t count)
static uint8_t encode_int4_nibble(int8_t value)
void rmsnorm_forward_int4(const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)