← Back to C-Kernel-Engine Docs Doxygen Source Documentation
v6.6/test_generated/ck-kernel-inference.c File Reference

AUTO-GENERATED: qwen2_0.5b_decode Implementation (IR v6 - Explicit Unrolled) More...

#include "ck-kernel-inference.h"
#include "ckernel_engine.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <math.h>

Go to the source code of this file.

Macros

#define _GNU_SOURCE   /* For MAP_ANONYMOUS, MAP_HUGETLB */
 

Functions

struct __attribute__ ((packed))
 
 _Static_assert (sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
 
static int qwen2_0_5b_decode_align_elems (int elems, int elem_bytes, int align_bytes)
 
void qwen2_0_5b_decode_decode (QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
 
static void qwen2_0_5b_decode_decode_token (QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
 
void qwen2_0_5b_decode_forward (QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
 
static void qwen2_0_5b_decode_forward_prefill_impl (QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
 
static void qwen2_0_5b_decode_layer_0_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_0_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_10_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_10_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_11_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_11_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_12_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_12_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_13_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_13_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_14_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_14_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_15_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_15_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_16_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_16_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_17_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_17_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_18_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_18_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_19_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_19_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_1_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_1_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_20_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_20_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_21_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_21_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_22_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_22_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_23_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_23_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_2_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_2_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_3_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_3_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_4_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_4_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_5_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_5_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_6_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_6_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_7_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_7_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_8_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_8_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_9_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_9_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
int qwen2_0_5b_decode_model_allocate (QWEN2_0_5B_DECODEModel *model)
 
void qwen2_0_5b_decode_model_free (QWEN2_0_5B_DECODEModel *model)
 
void qwen2_0_5b_decode_precompute_rope (QWEN2_0_5B_DECODEModel *model)
 
static void qwen2_0_5b_decode_residual_add_token_major (const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
 
int qwen2_0_5b_decode_verify_canaries (QWEN2_0_5B_DECODEModel *model)
 

Variables

 MagicHeader
 

Detailed Description

AUTO-GENERATED: qwen2_0.5b_decode Implementation (IR v6 - Explicit Unrolled)

Generated: 2026-01-12T04:06:36.662558 UTC Total Memory: 3.57 GB Mode: decode Layers: 24 (fully unrolled)

Per-layer quant types: Layer 0: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k Layer 1: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k Layer 2: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k ... (21 more layers)

DO NOT EDIT - Regenerate with build_ir_v6.py or codegen_v6.py

Definition in file v6.6/test_generated/ck-kernel-inference.c.

Macro Definition Documentation

◆ _GNU_SOURCE

#define _GNU_SOURCE   /* For MAP_ANONYMOUS, MAP_HUGETLB */

Definition at line 19 of file v6.6/test_generated/ck-kernel-inference.c.

Function Documentation

◆ __attribute__()

struct __attribute__ ( (packed)  )

Definition at line 43 of file v6.6/test_generated/ck-kernel-inference.c.

67  {
68  uint32_t magic; /* 0x434B454E */
69  uint32_t version; /* IR version */
70  uint64_t total_bytes;
71  uint64_t weight_bytes;
72  uint64_t activation_bytes;
73  uint32_t num_layers;
74  uint32_t embed_dim;
75  uint32_t num_heads;
76  uint32_t vocab_size;
77  uint32_t max_seq_len;
78  uint32_t canary_count;
79  uint8_t reserved[8]; /* Pad to 64 bytes */
80 } MagicHeader;
int vocab_size
Definition: true_bpe.h:185

◆ _Static_assert()

_Static_assert ( sizeof(MagicHeader = =64,
"MagicHeader must be 64 bytes"   
)

◆ qwen2_0_5b_decode_align_elems()

static int qwen2_0_5b_decode_align_elems ( int  elems,
int  elem_bytes,
int  align_bytes 
)
static

Definition at line 176 of file v6.6/test_generated/ck-kernel-inference.c.

176  {
177  int bytes = elems * elem_bytes;
178  int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179  return aligned / elem_bytes;
180 }

◆ qwen2_0_5b_decode_decode()

void qwen2_0_5b_decode_decode ( QWEN2_0_5B_DECODEModel model,
const int *  token,
int  token_index 
)

Definition at line 8022 of file v6.6/test_generated/ck-kernel-inference.c.

8022  {
8023  qwen2_0_5b_decode_decode_token(model, token, token_index);
8024 }
const char * token
Definition: tokenizer.h:306
static void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)

References qwen2_0_5b_decode_decode_token(), and token.

◆ qwen2_0_5b_decode_decode_token()

static void qwen2_0_5b_decode_decode_token ( QWEN2_0_5B_DECODEModel model,
const int *  token,
int  token_index 
)
static

Definition at line 7934 of file v6.6/test_generated/ck-kernel-inference.c.

7938  {
7939  if (!model || !token) return;
7940 
7941  const int aligned_embed_dim = 1024;
7942  const int aligned_head_dim = 64;
7943  const int aligned_intermediate_dim = 4864;
7944  const int aligned_context_window = 131072;
7945 
7946  if (token_index < 0 || token_index >= aligned_context_window) return;
7947 
7948  /* Embedding lookup */
7950  const void *embed_weight = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_HEADER.token_emb);
7951  /* Embedding: Q4_K -> embedding_forward_q4_k */
7952  embedding_forward_q4_k((const int32_t *)token,
7953  1,
7955  embed_weight,
7956  NULL,
7957  embed_out,
7959  aligned_embed_dim,
7960  1,
7961  0);
7962 
7963  /* Process each layer explicitly */
7964  qwen2_0_5b_decode_layer_0_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7965  qwen2_0_5b_decode_layer_1_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7966  qwen2_0_5b_decode_layer_2_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7967  qwen2_0_5b_decode_layer_3_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7968  qwen2_0_5b_decode_layer_4_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7969  qwen2_0_5b_decode_layer_5_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7970  qwen2_0_5b_decode_layer_6_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7971  qwen2_0_5b_decode_layer_7_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7972  qwen2_0_5b_decode_layer_8_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7973  qwen2_0_5b_decode_layer_9_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7974  qwen2_0_5b_decode_layer_10_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7975  qwen2_0_5b_decode_layer_11_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7976  qwen2_0_5b_decode_layer_12_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7977  qwen2_0_5b_decode_layer_13_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7978  qwen2_0_5b_decode_layer_14_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7979  qwen2_0_5b_decode_layer_15_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7980  qwen2_0_5b_decode_layer_16_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7981  qwen2_0_5b_decode_layer_17_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7982  qwen2_0_5b_decode_layer_18_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7983  qwen2_0_5b_decode_layer_19_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7984  qwen2_0_5b_decode_layer_20_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7985  qwen2_0_5b_decode_layer_21_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7986  qwen2_0_5b_decode_layer_22_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7987  qwen2_0_5b_decode_layer_23_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
7988 
7989  /* Final RMSNorm */
7990  float *last_hidden = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[23].output);
7991  float *final_ln_weight = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_ln_weight);
7992  float *final_out = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_output);
7993  rmsnorm_forward(last_hidden,
7994  final_ln_weight,
7995  final_out,
7996  NULL,
7997  1,
7999  aligned_embed_dim,
8000  1e-06f);
8001 
8002  /* LM head projection */
8003  float *logits = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.logits);
8004  const void *lm_head = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.lm_head_weight);
8005  /* LM head: Q4_K -> gemm_nt_q4_k */
8006  gemm_nt_q4_k(final_out, lm_head, NULL, logits, 1, QWEN2_0_5B_DECODE_VOCAB_SIZE, aligned_embed_dim);
8007 }
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
#define QWEN2_0_5B_DECODE_PTR(model, offset)
static const QWEN2_0_5B_DECODEFooterOffsets QWEN2_0_5B_DECODE_FOOTER
static const QWEN2_0_5B_DECODELayerOffsets QWEN2_0_5B_DECODE_LAYERS[24]
#define QWEN2_0_5B_DECODE_VOCAB_SIZE
static const QWEN2_0_5B_DECODEHeaderOffsets QWEN2_0_5B_DECODE_HEADER
static void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

References QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, embedding_forward_q4_k(), QWEN2_0_5B_DECODEFooterOffsets::final_ln_weight, QWEN2_0_5B_DECODEFooterOffsets::final_output, gemm_nt_q4_k(), QWEN2_0_5B_DECODEFooterOffsets::lm_head_weight, QWEN2_0_5B_DECODEFooterOffsets::logits, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_FOOTER, QWEN2_0_5B_DECODE_HEADER, qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODE_VOCAB_SIZE, rmsnorm_forward(), token, and QWEN2_0_5B_DECODEHeaderOffsets::token_emb.

Referenced by qwen2_0_5b_decode_decode().

◆ qwen2_0_5b_decode_forward()

void qwen2_0_5b_decode_forward ( QWEN2_0_5B_DECODEModel model,
const int *  tokens,
int  num_tokens 
)

Definition at line 8013 of file v6.6/test_generated/ck-kernel-inference.c.

8017  {
8018  if (!model || !tokens || num_tokens <= 0) return;
8019  qwen2_0_5b_decode_forward_prefill_impl(model, tokens, num_tokens);
8020 }
static void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)

References qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_forward_prefill_impl()

static void qwen2_0_5b_decode_forward_prefill_impl ( QWEN2_0_5B_DECODEModel model,
const int *  tokens,
int  num_tokens 
)
static

Definition at line 4076 of file v6.6/test_generated/ck-kernel-inference.c.

4080  {
4081  if (!model || !tokens || num_tokens <= 0) {
4082  return;
4083  }
4084 
4085  const int elem_bytes = QWEN2_0_5B_DECODE_DTYPE_BYTES;
4086  const int aligned_embed_dim = 1024;
4087  const int aligned_head_dim = 64;
4088  const int aligned_intermediate_dim = 4864;
4089  const int aligned_context_window = 131072;
4090 
4092  const void *embed_weight = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_HEADER.token_emb);
4093  embedding_forward_q4_k((const int32_t *)tokens,
4094  num_tokens,
4096  embed_weight,
4097  NULL,
4098  embed_out,
4100  aligned_embed_dim,
4101  num_tokens,
4102  0);
4103 
4105  model,
4106  num_tokens,
4107  aligned_embed_dim,
4108  aligned_head_dim,
4109  aligned_intermediate_dim,
4110  aligned_context_window);
4111 
4115  num_tokens,
4116  aligned_context_window,
4117  aligned_head_dim);
4121  num_tokens,
4122  aligned_context_window,
4123  aligned_head_dim);
4124 
4126  model,
4127  num_tokens,
4128  aligned_embed_dim,
4129  aligned_head_dim,
4130  aligned_intermediate_dim,
4131  aligned_context_window);
4132 
4136  num_tokens,
4137  aligned_context_window,
4138  aligned_head_dim);
4142  num_tokens,
4143  aligned_context_window,
4144  aligned_head_dim);
4145 
4147  model,
4148  num_tokens,
4149  aligned_embed_dim,
4150  aligned_head_dim,
4151  aligned_intermediate_dim,
4152  aligned_context_window);
4153 
4157  num_tokens,
4158  aligned_context_window,
4159  aligned_head_dim);
4163  num_tokens,
4164  aligned_context_window,
4165  aligned_head_dim);
4166 
4168  model,
4169  num_tokens,
4170  aligned_embed_dim,
4171  aligned_head_dim,
4172  aligned_intermediate_dim,
4173  aligned_context_window);
4174 
4178  num_tokens,
4179  aligned_context_window,
4180  aligned_head_dim);
4184  num_tokens,
4185  aligned_context_window,
4186  aligned_head_dim);
4187 
4189  model,
4190  num_tokens,
4191  aligned_embed_dim,
4192  aligned_head_dim,
4193  aligned_intermediate_dim,
4194  aligned_context_window);
4195 
4199  num_tokens,
4200  aligned_context_window,
4201  aligned_head_dim);
4205  num_tokens,
4206  aligned_context_window,
4207  aligned_head_dim);
4208 
4210  model,
4211  num_tokens,
4212  aligned_embed_dim,
4213  aligned_head_dim,
4214  aligned_intermediate_dim,
4215  aligned_context_window);
4216 
4220  num_tokens,
4221  aligned_context_window,
4222  aligned_head_dim);
4226  num_tokens,
4227  aligned_context_window,
4228  aligned_head_dim);
4229 
4231  model,
4232  num_tokens,
4233  aligned_embed_dim,
4234  aligned_head_dim,
4235  aligned_intermediate_dim,
4236  aligned_context_window);
4237 
4241  num_tokens,
4242  aligned_context_window,
4243  aligned_head_dim);
4247  num_tokens,
4248  aligned_context_window,
4249  aligned_head_dim);
4250 
4252  model,
4253  num_tokens,
4254  aligned_embed_dim,
4255  aligned_head_dim,
4256  aligned_intermediate_dim,
4257  aligned_context_window);
4258 
4262  num_tokens,
4263  aligned_context_window,
4264  aligned_head_dim);
4268  num_tokens,
4269  aligned_context_window,
4270  aligned_head_dim);
4271 
4273  model,
4274  num_tokens,
4275  aligned_embed_dim,
4276  aligned_head_dim,
4277  aligned_intermediate_dim,
4278  aligned_context_window);
4279 
4283  num_tokens,
4284  aligned_context_window,
4285  aligned_head_dim);
4289  num_tokens,
4290  aligned_context_window,
4291  aligned_head_dim);
4292 
4294  model,
4295  num_tokens,
4296  aligned_embed_dim,
4297  aligned_head_dim,
4298  aligned_intermediate_dim,
4299  aligned_context_window);
4300 
4304  num_tokens,
4305  aligned_context_window,
4306  aligned_head_dim);
4310  num_tokens,
4311  aligned_context_window,
4312  aligned_head_dim);
4313 
4315  model,
4316  num_tokens,
4317  aligned_embed_dim,
4318  aligned_head_dim,
4319  aligned_intermediate_dim,
4320  aligned_context_window);
4321 
4325  num_tokens,
4326  aligned_context_window,
4327  aligned_head_dim);
4331  num_tokens,
4332  aligned_context_window,
4333  aligned_head_dim);
4334 
4336  model,
4337  num_tokens,
4338  aligned_embed_dim,
4339  aligned_head_dim,
4340  aligned_intermediate_dim,
4341  aligned_context_window);
4342 
4346  num_tokens,
4347  aligned_context_window,
4348  aligned_head_dim);
4352  num_tokens,
4353  aligned_context_window,
4354  aligned_head_dim);
4355 
4357  model,
4358  num_tokens,
4359  aligned_embed_dim,
4360  aligned_head_dim,
4361  aligned_intermediate_dim,
4362  aligned_context_window);
4363 
4367  num_tokens,
4368  aligned_context_window,
4369  aligned_head_dim);
4373  num_tokens,
4374  aligned_context_window,
4375  aligned_head_dim);
4376 
4378  model,
4379  num_tokens,
4380  aligned_embed_dim,
4381  aligned_head_dim,
4382  aligned_intermediate_dim,
4383  aligned_context_window);
4384 
4388  num_tokens,
4389  aligned_context_window,
4390  aligned_head_dim);
4394  num_tokens,
4395  aligned_context_window,
4396  aligned_head_dim);
4397 
4399  model,
4400  num_tokens,
4401  aligned_embed_dim,
4402  aligned_head_dim,
4403  aligned_intermediate_dim,
4404  aligned_context_window);
4405 
4409  num_tokens,
4410  aligned_context_window,
4411  aligned_head_dim);
4415  num_tokens,
4416  aligned_context_window,
4417  aligned_head_dim);
4418 
4420  model,
4421  num_tokens,
4422  aligned_embed_dim,
4423  aligned_head_dim,
4424  aligned_intermediate_dim,
4425  aligned_context_window);
4426 
4430  num_tokens,
4431  aligned_context_window,
4432  aligned_head_dim);
4436  num_tokens,
4437  aligned_context_window,
4438  aligned_head_dim);
4439 
4441  model,
4442  num_tokens,
4443  aligned_embed_dim,
4444  aligned_head_dim,
4445  aligned_intermediate_dim,
4446  aligned_context_window);
4447 
4451  num_tokens,
4452  aligned_context_window,
4453  aligned_head_dim);
4457  num_tokens,
4458  aligned_context_window,
4459  aligned_head_dim);
4460 
4462  model,
4463  num_tokens,
4464  aligned_embed_dim,
4465  aligned_head_dim,
4466  aligned_intermediate_dim,
4467  aligned_context_window);
4468 
4472  num_tokens,
4473  aligned_context_window,
4474  aligned_head_dim);
4478  num_tokens,
4479  aligned_context_window,
4480  aligned_head_dim);
4481 
4483  model,
4484  num_tokens,
4485  aligned_embed_dim,
4486  aligned_head_dim,
4487  aligned_intermediate_dim,
4488  aligned_context_window);
4489 
4493  num_tokens,
4494  aligned_context_window,
4495  aligned_head_dim);
4499  num_tokens,
4500  aligned_context_window,
4501  aligned_head_dim);
4502 
4504  model,
4505  num_tokens,
4506  aligned_embed_dim,
4507  aligned_head_dim,
4508  aligned_intermediate_dim,
4509  aligned_context_window);
4510 
4514  num_tokens,
4515  aligned_context_window,
4516  aligned_head_dim);
4520  num_tokens,
4521  aligned_context_window,
4522  aligned_head_dim);
4523 
4525  model,
4526  num_tokens,
4527  aligned_embed_dim,
4528  aligned_head_dim,
4529  aligned_intermediate_dim,
4530  aligned_context_window);
4531 
4535  num_tokens,
4536  aligned_context_window,
4537  aligned_head_dim);
4541  num_tokens,
4542  aligned_context_window,
4543  aligned_head_dim);
4544 
4546  model,
4547  num_tokens,
4548  aligned_embed_dim,
4549  aligned_head_dim,
4550  aligned_intermediate_dim,
4551  aligned_context_window);
4552 
4556  num_tokens,
4557  aligned_context_window,
4558  aligned_head_dim);
4562  num_tokens,
4563  aligned_context_window,
4564  aligned_head_dim);
4565 
4567  model,
4568  num_tokens,
4569  aligned_embed_dim,
4570  aligned_head_dim,
4571  aligned_intermediate_dim,
4572  aligned_context_window);
4573 
4577  num_tokens,
4578  aligned_context_window,
4579  aligned_head_dim);
4583  num_tokens,
4584  aligned_context_window,
4585  aligned_head_dim);
4586 
4588  model,
4589  num_tokens,
4590  aligned_embed_dim,
4591  aligned_head_dim,
4592  aligned_intermediate_dim,
4593  aligned_context_window);
4594 
4598  num_tokens,
4599  aligned_context_window,
4600  aligned_head_dim);
4604  num_tokens,
4605  aligned_context_window,
4606  aligned_head_dim);
4607 
4608  float *last_hidden = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[QWEN2_0_5B_DECODE_NUM_LAYERS - 1].output);
4609  float *final_ln_weight = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_ln_weight);
4610  float *final_out = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_output);
4611  rmsnorm_forward(last_hidden,
4612  final_ln_weight,
4613  final_out,
4614  NULL,
4615  num_tokens,
4617  aligned_embed_dim,
4618  1e-06f);
4619 
4620  float *logits = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.logits);
4621  const void *lm_head = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.lm_head_weight);
4622  const size_t q8_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_embed_dim);
4623  for (int t = 0; t < num_tokens; ++t) {
4624  uint8_t q8_buf[q8_bytes];
4625  const float *row = final_out + (size_t)t * (size_t)aligned_embed_dim;
4626  float *logits_row = logits + (size_t)t * (size_t)QWEN2_0_5B_DECODE_VOCAB_SIZE;
4627  quantize_row_q8_k(row, q8_buf, aligned_embed_dim);
4628  gemm_nt_q4_k_q8_k(q8_buf,
4629  lm_head,
4630  NULL,
4631  logits_row,
4632  1,
4634  aligned_embed_dim);
4635  }
4636 }
@ CK_DT_Q8_K
Definition: ckernel_dtype.h:43
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void kv_cache_repack_head_major_inplace(float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void quantize_row_q8_k(const float *x, void *y, int k)
#define QWEN2_0_5B_DECODE_DTYPE_BYTES
#define QWEN2_0_5B_DECODE_NUM_LAYERS
#define QWEN2_0_5B_DECODE_NUM_KV_HEADS
static void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

References CK_DT_Q8_K, ck_dtype_row_bytes(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, embedding_forward_q4_k(), QWEN2_0_5B_DECODEFooterOffsets::final_ln_weight, QWEN2_0_5B_DECODEFooterOffsets::final_output, gemm_nt_q4_k_q8_k(), kv_cache_repack_head_major_inplace(), QWEN2_0_5B_DECODEFooterOffsets::lm_head_weight, QWEN2_0_5B_DECODEFooterOffsets::logits, quantize_row_q8_k(), QWEN2_0_5B_DECODE_DTYPE_BYTES, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_FOOTER, QWEN2_0_5B_DECODE_HEADER, qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_prefill(), QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_NUM_LAYERS, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODE_VOCAB_SIZE, rmsnorm_forward(), and QWEN2_0_5B_DECODEHeaderOffsets::token_emb.

Referenced by qwen2_0_5b_decode_forward().

◆ qwen2_0_5b_decode_layer_0_decode()

static void qwen2_0_5b_decode_layer_0_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4645 of file v6.6/test_generated/ck-kernel-inference.c.

4652  {
4654 
4656 
4657  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4658  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4659  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4660  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4661  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4662  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4663  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4664  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4665  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4666  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4667  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4668 
4669  /* Weights (explicit types for layer 0) */
4670  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4671  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4672  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4673  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4674  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4675  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4676 
4679 
4680  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4681  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4682  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4683 
4684  float q_token[H * aligned_head_dim];
4685  float k_token[H_kv * aligned_head_dim];
4686  float v_token[H_kv * aligned_head_dim];
4687  float attn_token[H * aligned_head_dim];
4688 
4689  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4690  float fc1_out[2 * aligned_intermediate_dim];
4691  float swiglu_out[aligned_intermediate_dim];
4692 
4693  /* Step 1: RMSNorm before attention */
4694  rmsnorm_forward(input,
4695  ln1_gamma,
4696  ln1_out,
4697  NULL,
4698  1,
4700  aligned_embed_dim,
4701  1e-06f);
4702 
4703  /* Step 2: QKV projection */
4704  /* Q projection: Q4_K -> gemm_nt_q4_k */
4705  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4706 
4707  /* K projection: Q4_K -> gemm_nt_q4_k */
4708  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4709 
4710  /* V projection: Q4_K -> gemm_nt_q4_k */
4711  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
4712 
4713  /* Step 3: RoPE */
4714  rope_forward_qk(q_token,
4715  k_token,
4716  rope_cos,
4717  rope_sin,
4718  H,
4719  H_kv,
4720  1,
4721  head_dim,
4722  aligned_head_dim,
4723  token_index);
4724 
4725  /* Step 4: KV cache write */
4726  kv_cache_write_head_major(k_token,
4727  v_token,
4728  k_cache,
4729  v_cache,
4730  H_kv,
4731  token_index,
4732  aligned_context_window,
4733  head_dim,
4734  aligned_head_dim);
4735 
4736  /* Step 5: Attention (decode) */
4738  k_cache,
4739  v_cache,
4740  attn_token,
4741  H,
4742  H_kv,
4743  token_index + 1,
4744  aligned_context_window,
4745  head_dim,
4746  aligned_head_dim);
4747 
4748  /* Step 6: Output projection */
4749  /* WO projection: Q4_K -> gemm_nt_q4_k */
4750  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4751 
4752  /* Step 7: Residual add */
4753  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
4754 
4755  /* Step 8: RMSNorm before MLP */
4756  rmsnorm_forward(residual1,
4757  ln2_gamma,
4758  ln2_out,
4759  NULL,
4760  1,
4762  aligned_embed_dim,
4763  1e-06f);
4764 
4765  /* Step 9: MLP (SwiGLU) */
4766  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
4767  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4768 
4769  /* SwiGLU activation */
4770  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4771 
4772  /* Down projection: Q4_K -> gemm_nt_q4_k */
4773  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4774 
4775  /* Step 10: Final residual add */
4776  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
4777 }
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
static const QWEN2_0_5B_DECODEGlobalOffsets QWEN2_0_5B_DECODE_GLOBALS
static void qwen2_0_5b_decode_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)

References attention_forward_decode_head_major_gqa_regular(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_HEADER, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_0_prefill()

static void qwen2_0_5b_decode_layer_0_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 211 of file v6.6/test_generated/ck-kernel-inference.c.

218  {
220 
222  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
223  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
224  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
225  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
226  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
227  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
228  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
229  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
230  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
231  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
232  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
233  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
234  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
235  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
236  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
237 
238  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
239  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
240  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
241  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
242  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
243  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
244  const float *BQ = NULL;
245  const float *BK = NULL;
246  const float *BV = NULL;
247  const float *BO = NULL;
248  const float *B1 = NULL;
249  const float *B2 = NULL;
250 
253 
254  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
255  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
256  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
257  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
258  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
259  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
260 
261  /* RMSNorm before attention */
262  rmsnorm_forward(input,
263  ln1_gamma,
264  ln1_out,
265  NULL,
266  num_tokens,
268  aligned_embed_dim,
269  1e-06f);
270 
271  /* Q projection (head-major) */
272  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
273  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
274  for (int h = 0; h < H; ++h) {
275  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
276  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
277  float *q_h = q + (size_t)h * q_head_stride;
278  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
279  }
280 
281  /* K projection (head-major) */
282  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
283  const uint8_t *WK_bytes = (const uint8_t *)WK;
284  for (int h = 0; h < H_kv; ++h) {
285  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
286  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
287  float *k_h = k + (size_t)h * kv_head_stride;
288  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
289  }
290 
291  /* V projection (head-major) */
292  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
293  const uint8_t *WV_bytes = (const uint8_t *)WV;
294  for (int h = 0; h < H_kv; ++h) {
295  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
296  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
297  float *v_h = v + (size_t)h * kv_head_stride;
298  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
299  }
300 
301  /* RoPE */
302  rope_forward_qk(q,
303  k,
304  rope_cos,
305  rope_sin,
306  H,
307  H_kv,
308  num_tokens,
309  head_dim,
310  aligned_head_dim,
311  0);
312 
313  /* Attention (prefill, causal) */
315  k,
316  v,
317  attn_out,
318  H,
319  H_kv,
320  num_tokens,
321  head_dim,
322  aligned_head_dim);
323 
324  /* Output projection (flatten head-major to token-major) */
325  const int K = H * aligned_head_dim;
326  if (K != aligned_embed_dim) {
327  return;
328  }
329  const float *proj_in = attn_out;
330  if (H > 1) {
331  if (!proj_scratch) {
332  return;
333  }
334  for (int t = 0; t < num_tokens; ++t) {
335  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
336  for (int h = 0; h < H; ++h) {
337  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
338  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
339  src,
340  (size_t)aligned_head_dim * sizeof(float));
341  }
342  }
343  proj_in = proj_scratch;
344  }
345  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
346 
347  /* Residual add */
348  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
349 
350  /* RMSNorm before MLP */
351  rmsnorm_forward(residual1,
352  ln2_gamma,
353  ln2_out,
354  NULL,
355  num_tokens,
357  aligned_embed_dim,
358  1e-06f);
359 
360  /* MLP (SwiGLU) */
361  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
362  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
363  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
364 
365  /* Final residual add */
366  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
367 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_HEADER, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_10_decode()

static void qwen2_0_5b_decode_layer_10_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6015 of file v6.6/test_generated/ck-kernel-inference.c.

6022  {
6024 
6025  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[9].output);
6026 
6027  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6028  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6029  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6030  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6031  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6032  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6033  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6034  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6035  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6036  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6037  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6038 
6039  /* Weights (explicit types for layer 10) */
6040  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6041  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6042  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6043  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6044  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6045  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6046 
6049 
6050  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6051  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6052  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6053 
6054  float q_token[H * aligned_head_dim];
6055  float k_token[H_kv * aligned_head_dim];
6056  float v_token[H_kv * aligned_head_dim];
6057  float attn_token[H * aligned_head_dim];
6058 
6059  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6060  float fc1_out[2 * aligned_intermediate_dim];
6061  float swiglu_out[aligned_intermediate_dim];
6062 
6063  /* Step 1: RMSNorm before attention */
6064  rmsnorm_forward(input,
6065  ln1_gamma,
6066  ln1_out,
6067  NULL,
6068  1,
6070  aligned_embed_dim,
6071  1e-06f);
6072 
6073  /* Step 2: QKV projection */
6074  /* Q projection: Q4_K -> gemm_nt_q4_k */
6075  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6076 
6077  /* K projection: Q4_K -> gemm_nt_q4_k */
6078  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6079 
6080  /* V projection: Q4_K -> gemm_nt_q4_k */
6081  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6082 
6083  /* Step 3: RoPE */
6084  rope_forward_qk(q_token,
6085  k_token,
6086  rope_cos,
6087  rope_sin,
6088  H,
6089  H_kv,
6090  1,
6091  head_dim,
6092  aligned_head_dim,
6093  token_index);
6094 
6095  /* Step 4: KV cache write */
6096  kv_cache_write_head_major(k_token,
6097  v_token,
6098  k_cache,
6099  v_cache,
6100  H_kv,
6101  token_index,
6102  aligned_context_window,
6103  head_dim,
6104  aligned_head_dim);
6105 
6106  /* Step 5: Attention (decode) */
6108  k_cache,
6109  v_cache,
6110  attn_token,
6111  H,
6112  H_kv,
6113  token_index + 1,
6114  aligned_context_window,
6115  head_dim,
6116  aligned_head_dim);
6117 
6118  /* Step 6: Output projection */
6119  /* WO projection: Q4_K -> gemm_nt_q4_k */
6120  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6121 
6122  /* Step 7: Residual add */
6123  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6124 
6125  /* Step 8: RMSNorm before MLP */
6126  rmsnorm_forward(residual1,
6127  ln2_gamma,
6128  ln2_out,
6129  NULL,
6130  1,
6132  aligned_embed_dim,
6133  1e-06f);
6134 
6135  /* Step 9: MLP (SwiGLU) */
6136  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6137  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6138 
6139  /* SwiGLU activation */
6140  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6141 
6142  /* Down projection: Q4_K -> gemm_nt_q4_k */
6143  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6144 
6145  /* Step 10: Final residual add */
6146  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6147 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_10_prefill()

static void qwen2_0_5b_decode_layer_10_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1821 of file v6.6/test_generated/ck-kernel-inference.c.

1828  {
1830 
1831  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[9].output);
1832  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1833  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1834  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1835  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1836  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1837  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1838  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1839  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1840  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1841  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1842  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1843  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1844  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1845  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1846  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1847 
1848  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1849  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1850  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1851  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1852  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1853  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1854  const float *BQ = NULL;
1855  const float *BK = NULL;
1856  const float *BV = NULL;
1857  const float *BO = NULL;
1858  const float *B1 = NULL;
1859  const float *B2 = NULL;
1860 
1863 
1864  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1865  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1866  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1867  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1868  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1869  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1870 
1871  /* RMSNorm before attention */
1872  rmsnorm_forward(input,
1873  ln1_gamma,
1874  ln1_out,
1875  NULL,
1876  num_tokens,
1878  aligned_embed_dim,
1879  1e-06f);
1880 
1881  /* Q projection (head-major) */
1882  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1883  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1884  for (int h = 0; h < H; ++h) {
1885  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1886  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1887  float *q_h = q + (size_t)h * q_head_stride;
1888  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1889  }
1890 
1891  /* K projection (head-major) */
1892  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1893  const uint8_t *WK_bytes = (const uint8_t *)WK;
1894  for (int h = 0; h < H_kv; ++h) {
1895  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1896  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1897  float *k_h = k + (size_t)h * kv_head_stride;
1898  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1899  }
1900 
1901  /* V projection (head-major) */
1902  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1903  const uint8_t *WV_bytes = (const uint8_t *)WV;
1904  for (int h = 0; h < H_kv; ++h) {
1905  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1906  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1907  float *v_h = v + (size_t)h * kv_head_stride;
1908  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1909  }
1910 
1911  /* RoPE */
1912  rope_forward_qk(q,
1913  k,
1914  rope_cos,
1915  rope_sin,
1916  H,
1917  H_kv,
1918  num_tokens,
1919  head_dim,
1920  aligned_head_dim,
1921  0);
1922 
1923  /* Attention (prefill, causal) */
1925  k,
1926  v,
1927  attn_out,
1928  H,
1929  H_kv,
1930  num_tokens,
1931  head_dim,
1932  aligned_head_dim);
1933 
1934  /* Output projection (flatten head-major to token-major) */
1935  const int K = H * aligned_head_dim;
1936  if (K != aligned_embed_dim) {
1937  return;
1938  }
1939  const float *proj_in = attn_out;
1940  if (H > 1) {
1941  if (!proj_scratch) {
1942  return;
1943  }
1944  for (int t = 0; t < num_tokens; ++t) {
1945  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1946  for (int h = 0; h < H; ++h) {
1947  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1948  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1949  src,
1950  (size_t)aligned_head_dim * sizeof(float));
1951  }
1952  }
1953  proj_in = proj_scratch;
1954  }
1955  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1956 
1957  /* Residual add */
1958  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1959 
1960  /* RMSNorm before MLP */
1961  rmsnorm_forward(residual1,
1962  ln2_gamma,
1963  ln2_out,
1964  NULL,
1965  num_tokens,
1967  aligned_embed_dim,
1968  1e-06f);
1969 
1970  /* MLP (SwiGLU) */
1971  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1972  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1973  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1974 
1975  /* Final residual add */
1976  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1977 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_11_decode()

static void qwen2_0_5b_decode_layer_11_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6152 of file v6.6/test_generated/ck-kernel-inference.c.

6159  {
6161 
6162  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[10].output);
6163 
6164  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6165  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6166  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6167  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6168  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6169  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6170  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6171  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6172  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6173  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6174  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6175 
6176  /* Weights (explicit types for layer 11) */
6177  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6178  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6179  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6180  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6181  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6182  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6183 
6186 
6187  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6188  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6189  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6190 
6191  float q_token[H * aligned_head_dim];
6192  float k_token[H_kv * aligned_head_dim];
6193  float v_token[H_kv * aligned_head_dim];
6194  float attn_token[H * aligned_head_dim];
6195 
6196  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6197  float fc1_out[2 * aligned_intermediate_dim];
6198  float swiglu_out[aligned_intermediate_dim];
6199 
6200  /* Step 1: RMSNorm before attention */
6201  rmsnorm_forward(input,
6202  ln1_gamma,
6203  ln1_out,
6204  NULL,
6205  1,
6207  aligned_embed_dim,
6208  1e-06f);
6209 
6210  /* Step 2: QKV projection */
6211  /* Q projection: Q4_K -> gemm_nt_q4_k */
6212  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6213 
6214  /* K projection: Q4_K -> gemm_nt_q4_k */
6215  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6216 
6217  /* V projection: Q4_K -> gemm_nt_q4_k */
6218  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6219 
6220  /* Step 3: RoPE */
6221  rope_forward_qk(q_token,
6222  k_token,
6223  rope_cos,
6224  rope_sin,
6225  H,
6226  H_kv,
6227  1,
6228  head_dim,
6229  aligned_head_dim,
6230  token_index);
6231 
6232  /* Step 4: KV cache write */
6233  kv_cache_write_head_major(k_token,
6234  v_token,
6235  k_cache,
6236  v_cache,
6237  H_kv,
6238  token_index,
6239  aligned_context_window,
6240  head_dim,
6241  aligned_head_dim);
6242 
6243  /* Step 5: Attention (decode) */
6245  k_cache,
6246  v_cache,
6247  attn_token,
6248  H,
6249  H_kv,
6250  token_index + 1,
6251  aligned_context_window,
6252  head_dim,
6253  aligned_head_dim);
6254 
6255  /* Step 6: Output projection */
6256  /* WO projection: Q4_K -> gemm_nt_q4_k */
6257  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6258 
6259  /* Step 7: Residual add */
6260  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6261 
6262  /* Step 8: RMSNorm before MLP */
6263  rmsnorm_forward(residual1,
6264  ln2_gamma,
6265  ln2_out,
6266  NULL,
6267  1,
6269  aligned_embed_dim,
6270  1e-06f);
6271 
6272  /* Step 9: MLP (SwiGLU) */
6273  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6274  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6275 
6276  /* SwiGLU activation */
6277  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6278 
6279  /* Down projection: Q4_K -> gemm_nt_q4_k */
6280  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6281 
6282  /* Step 10: Final residual add */
6283  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6284 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_11_prefill()

static void qwen2_0_5b_decode_layer_11_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1982 of file v6.6/test_generated/ck-kernel-inference.c.

1989  {
1991 
1992  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[10].output);
1993  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1994  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1995  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1996  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1997  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1998  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1999  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2000  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2001  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2002  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2003  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2004  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2005  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2006  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2007  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2008 
2009  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2010  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2011  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2012  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2013  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2014  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2015  const float *BQ = NULL;
2016  const float *BK = NULL;
2017  const float *BV = NULL;
2018  const float *BO = NULL;
2019  const float *B1 = NULL;
2020  const float *B2 = NULL;
2021 
2024 
2025  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2026  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2027  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2028  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2029  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2030  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2031 
2032  /* RMSNorm before attention */
2033  rmsnorm_forward(input,
2034  ln1_gamma,
2035  ln1_out,
2036  NULL,
2037  num_tokens,
2039  aligned_embed_dim,
2040  1e-06f);
2041 
2042  /* Q projection (head-major) */
2043  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2044  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2045  for (int h = 0; h < H; ++h) {
2046  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2047  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2048  float *q_h = q + (size_t)h * q_head_stride;
2049  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2050  }
2051 
2052  /* K projection (head-major) */
2053  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2054  const uint8_t *WK_bytes = (const uint8_t *)WK;
2055  for (int h = 0; h < H_kv; ++h) {
2056  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2057  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2058  float *k_h = k + (size_t)h * kv_head_stride;
2059  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2060  }
2061 
2062  /* V projection (head-major) */
2063  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2064  const uint8_t *WV_bytes = (const uint8_t *)WV;
2065  for (int h = 0; h < H_kv; ++h) {
2066  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2067  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2068  float *v_h = v + (size_t)h * kv_head_stride;
2069  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2070  }
2071 
2072  /* RoPE */
2073  rope_forward_qk(q,
2074  k,
2075  rope_cos,
2076  rope_sin,
2077  H,
2078  H_kv,
2079  num_tokens,
2080  head_dim,
2081  aligned_head_dim,
2082  0);
2083 
2084  /* Attention (prefill, causal) */
2086  k,
2087  v,
2088  attn_out,
2089  H,
2090  H_kv,
2091  num_tokens,
2092  head_dim,
2093  aligned_head_dim);
2094 
2095  /* Output projection (flatten head-major to token-major) */
2096  const int K = H * aligned_head_dim;
2097  if (K != aligned_embed_dim) {
2098  return;
2099  }
2100  const float *proj_in = attn_out;
2101  if (H > 1) {
2102  if (!proj_scratch) {
2103  return;
2104  }
2105  for (int t = 0; t < num_tokens; ++t) {
2106  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2107  for (int h = 0; h < H; ++h) {
2108  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2109  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2110  src,
2111  (size_t)aligned_head_dim * sizeof(float));
2112  }
2113  }
2114  proj_in = proj_scratch;
2115  }
2116  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2117 
2118  /* Residual add */
2119  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2120 
2121  /* RMSNorm before MLP */
2122  rmsnorm_forward(residual1,
2123  ln2_gamma,
2124  ln2_out,
2125  NULL,
2126  num_tokens,
2128  aligned_embed_dim,
2129  1e-06f);
2130 
2131  /* MLP (SwiGLU) */
2132  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2133  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2134  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2135 
2136  /* Final residual add */
2137  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2138 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_12_decode()

static void qwen2_0_5b_decode_layer_12_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6289 of file v6.6/test_generated/ck-kernel-inference.c.

6296  {
6298 
6299  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[11].output);
6300 
6301  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6302  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6303  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6304  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6305  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6306  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6307  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6308  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6309  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6310  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6311  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6312 
6313  /* Weights (explicit types for layer 12) */
6314  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6315  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6316  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6317  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6318  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6319  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6320 
6323 
6324  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6325  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6326  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6327 
6328  float q_token[H * aligned_head_dim];
6329  float k_token[H_kv * aligned_head_dim];
6330  float v_token[H_kv * aligned_head_dim];
6331  float attn_token[H * aligned_head_dim];
6332 
6333  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6334  float fc1_out[2 * aligned_intermediate_dim];
6335  float swiglu_out[aligned_intermediate_dim];
6336 
6337  /* Step 1: RMSNorm before attention */
6338  rmsnorm_forward(input,
6339  ln1_gamma,
6340  ln1_out,
6341  NULL,
6342  1,
6344  aligned_embed_dim,
6345  1e-06f);
6346 
6347  /* Step 2: QKV projection */
6348  /* Q projection: Q4_K -> gemm_nt_q4_k */
6349  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6350 
6351  /* K projection: Q4_K -> gemm_nt_q4_k */
6352  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6353 
6354  /* V projection: Q4_K -> gemm_nt_q4_k */
6355  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6356 
6357  /* Step 3: RoPE */
6358  rope_forward_qk(q_token,
6359  k_token,
6360  rope_cos,
6361  rope_sin,
6362  H,
6363  H_kv,
6364  1,
6365  head_dim,
6366  aligned_head_dim,
6367  token_index);
6368 
6369  /* Step 4: KV cache write */
6370  kv_cache_write_head_major(k_token,
6371  v_token,
6372  k_cache,
6373  v_cache,
6374  H_kv,
6375  token_index,
6376  aligned_context_window,
6377  head_dim,
6378  aligned_head_dim);
6379 
6380  /* Step 5: Attention (decode) */
6382  k_cache,
6383  v_cache,
6384  attn_token,
6385  H,
6386  H_kv,
6387  token_index + 1,
6388  aligned_context_window,
6389  head_dim,
6390  aligned_head_dim);
6391 
6392  /* Step 6: Output projection */
6393  /* WO projection: Q4_K -> gemm_nt_q4_k */
6394  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6395 
6396  /* Step 7: Residual add */
6397  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6398 
6399  /* Step 8: RMSNorm before MLP */
6400  rmsnorm_forward(residual1,
6401  ln2_gamma,
6402  ln2_out,
6403  NULL,
6404  1,
6406  aligned_embed_dim,
6407  1e-06f);
6408 
6409  /* Step 9: MLP (SwiGLU) */
6410  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6411  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6412 
6413  /* SwiGLU activation */
6414  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6415 
6416  /* Down projection: Q4_K -> gemm_nt_q4_k */
6417  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6418 
6419  /* Step 10: Final residual add */
6420  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6421 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_12_prefill()

static void qwen2_0_5b_decode_layer_12_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2143 of file v6.6/test_generated/ck-kernel-inference.c.

2150  {
2152 
2153  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[11].output);
2154  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2155  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2156  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2157  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2158  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2159  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2160  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2161  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2162  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2163  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2164  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2165  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2166  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2167  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2168  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2169 
2170  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2171  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2172  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2173  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2174  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2175  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2176  const float *BQ = NULL;
2177  const float *BK = NULL;
2178  const float *BV = NULL;
2179  const float *BO = NULL;
2180  const float *B1 = NULL;
2181  const float *B2 = NULL;
2182 
2185 
2186  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2187  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2188  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2189  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2190  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2191  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2192 
2193  /* RMSNorm before attention */
2194  rmsnorm_forward(input,
2195  ln1_gamma,
2196  ln1_out,
2197  NULL,
2198  num_tokens,
2200  aligned_embed_dim,
2201  1e-06f);
2202 
2203  /* Q projection (head-major) */
2204  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2205  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2206  for (int h = 0; h < H; ++h) {
2207  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2208  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2209  float *q_h = q + (size_t)h * q_head_stride;
2210  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2211  }
2212 
2213  /* K projection (head-major) */
2214  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2215  const uint8_t *WK_bytes = (const uint8_t *)WK;
2216  for (int h = 0; h < H_kv; ++h) {
2217  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2218  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2219  float *k_h = k + (size_t)h * kv_head_stride;
2220  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2221  }
2222 
2223  /* V projection (head-major) */
2224  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2225  const uint8_t *WV_bytes = (const uint8_t *)WV;
2226  for (int h = 0; h < H_kv; ++h) {
2227  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2228  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2229  float *v_h = v + (size_t)h * kv_head_stride;
2230  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2231  }
2232 
2233  /* RoPE */
2234  rope_forward_qk(q,
2235  k,
2236  rope_cos,
2237  rope_sin,
2238  H,
2239  H_kv,
2240  num_tokens,
2241  head_dim,
2242  aligned_head_dim,
2243  0);
2244 
2245  /* Attention (prefill, causal) */
2247  k,
2248  v,
2249  attn_out,
2250  H,
2251  H_kv,
2252  num_tokens,
2253  head_dim,
2254  aligned_head_dim);
2255 
2256  /* Output projection (flatten head-major to token-major) */
2257  const int K = H * aligned_head_dim;
2258  if (K != aligned_embed_dim) {
2259  return;
2260  }
2261  const float *proj_in = attn_out;
2262  if (H > 1) {
2263  if (!proj_scratch) {
2264  return;
2265  }
2266  for (int t = 0; t < num_tokens; ++t) {
2267  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2268  for (int h = 0; h < H; ++h) {
2269  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2270  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2271  src,
2272  (size_t)aligned_head_dim * sizeof(float));
2273  }
2274  }
2275  proj_in = proj_scratch;
2276  }
2277  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2278 
2279  /* Residual add */
2280  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2281 
2282  /* RMSNorm before MLP */
2283  rmsnorm_forward(residual1,
2284  ln2_gamma,
2285  ln2_out,
2286  NULL,
2287  num_tokens,
2289  aligned_embed_dim,
2290  1e-06f);
2291 
2292  /* MLP (SwiGLU) */
2293  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2294  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2295  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2296 
2297  /* Final residual add */
2298  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2299 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_13_decode()

static void qwen2_0_5b_decode_layer_13_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6426 of file v6.6/test_generated/ck-kernel-inference.c.

6433  {
6435 
6436  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[12].output);
6437 
6438  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6439  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6440  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6441  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6442  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6443  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6444  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6445  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6446  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6447  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6448  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6449 
6450  /* Weights (explicit types for layer 13) */
6451  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6452  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6453  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6454  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6455  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6456  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6457 
6460 
6461  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6462  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6463  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6464 
6465  float q_token[H * aligned_head_dim];
6466  float k_token[H_kv * aligned_head_dim];
6467  float v_token[H_kv * aligned_head_dim];
6468  float attn_token[H * aligned_head_dim];
6469 
6470  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6471  float fc1_out[2 * aligned_intermediate_dim];
6472  float swiglu_out[aligned_intermediate_dim];
6473 
6474  /* Step 1: RMSNorm before attention */
6475  rmsnorm_forward(input,
6476  ln1_gamma,
6477  ln1_out,
6478  NULL,
6479  1,
6481  aligned_embed_dim,
6482  1e-06f);
6483 
6484  /* Step 2: QKV projection */
6485  /* Q projection: Q4_K -> gemm_nt_q4_k */
6486  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6487 
6488  /* K projection: Q4_K -> gemm_nt_q4_k */
6489  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6490 
6491  /* V projection: Q4_K -> gemm_nt_q4_k */
6492  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6493 
6494  /* Step 3: RoPE */
6495  rope_forward_qk(q_token,
6496  k_token,
6497  rope_cos,
6498  rope_sin,
6499  H,
6500  H_kv,
6501  1,
6502  head_dim,
6503  aligned_head_dim,
6504  token_index);
6505 
6506  /* Step 4: KV cache write */
6507  kv_cache_write_head_major(k_token,
6508  v_token,
6509  k_cache,
6510  v_cache,
6511  H_kv,
6512  token_index,
6513  aligned_context_window,
6514  head_dim,
6515  aligned_head_dim);
6516 
6517  /* Step 5: Attention (decode) */
6519  k_cache,
6520  v_cache,
6521  attn_token,
6522  H,
6523  H_kv,
6524  token_index + 1,
6525  aligned_context_window,
6526  head_dim,
6527  aligned_head_dim);
6528 
6529  /* Step 6: Output projection */
6530  /* WO projection: Q4_K -> gemm_nt_q4_k */
6531  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6532 
6533  /* Step 7: Residual add */
6534  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6535 
6536  /* Step 8: RMSNorm before MLP */
6537  rmsnorm_forward(residual1,
6538  ln2_gamma,
6539  ln2_out,
6540  NULL,
6541  1,
6543  aligned_embed_dim,
6544  1e-06f);
6545 
6546  /* Step 9: MLP (SwiGLU) */
6547  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6548  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6549 
6550  /* SwiGLU activation */
6551  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6552 
6553  /* Down projection: Q4_K -> gemm_nt_q4_k */
6554  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6555 
6556  /* Step 10: Final residual add */
6557  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6558 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_13_prefill()

static void qwen2_0_5b_decode_layer_13_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2304 of file v6.6/test_generated/ck-kernel-inference.c.

2311  {
2313 
2314  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[12].output);
2315  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2316  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2317  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2318  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2319  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2320  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2321  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2322  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2323  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2324  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2325  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2326  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2327  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2328  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2329  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2330 
2331  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2332  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2333  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2334  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2335  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2336  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2337  const float *BQ = NULL;
2338  const float *BK = NULL;
2339  const float *BV = NULL;
2340  const float *BO = NULL;
2341  const float *B1 = NULL;
2342  const float *B2 = NULL;
2343 
2346 
2347  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2348  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2349  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2350  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2351  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2352  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2353 
2354  /* RMSNorm before attention */
2355  rmsnorm_forward(input,
2356  ln1_gamma,
2357  ln1_out,
2358  NULL,
2359  num_tokens,
2361  aligned_embed_dim,
2362  1e-06f);
2363 
2364  /* Q projection (head-major) */
2365  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2366  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2367  for (int h = 0; h < H; ++h) {
2368  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2369  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2370  float *q_h = q + (size_t)h * q_head_stride;
2371  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2372  }
2373 
2374  /* K projection (head-major) */
2375  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2376  const uint8_t *WK_bytes = (const uint8_t *)WK;
2377  for (int h = 0; h < H_kv; ++h) {
2378  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2379  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2380  float *k_h = k + (size_t)h * kv_head_stride;
2381  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2382  }
2383 
2384  /* V projection (head-major) */
2385  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2386  const uint8_t *WV_bytes = (const uint8_t *)WV;
2387  for (int h = 0; h < H_kv; ++h) {
2388  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2389  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2390  float *v_h = v + (size_t)h * kv_head_stride;
2391  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2392  }
2393 
2394  /* RoPE */
2395  rope_forward_qk(q,
2396  k,
2397  rope_cos,
2398  rope_sin,
2399  H,
2400  H_kv,
2401  num_tokens,
2402  head_dim,
2403  aligned_head_dim,
2404  0);
2405 
2406  /* Attention (prefill, causal) */
2408  k,
2409  v,
2410  attn_out,
2411  H,
2412  H_kv,
2413  num_tokens,
2414  head_dim,
2415  aligned_head_dim);
2416 
2417  /* Output projection (flatten head-major to token-major) */
2418  const int K = H * aligned_head_dim;
2419  if (K != aligned_embed_dim) {
2420  return;
2421  }
2422  const float *proj_in = attn_out;
2423  if (H > 1) {
2424  if (!proj_scratch) {
2425  return;
2426  }
2427  for (int t = 0; t < num_tokens; ++t) {
2428  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2429  for (int h = 0; h < H; ++h) {
2430  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2431  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2432  src,
2433  (size_t)aligned_head_dim * sizeof(float));
2434  }
2435  }
2436  proj_in = proj_scratch;
2437  }
2438  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2439 
2440  /* Residual add */
2441  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2442 
2443  /* RMSNorm before MLP */
2444  rmsnorm_forward(residual1,
2445  ln2_gamma,
2446  ln2_out,
2447  NULL,
2448  num_tokens,
2450  aligned_embed_dim,
2451  1e-06f);
2452 
2453  /* MLP (SwiGLU) */
2454  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2455  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2456  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2457 
2458  /* Final residual add */
2459  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2460 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_14_decode()

static void qwen2_0_5b_decode_layer_14_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6563 of file v6.6/test_generated/ck-kernel-inference.c.

6570  {
6572 
6573  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[13].output);
6574 
6575  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6576  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6577  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6578  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6579  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6580  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6581  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6582  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6583  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6584  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6585  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6586 
6587  /* Weights (explicit types for layer 14) */
6588  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6589  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6590  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6591  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6592  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6593  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6594 
6597 
6598  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6599  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6600  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6601 
6602  float q_token[H * aligned_head_dim];
6603  float k_token[H_kv * aligned_head_dim];
6604  float v_token[H_kv * aligned_head_dim];
6605  float attn_token[H * aligned_head_dim];
6606 
6607  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6608  float fc1_out[2 * aligned_intermediate_dim];
6609  float swiglu_out[aligned_intermediate_dim];
6610 
6611  /* Step 1: RMSNorm before attention */
6612  rmsnorm_forward(input,
6613  ln1_gamma,
6614  ln1_out,
6615  NULL,
6616  1,
6618  aligned_embed_dim,
6619  1e-06f);
6620 
6621  /* Step 2: QKV projection */
6622  /* Q projection: Q4_K -> gemm_nt_q4_k */
6623  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6624 
6625  /* K projection: Q4_K -> gemm_nt_q4_k */
6626  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6627 
6628  /* V projection: Q4_K -> gemm_nt_q4_k */
6629  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6630 
6631  /* Step 3: RoPE */
6632  rope_forward_qk(q_token,
6633  k_token,
6634  rope_cos,
6635  rope_sin,
6636  H,
6637  H_kv,
6638  1,
6639  head_dim,
6640  aligned_head_dim,
6641  token_index);
6642 
6643  /* Step 4: KV cache write */
6644  kv_cache_write_head_major(k_token,
6645  v_token,
6646  k_cache,
6647  v_cache,
6648  H_kv,
6649  token_index,
6650  aligned_context_window,
6651  head_dim,
6652  aligned_head_dim);
6653 
6654  /* Step 5: Attention (decode) */
6656  k_cache,
6657  v_cache,
6658  attn_token,
6659  H,
6660  H_kv,
6661  token_index + 1,
6662  aligned_context_window,
6663  head_dim,
6664  aligned_head_dim);
6665 
6666  /* Step 6: Output projection */
6667  /* WO projection: Q4_K -> gemm_nt_q4_k */
6668  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6669 
6670  /* Step 7: Residual add */
6671  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6672 
6673  /* Step 8: RMSNorm before MLP */
6674  rmsnorm_forward(residual1,
6675  ln2_gamma,
6676  ln2_out,
6677  NULL,
6678  1,
6680  aligned_embed_dim,
6681  1e-06f);
6682 
6683  /* Step 9: MLP (SwiGLU) */
6684  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6685  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6686 
6687  /* SwiGLU activation */
6688  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6689 
6690  /* Down projection: Q4_K -> gemm_nt_q4_k */
6691  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6692 
6693  /* Step 10: Final residual add */
6694  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6695 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_14_prefill()

static void qwen2_0_5b_decode_layer_14_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2465 of file v6.6/test_generated/ck-kernel-inference.c.

2472  {
2474 
2475  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[13].output);
2476  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2477  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2478  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2479  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2480  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2481  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2482  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2483  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2484  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2485  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2486  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2487  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2488  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2489  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2490  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2491 
2492  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2493  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2494  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2495  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2496  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2497  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2498  const float *BQ = NULL;
2499  const float *BK = NULL;
2500  const float *BV = NULL;
2501  const float *BO = NULL;
2502  const float *B1 = NULL;
2503  const float *B2 = NULL;
2504 
2507 
2508  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2509  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2510  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2511  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2512  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2513  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2514 
2515  /* RMSNorm before attention */
2516  rmsnorm_forward(input,
2517  ln1_gamma,
2518  ln1_out,
2519  NULL,
2520  num_tokens,
2522  aligned_embed_dim,
2523  1e-06f);
2524 
2525  /* Q projection (head-major) */
2526  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2527  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2528  for (int h = 0; h < H; ++h) {
2529  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2530  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2531  float *q_h = q + (size_t)h * q_head_stride;
2532  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2533  }
2534 
2535  /* K projection (head-major) */
2536  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2537  const uint8_t *WK_bytes = (const uint8_t *)WK;
2538  for (int h = 0; h < H_kv; ++h) {
2539  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2540  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2541  float *k_h = k + (size_t)h * kv_head_stride;
2542  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2543  }
2544 
2545  /* V projection (head-major) */
2546  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2547  const uint8_t *WV_bytes = (const uint8_t *)WV;
2548  for (int h = 0; h < H_kv; ++h) {
2549  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2550  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2551  float *v_h = v + (size_t)h * kv_head_stride;
2552  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2553  }
2554 
2555  /* RoPE */
2556  rope_forward_qk(q,
2557  k,
2558  rope_cos,
2559  rope_sin,
2560  H,
2561  H_kv,
2562  num_tokens,
2563  head_dim,
2564  aligned_head_dim,
2565  0);
2566 
2567  /* Attention (prefill, causal) */
2569  k,
2570  v,
2571  attn_out,
2572  H,
2573  H_kv,
2574  num_tokens,
2575  head_dim,
2576  aligned_head_dim);
2577 
2578  /* Output projection (flatten head-major to token-major) */
2579  const int K = H * aligned_head_dim;
2580  if (K != aligned_embed_dim) {
2581  return;
2582  }
2583  const float *proj_in = attn_out;
2584  if (H > 1) {
2585  if (!proj_scratch) {
2586  return;
2587  }
2588  for (int t = 0; t < num_tokens; ++t) {
2589  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2590  for (int h = 0; h < H; ++h) {
2591  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2592  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2593  src,
2594  (size_t)aligned_head_dim * sizeof(float));
2595  }
2596  }
2597  proj_in = proj_scratch;
2598  }
2599  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2600 
2601  /* Residual add */
2602  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2603 
2604  /* RMSNorm before MLP */
2605  rmsnorm_forward(residual1,
2606  ln2_gamma,
2607  ln2_out,
2608  NULL,
2609  num_tokens,
2611  aligned_embed_dim,
2612  1e-06f);
2613 
2614  /* MLP (SwiGLU) */
2615  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2616  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2617  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2618 
2619  /* Final residual add */
2620  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2621 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_15_decode()

static void qwen2_0_5b_decode_layer_15_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6700 of file v6.6/test_generated/ck-kernel-inference.c.

6707  {
6709 
6710  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[14].output);
6711 
6712  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6713  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6714  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6715  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6716  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6717  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6718  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6719  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6720  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6721  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6722  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6723 
6724  /* Weights (explicit types for layer 15) */
6725  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6726  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6727  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6728  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6729  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6730  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6731 
6734 
6735  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6736  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6737  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6738 
6739  float q_token[H * aligned_head_dim];
6740  float k_token[H_kv * aligned_head_dim];
6741  float v_token[H_kv * aligned_head_dim];
6742  float attn_token[H * aligned_head_dim];
6743 
6744  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6745  float fc1_out[2 * aligned_intermediate_dim];
6746  float swiglu_out[aligned_intermediate_dim];
6747 
6748  /* Step 1: RMSNorm before attention */
6749  rmsnorm_forward(input,
6750  ln1_gamma,
6751  ln1_out,
6752  NULL,
6753  1,
6755  aligned_embed_dim,
6756  1e-06f);
6757 
6758  /* Step 2: QKV projection */
6759  /* Q projection: Q4_K -> gemm_nt_q4_k */
6760  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6761 
6762  /* K projection: Q4_K -> gemm_nt_q4_k */
6763  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6764 
6765  /* V projection: Q4_K -> gemm_nt_q4_k */
6766  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6767 
6768  /* Step 3: RoPE */
6769  rope_forward_qk(q_token,
6770  k_token,
6771  rope_cos,
6772  rope_sin,
6773  H,
6774  H_kv,
6775  1,
6776  head_dim,
6777  aligned_head_dim,
6778  token_index);
6779 
6780  /* Step 4: KV cache write */
6781  kv_cache_write_head_major(k_token,
6782  v_token,
6783  k_cache,
6784  v_cache,
6785  H_kv,
6786  token_index,
6787  aligned_context_window,
6788  head_dim,
6789  aligned_head_dim);
6790 
6791  /* Step 5: Attention (decode) */
6793  k_cache,
6794  v_cache,
6795  attn_token,
6796  H,
6797  H_kv,
6798  token_index + 1,
6799  aligned_context_window,
6800  head_dim,
6801  aligned_head_dim);
6802 
6803  /* Step 6: Output projection */
6804  /* WO projection: Q4_K -> gemm_nt_q4_k */
6805  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6806 
6807  /* Step 7: Residual add */
6808  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6809 
6810  /* Step 8: RMSNorm before MLP */
6811  rmsnorm_forward(residual1,
6812  ln2_gamma,
6813  ln2_out,
6814  NULL,
6815  1,
6817  aligned_embed_dim,
6818  1e-06f);
6819 
6820  /* Step 9: MLP (SwiGLU) */
6821  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6822  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6823 
6824  /* SwiGLU activation */
6825  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6826 
6827  /* Down projection: Q4_K -> gemm_nt_q4_k */
6828  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6829 
6830  /* Step 10: Final residual add */
6831  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6832 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_15_prefill()

static void qwen2_0_5b_decode_layer_15_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2626 of file v6.6/test_generated/ck-kernel-inference.c.

2633  {
2635 
2636  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[14].output);
2637  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2638  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2639  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2640  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2641  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2642  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2643  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2644  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2645  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2646  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2647  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2648  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2649  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2650  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2651  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2652 
2653  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2654  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2655  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2656  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2657  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2658  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2659  const float *BQ = NULL;
2660  const float *BK = NULL;
2661  const float *BV = NULL;
2662  const float *BO = NULL;
2663  const float *B1 = NULL;
2664  const float *B2 = NULL;
2665 
2668 
2669  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2670  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2671  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2672  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2673  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2674  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2675 
2676  /* RMSNorm before attention */
2677  rmsnorm_forward(input,
2678  ln1_gamma,
2679  ln1_out,
2680  NULL,
2681  num_tokens,
2683  aligned_embed_dim,
2684  1e-06f);
2685 
2686  /* Q projection (head-major) */
2687  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2688  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2689  for (int h = 0; h < H; ++h) {
2690  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2691  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2692  float *q_h = q + (size_t)h * q_head_stride;
2693  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2694  }
2695 
2696  /* K projection (head-major) */
2697  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2698  const uint8_t *WK_bytes = (const uint8_t *)WK;
2699  for (int h = 0; h < H_kv; ++h) {
2700  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2701  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2702  float *k_h = k + (size_t)h * kv_head_stride;
2703  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2704  }
2705 
2706  /* V projection (head-major) */
2707  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2708  const uint8_t *WV_bytes = (const uint8_t *)WV;
2709  for (int h = 0; h < H_kv; ++h) {
2710  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2711  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2712  float *v_h = v + (size_t)h * kv_head_stride;
2713  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2714  }
2715 
2716  /* RoPE */
2717  rope_forward_qk(q,
2718  k,
2719  rope_cos,
2720  rope_sin,
2721  H,
2722  H_kv,
2723  num_tokens,
2724  head_dim,
2725  aligned_head_dim,
2726  0);
2727 
2728  /* Attention (prefill, causal) */
2730  k,
2731  v,
2732  attn_out,
2733  H,
2734  H_kv,
2735  num_tokens,
2736  head_dim,
2737  aligned_head_dim);
2738 
2739  /* Output projection (flatten head-major to token-major) */
2740  const int K = H * aligned_head_dim;
2741  if (K != aligned_embed_dim) {
2742  return;
2743  }
2744  const float *proj_in = attn_out;
2745  if (H > 1) {
2746  if (!proj_scratch) {
2747  return;
2748  }
2749  for (int t = 0; t < num_tokens; ++t) {
2750  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2751  for (int h = 0; h < H; ++h) {
2752  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2753  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2754  src,
2755  (size_t)aligned_head_dim * sizeof(float));
2756  }
2757  }
2758  proj_in = proj_scratch;
2759  }
2760  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2761 
2762  /* Residual add */
2763  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2764 
2765  /* RMSNorm before MLP */
2766  rmsnorm_forward(residual1,
2767  ln2_gamma,
2768  ln2_out,
2769  NULL,
2770  num_tokens,
2772  aligned_embed_dim,
2773  1e-06f);
2774 
2775  /* MLP (SwiGLU) */
2776  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2777  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2778  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2779 
2780  /* Final residual add */
2781  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2782 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_16_decode()

static void qwen2_0_5b_decode_layer_16_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6837 of file v6.6/test_generated/ck-kernel-inference.c.

6844  {
6846 
6847  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[15].output);
6848 
6849  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6850  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6851  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6852  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6853  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6854  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6855  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6856  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6857  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6858  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6859  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6860 
6861  /* Weights (explicit types for layer 16) */
6862  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6863  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6864  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6865  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6866  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6867  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6868 
6871 
6872  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6873  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6874  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6875 
6876  float q_token[H * aligned_head_dim];
6877  float k_token[H_kv * aligned_head_dim];
6878  float v_token[H_kv * aligned_head_dim];
6879  float attn_token[H * aligned_head_dim];
6880 
6881  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6882  float fc1_out[2 * aligned_intermediate_dim];
6883  float swiglu_out[aligned_intermediate_dim];
6884 
6885  /* Step 1: RMSNorm before attention */
6886  rmsnorm_forward(input,
6887  ln1_gamma,
6888  ln1_out,
6889  NULL,
6890  1,
6892  aligned_embed_dim,
6893  1e-06f);
6894 
6895  /* Step 2: QKV projection */
6896  /* Q projection: Q4_K -> gemm_nt_q4_k */
6897  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6898 
6899  /* K projection: Q4_K -> gemm_nt_q4_k */
6900  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6901 
6902  /* V projection: Q4_K -> gemm_nt_q4_k */
6903  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6904 
6905  /* Step 3: RoPE */
6906  rope_forward_qk(q_token,
6907  k_token,
6908  rope_cos,
6909  rope_sin,
6910  H,
6911  H_kv,
6912  1,
6913  head_dim,
6914  aligned_head_dim,
6915  token_index);
6916 
6917  /* Step 4: KV cache write */
6918  kv_cache_write_head_major(k_token,
6919  v_token,
6920  k_cache,
6921  v_cache,
6922  H_kv,
6923  token_index,
6924  aligned_context_window,
6925  head_dim,
6926  aligned_head_dim);
6927 
6928  /* Step 5: Attention (decode) */
6930  k_cache,
6931  v_cache,
6932  attn_token,
6933  H,
6934  H_kv,
6935  token_index + 1,
6936  aligned_context_window,
6937  head_dim,
6938  aligned_head_dim);
6939 
6940  /* Step 6: Output projection */
6941  /* WO projection: Q4_K -> gemm_nt_q4_k */
6942  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6943 
6944  /* Step 7: Residual add */
6945  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6946 
6947  /* Step 8: RMSNorm before MLP */
6948  rmsnorm_forward(residual1,
6949  ln2_gamma,
6950  ln2_out,
6951  NULL,
6952  1,
6954  aligned_embed_dim,
6955  1e-06f);
6956 
6957  /* Step 9: MLP (SwiGLU) */
6958  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6959  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6960 
6961  /* SwiGLU activation */
6962  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6963 
6964  /* Down projection: Q4_K -> gemm_nt_q4_k */
6965  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6966 
6967  /* Step 10: Final residual add */
6968  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6969 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_16_prefill()

static void qwen2_0_5b_decode_layer_16_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2787 of file v6.6/test_generated/ck-kernel-inference.c.

2794  {
2796 
2797  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[15].output);
2798  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2799  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2800  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2801  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2802  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2803  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2804  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2805  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2806  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2807  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2808  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2809  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2810  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2811  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2812  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2813 
2814  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2815  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2816  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2817  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2818  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2819  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2820  const float *BQ = NULL;
2821  const float *BK = NULL;
2822  const float *BV = NULL;
2823  const float *BO = NULL;
2824  const float *B1 = NULL;
2825  const float *B2 = NULL;
2826 
2829 
2830  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2831  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2832  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2833  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2834  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2835  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2836 
2837  /* RMSNorm before attention */
2838  rmsnorm_forward(input,
2839  ln1_gamma,
2840  ln1_out,
2841  NULL,
2842  num_tokens,
2844  aligned_embed_dim,
2845  1e-06f);
2846 
2847  /* Q projection (head-major) */
2848  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2849  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2850  for (int h = 0; h < H; ++h) {
2851  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2852  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2853  float *q_h = q + (size_t)h * q_head_stride;
2854  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2855  }
2856 
2857  /* K projection (head-major) */
2858  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2859  const uint8_t *WK_bytes = (const uint8_t *)WK;
2860  for (int h = 0; h < H_kv; ++h) {
2861  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2862  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2863  float *k_h = k + (size_t)h * kv_head_stride;
2864  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2865  }
2866 
2867  /* V projection (head-major) */
2868  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2869  const uint8_t *WV_bytes = (const uint8_t *)WV;
2870  for (int h = 0; h < H_kv; ++h) {
2871  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2872  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2873  float *v_h = v + (size_t)h * kv_head_stride;
2874  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2875  }
2876 
2877  /* RoPE */
2878  rope_forward_qk(q,
2879  k,
2880  rope_cos,
2881  rope_sin,
2882  H,
2883  H_kv,
2884  num_tokens,
2885  head_dim,
2886  aligned_head_dim,
2887  0);
2888 
2889  /* Attention (prefill, causal) */
2891  k,
2892  v,
2893  attn_out,
2894  H,
2895  H_kv,
2896  num_tokens,
2897  head_dim,
2898  aligned_head_dim);
2899 
2900  /* Output projection (flatten head-major to token-major) */
2901  const int K = H * aligned_head_dim;
2902  if (K != aligned_embed_dim) {
2903  return;
2904  }
2905  const float *proj_in = attn_out;
2906  if (H > 1) {
2907  if (!proj_scratch) {
2908  return;
2909  }
2910  for (int t = 0; t < num_tokens; ++t) {
2911  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2912  for (int h = 0; h < H; ++h) {
2913  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2914  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2915  src,
2916  (size_t)aligned_head_dim * sizeof(float));
2917  }
2918  }
2919  proj_in = proj_scratch;
2920  }
2921  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2922 
2923  /* Residual add */
2924  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2925 
2926  /* RMSNorm before MLP */
2927  rmsnorm_forward(residual1,
2928  ln2_gamma,
2929  ln2_out,
2930  NULL,
2931  num_tokens,
2933  aligned_embed_dim,
2934  1e-06f);
2935 
2936  /* MLP (SwiGLU) */
2937  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2938  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2939  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2940 
2941  /* Final residual add */
2942  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2943 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_17_decode()

static void qwen2_0_5b_decode_layer_17_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6974 of file v6.6/test_generated/ck-kernel-inference.c.

6981  {
6983 
6984  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[16].output);
6985 
6986  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6987  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6988  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6989  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6990  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6991  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6992  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6993  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6994  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6995  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6996  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6997 
6998  /* Weights (explicit types for layer 17) */
6999  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7000  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7001  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7002  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7003  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7004  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7005 
7008 
7009  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7010  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7011  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7012 
7013  float q_token[H * aligned_head_dim];
7014  float k_token[H_kv * aligned_head_dim];
7015  float v_token[H_kv * aligned_head_dim];
7016  float attn_token[H * aligned_head_dim];
7017 
7018  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7019  float fc1_out[2 * aligned_intermediate_dim];
7020  float swiglu_out[aligned_intermediate_dim];
7021 
7022  /* Step 1: RMSNorm before attention */
7023  rmsnorm_forward(input,
7024  ln1_gamma,
7025  ln1_out,
7026  NULL,
7027  1,
7029  aligned_embed_dim,
7030  1e-06f);
7031 
7032  /* Step 2: QKV projection */
7033  /* Q projection: Q4_K -> gemm_nt_q4_k */
7034  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7035 
7036  /* K projection: Q4_K -> gemm_nt_q4_k */
7037  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7038 
7039  /* V projection: Q4_K -> gemm_nt_q4_k */
7040  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7041 
7042  /* Step 3: RoPE */
7043  rope_forward_qk(q_token,
7044  k_token,
7045  rope_cos,
7046  rope_sin,
7047  H,
7048  H_kv,
7049  1,
7050  head_dim,
7051  aligned_head_dim,
7052  token_index);
7053 
7054  /* Step 4: KV cache write */
7055  kv_cache_write_head_major(k_token,
7056  v_token,
7057  k_cache,
7058  v_cache,
7059  H_kv,
7060  token_index,
7061  aligned_context_window,
7062  head_dim,
7063  aligned_head_dim);
7064 
7065  /* Step 5: Attention (decode) */
7067  k_cache,
7068  v_cache,
7069  attn_token,
7070  H,
7071  H_kv,
7072  token_index + 1,
7073  aligned_context_window,
7074  head_dim,
7075  aligned_head_dim);
7076 
7077  /* Step 6: Output projection */
7078  /* WO projection: Q4_K -> gemm_nt_q4_k */
7079  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7080 
7081  /* Step 7: Residual add */
7082  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7083 
7084  /* Step 8: RMSNorm before MLP */
7085  rmsnorm_forward(residual1,
7086  ln2_gamma,
7087  ln2_out,
7088  NULL,
7089  1,
7091  aligned_embed_dim,
7092  1e-06f);
7093 
7094  /* Step 9: MLP (SwiGLU) */
7095  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7096  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7097 
7098  /* SwiGLU activation */
7099  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7100 
7101  /* Down projection: Q4_K -> gemm_nt_q4_k */
7102  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7103 
7104  /* Step 10: Final residual add */
7105  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7106 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_17_prefill()

static void qwen2_0_5b_decode_layer_17_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2948 of file v6.6/test_generated/ck-kernel-inference.c.

2955  {
2957 
2958  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[16].output);
2959  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2960  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2961  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2962  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2963  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2964  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2965  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2966  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2967  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2968  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2969  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2970  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2971  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2972  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2973  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2974 
2975  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2976  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2977  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2978  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2979  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2980  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2981  const float *BQ = NULL;
2982  const float *BK = NULL;
2983  const float *BV = NULL;
2984  const float *BO = NULL;
2985  const float *B1 = NULL;
2986  const float *B2 = NULL;
2987 
2990 
2991  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2992  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2993  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2994  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2995  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2996  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2997 
2998  /* RMSNorm before attention */
2999  rmsnorm_forward(input,
3000  ln1_gamma,
3001  ln1_out,
3002  NULL,
3003  num_tokens,
3005  aligned_embed_dim,
3006  1e-06f);
3007 
3008  /* Q projection (head-major) */
3009  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3010  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3011  for (int h = 0; h < H; ++h) {
3012  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3013  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3014  float *q_h = q + (size_t)h * q_head_stride;
3015  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3016  }
3017 
3018  /* K projection (head-major) */
3019  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3020  const uint8_t *WK_bytes = (const uint8_t *)WK;
3021  for (int h = 0; h < H_kv; ++h) {
3022  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3023  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3024  float *k_h = k + (size_t)h * kv_head_stride;
3025  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3026  }
3027 
3028  /* V projection (head-major) */
3029  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3030  const uint8_t *WV_bytes = (const uint8_t *)WV;
3031  for (int h = 0; h < H_kv; ++h) {
3032  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3033  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3034  float *v_h = v + (size_t)h * kv_head_stride;
3035  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3036  }
3037 
3038  /* RoPE */
3039  rope_forward_qk(q,
3040  k,
3041  rope_cos,
3042  rope_sin,
3043  H,
3044  H_kv,
3045  num_tokens,
3046  head_dim,
3047  aligned_head_dim,
3048  0);
3049 
3050  /* Attention (prefill, causal) */
3052  k,
3053  v,
3054  attn_out,
3055  H,
3056  H_kv,
3057  num_tokens,
3058  head_dim,
3059  aligned_head_dim);
3060 
3061  /* Output projection (flatten head-major to token-major) */
3062  const int K = H * aligned_head_dim;
3063  if (K != aligned_embed_dim) {
3064  return;
3065  }
3066  const float *proj_in = attn_out;
3067  if (H > 1) {
3068  if (!proj_scratch) {
3069  return;
3070  }
3071  for (int t = 0; t < num_tokens; ++t) {
3072  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3073  for (int h = 0; h < H; ++h) {
3074  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3075  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3076  src,
3077  (size_t)aligned_head_dim * sizeof(float));
3078  }
3079  }
3080  proj_in = proj_scratch;
3081  }
3082  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3083 
3084  /* Residual add */
3085  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3086 
3087  /* RMSNorm before MLP */
3088  rmsnorm_forward(residual1,
3089  ln2_gamma,
3090  ln2_out,
3091  NULL,
3092  num_tokens,
3094  aligned_embed_dim,
3095  1e-06f);
3096 
3097  /* MLP (SwiGLU) */
3098  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3099  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3100  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3101 
3102  /* Final residual add */
3103  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3104 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_18_decode()

static void qwen2_0_5b_decode_layer_18_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7111 of file v6.6/test_generated/ck-kernel-inference.c.

7118  {
7120 
7121  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[17].output);
7122 
7123  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7124  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7125  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7126  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7127  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7128  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7129  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7130  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7131  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7132  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7133  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7134 
7135  /* Weights (explicit types for layer 18) */
7136  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7137  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7138  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7139  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7140  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7141  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7142 
7145 
7146  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7147  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7148  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7149 
7150  float q_token[H * aligned_head_dim];
7151  float k_token[H_kv * aligned_head_dim];
7152  float v_token[H_kv * aligned_head_dim];
7153  float attn_token[H * aligned_head_dim];
7154 
7155  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7156  float fc1_out[2 * aligned_intermediate_dim];
7157  float swiglu_out[aligned_intermediate_dim];
7158 
7159  /* Step 1: RMSNorm before attention */
7160  rmsnorm_forward(input,
7161  ln1_gamma,
7162  ln1_out,
7163  NULL,
7164  1,
7166  aligned_embed_dim,
7167  1e-06f);
7168 
7169  /* Step 2: QKV projection */
7170  /* Q projection: Q4_K -> gemm_nt_q4_k */
7171  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7172 
7173  /* K projection: Q4_K -> gemm_nt_q4_k */
7174  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7175 
7176  /* V projection: Q4_K -> gemm_nt_q4_k */
7177  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7178 
7179  /* Step 3: RoPE */
7180  rope_forward_qk(q_token,
7181  k_token,
7182  rope_cos,
7183  rope_sin,
7184  H,
7185  H_kv,
7186  1,
7187  head_dim,
7188  aligned_head_dim,
7189  token_index);
7190 
7191  /* Step 4: KV cache write */
7192  kv_cache_write_head_major(k_token,
7193  v_token,
7194  k_cache,
7195  v_cache,
7196  H_kv,
7197  token_index,
7198  aligned_context_window,
7199  head_dim,
7200  aligned_head_dim);
7201 
7202  /* Step 5: Attention (decode) */
7204  k_cache,
7205  v_cache,
7206  attn_token,
7207  H,
7208  H_kv,
7209  token_index + 1,
7210  aligned_context_window,
7211  head_dim,
7212  aligned_head_dim);
7213 
7214  /* Step 6: Output projection */
7215  /* WO projection: Q4_K -> gemm_nt_q4_k */
7216  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7217 
7218  /* Step 7: Residual add */
7219  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7220 
7221  /* Step 8: RMSNorm before MLP */
7222  rmsnorm_forward(residual1,
7223  ln2_gamma,
7224  ln2_out,
7225  NULL,
7226  1,
7228  aligned_embed_dim,
7229  1e-06f);
7230 
7231  /* Step 9: MLP (SwiGLU) */
7232  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7233  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7234 
7235  /* SwiGLU activation */
7236  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7237 
7238  /* Down projection: Q4_K -> gemm_nt_q4_k */
7239  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7240 
7241  /* Step 10: Final residual add */
7242  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7243 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_18_prefill()

static void qwen2_0_5b_decode_layer_18_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3109 of file v6.6/test_generated/ck-kernel-inference.c.

3116  {
3118 
3119  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[17].output);
3120  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3121  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3122  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3123  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3124  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3125  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3126  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3127  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3128  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3129  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3130  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3131  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3132  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3133  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3134  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3135 
3136  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3137  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3138  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3139  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3140  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3141  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3142  const float *BQ = NULL;
3143  const float *BK = NULL;
3144  const float *BV = NULL;
3145  const float *BO = NULL;
3146  const float *B1 = NULL;
3147  const float *B2 = NULL;
3148 
3151 
3152  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3153  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3154  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3155  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3156  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3157  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3158 
3159  /* RMSNorm before attention */
3160  rmsnorm_forward(input,
3161  ln1_gamma,
3162  ln1_out,
3163  NULL,
3164  num_tokens,
3166  aligned_embed_dim,
3167  1e-06f);
3168 
3169  /* Q projection (head-major) */
3170  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3171  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3172  for (int h = 0; h < H; ++h) {
3173  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3174  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3175  float *q_h = q + (size_t)h * q_head_stride;
3176  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3177  }
3178 
3179  /* K projection (head-major) */
3180  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3181  const uint8_t *WK_bytes = (const uint8_t *)WK;
3182  for (int h = 0; h < H_kv; ++h) {
3183  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3184  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3185  float *k_h = k + (size_t)h * kv_head_stride;
3186  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3187  }
3188 
3189  /* V projection (head-major) */
3190  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3191  const uint8_t *WV_bytes = (const uint8_t *)WV;
3192  for (int h = 0; h < H_kv; ++h) {
3193  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3194  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3195  float *v_h = v + (size_t)h * kv_head_stride;
3196  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3197  }
3198 
3199  /* RoPE */
3200  rope_forward_qk(q,
3201  k,
3202  rope_cos,
3203  rope_sin,
3204  H,
3205  H_kv,
3206  num_tokens,
3207  head_dim,
3208  aligned_head_dim,
3209  0);
3210 
3211  /* Attention (prefill, causal) */
3213  k,
3214  v,
3215  attn_out,
3216  H,
3217  H_kv,
3218  num_tokens,
3219  head_dim,
3220  aligned_head_dim);
3221 
3222  /* Output projection (flatten head-major to token-major) */
3223  const int K = H * aligned_head_dim;
3224  if (K != aligned_embed_dim) {
3225  return;
3226  }
3227  const float *proj_in = attn_out;
3228  if (H > 1) {
3229  if (!proj_scratch) {
3230  return;
3231  }
3232  for (int t = 0; t < num_tokens; ++t) {
3233  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3234  for (int h = 0; h < H; ++h) {
3235  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3236  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3237  src,
3238  (size_t)aligned_head_dim * sizeof(float));
3239  }
3240  }
3241  proj_in = proj_scratch;
3242  }
3243  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3244 
3245  /* Residual add */
3246  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3247 
3248  /* RMSNorm before MLP */
3249  rmsnorm_forward(residual1,
3250  ln2_gamma,
3251  ln2_out,
3252  NULL,
3253  num_tokens,
3255  aligned_embed_dim,
3256  1e-06f);
3257 
3258  /* MLP (SwiGLU) */
3259  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3260  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3261  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3262 
3263  /* Final residual add */
3264  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3265 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_19_decode()

static void qwen2_0_5b_decode_layer_19_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7248 of file v6.6/test_generated/ck-kernel-inference.c.

7255  {
7257 
7258  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[18].output);
7259 
7260  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7261  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7262  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7263  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7264  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7265  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7266  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7267  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7268  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7269  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7270  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7271 
7272  /* Weights (explicit types for layer 19) */
7273  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7274  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7275  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7276  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7277  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7278  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7279 
7282 
7283  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7284  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7285  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7286 
7287  float q_token[H * aligned_head_dim];
7288  float k_token[H_kv * aligned_head_dim];
7289  float v_token[H_kv * aligned_head_dim];
7290  float attn_token[H * aligned_head_dim];
7291 
7292  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7293  float fc1_out[2 * aligned_intermediate_dim];
7294  float swiglu_out[aligned_intermediate_dim];
7295 
7296  /* Step 1: RMSNorm before attention */
7297  rmsnorm_forward(input,
7298  ln1_gamma,
7299  ln1_out,
7300  NULL,
7301  1,
7303  aligned_embed_dim,
7304  1e-06f);
7305 
7306  /* Step 2: QKV projection */
7307  /* Q projection: Q4_K -> gemm_nt_q4_k */
7308  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7309 
7310  /* K projection: Q4_K -> gemm_nt_q4_k */
7311  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7312 
7313  /* V projection: Q4_K -> gemm_nt_q4_k */
7314  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7315 
7316  /* Step 3: RoPE */
7317  rope_forward_qk(q_token,
7318  k_token,
7319  rope_cos,
7320  rope_sin,
7321  H,
7322  H_kv,
7323  1,
7324  head_dim,
7325  aligned_head_dim,
7326  token_index);
7327 
7328  /* Step 4: KV cache write */
7329  kv_cache_write_head_major(k_token,
7330  v_token,
7331  k_cache,
7332  v_cache,
7333  H_kv,
7334  token_index,
7335  aligned_context_window,
7336  head_dim,
7337  aligned_head_dim);
7338 
7339  /* Step 5: Attention (decode) */
7341  k_cache,
7342  v_cache,
7343  attn_token,
7344  H,
7345  H_kv,
7346  token_index + 1,
7347  aligned_context_window,
7348  head_dim,
7349  aligned_head_dim);
7350 
7351  /* Step 6: Output projection */
7352  /* WO projection: Q4_K -> gemm_nt_q4_k */
7353  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7354 
7355  /* Step 7: Residual add */
7356  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7357 
7358  /* Step 8: RMSNorm before MLP */
7359  rmsnorm_forward(residual1,
7360  ln2_gamma,
7361  ln2_out,
7362  NULL,
7363  1,
7365  aligned_embed_dim,
7366  1e-06f);
7367 
7368  /* Step 9: MLP (SwiGLU) */
7369  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7370  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7371 
7372  /* SwiGLU activation */
7373  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7374 
7375  /* Down projection: Q4_K -> gemm_nt_q4_k */
7376  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7377 
7378  /* Step 10: Final residual add */
7379  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7380 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_19_prefill()

static void qwen2_0_5b_decode_layer_19_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3270 of file v6.6/test_generated/ck-kernel-inference.c.

3277  {
3279 
3280  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[18].output);
3281  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3282  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3283  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3284  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3285  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3286  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3287  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3288  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3289  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3290  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3291  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3292  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3293  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3294  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3295  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3296 
3297  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3298  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3299  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3300  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3301  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3302  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3303  const float *BQ = NULL;
3304  const float *BK = NULL;
3305  const float *BV = NULL;
3306  const float *BO = NULL;
3307  const float *B1 = NULL;
3308  const float *B2 = NULL;
3309 
3312 
3313  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3314  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3315  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3316  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3317  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3318  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3319 
3320  /* RMSNorm before attention */
3321  rmsnorm_forward(input,
3322  ln1_gamma,
3323  ln1_out,
3324  NULL,
3325  num_tokens,
3327  aligned_embed_dim,
3328  1e-06f);
3329 
3330  /* Q projection (head-major) */
3331  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3332  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3333  for (int h = 0; h < H; ++h) {
3334  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3335  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3336  float *q_h = q + (size_t)h * q_head_stride;
3337  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3338  }
3339 
3340  /* K projection (head-major) */
3341  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3342  const uint8_t *WK_bytes = (const uint8_t *)WK;
3343  for (int h = 0; h < H_kv; ++h) {
3344  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3345  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3346  float *k_h = k + (size_t)h * kv_head_stride;
3347  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3348  }
3349 
3350  /* V projection (head-major) */
3351  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3352  const uint8_t *WV_bytes = (const uint8_t *)WV;
3353  for (int h = 0; h < H_kv; ++h) {
3354  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3355  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3356  float *v_h = v + (size_t)h * kv_head_stride;
3357  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3358  }
3359 
3360  /* RoPE */
3361  rope_forward_qk(q,
3362  k,
3363  rope_cos,
3364  rope_sin,
3365  H,
3366  H_kv,
3367  num_tokens,
3368  head_dim,
3369  aligned_head_dim,
3370  0);
3371 
3372  /* Attention (prefill, causal) */
3374  k,
3375  v,
3376  attn_out,
3377  H,
3378  H_kv,
3379  num_tokens,
3380  head_dim,
3381  aligned_head_dim);
3382 
3383  /* Output projection (flatten head-major to token-major) */
3384  const int K = H * aligned_head_dim;
3385  if (K != aligned_embed_dim) {
3386  return;
3387  }
3388  const float *proj_in = attn_out;
3389  if (H > 1) {
3390  if (!proj_scratch) {
3391  return;
3392  }
3393  for (int t = 0; t < num_tokens; ++t) {
3394  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3395  for (int h = 0; h < H; ++h) {
3396  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3397  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3398  src,
3399  (size_t)aligned_head_dim * sizeof(float));
3400  }
3401  }
3402  proj_in = proj_scratch;
3403  }
3404  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3405 
3406  /* Residual add */
3407  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3408 
3409  /* RMSNorm before MLP */
3410  rmsnorm_forward(residual1,
3411  ln2_gamma,
3412  ln2_out,
3413  NULL,
3414  num_tokens,
3416  aligned_embed_dim,
3417  1e-06f);
3418 
3419  /* MLP (SwiGLU) */
3420  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3421  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3422  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3423 
3424  /* Final residual add */
3425  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3426 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_1_decode()

static void qwen2_0_5b_decode_layer_1_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4782 of file v6.6/test_generated/ck-kernel-inference.c.

4789  {
4791 
4792  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[0].output);
4793 
4794  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4795  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4796  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4797  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4798  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4799  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4800  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4801  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4802  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4803  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4804  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4805 
4806  /* Weights (explicit types for layer 1) */
4807  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4808  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4809  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4810  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4811  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4812  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4813 
4816 
4817  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4818  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4819  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4820 
4821  float q_token[H * aligned_head_dim];
4822  float k_token[H_kv * aligned_head_dim];
4823  float v_token[H_kv * aligned_head_dim];
4824  float attn_token[H * aligned_head_dim];
4825 
4826  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4827  float fc1_out[2 * aligned_intermediate_dim];
4828  float swiglu_out[aligned_intermediate_dim];
4829 
4830  /* Step 1: RMSNorm before attention */
4831  rmsnorm_forward(input,
4832  ln1_gamma,
4833  ln1_out,
4834  NULL,
4835  1,
4837  aligned_embed_dim,
4838  1e-06f);
4839 
4840  /* Step 2: QKV projection */
4841  /* Q projection: Q4_K -> gemm_nt_q4_k */
4842  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4843 
4844  /* K projection: Q4_K -> gemm_nt_q4_k */
4845  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4846 
4847  /* V projection: Q4_K -> gemm_nt_q4_k */
4848  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
4849 
4850  /* Step 3: RoPE */
4851  rope_forward_qk(q_token,
4852  k_token,
4853  rope_cos,
4854  rope_sin,
4855  H,
4856  H_kv,
4857  1,
4858  head_dim,
4859  aligned_head_dim,
4860  token_index);
4861 
4862  /* Step 4: KV cache write */
4863  kv_cache_write_head_major(k_token,
4864  v_token,
4865  k_cache,
4866  v_cache,
4867  H_kv,
4868  token_index,
4869  aligned_context_window,
4870  head_dim,
4871  aligned_head_dim);
4872 
4873  /* Step 5: Attention (decode) */
4875  k_cache,
4876  v_cache,
4877  attn_token,
4878  H,
4879  H_kv,
4880  token_index + 1,
4881  aligned_context_window,
4882  head_dim,
4883  aligned_head_dim);
4884 
4885  /* Step 6: Output projection */
4886  /* WO projection: Q4_K -> gemm_nt_q4_k */
4887  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4888 
4889  /* Step 7: Residual add */
4890  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
4891 
4892  /* Step 8: RMSNorm before MLP */
4893  rmsnorm_forward(residual1,
4894  ln2_gamma,
4895  ln2_out,
4896  NULL,
4897  1,
4899  aligned_embed_dim,
4900  1e-06f);
4901 
4902  /* Step 9: MLP (SwiGLU) */
4903  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
4904  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4905 
4906  /* SwiGLU activation */
4907  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4908 
4909  /* Down projection: Q4_K -> gemm_nt_q4_k */
4910  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4911 
4912  /* Step 10: Final residual add */
4913  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
4914 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_1_prefill()

static void qwen2_0_5b_decode_layer_1_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 372 of file v6.6/test_generated/ck-kernel-inference.c.

379  {
381 
382  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[0].output);
383  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
384  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
385  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
386  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
387  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
388  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
389  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
390  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
391  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
392  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
393  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
394  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
395  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
396  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
397  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
398 
399  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
400  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
401  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
402  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
403  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
404  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
405  const float *BQ = NULL;
406  const float *BK = NULL;
407  const float *BV = NULL;
408  const float *BO = NULL;
409  const float *B1 = NULL;
410  const float *B2 = NULL;
411 
414 
415  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
416  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
417  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
418  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
419  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
420  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
421 
422  /* RMSNorm before attention */
423  rmsnorm_forward(input,
424  ln1_gamma,
425  ln1_out,
426  NULL,
427  num_tokens,
429  aligned_embed_dim,
430  1e-06f);
431 
432  /* Q projection (head-major) */
433  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
434  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
435  for (int h = 0; h < H; ++h) {
436  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
437  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
438  float *q_h = q + (size_t)h * q_head_stride;
439  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
440  }
441 
442  /* K projection (head-major) */
443  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
444  const uint8_t *WK_bytes = (const uint8_t *)WK;
445  for (int h = 0; h < H_kv; ++h) {
446  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
447  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
448  float *k_h = k + (size_t)h * kv_head_stride;
449  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
450  }
451 
452  /* V projection (head-major) */
453  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
454  const uint8_t *WV_bytes = (const uint8_t *)WV;
455  for (int h = 0; h < H_kv; ++h) {
456  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
457  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
458  float *v_h = v + (size_t)h * kv_head_stride;
459  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
460  }
461 
462  /* RoPE */
463  rope_forward_qk(q,
464  k,
465  rope_cos,
466  rope_sin,
467  H,
468  H_kv,
469  num_tokens,
470  head_dim,
471  aligned_head_dim,
472  0);
473 
474  /* Attention (prefill, causal) */
476  k,
477  v,
478  attn_out,
479  H,
480  H_kv,
481  num_tokens,
482  head_dim,
483  aligned_head_dim);
484 
485  /* Output projection (flatten head-major to token-major) */
486  const int K = H * aligned_head_dim;
487  if (K != aligned_embed_dim) {
488  return;
489  }
490  const float *proj_in = attn_out;
491  if (H > 1) {
492  if (!proj_scratch) {
493  return;
494  }
495  for (int t = 0; t < num_tokens; ++t) {
496  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
497  for (int h = 0; h < H; ++h) {
498  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
499  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
500  src,
501  (size_t)aligned_head_dim * sizeof(float));
502  }
503  }
504  proj_in = proj_scratch;
505  }
506  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
507 
508  /* Residual add */
509  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
510 
511  /* RMSNorm before MLP */
512  rmsnorm_forward(residual1,
513  ln2_gamma,
514  ln2_out,
515  NULL,
516  num_tokens,
518  aligned_embed_dim,
519  1e-06f);
520 
521  /* MLP (SwiGLU) */
522  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
523  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
524  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
525 
526  /* Final residual add */
527  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
528 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_20_decode()

static void qwen2_0_5b_decode_layer_20_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7385 of file v6.6/test_generated/ck-kernel-inference.c.

7392  {
7394 
7395  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[19].output);
7396 
7397  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7398  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7399  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7400  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7401  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7402  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7403  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7404  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7405  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7406  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7407  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7408 
7409  /* Weights (explicit types for layer 20) */
7410  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7411  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7412  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7413  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7414  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7415  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7416 
7419 
7420  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7421  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7422  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7423 
7424  float q_token[H * aligned_head_dim];
7425  float k_token[H_kv * aligned_head_dim];
7426  float v_token[H_kv * aligned_head_dim];
7427  float attn_token[H * aligned_head_dim];
7428 
7429  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7430  float fc1_out[2 * aligned_intermediate_dim];
7431  float swiglu_out[aligned_intermediate_dim];
7432 
7433  /* Step 1: RMSNorm before attention */
7434  rmsnorm_forward(input,
7435  ln1_gamma,
7436  ln1_out,
7437  NULL,
7438  1,
7440  aligned_embed_dim,
7441  1e-06f);
7442 
7443  /* Step 2: QKV projection */
7444  /* Q projection: Q4_K -> gemm_nt_q4_k */
7445  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7446 
7447  /* K projection: Q4_K -> gemm_nt_q4_k */
7448  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7449 
7450  /* V projection: Q4_K -> gemm_nt_q4_k */
7451  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7452 
7453  /* Step 3: RoPE */
7454  rope_forward_qk(q_token,
7455  k_token,
7456  rope_cos,
7457  rope_sin,
7458  H,
7459  H_kv,
7460  1,
7461  head_dim,
7462  aligned_head_dim,
7463  token_index);
7464 
7465  /* Step 4: KV cache write */
7466  kv_cache_write_head_major(k_token,
7467  v_token,
7468  k_cache,
7469  v_cache,
7470  H_kv,
7471  token_index,
7472  aligned_context_window,
7473  head_dim,
7474  aligned_head_dim);
7475 
7476  /* Step 5: Attention (decode) */
7478  k_cache,
7479  v_cache,
7480  attn_token,
7481  H,
7482  H_kv,
7483  token_index + 1,
7484  aligned_context_window,
7485  head_dim,
7486  aligned_head_dim);
7487 
7488  /* Step 6: Output projection */
7489  /* WO projection: Q4_K -> gemm_nt_q4_k */
7490  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7491 
7492  /* Step 7: Residual add */
7493  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7494 
7495  /* Step 8: RMSNorm before MLP */
7496  rmsnorm_forward(residual1,
7497  ln2_gamma,
7498  ln2_out,
7499  NULL,
7500  1,
7502  aligned_embed_dim,
7503  1e-06f);
7504 
7505  /* Step 9: MLP (SwiGLU) */
7506  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7507  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7508 
7509  /* SwiGLU activation */
7510  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7511 
7512  /* Down projection: Q4_K -> gemm_nt_q4_k */
7513  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7514 
7515  /* Step 10: Final residual add */
7516  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7517 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_20_prefill()

static void qwen2_0_5b_decode_layer_20_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3431 of file v6.6/test_generated/ck-kernel-inference.c.

3438  {
3440 
3441  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[19].output);
3442  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3443  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3444  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3445  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3446  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3447  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3448  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3449  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3450  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3451  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3452  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3453  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3454  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3455  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3456  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3457 
3458  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3459  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3460  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3461  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3462  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3463  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3464  const float *BQ = NULL;
3465  const float *BK = NULL;
3466  const float *BV = NULL;
3467  const float *BO = NULL;
3468  const float *B1 = NULL;
3469  const float *B2 = NULL;
3470 
3473 
3474  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3475  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3476  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3477  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3478  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3479  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3480 
3481  /* RMSNorm before attention */
3482  rmsnorm_forward(input,
3483  ln1_gamma,
3484  ln1_out,
3485  NULL,
3486  num_tokens,
3488  aligned_embed_dim,
3489  1e-06f);
3490 
3491  /* Q projection (head-major) */
3492  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3493  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3494  for (int h = 0; h < H; ++h) {
3495  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3496  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3497  float *q_h = q + (size_t)h * q_head_stride;
3498  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3499  }
3500 
3501  /* K projection (head-major) */
3502  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3503  const uint8_t *WK_bytes = (const uint8_t *)WK;
3504  for (int h = 0; h < H_kv; ++h) {
3505  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3506  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3507  float *k_h = k + (size_t)h * kv_head_stride;
3508  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3509  }
3510 
3511  /* V projection (head-major) */
3512  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3513  const uint8_t *WV_bytes = (const uint8_t *)WV;
3514  for (int h = 0; h < H_kv; ++h) {
3515  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3516  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3517  float *v_h = v + (size_t)h * kv_head_stride;
3518  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3519  }
3520 
3521  /* RoPE */
3522  rope_forward_qk(q,
3523  k,
3524  rope_cos,
3525  rope_sin,
3526  H,
3527  H_kv,
3528  num_tokens,
3529  head_dim,
3530  aligned_head_dim,
3531  0);
3532 
3533  /* Attention (prefill, causal) */
3535  k,
3536  v,
3537  attn_out,
3538  H,
3539  H_kv,
3540  num_tokens,
3541  head_dim,
3542  aligned_head_dim);
3543 
3544  /* Output projection (flatten head-major to token-major) */
3545  const int K = H * aligned_head_dim;
3546  if (K != aligned_embed_dim) {
3547  return;
3548  }
3549  const float *proj_in = attn_out;
3550  if (H > 1) {
3551  if (!proj_scratch) {
3552  return;
3553  }
3554  for (int t = 0; t < num_tokens; ++t) {
3555  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3556  for (int h = 0; h < H; ++h) {
3557  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3558  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3559  src,
3560  (size_t)aligned_head_dim * sizeof(float));
3561  }
3562  }
3563  proj_in = proj_scratch;
3564  }
3565  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3566 
3567  /* Residual add */
3568  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3569 
3570  /* RMSNorm before MLP */
3571  rmsnorm_forward(residual1,
3572  ln2_gamma,
3573  ln2_out,
3574  NULL,
3575  num_tokens,
3577  aligned_embed_dim,
3578  1e-06f);
3579 
3580  /* MLP (SwiGLU) */
3581  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3582  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3583  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3584 
3585  /* Final residual add */
3586  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3587 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_21_decode()

static void qwen2_0_5b_decode_layer_21_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7522 of file v6.6/test_generated/ck-kernel-inference.c.

7529  {
7531 
7532  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[20].output);
7533 
7534  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7535  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7536  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7537  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7538  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7539  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7540  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7541  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7542  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7543  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7544  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7545 
7546  /* Weights (explicit types for layer 21) */
7547  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7548  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7549  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7550  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7551  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7552  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7553 
7556 
7557  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7558  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7559  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7560 
7561  float q_token[H * aligned_head_dim];
7562  float k_token[H_kv * aligned_head_dim];
7563  float v_token[H_kv * aligned_head_dim];
7564  float attn_token[H * aligned_head_dim];
7565 
7566  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7567  float fc1_out[2 * aligned_intermediate_dim];
7568  float swiglu_out[aligned_intermediate_dim];
7569 
7570  /* Step 1: RMSNorm before attention */
7571  rmsnorm_forward(input,
7572  ln1_gamma,
7573  ln1_out,
7574  NULL,
7575  1,
7577  aligned_embed_dim,
7578  1e-06f);
7579 
7580  /* Step 2: QKV projection */
7581  /* Q projection: Q4_K -> gemm_nt_q4_k */
7582  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7583 
7584  /* K projection: Q4_K -> gemm_nt_q4_k */
7585  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7586 
7587  /* V projection: Q4_K -> gemm_nt_q4_k */
7588  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7589 
7590  /* Step 3: RoPE */
7591  rope_forward_qk(q_token,
7592  k_token,
7593  rope_cos,
7594  rope_sin,
7595  H,
7596  H_kv,
7597  1,
7598  head_dim,
7599  aligned_head_dim,
7600  token_index);
7601 
7602  /* Step 4: KV cache write */
7603  kv_cache_write_head_major(k_token,
7604  v_token,
7605  k_cache,
7606  v_cache,
7607  H_kv,
7608  token_index,
7609  aligned_context_window,
7610  head_dim,
7611  aligned_head_dim);
7612 
7613  /* Step 5: Attention (decode) */
7615  k_cache,
7616  v_cache,
7617  attn_token,
7618  H,
7619  H_kv,
7620  token_index + 1,
7621  aligned_context_window,
7622  head_dim,
7623  aligned_head_dim);
7624 
7625  /* Step 6: Output projection */
7626  /* WO projection: Q4_K -> gemm_nt_q4_k */
7627  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7628 
7629  /* Step 7: Residual add */
7630  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7631 
7632  /* Step 8: RMSNorm before MLP */
7633  rmsnorm_forward(residual1,
7634  ln2_gamma,
7635  ln2_out,
7636  NULL,
7637  1,
7639  aligned_embed_dim,
7640  1e-06f);
7641 
7642  /* Step 9: MLP (SwiGLU) */
7643  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7644  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7645 
7646  /* SwiGLU activation */
7647  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7648 
7649  /* Down projection: Q4_K -> gemm_nt_q4_k */
7650  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7651 
7652  /* Step 10: Final residual add */
7653  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7654 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_21_prefill()

static void qwen2_0_5b_decode_layer_21_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3592 of file v6.6/test_generated/ck-kernel-inference.c.

3599  {
3601 
3602  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[20].output);
3603  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3604  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3605  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3606  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3607  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3608  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3609  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3610  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3611  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3612  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3613  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3614  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3615  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3616  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3617  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3618 
3619  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3620  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3621  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3622  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3623  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3624  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3625  const float *BQ = NULL;
3626  const float *BK = NULL;
3627  const float *BV = NULL;
3628  const float *BO = NULL;
3629  const float *B1 = NULL;
3630  const float *B2 = NULL;
3631 
3634 
3635  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3636  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3637  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3638  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3639  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3640  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3641 
3642  /* RMSNorm before attention */
3643  rmsnorm_forward(input,
3644  ln1_gamma,
3645  ln1_out,
3646  NULL,
3647  num_tokens,
3649  aligned_embed_dim,
3650  1e-06f);
3651 
3652  /* Q projection (head-major) */
3653  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3654  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3655  for (int h = 0; h < H; ++h) {
3656  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3657  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3658  float *q_h = q + (size_t)h * q_head_stride;
3659  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3660  }
3661 
3662  /* K projection (head-major) */
3663  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3664  const uint8_t *WK_bytes = (const uint8_t *)WK;
3665  for (int h = 0; h < H_kv; ++h) {
3666  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3667  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3668  float *k_h = k + (size_t)h * kv_head_stride;
3669  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3670  }
3671 
3672  /* V projection (head-major) */
3673  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3674  const uint8_t *WV_bytes = (const uint8_t *)WV;
3675  for (int h = 0; h < H_kv; ++h) {
3676  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3677  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3678  float *v_h = v + (size_t)h * kv_head_stride;
3679  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3680  }
3681 
3682  /* RoPE */
3683  rope_forward_qk(q,
3684  k,
3685  rope_cos,
3686  rope_sin,
3687  H,
3688  H_kv,
3689  num_tokens,
3690  head_dim,
3691  aligned_head_dim,
3692  0);
3693 
3694  /* Attention (prefill, causal) */
3696  k,
3697  v,
3698  attn_out,
3699  H,
3700  H_kv,
3701  num_tokens,
3702  head_dim,
3703  aligned_head_dim);
3704 
3705  /* Output projection (flatten head-major to token-major) */
3706  const int K = H * aligned_head_dim;
3707  if (K != aligned_embed_dim) {
3708  return;
3709  }
3710  const float *proj_in = attn_out;
3711  if (H > 1) {
3712  if (!proj_scratch) {
3713  return;
3714  }
3715  for (int t = 0; t < num_tokens; ++t) {
3716  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3717  for (int h = 0; h < H; ++h) {
3718  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3719  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3720  src,
3721  (size_t)aligned_head_dim * sizeof(float));
3722  }
3723  }
3724  proj_in = proj_scratch;
3725  }
3726  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3727 
3728  /* Residual add */
3729  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3730 
3731  /* RMSNorm before MLP */
3732  rmsnorm_forward(residual1,
3733  ln2_gamma,
3734  ln2_out,
3735  NULL,
3736  num_tokens,
3738  aligned_embed_dim,
3739  1e-06f);
3740 
3741  /* MLP (SwiGLU) */
3742  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3743  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3744  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3745 
3746  /* Final residual add */
3747  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3748 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_22_decode()

static void qwen2_0_5b_decode_layer_22_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7659 of file v6.6/test_generated/ck-kernel-inference.c.

7666  {
7668 
7669  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[21].output);
7670 
7671  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7672  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7673  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7674  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7675  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7676  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7677  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7678  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7679  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7680  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7681  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7682 
7683  /* Weights (explicit types for layer 22) */
7684  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7685  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7686  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7687  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7688  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7689  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7690 
7693 
7694  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7695  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7696  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7697 
7698  float q_token[H * aligned_head_dim];
7699  float k_token[H_kv * aligned_head_dim];
7700  float v_token[H_kv * aligned_head_dim];
7701  float attn_token[H * aligned_head_dim];
7702 
7703  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7704  float fc1_out[2 * aligned_intermediate_dim];
7705  float swiglu_out[aligned_intermediate_dim];
7706 
7707  /* Step 1: RMSNorm before attention */
7708  rmsnorm_forward(input,
7709  ln1_gamma,
7710  ln1_out,
7711  NULL,
7712  1,
7714  aligned_embed_dim,
7715  1e-06f);
7716 
7717  /* Step 2: QKV projection */
7718  /* Q projection: Q4_K -> gemm_nt_q4_k */
7719  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7720 
7721  /* K projection: Q4_K -> gemm_nt_q4_k */
7722  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7723 
7724  /* V projection: Q4_K -> gemm_nt_q4_k */
7725  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7726 
7727  /* Step 3: RoPE */
7728  rope_forward_qk(q_token,
7729  k_token,
7730  rope_cos,
7731  rope_sin,
7732  H,
7733  H_kv,
7734  1,
7735  head_dim,
7736  aligned_head_dim,
7737  token_index);
7738 
7739  /* Step 4: KV cache write */
7740  kv_cache_write_head_major(k_token,
7741  v_token,
7742  k_cache,
7743  v_cache,
7744  H_kv,
7745  token_index,
7746  aligned_context_window,
7747  head_dim,
7748  aligned_head_dim);
7749 
7750  /* Step 5: Attention (decode) */
7752  k_cache,
7753  v_cache,
7754  attn_token,
7755  H,
7756  H_kv,
7757  token_index + 1,
7758  aligned_context_window,
7759  head_dim,
7760  aligned_head_dim);
7761 
7762  /* Step 6: Output projection */
7763  /* WO projection: Q4_K -> gemm_nt_q4_k */
7764  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7765 
7766  /* Step 7: Residual add */
7767  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7768 
7769  /* Step 8: RMSNorm before MLP */
7770  rmsnorm_forward(residual1,
7771  ln2_gamma,
7772  ln2_out,
7773  NULL,
7774  1,
7776  aligned_embed_dim,
7777  1e-06f);
7778 
7779  /* Step 9: MLP (SwiGLU) */
7780  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7781  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7782 
7783  /* SwiGLU activation */
7784  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7785 
7786  /* Down projection: Q4_K -> gemm_nt_q4_k */
7787  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7788 
7789  /* Step 10: Final residual add */
7790  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7791 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_22_prefill()

static void qwen2_0_5b_decode_layer_22_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3753 of file v6.6/test_generated/ck-kernel-inference.c.

3760  {
3762 
3763  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[21].output);
3764  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3765  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3766  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3767  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3768  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3769  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3770  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3771  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3772  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3773  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3774  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3775  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3776  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3777  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3778  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3779 
3780  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3781  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3782  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3783  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3784  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3785  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3786  const float *BQ = NULL;
3787  const float *BK = NULL;
3788  const float *BV = NULL;
3789  const float *BO = NULL;
3790  const float *B1 = NULL;
3791  const float *B2 = NULL;
3792 
3795 
3796  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3797  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3798  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3799  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3800  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3801  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3802 
3803  /* RMSNorm before attention */
3804  rmsnorm_forward(input,
3805  ln1_gamma,
3806  ln1_out,
3807  NULL,
3808  num_tokens,
3810  aligned_embed_dim,
3811  1e-06f);
3812 
3813  /* Q projection (head-major) */
3814  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3815  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3816  for (int h = 0; h < H; ++h) {
3817  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3818  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3819  float *q_h = q + (size_t)h * q_head_stride;
3820  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3821  }
3822 
3823  /* K projection (head-major) */
3824  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3825  const uint8_t *WK_bytes = (const uint8_t *)WK;
3826  for (int h = 0; h < H_kv; ++h) {
3827  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3828  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3829  float *k_h = k + (size_t)h * kv_head_stride;
3830  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3831  }
3832 
3833  /* V projection (head-major) */
3834  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3835  const uint8_t *WV_bytes = (const uint8_t *)WV;
3836  for (int h = 0; h < H_kv; ++h) {
3837  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3838  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3839  float *v_h = v + (size_t)h * kv_head_stride;
3840  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3841  }
3842 
3843  /* RoPE */
3844  rope_forward_qk(q,
3845  k,
3846  rope_cos,
3847  rope_sin,
3848  H,
3849  H_kv,
3850  num_tokens,
3851  head_dim,
3852  aligned_head_dim,
3853  0);
3854 
3855  /* Attention (prefill, causal) */
3857  k,
3858  v,
3859  attn_out,
3860  H,
3861  H_kv,
3862  num_tokens,
3863  head_dim,
3864  aligned_head_dim);
3865 
3866  /* Output projection (flatten head-major to token-major) */
3867  const int K = H * aligned_head_dim;
3868  if (K != aligned_embed_dim) {
3869  return;
3870  }
3871  const float *proj_in = attn_out;
3872  if (H > 1) {
3873  if (!proj_scratch) {
3874  return;
3875  }
3876  for (int t = 0; t < num_tokens; ++t) {
3877  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3878  for (int h = 0; h < H; ++h) {
3879  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3880  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3881  src,
3882  (size_t)aligned_head_dim * sizeof(float));
3883  }
3884  }
3885  proj_in = proj_scratch;
3886  }
3887  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3888 
3889  /* Residual add */
3890  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3891 
3892  /* RMSNorm before MLP */
3893  rmsnorm_forward(residual1,
3894  ln2_gamma,
3895  ln2_out,
3896  NULL,
3897  num_tokens,
3899  aligned_embed_dim,
3900  1e-06f);
3901 
3902  /* MLP (SwiGLU) */
3903  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3904  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3905  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3906 
3907  /* Final residual add */
3908  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3909 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_23_decode()

static void qwen2_0_5b_decode_layer_23_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7796 of file v6.6/test_generated/ck-kernel-inference.c.

7803  {
7805 
7806  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[22].output);
7807 
7808  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7809  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7810  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7811  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7812  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7813  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7814  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7815  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7816  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7817  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7818  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7819 
7820  /* Weights (explicit types for layer 23) */
7821  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7822  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7823  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7824  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7825  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7826  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7827 
7830 
7831  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7832  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7833  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7834 
7835  float q_token[H * aligned_head_dim];
7836  float k_token[H_kv * aligned_head_dim];
7837  float v_token[H_kv * aligned_head_dim];
7838  float attn_token[H * aligned_head_dim];
7839 
7840  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7841  float fc1_out[2 * aligned_intermediate_dim];
7842  float swiglu_out[aligned_intermediate_dim];
7843 
7844  /* Step 1: RMSNorm before attention */
7845  rmsnorm_forward(input,
7846  ln1_gamma,
7847  ln1_out,
7848  NULL,
7849  1,
7851  aligned_embed_dim,
7852  1e-06f);
7853 
7854  /* Step 2: QKV projection */
7855  /* Q projection: Q4_K -> gemm_nt_q4_k */
7856  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7857 
7858  /* K projection: Q4_K -> gemm_nt_q4_k */
7859  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7860 
7861  /* V projection: Q4_K -> gemm_nt_q4_k */
7862  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7863 
7864  /* Step 3: RoPE */
7865  rope_forward_qk(q_token,
7866  k_token,
7867  rope_cos,
7868  rope_sin,
7869  H,
7870  H_kv,
7871  1,
7872  head_dim,
7873  aligned_head_dim,
7874  token_index);
7875 
7876  /* Step 4: KV cache write */
7877  kv_cache_write_head_major(k_token,
7878  v_token,
7879  k_cache,
7880  v_cache,
7881  H_kv,
7882  token_index,
7883  aligned_context_window,
7884  head_dim,
7885  aligned_head_dim);
7886 
7887  /* Step 5: Attention (decode) */
7889  k_cache,
7890  v_cache,
7891  attn_token,
7892  H,
7893  H_kv,
7894  token_index + 1,
7895  aligned_context_window,
7896  head_dim,
7897  aligned_head_dim);
7898 
7899  /* Step 6: Output projection */
7900  /* WO projection: Q4_K -> gemm_nt_q4_k */
7901  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7902 
7903  /* Step 7: Residual add */
7904  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7905 
7906  /* Step 8: RMSNorm before MLP */
7907  rmsnorm_forward(residual1,
7908  ln2_gamma,
7909  ln2_out,
7910  NULL,
7911  1,
7913  aligned_embed_dim,
7914  1e-06f);
7915 
7916  /* Step 9: MLP (SwiGLU) */
7917  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
7918  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7919 
7920  /* SwiGLU activation */
7921  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7922 
7923  /* Down projection: Q4_K -> gemm_nt_q4_k */
7924  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7925 
7926  /* Step 10: Final residual add */
7927  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7928 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_23_prefill()

static void qwen2_0_5b_decode_layer_23_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3914 of file v6.6/test_generated/ck-kernel-inference.c.

3921  {
3923 
3924  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[22].output);
3925  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3926  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3927  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3928  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3929  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3930  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3931  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3932  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3933  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3934  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3935  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3936  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3937  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3938  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3939  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3940 
3941  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3942  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3943  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3944  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3945  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3946  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3947  const float *BQ = NULL;
3948  const float *BK = NULL;
3949  const float *BV = NULL;
3950  const float *BO = NULL;
3951  const float *B1 = NULL;
3952  const float *B2 = NULL;
3953 
3956 
3957  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3958  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3959  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3960  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3961  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3962  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3963 
3964  /* RMSNorm before attention */
3965  rmsnorm_forward(input,
3966  ln1_gamma,
3967  ln1_out,
3968  NULL,
3969  num_tokens,
3971  aligned_embed_dim,
3972  1e-06f);
3973 
3974  /* Q projection (head-major) */
3975  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3976  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3977  for (int h = 0; h < H; ++h) {
3978  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3979  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3980  float *q_h = q + (size_t)h * q_head_stride;
3981  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3982  }
3983 
3984  /* K projection (head-major) */
3985  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3986  const uint8_t *WK_bytes = (const uint8_t *)WK;
3987  for (int h = 0; h < H_kv; ++h) {
3988  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3989  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3990  float *k_h = k + (size_t)h * kv_head_stride;
3991  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3992  }
3993 
3994  /* V projection (head-major) */
3995  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3996  const uint8_t *WV_bytes = (const uint8_t *)WV;
3997  for (int h = 0; h < H_kv; ++h) {
3998  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3999  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
4000  float *v_h = v + (size_t)h * kv_head_stride;
4001  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4002  }
4003 
4004  /* RoPE */
4005  rope_forward_qk(q,
4006  k,
4007  rope_cos,
4008  rope_sin,
4009  H,
4010  H_kv,
4011  num_tokens,
4012  head_dim,
4013  aligned_head_dim,
4014  0);
4015 
4016  /* Attention (prefill, causal) */
4018  k,
4019  v,
4020  attn_out,
4021  H,
4022  H_kv,
4023  num_tokens,
4024  head_dim,
4025  aligned_head_dim);
4026 
4027  /* Output projection (flatten head-major to token-major) */
4028  const int K = H * aligned_head_dim;
4029  if (K != aligned_embed_dim) {
4030  return;
4031  }
4032  const float *proj_in = attn_out;
4033  if (H > 1) {
4034  if (!proj_scratch) {
4035  return;
4036  }
4037  for (int t = 0; t < num_tokens; ++t) {
4038  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
4039  for (int h = 0; h < H; ++h) {
4040  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
4041  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
4042  src,
4043  (size_t)aligned_head_dim * sizeof(float));
4044  }
4045  }
4046  proj_in = proj_scratch;
4047  }
4048  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4049 
4050  /* Residual add */
4051  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
4052 
4053  /* RMSNorm before MLP */
4054  rmsnorm_forward(residual1,
4055  ln2_gamma,
4056  ln2_out,
4057  NULL,
4058  num_tokens,
4060  aligned_embed_dim,
4061  1e-06f);
4062 
4063  /* MLP (SwiGLU) */
4064  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4065  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4066  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4067 
4068  /* Final residual add */
4069  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
4070 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_2_decode()

static void qwen2_0_5b_decode_layer_2_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4919 of file v6.6/test_generated/ck-kernel-inference.c.

4926  {
4928 
4929  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[1].output);
4930 
4931  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4932  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4933  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4934  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4935  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4936  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4937  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4938  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4939  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4940  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4941  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4942 
4943  /* Weights (explicit types for layer 2) */
4944  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4945  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4946  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4947  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4948  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4949  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4950 
4953 
4954  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4955  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4956  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4957 
4958  float q_token[H * aligned_head_dim];
4959  float k_token[H_kv * aligned_head_dim];
4960  float v_token[H_kv * aligned_head_dim];
4961  float attn_token[H * aligned_head_dim];
4962 
4963  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4964  float fc1_out[2 * aligned_intermediate_dim];
4965  float swiglu_out[aligned_intermediate_dim];
4966 
4967  /* Step 1: RMSNorm before attention */
4968  rmsnorm_forward(input,
4969  ln1_gamma,
4970  ln1_out,
4971  NULL,
4972  1,
4974  aligned_embed_dim,
4975  1e-06f);
4976 
4977  /* Step 2: QKV projection */
4978  /* Q projection: Q4_K -> gemm_nt_q4_k */
4979  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4980 
4981  /* K projection: Q4_K -> gemm_nt_q4_k */
4982  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4983 
4984  /* V projection: Q4_K -> gemm_nt_q4_k */
4985  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
4986 
4987  /* Step 3: RoPE */
4988  rope_forward_qk(q_token,
4989  k_token,
4990  rope_cos,
4991  rope_sin,
4992  H,
4993  H_kv,
4994  1,
4995  head_dim,
4996  aligned_head_dim,
4997  token_index);
4998 
4999  /* Step 4: KV cache write */
5000  kv_cache_write_head_major(k_token,
5001  v_token,
5002  k_cache,
5003  v_cache,
5004  H_kv,
5005  token_index,
5006  aligned_context_window,
5007  head_dim,
5008  aligned_head_dim);
5009 
5010  /* Step 5: Attention (decode) */
5012  k_cache,
5013  v_cache,
5014  attn_token,
5015  H,
5016  H_kv,
5017  token_index + 1,
5018  aligned_context_window,
5019  head_dim,
5020  aligned_head_dim);
5021 
5022  /* Step 6: Output projection */
5023  /* WO projection: Q4_K -> gemm_nt_q4_k */
5024  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5025 
5026  /* Step 7: Residual add */
5027  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5028 
5029  /* Step 8: RMSNorm before MLP */
5030  rmsnorm_forward(residual1,
5031  ln2_gamma,
5032  ln2_out,
5033  NULL,
5034  1,
5036  aligned_embed_dim,
5037  1e-06f);
5038 
5039  /* Step 9: MLP (SwiGLU) */
5040  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5041  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5042 
5043  /* SwiGLU activation */
5044  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5045 
5046  /* Down projection: Q4_K -> gemm_nt_q4_k */
5047  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5048 
5049  /* Step 10: Final residual add */
5050  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5051 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_2_prefill()

static void qwen2_0_5b_decode_layer_2_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 533 of file v6.6/test_generated/ck-kernel-inference.c.

540  {
542 
543  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[1].output);
544  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
545  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
546  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
547  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
548  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
549  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
550  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
551  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
552  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
553  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
554  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
555  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
556  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
557  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
558  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
559 
560  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
561  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
562  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
563  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
564  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
565  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
566  const float *BQ = NULL;
567  const float *BK = NULL;
568  const float *BV = NULL;
569  const float *BO = NULL;
570  const float *B1 = NULL;
571  const float *B2 = NULL;
572 
575 
576  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
577  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
578  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
579  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
580  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
581  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
582 
583  /* RMSNorm before attention */
584  rmsnorm_forward(input,
585  ln1_gamma,
586  ln1_out,
587  NULL,
588  num_tokens,
590  aligned_embed_dim,
591  1e-06f);
592 
593  /* Q projection (head-major) */
594  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
595  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
596  for (int h = 0; h < H; ++h) {
597  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
598  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
599  float *q_h = q + (size_t)h * q_head_stride;
600  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
601  }
602 
603  /* K projection (head-major) */
604  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
605  const uint8_t *WK_bytes = (const uint8_t *)WK;
606  for (int h = 0; h < H_kv; ++h) {
607  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
608  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
609  float *k_h = k + (size_t)h * kv_head_stride;
610  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
611  }
612 
613  /* V projection (head-major) */
614  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
615  const uint8_t *WV_bytes = (const uint8_t *)WV;
616  for (int h = 0; h < H_kv; ++h) {
617  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
618  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
619  float *v_h = v + (size_t)h * kv_head_stride;
620  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
621  }
622 
623  /* RoPE */
624  rope_forward_qk(q,
625  k,
626  rope_cos,
627  rope_sin,
628  H,
629  H_kv,
630  num_tokens,
631  head_dim,
632  aligned_head_dim,
633  0);
634 
635  /* Attention (prefill, causal) */
637  k,
638  v,
639  attn_out,
640  H,
641  H_kv,
642  num_tokens,
643  head_dim,
644  aligned_head_dim);
645 
646  /* Output projection (flatten head-major to token-major) */
647  const int K = H * aligned_head_dim;
648  if (K != aligned_embed_dim) {
649  return;
650  }
651  const float *proj_in = attn_out;
652  if (H > 1) {
653  if (!proj_scratch) {
654  return;
655  }
656  for (int t = 0; t < num_tokens; ++t) {
657  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
658  for (int h = 0; h < H; ++h) {
659  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
660  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
661  src,
662  (size_t)aligned_head_dim * sizeof(float));
663  }
664  }
665  proj_in = proj_scratch;
666  }
667  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
668 
669  /* Residual add */
670  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
671 
672  /* RMSNorm before MLP */
673  rmsnorm_forward(residual1,
674  ln2_gamma,
675  ln2_out,
676  NULL,
677  num_tokens,
679  aligned_embed_dim,
680  1e-06f);
681 
682  /* MLP (SwiGLU) */
683  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
684  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
685  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
686 
687  /* Final residual add */
688  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
689 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_3_decode()

static void qwen2_0_5b_decode_layer_3_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5056 of file v6.6/test_generated/ck-kernel-inference.c.

5063  {
5065 
5066  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[2].output);
5067 
5068  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5069  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5070  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5071  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5072  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5073  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5074  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5075  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5076  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5077  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5078  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5079 
5080  /* Weights (explicit types for layer 3) */
5081  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5082  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5083  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5084  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5085  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5086  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5087 
5090 
5091  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5092  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5093  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5094 
5095  float q_token[H * aligned_head_dim];
5096  float k_token[H_kv * aligned_head_dim];
5097  float v_token[H_kv * aligned_head_dim];
5098  float attn_token[H * aligned_head_dim];
5099 
5100  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5101  float fc1_out[2 * aligned_intermediate_dim];
5102  float swiglu_out[aligned_intermediate_dim];
5103 
5104  /* Step 1: RMSNorm before attention */
5105  rmsnorm_forward(input,
5106  ln1_gamma,
5107  ln1_out,
5108  NULL,
5109  1,
5111  aligned_embed_dim,
5112  1e-06f);
5113 
5114  /* Step 2: QKV projection */
5115  /* Q projection: Q4_K -> gemm_nt_q4_k */
5116  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5117 
5118  /* K projection: Q4_K -> gemm_nt_q4_k */
5119  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5120 
5121  /* V projection: Q4_K -> gemm_nt_q4_k */
5122  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5123 
5124  /* Step 3: RoPE */
5125  rope_forward_qk(q_token,
5126  k_token,
5127  rope_cos,
5128  rope_sin,
5129  H,
5130  H_kv,
5131  1,
5132  head_dim,
5133  aligned_head_dim,
5134  token_index);
5135 
5136  /* Step 4: KV cache write */
5137  kv_cache_write_head_major(k_token,
5138  v_token,
5139  k_cache,
5140  v_cache,
5141  H_kv,
5142  token_index,
5143  aligned_context_window,
5144  head_dim,
5145  aligned_head_dim);
5146 
5147  /* Step 5: Attention (decode) */
5149  k_cache,
5150  v_cache,
5151  attn_token,
5152  H,
5153  H_kv,
5154  token_index + 1,
5155  aligned_context_window,
5156  head_dim,
5157  aligned_head_dim);
5158 
5159  /* Step 6: Output projection */
5160  /* WO projection: Q4_K -> gemm_nt_q4_k */
5161  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5162 
5163  /* Step 7: Residual add */
5164  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5165 
5166  /* Step 8: RMSNorm before MLP */
5167  rmsnorm_forward(residual1,
5168  ln2_gamma,
5169  ln2_out,
5170  NULL,
5171  1,
5173  aligned_embed_dim,
5174  1e-06f);
5175 
5176  /* Step 9: MLP (SwiGLU) */
5177  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5178  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5179 
5180  /* SwiGLU activation */
5181  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5182 
5183  /* Down projection: Q4_K -> gemm_nt_q4_k */
5184  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5185 
5186  /* Step 10: Final residual add */
5187  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5188 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_3_prefill()

static void qwen2_0_5b_decode_layer_3_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 694 of file v6.6/test_generated/ck-kernel-inference.c.

701  {
703 
704  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[2].output);
705  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
706  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
707  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
708  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
709  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
710  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
711  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
712  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
713  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
714  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
715  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
716  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
717  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
718  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
719  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
720 
721  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
722  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
723  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
724  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
725  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
726  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
727  const float *BQ = NULL;
728  const float *BK = NULL;
729  const float *BV = NULL;
730  const float *BO = NULL;
731  const float *B1 = NULL;
732  const float *B2 = NULL;
733 
736 
737  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
738  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
739  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
740  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
741  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
742  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
743 
744  /* RMSNorm before attention */
745  rmsnorm_forward(input,
746  ln1_gamma,
747  ln1_out,
748  NULL,
749  num_tokens,
751  aligned_embed_dim,
752  1e-06f);
753 
754  /* Q projection (head-major) */
755  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
756  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
757  for (int h = 0; h < H; ++h) {
758  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
759  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
760  float *q_h = q + (size_t)h * q_head_stride;
761  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
762  }
763 
764  /* K projection (head-major) */
765  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
766  const uint8_t *WK_bytes = (const uint8_t *)WK;
767  for (int h = 0; h < H_kv; ++h) {
768  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
769  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
770  float *k_h = k + (size_t)h * kv_head_stride;
771  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
772  }
773 
774  /* V projection (head-major) */
775  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
776  const uint8_t *WV_bytes = (const uint8_t *)WV;
777  for (int h = 0; h < H_kv; ++h) {
778  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
779  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
780  float *v_h = v + (size_t)h * kv_head_stride;
781  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
782  }
783 
784  /* RoPE */
785  rope_forward_qk(q,
786  k,
787  rope_cos,
788  rope_sin,
789  H,
790  H_kv,
791  num_tokens,
792  head_dim,
793  aligned_head_dim,
794  0);
795 
796  /* Attention (prefill, causal) */
798  k,
799  v,
800  attn_out,
801  H,
802  H_kv,
803  num_tokens,
804  head_dim,
805  aligned_head_dim);
806 
807  /* Output projection (flatten head-major to token-major) */
808  const int K = H * aligned_head_dim;
809  if (K != aligned_embed_dim) {
810  return;
811  }
812  const float *proj_in = attn_out;
813  if (H > 1) {
814  if (!proj_scratch) {
815  return;
816  }
817  for (int t = 0; t < num_tokens; ++t) {
818  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
819  for (int h = 0; h < H; ++h) {
820  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
821  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
822  src,
823  (size_t)aligned_head_dim * sizeof(float));
824  }
825  }
826  proj_in = proj_scratch;
827  }
828  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
829 
830  /* Residual add */
831  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
832 
833  /* RMSNorm before MLP */
834  rmsnorm_forward(residual1,
835  ln2_gamma,
836  ln2_out,
837  NULL,
838  num_tokens,
840  aligned_embed_dim,
841  1e-06f);
842 
843  /* MLP (SwiGLU) */
844  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
845  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
846  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
847 
848  /* Final residual add */
849  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
850 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_4_decode()

static void qwen2_0_5b_decode_layer_4_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5193 of file v6.6/test_generated/ck-kernel-inference.c.

5200  {
5202 
5203  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[3].output);
5204 
5205  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5206  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5207  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5208  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5209  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5210  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5211  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5212  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5213  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5214  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5215  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5216 
5217  /* Weights (explicit types for layer 4) */
5218  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5219  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5220  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5221  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5222  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5223  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5224 
5227 
5228  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5229  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5230  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5231 
5232  float q_token[H * aligned_head_dim];
5233  float k_token[H_kv * aligned_head_dim];
5234  float v_token[H_kv * aligned_head_dim];
5235  float attn_token[H * aligned_head_dim];
5236 
5237  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5238  float fc1_out[2 * aligned_intermediate_dim];
5239  float swiglu_out[aligned_intermediate_dim];
5240 
5241  /* Step 1: RMSNorm before attention */
5242  rmsnorm_forward(input,
5243  ln1_gamma,
5244  ln1_out,
5245  NULL,
5246  1,
5248  aligned_embed_dim,
5249  1e-06f);
5250 
5251  /* Step 2: QKV projection */
5252  /* Q projection: Q4_K -> gemm_nt_q4_k */
5253  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5254 
5255  /* K projection: Q4_K -> gemm_nt_q4_k */
5256  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5257 
5258  /* V projection: Q4_K -> gemm_nt_q4_k */
5259  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5260 
5261  /* Step 3: RoPE */
5262  rope_forward_qk(q_token,
5263  k_token,
5264  rope_cos,
5265  rope_sin,
5266  H,
5267  H_kv,
5268  1,
5269  head_dim,
5270  aligned_head_dim,
5271  token_index);
5272 
5273  /* Step 4: KV cache write */
5274  kv_cache_write_head_major(k_token,
5275  v_token,
5276  k_cache,
5277  v_cache,
5278  H_kv,
5279  token_index,
5280  aligned_context_window,
5281  head_dim,
5282  aligned_head_dim);
5283 
5284  /* Step 5: Attention (decode) */
5286  k_cache,
5287  v_cache,
5288  attn_token,
5289  H,
5290  H_kv,
5291  token_index + 1,
5292  aligned_context_window,
5293  head_dim,
5294  aligned_head_dim);
5295 
5296  /* Step 6: Output projection */
5297  /* WO projection: Q4_K -> gemm_nt_q4_k */
5298  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5299 
5300  /* Step 7: Residual add */
5301  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5302 
5303  /* Step 8: RMSNorm before MLP */
5304  rmsnorm_forward(residual1,
5305  ln2_gamma,
5306  ln2_out,
5307  NULL,
5308  1,
5310  aligned_embed_dim,
5311  1e-06f);
5312 
5313  /* Step 9: MLP (SwiGLU) */
5314  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5315  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5316 
5317  /* SwiGLU activation */
5318  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5319 
5320  /* Down projection: Q4_K -> gemm_nt_q4_k */
5321  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5322 
5323  /* Step 10: Final residual add */
5324  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5325 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_4_prefill()

static void qwen2_0_5b_decode_layer_4_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 855 of file v6.6/test_generated/ck-kernel-inference.c.

862  {
864 
865  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[3].output);
866  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
867  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
868  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
869  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
870  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
871  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
872  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
873  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
874  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
875  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
876  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
877  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
878  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
879  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
880  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
881 
882  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
883  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
884  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
885  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
886  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
887  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
888  const float *BQ = NULL;
889  const float *BK = NULL;
890  const float *BV = NULL;
891  const float *BO = NULL;
892  const float *B1 = NULL;
893  const float *B2 = NULL;
894 
897 
898  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
899  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
900  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
901  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
902  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
903  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
904 
905  /* RMSNorm before attention */
906  rmsnorm_forward(input,
907  ln1_gamma,
908  ln1_out,
909  NULL,
910  num_tokens,
912  aligned_embed_dim,
913  1e-06f);
914 
915  /* Q projection (head-major) */
916  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
917  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
918  for (int h = 0; h < H; ++h) {
919  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
920  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
921  float *q_h = q + (size_t)h * q_head_stride;
922  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
923  }
924 
925  /* K projection (head-major) */
926  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
927  const uint8_t *WK_bytes = (const uint8_t *)WK;
928  for (int h = 0; h < H_kv; ++h) {
929  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
930  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
931  float *k_h = k + (size_t)h * kv_head_stride;
932  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
933  }
934 
935  /* V projection (head-major) */
936  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
937  const uint8_t *WV_bytes = (const uint8_t *)WV;
938  for (int h = 0; h < H_kv; ++h) {
939  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
940  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
941  float *v_h = v + (size_t)h * kv_head_stride;
942  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
943  }
944 
945  /* RoPE */
946  rope_forward_qk(q,
947  k,
948  rope_cos,
949  rope_sin,
950  H,
951  H_kv,
952  num_tokens,
953  head_dim,
954  aligned_head_dim,
955  0);
956 
957  /* Attention (prefill, causal) */
959  k,
960  v,
961  attn_out,
962  H,
963  H_kv,
964  num_tokens,
965  head_dim,
966  aligned_head_dim);
967 
968  /* Output projection (flatten head-major to token-major) */
969  const int K = H * aligned_head_dim;
970  if (K != aligned_embed_dim) {
971  return;
972  }
973  const float *proj_in = attn_out;
974  if (H > 1) {
975  if (!proj_scratch) {
976  return;
977  }
978  for (int t = 0; t < num_tokens; ++t) {
979  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
980  for (int h = 0; h < H; ++h) {
981  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
982  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
983  src,
984  (size_t)aligned_head_dim * sizeof(float));
985  }
986  }
987  proj_in = proj_scratch;
988  }
989  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
990 
991  /* Residual add */
992  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
993 
994  /* RMSNorm before MLP */
995  rmsnorm_forward(residual1,
996  ln2_gamma,
997  ln2_out,
998  NULL,
999  num_tokens,
1001  aligned_embed_dim,
1002  1e-06f);
1003 
1004  /* MLP (SwiGLU) */
1005  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1006  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1007  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1008 
1009  /* Final residual add */
1010  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1011 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_5_decode()

static void qwen2_0_5b_decode_layer_5_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5330 of file v6.6/test_generated/ck-kernel-inference.c.

5337  {
5339 
5340  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[4].output);
5341 
5342  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5343  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5344  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5345  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5346  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5347  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5348  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5349  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5350  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5351  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5352  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5353 
5354  /* Weights (explicit types for layer 5) */
5355  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5356  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5357  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5358  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5359  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5360  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5361 
5364 
5365  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5366  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5367  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5368 
5369  float q_token[H * aligned_head_dim];
5370  float k_token[H_kv * aligned_head_dim];
5371  float v_token[H_kv * aligned_head_dim];
5372  float attn_token[H * aligned_head_dim];
5373 
5374  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5375  float fc1_out[2 * aligned_intermediate_dim];
5376  float swiglu_out[aligned_intermediate_dim];
5377 
5378  /* Step 1: RMSNorm before attention */
5379  rmsnorm_forward(input,
5380  ln1_gamma,
5381  ln1_out,
5382  NULL,
5383  1,
5385  aligned_embed_dim,
5386  1e-06f);
5387 
5388  /* Step 2: QKV projection */
5389  /* Q projection: Q4_K -> gemm_nt_q4_k */
5390  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5391 
5392  /* K projection: Q4_K -> gemm_nt_q4_k */
5393  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5394 
5395  /* V projection: Q4_K -> gemm_nt_q4_k */
5396  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5397 
5398  /* Step 3: RoPE */
5399  rope_forward_qk(q_token,
5400  k_token,
5401  rope_cos,
5402  rope_sin,
5403  H,
5404  H_kv,
5405  1,
5406  head_dim,
5407  aligned_head_dim,
5408  token_index);
5409 
5410  /* Step 4: KV cache write */
5411  kv_cache_write_head_major(k_token,
5412  v_token,
5413  k_cache,
5414  v_cache,
5415  H_kv,
5416  token_index,
5417  aligned_context_window,
5418  head_dim,
5419  aligned_head_dim);
5420 
5421  /* Step 5: Attention (decode) */
5423  k_cache,
5424  v_cache,
5425  attn_token,
5426  H,
5427  H_kv,
5428  token_index + 1,
5429  aligned_context_window,
5430  head_dim,
5431  aligned_head_dim);
5432 
5433  /* Step 6: Output projection */
5434  /* WO projection: Q4_K -> gemm_nt_q4_k */
5435  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5436 
5437  /* Step 7: Residual add */
5438  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5439 
5440  /* Step 8: RMSNorm before MLP */
5441  rmsnorm_forward(residual1,
5442  ln2_gamma,
5443  ln2_out,
5444  NULL,
5445  1,
5447  aligned_embed_dim,
5448  1e-06f);
5449 
5450  /* Step 9: MLP (SwiGLU) */
5451  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5452  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5453 
5454  /* SwiGLU activation */
5455  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5456 
5457  /* Down projection: Q4_K -> gemm_nt_q4_k */
5458  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5459 
5460  /* Step 10: Final residual add */
5461  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5462 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_5_prefill()

static void qwen2_0_5b_decode_layer_5_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1016 of file v6.6/test_generated/ck-kernel-inference.c.

1023  {
1025 
1026  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[4].output);
1027  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1028  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1029  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1030  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1031  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1032  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1033  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1034  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1035  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1036  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1037  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1038  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1039  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1040  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1041  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1042 
1043  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1044  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1045  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1046  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1047  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1048  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1049  const float *BQ = NULL;
1050  const float *BK = NULL;
1051  const float *BV = NULL;
1052  const float *BO = NULL;
1053  const float *B1 = NULL;
1054  const float *B2 = NULL;
1055 
1058 
1059  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1060  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1061  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1062  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1063  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1064  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1065 
1066  /* RMSNorm before attention */
1067  rmsnorm_forward(input,
1068  ln1_gamma,
1069  ln1_out,
1070  NULL,
1071  num_tokens,
1073  aligned_embed_dim,
1074  1e-06f);
1075 
1076  /* Q projection (head-major) */
1077  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1078  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1079  for (int h = 0; h < H; ++h) {
1080  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1081  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1082  float *q_h = q + (size_t)h * q_head_stride;
1083  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1084  }
1085 
1086  /* K projection (head-major) */
1087  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1088  const uint8_t *WK_bytes = (const uint8_t *)WK;
1089  for (int h = 0; h < H_kv; ++h) {
1090  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1091  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1092  float *k_h = k + (size_t)h * kv_head_stride;
1093  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1094  }
1095 
1096  /* V projection (head-major) */
1097  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1098  const uint8_t *WV_bytes = (const uint8_t *)WV;
1099  for (int h = 0; h < H_kv; ++h) {
1100  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1101  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1102  float *v_h = v + (size_t)h * kv_head_stride;
1103  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1104  }
1105 
1106  /* RoPE */
1107  rope_forward_qk(q,
1108  k,
1109  rope_cos,
1110  rope_sin,
1111  H,
1112  H_kv,
1113  num_tokens,
1114  head_dim,
1115  aligned_head_dim,
1116  0);
1117 
1118  /* Attention (prefill, causal) */
1120  k,
1121  v,
1122  attn_out,
1123  H,
1124  H_kv,
1125  num_tokens,
1126  head_dim,
1127  aligned_head_dim);
1128 
1129  /* Output projection (flatten head-major to token-major) */
1130  const int K = H * aligned_head_dim;
1131  if (K != aligned_embed_dim) {
1132  return;
1133  }
1134  const float *proj_in = attn_out;
1135  if (H > 1) {
1136  if (!proj_scratch) {
1137  return;
1138  }
1139  for (int t = 0; t < num_tokens; ++t) {
1140  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1141  for (int h = 0; h < H; ++h) {
1142  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1143  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1144  src,
1145  (size_t)aligned_head_dim * sizeof(float));
1146  }
1147  }
1148  proj_in = proj_scratch;
1149  }
1150  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1151 
1152  /* Residual add */
1153  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1154 
1155  /* RMSNorm before MLP */
1156  rmsnorm_forward(residual1,
1157  ln2_gamma,
1158  ln2_out,
1159  NULL,
1160  num_tokens,
1162  aligned_embed_dim,
1163  1e-06f);
1164 
1165  /* MLP (SwiGLU) */
1166  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1167  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1168  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1169 
1170  /* Final residual add */
1171  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1172 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_6_decode()

static void qwen2_0_5b_decode_layer_6_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5467 of file v6.6/test_generated/ck-kernel-inference.c.

5474  {
5476 
5477  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[5].output);
5478 
5479  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5480  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5481  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5482  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5483  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5484  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5485  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5486  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5487  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5488  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5489  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5490 
5491  /* Weights (explicit types for layer 6) */
5492  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5493  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5494  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5495  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5496  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5497  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5498 
5501 
5502  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5503  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5504  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5505 
5506  float q_token[H * aligned_head_dim];
5507  float k_token[H_kv * aligned_head_dim];
5508  float v_token[H_kv * aligned_head_dim];
5509  float attn_token[H * aligned_head_dim];
5510 
5511  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5512  float fc1_out[2 * aligned_intermediate_dim];
5513  float swiglu_out[aligned_intermediate_dim];
5514 
5515  /* Step 1: RMSNorm before attention */
5516  rmsnorm_forward(input,
5517  ln1_gamma,
5518  ln1_out,
5519  NULL,
5520  1,
5522  aligned_embed_dim,
5523  1e-06f);
5524 
5525  /* Step 2: QKV projection */
5526  /* Q projection: Q4_K -> gemm_nt_q4_k */
5527  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5528 
5529  /* K projection: Q4_K -> gemm_nt_q4_k */
5530  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5531 
5532  /* V projection: Q4_K -> gemm_nt_q4_k */
5533  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5534 
5535  /* Step 3: RoPE */
5536  rope_forward_qk(q_token,
5537  k_token,
5538  rope_cos,
5539  rope_sin,
5540  H,
5541  H_kv,
5542  1,
5543  head_dim,
5544  aligned_head_dim,
5545  token_index);
5546 
5547  /* Step 4: KV cache write */
5548  kv_cache_write_head_major(k_token,
5549  v_token,
5550  k_cache,
5551  v_cache,
5552  H_kv,
5553  token_index,
5554  aligned_context_window,
5555  head_dim,
5556  aligned_head_dim);
5557 
5558  /* Step 5: Attention (decode) */
5560  k_cache,
5561  v_cache,
5562  attn_token,
5563  H,
5564  H_kv,
5565  token_index + 1,
5566  aligned_context_window,
5567  head_dim,
5568  aligned_head_dim);
5569 
5570  /* Step 6: Output projection */
5571  /* WO projection: Q4_K -> gemm_nt_q4_k */
5572  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5573 
5574  /* Step 7: Residual add */
5575  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5576 
5577  /* Step 8: RMSNorm before MLP */
5578  rmsnorm_forward(residual1,
5579  ln2_gamma,
5580  ln2_out,
5581  NULL,
5582  1,
5584  aligned_embed_dim,
5585  1e-06f);
5586 
5587  /* Step 9: MLP (SwiGLU) */
5588  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5589  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5590 
5591  /* SwiGLU activation */
5592  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5593 
5594  /* Down projection: Q4_K -> gemm_nt_q4_k */
5595  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5596 
5597  /* Step 10: Final residual add */
5598  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5599 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_6_prefill()

static void qwen2_0_5b_decode_layer_6_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1177 of file v6.6/test_generated/ck-kernel-inference.c.

1184  {
1186 
1187  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[5].output);
1188  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1189  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1190  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1191  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1192  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1193  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1194  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1195  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1196  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1197  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1198  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1199  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1200  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1201  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1202  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1203 
1204  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1205  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1206  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1207  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1208  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1209  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1210  const float *BQ = NULL;
1211  const float *BK = NULL;
1212  const float *BV = NULL;
1213  const float *BO = NULL;
1214  const float *B1 = NULL;
1215  const float *B2 = NULL;
1216 
1219 
1220  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1221  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1222  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1223  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1224  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1225  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1226 
1227  /* RMSNorm before attention */
1228  rmsnorm_forward(input,
1229  ln1_gamma,
1230  ln1_out,
1231  NULL,
1232  num_tokens,
1234  aligned_embed_dim,
1235  1e-06f);
1236 
1237  /* Q projection (head-major) */
1238  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1239  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1240  for (int h = 0; h < H; ++h) {
1241  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1242  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1243  float *q_h = q + (size_t)h * q_head_stride;
1244  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1245  }
1246 
1247  /* K projection (head-major) */
1248  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1249  const uint8_t *WK_bytes = (const uint8_t *)WK;
1250  for (int h = 0; h < H_kv; ++h) {
1251  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1252  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1253  float *k_h = k + (size_t)h * kv_head_stride;
1254  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1255  }
1256 
1257  /* V projection (head-major) */
1258  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1259  const uint8_t *WV_bytes = (const uint8_t *)WV;
1260  for (int h = 0; h < H_kv; ++h) {
1261  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1262  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1263  float *v_h = v + (size_t)h * kv_head_stride;
1264  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1265  }
1266 
1267  /* RoPE */
1268  rope_forward_qk(q,
1269  k,
1270  rope_cos,
1271  rope_sin,
1272  H,
1273  H_kv,
1274  num_tokens,
1275  head_dim,
1276  aligned_head_dim,
1277  0);
1278 
1279  /* Attention (prefill, causal) */
1281  k,
1282  v,
1283  attn_out,
1284  H,
1285  H_kv,
1286  num_tokens,
1287  head_dim,
1288  aligned_head_dim);
1289 
1290  /* Output projection (flatten head-major to token-major) */
1291  const int K = H * aligned_head_dim;
1292  if (K != aligned_embed_dim) {
1293  return;
1294  }
1295  const float *proj_in = attn_out;
1296  if (H > 1) {
1297  if (!proj_scratch) {
1298  return;
1299  }
1300  for (int t = 0; t < num_tokens; ++t) {
1301  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1302  for (int h = 0; h < H; ++h) {
1303  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1304  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1305  src,
1306  (size_t)aligned_head_dim * sizeof(float));
1307  }
1308  }
1309  proj_in = proj_scratch;
1310  }
1311  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1312 
1313  /* Residual add */
1314  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1315 
1316  /* RMSNorm before MLP */
1317  rmsnorm_forward(residual1,
1318  ln2_gamma,
1319  ln2_out,
1320  NULL,
1321  num_tokens,
1323  aligned_embed_dim,
1324  1e-06f);
1325 
1326  /* MLP (SwiGLU) */
1327  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1328  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1329  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1330 
1331  /* Final residual add */
1332  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1333 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_7_decode()

static void qwen2_0_5b_decode_layer_7_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5604 of file v6.6/test_generated/ck-kernel-inference.c.

5611  {
5613 
5614  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[6].output);
5615 
5616  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5617  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5618  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5619  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5620  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5621  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5622  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5623  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5624  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5625  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5626  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5627 
5628  /* Weights (explicit types for layer 7) */
5629  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5630  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5631  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5632  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5633  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5634  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5635 
5638 
5639  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5640  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5641  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5642 
5643  float q_token[H * aligned_head_dim];
5644  float k_token[H_kv * aligned_head_dim];
5645  float v_token[H_kv * aligned_head_dim];
5646  float attn_token[H * aligned_head_dim];
5647 
5648  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5649  float fc1_out[2 * aligned_intermediate_dim];
5650  float swiglu_out[aligned_intermediate_dim];
5651 
5652  /* Step 1: RMSNorm before attention */
5653  rmsnorm_forward(input,
5654  ln1_gamma,
5655  ln1_out,
5656  NULL,
5657  1,
5659  aligned_embed_dim,
5660  1e-06f);
5661 
5662  /* Step 2: QKV projection */
5663  /* Q projection: Q4_K -> gemm_nt_q4_k */
5664  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5665 
5666  /* K projection: Q4_K -> gemm_nt_q4_k */
5667  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5668 
5669  /* V projection: Q4_K -> gemm_nt_q4_k */
5670  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5671 
5672  /* Step 3: RoPE */
5673  rope_forward_qk(q_token,
5674  k_token,
5675  rope_cos,
5676  rope_sin,
5677  H,
5678  H_kv,
5679  1,
5680  head_dim,
5681  aligned_head_dim,
5682  token_index);
5683 
5684  /* Step 4: KV cache write */
5685  kv_cache_write_head_major(k_token,
5686  v_token,
5687  k_cache,
5688  v_cache,
5689  H_kv,
5690  token_index,
5691  aligned_context_window,
5692  head_dim,
5693  aligned_head_dim);
5694 
5695  /* Step 5: Attention (decode) */
5697  k_cache,
5698  v_cache,
5699  attn_token,
5700  H,
5701  H_kv,
5702  token_index + 1,
5703  aligned_context_window,
5704  head_dim,
5705  aligned_head_dim);
5706 
5707  /* Step 6: Output projection */
5708  /* WO projection: Q4_K -> gemm_nt_q4_k */
5709  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5710 
5711  /* Step 7: Residual add */
5712  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5713 
5714  /* Step 8: RMSNorm before MLP */
5715  rmsnorm_forward(residual1,
5716  ln2_gamma,
5717  ln2_out,
5718  NULL,
5719  1,
5721  aligned_embed_dim,
5722  1e-06f);
5723 
5724  /* Step 9: MLP (SwiGLU) */
5725  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5726  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5727 
5728  /* SwiGLU activation */
5729  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5730 
5731  /* Down projection: Q4_K -> gemm_nt_q4_k */
5732  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5733 
5734  /* Step 10: Final residual add */
5735  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5736 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_7_prefill()

static void qwen2_0_5b_decode_layer_7_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1338 of file v6.6/test_generated/ck-kernel-inference.c.

1345  {
1347 
1348  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[6].output);
1349  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1350  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1351  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1352  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1353  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1354  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1355  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1356  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1357  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1358  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1359  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1360  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1361  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1362  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1363  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1364 
1365  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1366  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1367  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1368  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1369  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1370  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1371  const float *BQ = NULL;
1372  const float *BK = NULL;
1373  const float *BV = NULL;
1374  const float *BO = NULL;
1375  const float *B1 = NULL;
1376  const float *B2 = NULL;
1377 
1380 
1381  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1382  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1383  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1384  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1385  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1386  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1387 
1388  /* RMSNorm before attention */
1389  rmsnorm_forward(input,
1390  ln1_gamma,
1391  ln1_out,
1392  NULL,
1393  num_tokens,
1395  aligned_embed_dim,
1396  1e-06f);
1397 
1398  /* Q projection (head-major) */
1399  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1400  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1401  for (int h = 0; h < H; ++h) {
1402  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1403  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1404  float *q_h = q + (size_t)h * q_head_stride;
1405  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1406  }
1407 
1408  /* K projection (head-major) */
1409  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1410  const uint8_t *WK_bytes = (const uint8_t *)WK;
1411  for (int h = 0; h < H_kv; ++h) {
1412  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1413  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1414  float *k_h = k + (size_t)h * kv_head_stride;
1415  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1416  }
1417 
1418  /* V projection (head-major) */
1419  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1420  const uint8_t *WV_bytes = (const uint8_t *)WV;
1421  for (int h = 0; h < H_kv; ++h) {
1422  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1423  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1424  float *v_h = v + (size_t)h * kv_head_stride;
1425  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1426  }
1427 
1428  /* RoPE */
1429  rope_forward_qk(q,
1430  k,
1431  rope_cos,
1432  rope_sin,
1433  H,
1434  H_kv,
1435  num_tokens,
1436  head_dim,
1437  aligned_head_dim,
1438  0);
1439 
1440  /* Attention (prefill, causal) */
1442  k,
1443  v,
1444  attn_out,
1445  H,
1446  H_kv,
1447  num_tokens,
1448  head_dim,
1449  aligned_head_dim);
1450 
1451  /* Output projection (flatten head-major to token-major) */
1452  const int K = H * aligned_head_dim;
1453  if (K != aligned_embed_dim) {
1454  return;
1455  }
1456  const float *proj_in = attn_out;
1457  if (H > 1) {
1458  if (!proj_scratch) {
1459  return;
1460  }
1461  for (int t = 0; t < num_tokens; ++t) {
1462  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1463  for (int h = 0; h < H; ++h) {
1464  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1465  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1466  src,
1467  (size_t)aligned_head_dim * sizeof(float));
1468  }
1469  }
1470  proj_in = proj_scratch;
1471  }
1472  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1473 
1474  /* Residual add */
1475  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1476 
1477  /* RMSNorm before MLP */
1478  rmsnorm_forward(residual1,
1479  ln2_gamma,
1480  ln2_out,
1481  NULL,
1482  num_tokens,
1484  aligned_embed_dim,
1485  1e-06f);
1486 
1487  /* MLP (SwiGLU) */
1488  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1489  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1490  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1491 
1492  /* Final residual add */
1493  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1494 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_8_decode()

static void qwen2_0_5b_decode_layer_8_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5741 of file v6.6/test_generated/ck-kernel-inference.c.

5748  {
5750 
5751  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[7].output);
5752 
5753  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5754  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5755  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5756  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5757  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5758  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5759  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5760  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5761  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5762  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5763  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5764 
5765  /* Weights (explicit types for layer 8) */
5766  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5767  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5768  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5769  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5770  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5771  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5772 
5775 
5776  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5777  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5778  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5779 
5780  float q_token[H * aligned_head_dim];
5781  float k_token[H_kv * aligned_head_dim];
5782  float v_token[H_kv * aligned_head_dim];
5783  float attn_token[H * aligned_head_dim];
5784 
5785  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5786  float fc1_out[2 * aligned_intermediate_dim];
5787  float swiglu_out[aligned_intermediate_dim];
5788 
5789  /* Step 1: RMSNorm before attention */
5790  rmsnorm_forward(input,
5791  ln1_gamma,
5792  ln1_out,
5793  NULL,
5794  1,
5796  aligned_embed_dim,
5797  1e-06f);
5798 
5799  /* Step 2: QKV projection */
5800  /* Q projection: Q4_K -> gemm_nt_q4_k */
5801  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5802 
5803  /* K projection: Q4_K -> gemm_nt_q4_k */
5804  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5805 
5806  /* V projection: Q4_K -> gemm_nt_q4_k */
5807  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5808 
5809  /* Step 3: RoPE */
5810  rope_forward_qk(q_token,
5811  k_token,
5812  rope_cos,
5813  rope_sin,
5814  H,
5815  H_kv,
5816  1,
5817  head_dim,
5818  aligned_head_dim,
5819  token_index);
5820 
5821  /* Step 4: KV cache write */
5822  kv_cache_write_head_major(k_token,
5823  v_token,
5824  k_cache,
5825  v_cache,
5826  H_kv,
5827  token_index,
5828  aligned_context_window,
5829  head_dim,
5830  aligned_head_dim);
5831 
5832  /* Step 5: Attention (decode) */
5834  k_cache,
5835  v_cache,
5836  attn_token,
5837  H,
5838  H_kv,
5839  token_index + 1,
5840  aligned_context_window,
5841  head_dim,
5842  aligned_head_dim);
5843 
5844  /* Step 6: Output projection */
5845  /* WO projection: Q4_K -> gemm_nt_q4_k */
5846  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5847 
5848  /* Step 7: Residual add */
5849  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5850 
5851  /* Step 8: RMSNorm before MLP */
5852  rmsnorm_forward(residual1,
5853  ln2_gamma,
5854  ln2_out,
5855  NULL,
5856  1,
5858  aligned_embed_dim,
5859  1e-06f);
5860 
5861  /* Step 9: MLP (SwiGLU) */
5862  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
5863  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5864 
5865  /* SwiGLU activation */
5866  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5867 
5868  /* Down projection: Q4_K -> gemm_nt_q4_k */
5869  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5870 
5871  /* Step 10: Final residual add */
5872  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5873 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_8_prefill()

static void qwen2_0_5b_decode_layer_8_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1499 of file v6.6/test_generated/ck-kernel-inference.c.

1506  {
1508 
1509  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[7].output);
1510  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1511  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1512  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1513  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1514  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1515  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1516  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1517  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1518  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1519  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1520  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1521  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1522  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1523  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1524  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1525 
1526  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1527  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1528  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1529  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1530  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1531  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1532  const float *BQ = NULL;
1533  const float *BK = NULL;
1534  const float *BV = NULL;
1535  const float *BO = NULL;
1536  const float *B1 = NULL;
1537  const float *B2 = NULL;
1538 
1541 
1542  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1543  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1544  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1545  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1546  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1547  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1548 
1549  /* RMSNorm before attention */
1550  rmsnorm_forward(input,
1551  ln1_gamma,
1552  ln1_out,
1553  NULL,
1554  num_tokens,
1556  aligned_embed_dim,
1557  1e-06f);
1558 
1559  /* Q projection (head-major) */
1560  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1561  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1562  for (int h = 0; h < H; ++h) {
1563  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1564  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1565  float *q_h = q + (size_t)h * q_head_stride;
1566  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1567  }
1568 
1569  /* K projection (head-major) */
1570  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1571  const uint8_t *WK_bytes = (const uint8_t *)WK;
1572  for (int h = 0; h < H_kv; ++h) {
1573  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1574  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1575  float *k_h = k + (size_t)h * kv_head_stride;
1576  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1577  }
1578 
1579  /* V projection (head-major) */
1580  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1581  const uint8_t *WV_bytes = (const uint8_t *)WV;
1582  for (int h = 0; h < H_kv; ++h) {
1583  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1584  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1585  float *v_h = v + (size_t)h * kv_head_stride;
1586  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1587  }
1588 
1589  /* RoPE */
1590  rope_forward_qk(q,
1591  k,
1592  rope_cos,
1593  rope_sin,
1594  H,
1595  H_kv,
1596  num_tokens,
1597  head_dim,
1598  aligned_head_dim,
1599  0);
1600 
1601  /* Attention (prefill, causal) */
1603  k,
1604  v,
1605  attn_out,
1606  H,
1607  H_kv,
1608  num_tokens,
1609  head_dim,
1610  aligned_head_dim);
1611 
1612  /* Output projection (flatten head-major to token-major) */
1613  const int K = H * aligned_head_dim;
1614  if (K != aligned_embed_dim) {
1615  return;
1616  }
1617  const float *proj_in = attn_out;
1618  if (H > 1) {
1619  if (!proj_scratch) {
1620  return;
1621  }
1622  for (int t = 0; t < num_tokens; ++t) {
1623  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1624  for (int h = 0; h < H; ++h) {
1625  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1626  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1627  src,
1628  (size_t)aligned_head_dim * sizeof(float));
1629  }
1630  }
1631  proj_in = proj_scratch;
1632  }
1633  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1634 
1635  /* Residual add */
1636  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1637 
1638  /* RMSNorm before MLP */
1639  rmsnorm_forward(residual1,
1640  ln2_gamma,
1641  ln2_out,
1642  NULL,
1643  num_tokens,
1645  aligned_embed_dim,
1646  1e-06f);
1647 
1648  /* MLP (SwiGLU) */
1649  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1650  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1651  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1652 
1653  /* Final residual add */
1654  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1655 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_9_decode()

static void qwen2_0_5b_decode_layer_9_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5878 of file v6.6/test_generated/ck-kernel-inference.c.

5885  {
5887 
5888  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[8].output);
5889 
5890  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5891  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5892  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5893  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5894  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5895  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5896  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5897  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5898  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5899  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5900  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5901 
5902  /* Weights (explicit types for layer 9) */
5903  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5904  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5905  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5906  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5907  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5908  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5909 
5912 
5913  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5914  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5915  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5916 
5917  float q_token[H * aligned_head_dim];
5918  float k_token[H_kv * aligned_head_dim];
5919  float v_token[H_kv * aligned_head_dim];
5920  float attn_token[H * aligned_head_dim];
5921 
5922  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5923  float fc1_out[2 * aligned_intermediate_dim];
5924  float swiglu_out[aligned_intermediate_dim];
5925 
5926  /* Step 1: RMSNorm before attention */
5927  rmsnorm_forward(input,
5928  ln1_gamma,
5929  ln1_out,
5930  NULL,
5931  1,
5933  aligned_embed_dim,
5934  1e-06f);
5935 
5936  /* Step 2: QKV projection */
5937  /* Q projection: Q4_K -> gemm_nt_q4_k */
5938  gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5939 
5940  /* K projection: Q4_K -> gemm_nt_q4_k */
5941  gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5942 
5943  /* V projection: Q4_K -> gemm_nt_q4_k */
5944  gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5945 
5946  /* Step 3: RoPE */
5947  rope_forward_qk(q_token,
5948  k_token,
5949  rope_cos,
5950  rope_sin,
5951  H,
5952  H_kv,
5953  1,
5954  head_dim,
5955  aligned_head_dim,
5956  token_index);
5957 
5958  /* Step 4: KV cache write */
5959  kv_cache_write_head_major(k_token,
5960  v_token,
5961  k_cache,
5962  v_cache,
5963  H_kv,
5964  token_index,
5965  aligned_context_window,
5966  head_dim,
5967  aligned_head_dim);
5968 
5969  /* Step 5: Attention (decode) */
5971  k_cache,
5972  v_cache,
5973  attn_token,
5974  H,
5975  H_kv,
5976  token_index + 1,
5977  aligned_context_window,
5978  head_dim,
5979  aligned_head_dim);
5980 
5981  /* Step 6: Output projection */
5982  /* WO projection: Q4_K -> gemm_nt_q4_k */
5983  gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5984 
5985  /* Step 7: Residual add */
5986  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5987 
5988  /* Step 8: RMSNorm before MLP */
5989  rmsnorm_forward(residual1,
5990  ln2_gamma,
5991  ln2_out,
5992  NULL,
5993  1,
5995  aligned_embed_dim,
5996  1e-06f);
5997 
5998  /* Step 9: MLP (SwiGLU) */
5999  /* Gate+Up projection: Q4_K -> gemm_nt_q4_k */
6000  gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6001 
6002  /* SwiGLU activation */
6003  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6004 
6005  /* Down projection: Q4_K -> gemm_nt_q4_k */
6006  gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6007 
6008  /* Step 10: Final residual add */
6009  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6010 }

References attention_forward_decode_head_major_gqa_regular(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, kv_cache_write_head_major(), QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_9_prefill()

static void qwen2_0_5b_decode_layer_9_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1660 of file v6.6/test_generated/ck-kernel-inference.c.

1667  {
1669 
1670  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[8].output);
1671  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1672  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1673  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1674  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1675  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1676  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1677  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1678  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1679  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1680  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1681  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1682  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1683  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1684  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1685  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1686 
1687  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1688  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1689  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1690  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1691  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1692  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1693  const float *BQ = NULL;
1694  const float *BK = NULL;
1695  const float *BV = NULL;
1696  const float *BO = NULL;
1697  const float *B1 = NULL;
1698  const float *B2 = NULL;
1699 
1702 
1703  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1704  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1705  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1706  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1707  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1708  const size_t kv_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1709 
1710  /* RMSNorm before attention */
1711  rmsnorm_forward(input,
1712  ln1_gamma,
1713  ln1_out,
1714  NULL,
1715  num_tokens,
1717  aligned_embed_dim,
1718  1e-06f);
1719 
1720  /* Q projection (head-major) */
1721  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1722  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1723  for (int h = 0; h < H; ++h) {
1724  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1725  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1726  float *q_h = q + (size_t)h * q_head_stride;
1727  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1728  }
1729 
1730  /* K projection (head-major) */
1731  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1732  const uint8_t *WK_bytes = (const uint8_t *)WK;
1733  for (int h = 0; h < H_kv; ++h) {
1734  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1735  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1736  float *k_h = k + (size_t)h * kv_head_stride;
1737  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1738  }
1739 
1740  /* V projection (head-major) */
1741  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1742  const uint8_t *WV_bytes = (const uint8_t *)WV;
1743  for (int h = 0; h < H_kv; ++h) {
1744  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1745  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1746  float *v_h = v + (size_t)h * kv_head_stride;
1747  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1748  }
1749 
1750  /* RoPE */
1751  rope_forward_qk(q,
1752  k,
1753  rope_cos,
1754  rope_sin,
1755  H,
1756  H_kv,
1757  num_tokens,
1758  head_dim,
1759  aligned_head_dim,
1760  0);
1761 
1762  /* Attention (prefill, causal) */
1764  k,
1765  v,
1766  attn_out,
1767  H,
1768  H_kv,
1769  num_tokens,
1770  head_dim,
1771  aligned_head_dim);
1772 
1773  /* Output projection (flatten head-major to token-major) */
1774  const int K = H * aligned_head_dim;
1775  if (K != aligned_embed_dim) {
1776  return;
1777  }
1778  const float *proj_in = attn_out;
1779  if (H > 1) {
1780  if (!proj_scratch) {
1781  return;
1782  }
1783  for (int t = 0; t < num_tokens; ++t) {
1784  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1785  for (int h = 0; h < H; ++h) {
1786  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1787  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1788  src,
1789  (size_t)aligned_head_dim * sizeof(float));
1790  }
1791  }
1792  proj_in = proj_scratch;
1793  }
1794  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1795 
1796  /* Residual add */
1797  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1798 
1799  /* RMSNorm before MLP */
1800  rmsnorm_forward(residual1,
1801  ln2_gamma,
1802  ln2_out,
1803  NULL,
1804  num_tokens,
1806  aligned_embed_dim,
1807  1e-06f);
1808 
1809  /* MLP (SwiGLU) */
1810  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1811  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1812  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1813 
1814  /* Final residual add */
1815  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1816 }

References attention_forward_causal_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_model_allocate()

int qwen2_0_5b_decode_model_allocate ( QWEN2_0_5B_DECODEModel model)

Definition at line 88 of file v6.6/test_generated/ck-kernel-inference.c.

88  {
89  size_t total = QWEN2_0_5B_DECODE_TOTAL_BYTES;
90 
91 #ifdef __linux__
92  model->base = mmap(NULL, total,
93  PROT_READ | PROT_WRITE,
94  MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
95  -1, 0);
96  if (model->base == MAP_FAILED) {
97  model->base = mmap(NULL, total,
98  PROT_READ | PROT_WRITE,
99  MAP_PRIVATE | MAP_ANONYMOUS,
100  -1, 0);
101  }
102  if (model->base == MAP_FAILED) {
103  perror("mmap failed");
104  return -1;
105  }
106 #else
107  model->base = aligned_alloc(64, total);
108  if (!model->base) {
109  perror("aligned_alloc failed");
110  return -1;
111  }
112 #endif
113 
114  model->total_bytes = total;
115 
116  /* Initialize magic header */
117  MagicHeader *header = (MagicHeader *)model->base;
118  header->magic = QWEN2_0_5B_DECODE_MAGIC;
119  header->version = 5;
120  header->total_bytes = QWEN2_0_5B_DECODE_TOTAL_BYTES;
121  header->weight_bytes = QWEN2_0_5B_DECODE_WEIGHT_BYTES;
122  header->activation_bytes = QWEN2_0_5B_DECODE_ACTIVATION_BYTES;
123  header->num_layers = QWEN2_0_5B_DECODE_NUM_LAYERS;
124  header->embed_dim = QWEN2_0_5B_DECODE_EMBED_DIM;
125  header->num_heads = QWEN2_0_5B_DECODE_NUM_HEADS;
126  header->vocab_size = QWEN2_0_5B_DECODE_VOCAB_SIZE;
127  header->max_seq_len = QWEN2_0_5B_DECODE_MAX_SEQ_LEN;
128  header->canary_count = QWEN2_0_5B_DECODE_CANARY_COUNT;
129 
130  /* Initialize canary guards */
131  for (int i = 0; i < QWEN2_0_5B_DECODE_CANARY_COUNT; i++) {
132  uint32_t *ptr = (uint32_t*)((char*)model->base + QWEN2_0_5B_DECODE_CANARIES[i].offset);
133  for (int j = 0; j < (QWEN2_0_5B_DECODE_CANARY_SIZE / 4); j++) {
135  }
136  }
137 
138  return 0;
139 }
#define QWEN2_0_5B_DECODE_TOTAL_BYTES
#define QWEN2_0_5B_DECODE_ACTIVATION_BYTES
#define QWEN2_0_5B_DECODE_MAX_SEQ_LEN
#define QWEN2_0_5B_DECODE_WEIGHT_BYTES
#define QWEN2_0_5B_DECODE_CANARY_VALUE
#define QWEN2_0_5B_DECODE_CANARY_SIZE
static const QWEN2_0_5B_DECODECanary QWEN2_0_5B_DECODE_CANARIES[]

References QWEN2_0_5B_DECODEModel::base, MagicHeader, QWEN2_0_5B_DECODECanary::offset, QWEN2_0_5B_DECODE_ACTIVATION_BYTES, QWEN2_0_5B_DECODE_CANARIES, QWEN2_0_5B_DECODE_CANARY_COUNT, QWEN2_0_5B_DECODE_CANARY_SIZE, QWEN2_0_5B_DECODE_CANARY_VALUE, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_MAGIC, QWEN2_0_5B_DECODE_MAX_SEQ_LEN, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_LAYERS, QWEN2_0_5B_DECODE_TOTAL_BYTES, QWEN2_0_5B_DECODE_VOCAB_SIZE, QWEN2_0_5B_DECODE_WEIGHT_BYTES, and QWEN2_0_5B_DECODEModel::total_bytes.

◆ qwen2_0_5b_decode_model_free()

void qwen2_0_5b_decode_model_free ( QWEN2_0_5B_DECODEModel model)

Definition at line 141 of file v6.6/test_generated/ck-kernel-inference.c.

141  {
142  if (!model || !model->base) return;
143 #ifdef __linux__
144  munmap(model->base, model->total_bytes);
145 #else
146  free(model->base);
147 #endif
148  model->base = NULL;
149  model->total_bytes = 0;
150 }

References QWEN2_0_5B_DECODEModel::base, and QWEN2_0_5B_DECODEModel::total_bytes.

◆ qwen2_0_5b_decode_precompute_rope()

void qwen2_0_5b_decode_precompute_rope ( QWEN2_0_5B_DECODEModel model)

Definition at line 186 of file v6.6/test_generated/ck-kernel-inference.c.

186  {
187  const int T = QWEN2_0_5B_DECODE_MAX_SEQ_LEN;
188  const int D = QWEN2_0_5B_DECODE_HEAD_DIM / 2;
189  const float theta = 1000000.0f;
190 
193 
194  for (int pos = 0; pos < T; pos++) {
195  for (int i = 0; i < D; i++) {
196  float freq = 1.0f / powf(theta, (float)(2 * i) / (float)(D * 2));
197  float angle = (float)pos * freq;
198  cos_ptr[pos * D + i] = cosf(angle);
199  sin_ptr[pos * D + i] = sinf(angle);
200  }
201  }
202 }

References QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_MAX_SEQ_LEN, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, and QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache.

◆ qwen2_0_5b_decode_residual_add_token_major()

static void qwen2_0_5b_decode_residual_add_token_major ( const float *  a,
const float *  b,
float *  out,
int  tokens,
int  aligned_embed_dim 
)
static

Definition at line 43 of file v6.6/test_generated/ck-kernel-inference.c.

49  {
50  if (!a || !b || !out) {
51  return;
52  }
53  for (int t = 0; t < tokens; ++t) {
54  const float *pa = a + (size_t)t * (size_t)aligned_embed_dim;
55  const float *pb = b + (size_t)t * (size_t)aligned_embed_dim;
56  float *pc = out + (size_t)t * (size_t)aligned_embed_dim;
57  for (int d = 0; d < aligned_embed_dim; ++d) {
58  pc[d] = pa[d] + pb[d];
59  }
60  }
61 }

Referenced by qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ qwen2_0_5b_decode_verify_canaries()

int qwen2_0_5b_decode_verify_canaries ( QWEN2_0_5B_DECODEModel model)

Definition at line 152 of file v6.6/test_generated/ck-kernel-inference.c.

152  {
153  int errors = 0;
154  uint32_t *ptr;
155 
156  for (int i = 0; i < QWEN2_0_5B_DECODE_CANARY_COUNT; i++) {
157  ptr = (uint32_t*)((char*)model->base + QWEN2_0_5B_DECODE_CANARIES[i].offset);
158  for (int j = 0; j < 4; j++) {
159  if (ptr[j] != QWEN2_0_5B_DECODE_CANARY_VALUE) {
160  fprintf(stderr, "CANARY CORRUPTION: %s at offset 0x%lX\n",
162  QWEN2_0_5B_DECODE_CANARIES[i].offset);
163  errors++;
164  break;
165  }
166  }
167  }
168 
169  return errors;
170 }

References QWEN2_0_5B_DECODEModel::base, QWEN2_0_5B_DECODECanary::offset, QWEN2_0_5B_DECODE_CANARIES, QWEN2_0_5B_DECODE_CANARY_COUNT, and QWEN2_0_5B_DECODE_CANARY_VALUE.

Variable Documentation

◆ MagicHeader

MagicHeader