← Back to C-Kernel-Engine Docs Doxygen Source Documentation
v6.6/test_generated/qwen2_int8.c File Reference

AUTO-GENERATED: qwen2_0.5b_decode Implementation (IR v6.6 - Explicit Unrolled) More...

#include "ck-kernel-inference.h"
#include "ckernel_engine.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <math.h>
#include "ck_model_api.h"

Go to the source code of this file.

Macros

#define _GNU_SOURCE   /* For MAP_ANONYMOUS, MAP_HUGETLB */
 

Functions

struct __attribute__ ((packed))
 
 _Static_assert (sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
 
void * ck_model_create (void)
 
void ck_model_decode (void *model, const int *token, int token_index)
 
void ck_model_forward (void *model, const int *tokens, int num_tokens)
 
void ck_model_free (void *model)
 
void * ck_model_get_base (void *model)
 
const CKModelConfigck_model_get_config (void)
 
float * ck_model_get_logits (void *model)
 
size_t ck_model_get_total_bytes (void *model)
 
void ck_model_precompute_rope (void *model)
 
int ck_model_verify_canaries (void *model)
 
static int qwen2_0_5b_decode_align_elems (int elems, int elem_bytes, int align_bytes)
 
void qwen2_0_5b_decode_decode (QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
 
static void qwen2_0_5b_decode_decode_token (QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
 
void qwen2_0_5b_decode_forward (QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
 
static void qwen2_0_5b_decode_forward_prefill_impl (QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
 
static void qwen2_0_5b_decode_layer_0_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_0_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_10_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_10_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_11_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_11_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_12_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_12_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_13_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_13_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_14_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_14_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_15_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_15_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_16_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_16_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_17_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_17_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_18_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_18_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_19_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_19_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_1_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_1_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_20_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_20_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_21_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_21_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_22_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_22_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_23_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_23_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_2_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_2_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_3_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_3_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_4_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_4_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_5_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_5_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_6_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_6_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_7_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_7_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_8_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_8_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_9_decode (QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
static void qwen2_0_5b_decode_layer_9_prefill (QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
 
int qwen2_0_5b_decode_model_allocate (QWEN2_0_5B_DECODEModel *model)
 
void qwen2_0_5b_decode_model_free (QWEN2_0_5B_DECODEModel *model)
 
void qwen2_0_5b_decode_precompute_rope (QWEN2_0_5B_DECODEModel *model)
 
static void qwen2_0_5b_decode_residual_add_token_major (const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
 
int qwen2_0_5b_decode_verify_canaries (QWEN2_0_5B_DECODEModel *model)
 

Variables

static CKModelConfig g_model_config
 
 MagicHeader
 

Detailed Description

AUTO-GENERATED: qwen2_0.5b_decode Implementation (IR v6.6 - Explicit Unrolled)

Generated: 2026-01-12T11:58:55.212793 UTC Total Memory: 3.57 GB Mode: decode Layers: 24 (fully unrolled)

Per-layer quant types: Layer 0: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k Layer 1: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k Layer 2: wq=q4_k wk=q4_k wv=q4_k wo=q4_k w1=q4_k w2=q4_k ... (21 more layers)

DO NOT EDIT - Regenerate with build_ir_v6.6.py or codegen_v6.6.py

Definition in file v6.6/test_generated/qwen2_int8.c.

Macro Definition Documentation

◆ _GNU_SOURCE

#define _GNU_SOURCE   /* For MAP_ANONYMOUS, MAP_HUGETLB */

Definition at line 19 of file v6.6/test_generated/qwen2_int8.c.

Function Documentation

◆ __attribute__()

struct __attribute__ ( (packed)  )

Definition at line 43 of file v6.6/test_generated/qwen2_int8.c.

67  {
68  uint32_t magic; /* 0x434B454E */
69  uint32_t version; /* IR version */
70  uint64_t total_bytes;
71  uint64_t weight_bytes;
72  uint64_t activation_bytes;
73  uint32_t num_layers;
74  uint32_t embed_dim;
75  uint32_t num_heads;
76  uint32_t vocab_size;
77  uint32_t max_seq_len;
78  uint32_t canary_count;
79  uint8_t reserved[8]; /* Pad to 64 bytes */
80 } MagicHeader;
int vocab_size
Definition: true_bpe.h:185

◆ _Static_assert()

_Static_assert ( sizeof(MagicHeader = =64,
"MagicHeader must be 64 bytes"   
)

◆ ck_model_create()

void* ck_model_create ( void  )

Create and allocate model memory. Returns opaque model pointer, or NULL on failure.

Definition at line 8873 of file v6.6/test_generated/qwen2_int8.c.

8873  {
8874  QWEN2_0_5B_DECODEModel *model = malloc(sizeof(QWEN2_0_5B_DECODEModel));
8875  if (!model) return NULL;
8876  if (qwen2_0_5b_decode_model_allocate(model) != 0) {
8877  free(model);
8878  return NULL;
8879  }
8880  return model;
8881 }
int qwen2_0_5b_decode_model_allocate(QWEN2_0_5B_DECODEModel *model)

References model_model_allocate(), and qwen2_0_5b_decode_model_allocate().

Referenced by main().

◆ ck_model_decode()

void ck_model_decode ( void *  model,
const int *  token,
int  token_index 
)

Decode single token at position token_index. Used for autoregressive generation.

Definition at line 8897 of file v6.6/test_generated/qwen2_int8.c.

8897  {
8898  qwen2_0_5b_decode_decode((QWEN2_0_5B_DECODEModel *)model, token, token_index);
8899 }
const char * token
Definition: tokenizer.h:306
void qwen2_0_5b_decode_decode(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)

References model_decode(), qwen2_0_5b_decode_decode(), and token.

Referenced by run_benchmark(), and run_generation_test().

◆ ck_model_forward()

void ck_model_forward ( void *  model,
const int *  tokens,
int  num_tokens 
)

Forward pass (prefill) - process multiple tokens. Used for initial prompt processing.

Definition at line 8893 of file v6.6/test_generated/qwen2_int8.c.

8893  {
8894  qwen2_0_5b_decode_forward((QWEN2_0_5B_DECODEModel *)model, tokens, num_tokens);
8895 }
void qwen2_0_5b_decode_forward(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)

References model_forward(), and qwen2_0_5b_decode_forward().

◆ ck_model_free()

void ck_model_free ( void *  model)

Free model memory.

Definition at line 8883 of file v6.6/test_generated/qwen2_int8.c.

8883  {
8884  if (!model) return;
8886  free(model);
8887 }
void qwen2_0_5b_decode_model_free(QWEN2_0_5B_DECODEModel *model)

References model_model_free(), and qwen2_0_5b_decode_model_free().

Referenced by main().

◆ ck_model_get_base()

void* ck_model_get_base ( void *  model)

Get model base pointer (for weight loading).

Definition at line 8910 of file v6.6/test_generated/qwen2_int8.c.

8910  {
8911  return ((QWEN2_0_5B_DECODEModel *)model)->base;
8912 }

Referenced by load_weights_from_bump().

◆ ck_model_get_config()

const CKModelConfig* ck_model_get_config ( void  )

Get model configuration (dimensions, sizes, etc.) This is available before allocation.

Definition at line 8869 of file v6.6/test_generated/qwen2_int8.c.

8869  {
8870  return &g_model_config;
8871 }
static CKModelConfig g_model_config

References g_model_config.

Referenced by load_weights_from_bump(), main(), run_benchmark(), and run_generation_test().

◆ ck_model_get_logits()

float* ck_model_get_logits ( void *  model)

Get pointer to output logits buffer. Size is vocab_size floats.

Definition at line 8901 of file v6.6/test_generated/qwen2_int8.c.

8901  {
8904 }
#define QWEN2_0_5B_DECODE_PTR(model, offset)
static const QWEN2_0_5B_DECODEFooterOffsets QWEN2_0_5B_DECODE_FOOTER

References QWEN2_0_5B_DECODEFooterOffsets::logits, QWEN2_0_5B_DECODE_FOOTER, and QWEN2_0_5B_DECODE_PTR.

Referenced by run_generation_test().

◆ ck_model_get_total_bytes()

size_t ck_model_get_total_bytes ( void *  model)

Get total model size in bytes.

Definition at line 8914 of file v6.6/test_generated/qwen2_int8.c.

8914  {
8915  return ((QWEN2_0_5B_DECODEModel *)model)->total_bytes;
8916 }

Referenced by load_weights_from_bump().

◆ ck_model_precompute_rope()

void ck_model_precompute_rope ( void *  model)

Precompute RoPE cos/sin caches. Call once after allocation, before inference.

Definition at line 8889 of file v6.6/test_generated/qwen2_int8.c.

8889  {
8891 }
void qwen2_0_5b_decode_precompute_rope(QWEN2_0_5B_DECODEModel *model)

References model_precompute_rope(), and qwen2_0_5b_decode_precompute_rope().

Referenced by main().

◆ ck_model_verify_canaries()

int ck_model_verify_canaries ( void *  model)

Verify memory canaries (debug). Returns number of corrupted canaries (0 = OK).

Definition at line 8906 of file v6.6/test_generated/qwen2_int8.c.

8906  {
8908 }
int qwen2_0_5b_decode_verify_canaries(QWEN2_0_5B_DECODEModel *model)

References model_verify_canaries(), and qwen2_0_5b_decode_verify_canaries().

Referenced by run_benchmark().

◆ qwen2_0_5b_decode_align_elems()

static int qwen2_0_5b_decode_align_elems ( int  elems,
int  elem_bytes,
int  align_bytes 
)
static

Definition at line 176 of file v6.6/test_generated/qwen2_int8.c.

176  {
177  int bytes = elems * elem_bytes;
178  int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179  return aligned / elem_bytes;
180 }

◆ qwen2_0_5b_decode_decode()

void qwen2_0_5b_decode_decode ( QWEN2_0_5B_DECODEModel model,
const int *  token,
int  token_index 
)

Definition at line 8841 of file v6.6/test_generated/qwen2_int8.c.

8841  {
8842  qwen2_0_5b_decode_decode_token(model, token, token_index);
8843 }
static void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)

References qwen2_0_5b_decode_decode_token(), and token.

Referenced by ck_model_decode().

◆ qwen2_0_5b_decode_decode_token()

static void qwen2_0_5b_decode_decode_token ( QWEN2_0_5B_DECODEModel model,
const int *  token,
int  token_index 
)
static

Definition at line 8750 of file v6.6/test_generated/qwen2_int8.c.

8754  {
8755  if (!model || !token) return;
8756 
8757  const int aligned_embed_dim = 1024;
8758  const int aligned_head_dim = 64;
8759  const int aligned_intermediate_dim = 4864;
8760  const int aligned_context_window = 131072;
8761 
8762  if (token_index < 0 || token_index >= aligned_context_window) return;
8763 
8764  /* Embedding lookup */
8766  const void *embed_weight = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_HEADER.token_emb);
8767  /* Embedding: Q4_K -> embedding_forward_q4_k */
8768  embedding_forward_q4_k((const int32_t *)token,
8769  1,
8771  embed_weight,
8772  NULL,
8773  embed_out,
8775  aligned_embed_dim,
8776  1,
8777  0);
8778 
8779  /* Process each layer explicitly */
8780  qwen2_0_5b_decode_layer_0_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8781  qwen2_0_5b_decode_layer_1_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8782  qwen2_0_5b_decode_layer_2_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8783  qwen2_0_5b_decode_layer_3_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8784  qwen2_0_5b_decode_layer_4_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8785  qwen2_0_5b_decode_layer_5_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8786  qwen2_0_5b_decode_layer_6_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8787  qwen2_0_5b_decode_layer_7_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8788  qwen2_0_5b_decode_layer_8_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8789  qwen2_0_5b_decode_layer_9_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8790  qwen2_0_5b_decode_layer_10_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8791  qwen2_0_5b_decode_layer_11_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8792  qwen2_0_5b_decode_layer_12_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8793  qwen2_0_5b_decode_layer_13_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8794  qwen2_0_5b_decode_layer_14_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8795  qwen2_0_5b_decode_layer_15_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8796  qwen2_0_5b_decode_layer_16_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8797  qwen2_0_5b_decode_layer_17_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8798  qwen2_0_5b_decode_layer_18_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8799  qwen2_0_5b_decode_layer_19_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8800  qwen2_0_5b_decode_layer_20_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8801  qwen2_0_5b_decode_layer_21_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8802  qwen2_0_5b_decode_layer_22_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8803  qwen2_0_5b_decode_layer_23_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8804 
8805  /* Final RMSNorm */
8806  float *last_hidden = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[23].output);
8807  float *final_ln_weight = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_ln_weight);
8808  float *final_out = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_output);
8809  rmsnorm_forward(last_hidden,
8810  final_ln_weight,
8811  final_out,
8812  NULL,
8813  1,
8815  aligned_embed_dim,
8816  1e-06f);
8817 
8818  /* LM head projection */
8819  float *logits = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.logits);
8820  const void *lm_head = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.lm_head_weight);
8821  /* LM head (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8822  const size_t final_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8823  uint8_t final_q8[final_q8_bytes];
8824  quantize_row_q8_k(final_out, final_q8, aligned_embed_dim);
8825  gemv_q4_k_q8_k(logits, lm_head, final_q8, QWEN2_0_5B_DECODE_VOCAB_SIZE, aligned_embed_dim);
8826 }
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void quantize_row_q8_k(const float *x, void *y, int k)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
static const QWEN2_0_5B_DECODELayerOffsets QWEN2_0_5B_DECODE_LAYERS[24]
#define QWEN2_0_5B_DECODE_VOCAB_SIZE
static const QWEN2_0_5B_DECODEHeaderOffsets QWEN2_0_5B_DECODE_HEADER
static void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

References QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, embedding_forward_q4_k(), QWEN2_0_5B_DECODEFooterOffsets::final_ln_weight, QWEN2_0_5B_DECODEFooterOffsets::final_output, gemv_q4_k_q8_k(), QWEN2_0_5B_DECODEFooterOffsets::lm_head_weight, QWEN2_0_5B_DECODEFooterOffsets::logits, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_FOOTER, QWEN2_0_5B_DECODE_HEADER, qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODE_VOCAB_SIZE, rmsnorm_forward(), token, and QWEN2_0_5B_DECODEHeaderOffsets::token_emb.

Referenced by qwen2_0_5b_decode_decode().

◆ qwen2_0_5b_decode_forward()

void qwen2_0_5b_decode_forward ( QWEN2_0_5B_DECODEModel model,
const int *  tokens,
int  num_tokens 
)

Definition at line 8832 of file v6.6/test_generated/qwen2_int8.c.

8836  {
8837  if (!model || !tokens || num_tokens <= 0) return;
8838  qwen2_0_5b_decode_forward_prefill_impl(model, tokens, num_tokens);
8839 }
static void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)

References qwen2_0_5b_decode_forward_prefill_impl().

Referenced by ck_model_forward().

◆ qwen2_0_5b_decode_forward_prefill_impl()

static void qwen2_0_5b_decode_forward_prefill_impl ( QWEN2_0_5B_DECODEModel model,
const int *  tokens,
int  num_tokens 
)
static

Definition at line 4148 of file v6.6/test_generated/qwen2_int8.c.

4152  {
4153  if (!model || !tokens || num_tokens <= 0) {
4154  return;
4155  }
4156 
4157  const int elem_bytes = QWEN2_0_5B_DECODE_DTYPE_BYTES;
4158  const int aligned_embed_dim = 1024;
4159  const int aligned_head_dim = 64;
4160  const int aligned_intermediate_dim = 4864;
4161  const int aligned_context_window = 131072;
4162 
4164  const void *embed_weight = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_HEADER.token_emb);
4165  embedding_forward_q4_k((const int32_t *)tokens,
4166  num_tokens,
4168  embed_weight,
4169  NULL,
4170  embed_out,
4172  aligned_embed_dim,
4173  num_tokens,
4174  0);
4175 
4177  model,
4178  num_tokens,
4179  aligned_embed_dim,
4180  aligned_head_dim,
4181  aligned_intermediate_dim,
4182  aligned_context_window);
4183 
4185  model,
4186  num_tokens,
4187  aligned_embed_dim,
4188  aligned_head_dim,
4189  aligned_intermediate_dim,
4190  aligned_context_window);
4191 
4193  model,
4194  num_tokens,
4195  aligned_embed_dim,
4196  aligned_head_dim,
4197  aligned_intermediate_dim,
4198  aligned_context_window);
4199 
4201  model,
4202  num_tokens,
4203  aligned_embed_dim,
4204  aligned_head_dim,
4205  aligned_intermediate_dim,
4206  aligned_context_window);
4207 
4209  model,
4210  num_tokens,
4211  aligned_embed_dim,
4212  aligned_head_dim,
4213  aligned_intermediate_dim,
4214  aligned_context_window);
4215 
4217  model,
4218  num_tokens,
4219  aligned_embed_dim,
4220  aligned_head_dim,
4221  aligned_intermediate_dim,
4222  aligned_context_window);
4223 
4225  model,
4226  num_tokens,
4227  aligned_embed_dim,
4228  aligned_head_dim,
4229  aligned_intermediate_dim,
4230  aligned_context_window);
4231 
4233  model,
4234  num_tokens,
4235  aligned_embed_dim,
4236  aligned_head_dim,
4237  aligned_intermediate_dim,
4238  aligned_context_window);
4239 
4241  model,
4242  num_tokens,
4243  aligned_embed_dim,
4244  aligned_head_dim,
4245  aligned_intermediate_dim,
4246  aligned_context_window);
4247 
4249  model,
4250  num_tokens,
4251  aligned_embed_dim,
4252  aligned_head_dim,
4253  aligned_intermediate_dim,
4254  aligned_context_window);
4255 
4257  model,
4258  num_tokens,
4259  aligned_embed_dim,
4260  aligned_head_dim,
4261  aligned_intermediate_dim,
4262  aligned_context_window);
4263 
4265  model,
4266  num_tokens,
4267  aligned_embed_dim,
4268  aligned_head_dim,
4269  aligned_intermediate_dim,
4270  aligned_context_window);
4271 
4273  model,
4274  num_tokens,
4275  aligned_embed_dim,
4276  aligned_head_dim,
4277  aligned_intermediate_dim,
4278  aligned_context_window);
4279 
4281  model,
4282  num_tokens,
4283  aligned_embed_dim,
4284  aligned_head_dim,
4285  aligned_intermediate_dim,
4286  aligned_context_window);
4287 
4289  model,
4290  num_tokens,
4291  aligned_embed_dim,
4292  aligned_head_dim,
4293  aligned_intermediate_dim,
4294  aligned_context_window);
4295 
4297  model,
4298  num_tokens,
4299  aligned_embed_dim,
4300  aligned_head_dim,
4301  aligned_intermediate_dim,
4302  aligned_context_window);
4303 
4305  model,
4306  num_tokens,
4307  aligned_embed_dim,
4308  aligned_head_dim,
4309  aligned_intermediate_dim,
4310  aligned_context_window);
4311 
4313  model,
4314  num_tokens,
4315  aligned_embed_dim,
4316  aligned_head_dim,
4317  aligned_intermediate_dim,
4318  aligned_context_window);
4319 
4321  model,
4322  num_tokens,
4323  aligned_embed_dim,
4324  aligned_head_dim,
4325  aligned_intermediate_dim,
4326  aligned_context_window);
4327 
4329  model,
4330  num_tokens,
4331  aligned_embed_dim,
4332  aligned_head_dim,
4333  aligned_intermediate_dim,
4334  aligned_context_window);
4335 
4337  model,
4338  num_tokens,
4339  aligned_embed_dim,
4340  aligned_head_dim,
4341  aligned_intermediate_dim,
4342  aligned_context_window);
4343 
4345  model,
4346  num_tokens,
4347  aligned_embed_dim,
4348  aligned_head_dim,
4349  aligned_intermediate_dim,
4350  aligned_context_window);
4351 
4353  model,
4354  num_tokens,
4355  aligned_embed_dim,
4356  aligned_head_dim,
4357  aligned_intermediate_dim,
4358  aligned_context_window);
4359 
4361  model,
4362  num_tokens,
4363  aligned_embed_dim,
4364  aligned_head_dim,
4365  aligned_intermediate_dim,
4366  aligned_context_window);
4367 
4368  float *last_hidden = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[QWEN2_0_5B_DECODE_NUM_LAYERS - 1].output);
4369  float *final_ln_weight = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_ln_weight);
4370  float *final_out = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.final_output);
4371  rmsnorm_forward(last_hidden,
4372  final_ln_weight,
4373  final_out,
4374  NULL,
4375  num_tokens,
4377  aligned_embed_dim,
4378  1e-06f);
4379 
4380  float *logits = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.logits);
4381  const void *lm_head = (const void *)QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_FOOTER.lm_head_weight);
4382  const size_t q8_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_embed_dim);
4383  for (int t = 0; t < num_tokens; ++t) {
4384  uint8_t q8_buf[q8_bytes];
4385  const float *row = final_out + (size_t)t * (size_t)aligned_embed_dim;
4386  float *logits_row = logits + (size_t)t * (size_t)QWEN2_0_5B_DECODE_VOCAB_SIZE;
4387  quantize_row_q8_k(row, q8_buf, aligned_embed_dim);
4388  gemm_nt_q4_k_q8_k(q8_buf,
4389  lm_head,
4390  NULL,
4391  logits_row,
4392  1,
4394  aligned_embed_dim);
4395  }
4396 }
@ CK_DT_Q8_K
Definition: ckernel_dtype.h:43
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
#define QWEN2_0_5B_DECODE_DTYPE_BYTES
#define QWEN2_0_5B_DECODE_NUM_LAYERS
static void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

References CK_DT_Q8_K, ck_dtype_row_bytes(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, embedding_forward_q4_k(), QWEN2_0_5B_DECODEFooterOffsets::final_ln_weight, QWEN2_0_5B_DECODEFooterOffsets::final_output, gemm_nt_q4_k_q8_k(), QWEN2_0_5B_DECODEFooterOffsets::lm_head_weight, QWEN2_0_5B_DECODEFooterOffsets::logits, quantize_row_q8_k(), QWEN2_0_5B_DECODE_DTYPE_BYTES, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_FOOTER, QWEN2_0_5B_DECODE_HEADER, qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_prefill(), QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_LAYERS, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODE_VOCAB_SIZE, rmsnorm_forward(), and QWEN2_0_5B_DECODEHeaderOffsets::token_emb.

Referenced by qwen2_0_5b_decode_forward().

◆ qwen2_0_5b_decode_layer_0_decode()

static void qwen2_0_5b_decode_layer_0_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4405 of file v6.6/test_generated/qwen2_int8.c.

4412  {
4414 
4416 
4417  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4418  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4419  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4420  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4421  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4422  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4423  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4424  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4425  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4426  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4427  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4428 
4429  /* Weights (explicit types for layer 0) */
4430  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4431  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4432  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4433  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4434  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4435  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4436 
4439 
4440  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4441  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4442  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4443 
4444  float q_token[H * aligned_head_dim];
4445  float k_token[H_kv * aligned_head_dim];
4446  float v_token[H_kv * aligned_head_dim];
4447  float attn_token[H * aligned_head_dim];
4448 
4449  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4450  float fc1_out[2 * aligned_intermediate_dim];
4451  float swiglu_out[aligned_intermediate_dim];
4452 
4453  /* Step 1: RMSNorm before attention */
4454  rmsnorm_forward(input,
4455  ln1_gamma,
4456  ln1_out,
4457  NULL,
4458  1,
4460  aligned_embed_dim,
4461  1e-06f);
4462 
4463  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
4464 
4465  /* Step 2: QKV projection */
4466  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4467  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4468  uint8_t ln1_q8[ln1_q8_bytes];
4469  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
4470  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4471  if (aligned_head_dim > head_dim) {
4472  for (int h = 0; h < H; ++h) {
4473  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
4474  for (int d = head_dim; d < aligned_head_dim; ++d) {
4475  q_head[d] = 0.0f;
4476  }
4477  }
4478  }
4479 
4480  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4481  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4482  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
4483  const uint8_t *WK_bytes = (const uint8_t *)WK;
4484  /* ln1_q8 already quantized above */
4485  for (int h = 0; h < H_kv; ++h) {
4486  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
4487  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4488  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4489  for (int d = head_dim; d < aligned_head_dim; ++d) {
4490  k_head[d] = 0.0f;
4491  }
4492  }
4493 
4494  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4495  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4496  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
4497  const uint8_t *WV_bytes = (const uint8_t *)WV;
4498  /* ln1_q8 already quantized above */
4499  for (int h = 0; h < H_kv; ++h) {
4500  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
4501  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4502  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4503  for (int d = head_dim; d < aligned_head_dim; ++d) {
4504  v_head[d] = 0.0f;
4505  }
4506  }
4507 
4508  /* Step 3: RoPE */
4509  rope_forward(q_token,
4510  rope_cos,
4511  rope_sin,
4512  H,
4513  1,
4514  head_dim,
4515  aligned_head_dim,
4516  token_index);
4517  for (int h = 0; h < H_kv; ++h) {
4518  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4519  rope_forward(k_head,
4520  rope_cos,
4521  rope_sin,
4522  1,
4523  1,
4524  head_dim,
4525  aligned_head_dim,
4526  token_index);
4527  }
4528 
4529  /* Step 4: KV cache write (direct-to-cache) */
4530 
4531  /* Step 5: Attention (decode, flash) */
4533  k_cache,
4534  v_cache,
4535  attn_token,
4536  H,
4537  H_kv,
4538  token_index + 1,
4539  aligned_context_window,
4540  head_dim,
4541  aligned_head_dim);
4542 
4543  /* Step 6: Output projection */
4544  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4545  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4546  uint8_t attn_q8[attn_q8_bytes];
4547  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
4548  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4549 
4550  /* Step 7: Residual add */
4551  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
4552 
4553  /* Step 8: RMSNorm before MLP */
4554  rmsnorm_forward(residual1,
4555  ln2_gamma,
4556  ln2_out,
4557  NULL,
4558  1,
4560  aligned_embed_dim,
4561  1e-06f);
4562 
4563  /* Step 9: MLP (SwiGLU) */
4564  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4565  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4566  uint8_t ln2_q8[ln2_q8_bytes];
4567  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
4568  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4569 
4570  /* SwiGLU activation */
4571  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4572 
4573  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4574  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4575  uint8_t swiglu_q8[swiglu_q8_bytes];
4576  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
4577  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4578 
4579  /* Step 10: Final residual add */
4580  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
4581 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
static const QWEN2_0_5B_DECODEGlobalOffsets QWEN2_0_5B_DECODE_GLOBALS
#define QWEN2_0_5B_DECODE_NUM_KV_HEADS
static void qwen2_0_5b_decode_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_HEADER, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_0_prefill()

static void qwen2_0_5b_decode_layer_0_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 211 of file v6.6/test_generated/qwen2_int8.c.

218  {
220 
222  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
223  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
224  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
225  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
226  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
227  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
228  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
229  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
230  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
231  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
232  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
233  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
234  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
235  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
236  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
237 
238  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
239  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
240  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
241  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
242  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
243  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
244  const float *BQ = NULL;
245  const float *BK = NULL;
246  const float *BV = NULL;
247  const float *BO = NULL;
248  const float *B1 = NULL;
249  const float *B2 = NULL;
250 
253 
254  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
255  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
256  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
257  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
258  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
259  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
260 
261  /* RMSNorm before attention */
262  rmsnorm_forward(input,
263  ln1_gamma,
264  ln1_out,
265  NULL,
266  num_tokens,
268  aligned_embed_dim,
269  1e-06f);
270 
271  /* Q projection (head-major) */
272  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
273  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
274  for (int h = 0; h < H; ++h) {
275  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
276  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
277  float *q_h = q + (size_t)h * q_head_stride;
278  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
279  }
280 
281  /* K projection (head-major) */
282  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
283  const uint8_t *WK_bytes = (const uint8_t *)WK;
284  for (int h = 0; h < H_kv; ++h) {
285  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
286  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
287  float *k_h = k + (size_t)h * kv_head_stride;
288  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
289  }
290 
291  /* V projection (head-major) */
292  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
293  const uint8_t *WV_bytes = (const uint8_t *)WV;
294  for (int h = 0; h < H_kv; ++h) {
295  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
296  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
297  float *v_h = v + (size_t)h * kv_head_stride;
298  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
299  }
300 
301  /* RoPE */
303  k,
304  rope_cos,
305  rope_sin,
306  H,
307  H_kv,
308  num_tokens,
309  head_dim,
310  aligned_head_dim,
311  0,
312  num_tokens,
313  aligned_context_window);
314 
315  /* Attention (prefill, causal) */
317  k,
318  v,
319  attn_out,
320  H,
321  H_kv,
322  num_tokens,
323  head_dim,
324  aligned_head_dim,
325  aligned_context_window);
326 
327  /* Output projection (flatten head-major to token-major) */
328  const int K = H * aligned_head_dim;
329  if (K != aligned_embed_dim) {
330  return;
331  }
332  const float *proj_in = attn_out;
333  if (H > 1) {
334  if (!proj_scratch) {
335  return;
336  }
337  for (int t = 0; t < num_tokens; ++t) {
338  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
339  for (int h = 0; h < H; ++h) {
340  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
341  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
342  src,
343  (size_t)aligned_head_dim * sizeof(float));
344  }
345  }
346  proj_in = proj_scratch;
347  }
348  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
349 
350  /* Residual add */
351  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
352 
353  /* RMSNorm before MLP */
354  rmsnorm_forward(residual1,
355  ln2_gamma,
356  ln2_out,
357  NULL,
358  num_tokens,
360  aligned_embed_dim,
361  1e-06f);
362 
363  /* MLP (SwiGLU) */
364  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
365  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
366  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
367 
368  /* Final residual add */
369  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
370 }
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), QWEN2_0_5B_DECODEHeaderOffsets::embedded_input, gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_HEADER, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_10_decode()

static void qwen2_0_5b_decode_layer_10_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6215 of file v6.6/test_generated/qwen2_int8.c.

6222  {
6224 
6225  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[9].output);
6226 
6227  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6228  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6229  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6230  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6231  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6232  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6233  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6234  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6235  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6236  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6237  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6238 
6239  /* Weights (explicit types for layer 10) */
6240  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6241  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6242  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6243  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6244  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6245  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6246 
6249 
6250  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6251  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6252  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6253 
6254  float q_token[H * aligned_head_dim];
6255  float k_token[H_kv * aligned_head_dim];
6256  float v_token[H_kv * aligned_head_dim];
6257  float attn_token[H * aligned_head_dim];
6258 
6259  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6260  float fc1_out[2 * aligned_intermediate_dim];
6261  float swiglu_out[aligned_intermediate_dim];
6262 
6263  /* Step 1: RMSNorm before attention */
6264  rmsnorm_forward(input,
6265  ln1_gamma,
6266  ln1_out,
6267  NULL,
6268  1,
6270  aligned_embed_dim,
6271  1e-06f);
6272 
6273  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6274 
6275  /* Step 2: QKV projection */
6276  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6277  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6278  uint8_t ln1_q8[ln1_q8_bytes];
6279  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
6280  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6281  if (aligned_head_dim > head_dim) {
6282  for (int h = 0; h < H; ++h) {
6283  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
6284  for (int d = head_dim; d < aligned_head_dim; ++d) {
6285  q_head[d] = 0.0f;
6286  }
6287  }
6288  }
6289 
6290  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6291  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6292  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
6293  const uint8_t *WK_bytes = (const uint8_t *)WK;
6294  /* ln1_q8 already quantized above */
6295  for (int h = 0; h < H_kv; ++h) {
6296  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
6297  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6298  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6299  for (int d = head_dim; d < aligned_head_dim; ++d) {
6300  k_head[d] = 0.0f;
6301  }
6302  }
6303 
6304  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6305  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6306  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
6307  const uint8_t *WV_bytes = (const uint8_t *)WV;
6308  /* ln1_q8 already quantized above */
6309  for (int h = 0; h < H_kv; ++h) {
6310  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
6311  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6312  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6313  for (int d = head_dim; d < aligned_head_dim; ++d) {
6314  v_head[d] = 0.0f;
6315  }
6316  }
6317 
6318  /* Step 3: RoPE */
6319  rope_forward(q_token,
6320  rope_cos,
6321  rope_sin,
6322  H,
6323  1,
6324  head_dim,
6325  aligned_head_dim,
6326  token_index);
6327  for (int h = 0; h < H_kv; ++h) {
6328  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6329  rope_forward(k_head,
6330  rope_cos,
6331  rope_sin,
6332  1,
6333  1,
6334  head_dim,
6335  aligned_head_dim,
6336  token_index);
6337  }
6338 
6339  /* Step 4: KV cache write (direct-to-cache) */
6340 
6341  /* Step 5: Attention (decode, flash) */
6343  k_cache,
6344  v_cache,
6345  attn_token,
6346  H,
6347  H_kv,
6348  token_index + 1,
6349  aligned_context_window,
6350  head_dim,
6351  aligned_head_dim);
6352 
6353  /* Step 6: Output projection */
6354  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6355  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6356  uint8_t attn_q8[attn_q8_bytes];
6357  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
6358  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6359 
6360  /* Step 7: Residual add */
6361  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6362 
6363  /* Step 8: RMSNorm before MLP */
6364  rmsnorm_forward(residual1,
6365  ln2_gamma,
6366  ln2_out,
6367  NULL,
6368  1,
6370  aligned_embed_dim,
6371  1e-06f);
6372 
6373  /* Step 9: MLP (SwiGLU) */
6374  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6375  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6376  uint8_t ln2_q8[ln2_q8_bytes];
6377  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6378  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6379 
6380  /* SwiGLU activation */
6381  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6382 
6383  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6384  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6385  uint8_t swiglu_q8[swiglu_q8_bytes];
6386  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6387  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6388 
6389  /* Step 10: Final residual add */
6390  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6391 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_10_prefill()

static void qwen2_0_5b_decode_layer_10_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1851 of file v6.6/test_generated/qwen2_int8.c.

1858  {
1860 
1861  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[9].output);
1862  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1863  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1864  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1865  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1866  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1867  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1868  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1869  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1870  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1871  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1872  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1873  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1874  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1875  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1876  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1877 
1878  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1879  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1880  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1881  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1882  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1883  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1884  const float *BQ = NULL;
1885  const float *BK = NULL;
1886  const float *BV = NULL;
1887  const float *BO = NULL;
1888  const float *B1 = NULL;
1889  const float *B2 = NULL;
1890 
1893 
1894  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1895  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1896  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1897  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1898  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1899  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1900 
1901  /* RMSNorm before attention */
1902  rmsnorm_forward(input,
1903  ln1_gamma,
1904  ln1_out,
1905  NULL,
1906  num_tokens,
1908  aligned_embed_dim,
1909  1e-06f);
1910 
1911  /* Q projection (head-major) */
1912  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1913  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1914  for (int h = 0; h < H; ++h) {
1915  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1916  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1917  float *q_h = q + (size_t)h * q_head_stride;
1918  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1919  }
1920 
1921  /* K projection (head-major) */
1922  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1923  const uint8_t *WK_bytes = (const uint8_t *)WK;
1924  for (int h = 0; h < H_kv; ++h) {
1925  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1926  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1927  float *k_h = k + (size_t)h * kv_head_stride;
1928  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1929  }
1930 
1931  /* V projection (head-major) */
1932  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1933  const uint8_t *WV_bytes = (const uint8_t *)WV;
1934  for (int h = 0; h < H_kv; ++h) {
1935  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1936  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1937  float *v_h = v + (size_t)h * kv_head_stride;
1938  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1939  }
1940 
1941  /* RoPE */
1943  k,
1944  rope_cos,
1945  rope_sin,
1946  H,
1947  H_kv,
1948  num_tokens,
1949  head_dim,
1950  aligned_head_dim,
1951  0,
1952  num_tokens,
1953  aligned_context_window);
1954 
1955  /* Attention (prefill, causal) */
1957  k,
1958  v,
1959  attn_out,
1960  H,
1961  H_kv,
1962  num_tokens,
1963  head_dim,
1964  aligned_head_dim,
1965  aligned_context_window);
1966 
1967  /* Output projection (flatten head-major to token-major) */
1968  const int K = H * aligned_head_dim;
1969  if (K != aligned_embed_dim) {
1970  return;
1971  }
1972  const float *proj_in = attn_out;
1973  if (H > 1) {
1974  if (!proj_scratch) {
1975  return;
1976  }
1977  for (int t = 0; t < num_tokens; ++t) {
1978  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1979  for (int h = 0; h < H; ++h) {
1980  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1981  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1982  src,
1983  (size_t)aligned_head_dim * sizeof(float));
1984  }
1985  }
1986  proj_in = proj_scratch;
1987  }
1988  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1989 
1990  /* Residual add */
1991  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1992 
1993  /* RMSNorm before MLP */
1994  rmsnorm_forward(residual1,
1995  ln2_gamma,
1996  ln2_out,
1997  NULL,
1998  num_tokens,
2000  aligned_embed_dim,
2001  1e-06f);
2002 
2003  /* MLP (SwiGLU) */
2004  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2005  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2006  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2007 
2008  /* Final residual add */
2009  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2010 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_11_decode()

static void qwen2_0_5b_decode_layer_11_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6396 of file v6.6/test_generated/qwen2_int8.c.

6403  {
6405 
6406  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[10].output);
6407 
6408  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6409  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6410  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6411  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6412  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6413  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6414  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6415  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6416  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6417  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6418  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6419 
6420  /* Weights (explicit types for layer 11) */
6421  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6422  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6423  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6424  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6425  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6426  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6427 
6430 
6431  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6432  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6433  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6434 
6435  float q_token[H * aligned_head_dim];
6436  float k_token[H_kv * aligned_head_dim];
6437  float v_token[H_kv * aligned_head_dim];
6438  float attn_token[H * aligned_head_dim];
6439 
6440  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6441  float fc1_out[2 * aligned_intermediate_dim];
6442  float swiglu_out[aligned_intermediate_dim];
6443 
6444  /* Step 1: RMSNorm before attention */
6445  rmsnorm_forward(input,
6446  ln1_gamma,
6447  ln1_out,
6448  NULL,
6449  1,
6451  aligned_embed_dim,
6452  1e-06f);
6453 
6454  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6455 
6456  /* Step 2: QKV projection */
6457  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6458  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6459  uint8_t ln1_q8[ln1_q8_bytes];
6460  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
6461  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6462  if (aligned_head_dim > head_dim) {
6463  for (int h = 0; h < H; ++h) {
6464  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
6465  for (int d = head_dim; d < aligned_head_dim; ++d) {
6466  q_head[d] = 0.0f;
6467  }
6468  }
6469  }
6470 
6471  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6472  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6473  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
6474  const uint8_t *WK_bytes = (const uint8_t *)WK;
6475  /* ln1_q8 already quantized above */
6476  for (int h = 0; h < H_kv; ++h) {
6477  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
6478  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6479  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6480  for (int d = head_dim; d < aligned_head_dim; ++d) {
6481  k_head[d] = 0.0f;
6482  }
6483  }
6484 
6485  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6486  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6487  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
6488  const uint8_t *WV_bytes = (const uint8_t *)WV;
6489  /* ln1_q8 already quantized above */
6490  for (int h = 0; h < H_kv; ++h) {
6491  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
6492  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6493  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6494  for (int d = head_dim; d < aligned_head_dim; ++d) {
6495  v_head[d] = 0.0f;
6496  }
6497  }
6498 
6499  /* Step 3: RoPE */
6500  rope_forward(q_token,
6501  rope_cos,
6502  rope_sin,
6503  H,
6504  1,
6505  head_dim,
6506  aligned_head_dim,
6507  token_index);
6508  for (int h = 0; h < H_kv; ++h) {
6509  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6510  rope_forward(k_head,
6511  rope_cos,
6512  rope_sin,
6513  1,
6514  1,
6515  head_dim,
6516  aligned_head_dim,
6517  token_index);
6518  }
6519 
6520  /* Step 4: KV cache write (direct-to-cache) */
6521 
6522  /* Step 5: Attention (decode, flash) */
6524  k_cache,
6525  v_cache,
6526  attn_token,
6527  H,
6528  H_kv,
6529  token_index + 1,
6530  aligned_context_window,
6531  head_dim,
6532  aligned_head_dim);
6533 
6534  /* Step 6: Output projection */
6535  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6536  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6537  uint8_t attn_q8[attn_q8_bytes];
6538  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
6539  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6540 
6541  /* Step 7: Residual add */
6542  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6543 
6544  /* Step 8: RMSNorm before MLP */
6545  rmsnorm_forward(residual1,
6546  ln2_gamma,
6547  ln2_out,
6548  NULL,
6549  1,
6551  aligned_embed_dim,
6552  1e-06f);
6553 
6554  /* Step 9: MLP (SwiGLU) */
6555  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6556  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6557  uint8_t ln2_q8[ln2_q8_bytes];
6558  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6559  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6560 
6561  /* SwiGLU activation */
6562  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6563 
6564  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6565  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6566  uint8_t swiglu_q8[swiglu_q8_bytes];
6567  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6568  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6569 
6570  /* Step 10: Final residual add */
6571  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6572 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_11_prefill()

static void qwen2_0_5b_decode_layer_11_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2015 of file v6.6/test_generated/qwen2_int8.c.

2022  {
2024 
2025  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[10].output);
2026  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2027  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2028  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2029  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2030  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2031  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2032  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2033  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2034  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2035  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2036  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2037  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2038  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2039  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2040  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2041 
2042  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2043  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2044  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2045  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2046  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2047  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2048  const float *BQ = NULL;
2049  const float *BK = NULL;
2050  const float *BV = NULL;
2051  const float *BO = NULL;
2052  const float *B1 = NULL;
2053  const float *B2 = NULL;
2054 
2057 
2058  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2059  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2060  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2061  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2062  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2063  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2064 
2065  /* RMSNorm before attention */
2066  rmsnorm_forward(input,
2067  ln1_gamma,
2068  ln1_out,
2069  NULL,
2070  num_tokens,
2072  aligned_embed_dim,
2073  1e-06f);
2074 
2075  /* Q projection (head-major) */
2076  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2077  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2078  for (int h = 0; h < H; ++h) {
2079  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2080  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2081  float *q_h = q + (size_t)h * q_head_stride;
2082  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2083  }
2084 
2085  /* K projection (head-major) */
2086  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2087  const uint8_t *WK_bytes = (const uint8_t *)WK;
2088  for (int h = 0; h < H_kv; ++h) {
2089  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2090  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2091  float *k_h = k + (size_t)h * kv_head_stride;
2092  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2093  }
2094 
2095  /* V projection (head-major) */
2096  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2097  const uint8_t *WV_bytes = (const uint8_t *)WV;
2098  for (int h = 0; h < H_kv; ++h) {
2099  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2100  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2101  float *v_h = v + (size_t)h * kv_head_stride;
2102  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2103  }
2104 
2105  /* RoPE */
2107  k,
2108  rope_cos,
2109  rope_sin,
2110  H,
2111  H_kv,
2112  num_tokens,
2113  head_dim,
2114  aligned_head_dim,
2115  0,
2116  num_tokens,
2117  aligned_context_window);
2118 
2119  /* Attention (prefill, causal) */
2121  k,
2122  v,
2123  attn_out,
2124  H,
2125  H_kv,
2126  num_tokens,
2127  head_dim,
2128  aligned_head_dim,
2129  aligned_context_window);
2130 
2131  /* Output projection (flatten head-major to token-major) */
2132  const int K = H * aligned_head_dim;
2133  if (K != aligned_embed_dim) {
2134  return;
2135  }
2136  const float *proj_in = attn_out;
2137  if (H > 1) {
2138  if (!proj_scratch) {
2139  return;
2140  }
2141  for (int t = 0; t < num_tokens; ++t) {
2142  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2143  for (int h = 0; h < H; ++h) {
2144  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2145  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2146  src,
2147  (size_t)aligned_head_dim * sizeof(float));
2148  }
2149  }
2150  proj_in = proj_scratch;
2151  }
2152  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2153 
2154  /* Residual add */
2155  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2156 
2157  /* RMSNorm before MLP */
2158  rmsnorm_forward(residual1,
2159  ln2_gamma,
2160  ln2_out,
2161  NULL,
2162  num_tokens,
2164  aligned_embed_dim,
2165  1e-06f);
2166 
2167  /* MLP (SwiGLU) */
2168  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2169  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2170  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2171 
2172  /* Final residual add */
2173  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2174 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_12_decode()

static void qwen2_0_5b_decode_layer_12_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6577 of file v6.6/test_generated/qwen2_int8.c.

6584  {
6586 
6587  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[11].output);
6588 
6589  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6590  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6591  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6592  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6593  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6594  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6595  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6596  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6597  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6598  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6599  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6600 
6601  /* Weights (explicit types for layer 12) */
6602  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6603  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6604  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6605  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6606  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6607  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6608 
6611 
6612  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6613  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6614  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6615 
6616  float q_token[H * aligned_head_dim];
6617  float k_token[H_kv * aligned_head_dim];
6618  float v_token[H_kv * aligned_head_dim];
6619  float attn_token[H * aligned_head_dim];
6620 
6621  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6622  float fc1_out[2 * aligned_intermediate_dim];
6623  float swiglu_out[aligned_intermediate_dim];
6624 
6625  /* Step 1: RMSNorm before attention */
6626  rmsnorm_forward(input,
6627  ln1_gamma,
6628  ln1_out,
6629  NULL,
6630  1,
6632  aligned_embed_dim,
6633  1e-06f);
6634 
6635  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6636 
6637  /* Step 2: QKV projection */
6638  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6639  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6640  uint8_t ln1_q8[ln1_q8_bytes];
6641  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
6642  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6643  if (aligned_head_dim > head_dim) {
6644  for (int h = 0; h < H; ++h) {
6645  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
6646  for (int d = head_dim; d < aligned_head_dim; ++d) {
6647  q_head[d] = 0.0f;
6648  }
6649  }
6650  }
6651 
6652  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6653  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6654  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
6655  const uint8_t *WK_bytes = (const uint8_t *)WK;
6656  /* ln1_q8 already quantized above */
6657  for (int h = 0; h < H_kv; ++h) {
6658  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
6659  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6660  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6661  for (int d = head_dim; d < aligned_head_dim; ++d) {
6662  k_head[d] = 0.0f;
6663  }
6664  }
6665 
6666  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6667  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6668  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
6669  const uint8_t *WV_bytes = (const uint8_t *)WV;
6670  /* ln1_q8 already quantized above */
6671  for (int h = 0; h < H_kv; ++h) {
6672  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
6673  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6674  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6675  for (int d = head_dim; d < aligned_head_dim; ++d) {
6676  v_head[d] = 0.0f;
6677  }
6678  }
6679 
6680  /* Step 3: RoPE */
6681  rope_forward(q_token,
6682  rope_cos,
6683  rope_sin,
6684  H,
6685  1,
6686  head_dim,
6687  aligned_head_dim,
6688  token_index);
6689  for (int h = 0; h < H_kv; ++h) {
6690  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6691  rope_forward(k_head,
6692  rope_cos,
6693  rope_sin,
6694  1,
6695  1,
6696  head_dim,
6697  aligned_head_dim,
6698  token_index);
6699  }
6700 
6701  /* Step 4: KV cache write (direct-to-cache) */
6702 
6703  /* Step 5: Attention (decode, flash) */
6705  k_cache,
6706  v_cache,
6707  attn_token,
6708  H,
6709  H_kv,
6710  token_index + 1,
6711  aligned_context_window,
6712  head_dim,
6713  aligned_head_dim);
6714 
6715  /* Step 6: Output projection */
6716  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6717  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6718  uint8_t attn_q8[attn_q8_bytes];
6719  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
6720  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6721 
6722  /* Step 7: Residual add */
6723  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6724 
6725  /* Step 8: RMSNorm before MLP */
6726  rmsnorm_forward(residual1,
6727  ln2_gamma,
6728  ln2_out,
6729  NULL,
6730  1,
6732  aligned_embed_dim,
6733  1e-06f);
6734 
6735  /* Step 9: MLP (SwiGLU) */
6736  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6737  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6738  uint8_t ln2_q8[ln2_q8_bytes];
6739  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6740  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6741 
6742  /* SwiGLU activation */
6743  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6744 
6745  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6746  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6747  uint8_t swiglu_q8[swiglu_q8_bytes];
6748  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6749  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6750 
6751  /* Step 10: Final residual add */
6752  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6753 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_12_prefill()

static void qwen2_0_5b_decode_layer_12_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2179 of file v6.6/test_generated/qwen2_int8.c.

2186  {
2188 
2189  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[11].output);
2190  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2191  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2192  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2193  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2194  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2195  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2196  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2197  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2198  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2199  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2200  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2201  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2202  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2203  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2204  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2205 
2206  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2207  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2208  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2209  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2210  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2211  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2212  const float *BQ = NULL;
2213  const float *BK = NULL;
2214  const float *BV = NULL;
2215  const float *BO = NULL;
2216  const float *B1 = NULL;
2217  const float *B2 = NULL;
2218 
2221 
2222  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2223  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2224  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2225  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2226  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2227  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2228 
2229  /* RMSNorm before attention */
2230  rmsnorm_forward(input,
2231  ln1_gamma,
2232  ln1_out,
2233  NULL,
2234  num_tokens,
2236  aligned_embed_dim,
2237  1e-06f);
2238 
2239  /* Q projection (head-major) */
2240  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2241  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2242  for (int h = 0; h < H; ++h) {
2243  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2244  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2245  float *q_h = q + (size_t)h * q_head_stride;
2246  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2247  }
2248 
2249  /* K projection (head-major) */
2250  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2251  const uint8_t *WK_bytes = (const uint8_t *)WK;
2252  for (int h = 0; h < H_kv; ++h) {
2253  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2254  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2255  float *k_h = k + (size_t)h * kv_head_stride;
2256  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2257  }
2258 
2259  /* V projection (head-major) */
2260  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2261  const uint8_t *WV_bytes = (const uint8_t *)WV;
2262  for (int h = 0; h < H_kv; ++h) {
2263  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2264  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2265  float *v_h = v + (size_t)h * kv_head_stride;
2266  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2267  }
2268 
2269  /* RoPE */
2271  k,
2272  rope_cos,
2273  rope_sin,
2274  H,
2275  H_kv,
2276  num_tokens,
2277  head_dim,
2278  aligned_head_dim,
2279  0,
2280  num_tokens,
2281  aligned_context_window);
2282 
2283  /* Attention (prefill, causal) */
2285  k,
2286  v,
2287  attn_out,
2288  H,
2289  H_kv,
2290  num_tokens,
2291  head_dim,
2292  aligned_head_dim,
2293  aligned_context_window);
2294 
2295  /* Output projection (flatten head-major to token-major) */
2296  const int K = H * aligned_head_dim;
2297  if (K != aligned_embed_dim) {
2298  return;
2299  }
2300  const float *proj_in = attn_out;
2301  if (H > 1) {
2302  if (!proj_scratch) {
2303  return;
2304  }
2305  for (int t = 0; t < num_tokens; ++t) {
2306  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2307  for (int h = 0; h < H; ++h) {
2308  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2309  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2310  src,
2311  (size_t)aligned_head_dim * sizeof(float));
2312  }
2313  }
2314  proj_in = proj_scratch;
2315  }
2316  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2317 
2318  /* Residual add */
2319  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2320 
2321  /* RMSNorm before MLP */
2322  rmsnorm_forward(residual1,
2323  ln2_gamma,
2324  ln2_out,
2325  NULL,
2326  num_tokens,
2328  aligned_embed_dim,
2329  1e-06f);
2330 
2331  /* MLP (SwiGLU) */
2332  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2333  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2334  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2335 
2336  /* Final residual add */
2337  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2338 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_13_decode()

static void qwen2_0_5b_decode_layer_13_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6758 of file v6.6/test_generated/qwen2_int8.c.

6765  {
6767 
6768  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[12].output);
6769 
6770  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6771  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6772  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6773  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6774  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6775  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6776  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6777  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6778  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6779  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6780  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6781 
6782  /* Weights (explicit types for layer 13) */
6783  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6784  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6785  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6786  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6787  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6788  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6789 
6792 
6793  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6794  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6795  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6796 
6797  float q_token[H * aligned_head_dim];
6798  float k_token[H_kv * aligned_head_dim];
6799  float v_token[H_kv * aligned_head_dim];
6800  float attn_token[H * aligned_head_dim];
6801 
6802  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6803  float fc1_out[2 * aligned_intermediate_dim];
6804  float swiglu_out[aligned_intermediate_dim];
6805 
6806  /* Step 1: RMSNorm before attention */
6807  rmsnorm_forward(input,
6808  ln1_gamma,
6809  ln1_out,
6810  NULL,
6811  1,
6813  aligned_embed_dim,
6814  1e-06f);
6815 
6816  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6817 
6818  /* Step 2: QKV projection */
6819  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6820  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6821  uint8_t ln1_q8[ln1_q8_bytes];
6822  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
6823  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6824  if (aligned_head_dim > head_dim) {
6825  for (int h = 0; h < H; ++h) {
6826  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
6827  for (int d = head_dim; d < aligned_head_dim; ++d) {
6828  q_head[d] = 0.0f;
6829  }
6830  }
6831  }
6832 
6833  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6834  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6835  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
6836  const uint8_t *WK_bytes = (const uint8_t *)WK;
6837  /* ln1_q8 already quantized above */
6838  for (int h = 0; h < H_kv; ++h) {
6839  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
6840  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6841  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6842  for (int d = head_dim; d < aligned_head_dim; ++d) {
6843  k_head[d] = 0.0f;
6844  }
6845  }
6846 
6847  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6848  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6849  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
6850  const uint8_t *WV_bytes = (const uint8_t *)WV;
6851  /* ln1_q8 already quantized above */
6852  for (int h = 0; h < H_kv; ++h) {
6853  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
6854  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6855  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6856  for (int d = head_dim; d < aligned_head_dim; ++d) {
6857  v_head[d] = 0.0f;
6858  }
6859  }
6860 
6861  /* Step 3: RoPE */
6862  rope_forward(q_token,
6863  rope_cos,
6864  rope_sin,
6865  H,
6866  1,
6867  head_dim,
6868  aligned_head_dim,
6869  token_index);
6870  for (int h = 0; h < H_kv; ++h) {
6871  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6872  rope_forward(k_head,
6873  rope_cos,
6874  rope_sin,
6875  1,
6876  1,
6877  head_dim,
6878  aligned_head_dim,
6879  token_index);
6880  }
6881 
6882  /* Step 4: KV cache write (direct-to-cache) */
6883 
6884  /* Step 5: Attention (decode, flash) */
6886  k_cache,
6887  v_cache,
6888  attn_token,
6889  H,
6890  H_kv,
6891  token_index + 1,
6892  aligned_context_window,
6893  head_dim,
6894  aligned_head_dim);
6895 
6896  /* Step 6: Output projection */
6897  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6898  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6899  uint8_t attn_q8[attn_q8_bytes];
6900  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
6901  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6902 
6903  /* Step 7: Residual add */
6904  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6905 
6906  /* Step 8: RMSNorm before MLP */
6907  rmsnorm_forward(residual1,
6908  ln2_gamma,
6909  ln2_out,
6910  NULL,
6911  1,
6913  aligned_embed_dim,
6914  1e-06f);
6915 
6916  /* Step 9: MLP (SwiGLU) */
6917  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6918  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6919  uint8_t ln2_q8[ln2_q8_bytes];
6920  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6921  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6922 
6923  /* SwiGLU activation */
6924  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6925 
6926  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6927  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6928  uint8_t swiglu_q8[swiglu_q8_bytes];
6929  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6930  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6931 
6932  /* Step 10: Final residual add */
6933  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6934 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_13_prefill()

static void qwen2_0_5b_decode_layer_13_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2343 of file v6.6/test_generated/qwen2_int8.c.

2350  {
2352 
2353  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[12].output);
2354  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2355  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2356  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2357  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2358  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2359  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2360  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2361  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2362  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2363  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2364  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2365  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2366  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2367  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2368  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2369 
2370  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2371  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2372  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2373  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2374  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2375  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2376  const float *BQ = NULL;
2377  const float *BK = NULL;
2378  const float *BV = NULL;
2379  const float *BO = NULL;
2380  const float *B1 = NULL;
2381  const float *B2 = NULL;
2382 
2385 
2386  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2387  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2388  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2389  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2390  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2391  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2392 
2393  /* RMSNorm before attention */
2394  rmsnorm_forward(input,
2395  ln1_gamma,
2396  ln1_out,
2397  NULL,
2398  num_tokens,
2400  aligned_embed_dim,
2401  1e-06f);
2402 
2403  /* Q projection (head-major) */
2404  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2405  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2406  for (int h = 0; h < H; ++h) {
2407  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2408  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2409  float *q_h = q + (size_t)h * q_head_stride;
2410  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2411  }
2412 
2413  /* K projection (head-major) */
2414  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2415  const uint8_t *WK_bytes = (const uint8_t *)WK;
2416  for (int h = 0; h < H_kv; ++h) {
2417  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2418  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2419  float *k_h = k + (size_t)h * kv_head_stride;
2420  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2421  }
2422 
2423  /* V projection (head-major) */
2424  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2425  const uint8_t *WV_bytes = (const uint8_t *)WV;
2426  for (int h = 0; h < H_kv; ++h) {
2427  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2428  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2429  float *v_h = v + (size_t)h * kv_head_stride;
2430  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2431  }
2432 
2433  /* RoPE */
2435  k,
2436  rope_cos,
2437  rope_sin,
2438  H,
2439  H_kv,
2440  num_tokens,
2441  head_dim,
2442  aligned_head_dim,
2443  0,
2444  num_tokens,
2445  aligned_context_window);
2446 
2447  /* Attention (prefill, causal) */
2449  k,
2450  v,
2451  attn_out,
2452  H,
2453  H_kv,
2454  num_tokens,
2455  head_dim,
2456  aligned_head_dim,
2457  aligned_context_window);
2458 
2459  /* Output projection (flatten head-major to token-major) */
2460  const int K = H * aligned_head_dim;
2461  if (K != aligned_embed_dim) {
2462  return;
2463  }
2464  const float *proj_in = attn_out;
2465  if (H > 1) {
2466  if (!proj_scratch) {
2467  return;
2468  }
2469  for (int t = 0; t < num_tokens; ++t) {
2470  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2471  for (int h = 0; h < H; ++h) {
2472  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2473  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2474  src,
2475  (size_t)aligned_head_dim * sizeof(float));
2476  }
2477  }
2478  proj_in = proj_scratch;
2479  }
2480  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2481 
2482  /* Residual add */
2483  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2484 
2485  /* RMSNorm before MLP */
2486  rmsnorm_forward(residual1,
2487  ln2_gamma,
2488  ln2_out,
2489  NULL,
2490  num_tokens,
2492  aligned_embed_dim,
2493  1e-06f);
2494 
2495  /* MLP (SwiGLU) */
2496  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2497  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2498  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2499 
2500  /* Final residual add */
2501  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2502 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_14_decode()

static void qwen2_0_5b_decode_layer_14_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6939 of file v6.6/test_generated/qwen2_int8.c.

6946  {
6948 
6949  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[13].output);
6950 
6951  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6952  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6953  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6954  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6955  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6956  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6957  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6958  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6959  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6960  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6961  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6962 
6963  /* Weights (explicit types for layer 14) */
6964  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6965  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6966  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6967  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6968  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6969  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6970 
6973 
6974  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6975  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6976  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6977 
6978  float q_token[H * aligned_head_dim];
6979  float k_token[H_kv * aligned_head_dim];
6980  float v_token[H_kv * aligned_head_dim];
6981  float attn_token[H * aligned_head_dim];
6982 
6983  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6984  float fc1_out[2 * aligned_intermediate_dim];
6985  float swiglu_out[aligned_intermediate_dim];
6986 
6987  /* Step 1: RMSNorm before attention */
6988  rmsnorm_forward(input,
6989  ln1_gamma,
6990  ln1_out,
6991  NULL,
6992  1,
6994  aligned_embed_dim,
6995  1e-06f);
6996 
6997  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6998 
6999  /* Step 2: QKV projection */
7000  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7001  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7002  uint8_t ln1_q8[ln1_q8_bytes];
7003  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7004  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7005  if (aligned_head_dim > head_dim) {
7006  for (int h = 0; h < H; ++h) {
7007  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7008  for (int d = head_dim; d < aligned_head_dim; ++d) {
7009  q_head[d] = 0.0f;
7010  }
7011  }
7012  }
7013 
7014  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7015  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7016  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7017  const uint8_t *WK_bytes = (const uint8_t *)WK;
7018  /* ln1_q8 already quantized above */
7019  for (int h = 0; h < H_kv; ++h) {
7020  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7021  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7022  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7023  for (int d = head_dim; d < aligned_head_dim; ++d) {
7024  k_head[d] = 0.0f;
7025  }
7026  }
7027 
7028  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7029  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7030  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7031  const uint8_t *WV_bytes = (const uint8_t *)WV;
7032  /* ln1_q8 already quantized above */
7033  for (int h = 0; h < H_kv; ++h) {
7034  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7035  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7036  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7037  for (int d = head_dim; d < aligned_head_dim; ++d) {
7038  v_head[d] = 0.0f;
7039  }
7040  }
7041 
7042  /* Step 3: RoPE */
7043  rope_forward(q_token,
7044  rope_cos,
7045  rope_sin,
7046  H,
7047  1,
7048  head_dim,
7049  aligned_head_dim,
7050  token_index);
7051  for (int h = 0; h < H_kv; ++h) {
7052  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7053  rope_forward(k_head,
7054  rope_cos,
7055  rope_sin,
7056  1,
7057  1,
7058  head_dim,
7059  aligned_head_dim,
7060  token_index);
7061  }
7062 
7063  /* Step 4: KV cache write (direct-to-cache) */
7064 
7065  /* Step 5: Attention (decode, flash) */
7067  k_cache,
7068  v_cache,
7069  attn_token,
7070  H,
7071  H_kv,
7072  token_index + 1,
7073  aligned_context_window,
7074  head_dim,
7075  aligned_head_dim);
7076 
7077  /* Step 6: Output projection */
7078  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7079  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7080  uint8_t attn_q8[attn_q8_bytes];
7081  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7082  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7083 
7084  /* Step 7: Residual add */
7085  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7086 
7087  /* Step 8: RMSNorm before MLP */
7088  rmsnorm_forward(residual1,
7089  ln2_gamma,
7090  ln2_out,
7091  NULL,
7092  1,
7094  aligned_embed_dim,
7095  1e-06f);
7096 
7097  /* Step 9: MLP (SwiGLU) */
7098  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7099  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7100  uint8_t ln2_q8[ln2_q8_bytes];
7101  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
7102  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7103 
7104  /* SwiGLU activation */
7105  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7106 
7107  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7108  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7109  uint8_t swiglu_q8[swiglu_q8_bytes];
7110  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
7111  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7112 
7113  /* Step 10: Final residual add */
7114  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7115 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_14_prefill()

static void qwen2_0_5b_decode_layer_14_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2507 of file v6.6/test_generated/qwen2_int8.c.

2514  {
2516 
2517  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[13].output);
2518  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2519  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2520  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2521  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2522  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2523  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2524  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2525  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2526  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2527  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2528  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2529  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2530  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2531  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2532  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2533 
2534  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2535  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2536  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2537  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2538  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2539  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2540  const float *BQ = NULL;
2541  const float *BK = NULL;
2542  const float *BV = NULL;
2543  const float *BO = NULL;
2544  const float *B1 = NULL;
2545  const float *B2 = NULL;
2546 
2549 
2550  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2551  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2552  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2553  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2554  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2555  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2556 
2557  /* RMSNorm before attention */
2558  rmsnorm_forward(input,
2559  ln1_gamma,
2560  ln1_out,
2561  NULL,
2562  num_tokens,
2564  aligned_embed_dim,
2565  1e-06f);
2566 
2567  /* Q projection (head-major) */
2568  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2569  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2570  for (int h = 0; h < H; ++h) {
2571  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2572  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2573  float *q_h = q + (size_t)h * q_head_stride;
2574  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2575  }
2576 
2577  /* K projection (head-major) */
2578  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2579  const uint8_t *WK_bytes = (const uint8_t *)WK;
2580  for (int h = 0; h < H_kv; ++h) {
2581  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2582  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2583  float *k_h = k + (size_t)h * kv_head_stride;
2584  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2585  }
2586 
2587  /* V projection (head-major) */
2588  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2589  const uint8_t *WV_bytes = (const uint8_t *)WV;
2590  for (int h = 0; h < H_kv; ++h) {
2591  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2592  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2593  float *v_h = v + (size_t)h * kv_head_stride;
2594  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2595  }
2596 
2597  /* RoPE */
2599  k,
2600  rope_cos,
2601  rope_sin,
2602  H,
2603  H_kv,
2604  num_tokens,
2605  head_dim,
2606  aligned_head_dim,
2607  0,
2608  num_tokens,
2609  aligned_context_window);
2610 
2611  /* Attention (prefill, causal) */
2613  k,
2614  v,
2615  attn_out,
2616  H,
2617  H_kv,
2618  num_tokens,
2619  head_dim,
2620  aligned_head_dim,
2621  aligned_context_window);
2622 
2623  /* Output projection (flatten head-major to token-major) */
2624  const int K = H * aligned_head_dim;
2625  if (K != aligned_embed_dim) {
2626  return;
2627  }
2628  const float *proj_in = attn_out;
2629  if (H > 1) {
2630  if (!proj_scratch) {
2631  return;
2632  }
2633  for (int t = 0; t < num_tokens; ++t) {
2634  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2635  for (int h = 0; h < H; ++h) {
2636  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2637  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2638  src,
2639  (size_t)aligned_head_dim * sizeof(float));
2640  }
2641  }
2642  proj_in = proj_scratch;
2643  }
2644  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2645 
2646  /* Residual add */
2647  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2648 
2649  /* RMSNorm before MLP */
2650  rmsnorm_forward(residual1,
2651  ln2_gamma,
2652  ln2_out,
2653  NULL,
2654  num_tokens,
2656  aligned_embed_dim,
2657  1e-06f);
2658 
2659  /* MLP (SwiGLU) */
2660  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2661  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2662  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2663 
2664  /* Final residual add */
2665  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2666 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_15_decode()

static void qwen2_0_5b_decode_layer_15_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7120 of file v6.6/test_generated/qwen2_int8.c.

7127  {
7129 
7130  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[14].output);
7131 
7132  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7133  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7134  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7135  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7136  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7137  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7138  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7139  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7140  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7141  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7142  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7143 
7144  /* Weights (explicit types for layer 15) */
7145  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7146  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7147  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7148  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7149  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7150  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7151 
7154 
7155  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7156  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7157  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7158 
7159  float q_token[H * aligned_head_dim];
7160  float k_token[H_kv * aligned_head_dim];
7161  float v_token[H_kv * aligned_head_dim];
7162  float attn_token[H * aligned_head_dim];
7163 
7164  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7165  float fc1_out[2 * aligned_intermediate_dim];
7166  float swiglu_out[aligned_intermediate_dim];
7167 
7168  /* Step 1: RMSNorm before attention */
7169  rmsnorm_forward(input,
7170  ln1_gamma,
7171  ln1_out,
7172  NULL,
7173  1,
7175  aligned_embed_dim,
7176  1e-06f);
7177 
7178  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
7179 
7180  /* Step 2: QKV projection */
7181  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7182  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7183  uint8_t ln1_q8[ln1_q8_bytes];
7184  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7185  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7186  if (aligned_head_dim > head_dim) {
7187  for (int h = 0; h < H; ++h) {
7188  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7189  for (int d = head_dim; d < aligned_head_dim; ++d) {
7190  q_head[d] = 0.0f;
7191  }
7192  }
7193  }
7194 
7195  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7196  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7197  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7198  const uint8_t *WK_bytes = (const uint8_t *)WK;
7199  /* ln1_q8 already quantized above */
7200  for (int h = 0; h < H_kv; ++h) {
7201  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7202  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7203  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7204  for (int d = head_dim; d < aligned_head_dim; ++d) {
7205  k_head[d] = 0.0f;
7206  }
7207  }
7208 
7209  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7210  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7211  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7212  const uint8_t *WV_bytes = (const uint8_t *)WV;
7213  /* ln1_q8 already quantized above */
7214  for (int h = 0; h < H_kv; ++h) {
7215  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7216  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7217  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7218  for (int d = head_dim; d < aligned_head_dim; ++d) {
7219  v_head[d] = 0.0f;
7220  }
7221  }
7222 
7223  /* Step 3: RoPE */
7224  rope_forward(q_token,
7225  rope_cos,
7226  rope_sin,
7227  H,
7228  1,
7229  head_dim,
7230  aligned_head_dim,
7231  token_index);
7232  for (int h = 0; h < H_kv; ++h) {
7233  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7234  rope_forward(k_head,
7235  rope_cos,
7236  rope_sin,
7237  1,
7238  1,
7239  head_dim,
7240  aligned_head_dim,
7241  token_index);
7242  }
7243 
7244  /* Step 4: KV cache write (direct-to-cache) */
7245 
7246  /* Step 5: Attention (decode, flash) */
7248  k_cache,
7249  v_cache,
7250  attn_token,
7251  H,
7252  H_kv,
7253  token_index + 1,
7254  aligned_context_window,
7255  head_dim,
7256  aligned_head_dim);
7257 
7258  /* Step 6: Output projection */
7259  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7260  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7261  uint8_t attn_q8[attn_q8_bytes];
7262  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7263  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7264 
7265  /* Step 7: Residual add */
7266  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7267 
7268  /* Step 8: RMSNorm before MLP */
7269  rmsnorm_forward(residual1,
7270  ln2_gamma,
7271  ln2_out,
7272  NULL,
7273  1,
7275  aligned_embed_dim,
7276  1e-06f);
7277 
7278  /* Step 9: MLP (SwiGLU) */
7279  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7280  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7281  uint8_t ln2_q8[ln2_q8_bytes];
7282  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
7283  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7284 
7285  /* SwiGLU activation */
7286  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7287 
7288  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7289  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7290  uint8_t swiglu_q8[swiglu_q8_bytes];
7291  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
7292  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7293 
7294  /* Step 10: Final residual add */
7295  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7296 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_15_prefill()

static void qwen2_0_5b_decode_layer_15_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2671 of file v6.6/test_generated/qwen2_int8.c.

2678  {
2680 
2681  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[14].output);
2682  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2683  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2684  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2685  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2686  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2687  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2688  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2689  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2690  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2691  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2692  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2693  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2694  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2695  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2696  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2697 
2698  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2699  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2700  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2701  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2702  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2703  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2704  const float *BQ = NULL;
2705  const float *BK = NULL;
2706  const float *BV = NULL;
2707  const float *BO = NULL;
2708  const float *B1 = NULL;
2709  const float *B2 = NULL;
2710 
2713 
2714  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2715  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2716  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2717  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2718  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2719  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2720 
2721  /* RMSNorm before attention */
2722  rmsnorm_forward(input,
2723  ln1_gamma,
2724  ln1_out,
2725  NULL,
2726  num_tokens,
2728  aligned_embed_dim,
2729  1e-06f);
2730 
2731  /* Q projection (head-major) */
2732  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2733  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2734  for (int h = 0; h < H; ++h) {
2735  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2736  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2737  float *q_h = q + (size_t)h * q_head_stride;
2738  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2739  }
2740 
2741  /* K projection (head-major) */
2742  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2743  const uint8_t *WK_bytes = (const uint8_t *)WK;
2744  for (int h = 0; h < H_kv; ++h) {
2745  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2746  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2747  float *k_h = k + (size_t)h * kv_head_stride;
2748  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2749  }
2750 
2751  /* V projection (head-major) */
2752  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2753  const uint8_t *WV_bytes = (const uint8_t *)WV;
2754  for (int h = 0; h < H_kv; ++h) {
2755  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2756  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2757  float *v_h = v + (size_t)h * kv_head_stride;
2758  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2759  }
2760 
2761  /* RoPE */
2763  k,
2764  rope_cos,
2765  rope_sin,
2766  H,
2767  H_kv,
2768  num_tokens,
2769  head_dim,
2770  aligned_head_dim,
2771  0,
2772  num_tokens,
2773  aligned_context_window);
2774 
2775  /* Attention (prefill, causal) */
2777  k,
2778  v,
2779  attn_out,
2780  H,
2781  H_kv,
2782  num_tokens,
2783  head_dim,
2784  aligned_head_dim,
2785  aligned_context_window);
2786 
2787  /* Output projection (flatten head-major to token-major) */
2788  const int K = H * aligned_head_dim;
2789  if (K != aligned_embed_dim) {
2790  return;
2791  }
2792  const float *proj_in = attn_out;
2793  if (H > 1) {
2794  if (!proj_scratch) {
2795  return;
2796  }
2797  for (int t = 0; t < num_tokens; ++t) {
2798  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2799  for (int h = 0; h < H; ++h) {
2800  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2801  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2802  src,
2803  (size_t)aligned_head_dim * sizeof(float));
2804  }
2805  }
2806  proj_in = proj_scratch;
2807  }
2808  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2809 
2810  /* Residual add */
2811  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2812 
2813  /* RMSNorm before MLP */
2814  rmsnorm_forward(residual1,
2815  ln2_gamma,
2816  ln2_out,
2817  NULL,
2818  num_tokens,
2820  aligned_embed_dim,
2821  1e-06f);
2822 
2823  /* MLP (SwiGLU) */
2824  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2825  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2826  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2827 
2828  /* Final residual add */
2829  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2830 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_16_decode()

static void qwen2_0_5b_decode_layer_16_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7301 of file v6.6/test_generated/qwen2_int8.c.

7308  {
7310 
7311  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[15].output);
7312 
7313  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7314  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7315  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7316  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7317  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7318  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7319  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7320  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7321  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7322  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7323  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7324 
7325  /* Weights (explicit types for layer 16) */
7326  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7327  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7328  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7329  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7330  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7331  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7332 
7335 
7336  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7337  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7338  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7339 
7340  float q_token[H * aligned_head_dim];
7341  float k_token[H_kv * aligned_head_dim];
7342  float v_token[H_kv * aligned_head_dim];
7343  float attn_token[H * aligned_head_dim];
7344 
7345  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7346  float fc1_out[2 * aligned_intermediate_dim];
7347  float swiglu_out[aligned_intermediate_dim];
7348 
7349  /* Step 1: RMSNorm before attention */
7350  rmsnorm_forward(input,
7351  ln1_gamma,
7352  ln1_out,
7353  NULL,
7354  1,
7356  aligned_embed_dim,
7357  1e-06f);
7358 
7359  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
7360 
7361  /* Step 2: QKV projection */
7362  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7363  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7364  uint8_t ln1_q8[ln1_q8_bytes];
7365  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7366  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7367  if (aligned_head_dim > head_dim) {
7368  for (int h = 0; h < H; ++h) {
7369  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7370  for (int d = head_dim; d < aligned_head_dim; ++d) {
7371  q_head[d] = 0.0f;
7372  }
7373  }
7374  }
7375 
7376  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7377  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7378  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7379  const uint8_t *WK_bytes = (const uint8_t *)WK;
7380  /* ln1_q8 already quantized above */
7381  for (int h = 0; h < H_kv; ++h) {
7382  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7383  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7384  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7385  for (int d = head_dim; d < aligned_head_dim; ++d) {
7386  k_head[d] = 0.0f;
7387  }
7388  }
7389 
7390  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7391  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7392  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7393  const uint8_t *WV_bytes = (const uint8_t *)WV;
7394  /* ln1_q8 already quantized above */
7395  for (int h = 0; h < H_kv; ++h) {
7396  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7397  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7398  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7399  for (int d = head_dim; d < aligned_head_dim; ++d) {
7400  v_head[d] = 0.0f;
7401  }
7402  }
7403 
7404  /* Step 3: RoPE */
7405  rope_forward(q_token,
7406  rope_cos,
7407  rope_sin,
7408  H,
7409  1,
7410  head_dim,
7411  aligned_head_dim,
7412  token_index);
7413  for (int h = 0; h < H_kv; ++h) {
7414  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7415  rope_forward(k_head,
7416  rope_cos,
7417  rope_sin,
7418  1,
7419  1,
7420  head_dim,
7421  aligned_head_dim,
7422  token_index);
7423  }
7424 
7425  /* Step 4: KV cache write (direct-to-cache) */
7426 
7427  /* Step 5: Attention (decode, flash) */
7429  k_cache,
7430  v_cache,
7431  attn_token,
7432  H,
7433  H_kv,
7434  token_index + 1,
7435  aligned_context_window,
7436  head_dim,
7437  aligned_head_dim);
7438 
7439  /* Step 6: Output projection */
7440  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7441  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7442  uint8_t attn_q8[attn_q8_bytes];
7443  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7444  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7445 
7446  /* Step 7: Residual add */
7447  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7448 
7449  /* Step 8: RMSNorm before MLP */
7450  rmsnorm_forward(residual1,
7451  ln2_gamma,
7452  ln2_out,
7453  NULL,
7454  1,
7456  aligned_embed_dim,
7457  1e-06f);
7458 
7459  /* Step 9: MLP (SwiGLU) */
7460  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7461  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7462  uint8_t ln2_q8[ln2_q8_bytes];
7463  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
7464  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7465 
7466  /* SwiGLU activation */
7467  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7468 
7469  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7470  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7471  uint8_t swiglu_q8[swiglu_q8_bytes];
7472  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
7473  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7474 
7475  /* Step 10: Final residual add */
7476  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7477 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_16_prefill()

static void qwen2_0_5b_decode_layer_16_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2835 of file v6.6/test_generated/qwen2_int8.c.

2842  {
2844 
2845  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[15].output);
2846  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
2847  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
2848  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
2849  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
2850  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
2851  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
2852  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
2853  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
2854  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
2855  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
2856  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
2857  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
2858  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
2859  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
2860  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
2861 
2862  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
2863  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
2864  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
2865  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
2866  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
2867  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
2868  const float *BQ = NULL;
2869  const float *BK = NULL;
2870  const float *BV = NULL;
2871  const float *BO = NULL;
2872  const float *B1 = NULL;
2873  const float *B2 = NULL;
2874 
2877 
2878  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
2879  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
2880  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
2881  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
2882  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
2883  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
2884 
2885  /* RMSNorm before attention */
2886  rmsnorm_forward(input,
2887  ln1_gamma,
2888  ln1_out,
2889  NULL,
2890  num_tokens,
2892  aligned_embed_dim,
2893  1e-06f);
2894 
2895  /* Q projection (head-major) */
2896  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2897  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
2898  for (int h = 0; h < H; ++h) {
2899  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
2900  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
2901  float *q_h = q + (size_t)h * q_head_stride;
2902  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2903  }
2904 
2905  /* K projection (head-major) */
2906  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2907  const uint8_t *WK_bytes = (const uint8_t *)WK;
2908  for (int h = 0; h < H_kv; ++h) {
2909  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
2910  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
2911  float *k_h = k + (size_t)h * kv_head_stride;
2912  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2913  }
2914 
2915  /* V projection (head-major) */
2916  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
2917  const uint8_t *WV_bytes = (const uint8_t *)WV;
2918  for (int h = 0; h < H_kv; ++h) {
2919  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
2920  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
2921  float *v_h = v + (size_t)h * kv_head_stride;
2922  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2923  }
2924 
2925  /* RoPE */
2927  k,
2928  rope_cos,
2929  rope_sin,
2930  H,
2931  H_kv,
2932  num_tokens,
2933  head_dim,
2934  aligned_head_dim,
2935  0,
2936  num_tokens,
2937  aligned_context_window);
2938 
2939  /* Attention (prefill, causal) */
2941  k,
2942  v,
2943  attn_out,
2944  H,
2945  H_kv,
2946  num_tokens,
2947  head_dim,
2948  aligned_head_dim,
2949  aligned_context_window);
2950 
2951  /* Output projection (flatten head-major to token-major) */
2952  const int K = H * aligned_head_dim;
2953  if (K != aligned_embed_dim) {
2954  return;
2955  }
2956  const float *proj_in = attn_out;
2957  if (H > 1) {
2958  if (!proj_scratch) {
2959  return;
2960  }
2961  for (int t = 0; t < num_tokens; ++t) {
2962  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
2963  for (int h = 0; h < H; ++h) {
2964  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
2965  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
2966  src,
2967  (size_t)aligned_head_dim * sizeof(float));
2968  }
2969  }
2970  proj_in = proj_scratch;
2971  }
2972  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2973 
2974  /* Residual add */
2975  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
2976 
2977  /* RMSNorm before MLP */
2978  rmsnorm_forward(residual1,
2979  ln2_gamma,
2980  ln2_out,
2981  NULL,
2982  num_tokens,
2984  aligned_embed_dim,
2985  1e-06f);
2986 
2987  /* MLP (SwiGLU) */
2988  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2989  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2990  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2991 
2992  /* Final residual add */
2993  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
2994 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_17_decode()

static void qwen2_0_5b_decode_layer_17_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7482 of file v6.6/test_generated/qwen2_int8.c.

7489  {
7491 
7492  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[16].output);
7493 
7494  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7495  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7496  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7497  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7498  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7499  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7500  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7501  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7502  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7503  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7504  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7505 
7506  /* Weights (explicit types for layer 17) */
7507  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7508  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7509  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7510  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7511  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7512  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7513 
7516 
7517  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7518  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7519  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7520 
7521  float q_token[H * aligned_head_dim];
7522  float k_token[H_kv * aligned_head_dim];
7523  float v_token[H_kv * aligned_head_dim];
7524  float attn_token[H * aligned_head_dim];
7525 
7526  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7527  float fc1_out[2 * aligned_intermediate_dim];
7528  float swiglu_out[aligned_intermediate_dim];
7529 
7530  /* Step 1: RMSNorm before attention */
7531  rmsnorm_forward(input,
7532  ln1_gamma,
7533  ln1_out,
7534  NULL,
7535  1,
7537  aligned_embed_dim,
7538  1e-06f);
7539 
7540  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
7541 
7542  /* Step 2: QKV projection */
7543  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7544  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7545  uint8_t ln1_q8[ln1_q8_bytes];
7546  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7547  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7548  if (aligned_head_dim > head_dim) {
7549  for (int h = 0; h < H; ++h) {
7550  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7551  for (int d = head_dim; d < aligned_head_dim; ++d) {
7552  q_head[d] = 0.0f;
7553  }
7554  }
7555  }
7556 
7557  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7558  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7559  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7560  const uint8_t *WK_bytes = (const uint8_t *)WK;
7561  /* ln1_q8 already quantized above */
7562  for (int h = 0; h < H_kv; ++h) {
7563  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7564  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7565  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7566  for (int d = head_dim; d < aligned_head_dim; ++d) {
7567  k_head[d] = 0.0f;
7568  }
7569  }
7570 
7571  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7572  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7573  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7574  const uint8_t *WV_bytes = (const uint8_t *)WV;
7575  /* ln1_q8 already quantized above */
7576  for (int h = 0; h < H_kv; ++h) {
7577  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7578  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7579  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7580  for (int d = head_dim; d < aligned_head_dim; ++d) {
7581  v_head[d] = 0.0f;
7582  }
7583  }
7584 
7585  /* Step 3: RoPE */
7586  rope_forward(q_token,
7587  rope_cos,
7588  rope_sin,
7589  H,
7590  1,
7591  head_dim,
7592  aligned_head_dim,
7593  token_index);
7594  for (int h = 0; h < H_kv; ++h) {
7595  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7596  rope_forward(k_head,
7597  rope_cos,
7598  rope_sin,
7599  1,
7600  1,
7601  head_dim,
7602  aligned_head_dim,
7603  token_index);
7604  }
7605 
7606  /* Step 4: KV cache write (direct-to-cache) */
7607 
7608  /* Step 5: Attention (decode, flash) */
7610  k_cache,
7611  v_cache,
7612  attn_token,
7613  H,
7614  H_kv,
7615  token_index + 1,
7616  aligned_context_window,
7617  head_dim,
7618  aligned_head_dim);
7619 
7620  /* Step 6: Output projection */
7621  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7622  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7623  uint8_t attn_q8[attn_q8_bytes];
7624  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7625  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7626 
7627  /* Step 7: Residual add */
7628  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7629 
7630  /* Step 8: RMSNorm before MLP */
7631  rmsnorm_forward(residual1,
7632  ln2_gamma,
7633  ln2_out,
7634  NULL,
7635  1,
7637  aligned_embed_dim,
7638  1e-06f);
7639 
7640  /* Step 9: MLP (SwiGLU) */
7641  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7642  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7643  uint8_t ln2_q8[ln2_q8_bytes];
7644  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
7645  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7646 
7647  /* SwiGLU activation */
7648  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7649 
7650  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7651  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7652  uint8_t swiglu_q8[swiglu_q8_bytes];
7653  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
7654  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7655 
7656  /* Step 10: Final residual add */
7657  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7658 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_17_prefill()

static void qwen2_0_5b_decode_layer_17_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 2999 of file v6.6/test_generated/qwen2_int8.c.

3006  {
3008 
3009  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[16].output);
3010  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3011  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3012  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3013  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3014  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3015  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3016  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3017  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3018  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3019  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3020  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3021  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3022  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3023  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3024  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3025 
3026  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3027  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3028  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3029  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3030  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3031  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3032  const float *BQ = NULL;
3033  const float *BK = NULL;
3034  const float *BV = NULL;
3035  const float *BO = NULL;
3036  const float *B1 = NULL;
3037  const float *B2 = NULL;
3038 
3041 
3042  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3043  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3044  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3045  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3046  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3047  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3048 
3049  /* RMSNorm before attention */
3050  rmsnorm_forward(input,
3051  ln1_gamma,
3052  ln1_out,
3053  NULL,
3054  num_tokens,
3056  aligned_embed_dim,
3057  1e-06f);
3058 
3059  /* Q projection (head-major) */
3060  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3061  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3062  for (int h = 0; h < H; ++h) {
3063  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3064  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3065  float *q_h = q + (size_t)h * q_head_stride;
3066  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3067  }
3068 
3069  /* K projection (head-major) */
3070  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3071  const uint8_t *WK_bytes = (const uint8_t *)WK;
3072  for (int h = 0; h < H_kv; ++h) {
3073  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3074  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3075  float *k_h = k + (size_t)h * kv_head_stride;
3076  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3077  }
3078 
3079  /* V projection (head-major) */
3080  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3081  const uint8_t *WV_bytes = (const uint8_t *)WV;
3082  for (int h = 0; h < H_kv; ++h) {
3083  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3084  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3085  float *v_h = v + (size_t)h * kv_head_stride;
3086  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3087  }
3088 
3089  /* RoPE */
3091  k,
3092  rope_cos,
3093  rope_sin,
3094  H,
3095  H_kv,
3096  num_tokens,
3097  head_dim,
3098  aligned_head_dim,
3099  0,
3100  num_tokens,
3101  aligned_context_window);
3102 
3103  /* Attention (prefill, causal) */
3105  k,
3106  v,
3107  attn_out,
3108  H,
3109  H_kv,
3110  num_tokens,
3111  head_dim,
3112  aligned_head_dim,
3113  aligned_context_window);
3114 
3115  /* Output projection (flatten head-major to token-major) */
3116  const int K = H * aligned_head_dim;
3117  if (K != aligned_embed_dim) {
3118  return;
3119  }
3120  const float *proj_in = attn_out;
3121  if (H > 1) {
3122  if (!proj_scratch) {
3123  return;
3124  }
3125  for (int t = 0; t < num_tokens; ++t) {
3126  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3127  for (int h = 0; h < H; ++h) {
3128  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3129  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3130  src,
3131  (size_t)aligned_head_dim * sizeof(float));
3132  }
3133  }
3134  proj_in = proj_scratch;
3135  }
3136  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3137 
3138  /* Residual add */
3139  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3140 
3141  /* RMSNorm before MLP */
3142  rmsnorm_forward(residual1,
3143  ln2_gamma,
3144  ln2_out,
3145  NULL,
3146  num_tokens,
3148  aligned_embed_dim,
3149  1e-06f);
3150 
3151  /* MLP (SwiGLU) */
3152  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3153  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3154  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3155 
3156  /* Final residual add */
3157  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3158 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_18_decode()

static void qwen2_0_5b_decode_layer_18_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7663 of file v6.6/test_generated/qwen2_int8.c.

7670  {
7672 
7673  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[17].output);
7674 
7675  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7676  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7677  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7678  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7679  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7680  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7681  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7682  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7683  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7684  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7685  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7686 
7687  /* Weights (explicit types for layer 18) */
7688  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7689  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7690  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7691  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7692  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7693  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7694 
7697 
7698  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7699  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7700  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7701 
7702  float q_token[H * aligned_head_dim];
7703  float k_token[H_kv * aligned_head_dim];
7704  float v_token[H_kv * aligned_head_dim];
7705  float attn_token[H * aligned_head_dim];
7706 
7707  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7708  float fc1_out[2 * aligned_intermediate_dim];
7709  float swiglu_out[aligned_intermediate_dim];
7710 
7711  /* Step 1: RMSNorm before attention */
7712  rmsnorm_forward(input,
7713  ln1_gamma,
7714  ln1_out,
7715  NULL,
7716  1,
7718  aligned_embed_dim,
7719  1e-06f);
7720 
7721  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
7722 
7723  /* Step 2: QKV projection */
7724  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7725  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7726  uint8_t ln1_q8[ln1_q8_bytes];
7727  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7728  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7729  if (aligned_head_dim > head_dim) {
7730  for (int h = 0; h < H; ++h) {
7731  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7732  for (int d = head_dim; d < aligned_head_dim; ++d) {
7733  q_head[d] = 0.0f;
7734  }
7735  }
7736  }
7737 
7738  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7739  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7740  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7741  const uint8_t *WK_bytes = (const uint8_t *)WK;
7742  /* ln1_q8 already quantized above */
7743  for (int h = 0; h < H_kv; ++h) {
7744  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7745  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7746  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7747  for (int d = head_dim; d < aligned_head_dim; ++d) {
7748  k_head[d] = 0.0f;
7749  }
7750  }
7751 
7752  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7753  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7754  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7755  const uint8_t *WV_bytes = (const uint8_t *)WV;
7756  /* ln1_q8 already quantized above */
7757  for (int h = 0; h < H_kv; ++h) {
7758  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7759  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7760  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7761  for (int d = head_dim; d < aligned_head_dim; ++d) {
7762  v_head[d] = 0.0f;
7763  }
7764  }
7765 
7766  /* Step 3: RoPE */
7767  rope_forward(q_token,
7768  rope_cos,
7769  rope_sin,
7770  H,
7771  1,
7772  head_dim,
7773  aligned_head_dim,
7774  token_index);
7775  for (int h = 0; h < H_kv; ++h) {
7776  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7777  rope_forward(k_head,
7778  rope_cos,
7779  rope_sin,
7780  1,
7781  1,
7782  head_dim,
7783  aligned_head_dim,
7784  token_index);
7785  }
7786 
7787  /* Step 4: KV cache write (direct-to-cache) */
7788 
7789  /* Step 5: Attention (decode, flash) */
7791  k_cache,
7792  v_cache,
7793  attn_token,
7794  H,
7795  H_kv,
7796  token_index + 1,
7797  aligned_context_window,
7798  head_dim,
7799  aligned_head_dim);
7800 
7801  /* Step 6: Output projection */
7802  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7803  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7804  uint8_t attn_q8[attn_q8_bytes];
7805  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7806  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7807 
7808  /* Step 7: Residual add */
7809  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7810 
7811  /* Step 8: RMSNorm before MLP */
7812  rmsnorm_forward(residual1,
7813  ln2_gamma,
7814  ln2_out,
7815  NULL,
7816  1,
7818  aligned_embed_dim,
7819  1e-06f);
7820 
7821  /* Step 9: MLP (SwiGLU) */
7822  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7823  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7824  uint8_t ln2_q8[ln2_q8_bytes];
7825  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
7826  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7827 
7828  /* SwiGLU activation */
7829  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7830 
7831  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7832  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7833  uint8_t swiglu_q8[swiglu_q8_bytes];
7834  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
7835  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7836 
7837  /* Step 10: Final residual add */
7838  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
7839 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_18_prefill()

static void qwen2_0_5b_decode_layer_18_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3163 of file v6.6/test_generated/qwen2_int8.c.

3170  {
3172 
3173  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[17].output);
3174  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3175  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3176  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3177  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3178  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3179  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3180  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3181  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3182  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3183  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3184  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3185  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3186  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3187  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3188  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3189 
3190  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3191  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3192  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3193  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3194  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3195  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3196  const float *BQ = NULL;
3197  const float *BK = NULL;
3198  const float *BV = NULL;
3199  const float *BO = NULL;
3200  const float *B1 = NULL;
3201  const float *B2 = NULL;
3202 
3205 
3206  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3207  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3208  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3209  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3210  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3211  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3212 
3213  /* RMSNorm before attention */
3214  rmsnorm_forward(input,
3215  ln1_gamma,
3216  ln1_out,
3217  NULL,
3218  num_tokens,
3220  aligned_embed_dim,
3221  1e-06f);
3222 
3223  /* Q projection (head-major) */
3224  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3225  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3226  for (int h = 0; h < H; ++h) {
3227  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3228  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3229  float *q_h = q + (size_t)h * q_head_stride;
3230  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3231  }
3232 
3233  /* K projection (head-major) */
3234  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3235  const uint8_t *WK_bytes = (const uint8_t *)WK;
3236  for (int h = 0; h < H_kv; ++h) {
3237  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3238  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3239  float *k_h = k + (size_t)h * kv_head_stride;
3240  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3241  }
3242 
3243  /* V projection (head-major) */
3244  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3245  const uint8_t *WV_bytes = (const uint8_t *)WV;
3246  for (int h = 0; h < H_kv; ++h) {
3247  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3248  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3249  float *v_h = v + (size_t)h * kv_head_stride;
3250  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3251  }
3252 
3253  /* RoPE */
3255  k,
3256  rope_cos,
3257  rope_sin,
3258  H,
3259  H_kv,
3260  num_tokens,
3261  head_dim,
3262  aligned_head_dim,
3263  0,
3264  num_tokens,
3265  aligned_context_window);
3266 
3267  /* Attention (prefill, causal) */
3269  k,
3270  v,
3271  attn_out,
3272  H,
3273  H_kv,
3274  num_tokens,
3275  head_dim,
3276  aligned_head_dim,
3277  aligned_context_window);
3278 
3279  /* Output projection (flatten head-major to token-major) */
3280  const int K = H * aligned_head_dim;
3281  if (K != aligned_embed_dim) {
3282  return;
3283  }
3284  const float *proj_in = attn_out;
3285  if (H > 1) {
3286  if (!proj_scratch) {
3287  return;
3288  }
3289  for (int t = 0; t < num_tokens; ++t) {
3290  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3291  for (int h = 0; h < H; ++h) {
3292  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3293  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3294  src,
3295  (size_t)aligned_head_dim * sizeof(float));
3296  }
3297  }
3298  proj_in = proj_scratch;
3299  }
3300  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3301 
3302  /* Residual add */
3303  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3304 
3305  /* RMSNorm before MLP */
3306  rmsnorm_forward(residual1,
3307  ln2_gamma,
3308  ln2_out,
3309  NULL,
3310  num_tokens,
3312  aligned_embed_dim,
3313  1e-06f);
3314 
3315  /* MLP (SwiGLU) */
3316  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3317  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3318  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3319 
3320  /* Final residual add */
3321  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3322 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_19_decode()

static void qwen2_0_5b_decode_layer_19_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 7844 of file v6.6/test_generated/qwen2_int8.c.

7851  {
7853 
7854  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[18].output);
7855 
7856  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
7857  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
7858  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
7859  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
7860  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
7861  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
7862  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
7863  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
7864  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
7865  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
7866  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
7867 
7868  /* Weights (explicit types for layer 19) */
7869  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
7870  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
7871  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
7872  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
7873  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
7874  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
7875 
7878 
7879  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
7880  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
7881  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
7882 
7883  float q_token[H * aligned_head_dim];
7884  float k_token[H_kv * aligned_head_dim];
7885  float v_token[H_kv * aligned_head_dim];
7886  float attn_token[H * aligned_head_dim];
7887 
7888  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
7889  float fc1_out[2 * aligned_intermediate_dim];
7890  float swiglu_out[aligned_intermediate_dim];
7891 
7892  /* Step 1: RMSNorm before attention */
7893  rmsnorm_forward(input,
7894  ln1_gamma,
7895  ln1_out,
7896  NULL,
7897  1,
7899  aligned_embed_dim,
7900  1e-06f);
7901 
7902  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
7903 
7904  /* Step 2: QKV projection */
7905  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7906  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7907  uint8_t ln1_q8[ln1_q8_bytes];
7908  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
7909  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7910  if (aligned_head_dim > head_dim) {
7911  for (int h = 0; h < H; ++h) {
7912  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
7913  for (int d = head_dim; d < aligned_head_dim; ++d) {
7914  q_head[d] = 0.0f;
7915  }
7916  }
7917  }
7918 
7919  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7920  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7921  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
7922  const uint8_t *WK_bytes = (const uint8_t *)WK;
7923  /* ln1_q8 already quantized above */
7924  for (int h = 0; h < H_kv; ++h) {
7925  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
7926  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7927  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7928  for (int d = head_dim; d < aligned_head_dim; ++d) {
7929  k_head[d] = 0.0f;
7930  }
7931  }
7932 
7933  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
7934  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
7935  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
7936  const uint8_t *WV_bytes = (const uint8_t *)WV;
7937  /* ln1_q8 already quantized above */
7938  for (int h = 0; h < H_kv; ++h) {
7939  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
7940  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7941  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7942  for (int d = head_dim; d < aligned_head_dim; ++d) {
7943  v_head[d] = 0.0f;
7944  }
7945  }
7946 
7947  /* Step 3: RoPE */
7948  rope_forward(q_token,
7949  rope_cos,
7950  rope_sin,
7951  H,
7952  1,
7953  head_dim,
7954  aligned_head_dim,
7955  token_index);
7956  for (int h = 0; h < H_kv; ++h) {
7957  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
7958  rope_forward(k_head,
7959  rope_cos,
7960  rope_sin,
7961  1,
7962  1,
7963  head_dim,
7964  aligned_head_dim,
7965  token_index);
7966  }
7967 
7968  /* Step 4: KV cache write (direct-to-cache) */
7969 
7970  /* Step 5: Attention (decode, flash) */
7972  k_cache,
7973  v_cache,
7974  attn_token,
7975  H,
7976  H_kv,
7977  token_index + 1,
7978  aligned_context_window,
7979  head_dim,
7980  aligned_head_dim);
7981 
7982  /* Step 6: Output projection */
7983  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
7984  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7985  uint8_t attn_q8[attn_q8_bytes];
7986  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
7987  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7988 
7989  /* Step 7: Residual add */
7990  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
7991 
7992  /* Step 8: RMSNorm before MLP */
7993  rmsnorm_forward(residual1,
7994  ln2_gamma,
7995  ln2_out,
7996  NULL,
7997  1,
7999  aligned_embed_dim,
8000  1e-06f);
8001 
8002  /* Step 9: MLP (SwiGLU) */
8003  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8004  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8005  uint8_t ln2_q8[ln2_q8_bytes];
8006  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
8007  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8008 
8009  /* SwiGLU activation */
8010  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8011 
8012  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8013  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8014  uint8_t swiglu_q8[swiglu_q8_bytes];
8015  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
8016  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8017 
8018  /* Step 10: Final residual add */
8019  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
8020 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_19_prefill()

static void qwen2_0_5b_decode_layer_19_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3327 of file v6.6/test_generated/qwen2_int8.c.

3334  {
3336 
3337  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[18].output);
3338  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3339  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3340  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3341  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3342  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3343  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3344  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3345  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3346  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3347  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3348  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3349  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3350  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3351  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3352  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3353 
3354  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3355  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3356  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3357  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3358  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3359  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3360  const float *BQ = NULL;
3361  const float *BK = NULL;
3362  const float *BV = NULL;
3363  const float *BO = NULL;
3364  const float *B1 = NULL;
3365  const float *B2 = NULL;
3366 
3369 
3370  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3371  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3372  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3373  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3374  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3375  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3376 
3377  /* RMSNorm before attention */
3378  rmsnorm_forward(input,
3379  ln1_gamma,
3380  ln1_out,
3381  NULL,
3382  num_tokens,
3384  aligned_embed_dim,
3385  1e-06f);
3386 
3387  /* Q projection (head-major) */
3388  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3389  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3390  for (int h = 0; h < H; ++h) {
3391  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3392  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3393  float *q_h = q + (size_t)h * q_head_stride;
3394  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3395  }
3396 
3397  /* K projection (head-major) */
3398  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3399  const uint8_t *WK_bytes = (const uint8_t *)WK;
3400  for (int h = 0; h < H_kv; ++h) {
3401  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3402  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3403  float *k_h = k + (size_t)h * kv_head_stride;
3404  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3405  }
3406 
3407  /* V projection (head-major) */
3408  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3409  const uint8_t *WV_bytes = (const uint8_t *)WV;
3410  for (int h = 0; h < H_kv; ++h) {
3411  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3412  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3413  float *v_h = v + (size_t)h * kv_head_stride;
3414  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3415  }
3416 
3417  /* RoPE */
3419  k,
3420  rope_cos,
3421  rope_sin,
3422  H,
3423  H_kv,
3424  num_tokens,
3425  head_dim,
3426  aligned_head_dim,
3427  0,
3428  num_tokens,
3429  aligned_context_window);
3430 
3431  /* Attention (prefill, causal) */
3433  k,
3434  v,
3435  attn_out,
3436  H,
3437  H_kv,
3438  num_tokens,
3439  head_dim,
3440  aligned_head_dim,
3441  aligned_context_window);
3442 
3443  /* Output projection (flatten head-major to token-major) */
3444  const int K = H * aligned_head_dim;
3445  if (K != aligned_embed_dim) {
3446  return;
3447  }
3448  const float *proj_in = attn_out;
3449  if (H > 1) {
3450  if (!proj_scratch) {
3451  return;
3452  }
3453  for (int t = 0; t < num_tokens; ++t) {
3454  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3455  for (int h = 0; h < H; ++h) {
3456  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3457  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3458  src,
3459  (size_t)aligned_head_dim * sizeof(float));
3460  }
3461  }
3462  proj_in = proj_scratch;
3463  }
3464  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3465 
3466  /* Residual add */
3467  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3468 
3469  /* RMSNorm before MLP */
3470  rmsnorm_forward(residual1,
3471  ln2_gamma,
3472  ln2_out,
3473  NULL,
3474  num_tokens,
3476  aligned_embed_dim,
3477  1e-06f);
3478 
3479  /* MLP (SwiGLU) */
3480  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3481  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3482  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3483 
3484  /* Final residual add */
3485  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3486 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_1_decode()

static void qwen2_0_5b_decode_layer_1_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4586 of file v6.6/test_generated/qwen2_int8.c.

4593  {
4595 
4596  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[0].output);
4597 
4598  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4599  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4600  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4601  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4602  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4603  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4604  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4605  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4606  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4607  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4608  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4609 
4610  /* Weights (explicit types for layer 1) */
4611  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4612  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4613  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4614  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4615  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4616  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4617 
4620 
4621  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4622  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4623  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4624 
4625  float q_token[H * aligned_head_dim];
4626  float k_token[H_kv * aligned_head_dim];
4627  float v_token[H_kv * aligned_head_dim];
4628  float attn_token[H * aligned_head_dim];
4629 
4630  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4631  float fc1_out[2 * aligned_intermediate_dim];
4632  float swiglu_out[aligned_intermediate_dim];
4633 
4634  /* Step 1: RMSNorm before attention */
4635  rmsnorm_forward(input,
4636  ln1_gamma,
4637  ln1_out,
4638  NULL,
4639  1,
4641  aligned_embed_dim,
4642  1e-06f);
4643 
4644  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
4645 
4646  /* Step 2: QKV projection */
4647  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4648  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4649  uint8_t ln1_q8[ln1_q8_bytes];
4650  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
4651  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4652  if (aligned_head_dim > head_dim) {
4653  for (int h = 0; h < H; ++h) {
4654  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
4655  for (int d = head_dim; d < aligned_head_dim; ++d) {
4656  q_head[d] = 0.0f;
4657  }
4658  }
4659  }
4660 
4661  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4662  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4663  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
4664  const uint8_t *WK_bytes = (const uint8_t *)WK;
4665  /* ln1_q8 already quantized above */
4666  for (int h = 0; h < H_kv; ++h) {
4667  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
4668  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4669  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4670  for (int d = head_dim; d < aligned_head_dim; ++d) {
4671  k_head[d] = 0.0f;
4672  }
4673  }
4674 
4675  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4676  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4677  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
4678  const uint8_t *WV_bytes = (const uint8_t *)WV;
4679  /* ln1_q8 already quantized above */
4680  for (int h = 0; h < H_kv; ++h) {
4681  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
4682  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4683  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4684  for (int d = head_dim; d < aligned_head_dim; ++d) {
4685  v_head[d] = 0.0f;
4686  }
4687  }
4688 
4689  /* Step 3: RoPE */
4690  rope_forward(q_token,
4691  rope_cos,
4692  rope_sin,
4693  H,
4694  1,
4695  head_dim,
4696  aligned_head_dim,
4697  token_index);
4698  for (int h = 0; h < H_kv; ++h) {
4699  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4700  rope_forward(k_head,
4701  rope_cos,
4702  rope_sin,
4703  1,
4704  1,
4705  head_dim,
4706  aligned_head_dim,
4707  token_index);
4708  }
4709 
4710  /* Step 4: KV cache write (direct-to-cache) */
4711 
4712  /* Step 5: Attention (decode, flash) */
4714  k_cache,
4715  v_cache,
4716  attn_token,
4717  H,
4718  H_kv,
4719  token_index + 1,
4720  aligned_context_window,
4721  head_dim,
4722  aligned_head_dim);
4723 
4724  /* Step 6: Output projection */
4725  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4726  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4727  uint8_t attn_q8[attn_q8_bytes];
4728  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
4729  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4730 
4731  /* Step 7: Residual add */
4732  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
4733 
4734  /* Step 8: RMSNorm before MLP */
4735  rmsnorm_forward(residual1,
4736  ln2_gamma,
4737  ln2_out,
4738  NULL,
4739  1,
4741  aligned_embed_dim,
4742  1e-06f);
4743 
4744  /* Step 9: MLP (SwiGLU) */
4745  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4746  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4747  uint8_t ln2_q8[ln2_q8_bytes];
4748  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
4749  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4750 
4751  /* SwiGLU activation */
4752  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4753 
4754  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4755  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4756  uint8_t swiglu_q8[swiglu_q8_bytes];
4757  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
4758  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4759 
4760  /* Step 10: Final residual add */
4761  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
4762 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_1_prefill()

static void qwen2_0_5b_decode_layer_1_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 375 of file v6.6/test_generated/qwen2_int8.c.

382  {
384 
385  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[0].output);
386  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
387  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
388  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
389  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
390  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
391  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
392  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
393  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
394  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
395  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
396  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
397  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
398  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
399  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
400  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
401 
402  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
403  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
404  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
405  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
406  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
407  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
408  const float *BQ = NULL;
409  const float *BK = NULL;
410  const float *BV = NULL;
411  const float *BO = NULL;
412  const float *B1 = NULL;
413  const float *B2 = NULL;
414 
417 
418  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
419  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
420  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
421  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
422  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
423  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
424 
425  /* RMSNorm before attention */
426  rmsnorm_forward(input,
427  ln1_gamma,
428  ln1_out,
429  NULL,
430  num_tokens,
432  aligned_embed_dim,
433  1e-06f);
434 
435  /* Q projection (head-major) */
436  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
437  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
438  for (int h = 0; h < H; ++h) {
439  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
440  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
441  float *q_h = q + (size_t)h * q_head_stride;
442  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
443  }
444 
445  /* K projection (head-major) */
446  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
447  const uint8_t *WK_bytes = (const uint8_t *)WK;
448  for (int h = 0; h < H_kv; ++h) {
449  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
450  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
451  float *k_h = k + (size_t)h * kv_head_stride;
452  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
453  }
454 
455  /* V projection (head-major) */
456  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
457  const uint8_t *WV_bytes = (const uint8_t *)WV;
458  for (int h = 0; h < H_kv; ++h) {
459  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
460  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
461  float *v_h = v + (size_t)h * kv_head_stride;
462  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
463  }
464 
465  /* RoPE */
467  k,
468  rope_cos,
469  rope_sin,
470  H,
471  H_kv,
472  num_tokens,
473  head_dim,
474  aligned_head_dim,
475  0,
476  num_tokens,
477  aligned_context_window);
478 
479  /* Attention (prefill, causal) */
481  k,
482  v,
483  attn_out,
484  H,
485  H_kv,
486  num_tokens,
487  head_dim,
488  aligned_head_dim,
489  aligned_context_window);
490 
491  /* Output projection (flatten head-major to token-major) */
492  const int K = H * aligned_head_dim;
493  if (K != aligned_embed_dim) {
494  return;
495  }
496  const float *proj_in = attn_out;
497  if (H > 1) {
498  if (!proj_scratch) {
499  return;
500  }
501  for (int t = 0; t < num_tokens; ++t) {
502  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
503  for (int h = 0; h < H; ++h) {
504  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
505  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
506  src,
507  (size_t)aligned_head_dim * sizeof(float));
508  }
509  }
510  proj_in = proj_scratch;
511  }
512  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
513 
514  /* Residual add */
515  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
516 
517  /* RMSNorm before MLP */
518  rmsnorm_forward(residual1,
519  ln2_gamma,
520  ln2_out,
521  NULL,
522  num_tokens,
524  aligned_embed_dim,
525  1e-06f);
526 
527  /* MLP (SwiGLU) */
528  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
529  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
530  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
531 
532  /* Final residual add */
533  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
534 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_20_decode()

static void qwen2_0_5b_decode_layer_20_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 8025 of file v6.6/test_generated/qwen2_int8.c.

8032  {
8034 
8035  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[19].output);
8036 
8037  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
8038  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
8039  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
8040  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
8041  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
8042  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
8043  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
8044  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
8045  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
8046  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
8047  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
8048 
8049  /* Weights (explicit types for layer 20) */
8050  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
8051  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
8052  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
8053  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
8054  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
8055  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
8056 
8059 
8060  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
8061  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
8062  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
8063 
8064  float q_token[H * aligned_head_dim];
8065  float k_token[H_kv * aligned_head_dim];
8066  float v_token[H_kv * aligned_head_dim];
8067  float attn_token[H * aligned_head_dim];
8068 
8069  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
8070  float fc1_out[2 * aligned_intermediate_dim];
8071  float swiglu_out[aligned_intermediate_dim];
8072 
8073  /* Step 1: RMSNorm before attention */
8074  rmsnorm_forward(input,
8075  ln1_gamma,
8076  ln1_out,
8077  NULL,
8078  1,
8080  aligned_embed_dim,
8081  1e-06f);
8082 
8083  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
8084 
8085  /* Step 2: QKV projection */
8086  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8087  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8088  uint8_t ln1_q8[ln1_q8_bytes];
8089  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
8090  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8091  if (aligned_head_dim > head_dim) {
8092  for (int h = 0; h < H; ++h) {
8093  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
8094  for (int d = head_dim; d < aligned_head_dim; ++d) {
8095  q_head[d] = 0.0f;
8096  }
8097  }
8098  }
8099 
8100  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8101  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8102  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
8103  const uint8_t *WK_bytes = (const uint8_t *)WK;
8104  /* ln1_q8 already quantized above */
8105  for (int h = 0; h < H_kv; ++h) {
8106  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
8107  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8108  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8109  for (int d = head_dim; d < aligned_head_dim; ++d) {
8110  k_head[d] = 0.0f;
8111  }
8112  }
8113 
8114  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8115  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8116  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
8117  const uint8_t *WV_bytes = (const uint8_t *)WV;
8118  /* ln1_q8 already quantized above */
8119  for (int h = 0; h < H_kv; ++h) {
8120  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
8121  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8122  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8123  for (int d = head_dim; d < aligned_head_dim; ++d) {
8124  v_head[d] = 0.0f;
8125  }
8126  }
8127 
8128  /* Step 3: RoPE */
8129  rope_forward(q_token,
8130  rope_cos,
8131  rope_sin,
8132  H,
8133  1,
8134  head_dim,
8135  aligned_head_dim,
8136  token_index);
8137  for (int h = 0; h < H_kv; ++h) {
8138  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8139  rope_forward(k_head,
8140  rope_cos,
8141  rope_sin,
8142  1,
8143  1,
8144  head_dim,
8145  aligned_head_dim,
8146  token_index);
8147  }
8148 
8149  /* Step 4: KV cache write (direct-to-cache) */
8150 
8151  /* Step 5: Attention (decode, flash) */
8153  k_cache,
8154  v_cache,
8155  attn_token,
8156  H,
8157  H_kv,
8158  token_index + 1,
8159  aligned_context_window,
8160  head_dim,
8161  aligned_head_dim);
8162 
8163  /* Step 6: Output projection */
8164  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8165  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8166  uint8_t attn_q8[attn_q8_bytes];
8167  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
8168  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8169 
8170  /* Step 7: Residual add */
8171  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
8172 
8173  /* Step 8: RMSNorm before MLP */
8174  rmsnorm_forward(residual1,
8175  ln2_gamma,
8176  ln2_out,
8177  NULL,
8178  1,
8180  aligned_embed_dim,
8181  1e-06f);
8182 
8183  /* Step 9: MLP (SwiGLU) */
8184  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8185  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8186  uint8_t ln2_q8[ln2_q8_bytes];
8187  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
8188  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8189 
8190  /* SwiGLU activation */
8191  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8192 
8193  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8194  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8195  uint8_t swiglu_q8[swiglu_q8_bytes];
8196  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
8197  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8198 
8199  /* Step 10: Final residual add */
8200  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
8201 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_20_prefill()

static void qwen2_0_5b_decode_layer_20_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3491 of file v6.6/test_generated/qwen2_int8.c.

3498  {
3500 
3501  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[19].output);
3502  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3503  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3504  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3505  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3506  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3507  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3508  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3509  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3510  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3511  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3512  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3513  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3514  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3515  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3516  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3517 
3518  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3519  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3520  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3521  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3522  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3523  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3524  const float *BQ = NULL;
3525  const float *BK = NULL;
3526  const float *BV = NULL;
3527  const float *BO = NULL;
3528  const float *B1 = NULL;
3529  const float *B2 = NULL;
3530 
3533 
3534  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3535  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3536  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3537  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3538  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3539  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3540 
3541  /* RMSNorm before attention */
3542  rmsnorm_forward(input,
3543  ln1_gamma,
3544  ln1_out,
3545  NULL,
3546  num_tokens,
3548  aligned_embed_dim,
3549  1e-06f);
3550 
3551  /* Q projection (head-major) */
3552  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3553  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3554  for (int h = 0; h < H; ++h) {
3555  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3556  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3557  float *q_h = q + (size_t)h * q_head_stride;
3558  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3559  }
3560 
3561  /* K projection (head-major) */
3562  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3563  const uint8_t *WK_bytes = (const uint8_t *)WK;
3564  for (int h = 0; h < H_kv; ++h) {
3565  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3566  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3567  float *k_h = k + (size_t)h * kv_head_stride;
3568  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3569  }
3570 
3571  /* V projection (head-major) */
3572  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3573  const uint8_t *WV_bytes = (const uint8_t *)WV;
3574  for (int h = 0; h < H_kv; ++h) {
3575  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3576  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3577  float *v_h = v + (size_t)h * kv_head_stride;
3578  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3579  }
3580 
3581  /* RoPE */
3583  k,
3584  rope_cos,
3585  rope_sin,
3586  H,
3587  H_kv,
3588  num_tokens,
3589  head_dim,
3590  aligned_head_dim,
3591  0,
3592  num_tokens,
3593  aligned_context_window);
3594 
3595  /* Attention (prefill, causal) */
3597  k,
3598  v,
3599  attn_out,
3600  H,
3601  H_kv,
3602  num_tokens,
3603  head_dim,
3604  aligned_head_dim,
3605  aligned_context_window);
3606 
3607  /* Output projection (flatten head-major to token-major) */
3608  const int K = H * aligned_head_dim;
3609  if (K != aligned_embed_dim) {
3610  return;
3611  }
3612  const float *proj_in = attn_out;
3613  if (H > 1) {
3614  if (!proj_scratch) {
3615  return;
3616  }
3617  for (int t = 0; t < num_tokens; ++t) {
3618  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3619  for (int h = 0; h < H; ++h) {
3620  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3621  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3622  src,
3623  (size_t)aligned_head_dim * sizeof(float));
3624  }
3625  }
3626  proj_in = proj_scratch;
3627  }
3628  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3629 
3630  /* Residual add */
3631  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3632 
3633  /* RMSNorm before MLP */
3634  rmsnorm_forward(residual1,
3635  ln2_gamma,
3636  ln2_out,
3637  NULL,
3638  num_tokens,
3640  aligned_embed_dim,
3641  1e-06f);
3642 
3643  /* MLP (SwiGLU) */
3644  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3645  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3646  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3647 
3648  /* Final residual add */
3649  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3650 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_21_decode()

static void qwen2_0_5b_decode_layer_21_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 8206 of file v6.6/test_generated/qwen2_int8.c.

8213  {
8215 
8216  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[20].output);
8217 
8218  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
8219  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
8220  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
8221  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
8222  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
8223  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
8224  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
8225  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
8226  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
8227  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
8228  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
8229 
8230  /* Weights (explicit types for layer 21) */
8231  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
8232  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
8233  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
8234  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
8235  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
8236  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
8237 
8240 
8241  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
8242  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
8243  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
8244 
8245  float q_token[H * aligned_head_dim];
8246  float k_token[H_kv * aligned_head_dim];
8247  float v_token[H_kv * aligned_head_dim];
8248  float attn_token[H * aligned_head_dim];
8249 
8250  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
8251  float fc1_out[2 * aligned_intermediate_dim];
8252  float swiglu_out[aligned_intermediate_dim];
8253 
8254  /* Step 1: RMSNorm before attention */
8255  rmsnorm_forward(input,
8256  ln1_gamma,
8257  ln1_out,
8258  NULL,
8259  1,
8261  aligned_embed_dim,
8262  1e-06f);
8263 
8264  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
8265 
8266  /* Step 2: QKV projection */
8267  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8268  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8269  uint8_t ln1_q8[ln1_q8_bytes];
8270  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
8271  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8272  if (aligned_head_dim > head_dim) {
8273  for (int h = 0; h < H; ++h) {
8274  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
8275  for (int d = head_dim; d < aligned_head_dim; ++d) {
8276  q_head[d] = 0.0f;
8277  }
8278  }
8279  }
8280 
8281  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8282  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8283  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
8284  const uint8_t *WK_bytes = (const uint8_t *)WK;
8285  /* ln1_q8 already quantized above */
8286  for (int h = 0; h < H_kv; ++h) {
8287  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
8288  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8289  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8290  for (int d = head_dim; d < aligned_head_dim; ++d) {
8291  k_head[d] = 0.0f;
8292  }
8293  }
8294 
8295  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8296  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8297  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
8298  const uint8_t *WV_bytes = (const uint8_t *)WV;
8299  /* ln1_q8 already quantized above */
8300  for (int h = 0; h < H_kv; ++h) {
8301  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
8302  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8303  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8304  for (int d = head_dim; d < aligned_head_dim; ++d) {
8305  v_head[d] = 0.0f;
8306  }
8307  }
8308 
8309  /* Step 3: RoPE */
8310  rope_forward(q_token,
8311  rope_cos,
8312  rope_sin,
8313  H,
8314  1,
8315  head_dim,
8316  aligned_head_dim,
8317  token_index);
8318  for (int h = 0; h < H_kv; ++h) {
8319  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8320  rope_forward(k_head,
8321  rope_cos,
8322  rope_sin,
8323  1,
8324  1,
8325  head_dim,
8326  aligned_head_dim,
8327  token_index);
8328  }
8329 
8330  /* Step 4: KV cache write (direct-to-cache) */
8331 
8332  /* Step 5: Attention (decode, flash) */
8334  k_cache,
8335  v_cache,
8336  attn_token,
8337  H,
8338  H_kv,
8339  token_index + 1,
8340  aligned_context_window,
8341  head_dim,
8342  aligned_head_dim);
8343 
8344  /* Step 6: Output projection */
8345  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8346  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8347  uint8_t attn_q8[attn_q8_bytes];
8348  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
8349  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8350 
8351  /* Step 7: Residual add */
8352  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
8353 
8354  /* Step 8: RMSNorm before MLP */
8355  rmsnorm_forward(residual1,
8356  ln2_gamma,
8357  ln2_out,
8358  NULL,
8359  1,
8361  aligned_embed_dim,
8362  1e-06f);
8363 
8364  /* Step 9: MLP (SwiGLU) */
8365  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8366  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8367  uint8_t ln2_q8[ln2_q8_bytes];
8368  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
8369  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8370 
8371  /* SwiGLU activation */
8372  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8373 
8374  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8375  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8376  uint8_t swiglu_q8[swiglu_q8_bytes];
8377  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
8378  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8379 
8380  /* Step 10: Final residual add */
8381  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
8382 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_21_prefill()

static void qwen2_0_5b_decode_layer_21_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3655 of file v6.6/test_generated/qwen2_int8.c.

3662  {
3664 
3665  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[20].output);
3666  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3667  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3668  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3669  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3670  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3671  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3672  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3673  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3674  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3675  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3676  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3677  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3678  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3679  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3680  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3681 
3682  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3683  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3684  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3685  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3686  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3687  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3688  const float *BQ = NULL;
3689  const float *BK = NULL;
3690  const float *BV = NULL;
3691  const float *BO = NULL;
3692  const float *B1 = NULL;
3693  const float *B2 = NULL;
3694 
3697 
3698  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3699  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3700  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3701  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3702  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3703  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3704 
3705  /* RMSNorm before attention */
3706  rmsnorm_forward(input,
3707  ln1_gamma,
3708  ln1_out,
3709  NULL,
3710  num_tokens,
3712  aligned_embed_dim,
3713  1e-06f);
3714 
3715  /* Q projection (head-major) */
3716  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3717  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3718  for (int h = 0; h < H; ++h) {
3719  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3720  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3721  float *q_h = q + (size_t)h * q_head_stride;
3722  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3723  }
3724 
3725  /* K projection (head-major) */
3726  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3727  const uint8_t *WK_bytes = (const uint8_t *)WK;
3728  for (int h = 0; h < H_kv; ++h) {
3729  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3730  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3731  float *k_h = k + (size_t)h * kv_head_stride;
3732  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3733  }
3734 
3735  /* V projection (head-major) */
3736  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3737  const uint8_t *WV_bytes = (const uint8_t *)WV;
3738  for (int h = 0; h < H_kv; ++h) {
3739  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3740  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3741  float *v_h = v + (size_t)h * kv_head_stride;
3742  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3743  }
3744 
3745  /* RoPE */
3747  k,
3748  rope_cos,
3749  rope_sin,
3750  H,
3751  H_kv,
3752  num_tokens,
3753  head_dim,
3754  aligned_head_dim,
3755  0,
3756  num_tokens,
3757  aligned_context_window);
3758 
3759  /* Attention (prefill, causal) */
3761  k,
3762  v,
3763  attn_out,
3764  H,
3765  H_kv,
3766  num_tokens,
3767  head_dim,
3768  aligned_head_dim,
3769  aligned_context_window);
3770 
3771  /* Output projection (flatten head-major to token-major) */
3772  const int K = H * aligned_head_dim;
3773  if (K != aligned_embed_dim) {
3774  return;
3775  }
3776  const float *proj_in = attn_out;
3777  if (H > 1) {
3778  if (!proj_scratch) {
3779  return;
3780  }
3781  for (int t = 0; t < num_tokens; ++t) {
3782  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3783  for (int h = 0; h < H; ++h) {
3784  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3785  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3786  src,
3787  (size_t)aligned_head_dim * sizeof(float));
3788  }
3789  }
3790  proj_in = proj_scratch;
3791  }
3792  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3793 
3794  /* Residual add */
3795  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3796 
3797  /* RMSNorm before MLP */
3798  rmsnorm_forward(residual1,
3799  ln2_gamma,
3800  ln2_out,
3801  NULL,
3802  num_tokens,
3804  aligned_embed_dim,
3805  1e-06f);
3806 
3807  /* MLP (SwiGLU) */
3808  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3809  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3810  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3811 
3812  /* Final residual add */
3813  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3814 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_22_decode()

static void qwen2_0_5b_decode_layer_22_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 8387 of file v6.6/test_generated/qwen2_int8.c.

8394  {
8396 
8397  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[21].output);
8398 
8399  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
8400  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
8401  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
8402  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
8403  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
8404  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
8405  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
8406  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
8407  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
8408  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
8409  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
8410 
8411  /* Weights (explicit types for layer 22) */
8412  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
8413  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
8414  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
8415  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
8416  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
8417  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
8418 
8421 
8422  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
8423  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
8424  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
8425 
8426  float q_token[H * aligned_head_dim];
8427  float k_token[H_kv * aligned_head_dim];
8428  float v_token[H_kv * aligned_head_dim];
8429  float attn_token[H * aligned_head_dim];
8430 
8431  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
8432  float fc1_out[2 * aligned_intermediate_dim];
8433  float swiglu_out[aligned_intermediate_dim];
8434 
8435  /* Step 1: RMSNorm before attention */
8436  rmsnorm_forward(input,
8437  ln1_gamma,
8438  ln1_out,
8439  NULL,
8440  1,
8442  aligned_embed_dim,
8443  1e-06f);
8444 
8445  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
8446 
8447  /* Step 2: QKV projection */
8448  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8449  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8450  uint8_t ln1_q8[ln1_q8_bytes];
8451  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
8452  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8453  if (aligned_head_dim > head_dim) {
8454  for (int h = 0; h < H; ++h) {
8455  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
8456  for (int d = head_dim; d < aligned_head_dim; ++d) {
8457  q_head[d] = 0.0f;
8458  }
8459  }
8460  }
8461 
8462  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8463  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8464  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
8465  const uint8_t *WK_bytes = (const uint8_t *)WK;
8466  /* ln1_q8 already quantized above */
8467  for (int h = 0; h < H_kv; ++h) {
8468  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
8469  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8470  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8471  for (int d = head_dim; d < aligned_head_dim; ++d) {
8472  k_head[d] = 0.0f;
8473  }
8474  }
8475 
8476  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8477  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8478  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
8479  const uint8_t *WV_bytes = (const uint8_t *)WV;
8480  /* ln1_q8 already quantized above */
8481  for (int h = 0; h < H_kv; ++h) {
8482  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
8483  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8484  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8485  for (int d = head_dim; d < aligned_head_dim; ++d) {
8486  v_head[d] = 0.0f;
8487  }
8488  }
8489 
8490  /* Step 3: RoPE */
8491  rope_forward(q_token,
8492  rope_cos,
8493  rope_sin,
8494  H,
8495  1,
8496  head_dim,
8497  aligned_head_dim,
8498  token_index);
8499  for (int h = 0; h < H_kv; ++h) {
8500  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8501  rope_forward(k_head,
8502  rope_cos,
8503  rope_sin,
8504  1,
8505  1,
8506  head_dim,
8507  aligned_head_dim,
8508  token_index);
8509  }
8510 
8511  /* Step 4: KV cache write (direct-to-cache) */
8512 
8513  /* Step 5: Attention (decode, flash) */
8515  k_cache,
8516  v_cache,
8517  attn_token,
8518  H,
8519  H_kv,
8520  token_index + 1,
8521  aligned_context_window,
8522  head_dim,
8523  aligned_head_dim);
8524 
8525  /* Step 6: Output projection */
8526  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8527  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8528  uint8_t attn_q8[attn_q8_bytes];
8529  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
8530  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8531 
8532  /* Step 7: Residual add */
8533  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
8534 
8535  /* Step 8: RMSNorm before MLP */
8536  rmsnorm_forward(residual1,
8537  ln2_gamma,
8538  ln2_out,
8539  NULL,
8540  1,
8542  aligned_embed_dim,
8543  1e-06f);
8544 
8545  /* Step 9: MLP (SwiGLU) */
8546  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8547  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8548  uint8_t ln2_q8[ln2_q8_bytes];
8549  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
8550  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8551 
8552  /* SwiGLU activation */
8553  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8554 
8555  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8556  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8557  uint8_t swiglu_q8[swiglu_q8_bytes];
8558  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
8559  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8560 
8561  /* Step 10: Final residual add */
8562  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
8563 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_22_prefill()

static void qwen2_0_5b_decode_layer_22_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3819 of file v6.6/test_generated/qwen2_int8.c.

3826  {
3828 
3829  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[21].output);
3830  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3831  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3832  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3833  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3834  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3835  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
3836  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
3837  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
3838  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
3839  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
3840  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
3841  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
3842  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
3843  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
3844  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
3845 
3846  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
3847  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
3848  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
3849  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
3850  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
3851  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
3852  const float *BQ = NULL;
3853  const float *BK = NULL;
3854  const float *BV = NULL;
3855  const float *BO = NULL;
3856  const float *B1 = NULL;
3857  const float *B2 = NULL;
3858 
3861 
3862  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
3863  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
3864  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
3865  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
3866  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
3867  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
3868 
3869  /* RMSNorm before attention */
3870  rmsnorm_forward(input,
3871  ln1_gamma,
3872  ln1_out,
3873  NULL,
3874  num_tokens,
3876  aligned_embed_dim,
3877  1e-06f);
3878 
3879  /* Q projection (head-major) */
3880  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3881  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
3882  for (int h = 0; h < H; ++h) {
3883  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
3884  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
3885  float *q_h = q + (size_t)h * q_head_stride;
3886  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3887  }
3888 
3889  /* K projection (head-major) */
3890  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3891  const uint8_t *WK_bytes = (const uint8_t *)WK;
3892  for (int h = 0; h < H_kv; ++h) {
3893  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
3894  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
3895  float *k_h = k + (size_t)h * kv_head_stride;
3896  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3897  }
3898 
3899  /* V projection (head-major) */
3900  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
3901  const uint8_t *WV_bytes = (const uint8_t *)WV;
3902  for (int h = 0; h < H_kv; ++h) {
3903  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
3904  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
3905  float *v_h = v + (size_t)h * kv_head_stride;
3906  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3907  }
3908 
3909  /* RoPE */
3911  k,
3912  rope_cos,
3913  rope_sin,
3914  H,
3915  H_kv,
3916  num_tokens,
3917  head_dim,
3918  aligned_head_dim,
3919  0,
3920  num_tokens,
3921  aligned_context_window);
3922 
3923  /* Attention (prefill, causal) */
3925  k,
3926  v,
3927  attn_out,
3928  H,
3929  H_kv,
3930  num_tokens,
3931  head_dim,
3932  aligned_head_dim,
3933  aligned_context_window);
3934 
3935  /* Output projection (flatten head-major to token-major) */
3936  const int K = H * aligned_head_dim;
3937  if (K != aligned_embed_dim) {
3938  return;
3939  }
3940  const float *proj_in = attn_out;
3941  if (H > 1) {
3942  if (!proj_scratch) {
3943  return;
3944  }
3945  for (int t = 0; t < num_tokens; ++t) {
3946  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
3947  for (int h = 0; h < H; ++h) {
3948  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
3949  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
3950  src,
3951  (size_t)aligned_head_dim * sizeof(float));
3952  }
3953  }
3954  proj_in = proj_scratch;
3955  }
3956  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3957 
3958  /* Residual add */
3959  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
3960 
3961  /* RMSNorm before MLP */
3962  rmsnorm_forward(residual1,
3963  ln2_gamma,
3964  ln2_out,
3965  NULL,
3966  num_tokens,
3968  aligned_embed_dim,
3969  1e-06f);
3970 
3971  /* MLP (SwiGLU) */
3972  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3973  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3974  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3975 
3976  /* Final residual add */
3977  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
3978 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_23_decode()

static void qwen2_0_5b_decode_layer_23_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 8568 of file v6.6/test_generated/qwen2_int8.c.

8575  {
8577 
8578  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[22].output);
8579 
8580  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
8581  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
8582  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
8583  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
8584  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
8585  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
8586  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
8587  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
8588  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
8589  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
8590  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
8591 
8592  /* Weights (explicit types for layer 23) */
8593  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
8594  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
8595  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
8596  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
8597  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
8598  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
8599 
8602 
8603  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
8604  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
8605  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
8606 
8607  float q_token[H * aligned_head_dim];
8608  float k_token[H_kv * aligned_head_dim];
8609  float v_token[H_kv * aligned_head_dim];
8610  float attn_token[H * aligned_head_dim];
8611 
8612  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
8613  float fc1_out[2 * aligned_intermediate_dim];
8614  float swiglu_out[aligned_intermediate_dim];
8615 
8616  /* Step 1: RMSNorm before attention */
8617  rmsnorm_forward(input,
8618  ln1_gamma,
8619  ln1_out,
8620  NULL,
8621  1,
8623  aligned_embed_dim,
8624  1e-06f);
8625 
8626  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
8627 
8628  /* Step 2: QKV projection */
8629  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8630  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8631  uint8_t ln1_q8[ln1_q8_bytes];
8632  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
8633  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8634  if (aligned_head_dim > head_dim) {
8635  for (int h = 0; h < H; ++h) {
8636  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
8637  for (int d = head_dim; d < aligned_head_dim; ++d) {
8638  q_head[d] = 0.0f;
8639  }
8640  }
8641  }
8642 
8643  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8644  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8645  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
8646  const uint8_t *WK_bytes = (const uint8_t *)WK;
8647  /* ln1_q8 already quantized above */
8648  for (int h = 0; h < H_kv; ++h) {
8649  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
8650  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8651  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8652  for (int d = head_dim; d < aligned_head_dim; ++d) {
8653  k_head[d] = 0.0f;
8654  }
8655  }
8656 
8657  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
8658  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
8659  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
8660  const uint8_t *WV_bytes = (const uint8_t *)WV;
8661  /* ln1_q8 already quantized above */
8662  for (int h = 0; h < H_kv; ++h) {
8663  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
8664  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8665  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8666  for (int d = head_dim; d < aligned_head_dim; ++d) {
8667  v_head[d] = 0.0f;
8668  }
8669  }
8670 
8671  /* Step 3: RoPE */
8672  rope_forward(q_token,
8673  rope_cos,
8674  rope_sin,
8675  H,
8676  1,
8677  head_dim,
8678  aligned_head_dim,
8679  token_index);
8680  for (int h = 0; h < H_kv; ++h) {
8681  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
8682  rope_forward(k_head,
8683  rope_cos,
8684  rope_sin,
8685  1,
8686  1,
8687  head_dim,
8688  aligned_head_dim,
8689  token_index);
8690  }
8691 
8692  /* Step 4: KV cache write (direct-to-cache) */
8693 
8694  /* Step 5: Attention (decode, flash) */
8696  k_cache,
8697  v_cache,
8698  attn_token,
8699  H,
8700  H_kv,
8701  token_index + 1,
8702  aligned_context_window,
8703  head_dim,
8704  aligned_head_dim);
8705 
8706  /* Step 6: Output projection */
8707  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8708  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8709  uint8_t attn_q8[attn_q8_bytes];
8710  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
8711  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8712 
8713  /* Step 7: Residual add */
8714  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
8715 
8716  /* Step 8: RMSNorm before MLP */
8717  rmsnorm_forward(residual1,
8718  ln2_gamma,
8719  ln2_out,
8720  NULL,
8721  1,
8723  aligned_embed_dim,
8724  1e-06f);
8725 
8726  /* Step 9: MLP (SwiGLU) */
8727  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8728  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8729  uint8_t ln2_q8[ln2_q8_bytes];
8730  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
8731  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8732 
8733  /* SwiGLU activation */
8734  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8735 
8736  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
8737  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8738  uint8_t swiglu_q8[swiglu_q8_bytes];
8739  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
8740  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8741 
8742  /* Step 10: Final residual add */
8743  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
8744 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_23_prefill()

static void qwen2_0_5b_decode_layer_23_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 3983 of file v6.6/test_generated/qwen2_int8.c.

3990  {
3992 
3993  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[22].output);
3994  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
3995  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
3996  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
3997  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
3998  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
3999  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
4000  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
4001  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
4002  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4003  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4004  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4005  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
4006  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
4007  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4008  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4009 
4010  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
4011  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
4012  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
4013  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
4014  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
4015  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
4016  const float *BQ = NULL;
4017  const float *BK = NULL;
4018  const float *BV = NULL;
4019  const float *BO = NULL;
4020  const float *B1 = NULL;
4021  const float *B2 = NULL;
4022 
4025 
4026  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4027  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4028  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4029  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4030  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
4031  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
4032 
4033  /* RMSNorm before attention */
4034  rmsnorm_forward(input,
4035  ln1_gamma,
4036  ln1_out,
4037  NULL,
4038  num_tokens,
4040  aligned_embed_dim,
4041  1e-06f);
4042 
4043  /* Q projection (head-major) */
4044  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
4045  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
4046  for (int h = 0; h < H; ++h) {
4047  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
4048  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
4049  float *q_h = q + (size_t)h * q_head_stride;
4050  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4051  }
4052 
4053  /* K projection (head-major) */
4054  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
4055  const uint8_t *WK_bytes = (const uint8_t *)WK;
4056  for (int h = 0; h < H_kv; ++h) {
4057  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
4058  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
4059  float *k_h = k + (size_t)h * kv_head_stride;
4060  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4061  }
4062 
4063  /* V projection (head-major) */
4064  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
4065  const uint8_t *WV_bytes = (const uint8_t *)WV;
4066  for (int h = 0; h < H_kv; ++h) {
4067  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
4068  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
4069  float *v_h = v + (size_t)h * kv_head_stride;
4070  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4071  }
4072 
4073  /* RoPE */
4075  k,
4076  rope_cos,
4077  rope_sin,
4078  H,
4079  H_kv,
4080  num_tokens,
4081  head_dim,
4082  aligned_head_dim,
4083  0,
4084  num_tokens,
4085  aligned_context_window);
4086 
4087  /* Attention (prefill, causal) */
4089  k,
4090  v,
4091  attn_out,
4092  H,
4093  H_kv,
4094  num_tokens,
4095  head_dim,
4096  aligned_head_dim,
4097  aligned_context_window);
4098 
4099  /* Output projection (flatten head-major to token-major) */
4100  const int K = H * aligned_head_dim;
4101  if (K != aligned_embed_dim) {
4102  return;
4103  }
4104  const float *proj_in = attn_out;
4105  if (H > 1) {
4106  if (!proj_scratch) {
4107  return;
4108  }
4109  for (int t = 0; t < num_tokens; ++t) {
4110  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
4111  for (int h = 0; h < H; ++h) {
4112  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
4113  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
4114  src,
4115  (size_t)aligned_head_dim * sizeof(float));
4116  }
4117  }
4118  proj_in = proj_scratch;
4119  }
4120  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4121 
4122  /* Residual add */
4123  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
4124 
4125  /* RMSNorm before MLP */
4126  rmsnorm_forward(residual1,
4127  ln2_gamma,
4128  ln2_out,
4129  NULL,
4130  num_tokens,
4132  aligned_embed_dim,
4133  1e-06f);
4134 
4135  /* MLP (SwiGLU) */
4136  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4137  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4138  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4139 
4140  /* Final residual add */
4141  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
4142 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_2_decode()

static void qwen2_0_5b_decode_layer_2_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4767 of file v6.6/test_generated/qwen2_int8.c.

4774  {
4776 
4777  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[1].output);
4778 
4779  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4780  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4781  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4782  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4783  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4784  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4785  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4786  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4787  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4788  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4789  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4790 
4791  /* Weights (explicit types for layer 2) */
4792  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4793  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4794  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4795  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4796  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4797  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4798 
4801 
4802  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4803  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4804  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4805 
4806  float q_token[H * aligned_head_dim];
4807  float k_token[H_kv * aligned_head_dim];
4808  float v_token[H_kv * aligned_head_dim];
4809  float attn_token[H * aligned_head_dim];
4810 
4811  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4812  float fc1_out[2 * aligned_intermediate_dim];
4813  float swiglu_out[aligned_intermediate_dim];
4814 
4815  /* Step 1: RMSNorm before attention */
4816  rmsnorm_forward(input,
4817  ln1_gamma,
4818  ln1_out,
4819  NULL,
4820  1,
4822  aligned_embed_dim,
4823  1e-06f);
4824 
4825  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
4826 
4827  /* Step 2: QKV projection */
4828  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4829  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4830  uint8_t ln1_q8[ln1_q8_bytes];
4831  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
4832  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4833  if (aligned_head_dim > head_dim) {
4834  for (int h = 0; h < H; ++h) {
4835  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
4836  for (int d = head_dim; d < aligned_head_dim; ++d) {
4837  q_head[d] = 0.0f;
4838  }
4839  }
4840  }
4841 
4842  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4843  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4844  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
4845  const uint8_t *WK_bytes = (const uint8_t *)WK;
4846  /* ln1_q8 already quantized above */
4847  for (int h = 0; h < H_kv; ++h) {
4848  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
4849  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4850  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4851  for (int d = head_dim; d < aligned_head_dim; ++d) {
4852  k_head[d] = 0.0f;
4853  }
4854  }
4855 
4856  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
4857  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
4858  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
4859  const uint8_t *WV_bytes = (const uint8_t *)WV;
4860  /* ln1_q8 already quantized above */
4861  for (int h = 0; h < H_kv; ++h) {
4862  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
4863  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4864  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4865  for (int d = head_dim; d < aligned_head_dim; ++d) {
4866  v_head[d] = 0.0f;
4867  }
4868  }
4869 
4870  /* Step 3: RoPE */
4871  rope_forward(q_token,
4872  rope_cos,
4873  rope_sin,
4874  H,
4875  1,
4876  head_dim,
4877  aligned_head_dim,
4878  token_index);
4879  for (int h = 0; h < H_kv; ++h) {
4880  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
4881  rope_forward(k_head,
4882  rope_cos,
4883  rope_sin,
4884  1,
4885  1,
4886  head_dim,
4887  aligned_head_dim,
4888  token_index);
4889  }
4890 
4891  /* Step 4: KV cache write (direct-to-cache) */
4892 
4893  /* Step 5: Attention (decode, flash) */
4895  k_cache,
4896  v_cache,
4897  attn_token,
4898  H,
4899  H_kv,
4900  token_index + 1,
4901  aligned_context_window,
4902  head_dim,
4903  aligned_head_dim);
4904 
4905  /* Step 6: Output projection */
4906  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4907  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4908  uint8_t attn_q8[attn_q8_bytes];
4909  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
4910  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4911 
4912  /* Step 7: Residual add */
4913  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
4914 
4915  /* Step 8: RMSNorm before MLP */
4916  rmsnorm_forward(residual1,
4917  ln2_gamma,
4918  ln2_out,
4919  NULL,
4920  1,
4922  aligned_embed_dim,
4923  1e-06f);
4924 
4925  /* Step 9: MLP (SwiGLU) */
4926  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4927  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4928  uint8_t ln2_q8[ln2_q8_bytes];
4929  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
4930  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4931 
4932  /* SwiGLU activation */
4933  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4934 
4935  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
4936  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4937  uint8_t swiglu_q8[swiglu_q8_bytes];
4938  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
4939  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4940 
4941  /* Step 10: Final residual add */
4942  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
4943 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_2_prefill()

static void qwen2_0_5b_decode_layer_2_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 539 of file v6.6/test_generated/qwen2_int8.c.

546  {
548 
549  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[1].output);
550  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
551  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
552  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
553  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
554  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
555  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
556  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
557  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
558  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
559  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
560  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
561  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
562  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
563  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
564  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
565 
566  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
567  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
568  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
569  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
570  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
571  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
572  const float *BQ = NULL;
573  const float *BK = NULL;
574  const float *BV = NULL;
575  const float *BO = NULL;
576  const float *B1 = NULL;
577  const float *B2 = NULL;
578 
581 
582  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
583  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
584  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
585  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
586  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
587  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
588 
589  /* RMSNorm before attention */
590  rmsnorm_forward(input,
591  ln1_gamma,
592  ln1_out,
593  NULL,
594  num_tokens,
596  aligned_embed_dim,
597  1e-06f);
598 
599  /* Q projection (head-major) */
600  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
601  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
602  for (int h = 0; h < H; ++h) {
603  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
604  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
605  float *q_h = q + (size_t)h * q_head_stride;
606  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
607  }
608 
609  /* K projection (head-major) */
610  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
611  const uint8_t *WK_bytes = (const uint8_t *)WK;
612  for (int h = 0; h < H_kv; ++h) {
613  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
614  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
615  float *k_h = k + (size_t)h * kv_head_stride;
616  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
617  }
618 
619  /* V projection (head-major) */
620  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
621  const uint8_t *WV_bytes = (const uint8_t *)WV;
622  for (int h = 0; h < H_kv; ++h) {
623  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
624  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
625  float *v_h = v + (size_t)h * kv_head_stride;
626  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
627  }
628 
629  /* RoPE */
631  k,
632  rope_cos,
633  rope_sin,
634  H,
635  H_kv,
636  num_tokens,
637  head_dim,
638  aligned_head_dim,
639  0,
640  num_tokens,
641  aligned_context_window);
642 
643  /* Attention (prefill, causal) */
645  k,
646  v,
647  attn_out,
648  H,
649  H_kv,
650  num_tokens,
651  head_dim,
652  aligned_head_dim,
653  aligned_context_window);
654 
655  /* Output projection (flatten head-major to token-major) */
656  const int K = H * aligned_head_dim;
657  if (K != aligned_embed_dim) {
658  return;
659  }
660  const float *proj_in = attn_out;
661  if (H > 1) {
662  if (!proj_scratch) {
663  return;
664  }
665  for (int t = 0; t < num_tokens; ++t) {
666  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
667  for (int h = 0; h < H; ++h) {
668  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
669  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
670  src,
671  (size_t)aligned_head_dim * sizeof(float));
672  }
673  }
674  proj_in = proj_scratch;
675  }
676  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
677 
678  /* Residual add */
679  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
680 
681  /* RMSNorm before MLP */
682  rmsnorm_forward(residual1,
683  ln2_gamma,
684  ln2_out,
685  NULL,
686  num_tokens,
688  aligned_embed_dim,
689  1e-06f);
690 
691  /* MLP (SwiGLU) */
692  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
693  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
694  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
695 
696  /* Final residual add */
697  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
698 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_3_decode()

static void qwen2_0_5b_decode_layer_3_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 4948 of file v6.6/test_generated/qwen2_int8.c.

4955  {
4957 
4958  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[2].output);
4959 
4960  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
4961  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
4962  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
4963  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
4964  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
4965  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
4966  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
4967  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
4968  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
4969  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
4970  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
4971 
4972  /* Weights (explicit types for layer 3) */
4973  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
4974  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
4975  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
4976  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
4977  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
4978  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
4979 
4982 
4983  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
4984  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
4985  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
4986 
4987  float q_token[H * aligned_head_dim];
4988  float k_token[H_kv * aligned_head_dim];
4989  float v_token[H_kv * aligned_head_dim];
4990  float attn_token[H * aligned_head_dim];
4991 
4992  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
4993  float fc1_out[2 * aligned_intermediate_dim];
4994  float swiglu_out[aligned_intermediate_dim];
4995 
4996  /* Step 1: RMSNorm before attention */
4997  rmsnorm_forward(input,
4998  ln1_gamma,
4999  ln1_out,
5000  NULL,
5001  1,
5003  aligned_embed_dim,
5004  1e-06f);
5005 
5006  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5007 
5008  /* Step 2: QKV projection */
5009  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5010  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5011  uint8_t ln1_q8[ln1_q8_bytes];
5012  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5013  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5014  if (aligned_head_dim > head_dim) {
5015  for (int h = 0; h < H; ++h) {
5016  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5017  for (int d = head_dim; d < aligned_head_dim; ++d) {
5018  q_head[d] = 0.0f;
5019  }
5020  }
5021  }
5022 
5023  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5024  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5025  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5026  const uint8_t *WK_bytes = (const uint8_t *)WK;
5027  /* ln1_q8 already quantized above */
5028  for (int h = 0; h < H_kv; ++h) {
5029  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5030  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5031  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5032  for (int d = head_dim; d < aligned_head_dim; ++d) {
5033  k_head[d] = 0.0f;
5034  }
5035  }
5036 
5037  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5038  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5039  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5040  const uint8_t *WV_bytes = (const uint8_t *)WV;
5041  /* ln1_q8 already quantized above */
5042  for (int h = 0; h < H_kv; ++h) {
5043  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5044  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5045  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5046  for (int d = head_dim; d < aligned_head_dim; ++d) {
5047  v_head[d] = 0.0f;
5048  }
5049  }
5050 
5051  /* Step 3: RoPE */
5052  rope_forward(q_token,
5053  rope_cos,
5054  rope_sin,
5055  H,
5056  1,
5057  head_dim,
5058  aligned_head_dim,
5059  token_index);
5060  for (int h = 0; h < H_kv; ++h) {
5061  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5062  rope_forward(k_head,
5063  rope_cos,
5064  rope_sin,
5065  1,
5066  1,
5067  head_dim,
5068  aligned_head_dim,
5069  token_index);
5070  }
5071 
5072  /* Step 4: KV cache write (direct-to-cache) */
5073 
5074  /* Step 5: Attention (decode, flash) */
5076  k_cache,
5077  v_cache,
5078  attn_token,
5079  H,
5080  H_kv,
5081  token_index + 1,
5082  aligned_context_window,
5083  head_dim,
5084  aligned_head_dim);
5085 
5086  /* Step 6: Output projection */
5087  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5088  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5089  uint8_t attn_q8[attn_q8_bytes];
5090  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5091  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5092 
5093  /* Step 7: Residual add */
5094  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5095 
5096  /* Step 8: RMSNorm before MLP */
5097  rmsnorm_forward(residual1,
5098  ln2_gamma,
5099  ln2_out,
5100  NULL,
5101  1,
5103  aligned_embed_dim,
5104  1e-06f);
5105 
5106  /* Step 9: MLP (SwiGLU) */
5107  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5108  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5109  uint8_t ln2_q8[ln2_q8_bytes];
5110  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
5111  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5112 
5113  /* SwiGLU activation */
5114  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5115 
5116  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5117  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5118  uint8_t swiglu_q8[swiglu_q8_bytes];
5119  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
5120  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5121 
5122  /* Step 10: Final residual add */
5123  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5124 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_3_prefill()

static void qwen2_0_5b_decode_layer_3_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 703 of file v6.6/test_generated/qwen2_int8.c.

710  {
712 
713  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[2].output);
714  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
715  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
716  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
717  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
718  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
719  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
720  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
721  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
722  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
723  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
724  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
725  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
726  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
727  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
728  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
729 
730  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
731  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
732  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
733  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
734  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
735  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
736  const float *BQ = NULL;
737  const float *BK = NULL;
738  const float *BV = NULL;
739  const float *BO = NULL;
740  const float *B1 = NULL;
741  const float *B2 = NULL;
742 
745 
746  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
747  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
748  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
749  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
750  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
751  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
752 
753  /* RMSNorm before attention */
754  rmsnorm_forward(input,
755  ln1_gamma,
756  ln1_out,
757  NULL,
758  num_tokens,
760  aligned_embed_dim,
761  1e-06f);
762 
763  /* Q projection (head-major) */
764  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
765  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
766  for (int h = 0; h < H; ++h) {
767  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
768  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
769  float *q_h = q + (size_t)h * q_head_stride;
770  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
771  }
772 
773  /* K projection (head-major) */
774  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
775  const uint8_t *WK_bytes = (const uint8_t *)WK;
776  for (int h = 0; h < H_kv; ++h) {
777  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
778  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
779  float *k_h = k + (size_t)h * kv_head_stride;
780  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
781  }
782 
783  /* V projection (head-major) */
784  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
785  const uint8_t *WV_bytes = (const uint8_t *)WV;
786  for (int h = 0; h < H_kv; ++h) {
787  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
788  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
789  float *v_h = v + (size_t)h * kv_head_stride;
790  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
791  }
792 
793  /* RoPE */
795  k,
796  rope_cos,
797  rope_sin,
798  H,
799  H_kv,
800  num_tokens,
801  head_dim,
802  aligned_head_dim,
803  0,
804  num_tokens,
805  aligned_context_window);
806 
807  /* Attention (prefill, causal) */
809  k,
810  v,
811  attn_out,
812  H,
813  H_kv,
814  num_tokens,
815  head_dim,
816  aligned_head_dim,
817  aligned_context_window);
818 
819  /* Output projection (flatten head-major to token-major) */
820  const int K = H * aligned_head_dim;
821  if (K != aligned_embed_dim) {
822  return;
823  }
824  const float *proj_in = attn_out;
825  if (H > 1) {
826  if (!proj_scratch) {
827  return;
828  }
829  for (int t = 0; t < num_tokens; ++t) {
830  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
831  for (int h = 0; h < H; ++h) {
832  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
833  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
834  src,
835  (size_t)aligned_head_dim * sizeof(float));
836  }
837  }
838  proj_in = proj_scratch;
839  }
840  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
841 
842  /* Residual add */
843  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
844 
845  /* RMSNorm before MLP */
846  rmsnorm_forward(residual1,
847  ln2_gamma,
848  ln2_out,
849  NULL,
850  num_tokens,
852  aligned_embed_dim,
853  1e-06f);
854 
855  /* MLP (SwiGLU) */
856  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
857  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
858  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
859 
860  /* Final residual add */
861  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
862 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_4_decode()

static void qwen2_0_5b_decode_layer_4_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5129 of file v6.6/test_generated/qwen2_int8.c.

5136  {
5138 
5139  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[3].output);
5140 
5141  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5142  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5143  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5144  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5145  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5146  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5147  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5148  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5149  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5150  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5151  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5152 
5153  /* Weights (explicit types for layer 4) */
5154  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5155  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5156  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5157  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5158  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5159  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5160 
5163 
5164  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5165  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5166  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5167 
5168  float q_token[H * aligned_head_dim];
5169  float k_token[H_kv * aligned_head_dim];
5170  float v_token[H_kv * aligned_head_dim];
5171  float attn_token[H * aligned_head_dim];
5172 
5173  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5174  float fc1_out[2 * aligned_intermediate_dim];
5175  float swiglu_out[aligned_intermediate_dim];
5176 
5177  /* Step 1: RMSNorm before attention */
5178  rmsnorm_forward(input,
5179  ln1_gamma,
5180  ln1_out,
5181  NULL,
5182  1,
5184  aligned_embed_dim,
5185  1e-06f);
5186 
5187  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5188 
5189  /* Step 2: QKV projection */
5190  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5191  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5192  uint8_t ln1_q8[ln1_q8_bytes];
5193  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5194  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5195  if (aligned_head_dim > head_dim) {
5196  for (int h = 0; h < H; ++h) {
5197  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5198  for (int d = head_dim; d < aligned_head_dim; ++d) {
5199  q_head[d] = 0.0f;
5200  }
5201  }
5202  }
5203 
5204  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5205  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5206  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5207  const uint8_t *WK_bytes = (const uint8_t *)WK;
5208  /* ln1_q8 already quantized above */
5209  for (int h = 0; h < H_kv; ++h) {
5210  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5211  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5212  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5213  for (int d = head_dim; d < aligned_head_dim; ++d) {
5214  k_head[d] = 0.0f;
5215  }
5216  }
5217 
5218  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5219  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5220  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5221  const uint8_t *WV_bytes = (const uint8_t *)WV;
5222  /* ln1_q8 already quantized above */
5223  for (int h = 0; h < H_kv; ++h) {
5224  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5225  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5226  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5227  for (int d = head_dim; d < aligned_head_dim; ++d) {
5228  v_head[d] = 0.0f;
5229  }
5230  }
5231 
5232  /* Step 3: RoPE */
5233  rope_forward(q_token,
5234  rope_cos,
5235  rope_sin,
5236  H,
5237  1,
5238  head_dim,
5239  aligned_head_dim,
5240  token_index);
5241  for (int h = 0; h < H_kv; ++h) {
5242  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5243  rope_forward(k_head,
5244  rope_cos,
5245  rope_sin,
5246  1,
5247  1,
5248  head_dim,
5249  aligned_head_dim,
5250  token_index);
5251  }
5252 
5253  /* Step 4: KV cache write (direct-to-cache) */
5254 
5255  /* Step 5: Attention (decode, flash) */
5257  k_cache,
5258  v_cache,
5259  attn_token,
5260  H,
5261  H_kv,
5262  token_index + 1,
5263  aligned_context_window,
5264  head_dim,
5265  aligned_head_dim);
5266 
5267  /* Step 6: Output projection */
5268  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5269  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5270  uint8_t attn_q8[attn_q8_bytes];
5271  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5272  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5273 
5274  /* Step 7: Residual add */
5275  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5276 
5277  /* Step 8: RMSNorm before MLP */
5278  rmsnorm_forward(residual1,
5279  ln2_gamma,
5280  ln2_out,
5281  NULL,
5282  1,
5284  aligned_embed_dim,
5285  1e-06f);
5286 
5287  /* Step 9: MLP (SwiGLU) */
5288  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5289  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5290  uint8_t ln2_q8[ln2_q8_bytes];
5291  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
5292  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5293 
5294  /* SwiGLU activation */
5295  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5296 
5297  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5298  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5299  uint8_t swiglu_q8[swiglu_q8_bytes];
5300  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
5301  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5302 
5303  /* Step 10: Final residual add */
5304  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5305 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_4_prefill()

static void qwen2_0_5b_decode_layer_4_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 867 of file v6.6/test_generated/qwen2_int8.c.

874  {
876 
877  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[3].output);
878  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
879  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
880  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
881  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
882  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
883  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
884  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
885  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
886  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
887  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
888  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
889  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
890  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
891  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
892  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
893 
894  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
895  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
896  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
897  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
898  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
899  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
900  const float *BQ = NULL;
901  const float *BK = NULL;
902  const float *BV = NULL;
903  const float *BO = NULL;
904  const float *B1 = NULL;
905  const float *B2 = NULL;
906 
909 
910  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
911  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
912  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
913  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
914  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
915  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
916 
917  /* RMSNorm before attention */
918  rmsnorm_forward(input,
919  ln1_gamma,
920  ln1_out,
921  NULL,
922  num_tokens,
924  aligned_embed_dim,
925  1e-06f);
926 
927  /* Q projection (head-major) */
928  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
929  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
930  for (int h = 0; h < H; ++h) {
931  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
932  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
933  float *q_h = q + (size_t)h * q_head_stride;
934  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
935  }
936 
937  /* K projection (head-major) */
938  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
939  const uint8_t *WK_bytes = (const uint8_t *)WK;
940  for (int h = 0; h < H_kv; ++h) {
941  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
942  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
943  float *k_h = k + (size_t)h * kv_head_stride;
944  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
945  }
946 
947  /* V projection (head-major) */
948  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
949  const uint8_t *WV_bytes = (const uint8_t *)WV;
950  for (int h = 0; h < H_kv; ++h) {
951  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
952  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
953  float *v_h = v + (size_t)h * kv_head_stride;
954  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
955  }
956 
957  /* RoPE */
959  k,
960  rope_cos,
961  rope_sin,
962  H,
963  H_kv,
964  num_tokens,
965  head_dim,
966  aligned_head_dim,
967  0,
968  num_tokens,
969  aligned_context_window);
970 
971  /* Attention (prefill, causal) */
973  k,
974  v,
975  attn_out,
976  H,
977  H_kv,
978  num_tokens,
979  head_dim,
980  aligned_head_dim,
981  aligned_context_window);
982 
983  /* Output projection (flatten head-major to token-major) */
984  const int K = H * aligned_head_dim;
985  if (K != aligned_embed_dim) {
986  return;
987  }
988  const float *proj_in = attn_out;
989  if (H > 1) {
990  if (!proj_scratch) {
991  return;
992  }
993  for (int t = 0; t < num_tokens; ++t) {
994  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
995  for (int h = 0; h < H; ++h) {
996  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
997  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
998  src,
999  (size_t)aligned_head_dim * sizeof(float));
1000  }
1001  }
1002  proj_in = proj_scratch;
1003  }
1004  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1005 
1006  /* Residual add */
1007  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1008 
1009  /* RMSNorm before MLP */
1010  rmsnorm_forward(residual1,
1011  ln2_gamma,
1012  ln2_out,
1013  NULL,
1014  num_tokens,
1016  aligned_embed_dim,
1017  1e-06f);
1018 
1019  /* MLP (SwiGLU) */
1020  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1021  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1022  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1023 
1024  /* Final residual add */
1025  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1026 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_5_decode()

static void qwen2_0_5b_decode_layer_5_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5310 of file v6.6/test_generated/qwen2_int8.c.

5317  {
5319 
5320  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[4].output);
5321 
5322  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5323  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5324  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5325  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5326  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5327  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5328  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5329  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5330  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5331  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5332  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5333 
5334  /* Weights (explicit types for layer 5) */
5335  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5336  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5337  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5338  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5339  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5340  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5341 
5344 
5345  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5346  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5347  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5348 
5349  float q_token[H * aligned_head_dim];
5350  float k_token[H_kv * aligned_head_dim];
5351  float v_token[H_kv * aligned_head_dim];
5352  float attn_token[H * aligned_head_dim];
5353 
5354  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5355  float fc1_out[2 * aligned_intermediate_dim];
5356  float swiglu_out[aligned_intermediate_dim];
5357 
5358  /* Step 1: RMSNorm before attention */
5359  rmsnorm_forward(input,
5360  ln1_gamma,
5361  ln1_out,
5362  NULL,
5363  1,
5365  aligned_embed_dim,
5366  1e-06f);
5367 
5368  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5369 
5370  /* Step 2: QKV projection */
5371  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5372  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5373  uint8_t ln1_q8[ln1_q8_bytes];
5374  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5375  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5376  if (aligned_head_dim > head_dim) {
5377  for (int h = 0; h < H; ++h) {
5378  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5379  for (int d = head_dim; d < aligned_head_dim; ++d) {
5380  q_head[d] = 0.0f;
5381  }
5382  }
5383  }
5384 
5385  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5386  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5387  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5388  const uint8_t *WK_bytes = (const uint8_t *)WK;
5389  /* ln1_q8 already quantized above */
5390  for (int h = 0; h < H_kv; ++h) {
5391  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5392  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5393  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5394  for (int d = head_dim; d < aligned_head_dim; ++d) {
5395  k_head[d] = 0.0f;
5396  }
5397  }
5398 
5399  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5400  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5401  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5402  const uint8_t *WV_bytes = (const uint8_t *)WV;
5403  /* ln1_q8 already quantized above */
5404  for (int h = 0; h < H_kv; ++h) {
5405  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5406  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5407  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5408  for (int d = head_dim; d < aligned_head_dim; ++d) {
5409  v_head[d] = 0.0f;
5410  }
5411  }
5412 
5413  /* Step 3: RoPE */
5414  rope_forward(q_token,
5415  rope_cos,
5416  rope_sin,
5417  H,
5418  1,
5419  head_dim,
5420  aligned_head_dim,
5421  token_index);
5422  for (int h = 0; h < H_kv; ++h) {
5423  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5424  rope_forward(k_head,
5425  rope_cos,
5426  rope_sin,
5427  1,
5428  1,
5429  head_dim,
5430  aligned_head_dim,
5431  token_index);
5432  }
5433 
5434  /* Step 4: KV cache write (direct-to-cache) */
5435 
5436  /* Step 5: Attention (decode, flash) */
5438  k_cache,
5439  v_cache,
5440  attn_token,
5441  H,
5442  H_kv,
5443  token_index + 1,
5444  aligned_context_window,
5445  head_dim,
5446  aligned_head_dim);
5447 
5448  /* Step 6: Output projection */
5449  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5450  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5451  uint8_t attn_q8[attn_q8_bytes];
5452  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5453  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5454 
5455  /* Step 7: Residual add */
5456  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5457 
5458  /* Step 8: RMSNorm before MLP */
5459  rmsnorm_forward(residual1,
5460  ln2_gamma,
5461  ln2_out,
5462  NULL,
5463  1,
5465  aligned_embed_dim,
5466  1e-06f);
5467 
5468  /* Step 9: MLP (SwiGLU) */
5469  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5470  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5471  uint8_t ln2_q8[ln2_q8_bytes];
5472  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
5473  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5474 
5475  /* SwiGLU activation */
5476  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5477 
5478  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5479  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5480  uint8_t swiglu_q8[swiglu_q8_bytes];
5481  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
5482  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5483 
5484  /* Step 10: Final residual add */
5485  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5486 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_5_prefill()

static void qwen2_0_5b_decode_layer_5_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1031 of file v6.6/test_generated/qwen2_int8.c.

1038  {
1040 
1041  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[4].output);
1042  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1043  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1044  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1045  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1046  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1047  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1048  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1049  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1050  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1051  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1052  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1053  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1054  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1055  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1056  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1057 
1058  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1059  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1060  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1061  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1062  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1063  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1064  const float *BQ = NULL;
1065  const float *BK = NULL;
1066  const float *BV = NULL;
1067  const float *BO = NULL;
1068  const float *B1 = NULL;
1069  const float *B2 = NULL;
1070 
1073 
1074  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1075  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1076  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1077  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1078  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1079  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1080 
1081  /* RMSNorm before attention */
1082  rmsnorm_forward(input,
1083  ln1_gamma,
1084  ln1_out,
1085  NULL,
1086  num_tokens,
1088  aligned_embed_dim,
1089  1e-06f);
1090 
1091  /* Q projection (head-major) */
1092  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1093  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1094  for (int h = 0; h < H; ++h) {
1095  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1096  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1097  float *q_h = q + (size_t)h * q_head_stride;
1098  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1099  }
1100 
1101  /* K projection (head-major) */
1102  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1103  const uint8_t *WK_bytes = (const uint8_t *)WK;
1104  for (int h = 0; h < H_kv; ++h) {
1105  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1106  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1107  float *k_h = k + (size_t)h * kv_head_stride;
1108  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1109  }
1110 
1111  /* V projection (head-major) */
1112  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1113  const uint8_t *WV_bytes = (const uint8_t *)WV;
1114  for (int h = 0; h < H_kv; ++h) {
1115  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1116  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1117  float *v_h = v + (size_t)h * kv_head_stride;
1118  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1119  }
1120 
1121  /* RoPE */
1123  k,
1124  rope_cos,
1125  rope_sin,
1126  H,
1127  H_kv,
1128  num_tokens,
1129  head_dim,
1130  aligned_head_dim,
1131  0,
1132  num_tokens,
1133  aligned_context_window);
1134 
1135  /* Attention (prefill, causal) */
1137  k,
1138  v,
1139  attn_out,
1140  H,
1141  H_kv,
1142  num_tokens,
1143  head_dim,
1144  aligned_head_dim,
1145  aligned_context_window);
1146 
1147  /* Output projection (flatten head-major to token-major) */
1148  const int K = H * aligned_head_dim;
1149  if (K != aligned_embed_dim) {
1150  return;
1151  }
1152  const float *proj_in = attn_out;
1153  if (H > 1) {
1154  if (!proj_scratch) {
1155  return;
1156  }
1157  for (int t = 0; t < num_tokens; ++t) {
1158  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1159  for (int h = 0; h < H; ++h) {
1160  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1161  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1162  src,
1163  (size_t)aligned_head_dim * sizeof(float));
1164  }
1165  }
1166  proj_in = proj_scratch;
1167  }
1168  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1169 
1170  /* Residual add */
1171  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1172 
1173  /* RMSNorm before MLP */
1174  rmsnorm_forward(residual1,
1175  ln2_gamma,
1176  ln2_out,
1177  NULL,
1178  num_tokens,
1180  aligned_embed_dim,
1181  1e-06f);
1182 
1183  /* MLP (SwiGLU) */
1184  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1185  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1186  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1187 
1188  /* Final residual add */
1189  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1190 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_6_decode()

static void qwen2_0_5b_decode_layer_6_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5491 of file v6.6/test_generated/qwen2_int8.c.

5498  {
5500 
5501  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[5].output);
5502 
5503  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5504  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5505  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5506  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5507  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5508  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5509  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5510  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5511  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5512  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5513  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5514 
5515  /* Weights (explicit types for layer 6) */
5516  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5517  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5518  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5519  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5520  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5521  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5522 
5525 
5526  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5527  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5528  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5529 
5530  float q_token[H * aligned_head_dim];
5531  float k_token[H_kv * aligned_head_dim];
5532  float v_token[H_kv * aligned_head_dim];
5533  float attn_token[H * aligned_head_dim];
5534 
5535  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5536  float fc1_out[2 * aligned_intermediate_dim];
5537  float swiglu_out[aligned_intermediate_dim];
5538 
5539  /* Step 1: RMSNorm before attention */
5540  rmsnorm_forward(input,
5541  ln1_gamma,
5542  ln1_out,
5543  NULL,
5544  1,
5546  aligned_embed_dim,
5547  1e-06f);
5548 
5549  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5550 
5551  /* Step 2: QKV projection */
5552  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5553  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5554  uint8_t ln1_q8[ln1_q8_bytes];
5555  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5556  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5557  if (aligned_head_dim > head_dim) {
5558  for (int h = 0; h < H; ++h) {
5559  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5560  for (int d = head_dim; d < aligned_head_dim; ++d) {
5561  q_head[d] = 0.0f;
5562  }
5563  }
5564  }
5565 
5566  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5567  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5568  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5569  const uint8_t *WK_bytes = (const uint8_t *)WK;
5570  /* ln1_q8 already quantized above */
5571  for (int h = 0; h < H_kv; ++h) {
5572  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5573  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5574  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5575  for (int d = head_dim; d < aligned_head_dim; ++d) {
5576  k_head[d] = 0.0f;
5577  }
5578  }
5579 
5580  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5581  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5582  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5583  const uint8_t *WV_bytes = (const uint8_t *)WV;
5584  /* ln1_q8 already quantized above */
5585  for (int h = 0; h < H_kv; ++h) {
5586  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5587  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5588  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5589  for (int d = head_dim; d < aligned_head_dim; ++d) {
5590  v_head[d] = 0.0f;
5591  }
5592  }
5593 
5594  /* Step 3: RoPE */
5595  rope_forward(q_token,
5596  rope_cos,
5597  rope_sin,
5598  H,
5599  1,
5600  head_dim,
5601  aligned_head_dim,
5602  token_index);
5603  for (int h = 0; h < H_kv; ++h) {
5604  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5605  rope_forward(k_head,
5606  rope_cos,
5607  rope_sin,
5608  1,
5609  1,
5610  head_dim,
5611  aligned_head_dim,
5612  token_index);
5613  }
5614 
5615  /* Step 4: KV cache write (direct-to-cache) */
5616 
5617  /* Step 5: Attention (decode, flash) */
5619  k_cache,
5620  v_cache,
5621  attn_token,
5622  H,
5623  H_kv,
5624  token_index + 1,
5625  aligned_context_window,
5626  head_dim,
5627  aligned_head_dim);
5628 
5629  /* Step 6: Output projection */
5630  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5631  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5632  uint8_t attn_q8[attn_q8_bytes];
5633  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5634  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5635 
5636  /* Step 7: Residual add */
5637  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5638 
5639  /* Step 8: RMSNorm before MLP */
5640  rmsnorm_forward(residual1,
5641  ln2_gamma,
5642  ln2_out,
5643  NULL,
5644  1,
5646  aligned_embed_dim,
5647  1e-06f);
5648 
5649  /* Step 9: MLP (SwiGLU) */
5650  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5651  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5652  uint8_t ln2_q8[ln2_q8_bytes];
5653  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
5654  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5655 
5656  /* SwiGLU activation */
5657  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5658 
5659  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5660  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5661  uint8_t swiglu_q8[swiglu_q8_bytes];
5662  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
5663  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5664 
5665  /* Step 10: Final residual add */
5666  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5667 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_6_prefill()

static void qwen2_0_5b_decode_layer_6_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1195 of file v6.6/test_generated/qwen2_int8.c.

1202  {
1204 
1205  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[5].output);
1206  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1207  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1208  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1209  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1210  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1211  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1212  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1213  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1214  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1215  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1216  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1217  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1218  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1219  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1220  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1221 
1222  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1223  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1224  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1225  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1226  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1227  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1228  const float *BQ = NULL;
1229  const float *BK = NULL;
1230  const float *BV = NULL;
1231  const float *BO = NULL;
1232  const float *B1 = NULL;
1233  const float *B2 = NULL;
1234 
1237 
1238  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1239  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1240  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1241  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1242  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1243  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1244 
1245  /* RMSNorm before attention */
1246  rmsnorm_forward(input,
1247  ln1_gamma,
1248  ln1_out,
1249  NULL,
1250  num_tokens,
1252  aligned_embed_dim,
1253  1e-06f);
1254 
1255  /* Q projection (head-major) */
1256  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1257  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1258  for (int h = 0; h < H; ++h) {
1259  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1260  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1261  float *q_h = q + (size_t)h * q_head_stride;
1262  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1263  }
1264 
1265  /* K projection (head-major) */
1266  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1267  const uint8_t *WK_bytes = (const uint8_t *)WK;
1268  for (int h = 0; h < H_kv; ++h) {
1269  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1270  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1271  float *k_h = k + (size_t)h * kv_head_stride;
1272  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1273  }
1274 
1275  /* V projection (head-major) */
1276  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1277  const uint8_t *WV_bytes = (const uint8_t *)WV;
1278  for (int h = 0; h < H_kv; ++h) {
1279  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1280  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1281  float *v_h = v + (size_t)h * kv_head_stride;
1282  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1283  }
1284 
1285  /* RoPE */
1287  k,
1288  rope_cos,
1289  rope_sin,
1290  H,
1291  H_kv,
1292  num_tokens,
1293  head_dim,
1294  aligned_head_dim,
1295  0,
1296  num_tokens,
1297  aligned_context_window);
1298 
1299  /* Attention (prefill, causal) */
1301  k,
1302  v,
1303  attn_out,
1304  H,
1305  H_kv,
1306  num_tokens,
1307  head_dim,
1308  aligned_head_dim,
1309  aligned_context_window);
1310 
1311  /* Output projection (flatten head-major to token-major) */
1312  const int K = H * aligned_head_dim;
1313  if (K != aligned_embed_dim) {
1314  return;
1315  }
1316  const float *proj_in = attn_out;
1317  if (H > 1) {
1318  if (!proj_scratch) {
1319  return;
1320  }
1321  for (int t = 0; t < num_tokens; ++t) {
1322  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1323  for (int h = 0; h < H; ++h) {
1324  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1325  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1326  src,
1327  (size_t)aligned_head_dim * sizeof(float));
1328  }
1329  }
1330  proj_in = proj_scratch;
1331  }
1332  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1333 
1334  /* Residual add */
1335  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1336 
1337  /* RMSNorm before MLP */
1338  rmsnorm_forward(residual1,
1339  ln2_gamma,
1340  ln2_out,
1341  NULL,
1342  num_tokens,
1344  aligned_embed_dim,
1345  1e-06f);
1346 
1347  /* MLP (SwiGLU) */
1348  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1349  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1350  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1351 
1352  /* Final residual add */
1353  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1354 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_7_decode()

static void qwen2_0_5b_decode_layer_7_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5672 of file v6.6/test_generated/qwen2_int8.c.

5679  {
5681 
5682  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[6].output);
5683 
5684  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5685  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5686  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5687  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5688  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5689  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5690  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5691  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5692  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5693  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5694  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5695 
5696  /* Weights (explicit types for layer 7) */
5697  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5698  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5699  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5700  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5701  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5702  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5703 
5706 
5707  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5708  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5709  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5710 
5711  float q_token[H * aligned_head_dim];
5712  float k_token[H_kv * aligned_head_dim];
5713  float v_token[H_kv * aligned_head_dim];
5714  float attn_token[H * aligned_head_dim];
5715 
5716  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5717  float fc1_out[2 * aligned_intermediate_dim];
5718  float swiglu_out[aligned_intermediate_dim];
5719 
5720  /* Step 1: RMSNorm before attention */
5721  rmsnorm_forward(input,
5722  ln1_gamma,
5723  ln1_out,
5724  NULL,
5725  1,
5727  aligned_embed_dim,
5728  1e-06f);
5729 
5730  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5731 
5732  /* Step 2: QKV projection */
5733  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5734  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5735  uint8_t ln1_q8[ln1_q8_bytes];
5736  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5737  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5738  if (aligned_head_dim > head_dim) {
5739  for (int h = 0; h < H; ++h) {
5740  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5741  for (int d = head_dim; d < aligned_head_dim; ++d) {
5742  q_head[d] = 0.0f;
5743  }
5744  }
5745  }
5746 
5747  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5748  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5749  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5750  const uint8_t *WK_bytes = (const uint8_t *)WK;
5751  /* ln1_q8 already quantized above */
5752  for (int h = 0; h < H_kv; ++h) {
5753  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5754  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5755  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5756  for (int d = head_dim; d < aligned_head_dim; ++d) {
5757  k_head[d] = 0.0f;
5758  }
5759  }
5760 
5761  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5762  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5763  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5764  const uint8_t *WV_bytes = (const uint8_t *)WV;
5765  /* ln1_q8 already quantized above */
5766  for (int h = 0; h < H_kv; ++h) {
5767  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5768  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5769  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5770  for (int d = head_dim; d < aligned_head_dim; ++d) {
5771  v_head[d] = 0.0f;
5772  }
5773  }
5774 
5775  /* Step 3: RoPE */
5776  rope_forward(q_token,
5777  rope_cos,
5778  rope_sin,
5779  H,
5780  1,
5781  head_dim,
5782  aligned_head_dim,
5783  token_index);
5784  for (int h = 0; h < H_kv; ++h) {
5785  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5786  rope_forward(k_head,
5787  rope_cos,
5788  rope_sin,
5789  1,
5790  1,
5791  head_dim,
5792  aligned_head_dim,
5793  token_index);
5794  }
5795 
5796  /* Step 4: KV cache write (direct-to-cache) */
5797 
5798  /* Step 5: Attention (decode, flash) */
5800  k_cache,
5801  v_cache,
5802  attn_token,
5803  H,
5804  H_kv,
5805  token_index + 1,
5806  aligned_context_window,
5807  head_dim,
5808  aligned_head_dim);
5809 
5810  /* Step 6: Output projection */
5811  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5812  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5813  uint8_t attn_q8[attn_q8_bytes];
5814  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5815  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5816 
5817  /* Step 7: Residual add */
5818  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
5819 
5820  /* Step 8: RMSNorm before MLP */
5821  rmsnorm_forward(residual1,
5822  ln2_gamma,
5823  ln2_out,
5824  NULL,
5825  1,
5827  aligned_embed_dim,
5828  1e-06f);
5829 
5830  /* Step 9: MLP (SwiGLU) */
5831  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5832  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5833  uint8_t ln2_q8[ln2_q8_bytes];
5834  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
5835  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5836 
5837  /* SwiGLU activation */
5838  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5839 
5840  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5841  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5842  uint8_t swiglu_q8[swiglu_q8_bytes];
5843  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
5844  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5845 
5846  /* Step 10: Final residual add */
5847  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
5848 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_7_prefill()

static void qwen2_0_5b_decode_layer_7_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1359 of file v6.6/test_generated/qwen2_int8.c.

1366  {
1368 
1369  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[6].output);
1370  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1371  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1372  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1373  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1374  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1375  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1376  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1377  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1378  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1379  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1380  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1381  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1382  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1383  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1384  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1385 
1386  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1387  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1388  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1389  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1390  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1391  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1392  const float *BQ = NULL;
1393  const float *BK = NULL;
1394  const float *BV = NULL;
1395  const float *BO = NULL;
1396  const float *B1 = NULL;
1397  const float *B2 = NULL;
1398 
1401 
1402  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1403  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1404  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1405  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1406  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1407  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1408 
1409  /* RMSNorm before attention */
1410  rmsnorm_forward(input,
1411  ln1_gamma,
1412  ln1_out,
1413  NULL,
1414  num_tokens,
1416  aligned_embed_dim,
1417  1e-06f);
1418 
1419  /* Q projection (head-major) */
1420  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1421  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1422  for (int h = 0; h < H; ++h) {
1423  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1424  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1425  float *q_h = q + (size_t)h * q_head_stride;
1426  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1427  }
1428 
1429  /* K projection (head-major) */
1430  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1431  const uint8_t *WK_bytes = (const uint8_t *)WK;
1432  for (int h = 0; h < H_kv; ++h) {
1433  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1434  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1435  float *k_h = k + (size_t)h * kv_head_stride;
1436  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1437  }
1438 
1439  /* V projection (head-major) */
1440  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1441  const uint8_t *WV_bytes = (const uint8_t *)WV;
1442  for (int h = 0; h < H_kv; ++h) {
1443  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1444  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1445  float *v_h = v + (size_t)h * kv_head_stride;
1446  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1447  }
1448 
1449  /* RoPE */
1451  k,
1452  rope_cos,
1453  rope_sin,
1454  H,
1455  H_kv,
1456  num_tokens,
1457  head_dim,
1458  aligned_head_dim,
1459  0,
1460  num_tokens,
1461  aligned_context_window);
1462 
1463  /* Attention (prefill, causal) */
1465  k,
1466  v,
1467  attn_out,
1468  H,
1469  H_kv,
1470  num_tokens,
1471  head_dim,
1472  aligned_head_dim,
1473  aligned_context_window);
1474 
1475  /* Output projection (flatten head-major to token-major) */
1476  const int K = H * aligned_head_dim;
1477  if (K != aligned_embed_dim) {
1478  return;
1479  }
1480  const float *proj_in = attn_out;
1481  if (H > 1) {
1482  if (!proj_scratch) {
1483  return;
1484  }
1485  for (int t = 0; t < num_tokens; ++t) {
1486  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1487  for (int h = 0; h < H; ++h) {
1488  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1489  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1490  src,
1491  (size_t)aligned_head_dim * sizeof(float));
1492  }
1493  }
1494  proj_in = proj_scratch;
1495  }
1496  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1497 
1498  /* Residual add */
1499  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1500 
1501  /* RMSNorm before MLP */
1502  rmsnorm_forward(residual1,
1503  ln2_gamma,
1504  ln2_out,
1505  NULL,
1506  num_tokens,
1508  aligned_embed_dim,
1509  1e-06f);
1510 
1511  /* MLP (SwiGLU) */
1512  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1513  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1514  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1515 
1516  /* Final residual add */
1517  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1518 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_8_decode()

static void qwen2_0_5b_decode_layer_8_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 5853 of file v6.6/test_generated/qwen2_int8.c.

5860  {
5862 
5863  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[7].output);
5864 
5865  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
5866  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
5867  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
5868  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
5869  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
5870  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
5871  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
5872  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
5873  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
5874  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
5875  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
5876 
5877  /* Weights (explicit types for layer 8) */
5878  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
5879  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
5880  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
5881  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
5882  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
5883  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
5884 
5887 
5888  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
5889  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
5890  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
5891 
5892  float q_token[H * aligned_head_dim];
5893  float k_token[H_kv * aligned_head_dim];
5894  float v_token[H_kv * aligned_head_dim];
5895  float attn_token[H * aligned_head_dim];
5896 
5897  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
5898  float fc1_out[2 * aligned_intermediate_dim];
5899  float swiglu_out[aligned_intermediate_dim];
5900 
5901  /* Step 1: RMSNorm before attention */
5902  rmsnorm_forward(input,
5903  ln1_gamma,
5904  ln1_out,
5905  NULL,
5906  1,
5908  aligned_embed_dim,
5909  1e-06f);
5910 
5911  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
5912 
5913  /* Step 2: QKV projection */
5914  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5915  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5916  uint8_t ln1_q8[ln1_q8_bytes];
5917  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
5918  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5919  if (aligned_head_dim > head_dim) {
5920  for (int h = 0; h < H; ++h) {
5921  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
5922  for (int d = head_dim; d < aligned_head_dim; ++d) {
5923  q_head[d] = 0.0f;
5924  }
5925  }
5926  }
5927 
5928  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5929  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5930  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
5931  const uint8_t *WK_bytes = (const uint8_t *)WK;
5932  /* ln1_q8 already quantized above */
5933  for (int h = 0; h < H_kv; ++h) {
5934  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
5935  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5936  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5937  for (int d = head_dim; d < aligned_head_dim; ++d) {
5938  k_head[d] = 0.0f;
5939  }
5940  }
5941 
5942  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
5943  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
5944  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
5945  const uint8_t *WV_bytes = (const uint8_t *)WV;
5946  /* ln1_q8 already quantized above */
5947  for (int h = 0; h < H_kv; ++h) {
5948  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
5949  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5950  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5951  for (int d = head_dim; d < aligned_head_dim; ++d) {
5952  v_head[d] = 0.0f;
5953  }
5954  }
5955 
5956  /* Step 3: RoPE */
5957  rope_forward(q_token,
5958  rope_cos,
5959  rope_sin,
5960  H,
5961  1,
5962  head_dim,
5963  aligned_head_dim,
5964  token_index);
5965  for (int h = 0; h < H_kv; ++h) {
5966  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
5967  rope_forward(k_head,
5968  rope_cos,
5969  rope_sin,
5970  1,
5971  1,
5972  head_dim,
5973  aligned_head_dim,
5974  token_index);
5975  }
5976 
5977  /* Step 4: KV cache write (direct-to-cache) */
5978 
5979  /* Step 5: Attention (decode, flash) */
5981  k_cache,
5982  v_cache,
5983  attn_token,
5984  H,
5985  H_kv,
5986  token_index + 1,
5987  aligned_context_window,
5988  head_dim,
5989  aligned_head_dim);
5990 
5991  /* Step 6: Output projection */
5992  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
5993  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5994  uint8_t attn_q8[attn_q8_bytes];
5995  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
5996  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5997 
5998  /* Step 7: Residual add */
5999  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6000 
6001  /* Step 8: RMSNorm before MLP */
6002  rmsnorm_forward(residual1,
6003  ln2_gamma,
6004  ln2_out,
6005  NULL,
6006  1,
6008  aligned_embed_dim,
6009  1e-06f);
6010 
6011  /* Step 9: MLP (SwiGLU) */
6012  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6013  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6014  uint8_t ln2_q8[ln2_q8_bytes];
6015  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6016  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6017 
6018  /* SwiGLU activation */
6019  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6020 
6021  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6022  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6023  uint8_t swiglu_q8[swiglu_q8_bytes];
6024  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6025  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6026 
6027  /* Step 10: Final residual add */
6028  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6029 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_8_prefill()

static void qwen2_0_5b_decode_layer_8_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1523 of file v6.6/test_generated/qwen2_int8.c.

1530  {
1532 
1533  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[7].output);
1534  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1535  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1536  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1537  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1538  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1539  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1540  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1541  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1542  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1543  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1544  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1545  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1546  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1547  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1548  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1549 
1550  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1551  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1552  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1553  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1554  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1555  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1556  const float *BQ = NULL;
1557  const float *BK = NULL;
1558  const float *BV = NULL;
1559  const float *BO = NULL;
1560  const float *B1 = NULL;
1561  const float *B2 = NULL;
1562 
1565 
1566  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1567  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1568  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1569  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1570  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1571  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1572 
1573  /* RMSNorm before attention */
1574  rmsnorm_forward(input,
1575  ln1_gamma,
1576  ln1_out,
1577  NULL,
1578  num_tokens,
1580  aligned_embed_dim,
1581  1e-06f);
1582 
1583  /* Q projection (head-major) */
1584  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1585  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1586  for (int h = 0; h < H; ++h) {
1587  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1588  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1589  float *q_h = q + (size_t)h * q_head_stride;
1590  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1591  }
1592 
1593  /* K projection (head-major) */
1594  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1595  const uint8_t *WK_bytes = (const uint8_t *)WK;
1596  for (int h = 0; h < H_kv; ++h) {
1597  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1598  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1599  float *k_h = k + (size_t)h * kv_head_stride;
1600  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1601  }
1602 
1603  /* V projection (head-major) */
1604  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1605  const uint8_t *WV_bytes = (const uint8_t *)WV;
1606  for (int h = 0; h < H_kv; ++h) {
1607  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1608  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1609  float *v_h = v + (size_t)h * kv_head_stride;
1610  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1611  }
1612 
1613  /* RoPE */
1615  k,
1616  rope_cos,
1617  rope_sin,
1618  H,
1619  H_kv,
1620  num_tokens,
1621  head_dim,
1622  aligned_head_dim,
1623  0,
1624  num_tokens,
1625  aligned_context_window);
1626 
1627  /* Attention (prefill, causal) */
1629  k,
1630  v,
1631  attn_out,
1632  H,
1633  H_kv,
1634  num_tokens,
1635  head_dim,
1636  aligned_head_dim,
1637  aligned_context_window);
1638 
1639  /* Output projection (flatten head-major to token-major) */
1640  const int K = H * aligned_head_dim;
1641  if (K != aligned_embed_dim) {
1642  return;
1643  }
1644  const float *proj_in = attn_out;
1645  if (H > 1) {
1646  if (!proj_scratch) {
1647  return;
1648  }
1649  for (int t = 0; t < num_tokens; ++t) {
1650  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1651  for (int h = 0; h < H; ++h) {
1652  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1653  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1654  src,
1655  (size_t)aligned_head_dim * sizeof(float));
1656  }
1657  }
1658  proj_in = proj_scratch;
1659  }
1660  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1661 
1662  /* Residual add */
1663  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1664 
1665  /* RMSNorm before MLP */
1666  rmsnorm_forward(residual1,
1667  ln2_gamma,
1668  ln2_out,
1669  NULL,
1670  num_tokens,
1672  aligned_embed_dim,
1673  1e-06f);
1674 
1675  /* MLP (SwiGLU) */
1676  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1677  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1678  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1679 
1680  /* Final residual add */
1681  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1682 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_layer_9_decode()

static void qwen2_0_5b_decode_layer_9_decode ( QWEN2_0_5B_DECODEModel model,
int  token_index,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 6034 of file v6.6/test_generated/qwen2_int8.c.

6041  {
6043 
6044  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[8].output);
6045 
6046  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
6047  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
6048  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
6049  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
6050  float *k_cache = QWEN2_0_5B_DECODE_PTR(model, L->k);
6051  float *v_cache = QWEN2_0_5B_DECODE_PTR(model, L->v);
6052  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
6053  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
6054  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
6055  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
6056  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
6057 
6058  /* Weights (explicit types for layer 9) */
6059  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq); /* Q4_K */
6060  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk); /* Q4_K */
6061  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv); /* Q4_K */
6062  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo); /* Q4_K */
6063  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1); /* Q4_K (gate+up) */
6064  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2); /* Q4_K (down) */
6065 
6068 
6069  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
6070  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
6071  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
6072 
6073  float q_token[H * aligned_head_dim];
6074  float k_token[H_kv * aligned_head_dim];
6075  float v_token[H_kv * aligned_head_dim];
6076  float attn_token[H * aligned_head_dim];
6077 
6078  /* Local MLP buffers (avoid layout dependencies for intermediate values) */
6079  float fc1_out[2 * aligned_intermediate_dim];
6080  float swiglu_out[aligned_intermediate_dim];
6081 
6082  /* Step 1: RMSNorm before attention */
6083  rmsnorm_forward(input,
6084  ln1_gamma,
6085  ln1_out,
6086  NULL,
6087  1,
6089  aligned_embed_dim,
6090  1e-06f);
6091 
6092  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
6093 
6094  /* Step 2: QKV projection */
6095  /* Q projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6096  const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6097  uint8_t ln1_q8[ln1_q8_bytes];
6098  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed_dim);
6099  gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6100  if (aligned_head_dim > head_dim) {
6101  for (int h = 0; h < H; ++h) {
6102  float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
6103  for (int d = head_dim; d < aligned_head_dim; ++d) {
6104  q_head[d] = 0.0f;
6105  }
6106  }
6107  }
6108 
6109  /* K projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6110  const size_t wk_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6111  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wk_head_elems);
6112  const uint8_t *WK_bytes = (const uint8_t *)WK;
6113  /* ln1_q8 already quantized above */
6114  for (int h = 0; h < H_kv; ++h) {
6115  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
6116  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6117  gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6118  for (int d = head_dim; d < aligned_head_dim; ++d) {
6119  k_head[d] = 0.0f;
6120  }
6121  }
6122 
6123  /* V projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k (direct-to-cache) */
6124  const size_t wv_head_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
6125  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, wv_head_elems);
6126  const uint8_t *WV_bytes = (const uint8_t *)WV;
6127  /* ln1_q8 already quantized above */
6128  for (int h = 0; h < H_kv; ++h) {
6129  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
6130  float *v_head = v_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6131  gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6132  for (int d = head_dim; d < aligned_head_dim; ++d) {
6133  v_head[d] = 0.0f;
6134  }
6135  }
6136 
6137  /* Step 3: RoPE */
6138  rope_forward(q_token,
6139  rope_cos,
6140  rope_sin,
6141  H,
6142  1,
6143  head_dim,
6144  aligned_head_dim,
6145  token_index);
6146  for (int h = 0; h < H_kv; ++h) {
6147  float *k_head = k_cache + (size_t)h * kv_head_stride + (size_t)token_index * (size_t)aligned_head_dim;
6148  rope_forward(k_head,
6149  rope_cos,
6150  rope_sin,
6151  1,
6152  1,
6153  head_dim,
6154  aligned_head_dim,
6155  token_index);
6156  }
6157 
6158  /* Step 4: KV cache write (direct-to-cache) */
6159 
6160  /* Step 5: Attention (decode, flash) */
6162  k_cache,
6163  v_cache,
6164  attn_token,
6165  H,
6166  H_kv,
6167  token_index + 1,
6168  aligned_context_window,
6169  head_dim,
6170  aligned_head_dim);
6171 
6172  /* Step 6: Output projection */
6173  /* WO projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6174  const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6175  uint8_t attn_q8[attn_q8_bytes];
6176  quantize_row_q8_k(attn_token, attn_q8, H * head_dim);
6177  gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6178 
6179  /* Step 7: Residual add */
6180  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, 1, aligned_embed_dim);
6181 
6182  /* Step 8: RMSNorm before MLP */
6183  rmsnorm_forward(residual1,
6184  ln2_gamma,
6185  ln2_out,
6186  NULL,
6187  1,
6189  aligned_embed_dim,
6190  1e-06f);
6191 
6192  /* Step 9: MLP (SwiGLU) */
6193  /* Gate+Up projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6194  const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6195  uint8_t ln2_q8[ln2_q8_bytes];
6196  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed_dim);
6197  gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6198 
6199  /* SwiGLU activation */
6200  swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6201 
6202  /* Down projection (INT8): Q4_K x Q8_K -> gemv_q4_k_q8_k */
6203  const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6204  uint8_t swiglu_q8[swiglu_q8_bytes];
6205  quantize_row_q8_k(swiglu_out, swiglu_q8, aligned_intermediate_dim);
6206  gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6207 
6208  /* Step 10: Final residual add */
6209  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, 1, aligned_embed_dim);
6210 }

References attention_forward_decode_head_major_gqa_flash(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemv_q4_k_q8_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, quantize_row_q8_k(), QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_decode_token().

◆ qwen2_0_5b_decode_layer_9_prefill()

static void qwen2_0_5b_decode_layer_9_prefill ( QWEN2_0_5B_DECODEModel model,
int  num_tokens,
int  aligned_embed_dim,
int  aligned_head_dim,
int  aligned_intermediate_dim,
int  aligned_context_window 
)
static

Definition at line 1687 of file v6.6/test_generated/qwen2_int8.c.

1694  {
1696 
1697  float *input = QWEN2_0_5B_DECODE_PTR(model, QWEN2_0_5B_DECODE_LAYERS[8].output);
1698  float *ln1_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln1_gamma);
1699  float *ln1_out = QWEN2_0_5B_DECODE_PTR(model, L->ln1_out);
1700  float *ln2_gamma = QWEN2_0_5B_DECODE_PTR(model, L->ln2_gamma);
1701  float *ln2_out = QWEN2_0_5B_DECODE_PTR(model, L->ln2_out);
1702  float *q = QWEN2_0_5B_DECODE_PTR(model, L->q);
1703  float *k = QWEN2_0_5B_DECODE_PTR(model, L->k);
1704  float *v = QWEN2_0_5B_DECODE_PTR(model, L->v);
1705  float *attn_out = QWEN2_0_5B_DECODE_PTR(model, L->attn_out);
1706  float *proj_tmp = QWEN2_0_5B_DECODE_PTR(model, L->proj_tmp);
1707  float *proj_scratch = QWEN2_0_5B_DECODE_PTR(model, L->proj_scratch);
1708  float *residual1 = QWEN2_0_5B_DECODE_PTR(model, L->residual1);
1709  float *fc1_out = QWEN2_0_5B_DECODE_PTR(model, L->fc1_out);
1710  float *swiglu_out = QWEN2_0_5B_DECODE_PTR(model, L->swiglu_out);
1711  float *mlp_out = QWEN2_0_5B_DECODE_PTR(model, L->mlp_out);
1712  float *output = QWEN2_0_5B_DECODE_PTR(model, L->output);
1713 
1714  const void *WQ = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wq);
1715  const void *WK = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wk);
1716  const void *WV = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wv);
1717  const void *WO = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->wo);
1718  const void *W1 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w1);
1719  const void *W2 = (const void *)QWEN2_0_5B_DECODE_PTR(model, L->w2);
1720  const float *BQ = NULL;
1721  const float *BK = NULL;
1722  const float *BV = NULL;
1723  const float *BO = NULL;
1724  const float *B1 = NULL;
1725  const float *B2 = NULL;
1726 
1729 
1730  const int H = QWEN2_0_5B_DECODE_NUM_HEADS;
1731  const int H_kv = QWEN2_0_5B_DECODE_NUM_KV_HEADS;
1732  const int head_dim = QWEN2_0_5B_DECODE_HEAD_DIM;
1733  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1734  const size_t q_head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
1735  const size_t kv_head_stride = (size_t)aligned_context_window * (size_t)aligned_head_dim;
1736 
1737  /* RMSNorm before attention */
1738  rmsnorm_forward(input,
1739  ln1_gamma,
1740  ln1_out,
1741  NULL,
1742  num_tokens,
1744  aligned_embed_dim,
1745  1e-06f);
1746 
1747  /* Q projection (head-major) */
1748  const size_t wq_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1749  const uint8_t *WQ_bytes = (const uint8_t *)WQ;
1750  for (int h = 0; h < H; ++h) {
1751  const void *wq_h = (const void *)(WQ_bytes + (size_t)h * wq_head_bytes);
1752  const float *bq_h = BQ ? (BQ + (size_t)h * (size_t)aligned_head_dim) : NULL;
1753  float *q_h = q + (size_t)h * q_head_stride;
1754  gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1755  }
1756 
1757  /* K projection (head-major) */
1758  const size_t wk_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1759  const uint8_t *WK_bytes = (const uint8_t *)WK;
1760  for (int h = 0; h < H_kv; ++h) {
1761  const void *wk_h = (const void *)(WK_bytes + (size_t)h * wk_head_bytes);
1762  const float *bk_h = BK ? (BK + (size_t)h * (size_t)aligned_head_dim) : NULL;
1763  float *k_h = k + (size_t)h * kv_head_stride;
1764  gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1765  }
1766 
1767  /* V projection (head-major) */
1768  const size_t wv_head_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1769  const uint8_t *WV_bytes = (const uint8_t *)WV;
1770  for (int h = 0; h < H_kv; ++h) {
1771  const void *wv_h = (const void *)(WV_bytes + (size_t)h * wv_head_bytes);
1772  const float *bv_h = BV ? (BV + (size_t)h * (size_t)aligned_head_dim) : NULL;
1773  float *v_h = v + (size_t)h * kv_head_stride;
1774  gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1775  }
1776 
1777  /* RoPE */
1779  k,
1780  rope_cos,
1781  rope_sin,
1782  H,
1783  H_kv,
1784  num_tokens,
1785  head_dim,
1786  aligned_head_dim,
1787  0,
1788  num_tokens,
1789  aligned_context_window);
1790 
1791  /* Attention (prefill, causal) */
1793  k,
1794  v,
1795  attn_out,
1796  H,
1797  H_kv,
1798  num_tokens,
1799  head_dim,
1800  aligned_head_dim,
1801  aligned_context_window);
1802 
1803  /* Output projection (flatten head-major to token-major) */
1804  const int K = H * aligned_head_dim;
1805  if (K != aligned_embed_dim) {
1806  return;
1807  }
1808  const float *proj_in = attn_out;
1809  if (H > 1) {
1810  if (!proj_scratch) {
1811  return;
1812  }
1813  for (int t = 0; t < num_tokens; ++t) {
1814  float *dst = proj_scratch + (size_t)t * (size_t)aligned_embed_dim;
1815  for (int h = 0; h < H; ++h) {
1816  const float *src = attn_out + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1817  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
1818  src,
1819  (size_t)aligned_head_dim * sizeof(float));
1820  }
1821  }
1822  proj_in = proj_scratch;
1823  }
1824  gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1825 
1826  /* Residual add */
1827  qwen2_0_5b_decode_residual_add_token_major(input, proj_tmp, residual1, num_tokens, aligned_embed_dim);
1828 
1829  /* RMSNorm before MLP */
1830  rmsnorm_forward(residual1,
1831  ln2_gamma,
1832  ln2_out,
1833  NULL,
1834  num_tokens,
1836  aligned_embed_dim,
1837  1e-06f);
1838 
1839  /* MLP (SwiGLU) */
1840  gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1841  swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1842  gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1843 
1844  /* Final residual add */
1845  qwen2_0_5b_decode_residual_add_token_major(residual1, mlp_out, output, num_tokens, aligned_embed_dim);
1846 }

References attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_Q4_K, ck_dtype_row_bytes(), gemm_nt_q4_k(), QWEN2_0_5B_DECODELayerOffsets::k, QWEN2_0_5B_DECODELayerOffsets::ln1_gamma, QWEN2_0_5B_DECODELayerOffsets::ln1_out, QWEN2_0_5B_DECODELayerOffsets::ln2_gamma, QWEN2_0_5B_DECODELayerOffsets::ln2_out, QWEN2_0_5B_DECODELayerOffsets::mlp_out, QWEN2_0_5B_DECODELayerOffsets::output, QWEN2_0_5B_DECODELayerOffsets::proj_scratch, QWEN2_0_5B_DECODELayerOffsets::proj_tmp, QWEN2_0_5B_DECODELayerOffsets::q, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_LAYERS, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_KV_HEADS, QWEN2_0_5B_DECODE_PTR, qwen2_0_5b_decode_residual_add_token_major(), QWEN2_0_5B_DECODELayerOffsets::residual1, rmsnorm_forward(), QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, rope_forward_qk_strided(), QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache, swiglu_forward(), QWEN2_0_5B_DECODELayerOffsets::v, QWEN2_0_5B_DECODELayerOffsets::w1, QWEN2_0_5B_DECODELayerOffsets::w2, QWEN2_0_5B_DECODELayerOffsets::wk, QWEN2_0_5B_DECODELayerOffsets::wo, QWEN2_0_5B_DECODELayerOffsets::wq, and QWEN2_0_5B_DECODELayerOffsets::wv.

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ qwen2_0_5b_decode_model_allocate()

int qwen2_0_5b_decode_model_allocate ( QWEN2_0_5B_DECODEModel model)

Definition at line 88 of file v6.6/test_generated/qwen2_int8.c.

88  {
89  size_t total = QWEN2_0_5B_DECODE_TOTAL_BYTES;
90 
91 #ifdef __linux__
92  model->base = mmap(NULL, total,
93  PROT_READ | PROT_WRITE,
94  MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
95  -1, 0);
96  if (model->base == MAP_FAILED) {
97  model->base = mmap(NULL, total,
98  PROT_READ | PROT_WRITE,
99  MAP_PRIVATE | MAP_ANONYMOUS,
100  -1, 0);
101  }
102  if (model->base == MAP_FAILED) {
103  perror("mmap failed");
104  return -1;
105  }
106 #else
107  model->base = aligned_alloc(64, total);
108  if (!model->base) {
109  perror("aligned_alloc failed");
110  return -1;
111  }
112 #endif
113 
114  model->total_bytes = total;
115 
116  /* Initialize magic header */
117  MagicHeader *header = (MagicHeader *)model->base;
118  header->magic = QWEN2_0_5B_DECODE_MAGIC;
119  header->version = 5;
120  header->total_bytes = QWEN2_0_5B_DECODE_TOTAL_BYTES;
121  header->weight_bytes = QWEN2_0_5B_DECODE_WEIGHT_BYTES;
122  header->activation_bytes = QWEN2_0_5B_DECODE_ACTIVATION_BYTES;
123  header->num_layers = QWEN2_0_5B_DECODE_NUM_LAYERS;
124  header->embed_dim = QWEN2_0_5B_DECODE_EMBED_DIM;
125  header->num_heads = QWEN2_0_5B_DECODE_NUM_HEADS;
126  header->vocab_size = QWEN2_0_5B_DECODE_VOCAB_SIZE;
127  header->max_seq_len = QWEN2_0_5B_DECODE_MAX_SEQ_LEN;
128  header->canary_count = QWEN2_0_5B_DECODE_CANARY_COUNT;
129 
130  /* Initialize canary guards */
131  for (int i = 0; i < QWEN2_0_5B_DECODE_CANARY_COUNT; i++) {
132  uint32_t *ptr = (uint32_t*)((char*)model->base + QWEN2_0_5B_DECODE_CANARIES[i].offset);
133  for (int j = 0; j < (QWEN2_0_5B_DECODE_CANARY_SIZE / 4); j++) {
135  }
136  }
137 
138  return 0;
139 }
#define QWEN2_0_5B_DECODE_TOTAL_BYTES
#define QWEN2_0_5B_DECODE_ACTIVATION_BYTES
#define QWEN2_0_5B_DECODE_MAX_SEQ_LEN
#define QWEN2_0_5B_DECODE_WEIGHT_BYTES
#define QWEN2_0_5B_DECODE_CANARY_VALUE
#define QWEN2_0_5B_DECODE_CANARY_SIZE
static const QWEN2_0_5B_DECODECanary QWEN2_0_5B_DECODE_CANARIES[]

References QWEN2_0_5B_DECODEModel::base, MagicHeader, QWEN2_0_5B_DECODECanary::offset, QWEN2_0_5B_DECODE_ACTIVATION_BYTES, QWEN2_0_5B_DECODE_CANARIES, QWEN2_0_5B_DECODE_CANARY_COUNT, QWEN2_0_5B_DECODE_CANARY_SIZE, QWEN2_0_5B_DECODE_CANARY_VALUE, QWEN2_0_5B_DECODE_EMBED_DIM, QWEN2_0_5B_DECODE_MAGIC, QWEN2_0_5B_DECODE_MAX_SEQ_LEN, QWEN2_0_5B_DECODE_NUM_HEADS, QWEN2_0_5B_DECODE_NUM_LAYERS, QWEN2_0_5B_DECODE_TOTAL_BYTES, QWEN2_0_5B_DECODE_VOCAB_SIZE, QWEN2_0_5B_DECODE_WEIGHT_BYTES, and QWEN2_0_5B_DECODEModel::total_bytes.

Referenced by ck_model_create().

◆ qwen2_0_5b_decode_model_free()

void qwen2_0_5b_decode_model_free ( QWEN2_0_5B_DECODEModel model)

Definition at line 141 of file v6.6/test_generated/qwen2_int8.c.

141  {
142  if (!model || !model->base) return;
143 #ifdef __linux__
144  munmap(model->base, model->total_bytes);
145 #else
146  free(model->base);
147 #endif
148  model->base = NULL;
149  model->total_bytes = 0;
150 }

References QWEN2_0_5B_DECODEModel::base, and QWEN2_0_5B_DECODEModel::total_bytes.

Referenced by ck_model_free().

◆ qwen2_0_5b_decode_precompute_rope()

void qwen2_0_5b_decode_precompute_rope ( QWEN2_0_5B_DECODEModel model)

Definition at line 186 of file v6.6/test_generated/qwen2_int8.c.

186  {
187  const int T = QWEN2_0_5B_DECODE_MAX_SEQ_LEN;
188  const int D = QWEN2_0_5B_DECODE_HEAD_DIM / 2;
189  const float theta = 1000000.0f;
190 
193 
194  for (int pos = 0; pos < T; pos++) {
195  for (int i = 0; i < D; i++) {
196  float freq = 1.0f / powf(theta, (float)(2 * i) / (float)(D * 2));
197  float angle = (float)pos * freq;
198  cos_ptr[pos * D + i] = cosf(angle);
199  sin_ptr[pos * D + i] = sinf(angle);
200  }
201  }
202 }

References QWEN2_0_5B_DECODE_GLOBALS, QWEN2_0_5B_DECODE_HEAD_DIM, QWEN2_0_5B_DECODE_MAX_SEQ_LEN, QWEN2_0_5B_DECODE_PTR, QWEN2_0_5B_DECODEGlobalOffsets::rope_cos_cache, and QWEN2_0_5B_DECODEGlobalOffsets::rope_sin_cache.

Referenced by ck_model_precompute_rope().

◆ qwen2_0_5b_decode_residual_add_token_major()

static void qwen2_0_5b_decode_residual_add_token_major ( const float *  a,
const float *  b,
float *  out,
int  tokens,
int  aligned_embed_dim 
)
static

Definition at line 43 of file v6.6/test_generated/qwen2_int8.c.

49  {
50  if (!a || !b || !out) {
51  return;
52  }
53  for (int t = 0; t < tokens; ++t) {
54  const float *pa = a + (size_t)t * (size_t)aligned_embed_dim;
55  const float *pb = b + (size_t)t * (size_t)aligned_embed_dim;
56  float *pc = out + (size_t)t * (size_t)aligned_embed_dim;
57  for (int d = 0; d < aligned_embed_dim; ++d) {
58  pc[d] = pa[d] + pb[d];
59  }
60  }
61 }

Referenced by qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ qwen2_0_5b_decode_verify_canaries()

int qwen2_0_5b_decode_verify_canaries ( QWEN2_0_5B_DECODEModel model)

Definition at line 152 of file v6.6/test_generated/qwen2_int8.c.

152  {
153  int errors = 0;
154  uint32_t *ptr;
155 
156  for (int i = 0; i < QWEN2_0_5B_DECODE_CANARY_COUNT; i++) {
157  ptr = (uint32_t*)((char*)model->base + QWEN2_0_5B_DECODE_CANARIES[i].offset);
158  for (int j = 0; j < 4; j++) {
159  if (ptr[j] != QWEN2_0_5B_DECODE_CANARY_VALUE) {
160  fprintf(stderr, "CANARY CORRUPTION: %s at offset 0x%lX\n",
162  QWEN2_0_5B_DECODE_CANARIES[i].offset);
163  errors++;
164  break;
165  }
166  }
167  }
168 
169  return errors;
170 }

References QWEN2_0_5B_DECODEModel::base, QWEN2_0_5B_DECODECanary::offset, QWEN2_0_5B_DECODE_CANARIES, QWEN2_0_5B_DECODE_CANARY_COUNT, and QWEN2_0_5B_DECODE_CANARY_VALUE.

Referenced by ck_model_verify_canaries().

Variable Documentation

◆ g_model_config

CKModelConfig g_model_config
static
Initial value:
= {
.embed_dim = 896 ,
.num_heads = 14 ,
.num_kv_heads = 2 ,
.head_dim = 64 ,
.intermediate_size = 4864 ,
.num_layers = 24 ,
.vocab_size = 151936 ,
.max_seq_len = 131072 ,
.total_bytes = 3573889600ULL ,
.weight_bytes = 317683328ULL ,
.activation_bytes = 3256169984ULL ,
.model_name = "qwen2_0.5b_decode",
.model_family = "qwen2",
}

Definition at line 8853 of file v6.6/test_generated/qwen2_int8.c.

Referenced by ck_model_get_config().

◆ MagicHeader

MagicHeader

Definition at line 80 of file v6.6/test_generated/qwen2_int8.c.

Referenced by qwen2_0_5b_decode_model_allocate().