= {
{
"attention", {
"attention_forward_causal_head_major_gqa",
"attention_forward_causal_head_major_gqa_bf16", NULL, NULL, NULL }, {
"attention_backward_causal_head_major_gqa",
"attention_backward_causal_head_major_gqa_bf16", NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32) |
CK_DT_MASK(
CK_DT_BF16),
CK_DT_FP32, {
"src/kernels/attention_kernels.c",
"src/kernels/softmax_kernels.c", NULL, NULL, NULL, NULL, NULL, NULL }},
{
"attn_proj", {
"ck_attention_project_head_major", NULL, NULL, NULL, NULL }, {
"ck_attention_project_head_major_backward", NULL, NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32),
CK_DT_FP32, {
"src/ckernel_orchestration.c",
"src/kernels/gemm_kernels.c",
"src/kernels/mlp_kernels.c",
"src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL }},
{
"mlp_down", {
"gemm_blocked_serial", NULL, NULL, NULL, NULL }, {
"fc2_backward_kernel", NULL, NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32),
CK_DT_FP32, {
"src/kernels/gemm_kernels.c",
"src/kernels/mlp_kernels.c",
"src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL, NULL }},
{
"mlp_up", {
"gemm_blocked_serial", NULL, NULL, NULL, NULL }, {
"fc1_backward_kernel", NULL, NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32),
CK_DT_FP32, {
"src/kernels/gemm_kernels.c",
"src/kernels/mlp_kernels.c",
"src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL, NULL }},
{
"qkv_project", {
"ck_qkv_project_head_major", NULL, NULL, NULL, NULL }, {
"ck_qkv_project_head_major_backward", NULL, NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32),
CK_DT_FP32, {
"src/ckernel_orchestration.c",
"src/kernels/gemm_kernels.c",
"src/kernels/mlp_kernels.c",
"src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL }},
{
"residual_add", {
"ck_residual_add_token_major", NULL, NULL, NULL, NULL }, {
"ck_residual_add_backward", NULL, NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32),
CK_DT_FP32, {
"src/ckernel_orchestration.c", NULL, NULL, NULL, NULL, NULL, NULL, NULL }},
{
"rmsnorm", {
"rmsnorm_forward",
"rmsnorm_forward_bf16", NULL,
"rmsnorm_forward_int8",
"rmsnorm_forward_int4" }, {
"rmsnorm_backward",
"rmsnorm_backward_bf16", NULL,
"rmsnorm_backward_int8",
"rmsnorm_backward_int4" },
CK_DT_MASK(
CK_DT_FP32) |
CK_DT_MASK(
CK_DT_BF16) |
CK_DT_MASK(
CK_DT_INT8) |
CK_DT_MASK(
CK_DT_INT4),
CK_DT_FP32, {
"src/kernels/rmsnorm_kernels.c",
"src/kernels/rmsnorm_kernels_bf16.c",
"src/kernels/rmsnorm_kernels_int8.c",
"src/kernels/rmsnorm_kernels_int4.c", NULL, NULL, NULL, NULL }},
{
"rope", {
"rope_forward_qk",
"rope_forward_qk_bf16", NULL, NULL, NULL }, {
"rope_backward_qk",
"rope_backward_qk_bf16", NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32) |
CK_DT_MASK(
CK_DT_BF16),
CK_DT_FP32, {
"src/kernels/rope_kernels.c",
"src/kernels/rope_kernels_bf16.c", NULL, NULL, NULL, NULL, NULL, NULL }},
{
"swiglu", {
"swiglu_forward",
"swiglu_forward_bf16", NULL, NULL, NULL }, {
"swiglu_backward",
"swiglu_backward_bf16", NULL, NULL, NULL },
CK_DT_MASK(
CK_DT_FP32) |
CK_DT_MASK(
CK_DT_BF16),
CK_DT_FP32, {
"src/kernels/swiglu_kernels.c",
"src/kernels/swiglu_kernels_bf16.c",
"src/kernels/sigmoid_kernels.c", NULL, NULL, NULL, NULL, NULL }},
}