MLP (feed-forward) kernels with SIMD (SSE/AVX/AVX512) More...
Go to the source code of this file.
Functions | |
| void | fc1_backward_kernel (const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads) |
| void | fc2_backward_kernel (const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads) |
| void | mlp_token_parallel (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads) |
| void | mlp_token_parallel_exact (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads) |
MLP (feed-forward) kernels with SIMD (SSE/AVX/AVX512)
After changes: make test && make llamacpp-parity-full
LEGACY EXCEPTION: This file contains OpenMP for backward compatibility. New kernels should NOT use OpenMP internally.
MLP: out = FC2(GELU(FC1(x)))
Definition in file mlp_kernels.c.
| void fc1_backward_kernel | ( | const float * | d_output, |
| const float * | fc1_input, | ||
| const float * | W_fc1, | ||
| float * | d_input, | ||
| float * | d_W_fc1, | ||
| float * | d_b_fc1, | ||
| int | T, | ||
| int | aligned_in, | ||
| int | aligned_out, | ||
| int | num_threads | ||
| ) |
Definition at line 167 of file mlp_kernels.c.
References gemm_nn_avx512(), and gemm_tn_avx512().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void fc2_backward_kernel | ( | const float * | d_output, |
| const float * | fc2_input, | ||
| const float * | W_fc2, | ||
| float * | d_input, | ||
| float * | d_W_fc2, | ||
| float * | d_b_fc2, | ||
| int | T, | ||
| int | aligned_in, | ||
| int | aligned_out, | ||
| int | num_threads | ||
| ) |
Definition at line 118 of file mlp_kernels.c.
References gemm_nn_avx512(), and gemm_tn_avx512().
Referenced by ck_attention_project_head_major_backward(), ck_layer_backward_rmsnorm_swiglu(), and ck_qkv_project_head_major_backward().
| void mlp_token_parallel | ( | const float * | input, |
| const float * | W_fc1, | ||
| const float * | b_fc1, | ||
| const float * | W_fc2, | ||
| const float * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads | ||
| ) |
Definition at line 41 of file mlp_kernels.c.
References gelu_fast_inplace(), and gemm_blocked_serial().
| void mlp_token_parallel_exact | ( | const float * | input, |
| const float * | W_fc1, | ||
| const float * | b_fc1, | ||
| const float * | W_fc2, | ||
| const float * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads | ||
| ) |
Definition at line 76 of file mlp_kernels.c.
References gelu_exact_inplace(), and gemm_blocked_serial().