Skip to content

File jm_simd.h

File List > inc > jm_simd.h

Go to the documentation of this file

#ifndef JM_SIMD_H
#define JM_SIMD_H

#include <stddef.h> /* size_t */

/* Reuse JM_RESTRICT from jm_perf.h if available; otherwise define. */
#ifndef JM_RESTRICT
#if defined(__GNUC__) || defined(__clang__)
#define JM_RESTRICT __restrict__
#elif defined(_MSC_VER)
#define JM_RESTRICT __restrict
#else
#define JM_RESTRICT restrict
#endif
#endif

/* Pull in x86 intrinsic headers. */
#if (defined(__x86_64__) || defined(_M_X64) || defined(__i386__)              \
     || defined(_M_IX86))
#ifndef _IMMINTRIN_H_INCLUDED
#include <immintrin.h>
#endif
#endif

/* ════════════════════════════════════════════════════════════════════
 * Tier 1 — AVX-512F  (16 float / 8 double lanes)
 * ════════════════════════════════════════════════════════════════════ */
#if defined(__AVX512F__)

#define JM_SIMD_WIDTH_F32 16
#define JM_SIMD_WIDTH_F64 8
#define JM_SIMD_WIDTH JM_SIMD_WIDTH_F32

typedef __m512 JM_VEC_F32;
typedef __m512d JM_VEC_F64;

#define JM_ZERO_F32() _mm512_setzero_ps ()
#define JM_ZERO_F64() _mm512_setzero_pd ()
#define JM_SPLAT_F32(x) _mm512_set1_ps (x)
#define JM_SPLAT_F64(x) _mm512_set1_pd (x)
#define JM_LOAD_F32(p) _mm512_loadu_ps (p)
#define JM_LOAD_F64(p) _mm512_loadu_pd (p)
#define JM_STORE_F32(p, v) _mm512_storeu_ps (p, v)
#define JM_STORE_F64(p, v) _mm512_storeu_pd (p, v)
#define JM_ADD_F32(a, b) _mm512_add_ps (a, b)
#define JM_ADD_F64(a, b) _mm512_add_pd (a, b)
#define JM_MUL_F32(a, b) _mm512_mul_ps (a, b)
#define JM_MUL_F64(a, b) _mm512_mul_pd (a, b)
#define JM_FMA_F32(acc, a, b) ((acc) = _mm512_fmadd_ps (a, b, acc))
#define JM_FMA_F64(acc, a, b) ((acc) = _mm512_fmadd_pd (a, b, acc))
#define JM_MAC_F32(acc, ptr, s)                                               \
  JM_FMA_F32 (acc, JM_LOAD_F32 (ptr), JM_SPLAT_F32 (s))
#define JM_MAC_F64(acc, ptr, s)                                               \
  JM_FMA_F64 (acc, JM_LOAD_F64 (ptr), JM_SPLAT_F64 (s))
#define JM_HSUM_F32(v) _mm512_reduce_add_ps (v)
#define JM_HSUM_F64(v) _mm512_reduce_add_pd (v)

/* ════════════════════════════════════════════════════════════════════
 * Tier 2 — AVX2 + FMA  (8 float / 4 double lanes)
 * ════════════════════════════════════════════════════════════════════ */
#elif defined(__AVX2__) && defined(__FMA__)

#define JM_SIMD_WIDTH_F32 8
#define JM_SIMD_WIDTH_F64 4
#define JM_SIMD_WIDTH JM_SIMD_WIDTH_F32

typedef __m256 JM_VEC_F32;
typedef __m256d JM_VEC_F64;

#define JM_ZERO_F32() _mm256_setzero_ps ()
#define JM_ZERO_F64() _mm256_setzero_pd ()
#define JM_SPLAT_F32(x) _mm256_set1_ps (x)
#define JM_SPLAT_F64(x) _mm256_set1_pd (x)
#define JM_LOAD_F32(p) _mm256_loadu_ps (p)
#define JM_LOAD_F64(p) _mm256_loadu_pd (p)
#define JM_STORE_F32(p, v) _mm256_storeu_ps (p, v)
#define JM_STORE_F64(p, v) _mm256_storeu_pd (p, v)
#define JM_ADD_F32(a, b) _mm256_add_ps (a, b)
#define JM_ADD_F64(a, b) _mm256_add_pd (a, b)
#define JM_MUL_F32(a, b) _mm256_mul_ps (a, b)
#define JM_MUL_F64(a, b) _mm256_mul_pd (a, b)
#define JM_FMA_F32(acc, a, b) ((acc) = _mm256_fmadd_ps (a, b, acc))
#define JM_FMA_F64(acc, a, b) ((acc) = _mm256_fmadd_pd (a, b, acc))
#define JM_MAC_F32(acc, ptr, s)                                               \
  JM_FMA_F32 (acc, JM_LOAD_F32 (ptr), JM_SPLAT_F32 (s))
#define JM_MAC_F64(acc, ptr, s)                                               \
  JM_FMA_F64 (acc, JM_LOAD_F64 (ptr), JM_SPLAT_F64 (s))

/* Horizontal-sum helpers (SSE3 hadd guaranteed with AVX2). */
static inline float
_jm_hsum256_f32 (__m256 v)
{
  __m128 lo = _mm256_castps256_ps128 (v);
  __m128 hi = _mm256_extractf128_ps (v, 1);
  __m128 s = _mm_add_ps (lo, hi);
  s = _mm_hadd_ps (s, s);
  s = _mm_hadd_ps (s, s);
  return _mm_cvtss_f32 (s);
}
static inline double
_jm_hsum256_f64 (__m256d v)
{
  __m128d lo = _mm256_castpd256_pd128 (v);
  __m128d hi = _mm256_extractf128_pd (v, 1);
  __m128d s = _mm_add_pd (lo, hi);
  s = _mm_hadd_pd (s, s);
  return _mm_cvtsd_f64 (s);
}
#define JM_HSUM_F32(v) _jm_hsum256_f32 (v)
#define JM_HSUM_F64(v) _jm_hsum256_f64 (v)

/* ════════════════════════════════════════════════════════════════════
 * Tier 3 — Scalar  (1 lane; auto-vectorisation still applies)
 * ════════════════════════════════════════════════════════════════════ */
#else

#define JM_SIMD_WIDTH_F32 1
#define JM_SIMD_WIDTH_F64 1
#define JM_SIMD_WIDTH 1

typedef float JM_VEC_F32;
typedef double JM_VEC_F64;

#define JM_ZERO_F32() (0.0f)
#define JM_ZERO_F64() (0.0)
#define JM_SPLAT_F32(x) ((float)(x))
#define JM_SPLAT_F64(x) ((double)(x))
#define JM_LOAD_F32(p) (*(const float *)(p))
#define JM_LOAD_F64(p) (*(const double *)(p))
#define JM_STORE_F32(p, v) (*(float *)(p) = (v))
#define JM_STORE_F64(p, v) (*(double *)(p) = (v))
#define JM_ADD_F32(a, b) ((a) + (b))
#define JM_ADD_F64(a, b) ((a) + (b))
#define JM_MUL_F32(a, b) ((a) * (b))
#define JM_MUL_F64(a, b) ((a) * (b))
#define JM_FMA_F32(acc, a, b) ((acc) += (a) * (b))
#define JM_FMA_F64(acc, a, b) ((acc) += (a) * (b))
#define JM_MAC_F32(acc, ptr, s) ((acc) += (*(const float *)(ptr)) * (float)(s))
#define JM_MAC_F64(acc, ptr, s)                                               \
  ((acc) += (*(const double *)(ptr)) * (double)(s))
#define JM_HSUM_F32(v) ((float)(v))
#define JM_HSUM_F64(v) ((double)(v))

#endif /* ISA tiers */

/* ════════════════════════════════════════════════════════════════════
 * Composite reductions — built on the tier macros above, so they are
 * ISA-portable: the widest available tier vectorises them, the scalar
 * tier still compiles (and auto-vectorises) with identical results.
 * ════════════════════════════════════════════════════════════════════ */

#define JM_SUMSQ_F32(dst, ptr, n)                                             \
  do                                                                          \
    {                                                                         \
      const float *jm__p = (ptr);                                             \
      size_t jm__n = (size_t)(n);                                             \
      size_t jm__nv = jm__n - jm__n % (size_t)JM_SIMD_WIDTH_F32;              \
      JM_VEC_F32 jm__acc = JM_ZERO_F32 ();                                    \
      for (size_t jm__i = 0; jm__i < jm__nv;                                  \
           jm__i += (size_t)JM_SIMD_WIDTH_F32)                                \
        {                                                                     \
          JM_VEC_F32 jm__v = JM_LOAD_F32 (jm__p + jm__i);                     \
          JM_FMA_F32 (jm__acc, jm__v, jm__v);                                 \
        }                                                                     \
      float jm__s = JM_HSUM_F32 (jm__acc);                                    \
      for (size_t jm__i = jm__nv; jm__i < jm__n; jm__i++)                     \
        jm__s += jm__p[jm__i] * jm__p[jm__i];                                 \
      (dst) = jm__s;                                                          \
    }                                                                         \
  while (0)

#endif /* JM_SIMD_H */