32 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
33 #include <immintrin.h>
42 #if defined(__F16C__) || (defined(__AVX__) && !defined(__clang__))
51 return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
67 union {
float f; uint32_t u; } u = { f };
71 uint32_t sign = (x >> 16) & 0x8000;
72 int exp = ((x >> 23) & 0xFF) - 127 + 15;
73 uint32_t mant = (x >> 13) & 0x3FF;
77 if (exp < -10)
return (uint16_t)sign;
78 mant = (mant | 0x400) >> (1 - exp);
79 return (uint16_t)(sign | mant);
80 }
else if (exp >= 31) {
82 if (exp == 128 && (x & 0x7FFFFF)) {
83 return (uint16_t)(sign | 0x7E00 | mant);
85 return (uint16_t)(sign | 0x7C00);
88 return (uint16_t)(sign | ((uint32_t)exp << 10) | mant);
92 uint32_t sign = ((uint32_t)h & 0x8000) << 16;
93 int exp = (h >> 10) & 0x1F;
94 uint32_t mant = h & 0x3FF;
99 union { uint32_t u;
float f; } u = { sign };
103 while (!(mant & 0x400)) {
109 }
else if (exp == 31) {
111 union { uint32_t u;
float f; } u = { sign | 0x7F800000 | (mant << 13) };
115 union { uint32_t u;
float f; } u = { sign | ((uint32_t)(exp + 112) << 23) | (mant << 13) };
124 #if defined(__AVX512F__)
131 void ck_fp32_to_fp16_avx512(
const float *src, uint16_t *dst,
int n) {
132 if (!src || !dst || n <= 0)
return;
135 for (; i + 15 < n; i += 16) {
136 __m512 x = _mm512_loadu_ps(src + i);
137 __m256i y = _mm512_cvtps_ph(x, _MM_FROUND_TO_NEAREST_INT);
138 _mm256_storeu_si256((__m256i*)(dst + i), y);
153 void ck_fp16_to_fp32_avx512(
const uint16_t *src,
float *dst,
int n) {
154 if (!src || !dst || n <= 0)
return;
157 for (; i + 15 < n; i += 16) {
158 __m256i x = _mm256_loadu_si256((
const __m256i*)(src + i));
159 __m512 y = _mm512_cvtph_ps(x);
160 _mm512_storeu_ps(dst + i, y);
177 void ck_fp32_to_fp16_avx(
const float *src, uint16_t *dst,
int n) {
178 if (!src || !dst || n <= 0)
return;
181 #if defined(__F16C__)
182 for (; i + 7 < n; i += 8) {
183 __m256 x = _mm256_loadu_ps(src + i);
184 __m128i y = _mm256_cvtps_ph(x, _MM_FROUND_TO_NEAREST_INT);
185 _mm_storeu_si128((__m128i*)(dst + i), y);
201 void ck_fp16_to_fp32_avx(
const uint16_t *src,
float *dst,
int n) {
202 if (!src || !dst || n <= 0)
return;
205 #if defined(__F16C__)
206 for (; i + 7 < n; i += 8) {
207 __m128i x = _mm_loadu_si128((
const __m128i*)(src + i));
208 __m256 y = _mm256_cvtph_ps(x);
209 _mm256_storeu_ps(dst + i, y);
231 if (!src || !dst || n <= 0)
return;
233 #if defined(__AVX512F__)
234 ck_fp32_to_fp16_avx512(src, dst, n);
235 #elif defined(__AVX__)
236 ck_fp32_to_fp16_avx(src, dst, n);
238 for (
int i = 0; i < n; i++) {
251 if (!src || !dst || n <= 0)
return;
253 #if defined(__AVX512F__)
254 ck_fp16_to_fp32_avx512(src, dst, n);
255 #elif defined(__AVX__)
256 ck_fp16_to_fp32_avx(src, dst, n);
258 for (
int i = 0; i < n; i++) {
279 int src_stride,
int dst_stride) {
280 if (!src || !dst || rows <= 0 || cols <= 0)
return;
282 for (
int r = 0; r < rows; r++) {
284 dst + (
size_t)r * dst_stride,
300 int src_stride,
int dst_stride) {
301 if (!src || !dst || rows <= 0 || cols <= 0)
return;
303 for (
int r = 0; r < rows; r++) {
305 dst + (
size_t)r * dst_stride,
326 if (!data || !scratch || n <= 0)
return;
328 uint16_t *tmp = (uint16_t*)scratch;
332 uint16_t *dst = (uint16_t*)data;
333 for (
int i = 0; i < n; i++) {
351 uint16_t *dst,
int n) {
352 if (!a || !b || !c || !dst || n <= 0)
return;
354 #if defined(__AVX512F__)
356 for (; i + 15 < n; i += 16) {
357 __m512 va = _mm512_loadu_ps(a + i);
358 __m512 vb = _mm512_loadu_ps(b + i);
359 __m512 vc = _mm512_loadu_ps(c + i);
360 __m512 vr = _mm512_fmadd_ps(va, vb, vc);
361 __m256i vh = _mm512_cvtps_ph(vr, _MM_FROUND_TO_NEAREST_INT);
362 _mm256_storeu_si256((__m256i*)(dst + i), vh);
367 #elif defined(__AVX__) && defined(__F16C__)
369 for (; i + 7 < n; i += 8) {
370 __m256 va = _mm256_loadu_ps(a + i);
371 __m256 vb = _mm256_loadu_ps(b + i);
372 __m256 vc = _mm256_loadu_ps(c + i);
374 __m256 vr = _mm256_fmadd_ps(va, vb, vc);
376 __m256 vr = _mm256_add_ps(_mm256_mul_ps(va, vb), vc);
378 __m128i vh = _mm256_cvtps_ph(vr, _MM_FROUND_TO_NEAREST_INT);
379 _mm_storeu_si128((__m128i*)(dst + i), vh);
385 for (
int i = 0; i < n; i++) {
399 if (!src || !dst || n <= 0)
return;
401 #if defined(__AVX512F__)
402 __m512 vs = _mm512_set1_ps(scale);
404 for (; i + 15 < n; i += 16) {
405 __m512 vx = _mm512_loadu_ps(src + i);
406 __m512 vr = _mm512_mul_ps(vx, vs);
407 __m256i vh = _mm512_cvtps_ph(vr, _MM_FROUND_TO_NEAREST_INT);
408 _mm256_storeu_si256((__m256i*)(dst + i), vh);
413 #elif defined(__AVX__) && defined(__F16C__)
414 __m256 vs = _mm256_set1_ps(scale);
416 for (; i + 7 < n; i += 8) {
417 __m256 vx = _mm256_loadu_ps(src + i);
418 __m256 vr = _mm256_mul_ps(vx, vs);
419 __m128i vh = _mm256_cvtps_ph(vr, _MM_FROUND_TO_NEAREST_INT);
420 _mm_storeu_si128((__m128i*)(dst + i), vh);
426 for (
int i = 0; i < n; i++) {
void ck_scale_f32_to_f16(const float *src, float scale, uint16_t *dst, int n)
Scale FP32 array and store as FP16: dst = scale * src.
void ck_fp16_to_fp32_2d(const uint16_t *src, float *dst, int rows, int cols, int src_stride, int dst_stride)
Convert 2D FP16 matrix to FP32 with strided access.
void ck_fp32_to_fp16_2d(const float *src, uint16_t *dst, int rows, int cols, int src_stride, int dst_stride)
Convert 2D FP32 matrix to FP16 with strided access.
void ck_fma_f32_to_f16(const float *a, const float *b, const float *c, uint16_t *dst, int n)
FMA in FP32, store result as FP16: dst = a * b + c.
void ck_fp32_to_fp16_inplace(float *data, void *scratch, int n)
Convert FP32 to FP16 in-place using scratch buffer.
void ck_fp16_to_fp32_row(const uint16_t *src, float *dst, int n)
Convert FP16 row to FP32 (auto-select best implementation)
static uint16_t ck_fp32_to_fp16_scalar(float f)
static float ck_fp16_to_fp32_scalar(uint16_t h)
void ck_fp32_to_fp16_row(const float *src, uint16_t *dst, int n)
Convert FP32 row to FP16 (auto-select best implementation)