RMSNorm kernels with INT4 output quantization. More...
Go to the source code of this file.
Functions | |
| static void | convert_float_to_int4 (const float *src, uint8_t *dst, size_t count) |
| static void | convert_int4_to_float (const uint8_t *src, float *dst, size_t count) |
| static int8_t | decode_int4 (uint8_t packed, int index) |
| static uint8_t | encode_int4_nibble (int8_t value) |
| void | rmsnorm_backward_int4 (const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input) |
| void | rmsnorm_forward_int4 (const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output) |
RMSNorm kernels with INT4 output quantization.
After changes: make test && make llamacpp-parity-full
Definition in file rmsnorm_kernels_int4.c.
|
static |
Definition at line 55 of file rmsnorm_kernels_int4.c.
References encode_int4_nibble().
Referenced by rmsnorm_backward_int4(), and rmsnorm_forward_int4().
|
static |
Definition at line 45 of file rmsnorm_kernels_int4.c.
References decode_int4().
Referenced by rmsnorm_backward_int4(), and rmsnorm_forward_int4().
|
inlinestatic |
|
inlinestatic |
| void rmsnorm_backward_int4 | ( | const uint8_t * | d_output, |
| const uint8_t * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| uint8_t * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float * | scratch_d_output, | ||
| float * | scratch_input, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 104 of file rmsnorm_kernels_int4.c.
References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_backward().
| void rmsnorm_forward_int4 | ( | const uint8_t * | input, |
| const float * | gamma, | ||
| uint8_t * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 78 of file rmsnorm_kernels_int4.c.
References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_forward().