24 if ((index & 1) == 0) {
25 nibble = packed & 0x0F;
27 nibble = (packed >> 4) & 0x0F;
39 }
else if (value < -8) {
42 return (uint8_t)(value & 0x0F);
49 for (
size_t i = 0; i < count; ++i) {
50 uint8_t packed = src[i >> 1];
59 size_t bytes = (count + 1) / 2;
60 for (
size_t i = 0; i < bytes; ++i) {
63 for (
size_t i = 0; i < count; ++i) {
65 size_t byte_idx = i >> 1;
67 dst[byte_idx] = (dst[byte_idx] & 0xF0) | quant;
69 dst[byte_idx] = (dst[byte_idx] & 0x0F) | (quant << 4);
84 int aligned_embed_dim,
87 float *scratch_output)
89 if (!input || !gamma || !output)
return;
90 if (!scratch_input || !scratch_output)
return;
92 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
96 tokens, d_model, aligned_embed_dim, eps);
105 const uint8_t *input,
107 const float *rstd_cache,
112 int aligned_embed_dim,
113 float *scratch_d_output,
114 float *scratch_input,
115 float *scratch_d_input)
117 if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma)
return;
118 if (!scratch_d_output || !scratch_input || !scratch_d_input)
return;
120 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
125 for (
int d = 0; d < d_model; ++d) {
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static void convert_int4_to_float(const uint8_t *src, float *dst, size_t count)
void rmsnorm_backward_int4(const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
static int8_t decode_int4(uint8_t packed, int index)
static void convert_float_to_int4(const float *src, uint8_t *dst, size_t count)
static uint8_t encode_int4_nibble(int8_t value)
void rmsnorm_forward_int4(const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)