← Back to C-Kernel-Engine Docs Doxygen Source Documentation
layernorm_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file layernorm_kernels_bf16.c
3  * @brief LayerNorm kernels for BF16 tensors
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * LayerNorm: y = gamma * (x - mean) / sqrt(var + eps) + beta
15  */
16 
17 #include <stdint.h>
18 
19 #include "bf16_utils.h"
20 #include "ckernel_engine.h"
21 
22 /* Suppress false positive warnings about uninitialized variables */
23 #pragma GCC diagnostic push
24 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
25 
26 /*
27  * BF16 LayerNorm forward (rolled) with caller-provided scratch buffers.
28  * scratch_input, scratch_output: each [num_tokens_in_slice * aligned_embed_dim] floats
29  */
30 void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base,
31  const float *__restrict gamma,
32  const float *__restrict beta,
33  uint16_t *__restrict output_slice_base,
34  float *__restrict mean_cache_slice,
35  float *__restrict rstd_cache_slice,
36  int num_tokens_in_slice,
37  int d_model,
38  int aligned_embed_dim,
39  float eps,
40  float *scratch_input,
41  float *scratch_output)
42 {
43  if (!scratch_input || !scratch_output) return;
44 
45  size_t total = (size_t)num_tokens_in_slice * (size_t)aligned_embed_dim;
46 
47  bf16_tensor_to_float(input_slice_base, scratch_input, total);
48  layernorm_forward_rolled_slice(scratch_input, gamma, beta,
49  scratch_output, mean_cache_slice, rstd_cache_slice,
50  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
51  float_tensor_to_bf16(scratch_output, output_slice_base, total);
52 }
53 
54 /*
55  * BF16 LayerNorm forward (unrolled) with caller-provided scratch buffers.
56  */
57 void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base,
58  const float *__restrict gamma,
59  const float *__restrict beta,
60  uint16_t *__restrict output_slice_base,
61  float *__restrict mean_cache_slice,
62  float *__restrict rstd_cache_slice,
63  int num_tokens_in_slice,
64  int d_model,
65  float eps,
66  float *scratch_input,
67  float *scratch_output)
68 {
69  if (!scratch_input || !scratch_output) return;
70 
71  size_t total = (size_t)num_tokens_in_slice * (size_t)d_model;
72 
73  bf16_tensor_to_float(input_slice_base, scratch_input, total);
74  layernorm_forward_unrolled_slice(scratch_input, gamma, beta,
75  scratch_output, mean_cache_slice, rstd_cache_slice,
76  num_tokens_in_slice, d_model, eps);
77  float_tensor_to_bf16(scratch_output, output_slice_base, total);
78 }
79 
80 /*
81  * BF16 LayerNorm backward with caller-provided scratch buffers.
82  * scratch_d_output, scratch_input, scratch_d_input: each [tokens * aligned_embed_dim] floats
83  */
84 void layernorm_backward_kernel_bf16(const uint16_t *d_output,
85  const uint16_t *input,
86  const float *gamma,
87  const float *mean,
88  const float *rstd,
89  uint16_t *d_input,
90  float *d_gamma,
91  float *d_beta,
92  int tokens, int d_model, int aligned_embed_dim,
93  float *scratch_d_output,
94  float *scratch_input,
95  float *scratch_d_input)
96 {
97  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
98 
99  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
100 
101  bf16_tensor_to_float(d_output, scratch_d_output, total);
102  bf16_tensor_to_float(input, scratch_input, total);
103 
104  layernorm_backward_kernel(scratch_d_output, scratch_input, gamma, mean, rstd,
105  scratch_d_input, d_gamma, d_beta,
106  tokens, d_model, aligned_embed_dim);
107 
108  float_tensor_to_bf16(scratch_d_input, d_input, total);
109 }
110 
111 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void layernorm_backward_kernel_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output)
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)