i386: Support partial vectorized FMA for V2BF/V4BF

Message ID 20240904033121.1895231-1-admin@levyhsu.com
State New
Headers
Series i386: Support partial vectorized FMA for V2BF/V4BF |

Checks

Context Check Description
linaro-tcwg-bot/tcwg_gcc_build--master-aarch64 fail Patch failed to apply
linaro-tcwg-bot/tcwg_gcc_build--master-arm fail Patch failed to apply

Commit Message

Levy Hsu Sept. 4, 2024, 3:23 a.m. UTC
  Hi

Bootstrapped and tested on x86-64-pc-linux-gnu. 
Ok for trunk?

This patch introduces support for vectorized FMA operations for bf16 types in
V2BF and V4BF modes on the i386 architecture. New mode iterators and
define_expand entries for fma, fnma, fms, and fnms operations are added in
mmx.md, enhancing the i386 backend to handle these complex arithmetic operations.

gcc/ChangeLog:

	* config/i386/mmx.md (TARGET_MMX_WITH_SSE): New mode iterator VBF_32_64
	(fma<mode>4): define_expand for V2BF/V4BF fma<mode>4.
	(fnma<mode>4): define_expand for V2BF/V4BF fnma<mode>4.
	(fms<mode>4): define_expand for V2BF/V4BF fms<mode>4.
	(fnms<mode>4): define_expand for V2BF/V4BF fnms<mode>4.

gcc/testsuite/ChangeLog:

	* gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c: New test.
---
 gcc/config/i386/mmx.md                        | 84 ++++++++++++++++++-
 .../i386/avx10_2-partial-bf-vector-fma-1.c    | 57 +++++++++++++
 2 files changed, 139 insertions(+), 2 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
  

Comments

Hongtao Liu Sept. 5, 2024, 1:35 a.m. UTC | #1
On Wed, Sep 4, 2024 at 11:31 AM Levy Hsu <admin@levyhsu.com> wrote:
>
> Hi
>
> Bootstrapped and tested on x86-64-pc-linux-gnu.
> Ok for trunk?
Ok.
>
> This patch introduces support for vectorized FMA operations for bf16 types in
> V2BF and V4BF modes on the i386 architecture. New mode iterators and
> define_expand entries for fma, fnma, fms, and fnms operations are added in
> mmx.md, enhancing the i386 backend to handle these complex arithmetic operations.
>
> gcc/ChangeLog:
>
>         * config/i386/mmx.md (TARGET_MMX_WITH_SSE): New mode iterator VBF_32_64
>         (fma<mode>4): define_expand for V2BF/V4BF fma<mode>4.
>         (fnma<mode>4): define_expand for V2BF/V4BF fnma<mode>4.
>         (fms<mode>4): define_expand for V2BF/V4BF fms<mode>4.
>         (fnms<mode>4): define_expand for V2BF/V4BF fnms<mode>4.
>
> gcc/testsuite/ChangeLog:
>
>         * gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c: New test.
> ---
>  gcc/config/i386/mmx.md                        | 84 ++++++++++++++++++-
>  .../i386/avx10_2-partial-bf-vector-fma-1.c    | 57 +++++++++++++
>  2 files changed, 139 insertions(+), 2 deletions(-)
>  create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
>
> diff --git a/gcc/config/i386/mmx.md b/gcc/config/i386/mmx.md
> index 10fcd2beda6..22aeb43f436 100644
> --- a/gcc/config/i386/mmx.md
> +++ b/gcc/config/i386/mmx.md
> @@ -2636,6 +2636,88 @@
>    DONE;
>  })
>
> +(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")])
> +
> +(define_expand "fma<mode>4"
> +  [(set (match_operand:VBF_32_64 0 "register_operand")
> +       (fma:VBF_32_64
> +         (match_operand:VBF_32_64 1 "nonimmediate_operand")
> +         (match_operand:VBF_32_64 2 "nonimmediate_operand")
> +         (match_operand:VBF_32_64 3 "nonimmediate_operand")))]
> +  "TARGET_AVX10_2_256"
> +{
> +  rtx op0 = gen_reg_rtx (V8BFmode);
> +  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
> +  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
> +  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
> +
> +  emit_insn (gen_fmav8bf4 (op0, op1, op2, op3));
> +
> +  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
> +  DONE;
> +})
> +
> +(define_expand "fms<mode>4"
> +  [(set (match_operand:VBF_32_64 0 "register_operand")
> +       (fma:VBF_32_64
> +         (match_operand:VBF_32_64   1 "nonimmediate_operand")
> +         (match_operand:VBF_32_64   2 "nonimmediate_operand")
> +         (neg:VBF_32_64
> +           (match_operand:VBF_32_64 3 "nonimmediate_operand"))))]
> +  "TARGET_AVX10_2_256"
> +{
> +  rtx op0 = gen_reg_rtx (V8BFmode);
> +  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
> +  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
> +  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
> +
> +  emit_insn (gen_fmsv8bf4 (op0, op1, op2, op3));
> +
> +  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
> +  DONE;
> +})
> +
> +(define_expand "fnma<mode>4"
> +  [(set (match_operand:VBF_32_64 0 "register_operand")
> +       (fma:VBF_32_64
> +         (neg:VBF_32_64
> +           (match_operand:VBF_32_64 1 "nonimmediate_operand"))
> +         (match_operand:VBF_32_64   2 "nonimmediate_operand")
> +         (match_operand:VBF_32_64   3 "nonimmediate_operand")))]
> +  "TARGET_AVX10_2_256"
> +{
> +  rtx op0 = gen_reg_rtx (V8BFmode);
> +  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
> +  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
> +  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
> +
> +  emit_insn (gen_fnmav8bf4 (op0, op1, op2, op3));
> +
> +  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
> +  DONE;
> +})
> +
> +(define_expand "fnms<mode>4"
> +  [(set (match_operand:VBF_32_64 0 "register_operand")
> +       (fma:VBF_32_64
> +         (neg:VBF_32_64
> +           (match_operand:VBF_32_64 1 "nonimmediate_operand"))
> +         (match_operand:VBF_32_64   2 "nonimmediate_operand")
> +         (neg:VBF_32_64
> +           (match_operand:VBF_32_64 3 "nonimmediate_operand"))))]
> +  "TARGET_AVX10_2_256"
> +{
> +  rtx op0 = gen_reg_rtx (V8BFmode);
> +  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
> +  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
> +  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
> +
> +  emit_insn (gen_fnmsv8bf4 (op0, op1, op2, op3));
> +
> +  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
> +  DONE;
> +})
> +
>  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
>  ;;
>  ;; Parallel half-precision floating point complex type operations
> @@ -6670,8 +6752,6 @@
>     (set_attr "modrm" "0")
>     (set_attr "memory" "none")])
>
> -(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")])
> -
>  ;; VDIVNEPBF16 does not generate floating point exceptions.
>  (define_expand "<insn><mode>3"
>    [(set (match_operand:VBF_32_64 0 "register_operand")
> diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
> new file mode 100644
> index 00000000000..72e17e99603
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
> @@ -0,0 +1,57 @@
> +/* { dg-do compile } */
> +/* { dg-options "-mavx10.2 -O2" } */
> +/* { dg-final { scan-assembler-times "vfmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
> +/* { dg-final { scan-assembler-times "vfmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
> +/* { dg-final { scan-assembler-times "vfnmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
> +/* { dg-final { scan-assembler-times "vfnmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
> +
> +typedef __bf16 v4bf __attribute__ ((__vector_size__ (8)));
> +typedef __bf16 v2bf __attribute__ ((__vector_size__ (4)));
> +
> +v4bf
> +foo_madd_64 (v4bf a, v4bf b, v4bf c)
> +{
> +  return a * b + c;
> +}
> +
> +v4bf
> +foo_msub_64 (v4bf a, v4bf b, v4bf c)
> +{
> +  return a * b - c;
> +}
> +
> +v4bf
> +foo_nmadd_64 (v4bf a, v4bf b, v4bf c)
> +{
> +  return -a * b + c;
> +}
> +
> +v4bf
> +foo_nmsub_64 (v4bf a, v4bf b, v4bf c)
> +{
> +  return -a * b - c;
> +}
> +
> +v2bf
> +foo_madd_32 (v2bf a, v2bf b, v2bf c)
> +{
> +  return a * b + c;
> +}
> +
> +v2bf
> +foo_msub_32 (v2bf a, v2bf b, v2bf c)
> +{
> +  return a * b - c;
> +}
> +
> +v2bf
> +foo_nmadd_32 (v2bf a, v2bf b, v2bf c)
> +{
> +  return -a * b + c;
> +}
> +
> +v2bf
> +foo_nmsub_32 (v2bf a, v2bf b, v2bf c)
> +{
> +  return -a * b - c;
> +}
> --
> 2.31.1
>
  

Patch

diff --git a/gcc/config/i386/mmx.md b/gcc/config/i386/mmx.md
index 10fcd2beda6..22aeb43f436 100644
--- a/gcc/config/i386/mmx.md
+++ b/gcc/config/i386/mmx.md
@@ -2636,6 +2636,88 @@ 
   DONE;
 })
 
+(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")])
+
+(define_expand "fma<mode>4"
+  [(set (match_operand:VBF_32_64 0 "register_operand")
+	(fma:VBF_32_64
+	  (match_operand:VBF_32_64 1 "nonimmediate_operand")
+	  (match_operand:VBF_32_64 2 "nonimmediate_operand")
+	  (match_operand:VBF_32_64 3 "nonimmediate_operand")))]
+  "TARGET_AVX10_2_256"
+{
+  rtx op0 = gen_reg_rtx (V8BFmode);
+  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
+  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
+  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
+
+  emit_insn (gen_fmav8bf4 (op0, op1, op2, op3));
+
+  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
+  DONE;
+})
+
+(define_expand "fms<mode>4"
+  [(set (match_operand:VBF_32_64 0 "register_operand")
+	(fma:VBF_32_64
+	  (match_operand:VBF_32_64   1 "nonimmediate_operand")
+	  (match_operand:VBF_32_64   2 "nonimmediate_operand")
+	  (neg:VBF_32_64
+	    (match_operand:VBF_32_64 3 "nonimmediate_operand"))))]
+  "TARGET_AVX10_2_256"
+{
+  rtx op0 = gen_reg_rtx (V8BFmode);
+  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
+  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
+  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
+
+  emit_insn (gen_fmsv8bf4 (op0, op1, op2, op3));
+
+  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
+  DONE;
+})
+
+(define_expand "fnma<mode>4"
+  [(set (match_operand:VBF_32_64 0 "register_operand")
+	(fma:VBF_32_64
+	  (neg:VBF_32_64
+	    (match_operand:VBF_32_64 1 "nonimmediate_operand"))
+	  (match_operand:VBF_32_64   2 "nonimmediate_operand")
+	  (match_operand:VBF_32_64   3 "nonimmediate_operand")))]
+  "TARGET_AVX10_2_256"
+{
+  rtx op0 = gen_reg_rtx (V8BFmode);
+  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
+  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
+  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
+
+  emit_insn (gen_fnmav8bf4 (op0, op1, op2, op3));
+
+  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
+  DONE;
+})
+
+(define_expand "fnms<mode>4"
+  [(set (match_operand:VBF_32_64 0 "register_operand")
+	(fma:VBF_32_64
+	  (neg:VBF_32_64
+	    (match_operand:VBF_32_64 1 "nonimmediate_operand"))
+	  (match_operand:VBF_32_64   2 "nonimmediate_operand")
+	  (neg:VBF_32_64
+	    (match_operand:VBF_32_64 3 "nonimmediate_operand"))))]
+  "TARGET_AVX10_2_256"
+{
+  rtx op0 = gen_reg_rtx (V8BFmode);
+  rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode);
+  rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode);
+  rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode);
+
+  emit_insn (gen_fnmsv8bf4 (op0, op1, op2, op3));
+
+  emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode));
+  DONE;
+})
+
 ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
 ;;
 ;; Parallel half-precision floating point complex type operations
@@ -6670,8 +6752,6 @@ 
    (set_attr "modrm" "0")
    (set_attr "memory" "none")])
 
-(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")])
-
 ;; VDIVNEPBF16 does not generate floating point exceptions.
 (define_expand "<insn><mode>3"
   [(set (match_operand:VBF_32_64 0 "register_operand")
diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
new file mode 100644
index 00000000000..72e17e99603
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c
@@ -0,0 +1,57 @@ 
+/* { dg-do compile } */
+/* { dg-options "-mavx10.2 -O2" } */
+/* { dg-final { scan-assembler-times "vfmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
+/* { dg-final { scan-assembler-times "vfmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
+/* { dg-final { scan-assembler-times "vfnmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
+/* { dg-final { scan-assembler-times "vfnmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */
+
+typedef __bf16 v4bf __attribute__ ((__vector_size__ (8)));
+typedef __bf16 v2bf __attribute__ ((__vector_size__ (4)));
+
+v4bf
+foo_madd_64 (v4bf a, v4bf b, v4bf c)
+{
+  return a * b + c;
+}
+
+v4bf
+foo_msub_64 (v4bf a, v4bf b, v4bf c)
+{
+  return a * b - c;
+}
+
+v4bf
+foo_nmadd_64 (v4bf a, v4bf b, v4bf c)
+{
+  return -a * b + c;
+}
+
+v4bf
+foo_nmsub_64 (v4bf a, v4bf b, v4bf c)
+{
+  return -a * b - c;
+}
+
+v2bf
+foo_madd_32 (v2bf a, v2bf b, v2bf c)
+{
+  return a * b + c;
+}
+
+v2bf
+foo_msub_32 (v2bf a, v2bf b, v2bf c)
+{
+  return a * b - c;
+}
+
+v2bf
+foo_nmadd_32 (v2bf a, v2bf b, v2bf c)
+{
+  return -a * b + c;
+}
+
+v2bf
+foo_nmsub_32 (v2bf a, v2bf b, v2bf c)
+{
+  return -a * b - c;
+}