23 #pragma GCC diagnostic push
24 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
31 const float *__restrict gamma,
32 const float *__restrict beta,
33 uint16_t *__restrict output_slice_base,
34 float *__restrict mean_cache_slice,
35 float *__restrict rstd_cache_slice,
36 int num_tokens_in_slice,
38 int aligned_embed_dim,
41 float *scratch_output)
43 if (!scratch_input || !scratch_output)
return;
45 size_t total = (size_t)num_tokens_in_slice * (
size_t)aligned_embed_dim;
49 scratch_output, mean_cache_slice, rstd_cache_slice,
50 num_tokens_in_slice, d_model, aligned_embed_dim, eps);
58 const float *__restrict gamma,
59 const float *__restrict beta,
60 uint16_t *__restrict output_slice_base,
61 float *__restrict mean_cache_slice,
62 float *__restrict rstd_cache_slice,
63 int num_tokens_in_slice,
67 float *scratch_output)
69 if (!scratch_input || !scratch_output)
return;
71 size_t total = (size_t)num_tokens_in_slice * (
size_t)d_model;
75 scratch_output, mean_cache_slice, rstd_cache_slice,
76 num_tokens_in_slice, d_model, eps);
85 const uint16_t *input,
92 int tokens,
int d_model,
int aligned_embed_dim,
93 float *scratch_d_output,
95 float *scratch_d_input)
97 if (!scratch_d_output || !scratch_input || !scratch_d_input)
return;
99 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
105 scratch_d_input, d_gamma, d_beta,
106 tokens, d_model, aligned_embed_dim);
111 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void layernorm_backward_kernel_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output)
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)