24 #pragma GCC diagnostic push
25 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
34 int aligned_context_window,
37 if (!scores || num_heads <= 0 || num_tokens <= 0 || aligned_context_window <= 0)
return;
40 const size_t total = (size_t)num_heads *
41 (
size_t)aligned_context_window *
42 (size_t)aligned_context_window;
54 const uint16_t *weights,
57 int aligned_context_window,
58 float *scratch_d_scores,
59 float *scratch_weights)
61 if (!d_scores || !weights || num_heads <= 0 || num_tokens <= 0 || aligned_context_window <= 0)
return;
62 if (!scratch_d_scores || !scratch_weights)
return;
64 const size_t total = (size_t)num_heads *
65 (
size_t)aligned_context_window *
66 (size_t)aligned_context_window;
74 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void backward_causal_softmax_head_major_bf16(uint16_t *d_scores, const uint16_t *weights, int num_heads, int num_tokens, int aligned_context_window, float *scratch_d_scores, float *scratch_weights)
void causal_softmax_head_major_bf16(uint16_t *scores, int num_heads, int num_tokens, int aligned_context_window, float *scratch)