Optimized BF16 MLP Kernels. More...
#include <stddef.h>#include <stdint.h>#include <math.h>#include "bf16_utils.h"#include "ckernel_engine.h"Go to the source code of this file.
Functions | |
| static float | gelu_scalar (float x) |
| void | gemm_bf16_fp32out (const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K) |
| void | mlp_token_parallel_bf16 (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16) |
| void | mlp_token_parallel_bf16_fp32act (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16) |
Optimized BF16 MLP Kernels.
After changes: make test && make llamacpp-parity-full
Uses direct BF16 GEMM instead of converting to FP32. Layout: input[T,D] -> fc1[T,4D] -> GELU -> fc2[T,D]
All functions use caller-provided scratch buffers (no internal malloc).
Definition in file mlp_kernels_bf16.c.
|
inlinestatic |
Definition at line 45 of file mlp_kernels_bf16.c.
Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().
| void gemm_bf16_fp32out | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 301 of file gemm_kernels_bf16.c.
References bf16_to_float(), and C.
Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().
| void mlp_token_parallel_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | W_fc1, | ||
| const uint16_t * | b_fc1, | ||
| const uint16_t * | W_fc2, | ||
| const uint16_t * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads, | ||
| float * | scratch_bias1_f, | ||
| float * | scratch_bias2_f, | ||
| uint16_t * | scratch_fc1_bf16 | ||
| ) |
Optimized MLP Forward (BF16 weights, FP32 activations)
Caller-provided scratch buffers: scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
Definition at line 91 of file mlp_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().
| void mlp_token_parallel_bf16_fp32act | ( | const uint16_t * | input, |
| const uint16_t * | W_fc1, | ||
| const uint16_t * | b_fc1, | ||
| const uint16_t * | W_fc2, | ||
| const uint16_t * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads, | ||
| float * | scratch_input_f, | ||
| float * | scratch_bias1_f, | ||
| float * | scratch_bias2_f, | ||
| uint16_t * | scratch_fc1_bf16 | ||
| ) |
Alternative: Fully FP32 activations throughout
Caller-provided scratch buffers: scratch_input_f: [T * D] floats scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
Definition at line 186 of file mlp_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().