← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_int8.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_kernels_int8.c
3  * @brief RMSNorm kernels with INT8 output quantization
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  */
14 
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include "ckernel_engine.h"
20 
21 /* Suppress false positive warnings about uninitialized variables */
22 #pragma GCC diagnostic push
23 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
24 
25 static void convert_int8_to_float(const int8_t *src,
26  float *dst,
27  size_t count)
28 {
29  for (size_t i = 0; i < count; ++i) {
30  dst[i] = (float)src[i];
31  }
32 }
33 
34 static int8_t clamp_int8(float value)
35 {
36  int32_t q = (int32_t)lrintf(value);
37  if (q > INT8_MAX) {
38  q = INT8_MAX;
39  } else if (q < INT8_MIN) {
40  q = INT8_MIN;
41  }
42  return (int8_t)q;
43 }
44 
45 static void convert_float_to_int8(const float *src,
46  int8_t *dst,
47  size_t count)
48 {
49  for (size_t i = 0; i < count; ++i) {
50  dst[i] = clamp_int8(src[i]);
51  }
52 }
53 
54 /*
55  * INT8 RMSNorm forward with caller-provided scratch buffers.
56  * scratch_input, scratch_output: each [tokens * aligned_embed_dim] floats
57  */
58 void rmsnorm_forward_int8(const int8_t *input,
59  const float *gamma,
60  int8_t *output,
61  float *rstd_cache,
62  int tokens,
63  int d_model,
64  int aligned_embed_dim,
65  float eps,
66  float *scratch_input,
67  float *scratch_output)
68 {
69  if (!input || !gamma || !output) return;
70  if (!scratch_input || !scratch_output) return;
71 
72  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
73 
74  convert_int8_to_float(input, scratch_input, total);
75  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
76  tokens, d_model, aligned_embed_dim, eps);
77  convert_float_to_int8(scratch_output, output, total);
78 }
79 
80 /*
81  * INT8 RMSNorm backward with caller-provided scratch buffers.
82  * scratch_d_output, scratch_input, scratch_d_input: each [tokens * aligned_embed_dim] floats
83  */
84 void rmsnorm_backward_int8(const int8_t *d_output,
85  const int8_t *input,
86  const float *gamma,
87  const float *rstd_cache,
88  int8_t *d_input,
89  float *d_gamma,
90  int tokens,
91  int d_model,
92  int aligned_embed_dim,
93  float *scratch_d_output,
94  float *scratch_input,
95  float *scratch_d_input)
96 {
97  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
98  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
99 
100  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
101 
102  convert_int8_to_float(d_output, scratch_d_output, total);
103  convert_int8_to_float(input, scratch_input, total);
104 
105  // Zero gamma gradient before accumulation.
106  for (int d = 0; d < d_model; ++d) {
107  d_gamma[d] = 0.0f;
108  }
109 
110  rmsnorm_backward(scratch_d_output,
111  scratch_input,
112  gamma,
113  rstd_cache,
114  scratch_d_input,
115  d_gamma,
116  tokens,
117  d_model,
118  aligned_embed_dim);
119 
120  convert_float_to_int8(scratch_d_input, d_input, total);
121 }
122 
123 #pragma GCC diagnostic pop
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static int8_t clamp_int8(float value)
void rmsnorm_forward_int8(const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
static void convert_int8_to_float(const int8_t *src, float *dst, size_t count)
void rmsnorm_backward_int8(const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
static void convert_float_to_int8(const float *src, int8_t *dst, size_t count)