23 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
24 #include <immintrin.h>
27 #if defined(__AVX__) && !defined(__AVX512F__)
28 static inline float ck_hsum256_ps(__m256 v)
30 __m128 lo = _mm256_castps256_ps128(v);
31 __m128 hi = _mm256_extractf128_ps(v, 1);
32 __m128 sum128 = _mm_add_ps(lo, hi);
33 __m128 shuf = _mm_movehdup_ps(sum128);
34 __m128 sums = _mm_add_ps(sum128, shuf);
35 shuf = _mm_movehl_ps(shuf, sums);
36 sums = _mm_add_ss(sums, shuf);
37 return _mm_cvtss_f32(sums);
41 static inline float ck_dot_f32(
const float *a,
const float *b,
int len)
43 #if defined(__AVX512F__)
44 __m512 acc = _mm512_setzero_ps();
46 for (; i <= len - 16; i += 16) {
47 __m512 va = _mm512_loadu_ps(a + i);
48 __m512 vb = _mm512_loadu_ps(b + i);
49 acc = _mm512_fmadd_ps(va, vb, acc);
51 float sum = _mm512_reduce_add_ps(acc);
52 for (; i < len; ++i) {
56 #elif defined(__AVX__)
57 __m256 acc = _mm256_setzero_ps();
59 for (; i <= len - 8; i += 8) {
60 __m256 va = _mm256_loadu_ps(a + i);
61 __m256 vb = _mm256_loadu_ps(b + i);
62 acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
64 float sum = ck_hsum256_ps(acc);
65 for (; i < len; ++i) {
71 for (
int i = 0; i < len; ++i) {
79 const float *wq,
const float *bq,
80 const float *wk,
const float *bk,
81 const float *wv,
const float *bv,
85 int aligned_embed_dim,
90 if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
94 const int q_out = num_heads * aligned_head_dim;
96 1, q_out, aligned_embed_dim);
98 size_t head_weight_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
99 #pragma omp parallel for schedule(static) if(num_kv_heads > 1)
100 for (
int h = 0; h < num_kv_heads; ++h) {
101 const float *wk_h = wk + (size_t)h * head_weight_stride;
102 const float *wv_h = wv + (size_t)h * head_weight_stride;
103 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
104 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
105 float *k_h = k_token + (size_t)h * (
size_t)aligned_head_dim;
106 float *v_h = v_token + (size_t)h * (
size_t)aligned_head_dim;
109 1, aligned_head_dim, aligned_embed_dim);
111 1, aligned_head_dim, aligned_embed_dim);
120 int aligned_embed_dim,
122 int aligned_head_dim)
124 const size_t head_in_stride = (size_t)aligned_head_dim;
125 const size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
127 #pragma omp parallel for schedule(static)
128 for (
int j = 0; j < embed_dim; ++j) {
129 float sum = bo ? bo[j] : 0.0f;
130 for (
int h = 0; h < num_heads; ++h) {
131 const float *head_in = attn_token + (size_t)h * head_in_stride;
132 const float *wo_row = wo + (size_t)h * head_weight_stride + (
size_t)j * (size_t)aligned_head_dim;
133 sum +=
ck_dot_f32(head_in, wo_row, aligned_head_dim);
138 for (
int j = embed_dim; j < aligned_embed_dim; ++j) {
148 const float *residual_in,
152 int aligned_embed_dim,
154 int aligned_head_dim)
156 const size_t head_in_stride = (size_t)aligned_head_dim;
157 const size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
159 #pragma omp parallel for schedule(static)
160 for (
int j = 0; j < embed_dim; ++j) {
161 float sum = bo ? bo[j] : 0.0f;
162 for (
int h = 0; h < num_heads; ++h) {
163 const float *head_in = attn_token + (size_t)h * head_in_stride;
164 const float *wo_row = wo + (size_t)h * head_weight_stride + (
size_t)j * (size_t)aligned_head_dim;
165 sum +=
ck_dot_f32(head_in, wo_row, aligned_head_dim);
170 residual_out[j] = sum + residual_in[j];
173 for (
int j = embed_dim; j < aligned_embed_dim; ++j) {
177 residual_out[j] = 0.0f;
194 if (!fuse_mlp && !p->
fc1_out) {
197 if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
213 const size_t token_slot = 0;
214 const float *input_row = p->
input + token_slot * (size_t)aligned_D;
215 float *proj_row = NULL;
216 float *residual_row = p->
residual1 + token_slot * (size_t)aligned_D;
217 float *ln2_row = p->
ln2_out + token_slot * (size_t)aligned_D;
218 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
219 float *mlp_row = p->
mlp_out + token_slot * (size_t)aligned_D;
220 float *out_row = p->
output + token_slot * (size_t)aligned_D;
222 float ln1_rstd_tmp = 0.0f;
223 float ln2_rstd_tmp = 0.0f;
226 float ln1_row[aligned_D];
237 size_t q_elems = (size_t)H * (
size_t)ad;
238 size_t kv_elems = (size_t)H_kv * (
size_t)ad;
239 float q_token[q_elems];
240 float k_token[kv_elems];
241 float v_token[kv_elems];
242 float attn_token[q_elems];
248 q_token, k_token, v_token,
318 aligned_intermediate);
320 int up_dim = 2 * aligned_intermediate;
321 float *fc1_row = p->
fc1_out + token_slot * (size_t)up_dim;
333 aligned_intermediate);
static void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(const CKLayerForwardParams *p, int token_index, int cache_capacity, int fuse_mlp)
static float ck_dot_f32(const float *a, const float *b, int len)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(const CKLayerForwardParams *p, int token_index, int cache_capacity)
static void ck_attention_project_head_major_decode_token_residual(const float *attn_token, const float *wo, const float *bo, const float *residual_in, float *proj_out, float *residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
int aligned_intermediate_dim