21 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
22 #include <immintrin.h>
119 const float *fc2_input,
136 T, aligned_in, aligned_out);
145 aligned_out, aligned_in, T);
148 #pragma omp parallel for schedule(static)
149 for (
int out_idx = 0; out_idx < aligned_out; ++out_idx) {
150 float bias_grad = 0.0f;
151 for (
int t = 0; t < T; ++t) {
152 bias_grad += d_output[(size_t)t * aligned_out + out_idx];
154 d_b_fc2[out_idx] += bias_grad;
168 const float *fc1_input,
185 T, aligned_in, aligned_out);
192 aligned_out, aligned_in, T);
195 #pragma omp parallel for schedule(static)
196 for (
int out_idx = 0; out_idx < aligned_out; ++out_idx) {
197 float bias_grad = 0.0f;
198 for (
int t = 0; t < T; ++t) {
199 bias_grad += d_output[(size_t)t * aligned_out + out_idx];
201 d_b_fc1[out_idx] += bias_grad;
void gelu_exact_inplace(float *data, size_t n)
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gelu_fast_inplace(float *data, size_t n)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void mlp_token_parallel(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
void mlp_token_parallel_exact(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)