RMSNorm forward/backward kernels with SIMD (SSE/AVX/AVX512) More...
#include <math.h>#include <stddef.h>Go to the source code of this file.
Functions | |
| void | rmsnorm_backward (const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim) |
| void | rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
RMSNorm forward/backward kernels with SIMD (SSE/AVX/AVX512)
After changes: make test && make llamacpp-parity-full
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
Definition in file rmsnorm_kernels.c.
| void rmsnorm_backward | ( | const float * | d_output, |
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| float * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim | ||
| ) |
RMSNorm backward pass
test_rmsnorm.py::TestRMSNormBackward::test_backward_tokens
test_rmsnorm.py::TestRMSNormBackward::test_backward_single
test_parity.py::test_rmsnorm_backward_parity
Computes dX and dGamma given dY, X, gamma, and cached rstd. dX_i = rstd * (dY_i * gamma_i - x_hat_i * m) dGamma_i = sum_t (dY_i * x_hat_i)
After changes: make test && make llamacpp-parity-full
Definition at line 184 of file rmsnorm_kernels.c.
Referenced by ck_layer_backward_rmsnorm_swiglu(), rmsnorm_backward_int4(), and rmsnorm_backward_int8().
| void rmsnorm_forward | ( | const float * | input, |
| const float * | gamma, | ||
| float * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
RMSNorm forward pass
test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
test_rmsnorm.py::TestRMSNormForward::test_fp32_single
test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
test_parity.py::test_rmsnorm_parity
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
After changes: make test && make llamacpp-parity-full
Definition at line 50 of file rmsnorm_kernels.c.
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), ck_test_rmsnorm(), mega_fused_attention_decode_q5_0(), mega_fused_attention_decode_q5_0_parallel_simd(), mega_fused_outproj_mlp_prefill(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qk_norm_forward(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), qwen2_0_5b_decode_layer_9_prefill(), rmsnorm_forward_int4(), and rmsnorm_forward_int8().