43 int aligned_embed_dim,
62 const float *q_gamma,
const float *k_gamma,
63 int num_heads,
int num_kv_heads,
64 int num_tokens,
int head_dim,
float eps)
69 num_heads * num_tokens, head_dim, head_dim, eps);
74 num_kv_heads * num_tokens, head_dim, head_dim, eps);
void qk_norm_forward(float *q, float *k, const float *q_gamma, const float *k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)