Add strstr-avx2 based on strstr-avx512

Message ID 20240223072813.95327-1-tirtajames45@gmail.com
State Changes Requested
Headers
Series Add strstr-avx2 based on strstr-avx512 |

Checks

Context Check Description
redhat-pt-bot/TryBot-apply_patch success Patch applied to master at the time it was sent
redhat-pt-bot/TryBot-32bit success Build for i686
linaro-tcwg-bot/tcwg_glibc_build--master-arm success Testing passed
linaro-tcwg-bot/tcwg_glibc_build--master-aarch64 success Testing passed
linaro-tcwg-bot/tcwg_glibc_check--master-arm success Testing passed
linaro-tcwg-bot/tcwg_glibc_check--master-aarch64 success Testing passed

Commit Message

James Tirta Halim Feb. 23, 2024, 7:28 a.m. UTC
  Create a unified implementation for strstr-avx2 and strstr-avx512 based
on the existing strstr-avx512 (no changes were made in the
implementation). strstr-avx2 implements avx512 instructions using
equivalent avx2 or generic instructions.

basic_strstr twoway_strstr __strstr_avx2 __strstr_avx512 __strstr_sse2_unaligned __strstr_generic
average:
211775 32055.7 4961.68 3725.31 14687 17268.7
total:
1.12876e+08 1.70857e+07 2.64458e+06 1.98559e+06 7.82818e+06 9.2042e+06

Passes test-strstr.

---
 sysdeps/x86_64/multiarch/Makefile          |   2 +
 sysdeps/x86_64/multiarch/ifunc-impl-list.c |   4 +
 sysdeps/x86_64/multiarch/strstr-avx-base.h | 268 +++++++++++++++++++++
 sysdeps/x86_64/multiarch/strstr-avx2.c     |  19 ++
 sysdeps/x86_64/multiarch/strstr-avx512.c   | 221 ++---------------
 5 files changed, 312 insertions(+), 202 deletions(-)
 create mode 100644 sysdeps/x86_64/multiarch/strstr-avx-base.h
 create mode 100644 sysdeps/x86_64/multiarch/strstr-avx2.c
  

Comments

Noah Goldstein Feb. 23, 2024, 3:58 p.m. UTC | #1
On Fri, Feb 23, 2024 at 1:28 AM James Tirta Halim
<tirtajames45@gmail.com> wrote:
>
> Create a unified implementation for strstr-avx2 and strstr-avx512 based
> on the existing strstr-avx512 (no changes were made in the
> implementation). strstr-avx2 implements avx512 instructions using
> equivalent avx2 or generic instructions.
>
> basic_strstr twoway_strstr __strstr_avx2 __strstr_avx512 __strstr_sse2_unaligned __strstr_generic
> average:
> 211775 32055.7 4961.68 3725.31 14687 17268.7
> total:
> 1.12876e+08 1.70857e+07 2.64458e+06 1.98559e+06 7.82818e+06 9.2042e+06
>
> Passes test-strstr.
>
> ---
>  sysdeps/x86_64/multiarch/Makefile          |   2 +
>  sysdeps/x86_64/multiarch/ifunc-impl-list.c |   4 +
>  sysdeps/x86_64/multiarch/strstr-avx-base.h | 268 +++++++++++++++++++++
>  sysdeps/x86_64/multiarch/strstr-avx2.c     |  19 ++
>  sysdeps/x86_64/multiarch/strstr-avx512.c   | 221 ++---------------
>  5 files changed, 312 insertions(+), 202 deletions(-)
>  create mode 100644 sysdeps/x86_64/multiarch/strstr-avx-base.h
>  create mode 100644 sysdeps/x86_64/multiarch/strstr-avx2.c
>
> diff --git a/sysdeps/x86_64/multiarch/Makefile b/sysdeps/x86_64/multiarch/Makefile
> index d3d2270394..1cf74a13a2 100644
> --- a/sysdeps/x86_64/multiarch/Makefile
> +++ b/sysdeps/x86_64/multiarch/Makefile
> @@ -117,6 +117,7 @@ sysdep_routines += \
>    strrchr-evex512 \
>    strrchr-sse2 \
>    strspn-sse4 \
> +  strstr-avx2 \
>    strstr-avx512 \
>    strstr-sse2-unaligned \
>    varshift \
> @@ -126,6 +127,7 @@ CFLAGS-strcspn-sse4.c += -msse4
>  CFLAGS-strpbrk-sse4.c += -msse4
>  CFLAGS-strspn-sse4.c += -msse4
>
> +CFLAGS-strstr-avx2.c += -Wno-error=stringop-truncation -Wno-stringop-truncation -mavx2 -mbmi -mbmi2 -O3
>  CFLAGS-strstr-avx512.c += -mavx512f -mavx512vl -mavx512dq -mavx512bw -mbmi -mbmi2 -O3
>  endif
>
> diff --git a/sysdeps/x86_64/multiarch/ifunc-impl-list.c b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> index c4a21d4b7c..7b651c7a9c 100644
> --- a/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> +++ b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> @@ -790,6 +790,10 @@ __libc_ifunc_impl_list (const char *name, struct libc_ifunc_impl *array,
>
>    /* Support sysdeps/x86_64/multiarch/strstr.c.  */
>    IFUNC_IMPL (i, name, strstr,
> +              IFUNC_IMPL_ADD (array, i, strstr,
> +                              (CPU_FEATURE_USABLE (AVX2)
> +                               && CPU_FEATURE_USABLE (BMI2)),
> +                              __strstr_avx2)
>                IFUNC_IMPL_ADD (array, i, strstr,
>                                (CPU_FEATURE_USABLE (AVX512VL)
>                                 && CPU_FEATURE_USABLE (AVX512BW)
> diff --git a/sysdeps/x86_64/multiarch/strstr-avx-base.h b/sysdeps/x86_64/multiarch/strstr-avx-base.h
> new file mode 100644
> index 0000000000..e9f736606e
> --- /dev/null
> +++ b/sysdeps/x86_64/multiarch/strstr-avx-base.h
> @@ -0,0 +1,268 @@
> +/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
> +   This file is part of the GNU C Library.
> +
> +   The GNU C Library is free software; you can redistribute it and/or
> +   modify it under the terms of the GNU Lesser General Public
> +   License as published by the Free Software Foundation; either
> +   version 2.1 of the License, or (at your option) any later version.
> +
> +   The GNU C Library is distributed in the hope that it will be useful,
> +   but WITHOUT ANY WARRANTY; without even the implied warranty of
> +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> +   Lesser General Public License for more details.
> +
> +   You should have received a copy of the GNU Lesser General Public
> +   License along with the GNU C Library; if not, see
> +   <https://www.gnu.org/licenses/>.  */
> +
> +#include <immintrin.h>
> +#include <inttypes.h>
> +#include <stdbool.h>
> +#include <string.h>
> +
> +#define VEC_SIZE sizeof (VEC)
> +#define ONES ((MASK) -1)
> +#define ONE ((MASK) 0x1)
> +#define PAGE_SIZE 4096
> +#define CVTMASK(...) (MASK) (__VA_ARGS__)
> +#define KSHIFTRI(x, y) ((x) >> (y))
> +#define KAND_MASK(x, y) ((x) & (y))
> +
> +#ifndef FUNC_NAME
> +#  define FUNC_NAME __strstr_avx2
> +#endif
> +#ifndef VEC
> +#  define VEC __m256i
> +#endif
> +#ifndef MASK
> +#  define MASK uint32_t
> +#endif
> +#ifndef LOAD
> +#  define LOAD _mm256_load_si256
> +#endif
> +#ifndef LOADU
> +#  define LOADU _mm256_loadu_si256
> +#endif
> +#ifndef CMPEQ8_MASK
> +#  define CMPEQ8_MASK(x, y)                                                   \
> +    (MASK) _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, y))
> +#endif
> +#ifndef MASK_CMPNEQ8_MASK
> +#  define MASK_CMPNEQ8_MASK(m, a, b) ((~CMPEQ8_MASK (a, b)) & (m))
> +#endif
> +#ifndef SETONE8
> +#  define SETONE8 _mm256_set1_epi8
> +#endif
> +#ifndef SETZERO
> +#  define SETZERO() _mm256_setzero_si256 ()
> +#endif
> +#ifndef TESTN8_MASK
> +#  define TESTN8_MASK(x, y)                                                   \
> +    _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, SETZERO ()))
> +#endif
> +#ifndef MASK_TESTN8_MASK
> +#  define MASK_TESTN8_MASK(m, a, b) ((MASK) TESTN8_MASK (a, a) & (m))
> +#endif
> +#ifndef TZCNT
> +#  define TZCNT _tzcnt_u32
> +#endif
> +#ifndef BLSR
> +#  define BLSR _blsr_u32
> +#endif
> +#ifndef BZHI
> +#  define BZHI _bzhi_u32
> +#endif
> +#ifndef MASKZ_LOADU8
> +#  define MASKZ_LOADU8(m, x) maskz_loadu8 (x)
> +static inline VEC __attribute__ ((always_inline)) maskz_loadu8 (const void *x)
> +{
> +  VEC ret;
> +  strncpy ((char *) &ret, (const char *) x, VEC_SIZE);
> +  return ret;
> +}
> +#endif
> +
> +/*
> + Returns the index of the first edge within the needle, returns 0 if no edge
> + is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
> + */
> +static inline size_t __attribute__ ((always_inline))
> +find_edge_in_needle (const char *ne)
> +{
> +  size_t ind = 0;
> +  while (ne[ind + 1] != '\0')
> +    {
> +      if (ne[ind] != ne[ind + 1])
> +       return ind;
> +      else
> +       ind = ind + 1;
> +    }
> +  return 0;
> +}
> +
> +/*
> + Compare needle with hs byte by byte at specified location
> + */
> +static inline bool __attribute__ ((always_inline))
> +verify_string_match (const char *hay, const size_t hay_index, const char *ne,
> +                    size_t ind)
> +{
> +  while (ne[ind] != '\0')
> +    {
> +      if (ne[ind] != hay[hay_index + ind])
> +       return false;
> +      ind = ind + 1;
> +    }
> +  return true;
> +}
> +
> +/*
> + Compare needle with hs at specified location. The first VEC_SIZE bytes are
> + compared using a ZMM register.
> + */
> +static inline bool __attribute__ ((always_inline))
> +verify_string_match_vector (const char *hay, const size_t hay_index,
> +                           const char *ne, const MASK ned_mask,
> +                           const VEC ned_zmm)
> +{
> +  /* check first VEC_SIZE bytes using zmm and then scalar */
> +  VEC hay_zmm = LOADU ((const VEC *) (hay + hay_index)); // safe to do so
> +  MASK match = MASK_CMPNEQ8_MASK (ned_mask, hay_zmm, ned_zmm);
> +  if (match != 0x0) // failed the first few chars
> +    return false;
> +  else if (ned_mask == ONES)
> +    return verify_string_match (hay, hay_index, ne, VEC_SIZE);
> +  return true;
> +}
> +
> +char *
> +FUNC_NAME (const char *hs, const char *ne)
> +{
> +  char first = ne[0];
> +  if (__glibc_unlikely (first == '\0'))
> +    return (char *) hs;
> +  if (ne[1] == '\0')
> +    return (char *) strchr (hs, ne[0]);
> +
> +  size_t edge = find_edge_in_needle (ne);
> +
> +  /* ensure hs is as long as the pos of edge in needle */
> +  for (unsigned int ii = 0; ii < edge; ++ii)
> +    {
> +      if (__glibc_unlikely (hs[ii] == '\0'))
> +       return NULL;
> +    }
> +
> +  /*
> +   Load VEC_SIZE bytes of the needle and save it to a zmm register
> +   Read one cache line at a time to avoid loading across a page boundary
> +   */
> +  MASK ned_load_mask
> +      = BZHI (ONES, VEC_SIZE - ((uintptr_t) (ne) & (VEC_SIZE - 1)));
> +  VEC ned_zmm = MASKZ_LOADU8 (ned_load_mask, (const VEC *) ne);
> +  MASK ned_nullmask = MASK_TESTN8_MASK (ned_load_mask, ned_zmm, ned_zmm);
> +
> +  if (__glibc_unlikely (ned_nullmask == 0x0))
> +    {
> +      ned_zmm = LOADU ((const VEC *) ne);
> +      ned_nullmask = TESTN8_MASK (ned_zmm, ned_zmm);
> +      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE);
> +      if (ned_nullmask != 0x0)
> +       ned_load_mask = ned_load_mask >> 1;
> +    }
> +  else
> +    {
> +      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE);
> +      ned_load_mask = ned_load_mask >> 1;
> +    }
> +  const VEC ned0 = SETONE8 (ne[edge]);
> +  const VEC ned1 = SETONE8 (ne[edge + 1]);
> +
> +  /*
> +   Read the bytes of hs in the current cache line
> +   */
> +  size_t hay_index = edge;
> +  MASK loadmask = BZHI (
> +      ONES, VEC_SIZE - ((uintptr_t) (hs + hay_index) & (VEC_SIZE - 1)));
> +  /* First load is a partial cache line */
> +  VEC hay0 = MASKZ_LOADU8 (loadmask, (const VEC *) (hs + hay_index));
> +  /* Search for NULL and compare only till null char */
> +  MASK nullmask = CVTMASK (MASK_TESTN8_MASK (loadmask, hay0, hay0));
> +  MASK cmpmask = nullmask ^ (nullmask - ONE);
> +  cmpmask = cmpmask & CVTMASK (loadmask);
> +  /* Search for the 2 characters of needle */
> +  MASK k0 = CMPEQ8_MASK (hay0, ned0);
> +  MASK k1 = CMPEQ8_MASK (hay0, ned1);
> +  k1 = KSHIFTRI (k1, 1);
> +  /* k2 masks tell us if both chars from needle match */
> +  MASK k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask;
> +  /* For every match, search for the entire needle for a full match */
> +  while (k2)
> +    {
> +      MASK bitcount = TZCNT (k2);
> +      k2 = BLSR (k2);
> +      size_t match_pos = hay_index + bitcount - edge;
> +      if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1))
> +         < PAGE_SIZE - 1 - VEC_SIZE)
> +       {
> +         /*
> +          * Use vector compare as long as you are not crossing a page
> +          */
> +         if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask,
> +                                         ned_zmm))
> +           return (char *) hs + match_pos;
> +       }
> +      else
> +       {
> +         if (verify_string_match (hs, match_pos, ne, 0))
> +           return (char *) hs + match_pos;
> +       }
> +    }
> +  /* We haven't checked for potential match at the last char yet */
> +  hs = (const char *) (((uintptr_t) (hs + hay_index) | (VEC_SIZE - 1)));
> +  hay_index = 0;
> +
> +  /*
> +   Loop over one cache line at a time to prevent reading over page
> +   boundary
> +   */
> +  VEC hay1;
> +  while (nullmask == 0)
> +    {
> +      hay0 = LOADU ((const VEC *) (hs + hay_index));
> +      hay1 = LOAD (
> +         (const VEC *) (hs + hay_index + 1)); // Always VEC_SIZE byte aligned
> +      nullmask = CVTMASK (TESTN8_MASK (hay1, hay1));
> +      /* Compare only till null char */
> +      cmpmask = nullmask ^ (nullmask - ONE);
> +      k0 = CMPEQ8_MASK (hay0, ned0);
> +      k1 = CMPEQ8_MASK (hay1, ned1);
> +      /* k2 masks tell us if both chars from needle match */
> +      k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask;
> +      /* For every match, compare full strings for potential match */
> +      while (k2)
> +       {
> +         MASK bitcount = TZCNT (k2);
> +         k2 = BLSR (k2);
> +         size_t match_pos = hay_index + bitcount - edge;
> +         if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1))
> +             < PAGE_SIZE - 1 - VEC_SIZE)
> +           {
> +             /*
> +              * Use vector compare as long as you are not crossing a page
> +              */
> +             if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask,
> +                                             ned_zmm))
> +               return (char *) hs + match_pos;
> +           }
> +         else
> +           {
> +             /* Compare byte by byte */
> +             if (verify_string_match (hs, match_pos, ne, 0))
> +               return (char *) hs + match_pos;
> +           }
> +       }
> +      hay_index += VEC_SIZE;
> +    }
> +  return NULL;
> +}
> diff --git a/sysdeps/x86_64/multiarch/strstr-avx2.c b/sysdeps/x86_64/multiarch/strstr-avx2.c
> new file mode 100644
> index 0000000000..e86ffd160f
> --- /dev/null
> +++ b/sysdeps/x86_64/multiarch/strstr-avx2.c
> @@ -0,0 +1,19 @@
> +/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
> +   This file is part of the GNU C Library.
> +
> +   The GNU C Library is free software; you can redistribute it and/or
> +   modify it under the terms of the GNU Lesser General Public
> +   License as published by the Free Software Foundation; either
> +   version 2.1 of the License, or (at your option) any later version.
> +
> +   The GNU C Library is distributed in the hope that it will be useful,
> +   but WITHOUT ANY WARRANTY; without even the implied warranty of
> +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> +   Lesser General Public License for more details.
> +
> +   You should have received a copy of the GNU Lesser General Public
> +   License along with the GNU C Library; if not, see
> +   <https://www.gnu.org/licenses/>.  */
> +
> +#define FUNC_NAME __strstr_avx2
> +#include "strstr-avx-base.h"
> diff --git a/sysdeps/x86_64/multiarch/strstr-avx512.c b/sysdeps/x86_64/multiarch/strstr-avx512.c
> index 3ac53accbd..5eb69043b9 100644
> --- a/sysdeps/x86_64/multiarch/strstr-avx512.c
> +++ b/sysdeps/x86_64/multiarch/strstr-avx512.c
> @@ -1,5 +1,4 @@
> -/* strstr optimized with 512-bit AVX-512 instructions
> -   Copyright (C) 2022-2024 Free Software Foundation, Inc.
> +/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
>     This file is part of the GNU C Library.
>
>     The GNU C Library is free software; you can redistribute it and/or
> @@ -16,203 +15,21 @@
>     License along with the GNU C Library; if not, see
>     <https://www.gnu.org/licenses/>.  */
>
> -#include <immintrin.h>
> -#include <inttypes.h>
> -#include <stdbool.h>
> -#include <string.h>
> -
> -#define FULL_MMASK64 0xffffffffffffffff
> -#define ONE_64BIT 0x1ull
> -#define ZMM_SIZE_IN_BYTES 64
> -#define PAGESIZE 4096
> -
> -#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
> -#define kshiftri_mask64(x, y) ((x) >> (y))
> -#define kand_mask64(x, y) ((x) & (y))
> -
> -/*
> - Returns the index of the first edge within the needle, returns 0 if no edge
> - is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
> - */
> -static inline size_t
> -find_edge_in_needle (const char *ned)
> -{
> -  size_t ind = 0;
> -  while (ned[ind + 1] != '\0')
> -    {
> -      if (ned[ind] != ned[ind + 1])
> -        return ind;
> -      else
> -        ind = ind + 1;
> -    }
> -  return 0;
> -}
> -
> -/*
> - Compare needle with haystack byte by byte at specified location
> - */
> -static inline bool
> -verify_string_match (const char *hay, const size_t hay_index, const char *ned,
> -                     size_t ind)
> -{
> -  while (ned[ind] != '\0')
> -    {
> -      if (ned[ind] != hay[hay_index + ind])
> -        return false;
> -      ind = ind + 1;
> -    }
> -  return true;
> -}
> -
> -/*
> - Compare needle with haystack at specified location. The first 64 bytes are
> - compared using a ZMM register.
> - */
> -static inline bool
> -verify_string_match_avx512 (const char *hay, const size_t hay_index,
> -                            const char *ned, const __mmask64 ned_mask,
> -                            const __m512i ned_zmm)
> -{
> -  /* check first 64 bytes using zmm and then scalar */
> -  __m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so
> -  __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm);
> -  if (match != 0x0) // failed the first few chars
> -    return false;
> -  else if (ned_mask == FULL_MMASK64)
> -    return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES);
> -  return true;
> -}
> -
> -char *
> -__strstr_avx512 (const char *haystack, const char *ned)
> -{
> -  char first = ned[0];
> -  if (first == '\0')
> -    return (char *)haystack;
> -  if (ned[1] == '\0')
> -    return (char *)strchr (haystack, ned[0]);
> -
> -  size_t edge = find_edge_in_needle (ned);
> -
> -  /* ensure haystack is as long as the pos of edge in needle */
> -  for (int ii = 0; ii < edge; ++ii)
> -    {
> -      if (haystack[ii] == '\0')
> -        return NULL;
> -    }
> -
> -  /*
> -   Load 64 bytes of the needle and save it to a zmm register
> -   Read one cache line at a time to avoid loading across a page boundary
> -   */
> -  __mmask64 ned_load_mask = _bzhi_u64 (
> -      FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63));
> -  __m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned);
> -  __mmask64 ned_nullmask
> -      = _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm);
> -
> -  if (__glibc_unlikely (ned_nullmask == 0x0))
> -    {
> -      ned_zmm = _mm512_loadu_si512 (ned);
> -      ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm);
> -      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
> -      if (ned_nullmask != 0x0)
> -        ned_load_mask = ned_load_mask >> 1;
> -    }
> -  else
> -    {
> -      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
> -      ned_load_mask = ned_load_mask >> 1;
> -    }
> -  const __m512i ned0 = _mm512_set1_epi8 (ned[edge]);
> -  const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]);
> -
> -  /*
> -   Read the bytes of haystack in the current cache line
> -   */
> -  size_t hay_index = edge;
> -  __mmask64 loadmask = _bzhi_u64 (
> -      FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63));
> -  /* First load is a partial cache line */
> -  __m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index);
> -  /* Search for NULL and compare only till null char */
> -  uint64_t nullmask
> -      = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0));
> -  uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT);
> -  cmpmask = cmpmask & cvtmask64_u64 (loadmask);
> -  /* Search for the 2 characters of needle */
> -  __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
> -  __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1);
> -  k1 = kshiftri_mask64 (k1, 1);
> -  /* k2 masks tell us if both chars from needle match */
> -  uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
> -  /* For every match, search for the entire needle for a full match */
> -  while (k2)
> -    {
> -      uint64_t bitcount = _tzcnt_u64 (k2);
> -      k2 = _blsr_u64 (k2);
> -      size_t match_pos = hay_index + bitcount - edge;
> -      if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
> -          < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
> -        {
> -          /*
> -           * Use vector compare as long as you are not crossing a page
> -           */
> -          if (verify_string_match_avx512 (haystack, match_pos, ned,
> -                                          ned_load_mask, ned_zmm))
> -            return (char *)haystack + match_pos;
> -        }
> -      else
> -        {
> -          if (verify_string_match (haystack, match_pos, ned, 0))
> -            return (char *)haystack + match_pos;
> -        }
> -    }
> -  /* We haven't checked for potential match at the last char yet */
> -  haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63));
> -  hay_index = 0;
> -
> -  /*
> -   Loop over one cache line at a time to prevent reading over page
> -   boundary
> -   */
> -  __m512i hay1;
> -  while (nullmask == 0)
> -    {
> -      hay0 = _mm512_loadu_si512 (haystack + hay_index);
> -      hay1 = _mm512_load_si512 (haystack + hay_index
> -                                + 1); // Always 64 byte aligned
> -      nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1));
> -      /* Compare only till null char */
> -      cmpmask = nullmask ^ (nullmask - ONE_64BIT);
> -      k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
> -      k1 = _mm512_cmpeq_epi8_mask (hay1, ned1);
> -      /* k2 masks tell us if both chars from needle match */
> -      k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
> -      /* For every match, compare full strings for potential match */
> -      while (k2)
> -        {
> -          uint64_t bitcount = _tzcnt_u64 (k2);
> -          k2 = _blsr_u64 (k2);
> -          size_t match_pos = hay_index + bitcount - edge;
> -          if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
> -              < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
> -            {
> -              /*
> -               * Use vector compare as long as you are not crossing a page
> -               */
> -              if (verify_string_match_avx512 (haystack, match_pos, ned,
> -                                              ned_load_mask, ned_zmm))
> -                return (char *)haystack + match_pos;
> -            }
> -          else
> -            {
> -              /* Compare byte by byte */
> -              if (verify_string_match (haystack, match_pos, ned, 0))
> -                return (char *)haystack + match_pos;
> -            }
> -        }
> -      hay_index += ZMM_SIZE_IN_BYTES;
> -    }
> -  return NULL;
> -}
> +#define FUNC_NAME __strstr_avx512
> +#define VEC __m512i
> +#define MASK uint64_t
> +#define LOAD _mm512_load_si512
> +#define LOADU _mm512_loadu_si512
> +#define MOVEMASK8 _mm512_movemask_epi8
> +#define CMPEQ8_MASK _mm512_cmpeq_epi8_mask
> +#define MASK_CMPNEQ8_MASK _mm512_mask_cmpneq_epi8_mask
> +#define SETONE8 _mm512_set1_epi8
> +#define SETZERO _mm512_setzero_si512
> +#define TESTN8_MASK _mm512_testn_epi8_mask
> +#define MASK_TESTN8_MASK _mm512_mask_testn_epi8_mask
> +#define TZCNT _tzcnt_u64
> +#define BLSR _blsr_u64
> +#define BZHI _bzhi_u64
> +#define MASKZ_LOADU8 _mm512_maskz_loadu_epi8
> +
> +#include "strstr-avx-base.h"
> --
> 2.43.2
>

Thank you for the patch :)

Can you finish up your memmem patch before moving onto this one?
Some of the comments regarding your memmem patch apply to this
one, so it would be nice to finish that one up first.
  
H.J. Lu Feb. 24, 2024, 4:31 p.m. UTC | #2
On Thu, Feb 22, 2024 at 11:28 PM James Tirta Halim
<tirtajames45@gmail.com> wrote:
>
> Create a unified implementation for strstr-avx2 and strstr-avx512 based
> on the existing strstr-avx512 (no changes were made in the
> implementation). strstr-avx2 implements avx512 instructions using
> equivalent avx2 or generic instructions.
>
> basic_strstr twoway_strstr __strstr_avx2 __strstr_avx512 __strstr_sse2_unaligned __strstr_generic
> average:
> 211775 32055.7 4961.68 3725.31 14687 17268.7
> total:
> 1.12876e+08 1.70857e+07 2.64458e+06 1.98559e+06 7.82818e+06 9.2042e+06
>
> Passes test-strstr.
>
> ---
>  sysdeps/x86_64/multiarch/Makefile          |   2 +
>  sysdeps/x86_64/multiarch/ifunc-impl-list.c |   4 +
>  sysdeps/x86_64/multiarch/strstr-avx-base.h | 268 +++++++++++++++++++++
>  sysdeps/x86_64/multiarch/strstr-avx2.c     |  19 ++
>  sysdeps/x86_64/multiarch/strstr-avx512.c   | 221 ++---------------
>  5 files changed, 312 insertions(+), 202 deletions(-)
>  create mode 100644 sysdeps/x86_64/multiarch/strstr-avx-base.h
>  create mode 100644 sysdeps/x86_64/multiarch/strstr-avx2.c
>
> diff --git a/sysdeps/x86_64/multiarch/Makefile b/sysdeps/x86_64/multiarch/Makefile
> index d3d2270394..1cf74a13a2 100644
> --- a/sysdeps/x86_64/multiarch/Makefile
> +++ b/sysdeps/x86_64/multiarch/Makefile
> @@ -117,6 +117,7 @@ sysdep_routines += \
>    strrchr-evex512 \
>    strrchr-sse2 \
>    strspn-sse4 \
> +  strstr-avx2 \
>    strstr-avx512 \
>    strstr-sse2-unaligned \
>    varshift \
> @@ -126,6 +127,7 @@ CFLAGS-strcspn-sse4.c += -msse4
>  CFLAGS-strpbrk-sse4.c += -msse4
>  CFLAGS-strspn-sse4.c += -msse4
>
> +CFLAGS-strstr-avx2.c += -Wno-error=stringop-truncation -Wno-stringop-truncation -mavx2 -mbmi -mbmi2 -O3
>  CFLAGS-strstr-avx512.c += -mavx512f -mavx512vl -mavx512dq -mavx512bw -mbmi -mbmi2 -O3
>  endif
>
> diff --git a/sysdeps/x86_64/multiarch/ifunc-impl-list.c b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> index c4a21d4b7c..7b651c7a9c 100644
> --- a/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> +++ b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
> @@ -790,6 +790,10 @@ __libc_ifunc_impl_list (const char *name, struct libc_ifunc_impl *array,
>
>    /* Support sysdeps/x86_64/multiarch/strstr.c.  */
>    IFUNC_IMPL (i, name, strstr,
> +              IFUNC_IMPL_ADD (array, i, strstr,
> +                              (CPU_FEATURE_USABLE (AVX2)
> +                               && CPU_FEATURE_USABLE (BMI2)),
> +                              __strstr_avx2)

Use X86_IFUNC_IMPL_ADD_V3

>                IFUNC_IMPL_ADD (array, i, strstr,
>                                (CPU_FEATURE_USABLE (AVX512VL)
>                                 && CPU_FEATURE_USABLE (AVX512BW)

Use X86_IFUNC_IMPL_ADD_V4

Need to update strstr.c.
  

Patch

diff --git a/sysdeps/x86_64/multiarch/Makefile b/sysdeps/x86_64/multiarch/Makefile
index d3d2270394..1cf74a13a2 100644
--- a/sysdeps/x86_64/multiarch/Makefile
+++ b/sysdeps/x86_64/multiarch/Makefile
@@ -117,6 +117,7 @@  sysdep_routines += \
   strrchr-evex512 \
   strrchr-sse2 \
   strspn-sse4 \
+  strstr-avx2 \
   strstr-avx512 \
   strstr-sse2-unaligned \
   varshift \
@@ -126,6 +127,7 @@  CFLAGS-strcspn-sse4.c += -msse4
 CFLAGS-strpbrk-sse4.c += -msse4
 CFLAGS-strspn-sse4.c += -msse4
 
+CFLAGS-strstr-avx2.c += -Wno-error=stringop-truncation -Wno-stringop-truncation -mavx2 -mbmi -mbmi2 -O3
 CFLAGS-strstr-avx512.c += -mavx512f -mavx512vl -mavx512dq -mavx512bw -mbmi -mbmi2 -O3
 endif
 
diff --git a/sysdeps/x86_64/multiarch/ifunc-impl-list.c b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
index c4a21d4b7c..7b651c7a9c 100644
--- a/sysdeps/x86_64/multiarch/ifunc-impl-list.c
+++ b/sysdeps/x86_64/multiarch/ifunc-impl-list.c
@@ -790,6 +790,10 @@  __libc_ifunc_impl_list (const char *name, struct libc_ifunc_impl *array,
 
   /* Support sysdeps/x86_64/multiarch/strstr.c.  */
   IFUNC_IMPL (i, name, strstr,
+              IFUNC_IMPL_ADD (array, i, strstr,
+                              (CPU_FEATURE_USABLE (AVX2)
+                               && CPU_FEATURE_USABLE (BMI2)),
+                              __strstr_avx2)
               IFUNC_IMPL_ADD (array, i, strstr,
                               (CPU_FEATURE_USABLE (AVX512VL)
                                && CPU_FEATURE_USABLE (AVX512BW)
diff --git a/sysdeps/x86_64/multiarch/strstr-avx-base.h b/sysdeps/x86_64/multiarch/strstr-avx-base.h
new file mode 100644
index 0000000000..e9f736606e
--- /dev/null
+++ b/sysdeps/x86_64/multiarch/strstr-avx-base.h
@@ -0,0 +1,268 @@ 
+/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
+   This file is part of the GNU C Library.
+
+   The GNU C Library is free software; you can redistribute it and/or
+   modify it under the terms of the GNU Lesser General Public
+   License as published by the Free Software Foundation; either
+   version 2.1 of the License, or (at your option) any later version.
+
+   The GNU C Library is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+   Lesser General Public License for more details.
+
+   You should have received a copy of the GNU Lesser General Public
+   License along with the GNU C Library; if not, see
+   <https://www.gnu.org/licenses/>.  */
+
+#include <immintrin.h>
+#include <inttypes.h>
+#include <stdbool.h>
+#include <string.h>
+
+#define VEC_SIZE sizeof (VEC)
+#define ONES ((MASK) -1)
+#define ONE ((MASK) 0x1)
+#define PAGE_SIZE 4096
+#define CVTMASK(...) (MASK) (__VA_ARGS__)
+#define KSHIFTRI(x, y) ((x) >> (y))
+#define KAND_MASK(x, y) ((x) & (y))
+
+#ifndef FUNC_NAME
+#  define FUNC_NAME __strstr_avx2
+#endif
+#ifndef VEC
+#  define VEC __m256i
+#endif
+#ifndef MASK
+#  define MASK uint32_t
+#endif
+#ifndef LOAD
+#  define LOAD _mm256_load_si256
+#endif
+#ifndef LOADU
+#  define LOADU _mm256_loadu_si256
+#endif
+#ifndef CMPEQ8_MASK
+#  define CMPEQ8_MASK(x, y)                                                   \
+    (MASK) _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, y))
+#endif
+#ifndef MASK_CMPNEQ8_MASK
+#  define MASK_CMPNEQ8_MASK(m, a, b) ((~CMPEQ8_MASK (a, b)) & (m))
+#endif
+#ifndef SETONE8
+#  define SETONE8 _mm256_set1_epi8
+#endif
+#ifndef SETZERO
+#  define SETZERO() _mm256_setzero_si256 ()
+#endif
+#ifndef TESTN8_MASK
+#  define TESTN8_MASK(x, y)                                                   \
+    _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, SETZERO ()))
+#endif
+#ifndef MASK_TESTN8_MASK
+#  define MASK_TESTN8_MASK(m, a, b) ((MASK) TESTN8_MASK (a, a) & (m))
+#endif
+#ifndef TZCNT
+#  define TZCNT _tzcnt_u32
+#endif
+#ifndef BLSR
+#  define BLSR _blsr_u32
+#endif
+#ifndef BZHI
+#  define BZHI _bzhi_u32
+#endif
+#ifndef MASKZ_LOADU8
+#  define MASKZ_LOADU8(m, x) maskz_loadu8 (x)
+static inline VEC __attribute__ ((always_inline)) maskz_loadu8 (const void *x)
+{
+  VEC ret;
+  strncpy ((char *) &ret, (const char *) x, VEC_SIZE);
+  return ret;
+}
+#endif
+
+/*
+ Returns the index of the first edge within the needle, returns 0 if no edge
+ is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
+ */
+static inline size_t __attribute__ ((always_inline))
+find_edge_in_needle (const char *ne)
+{
+  size_t ind = 0;
+  while (ne[ind + 1] != '\0')
+    {
+      if (ne[ind] != ne[ind + 1])
+	return ind;
+      else
+	ind = ind + 1;
+    }
+  return 0;
+}
+
+/*
+ Compare needle with hs byte by byte at specified location
+ */
+static inline bool __attribute__ ((always_inline))
+verify_string_match (const char *hay, const size_t hay_index, const char *ne,
+		     size_t ind)
+{
+  while (ne[ind] != '\0')
+    {
+      if (ne[ind] != hay[hay_index + ind])
+	return false;
+      ind = ind + 1;
+    }
+  return true;
+}
+
+/*
+ Compare needle with hs at specified location. The first VEC_SIZE bytes are
+ compared using a ZMM register.
+ */
+static inline bool __attribute__ ((always_inline))
+verify_string_match_vector (const char *hay, const size_t hay_index,
+			    const char *ne, const MASK ned_mask,
+			    const VEC ned_zmm)
+{
+  /* check first VEC_SIZE bytes using zmm and then scalar */
+  VEC hay_zmm = LOADU ((const VEC *) (hay + hay_index)); // safe to do so
+  MASK match = MASK_CMPNEQ8_MASK (ned_mask, hay_zmm, ned_zmm);
+  if (match != 0x0) // failed the first few chars
+    return false;
+  else if (ned_mask == ONES)
+    return verify_string_match (hay, hay_index, ne, VEC_SIZE);
+  return true;
+}
+
+char *
+FUNC_NAME (const char *hs, const char *ne)
+{
+  char first = ne[0];
+  if (__glibc_unlikely (first == '\0'))
+    return (char *) hs;
+  if (ne[1] == '\0')
+    return (char *) strchr (hs, ne[0]);
+
+  size_t edge = find_edge_in_needle (ne);
+
+  /* ensure hs is as long as the pos of edge in needle */
+  for (unsigned int ii = 0; ii < edge; ++ii)
+    {
+      if (__glibc_unlikely (hs[ii] == '\0'))
+	return NULL;
+    }
+
+  /*
+   Load VEC_SIZE bytes of the needle and save it to a zmm register
+   Read one cache line at a time to avoid loading across a page boundary
+   */
+  MASK ned_load_mask
+      = BZHI (ONES, VEC_SIZE - ((uintptr_t) (ne) & (VEC_SIZE - 1)));
+  VEC ned_zmm = MASKZ_LOADU8 (ned_load_mask, (const VEC *) ne);
+  MASK ned_nullmask = MASK_TESTN8_MASK (ned_load_mask, ned_zmm, ned_zmm);
+
+  if (__glibc_unlikely (ned_nullmask == 0x0))
+    {
+      ned_zmm = LOADU ((const VEC *) ne);
+      ned_nullmask = TESTN8_MASK (ned_zmm, ned_zmm);
+      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE);
+      if (ned_nullmask != 0x0)
+	ned_load_mask = ned_load_mask >> 1;
+    }
+  else
+    {
+      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE);
+      ned_load_mask = ned_load_mask >> 1;
+    }
+  const VEC ned0 = SETONE8 (ne[edge]);
+  const VEC ned1 = SETONE8 (ne[edge + 1]);
+
+  /*
+   Read the bytes of hs in the current cache line
+   */
+  size_t hay_index = edge;
+  MASK loadmask = BZHI (
+      ONES, VEC_SIZE - ((uintptr_t) (hs + hay_index) & (VEC_SIZE - 1)));
+  /* First load is a partial cache line */
+  VEC hay0 = MASKZ_LOADU8 (loadmask, (const VEC *) (hs + hay_index));
+  /* Search for NULL and compare only till null char */
+  MASK nullmask = CVTMASK (MASK_TESTN8_MASK (loadmask, hay0, hay0));
+  MASK cmpmask = nullmask ^ (nullmask - ONE);
+  cmpmask = cmpmask & CVTMASK (loadmask);
+  /* Search for the 2 characters of needle */
+  MASK k0 = CMPEQ8_MASK (hay0, ned0);
+  MASK k1 = CMPEQ8_MASK (hay0, ned1);
+  k1 = KSHIFTRI (k1, 1);
+  /* k2 masks tell us if both chars from needle match */
+  MASK k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask;
+  /* For every match, search for the entire needle for a full match */
+  while (k2)
+    {
+      MASK bitcount = TZCNT (k2);
+      k2 = BLSR (k2);
+      size_t match_pos = hay_index + bitcount - edge;
+      if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1))
+	  < PAGE_SIZE - 1 - VEC_SIZE)
+	{
+	  /*
+	   * Use vector compare as long as you are not crossing a page
+	   */
+	  if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask,
+					  ned_zmm))
+	    return (char *) hs + match_pos;
+	}
+      else
+	{
+	  if (verify_string_match (hs, match_pos, ne, 0))
+	    return (char *) hs + match_pos;
+	}
+    }
+  /* We haven't checked for potential match at the last char yet */
+  hs = (const char *) (((uintptr_t) (hs + hay_index) | (VEC_SIZE - 1)));
+  hay_index = 0;
+
+  /*
+   Loop over one cache line at a time to prevent reading over page
+   boundary
+   */
+  VEC hay1;
+  while (nullmask == 0)
+    {
+      hay0 = LOADU ((const VEC *) (hs + hay_index));
+      hay1 = LOAD (
+	  (const VEC *) (hs + hay_index + 1)); // Always VEC_SIZE byte aligned
+      nullmask = CVTMASK (TESTN8_MASK (hay1, hay1));
+      /* Compare only till null char */
+      cmpmask = nullmask ^ (nullmask - ONE);
+      k0 = CMPEQ8_MASK (hay0, ned0);
+      k1 = CMPEQ8_MASK (hay1, ned1);
+      /* k2 masks tell us if both chars from needle match */
+      k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask;
+      /* For every match, compare full strings for potential match */
+      while (k2)
+	{
+	  MASK bitcount = TZCNT (k2);
+	  k2 = BLSR (k2);
+	  size_t match_pos = hay_index + bitcount - edge;
+	  if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1))
+	      < PAGE_SIZE - 1 - VEC_SIZE)
+	    {
+	      /*
+	       * Use vector compare as long as you are not crossing a page
+	       */
+	      if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask,
+					      ned_zmm))
+		return (char *) hs + match_pos;
+	    }
+	  else
+	    {
+	      /* Compare byte by byte */
+	      if (verify_string_match (hs, match_pos, ne, 0))
+		return (char *) hs + match_pos;
+	    }
+	}
+      hay_index += VEC_SIZE;
+    }
+  return NULL;
+}
diff --git a/sysdeps/x86_64/multiarch/strstr-avx2.c b/sysdeps/x86_64/multiarch/strstr-avx2.c
new file mode 100644
index 0000000000..e86ffd160f
--- /dev/null
+++ b/sysdeps/x86_64/multiarch/strstr-avx2.c
@@ -0,0 +1,19 @@ 
+/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
+   This file is part of the GNU C Library.
+
+   The GNU C Library is free software; you can redistribute it and/or
+   modify it under the terms of the GNU Lesser General Public
+   License as published by the Free Software Foundation; either
+   version 2.1 of the License, or (at your option) any later version.
+
+   The GNU C Library is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+   Lesser General Public License for more details.
+
+   You should have received a copy of the GNU Lesser General Public
+   License along with the GNU C Library; if not, see
+   <https://www.gnu.org/licenses/>.  */
+
+#define FUNC_NAME __strstr_avx2
+#include "strstr-avx-base.h"
diff --git a/sysdeps/x86_64/multiarch/strstr-avx512.c b/sysdeps/x86_64/multiarch/strstr-avx512.c
index 3ac53accbd..5eb69043b9 100644
--- a/sysdeps/x86_64/multiarch/strstr-avx512.c
+++ b/sysdeps/x86_64/multiarch/strstr-avx512.c
@@ -1,5 +1,4 @@ 
-/* strstr optimized with 512-bit AVX-512 instructions
-   Copyright (C) 2022-2024 Free Software Foundation, Inc.
+/* Copyright (C) 2022-2024 Free Software Foundation, Inc.
    This file is part of the GNU C Library.
 
    The GNU C Library is free software; you can redistribute it and/or
@@ -16,203 +15,21 @@ 
    License along with the GNU C Library; if not, see
    <https://www.gnu.org/licenses/>.  */
 
-#include <immintrin.h>
-#include <inttypes.h>
-#include <stdbool.h>
-#include <string.h>
-
-#define FULL_MMASK64 0xffffffffffffffff
-#define ONE_64BIT 0x1ull
-#define ZMM_SIZE_IN_BYTES 64
-#define PAGESIZE 4096
-
-#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
-#define kshiftri_mask64(x, y) ((x) >> (y))
-#define kand_mask64(x, y) ((x) & (y))
-
-/*
- Returns the index of the first edge within the needle, returns 0 if no edge
- is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
- */
-static inline size_t
-find_edge_in_needle (const char *ned)
-{
-  size_t ind = 0;
-  while (ned[ind + 1] != '\0')
-    {
-      if (ned[ind] != ned[ind + 1])
-        return ind;
-      else
-        ind = ind + 1;
-    }
-  return 0;
-}
-
-/*
- Compare needle with haystack byte by byte at specified location
- */
-static inline bool
-verify_string_match (const char *hay, const size_t hay_index, const char *ned,
-                     size_t ind)
-{
-  while (ned[ind] != '\0')
-    {
-      if (ned[ind] != hay[hay_index + ind])
-        return false;
-      ind = ind + 1;
-    }
-  return true;
-}
-
-/*
- Compare needle with haystack at specified location. The first 64 bytes are
- compared using a ZMM register.
- */
-static inline bool
-verify_string_match_avx512 (const char *hay, const size_t hay_index,
-                            const char *ned, const __mmask64 ned_mask,
-                            const __m512i ned_zmm)
-{
-  /* check first 64 bytes using zmm and then scalar */
-  __m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so
-  __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm);
-  if (match != 0x0) // failed the first few chars
-    return false;
-  else if (ned_mask == FULL_MMASK64)
-    return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES);
-  return true;
-}
-
-char *
-__strstr_avx512 (const char *haystack, const char *ned)
-{
-  char first = ned[0];
-  if (first == '\0')
-    return (char *)haystack;
-  if (ned[1] == '\0')
-    return (char *)strchr (haystack, ned[0]);
-
-  size_t edge = find_edge_in_needle (ned);
-
-  /* ensure haystack is as long as the pos of edge in needle */
-  for (int ii = 0; ii < edge; ++ii)
-    {
-      if (haystack[ii] == '\0')
-        return NULL;
-    }
-
-  /*
-   Load 64 bytes of the needle and save it to a zmm register
-   Read one cache line at a time to avoid loading across a page boundary
-   */
-  __mmask64 ned_load_mask = _bzhi_u64 (
-      FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63));
-  __m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned);
-  __mmask64 ned_nullmask
-      = _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm);
-
-  if (__glibc_unlikely (ned_nullmask == 0x0))
-    {
-      ned_zmm = _mm512_loadu_si512 (ned);
-      ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm);
-      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
-      if (ned_nullmask != 0x0)
-        ned_load_mask = ned_load_mask >> 1;
-    }
-  else
-    {
-      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
-      ned_load_mask = ned_load_mask >> 1;
-    }
-  const __m512i ned0 = _mm512_set1_epi8 (ned[edge]);
-  const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]);
-
-  /*
-   Read the bytes of haystack in the current cache line
-   */
-  size_t hay_index = edge;
-  __mmask64 loadmask = _bzhi_u64 (
-      FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63));
-  /* First load is a partial cache line */
-  __m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index);
-  /* Search for NULL and compare only till null char */
-  uint64_t nullmask
-      = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0));
-  uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT);
-  cmpmask = cmpmask & cvtmask64_u64 (loadmask);
-  /* Search for the 2 characters of needle */
-  __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
-  __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1);
-  k1 = kshiftri_mask64 (k1, 1);
-  /* k2 masks tell us if both chars from needle match */
-  uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
-  /* For every match, search for the entire needle for a full match */
-  while (k2)
-    {
-      uint64_t bitcount = _tzcnt_u64 (k2);
-      k2 = _blsr_u64 (k2);
-      size_t match_pos = hay_index + bitcount - edge;
-      if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
-          < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
-        {
-          /*
-           * Use vector compare as long as you are not crossing a page
-           */
-          if (verify_string_match_avx512 (haystack, match_pos, ned,
-                                          ned_load_mask, ned_zmm))
-            return (char *)haystack + match_pos;
-        }
-      else
-        {
-          if (verify_string_match (haystack, match_pos, ned, 0))
-            return (char *)haystack + match_pos;
-        }
-    }
-  /* We haven't checked for potential match at the last char yet */
-  haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63));
-  hay_index = 0;
-
-  /*
-   Loop over one cache line at a time to prevent reading over page
-   boundary
-   */
-  __m512i hay1;
-  while (nullmask == 0)
-    {
-      hay0 = _mm512_loadu_si512 (haystack + hay_index);
-      hay1 = _mm512_load_si512 (haystack + hay_index
-                                + 1); // Always 64 byte aligned
-      nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1));
-      /* Compare only till null char */
-      cmpmask = nullmask ^ (nullmask - ONE_64BIT);
-      k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
-      k1 = _mm512_cmpeq_epi8_mask (hay1, ned1);
-      /* k2 masks tell us if both chars from needle match */
-      k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
-      /* For every match, compare full strings for potential match */
-      while (k2)
-        {
-          uint64_t bitcount = _tzcnt_u64 (k2);
-          k2 = _blsr_u64 (k2);
-          size_t match_pos = hay_index + bitcount - edge;
-          if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
-              < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
-            {
-              /*
-               * Use vector compare as long as you are not crossing a page
-               */
-              if (verify_string_match_avx512 (haystack, match_pos, ned,
-                                              ned_load_mask, ned_zmm))
-                return (char *)haystack + match_pos;
-            }
-          else
-            {
-              /* Compare byte by byte */
-              if (verify_string_match (haystack, match_pos, ned, 0))
-                return (char *)haystack + match_pos;
-            }
-        }
-      hay_index += ZMM_SIZE_IN_BYTES;
-    }
-  return NULL;
-}
+#define FUNC_NAME __strstr_avx512
+#define VEC __m512i
+#define MASK uint64_t
+#define LOAD _mm512_load_si512
+#define LOADU _mm512_loadu_si512
+#define MOVEMASK8 _mm512_movemask_epi8
+#define CMPEQ8_MASK _mm512_cmpeq_epi8_mask
+#define MASK_CMPNEQ8_MASK _mm512_mask_cmpneq_epi8_mask
+#define SETONE8 _mm512_set1_epi8
+#define SETZERO _mm512_setzero_si512
+#define TESTN8_MASK _mm512_testn_epi8_mask
+#define MASK_TESTN8_MASK _mm512_mask_testn_epi8_mask
+#define TZCNT _tzcnt_u64
+#define BLSR _blsr_u64
+#define BZHI _bzhi_u64
+#define MASKZ_LOADU8 _mm512_maskz_loadu_epi8
+
+#include "strstr-avx-base.h"