i386: Utilize VCOMSBF16 for BF16 Comparisons with AVX10.2
Checks
Context |
Check |
Description |
linaro-tcwg-bot/tcwg_gcc_build--master-arm |
success
|
Build passed
|
linaro-tcwg-bot/tcwg_gcc_build--master-aarch64 |
success
|
Build passed
|
linaro-tcwg-bot/tcwg_gcc_check--master-aarch64 |
success
|
Test passed
|
linaro-tcwg-bot/tcwg_gcc_check--master-arm |
success
|
Test passed
|
Commit Message
From: Levy Hsu <admin@levyhsu.com>
This patch enables the use of the VCOMSBF16 instruction from AVX10.2 for
efficient BF16 comparisons.
Bootstrapped & regtested on x86-64-pc-linux-gnu.
Ok for trunk?
gcc/ChangeLog:
* config/i386/i386-expand.cc (ix86_expand_branch): Handle BFmode
when TARGET_AVX10_2_256 is enabled.
(ix86_prepare_fp_compare_args): Use SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
(ix86_expand_fp_movcc): Ditto.
(ix86_expand_fp_compare): Handle BFmode under IX86_FPCMP_COMI.
* config/i386/i386.cc (ix86_multiplication_cost): Use
SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
(ix86_division_cost): Ditto.
(ix86_rtx_costs): Ditto.
(ix86_vector_costs::add_stmt_cost): Ditto.
* config/i386/i386.h (SSE_FLOAT_MODE_SSEMATH_OR_HF_P): Rename to ...
(SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P): ...this, and add BFmode.
* config/i386/i386.md (*cmpibf): New define_insn.
gcc/testsuite/ChangeLog:
* gcc.target/i386/avx10_2-comibf-1.c: New test.
* gcc.target/i386/avx10_2-comibf-2.c: Ditto.
---
gcc/config/i386/i386-expand.cc | 22 ++--
gcc/config/i386/i386.cc | 22 ++--
gcc/config/i386/i386.h | 7 +-
gcc/config/i386/i386.md | 33 +++--
.../gcc.target/i386/avx10_2-comibf-1.c | 40 ++++++
.../gcc.target/i386/avx10_2-comibf-2.c | 118 ++++++++++++++++++
6 files changed, 214 insertions(+), 28 deletions(-)
create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
Comments
On Fri, Nov 1, 2024 at 8:33 AM Hongyu Wang <hongyu.wang@intel.com> wrote:
>
> From: Levy Hsu <admin@levyhsu.com>
>
> This patch enables the use of the VCOMSBF16 instruction from AVX10.2 for
> efficient BF16 comparisons.
>
> Bootstrapped & regtested on x86-64-pc-linux-gnu.
> Ok for trunk?
Ok.
>
> gcc/ChangeLog:
>
> * config/i386/i386-expand.cc (ix86_expand_branch): Handle BFmode
> when TARGET_AVX10_2_256 is enabled.
> (ix86_prepare_fp_compare_args): Use SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
> (ix86_expand_fp_movcc): Ditto.
> (ix86_expand_fp_compare): Handle BFmode under IX86_FPCMP_COMI.
> * config/i386/i386.cc (ix86_multiplication_cost): Use
> SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
> (ix86_division_cost): Ditto.
> (ix86_rtx_costs): Ditto.
> (ix86_vector_costs::add_stmt_cost): Ditto.
> * config/i386/i386.h (SSE_FLOAT_MODE_SSEMATH_OR_HF_P): Rename to ...
> (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P): ...this, and add BFmode.
> * config/i386/i386.md (*cmpibf): New define_insn.
>
> gcc/testsuite/ChangeLog:
>
> * gcc.target/i386/avx10_2-comibf-1.c: New test.
> * gcc.target/i386/avx10_2-comibf-2.c: Ditto.
> ---
> gcc/config/i386/i386-expand.cc | 22 ++--
> gcc/config/i386/i386.cc | 22 ++--
> gcc/config/i386/i386.h | 7 +-
> gcc/config/i386/i386.md | 33 +++--
> .../gcc.target/i386/avx10_2-comibf-1.c | 40 ++++++
> .../gcc.target/i386/avx10_2-comibf-2.c | 118 ++++++++++++++++++
> 6 files changed, 214 insertions(+), 28 deletions(-)
> create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
> create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
>
> diff --git a/gcc/config/i386/i386-expand.cc b/gcc/config/i386/i386-expand.cc
> index 0de0e842731..96e4659da10 100644
> --- a/gcc/config/i386/i386-expand.cc
> +++ b/gcc/config/i386/i386-expand.cc
> @@ -2531,6 +2531,10 @@ ix86_expand_branch (enum rtx_code code, rtx op0, rtx op1, rtx label)
> emit_jump_insn (gen_rtx_SET (pc_rtx, tmp));
> return;
>
> + case E_BFmode:
> + gcc_assert (TARGET_AVX10_2_256 && !flag_trapping_math);
> + goto simple;
> +
> case E_DImode:
> if (TARGET_64BIT)
> goto simple;
> @@ -2797,9 +2801,9 @@ ix86_prepare_fp_compare_args (enum rtx_code code, rtx *pop0, rtx *pop1)
> bool unordered_compare = ix86_unordered_fp_compare (code);
> rtx op0 = *pop0, op1 = *pop1;
> machine_mode op_mode = GET_MODE (op0);
> - bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HF_P (op_mode);
> + bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (op_mode);
>
> - if (op_mode == BFmode)
> + if (op_mode == BFmode && (!TARGET_AVX10_2_256 || flag_trapping_math))
> {
> rtx op = gen_lowpart (HImode, op0);
> if (CONST_INT_P (op))
> @@ -2918,10 +2922,14 @@ ix86_expand_fp_compare (enum rtx_code code, rtx op0, rtx op1)
> {
> case IX86_FPCMP_COMI:
> tmp = gen_rtx_COMPARE (CCFPmode, op0, op1);
> - if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
> - tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
> - if (unordered_compare)
> - tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
> + /* We only have vcomsbf16, No vcomubf16 nor vcomxbf16 */
> + if (GET_MODE (op0) != E_BFmode)
> + {
> + if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
> + tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
> + if (unordered_compare)
> + tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
> + }
> cmp_mode = CCFPmode;
> emit_insn (gen_rtx_SET (gen_rtx_REG (CCFPmode, FLAGS_REG), tmp));
> break;
> @@ -4636,7 +4644,7 @@ ix86_expand_fp_movcc (rtx operands[])
> && !ix86_fp_comparison_operator (operands[1], VOIDmode))
> return false;
>
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> {
> machine_mode cmode;
>
> diff --git a/gcc/config/i386/i386.cc b/gcc/config/i386/i386.cc
> index 473e4cbf10e..6ac3a5d55f2 100644
> --- a/gcc/config/i386/i386.cc
> +++ b/gcc/config/i386/i386.cc
> @@ -21324,7 +21324,7 @@ ix86_multiplication_cost (const struct processor_costs *cost,
> if (VECTOR_MODE_P (mode))
> inner_mode = GET_MODE_INNER (mode);
>
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> return inner_mode == DFmode ? cost->mulsd : cost->mulss;
> else if (X87_FLOAT_MODE_P (mode))
> return cost->fmul;
> @@ -21449,7 +21449,7 @@ ix86_division_cost (const struct processor_costs *cost,
> if (VECTOR_MODE_P (mode))
> inner_mode = GET_MODE_INNER (mode);
>
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> return inner_mode == DFmode ? cost->divsd : cost->divss;
> else if (X87_FLOAT_MODE_P (mode))
> return cost->fdiv;
> @@ -21991,7 +21991,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
> return true;
> }
>
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = cost->addss;
> else if (X87_FLOAT_MODE_P (mode))
> *total = cost->fadd;
> @@ -22198,7 +22198,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
> return false;
>
> case NEG:
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = cost->sse_op;
> else if (X87_FLOAT_MODE_P (mode))
> *total = cost->fchs;
> @@ -22306,14 +22306,14 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
> return false;
>
> case FLOAT_EXTEND:
> - if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = 0;
> else
> *total = ix86_vec_cost (mode, cost->addss);
> return false;
>
> case FLOAT_TRUNCATE:
> - if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = cost->fadd;
> else
> *total = ix86_vec_cost (mode, cost->addss);
> @@ -22323,7 +22323,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
> /* SSE requires memory load for the constant operand. It may make
> sense to account for this. Of course the constant operand may or
> may not be reused. */
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = cost->sse_op;
> else if (X87_FLOAT_MODE_P (mode))
> *total = cost->fabs;
> @@ -22334,7 +22334,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
> return false;
>
> case SQRT:
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> *total = mode == SFmode ? cost->sqrtss : cost->sqrtsd;
> else if (X87_FLOAT_MODE_P (mode))
> *total = cost->fsqrt;
> @@ -25083,7 +25083,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
> case MINUS_EXPR:
> if (kind == scalar_stmt)
> {
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> stmt_cost = ix86_cost->addss;
> else if (X87_FLOAT_MODE_P (mode))
> stmt_cost = ix86_cost->fadd;
> @@ -25109,7 +25109,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
> break;
>
> case NEGATE_EXPR:
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> stmt_cost = ix86_cost->sse_op;
> else if (X87_FLOAT_MODE_P (mode))
> stmt_cost = ix86_cost->fchs;
> @@ -25165,7 +25165,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
> case BIT_XOR_EXPR:
> case BIT_AND_EXPR:
> case BIT_NOT_EXPR:
> - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
> + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
> stmt_cost = ix86_cost->sse_op;
> else if (VECTOR_MODE_P (mode))
> stmt_cost = ix86_vec_cost (mode, ix86_cost->sse_op);
> diff --git a/gcc/config/i386/i386.h b/gcc/config/i386/i386.h
> index 51934400951..a4874a46dc7 100644
> --- a/gcc/config/i386/i386.h
> +++ b/gcc/config/i386/i386.h
> @@ -1158,9 +1158,10 @@ extern const char *host_detect_local_cpu (int argc, const char **argv);
> #define SSE_FLOAT_MODE_P(MODE) \
> ((TARGET_SSE && (MODE) == SFmode) || (TARGET_SSE2 && (MODE) == DFmode))
>
> -#define SSE_FLOAT_MODE_SSEMATH_OR_HF_P(MODE) \
> - ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) \
> - || (TARGET_AVX512FP16 && (MODE) == HFmode))
> +#define SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P(MODE) \
> + ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) \
> + || (TARGET_AVX512FP16 && (MODE) == HFmode) \
> + || (TARGET_AVX10_2_256 && (MODE) == BFmode))
>
> #define FMA4_VEC_FLOAT_MODE_P(MODE) \
> (TARGET_FMA4 && ((MODE) == V4SFmode || (MODE) == V2DFmode \
> diff --git a/gcc/config/i386/i386.md b/gcc/config/i386/i386.md
> index fb6aaa81505..11855d793a4 100644
> --- a/gcc/config/i386/i386.md
> +++ b/gcc/config/i386/i386.md
> @@ -1814,13 +1814,21 @@ (define_expand "cbranchbf4"
> (pc)))]
> "TARGET_80387 || (SSE_FLOAT_MODE_P (SFmode) && TARGET_SSE_MATH)"
> {
> - rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
> - rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
> - do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
> - SFmode, NULL_RTX, NULL,
> - as_a <rtx_code_label *> (operands[3]),
> - /* Unfortunately this isn't propagated. */
> - profile_probability::even ());
> + if (TARGET_AVX10_2_256 && !flag_trapping_math)
> + {
> + ix86_expand_branch (GET_CODE (operands[0]),
> + operands[1], operands[2], operands[3]);
> + }
> + else
> + {
> + rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
> + rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
> + do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
> + SFmode, NULL_RTX, NULL,
> + as_a <rtx_code_label *> (operands[3]),
> + /* Unfortunately this isn't propagated. */
> + profile_probability::even ());
> + }
> DONE;
> })
>
> @@ -2096,6 +2104,17 @@ (define_insn "*cmpi<unord>hf"
> (set_attr "prefix" "evex")
> (set_attr "mode" "HF")])
>
> +(define_insn "*cmpibf"
> + [(set (reg:CCFP FLAGS_REG)
> + (compare:CCFP
> + (match_operand:BF 0 "register_operand" "v")
> + (match_operand:BF 1 "nonimmediate_operand" "vm")))]
> + "TARGET_AVX10_2_256"
> + "vcomsbf16\t{%1, %0|%0, %1}"
> + [(set_attr "type" "ssecomi")
> + (set_attr "prefix" "evex")
> + (set_attr "mode" "BF")])
> +
> ;; Set carry flag.
> (define_insn "x86_stc"
> [(set (reg:CCC FLAGS_REG) (unspec:CCC [(const_int 0)] UNSPEC_STC))]
> diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
> new file mode 100644
> index 00000000000..85b773b89f2
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
> @@ -0,0 +1,40 @@
> +/* { dg-do compile } */
> +/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
> +/* { dg-final { scan-assembler-times "vcomsbf16\[ \\t\]+\[^{}\n\]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 6 } } */
> +/* { dg-final { scan-assembler-times {j[a-z]+\s} 6 } } */
> +
> +__bf16
> +foo_eq (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a == b ? c + d : c - d;
> +}
> +
> +__bf16
> +foo_ne (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a != b ? c + d : c - d;
> +}
> +
> +__bf16
> +foo_lt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a < b ? c + d : c - d;
> +}
> +
> +__bf16
> +foo_le (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a <= b ? c + d : c - d;
> +}
> +
> +__bf16
> +foo_gt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a > b ? c + d : c - d;
> +}
> +
> +__bf16
> +foo_ge (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
> +{
> + return a >= b ? c + d : c - d;
> +}
> diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
> new file mode 100644
> index 00000000000..126957bf272
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
> @@ -0,0 +1,118 @@
> + /* { dg-do run } */
> +/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
> +
> +#include <stdlib.h>
> +#include <stdint.h>
> +#include <string.h>
> +
> +/* Fast shift conversion here for convenience */
> +static __bf16
> +float_to_bf16 (float f)
> +{
> + uint32_t float_bits;
> + uint16_t bf16_bits;
> +
> + memcpy (&float_bits, &f, sizeof (float_bits));
> + bf16_bits = (uint16_t) (float_bits >> 16);
> +
> + __bf16 bf;
> + memcpy (&bf, &bf16_bits, sizeof (bf));
> + return bf;
> +}
> +
> +static float
> +bf16_to_float (__bf16 bf)
> +{
> + uint32_t float_bits;
> + uint16_t bf16_bits;
> +
> + memcpy (&bf16_bits, &bf, sizeof (bf16_bits));
> + float_bits = ((uint32_t) bf16_bits) << 16;
> +
> + float f;
> + memcpy (&f, &float_bits, sizeof (f));
> + return f;
> +}
> +
> +static void
> +test_eq (__bf16 a, __bf16 b)
> +{
> + int result = (a == b);
> + int expected = (bf16_to_float (a) == bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +static void
> +test_ne (__bf16 a, __bf16 b)
> +{
> + int result = (a != b);
> + int expected = (bf16_to_float (a) != bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +static void
> +test_lt (__bf16 a, __bf16 b)
> +{
> + int result = (a < b);
> + int expected = (bf16_to_float (a) < bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +static void
> +test_le (__bf16 a, __bf16 b)
> +{
> + int result = (a <= b);
> + int expected = (bf16_to_float (a) <= bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +static void
> +test_gt (__bf16 a, __bf16 b)
> +{
> + int result = (a > b);
> + int expected = (bf16_to_float (a) > bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +static void
> +test_ge (__bf16 a, __bf16 b)
> +{
> + int result = (a >= b);
> + int expected = (bf16_to_float (a) >= bf16_to_float (b));
> + if (result != expected)
> + abort ();
> +}
> +
> +int
> +main (void)
> +{
> + if (!__builtin_cpu_supports ("avx10.2"))
> + return 0;
> +
> + float test_values[] = {
> + -10.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 10.0f, 100.0f, -100.0f
> + };
> +
> + size_t num_values = sizeof (test_values) / sizeof (test_values[0]);
> +
> + for (size_t i = 0; i < num_values; i++)
> + for (size_t j = 0; j < num_values; j++)
> + {
> + __bf16 a = float_to_bf16 (test_values[i]);
> + __bf16 b = float_to_bf16 (test_values[j]);
> +
> + test_eq (a, b);
> + test_ne (a, b);
> + test_lt (a, b);
> + test_le (a, b);
> + test_gt (a, b);
> + test_ge (a, b);
> + }
> +
> + return 0;
> +}
> --
> 2.31.1
>
@@ -2531,6 +2531,10 @@ ix86_expand_branch (enum rtx_code code, rtx op0, rtx op1, rtx label)
emit_jump_insn (gen_rtx_SET (pc_rtx, tmp));
return;
+ case E_BFmode:
+ gcc_assert (TARGET_AVX10_2_256 && !flag_trapping_math);
+ goto simple;
+
case E_DImode:
if (TARGET_64BIT)
goto simple;
@@ -2797,9 +2801,9 @@ ix86_prepare_fp_compare_args (enum rtx_code code, rtx *pop0, rtx *pop1)
bool unordered_compare = ix86_unordered_fp_compare (code);
rtx op0 = *pop0, op1 = *pop1;
machine_mode op_mode = GET_MODE (op0);
- bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HF_P (op_mode);
+ bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (op_mode);
- if (op_mode == BFmode)
+ if (op_mode == BFmode && (!TARGET_AVX10_2_256 || flag_trapping_math))
{
rtx op = gen_lowpart (HImode, op0);
if (CONST_INT_P (op))
@@ -2918,10 +2922,14 @@ ix86_expand_fp_compare (enum rtx_code code, rtx op0, rtx op1)
{
case IX86_FPCMP_COMI:
tmp = gen_rtx_COMPARE (CCFPmode, op0, op1);
- if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
- tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
- if (unordered_compare)
- tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
+ /* We only have vcomsbf16, No vcomubf16 nor vcomxbf16 */
+ if (GET_MODE (op0) != E_BFmode)
+ {
+ if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
+ tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
+ if (unordered_compare)
+ tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
+ }
cmp_mode = CCFPmode;
emit_insn (gen_rtx_SET (gen_rtx_REG (CCFPmode, FLAGS_REG), tmp));
break;
@@ -4636,7 +4644,7 @@ ix86_expand_fp_movcc (rtx operands[])
&& !ix86_fp_comparison_operator (operands[1], VOIDmode))
return false;
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
{
machine_mode cmode;
@@ -21324,7 +21324,7 @@ ix86_multiplication_cost (const struct processor_costs *cost,
if (VECTOR_MODE_P (mode))
inner_mode = GET_MODE_INNER (mode);
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
return inner_mode == DFmode ? cost->mulsd : cost->mulss;
else if (X87_FLOAT_MODE_P (mode))
return cost->fmul;
@@ -21449,7 +21449,7 @@ ix86_division_cost (const struct processor_costs *cost,
if (VECTOR_MODE_P (mode))
inner_mode = GET_MODE_INNER (mode);
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
return inner_mode == DFmode ? cost->divsd : cost->divss;
else if (X87_FLOAT_MODE_P (mode))
return cost->fdiv;
@@ -21991,7 +21991,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
return true;
}
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = cost->addss;
else if (X87_FLOAT_MODE_P (mode))
*total = cost->fadd;
@@ -22198,7 +22198,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
return false;
case NEG:
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = cost->sse_op;
else if (X87_FLOAT_MODE_P (mode))
*total = cost->fchs;
@@ -22306,14 +22306,14 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
return false;
case FLOAT_EXTEND:
- if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = 0;
else
*total = ix86_vec_cost (mode, cost->addss);
return false;
case FLOAT_TRUNCATE:
- if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = cost->fadd;
else
*total = ix86_vec_cost (mode, cost->addss);
@@ -22323,7 +22323,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
/* SSE requires memory load for the constant operand. It may make
sense to account for this. Of course the constant operand may or
may not be reused. */
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = cost->sse_op;
else if (X87_FLOAT_MODE_P (mode))
*total = cost->fabs;
@@ -22334,7 +22334,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
return false;
case SQRT:
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
*total = mode == SFmode ? cost->sqrtss : cost->sqrtsd;
else if (X87_FLOAT_MODE_P (mode))
*total = cost->fsqrt;
@@ -25083,7 +25083,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
case MINUS_EXPR:
if (kind == scalar_stmt)
{
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
stmt_cost = ix86_cost->addss;
else if (X87_FLOAT_MODE_P (mode))
stmt_cost = ix86_cost->fadd;
@@ -25109,7 +25109,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
break;
case NEGATE_EXPR:
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
stmt_cost = ix86_cost->sse_op;
else if (X87_FLOAT_MODE_P (mode))
stmt_cost = ix86_cost->fchs;
@@ -25165,7 +25165,7 @@ ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
case BIT_XOR_EXPR:
case BIT_AND_EXPR:
case BIT_NOT_EXPR:
- if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+ if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
stmt_cost = ix86_cost->sse_op;
else if (VECTOR_MODE_P (mode))
stmt_cost = ix86_vec_cost (mode, ix86_cost->sse_op);
@@ -1158,9 +1158,10 @@ extern const char *host_detect_local_cpu (int argc, const char **argv);
#define SSE_FLOAT_MODE_P(MODE) \
((TARGET_SSE && (MODE) == SFmode) || (TARGET_SSE2 && (MODE) == DFmode))
-#define SSE_FLOAT_MODE_SSEMATH_OR_HF_P(MODE) \
- ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) \
- || (TARGET_AVX512FP16 && (MODE) == HFmode))
+#define SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P(MODE) \
+ ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) \
+ || (TARGET_AVX512FP16 && (MODE) == HFmode) \
+ || (TARGET_AVX10_2_256 && (MODE) == BFmode))
#define FMA4_VEC_FLOAT_MODE_P(MODE) \
(TARGET_FMA4 && ((MODE) == V4SFmode || (MODE) == V2DFmode \
@@ -1814,13 +1814,21 @@ (define_expand "cbranchbf4"
(pc)))]
"TARGET_80387 || (SSE_FLOAT_MODE_P (SFmode) && TARGET_SSE_MATH)"
{
- rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
- rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
- do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
- SFmode, NULL_RTX, NULL,
- as_a <rtx_code_label *> (operands[3]),
- /* Unfortunately this isn't propagated. */
- profile_probability::even ());
+ if (TARGET_AVX10_2_256 && !flag_trapping_math)
+ {
+ ix86_expand_branch (GET_CODE (operands[0]),
+ operands[1], operands[2], operands[3]);
+ }
+ else
+ {
+ rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
+ rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
+ do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
+ SFmode, NULL_RTX, NULL,
+ as_a <rtx_code_label *> (operands[3]),
+ /* Unfortunately this isn't propagated. */
+ profile_probability::even ());
+ }
DONE;
})
@@ -2096,6 +2104,17 @@ (define_insn "*cmpi<unord>hf"
(set_attr "prefix" "evex")
(set_attr "mode" "HF")])
+(define_insn "*cmpibf"
+ [(set (reg:CCFP FLAGS_REG)
+ (compare:CCFP
+ (match_operand:BF 0 "register_operand" "v")
+ (match_operand:BF 1 "nonimmediate_operand" "vm")))]
+ "TARGET_AVX10_2_256"
+ "vcomsbf16\t{%1, %0|%0, %1}"
+ [(set_attr "type" "ssecomi")
+ (set_attr "prefix" "evex")
+ (set_attr "mode" "BF")])
+
;; Set carry flag.
(define_insn "x86_stc"
[(set (reg:CCC FLAGS_REG) (unspec:CCC [(const_int 0)] UNSPEC_STC))]
new file mode 100644
@@ -0,0 +1,40 @@
+/* { dg-do compile } */
+/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
+/* { dg-final { scan-assembler-times "vcomsbf16\[ \\t\]+\[^{}\n\]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 6 } } */
+/* { dg-final { scan-assembler-times {j[a-z]+\s} 6 } } */
+
+__bf16
+foo_eq (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a == b ? c + d : c - d;
+}
+
+__bf16
+foo_ne (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a != b ? c + d : c - d;
+}
+
+__bf16
+foo_lt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a < b ? c + d : c - d;
+}
+
+__bf16
+foo_le (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a <= b ? c + d : c - d;
+}
+
+__bf16
+foo_gt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a > b ? c + d : c - d;
+}
+
+__bf16
+foo_ge (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+ return a >= b ? c + d : c - d;
+}
new file mode 100644
@@ -0,0 +1,118 @@
+ /* { dg-do run } */
+/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
+
+#include <stdlib.h>
+#include <stdint.h>
+#include <string.h>
+
+/* Fast shift conversion here for convenience */
+static __bf16
+float_to_bf16 (float f)
+{
+ uint32_t float_bits;
+ uint16_t bf16_bits;
+
+ memcpy (&float_bits, &f, sizeof (float_bits));
+ bf16_bits = (uint16_t) (float_bits >> 16);
+
+ __bf16 bf;
+ memcpy (&bf, &bf16_bits, sizeof (bf));
+ return bf;
+}
+
+static float
+bf16_to_float (__bf16 bf)
+{
+ uint32_t float_bits;
+ uint16_t bf16_bits;
+
+ memcpy (&bf16_bits, &bf, sizeof (bf16_bits));
+ float_bits = ((uint32_t) bf16_bits) << 16;
+
+ float f;
+ memcpy (&f, &float_bits, sizeof (f));
+ return f;
+}
+
+static void
+test_eq (__bf16 a, __bf16 b)
+{
+ int result = (a == b);
+ int expected = (bf16_to_float (a) == bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+static void
+test_ne (__bf16 a, __bf16 b)
+{
+ int result = (a != b);
+ int expected = (bf16_to_float (a) != bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+static void
+test_lt (__bf16 a, __bf16 b)
+{
+ int result = (a < b);
+ int expected = (bf16_to_float (a) < bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+static void
+test_le (__bf16 a, __bf16 b)
+{
+ int result = (a <= b);
+ int expected = (bf16_to_float (a) <= bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+static void
+test_gt (__bf16 a, __bf16 b)
+{
+ int result = (a > b);
+ int expected = (bf16_to_float (a) > bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+static void
+test_ge (__bf16 a, __bf16 b)
+{
+ int result = (a >= b);
+ int expected = (bf16_to_float (a) >= bf16_to_float (b));
+ if (result != expected)
+ abort ();
+}
+
+int
+main (void)
+{
+ if (!__builtin_cpu_supports ("avx10.2"))
+ return 0;
+
+ float test_values[] = {
+ -10.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 10.0f, 100.0f, -100.0f
+ };
+
+ size_t num_values = sizeof (test_values) / sizeof (test_values[0]);
+
+ for (size_t i = 0; i < num_values; i++)
+ for (size_t j = 0; j < num_values; j++)
+ {
+ __bf16 a = float_to_bf16 (test_values[i]);
+ __bf16 b = float_to_bf16 (test_values[j]);
+
+ test_eq (a, b);
+ test_ne (a, b);
+ test_lt (a, b);
+ test_le (a, b);
+ test_gt (a, b);
+ test_ge (a, b);
+ }
+
+ return 0;
+}