21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
29 const float *cos_cache,
30 const float *sin_cache,
40 size_t total = (size_t)num_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
44 num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
54 const float *cos_cache,
55 const float *sin_cache,
64 if (!scratch_d_out || !scratch_d_x)
return;
66 size_t total = (size_t)num_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
69 rope_backward(scratch_d_out, scratch_d_x, cos_cache, sin_cache,
70 num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
81 const float *cos_cache,
82 const float *sin_cache,
95 num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_q);
97 num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_k);
104 const uint16_t *d_k_out,
107 const float *cos_cache,
108 const float *sin_cache,
113 int aligned_head_dim,
115 float *scratch_dq_out,
117 float *scratch_dk_out,
120 if (!d_q_out || !d_k_out || !d_q || !d_k)
return;
123 num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
124 scratch_dq_out, scratch_dq);
126 num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
127 scratch_dk_out, scratch_dk);
130 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void rope_backward_bf16(const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)
void rope_backward_qk_bf16(const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk)
void rope_forward_bf16(uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)
void rope_forward_qk_bf16(uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k)