24 #pragma GCC diagnostic push
25 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
47 const uint16_t *d_output,
51 float *scratch_d_output,
52 float *scratch_d_input)
54 if (!scratch_input || !scratch_d_output || !scratch_d_input)
return;
70 const uint16_t *d_output,
74 float *scratch_d_output,
75 float *scratch_d_input)
77 if (!scratch_input || !scratch_d_output || !scratch_d_input)
return;
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void gelu_exact_inplace(float *data, size_t n)
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
void gelu_backward_fast_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gelu_backward_exact_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gelu_fast_inplace_bf16(uint16_t *data, size_t n, float *scratch)