From patchwork Fri Feb 23 07:28:13 2024 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: James Tirta Halim X-Patchwork-Id: 86250 Return-Path: X-Original-To: patchwork@sourceware.org Delivered-To: patchwork@sourceware.org Received: from server2.sourceware.org (localhost [IPv6:::1]) by sourceware.org (Postfix) with ESMTP id 0A378385828D for ; Fri, 23 Feb 2024 07:28:58 +0000 (GMT) X-Original-To: libc-alpha@sourceware.org Delivered-To: libc-alpha@sourceware.org Received: from mail-pl1-x629.google.com (mail-pl1-x629.google.com [IPv6:2607:f8b0:4864:20::629]) by sourceware.org (Postfix) with ESMTPS id 027D03858D1E for ; Fri, 23 Feb 2024 07:28:31 +0000 (GMT) DMARC-Filter: OpenDMARC Filter v1.4.2 sourceware.org 027D03858D1E Authentication-Results: sourceware.org; dmarc=pass (p=none dis=none) header.from=gmail.com Authentication-Results: sourceware.org; spf=pass smtp.mailfrom=gmail.com ARC-Filter: OpenARC Filter v1.0.0 sourceware.org 027D03858D1E Authentication-Results: server2.sourceware.org; arc=none smtp.remote-ip=2607:f8b0:4864:20::629 ARC-Seal: i=1; a=rsa-sha256; d=sourceware.org; s=key; t=1708673314; cv=none; b=hYyOX2QUKlZ6OxrO8MKsQHlgCtrTW/c7+gLfEzJididj6lRAOCxfUn7ZRU3IJ56FhUsClc584nZxeLurRket0gS7VTTz6kt/pm2ixJPI3WCaB49qSnHCr4fI+C4z9cRMYbJIht5cMuTI9y/bx6V3CIhKdQAeinC5omA3GUXk/QM= ARC-Message-Signature: i=1; a=rsa-sha256; d=sourceware.org; s=key; t=1708673314; c=relaxed/simple; bh=GPwL/Ump4PAfJTy2BDQPlcMatT04PWxiY5lwQEegQr8=; h=DKIM-Signature:From:To:Subject:Date:Message-ID:MIME-Version; b=VrJcnxHyqzyI2xNP2yLVHpnN3+IjjCk/CQIDMgVPyGEKz9n9DzZzPUHy5xYM+XXYmvtwVbiZSfgoX9LuuZ/oEfZ5nUSRTka6K5SeeOZA6ipOJeNXjDTUAMfFtmmX+XzXs0UDhud8PUzR6j7wJh1kStptXcZc/N9EfZds1PEQH/I= ARC-Authentication-Results: i=1; server2.sourceware.org Received: by mail-pl1-x629.google.com with SMTP id d9443c01a7336-1dc1ff697f9so4957175ad.0 for ; Thu, 22 Feb 2024 23:28:30 -0800 (PST) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=gmail.com; s=20230601; t=1708673309; x=1709278109; darn=sourceware.org; h=content-transfer-encoding:mime-version:message-id:date:subject:cc :to:from:from:to:cc:subject:date:message-id:reply-to; bh=3moKBV0xvG8WWBuXmtqdOg4/kaAZ/KQwdBsjvja8cGs=; b=WdHa2hFIQ48HqcvIWffPaZwYKxPofdGRzf8E6NK4coBFKlXH7wc8/WG8NLPhFb2kGt YeUPN2AywrrXR0ZLVMQ26qwTzLn/THE72wnNnxBVwQTT7be6H+lbqdmwjiq0jmCE+oHm k50qH4pRV/H+CU1YCycLYnybeD5I91elJTFzL0h/TjTO5YABMzx9yJ+6ACUJBL0acJoZ jfcsnSMpgJltv4KySXvFXgaxhG4Okqy1YZ0o1Dc1Ktd0eDZg6eQ6ze6Ns4dGM85Yyybr y4T9T2GpUuCqy7SKxCv30ito0kCw+vELiovbGt7CfjjXNRpy0fHsEqSizVL1rnBEdCKQ vu6A== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; t=1708673309; x=1709278109; h=content-transfer-encoding:mime-version:message-id:date:subject:cc :to:from:x-gm-message-state:from:to:cc:subject:date:message-id :reply-to; bh=3moKBV0xvG8WWBuXmtqdOg4/kaAZ/KQwdBsjvja8cGs=; b=UyAfT6l0VLgkFLdV5AI/40l7t3xRYPw8rYbEF9+3Quv7uPjj9GLLvzOoWcNfuFBFvp QFfWMrEbZYByqIDglf7tbOc797UDo0rUvBbYWlMB0jFPhPzgz/QMsLBaAazQPZuQLVz3 4eFgOEk1Ikqnifb83mmPWpoArnEo3oor/G1cVRJotaYfcLoPXZSYTlZI+8dYrNxnQ6XU 0GIEv7rD/vfl2IlPfZPQidPK59GgD73AZehvOLAVHdWvKTsnZqld5A3nmDI7fqhNR+f2 grEOdzO3ISeOR6JXKdzEo/hV6ArmY0zEEkw/R9HibqUm802rt/7YP/Rb2Y4VQdj9tbR7 FPHg== X-Gm-Message-State: AOJu0YxwaG7Kr2CR7TuYqdo06KT6jm7e7UHso0ZCxWIpyEZUxnjVHJfP UWmuMRltvLuRXUgAYXmatloQLWjJf2ViRRJVWaX26pbUsynwhEELMrZKHvmQZpk= X-Google-Smtp-Source: AGHT+IHQojnd0bYqCf8gtVzxRWDJOBo509QLfL4sfkik3Fj6VSN1D2tDh4G1oMmDZuUka/oGOTyUIQ== X-Received: by 2002:a17:902:b947:b0:1db:e453:da81 with SMTP id h7-20020a170902b94700b001dbe453da81mr903188pls.29.1708673308415; Thu, 22 Feb 2024 23:28:28 -0800 (PST) Received: from localhost.localdomain ([2001:448a:20a0:5e8b:5b2b:1e14:d5af:d2fd]) by smtp.gmail.com with ESMTPSA id x11-20020a170902a38b00b001d8d1a2e5fesm11324120pla.196.2024.02.22.23.28.26 (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Thu, 22 Feb 2024 23:28:27 -0800 (PST) From: James Tirta Halim To: libc-alpha@sourceware.org Cc: James Tirta Halim Subject: [PATCH] Add strstr-avx2 based on strstr-avx512 Date: Fri, 23 Feb 2024 14:28:13 +0700 Message-ID: <20240223072813.95327-1-tirtajames45@gmail.com> X-Mailer: git-send-email 2.43.2 MIME-Version: 1.0 X-Spam-Status: No, score=-10.9 required=5.0 tests=BAYES_00, DKIM_SIGNED, DKIM_VALID, DKIM_VALID_AU, DKIM_VALID_EF, FREEMAIL_ENVFROM_END_DIGIT, FREEMAIL_FROM, GIT_PATCH_0, KAM_NUMSUBJECT, KAM_SHORT, RCVD_IN_DNSWL_NONE, SPF_HELO_NONE, SPF_PASS, TXREP, T_SCC_BODY_TEXT_LINE autolearn=ham autolearn_force=no version=3.4.6 X-Spam-Checker-Version: SpamAssassin 3.4.6 (2021-04-09) on server2.sourceware.org X-BeenThere: libc-alpha@sourceware.org X-Mailman-Version: 2.1.30 Precedence: list List-Id: Libc-alpha mailing list List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: libc-alpha-bounces+patchwork=sourceware.org@sourceware.org Create a unified implementation for strstr-avx2 and strstr-avx512 based on the existing strstr-avx512 (no changes were made in the implementation). strstr-avx2 implements avx512 instructions using equivalent avx2 or generic instructions. basic_strstr twoway_strstr __strstr_avx2 __strstr_avx512 __strstr_sse2_unaligned __strstr_generic average: 211775 32055.7 4961.68 3725.31 14687 17268.7 total: 1.12876e+08 1.70857e+07 2.64458e+06 1.98559e+06 7.82818e+06 9.2042e+06 Passes test-strstr. --- sysdeps/x86_64/multiarch/Makefile | 2 + sysdeps/x86_64/multiarch/ifunc-impl-list.c | 4 + sysdeps/x86_64/multiarch/strstr-avx-base.h | 268 +++++++++++++++++++++ sysdeps/x86_64/multiarch/strstr-avx2.c | 19 ++ sysdeps/x86_64/multiarch/strstr-avx512.c | 221 ++--------------- 5 files changed, 312 insertions(+), 202 deletions(-) create mode 100644 sysdeps/x86_64/multiarch/strstr-avx-base.h create mode 100644 sysdeps/x86_64/multiarch/strstr-avx2.c diff --git a/sysdeps/x86_64/multiarch/Makefile b/sysdeps/x86_64/multiarch/Makefile index d3d2270394..1cf74a13a2 100644 --- a/sysdeps/x86_64/multiarch/Makefile +++ b/sysdeps/x86_64/multiarch/Makefile @@ -117,6 +117,7 @@ sysdep_routines += \ strrchr-evex512 \ strrchr-sse2 \ strspn-sse4 \ + strstr-avx2 \ strstr-avx512 \ strstr-sse2-unaligned \ varshift \ @@ -126,6 +127,7 @@ CFLAGS-strcspn-sse4.c += -msse4 CFLAGS-strpbrk-sse4.c += -msse4 CFLAGS-strspn-sse4.c += -msse4 +CFLAGS-strstr-avx2.c += -Wno-error=stringop-truncation -Wno-stringop-truncation -mavx2 -mbmi -mbmi2 -O3 CFLAGS-strstr-avx512.c += -mavx512f -mavx512vl -mavx512dq -mavx512bw -mbmi -mbmi2 -O3 endif diff --git a/sysdeps/x86_64/multiarch/ifunc-impl-list.c b/sysdeps/x86_64/multiarch/ifunc-impl-list.c index c4a21d4b7c..7b651c7a9c 100644 --- a/sysdeps/x86_64/multiarch/ifunc-impl-list.c +++ b/sysdeps/x86_64/multiarch/ifunc-impl-list.c @@ -790,6 +790,10 @@ __libc_ifunc_impl_list (const char *name, struct libc_ifunc_impl *array, /* Support sysdeps/x86_64/multiarch/strstr.c. */ IFUNC_IMPL (i, name, strstr, + IFUNC_IMPL_ADD (array, i, strstr, + (CPU_FEATURE_USABLE (AVX2) + && CPU_FEATURE_USABLE (BMI2)), + __strstr_avx2) IFUNC_IMPL_ADD (array, i, strstr, (CPU_FEATURE_USABLE (AVX512VL) && CPU_FEATURE_USABLE (AVX512BW) diff --git a/sysdeps/x86_64/multiarch/strstr-avx-base.h b/sysdeps/x86_64/multiarch/strstr-avx-base.h new file mode 100644 index 0000000000..e9f736606e --- /dev/null +++ b/sysdeps/x86_64/multiarch/strstr-avx-base.h @@ -0,0 +1,268 @@ +/* Copyright (C) 2022-2024 Free Software Foundation, Inc. + This file is part of the GNU C Library. + + The GNU C Library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + The GNU C Library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with the GNU C Library; if not, see + . */ + +#include +#include +#include +#include + +#define VEC_SIZE sizeof (VEC) +#define ONES ((MASK) -1) +#define ONE ((MASK) 0x1) +#define PAGE_SIZE 4096 +#define CVTMASK(...) (MASK) (__VA_ARGS__) +#define KSHIFTRI(x, y) ((x) >> (y)) +#define KAND_MASK(x, y) ((x) & (y)) + +#ifndef FUNC_NAME +# define FUNC_NAME __strstr_avx2 +#endif +#ifndef VEC +# define VEC __m256i +#endif +#ifndef MASK +# define MASK uint32_t +#endif +#ifndef LOAD +# define LOAD _mm256_load_si256 +#endif +#ifndef LOADU +# define LOADU _mm256_loadu_si256 +#endif +#ifndef CMPEQ8_MASK +# define CMPEQ8_MASK(x, y) \ + (MASK) _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, y)) +#endif +#ifndef MASK_CMPNEQ8_MASK +# define MASK_CMPNEQ8_MASK(m, a, b) ((~CMPEQ8_MASK (a, b)) & (m)) +#endif +#ifndef SETONE8 +# define SETONE8 _mm256_set1_epi8 +#endif +#ifndef SETZERO +# define SETZERO() _mm256_setzero_si256 () +#endif +#ifndef TESTN8_MASK +# define TESTN8_MASK(x, y) \ + _mm256_movemask_epi8 (_mm256_cmpeq_epi8 (x, SETZERO ())) +#endif +#ifndef MASK_TESTN8_MASK +# define MASK_TESTN8_MASK(m, a, b) ((MASK) TESTN8_MASK (a, a) & (m)) +#endif +#ifndef TZCNT +# define TZCNT _tzcnt_u32 +#endif +#ifndef BLSR +# define BLSR _blsr_u32 +#endif +#ifndef BZHI +# define BZHI _bzhi_u32 +#endif +#ifndef MASKZ_LOADU8 +# define MASKZ_LOADU8(m, x) maskz_loadu8 (x) +static inline VEC __attribute__ ((always_inline)) maskz_loadu8 (const void *x) +{ + VEC ret; + strncpy ((char *) &ret, (const char *) x, VEC_SIZE); + return ret; +} +#endif + +/* + Returns the index of the first edge within the needle, returns 0 if no edge + is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg' + */ +static inline size_t __attribute__ ((always_inline)) +find_edge_in_needle (const char *ne) +{ + size_t ind = 0; + while (ne[ind + 1] != '\0') + { + if (ne[ind] != ne[ind + 1]) + return ind; + else + ind = ind + 1; + } + return 0; +} + +/* + Compare needle with hs byte by byte at specified location + */ +static inline bool __attribute__ ((always_inline)) +verify_string_match (const char *hay, const size_t hay_index, const char *ne, + size_t ind) +{ + while (ne[ind] != '\0') + { + if (ne[ind] != hay[hay_index + ind]) + return false; + ind = ind + 1; + } + return true; +} + +/* + Compare needle with hs at specified location. The first VEC_SIZE bytes are + compared using a ZMM register. + */ +static inline bool __attribute__ ((always_inline)) +verify_string_match_vector (const char *hay, const size_t hay_index, + const char *ne, const MASK ned_mask, + const VEC ned_zmm) +{ + /* check first VEC_SIZE bytes using zmm and then scalar */ + VEC hay_zmm = LOADU ((const VEC *) (hay + hay_index)); // safe to do so + MASK match = MASK_CMPNEQ8_MASK (ned_mask, hay_zmm, ned_zmm); + if (match != 0x0) // failed the first few chars + return false; + else if (ned_mask == ONES) + return verify_string_match (hay, hay_index, ne, VEC_SIZE); + return true; +} + +char * +FUNC_NAME (const char *hs, const char *ne) +{ + char first = ne[0]; + if (__glibc_unlikely (first == '\0')) + return (char *) hs; + if (ne[1] == '\0') + return (char *) strchr (hs, ne[0]); + + size_t edge = find_edge_in_needle (ne); + + /* ensure hs is as long as the pos of edge in needle */ + for (unsigned int ii = 0; ii < edge; ++ii) + { + if (__glibc_unlikely (hs[ii] == '\0')) + return NULL; + } + + /* + Load VEC_SIZE bytes of the needle and save it to a zmm register + Read one cache line at a time to avoid loading across a page boundary + */ + MASK ned_load_mask + = BZHI (ONES, VEC_SIZE - ((uintptr_t) (ne) & (VEC_SIZE - 1))); + VEC ned_zmm = MASKZ_LOADU8 (ned_load_mask, (const VEC *) ne); + MASK ned_nullmask = MASK_TESTN8_MASK (ned_load_mask, ned_zmm, ned_zmm); + + if (__glibc_unlikely (ned_nullmask == 0x0)) + { + ned_zmm = LOADU ((const VEC *) ne); + ned_nullmask = TESTN8_MASK (ned_zmm, ned_zmm); + ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE); + if (ned_nullmask != 0x0) + ned_load_mask = ned_load_mask >> 1; + } + else + { + ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE); + ned_load_mask = ned_load_mask >> 1; + } + const VEC ned0 = SETONE8 (ne[edge]); + const VEC ned1 = SETONE8 (ne[edge + 1]); + + /* + Read the bytes of hs in the current cache line + */ + size_t hay_index = edge; + MASK loadmask = BZHI ( + ONES, VEC_SIZE - ((uintptr_t) (hs + hay_index) & (VEC_SIZE - 1))); + /* First load is a partial cache line */ + VEC hay0 = MASKZ_LOADU8 (loadmask, (const VEC *) (hs + hay_index)); + /* Search for NULL and compare only till null char */ + MASK nullmask = CVTMASK (MASK_TESTN8_MASK (loadmask, hay0, hay0)); + MASK cmpmask = nullmask ^ (nullmask - ONE); + cmpmask = cmpmask & CVTMASK (loadmask); + /* Search for the 2 characters of needle */ + MASK k0 = CMPEQ8_MASK (hay0, ned0); + MASK k1 = CMPEQ8_MASK (hay0, ned1); + k1 = KSHIFTRI (k1, 1); + /* k2 masks tell us if both chars from needle match */ + MASK k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask; + /* For every match, search for the entire needle for a full match */ + while (k2) + { + MASK bitcount = TZCNT (k2); + k2 = BLSR (k2); + size_t match_pos = hay_index + bitcount - edge; + if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1)) + < PAGE_SIZE - 1 - VEC_SIZE) + { + /* + * Use vector compare as long as you are not crossing a page + */ + if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask, + ned_zmm)) + return (char *) hs + match_pos; + } + else + { + if (verify_string_match (hs, match_pos, ne, 0)) + return (char *) hs + match_pos; + } + } + /* We haven't checked for potential match at the last char yet */ + hs = (const char *) (((uintptr_t) (hs + hay_index) | (VEC_SIZE - 1))); + hay_index = 0; + + /* + Loop over one cache line at a time to prevent reading over page + boundary + */ + VEC hay1; + while (nullmask == 0) + { + hay0 = LOADU ((const VEC *) (hs + hay_index)); + hay1 = LOAD ( + (const VEC *) (hs + hay_index + 1)); // Always VEC_SIZE byte aligned + nullmask = CVTMASK (TESTN8_MASK (hay1, hay1)); + /* Compare only till null char */ + cmpmask = nullmask ^ (nullmask - ONE); + k0 = CMPEQ8_MASK (hay0, ned0); + k1 = CMPEQ8_MASK (hay1, ned1); + /* k2 masks tell us if both chars from needle match */ + k2 = CVTMASK (KAND_MASK (k0, k1)) & cmpmask; + /* For every match, compare full strings for potential match */ + while (k2) + { + MASK bitcount = TZCNT (k2); + k2 = BLSR (k2); + size_t match_pos = hay_index + bitcount - edge; + if (((uintptr_t) (hs + match_pos) & (PAGE_SIZE - 1)) + < PAGE_SIZE - 1 - VEC_SIZE) + { + /* + * Use vector compare as long as you are not crossing a page + */ + if (verify_string_match_vector (hs, match_pos, ne, ned_load_mask, + ned_zmm)) + return (char *) hs + match_pos; + } + else + { + /* Compare byte by byte */ + if (verify_string_match (hs, match_pos, ne, 0)) + return (char *) hs + match_pos; + } + } + hay_index += VEC_SIZE; + } + return NULL; +} diff --git a/sysdeps/x86_64/multiarch/strstr-avx2.c b/sysdeps/x86_64/multiarch/strstr-avx2.c new file mode 100644 index 0000000000..e86ffd160f --- /dev/null +++ b/sysdeps/x86_64/multiarch/strstr-avx2.c @@ -0,0 +1,19 @@ +/* Copyright (C) 2022-2024 Free Software Foundation, Inc. + This file is part of the GNU C Library. + + The GNU C Library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + The GNU C Library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with the GNU C Library; if not, see + . */ + +#define FUNC_NAME __strstr_avx2 +#include "strstr-avx-base.h" diff --git a/sysdeps/x86_64/multiarch/strstr-avx512.c b/sysdeps/x86_64/multiarch/strstr-avx512.c index 3ac53accbd..5eb69043b9 100644 --- a/sysdeps/x86_64/multiarch/strstr-avx512.c +++ b/sysdeps/x86_64/multiarch/strstr-avx512.c @@ -1,5 +1,4 @@ -/* strstr optimized with 512-bit AVX-512 instructions - Copyright (C) 2022-2024 Free Software Foundation, Inc. +/* Copyright (C) 2022-2024 Free Software Foundation, Inc. This file is part of the GNU C Library. The GNU C Library is free software; you can redistribute it and/or @@ -16,203 +15,21 @@ License along with the GNU C Library; if not, see . */ -#include -#include -#include -#include - -#define FULL_MMASK64 0xffffffffffffffff -#define ONE_64BIT 0x1ull -#define ZMM_SIZE_IN_BYTES 64 -#define PAGESIZE 4096 - -#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__) -#define kshiftri_mask64(x, y) ((x) >> (y)) -#define kand_mask64(x, y) ((x) & (y)) - -/* - Returns the index of the first edge within the needle, returns 0 if no edge - is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg' - */ -static inline size_t -find_edge_in_needle (const char *ned) -{ - size_t ind = 0; - while (ned[ind + 1] != '\0') - { - if (ned[ind] != ned[ind + 1]) - return ind; - else - ind = ind + 1; - } - return 0; -} - -/* - Compare needle with haystack byte by byte at specified location - */ -static inline bool -verify_string_match (const char *hay, const size_t hay_index, const char *ned, - size_t ind) -{ - while (ned[ind] != '\0') - { - if (ned[ind] != hay[hay_index + ind]) - return false; - ind = ind + 1; - } - return true; -} - -/* - Compare needle with haystack at specified location. The first 64 bytes are - compared using a ZMM register. - */ -static inline bool -verify_string_match_avx512 (const char *hay, const size_t hay_index, - const char *ned, const __mmask64 ned_mask, - const __m512i ned_zmm) -{ - /* check first 64 bytes using zmm and then scalar */ - __m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so - __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm); - if (match != 0x0) // failed the first few chars - return false; - else if (ned_mask == FULL_MMASK64) - return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES); - return true; -} - -char * -__strstr_avx512 (const char *haystack, const char *ned) -{ - char first = ned[0]; - if (first == '\0') - return (char *)haystack; - if (ned[1] == '\0') - return (char *)strchr (haystack, ned[0]); - - size_t edge = find_edge_in_needle (ned); - - /* ensure haystack is as long as the pos of edge in needle */ - for (int ii = 0; ii < edge; ++ii) - { - if (haystack[ii] == '\0') - return NULL; - } - - /* - Load 64 bytes of the needle and save it to a zmm register - Read one cache line at a time to avoid loading across a page boundary - */ - __mmask64 ned_load_mask = _bzhi_u64 ( - FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63)); - __m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned); - __mmask64 ned_nullmask - = _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm); - - if (__glibc_unlikely (ned_nullmask == 0x0)) - { - ned_zmm = _mm512_loadu_si512 (ned); - ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm); - ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT); - if (ned_nullmask != 0x0) - ned_load_mask = ned_load_mask >> 1; - } - else - { - ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT); - ned_load_mask = ned_load_mask >> 1; - } - const __m512i ned0 = _mm512_set1_epi8 (ned[edge]); - const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]); - - /* - Read the bytes of haystack in the current cache line - */ - size_t hay_index = edge; - __mmask64 loadmask = _bzhi_u64 ( - FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63)); - /* First load is a partial cache line */ - __m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index); - /* Search for NULL and compare only till null char */ - uint64_t nullmask - = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0)); - uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT); - cmpmask = cmpmask & cvtmask64_u64 (loadmask); - /* Search for the 2 characters of needle */ - __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0); - __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1); - k1 = kshiftri_mask64 (k1, 1); - /* k2 masks tell us if both chars from needle match */ - uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask; - /* For every match, search for the entire needle for a full match */ - while (k2) - { - uint64_t bitcount = _tzcnt_u64 (k2); - k2 = _blsr_u64 (k2); - size_t match_pos = hay_index + bitcount - edge; - if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1)) - < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES) - { - /* - * Use vector compare as long as you are not crossing a page - */ - if (verify_string_match_avx512 (haystack, match_pos, ned, - ned_load_mask, ned_zmm)) - return (char *)haystack + match_pos; - } - else - { - if (verify_string_match (haystack, match_pos, ned, 0)) - return (char *)haystack + match_pos; - } - } - /* We haven't checked for potential match at the last char yet */ - haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63)); - hay_index = 0; - - /* - Loop over one cache line at a time to prevent reading over page - boundary - */ - __m512i hay1; - while (nullmask == 0) - { - hay0 = _mm512_loadu_si512 (haystack + hay_index); - hay1 = _mm512_load_si512 (haystack + hay_index - + 1); // Always 64 byte aligned - nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1)); - /* Compare only till null char */ - cmpmask = nullmask ^ (nullmask - ONE_64BIT); - k0 = _mm512_cmpeq_epi8_mask (hay0, ned0); - k1 = _mm512_cmpeq_epi8_mask (hay1, ned1); - /* k2 masks tell us if both chars from needle match */ - k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask; - /* For every match, compare full strings for potential match */ - while (k2) - { - uint64_t bitcount = _tzcnt_u64 (k2); - k2 = _blsr_u64 (k2); - size_t match_pos = hay_index + bitcount - edge; - if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1)) - < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES) - { - /* - * Use vector compare as long as you are not crossing a page - */ - if (verify_string_match_avx512 (haystack, match_pos, ned, - ned_load_mask, ned_zmm)) - return (char *)haystack + match_pos; - } - else - { - /* Compare byte by byte */ - if (verify_string_match (haystack, match_pos, ned, 0)) - return (char *)haystack + match_pos; - } - } - hay_index += ZMM_SIZE_IN_BYTES; - } - return NULL; -} +#define FUNC_NAME __strstr_avx512 +#define VEC __m512i +#define MASK uint64_t +#define LOAD _mm512_load_si512 +#define LOADU _mm512_loadu_si512 +#define MOVEMASK8 _mm512_movemask_epi8 +#define CMPEQ8_MASK _mm512_cmpeq_epi8_mask +#define MASK_CMPNEQ8_MASK _mm512_mask_cmpneq_epi8_mask +#define SETONE8 _mm512_set1_epi8 +#define SETZERO _mm512_setzero_si512 +#define TESTN8_MASK _mm512_testn_epi8_mask +#define MASK_TESTN8_MASK _mm512_mask_testn_epi8_mask +#define TZCNT _tzcnt_u64 +#define BLSR _blsr_u64 +#define BZHI _bzhi_u64 +#define MASKZ_LOADU8 _mm512_maskz_loadu_epi8 + +#include "strstr-avx-base.h"