[v4,1/2] libstdc++: Refactor _ScalarAbi<N> into _Abi<N, N>

Message ID bmm.hj4ei233v2.gcc.gcc-TEST.mkretz.145.4.1@forge-stage.sourceware.org
State New
Headers
Series std-simd-complex |

Checks

Context Check Description
linaro-tcwg-bot/tcwg_gcc_build--master-arm success Build passed

Commit Message

Matthias Kretz via Sourceware Forge June 3, 2026, 5:42 a.m. UTC
  From: Matthias Kretz <m.kretz@gsi.de>

Before this change _Ap::_S_is_bitmask would pick up false from
_ScalarAbi<N>. After __scalar_abi_tag now identifies any _Abi<N, N, V>,
where V can also identify bit-masks, the short-cut of setting
_S_use_bitmask to _Ap::_S_is_bitmask is wrong. It would be correct to
have it say _Ap::_S_is_bitmask && !__scalar_abi_tag<_Ap>. I decided to
implement the latter only in the _S_nreg == 1 specialization and have
the higher ups inherit the value from their vec/mask member. The
_S_is_bitmask bit is not erased for __scalar_abi_tag since it makes a
difference for __abi_rebind.

Signed-off-by: Matthias Kretz <m.kretz@gsi.de>

libstdc++-v3/ChangeLog:

	* include/bits/simd_details.h (_ScalarAbi): Remove.
	(__scalar_abi_tag): Identify _Abi<N, N> as scalar now.
	(__native_abi): Replace _ScalarAbi<1> with _Abi_t<1, 1, ...>.
	(__abi_rebind): Refactor rebinding from/to __scalar_abi_tag.
	* include/bits/simd_mask.h (_S_use_bitmask): Only true if
	!_S_is_scalar.
	(_M_and_neighbors, _M_or_neighbors): Add case for _S_is_scalar
	where the and/or must be executed one step earlier.
	(_M_reduce_min_index, _M_reduce_max_index): Delete dead code.
	* include/bits/simd_vec.h (_S_use_bitmask): Inherit the value
	from the first data member.
	* testsuite/std/simd/traits_impl.cc: Adjust for the removal of
	_ScalarAbi.
---
 libstdc++-v3/include/bits/simd_details.h      | 86 +++++++------------
 libstdc++-v3/include/bits/simd_mask.h         | 38 +++++---
 libstdc++-v3/include/bits/simd_vec.h          |  2 +-
 .../testsuite/std/simd/traits_impl.cc         | 11 +--
 4 files changed, 58 insertions(+), 79 deletions(-)
  

Patch

diff --git a/libstdc++-v3/include/bits/simd_details.h b/libstdc++-v3/include/bits/simd_details.h
index a1acc5bd9464..e6185fac64a3 100644
--- a/libstdc++-v3/include/bits/simd_details.h
+++ b/libstdc++-v3/include/bits/simd_details.h
@@ -241,46 +241,21 @@  namespace simd
 #endif
 
   /** @internal
-   * This ABI tag describes basic_vec objects that store one element per data member and basic_mask
-   * objects that store one bool data members.
+   * @brief This ABI tag determines the data member(s) of basic_vec and basic_mask.
    *
-   * @tparam _Np   The number of elements, which also matches the number of data members in
-   *               basic_vec and basic_mask.
-   */
-  template <int _Np = 1>
-    struct _ScalarAbi
-    {
-      static constexpr int _S_size = _Np;
-
-      static constexpr int _S_nreg = _Np;
-
-      static constexpr _AbiVariant _S_variant = {};
-
-      template <typename _Tp>
-	using _DataType = __canonical_vec_type_t<_Tp>;
-
-      static constexpr bool _S_is_vecmask = false;
-
-      // in principle a bool is a 1-bit bitmask, but this is asking for an AVX512 bitmask
-      static constexpr bool _S_is_bitmask = false;
-
-      template <size_t>
-	using _MaskDataType = bool;
-
-      template <int _N2, int _Nreg2 = _N2>
-	static consteval _ScalarAbi<_N2>
-	_S_resize()
-	{
-	  static_assert(_N2 == _Nreg2);
-	  return {};
-	}
-    };
-
-  /** @internal
-   * This ABI tag describes basic_vec objects that store one or more objects declared with the
-   * [[gnu::vector_size(N)]] attribute.
-   * Applied to basic_mask objects, this ABI tag either describes corresponding vector-mask objects
-   * or bit-mask objects. Which one is used is determined via @p _Var.
+   * @p _Nreg determines the number of recursive basic_vec/basic_mask data members where @p _Nreg is
+   * equal to 1. With @p _Nreg equal to 1, the basic_vec/basic_mask holds one vector builtin (@p
+   * _Np greater than 1) or a scalar (@ _Np equal to 1).
+   * @f$\lceil\frac{\mathtt{Np}}{\mathtt{Nreg}}\rceil@f$ therefore determines the number of elements
+   * in a register (except for a remainder where it can be smaller). If @p _Np equals @p _Nreg, (the
+   * aforementioned quotient is 1), then basic_vec (recursively) holds non-vector data members and
+   * basic_mask holds bools.
+   *
+   * The @p _Var parameter determines details about the data member in the one register case. Masks
+   * can be represented as vector masks (the default comparison result of GNU vector builtins),
+   * bit-masks as used by AVX-512, bit-masks as used by ARM SVE (not yet implemented), or a single
+   * bool (for the @p _Np equals 1 case). For basic_mask it determines the actual data layout and
+   * for basic_mask it determines the result of compares.
    *
    * @tparam _Np    The number of elements.
    * @tparam _Nreg  The number of registers needed to store @p _Np elements.
@@ -391,9 +366,13 @@  namespace simd
 	    { __x.template _S_resize<_Tp::_S_size, _Tp::_S_nreg>() } -> same_as<_Tp>;
 	  };
 
+  /** @internal
+   * Satisfied if @p _Tp is a valid simd ABI tag and one element is stored per register (number of
+   * registers equals size).
+   */
   template <typename _Tp>
     concept __scalar_abi_tag
-      = same_as<_Tp, _ScalarAbi<_Tp::_S_size>> && __abi_tag<_Tp>;
+      = same_as<_Tp, _Abi_t<_Tp::_S_size, _Tp::_S_size, _Tp::_S_variant>> && __abi_tag<_Tp>;
 
   // Determine if math functions must *raise* floating-point exceptions.
   // math_errhandling may expand to an extern symbol, in which case we must assume fp exceptions
@@ -760,7 +739,7 @@  namespace simd
       else if constexpr (_Traits._M_have_avx512f())
 	return _Abi_t<64 / __adj_sizeof, 1, _AbiVariant::_BitMask>();
       else if constexpr (is_same_v<_Tp, _Float16> && !_Traits._M_have_f16c())
-	return _ScalarAbi<1>();
+	return _Abi_t<1, 1>();
       else if constexpr (_Traits._M_have_avx2())
 	return _Abi_t<32 / __adj_sizeof, 1>();
       else if constexpr (_Traits._M_have_avx() && is_floating_point_v<_Tp>)
@@ -772,7 +751,7 @@  namespace simd
 	return _Abi_t<16 / __adj_sizeof, 1>();
       // no MMX: we can't emit EMMS where it would be necessary
       else
-	return _ScalarAbi<1>();
+	return _Abi_t<1, 1>();
     }
 
 #else
@@ -794,7 +773,7 @@  namespace simd
       if constexpr (!__vectorizable<_Tp>)
 	return _InvalidAbi();
       else
-	return _ScalarAbi<1>();
+	return _Abi_t<1, 1>();
     }
 
 #endif
@@ -850,17 +829,19 @@  namespace simd
       if constexpr (_Np <= 0 || !__vectorizable<_Tp>)
 	return _InvalidAbi();
 
-      else if constexpr (__scalar_abi_tag<_A0>)
-	return _A0::template _S_resize<_Np>();
-
       else
 	{
 	  using _Native = remove_const_t<decltype(std::simd::__native_abi<_Tp>())>;
 	  static_assert(0 != _Native::_S_size);
 	  constexpr int __nreg = __div_ceil(_Np, _Native::_S_size);
 
-	  if constexpr (__scalar_abi_tag<_Native>)
-	    return _Native::template _S_resize<_Np>();
+	  // __scalar_abi_tag is sticky (unless we reach size 1, where we can't know whether it was
+	  // an explicit __scalar_abi_tag before some resize_t)
+	  if constexpr (__scalar_abi_tag<_Native> || (__scalar_abi_tag<_A0> && _A0::_S_size >= 2))
+	    {
+		return _A0::template _S_resize<_Np, _Np>();
+	    }
+
 	  else
 	    return _Abi_t<_Native::_S_size, 1, __filter_abi_variant(_A0::_S_variant,
 								    _AbiVariant::_MaskVariants)
@@ -885,9 +866,6 @@  namespace simd
       if constexpr (_Bytes == 0 || _Np <= 0)
 	return _InvalidAbi();
 
-      else if constexpr (__scalar_abi_tag<_A0>)
-	return _A0::template _S_resize<_Np>();
-
 #if _GLIBCXX_X86
       // AVX w/o AVX2:
       // e.g. resize_t<8, mask<float, Whatever>> needs to be _Abi<8, 1> not _Abi<8, 2>
@@ -939,12 +917,6 @@  namespace simd
       if (__b0 != __b1)
 	return true;
 
-      // everything is better than _ScalarAbi, except when converting to a single bool
-      if constexpr (__scalar_abi_tag<_To>)
-	return __n > 1;
-      else if constexpr (__scalar_abi_tag<_From>)
-	return true;
-
       // converting to a bit-mask is better
       else if constexpr (_To::_S_is_vecmask != _From::_S_is_vecmask)
 	return _To::_S_is_vecmask; // to vector-mask is explicit
diff --git a/libstdc++-v3/include/bits/simd_mask.h b/libstdc++-v3/include/bits/simd_mask.h
index 0a7cfa03cedd..81a0825ec6ce 100644
--- a/libstdc++-v3/include/bits/simd_mask.h
+++ b/libstdc++-v3/include/bits/simd_mask.h
@@ -543,7 +543,7 @@  namespace simd
 
       static constexpr bool _S_is_scalar = _S_has_bool_member;
 
-      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask;
+      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask && !_S_is_scalar;
 
       static constexpr int _S_full_size = [] {
 	if constexpr (_S_is_scalar)
@@ -1519,8 +1519,16 @@  namespace simd
       constexpr basic_mask&
       _M_and_neighbors()
       {
-	_M_data0._M_and_neighbors();
-	_M_data1._M_and_neighbors();
+	if constexpr (_S_size == 2)
+	  {
+	    static_assert(_S_is_scalar);
+	    _M_data0 = _M_data1 = _M_data0 && _M_data1;
+	  }
+	else
+	  {
+	    _M_data0._M_and_neighbors();
+	    _M_data1._M_and_neighbors();
+	  }
 	return *this;
       }
 
@@ -1528,8 +1536,16 @@  namespace simd
       constexpr basic_mask&
       _M_or_neighbors()
       {
-	_M_data0._M_or_neighbors();
-	_M_data1._M_or_neighbors();
+	if constexpr (_S_size == 2)
+	  {
+	    static_assert(_S_is_scalar);
+	    _M_data0 = _M_data1 = _M_data0 || _M_data1;
+	  }
+	else
+	  {
+	    _M_data0._M_or_neighbors();
+	    _M_data1._M_or_neighbors();
+	  }
 	return *this;
       }
 
@@ -1650,7 +1666,7 @@  namespace simd
 	else if constexpr (_M_data1._S_has_bool_member)
 	  // in some cases the last element can be 'bool' instead of bit-/vector-mask;
 	  // e.g. mask<short, 17> is {mask<short, 16>, mask<short, 1>}, where the latter uses
-	  // _ScalarAbi<1>, which is stored as 'bool'
+	  // _Abi<1, 1>, which is stored as 'bool'
 	  return __i < _N0 ? _M_data0[__i] : _M_data1[__i - _N0];
 	else if constexpr (abi_type::_S_is_bitmask)
 	  {
@@ -1929,10 +1945,7 @@  namespace simd
 	  {
 	    const auto __bits = _M_to_uint();
 	    __glibcxx_simd_precondition(__bits, "An empty mask does not have a min_index.");
-	    if constexpr (_S_size == 1)
-	      return 0;
-	    else
-	      return __countr_zero(_M_to_uint());
+	    return __countr_zero(_M_to_uint());
 	  }
 	else if (_M_data0._M_none_of())
 	  return _M_data1._M_reduce_min_index() + _N0;
@@ -1948,10 +1961,7 @@  namespace simd
 	  {
 	    const auto __bits = _M_to_uint();
 	    __glibcxx_simd_precondition(__bits, "An empty mask does not have a max_index.");
-	    if constexpr (_S_size == 1)
-	      return 0;
-	    else
-	      return __highest_bit(_M_to_uint());
+	    return __highest_bit(_M_to_uint());
 	  }
 	else if (_M_data1._M_none_of())
 	  return _M_data0._M_reduce_max_index();
diff --git a/libstdc++-v3/include/bits/simd_vec.h b/libstdc++-v3/include/bits/simd_vec.h
index 5f3bd7fd2f61..5624ec781426 100644
--- a/libstdc++-v3/include/bits/simd_vec.h
+++ b/libstdc++-v3/include/bits/simd_vec.h
@@ -1776,7 +1776,7 @@  namespace simd
 
       _DataType1 _M_data1;
 
-      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask;
+      static constexpr bool _S_use_bitmask = _DataType0::_S_use_bitmask;
 
       static constexpr bool _S_is_partial = _DataType1::_S_is_partial;
 
diff --git a/libstdc++-v3/testsuite/std/simd/traits_impl.cc b/libstdc++-v3/testsuite/std/simd/traits_impl.cc
index 2f705c7df2f7..94c6843b6228 100644
--- a/libstdc++-v3/testsuite/std/simd/traits_impl.cc
+++ b/libstdc++-v3/testsuite/std/simd/traits_impl.cc
@@ -49,24 +49,21 @@  void test()
   static_assert(sizeof(_Bitmask<3>) == 1);
   static_assert(sizeof(_Bitmask<30>) == 4);
 
-  static_assert(__scalar_abi_tag<_ScalarAbi<1>>);
-  static_assert(__scalar_abi_tag<_ScalarAbi<2>>);
-  static_assert(!__scalar_abi_tag<_Abi_t<1, 1>>);
-
-  static_assert(__abi_tag<_ScalarAbi<1>>);
-  static_assert(__abi_tag<_ScalarAbi<2>>);
+  static_assert(__scalar_abi_tag<_Abi_t<1, 1>>);
+  static_assert(__scalar_abi_tag<_Abi_t<2, 2>>);
+  static_assert(!__scalar_abi_tag<_Abi_t<2, 1>>);
 
   using AN = decltype(__native_abi<float>());
   using A1 = decltype(__native_abi<float>()._S_resize<1>());
   static_assert(A1::_S_size == 1);
   static_assert(A1::_S_nreg == 1);
   static_assert(A1::_S_variant == AN::_S_variant);
-  static_assert(__scalar_abi_tag<A1> == __scalar_abi_tag<AN>);
   static_assert(std::is_same_v<decltype(__abi_rebind<float, AN::_S_size, A1>()), AN>);
   if constexpr (AN::_S_size >= 2) // the target has SIMD support for float
     {
       {
 	using A2 = decltype(__abi_rebind<float, 2, AN>());
+	static_assert(__scalar_abi_tag<A2> == __scalar_abi_tag<AN>);
 	static_assert(A2::_S_size == 2);
 	static_assert(A2::_S_nreg == 1);
 	static_assert(A2::_S_variant == AN::_S_variant);