← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_int8.c File Reference

RMSNorm kernels with INT8 output quantization. More...

#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "ckernel_engine.h"

Go to the source code of this file.

Functions

static int8_t clamp_int8 (float value)
 
static void convert_float_to_int8 (const float *src, int8_t *dst, size_t count)
 
static void convert_int8_to_float (const int8_t *src, float *dst, size_t count)
 
void rmsnorm_backward_int8 (const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
 
void rmsnorm_forward_int8 (const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
 

Detailed Description

RMSNorm kernels with INT8 output quantization.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Definition in file rmsnorm_kernels_int8.c.

Function Documentation

◆ clamp_int8()

static int8_t clamp_int8 ( float  value)
static

Definition at line 34 of file rmsnorm_kernels_int8.c.

35 {
36  int32_t q = (int32_t)lrintf(value);
37  if (q > INT8_MAX) {
38  q = INT8_MAX;
39  } else if (q < INT8_MIN) {
40  q = INT8_MIN;
41  }
42  return (int8_t)q;
43 }

Referenced by convert_float_to_int8().

◆ convert_float_to_int8()

static void convert_float_to_int8 ( const float *  src,
int8_t *  dst,
size_t  count 
)
static

Definition at line 45 of file rmsnorm_kernels_int8.c.

48 {
49  for (size_t i = 0; i < count; ++i) {
50  dst[i] = clamp_int8(src[i]);
51  }
52 }
static int8_t clamp_int8(float value)

References clamp_int8().

Referenced by rmsnorm_backward_int8(), and rmsnorm_forward_int8().

◆ convert_int8_to_float()

static void convert_int8_to_float ( const int8_t *  src,
float *  dst,
size_t  count 
)
static

Definition at line 25 of file rmsnorm_kernels_int8.c.

28 {
29  for (size_t i = 0; i < count; ++i) {
30  dst[i] = (float)src[i];
31  }
32 }

Referenced by rmsnorm_backward_int8(), and rmsnorm_forward_int8().

◆ rmsnorm_backward_int8()

void rmsnorm_backward_int8 ( const int8_t *  d_output,
const int8_t *  input,
const float *  gamma,
const float *  rstd_cache,
int8_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float *  scratch_d_output,
float *  scratch_input,
float *  scratch_d_input 
)

Definition at line 84 of file rmsnorm_kernels_int8.c.

96 {
97  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
98  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
99 
100  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
101 
102  convert_int8_to_float(d_output, scratch_d_output, total);
103  convert_int8_to_float(input, scratch_input, total);
104 
105  // Zero gamma gradient before accumulation.
106  for (int d = 0; d < d_model; ++d) {
107  d_gamma[d] = 0.0f;
108  }
109 
110  rmsnorm_backward(scratch_d_output,
111  scratch_input,
112  gamma,
113  rstd_cache,
114  scratch_d_input,
115  d_gamma,
116  tokens,
117  d_model,
118  aligned_embed_dim);
119 
120  convert_float_to_int8(scratch_d_input, d_input, total);
121 }
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static void convert_int8_to_float(const int8_t *src, float *dst, size_t count)
static void convert_float_to_int8(const float *src, int8_t *dst, size_t count)

References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_backward().

◆ rmsnorm_forward_int8()

void rmsnorm_forward_int8 ( const int8_t *  input,
const float *  gamma,
int8_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 58 of file rmsnorm_kernels_int8.c.

68 {
69  if (!input || !gamma || !output) return;
70  if (!scratch_input || !scratch_output) return;
71 
72  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
73 
74  convert_int8_to_float(input, scratch_input, total);
75  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
76  tokens, d_model, aligned_embed_dim, eps);
77  convert_float_to_int8(scratch_output, output, total);
78 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_forward().