diff mbox series

[v1,14/15] tcg/riscv: Implement vector roti/v/x shi ops

Message ID 20240813113436.831-15-zhiwei_liu@linux.alibaba.com
State New
Headers show
Series tcg/riscv: Add support for vector | expand

Commit Message

LIU Zhiwei Aug. 13, 2024, 11:34 a.m. UTC
From: TANG Tiancheng <tangtiancheng.ttc@alibaba-inc.com>

Signed-off-by: TANG Tiancheng <tangtiancheng.ttc@alibaba-inc.com>
Reviewed-by: Liu Zhiwei <zhiwei_liu@linux.alibaba.com>
---
 tcg/riscv/tcg-target.c.inc | 107 ++++++++++++++++++++++++++++++++++++-
 tcg/riscv/tcg-target.h     |   8 +--
 tcg/riscv/tcg-target.opc.h |   3 ++
 3 files changed, 113 insertions(+), 5 deletions(-)

Comments

Richard Henderson Aug. 14, 2024, 9:55 a.m. UTC | #1
On 8/13/24 21:34, LIU Zhiwei wrote:
> +    case INDEX_op_shli_vec:
> +        if (a2 > 31) {
> +            t2 = tcg_temp_new_i32();
> +            tcg_gen_movi_i32(t2, (int32_t)a2);
> +            tcg_gen_shls_vec(vece, v0, v1, t2);

Drop the movi, just pass tcg_constant_i32(a2) as the second source.

> +    case INDEX_op_rotls_vec:
> +        t1 = tcg_temp_new_vec(type);
> +        t2 = tcg_temp_new_i32();
> +        tcg_gen_sub_i32(t2, tcg_constant_i32(8 << vece),
> +                        temp_tcgv_i32(arg_temp(a2)));
> +        tcg_gen_shrs_vec(vece, v0, v1, t2);

Only the low lg2(SEW) bits are used; you can just tcg_gen_neg_i32.

> +    case INDEX_op_rotlv_vec:
> +        v2 = temp_tcgv_vec(arg_temp(a2));
> +        t1 = tcg_temp_new_vec(type);
> +        c1 = tcg_constant_vec(type, vece, 8 << vece);
> +        tcg_gen_sub_vec(vece, t1, c1, v2);

Likewise tcg_gen_neg_vec.

> +    case INDEX_op_rotrv_vec:
> +        v2 = temp_tcgv_vec(arg_temp(a2));
> +        t1 = tcg_temp_new_vec(type);
> +        c1 = tcg_constant_vec(type, vece, 8 << vece);
> +        tcg_gen_sub_vec(vece, t1, c1, v2);

Likewise.


r~
LIU Zhiwei Aug. 27, 2024, 7:57 a.m. UTC | #2
On 2024/8/14 17:55, Richard Henderson wrote:
> On 8/13/24 21:34, LIU Zhiwei wrote:
>> +    case INDEX_op_shli_vec:
>> +        if (a2 > 31) {
>> +            t2 = tcg_temp_new_i32();
>> +            tcg_gen_movi_i32(t2, (int32_t)a2);
>> +            tcg_gen_shls_vec(vece, v0, v1, t2);
>
> Drop the movi, just pass tcg_constant_i32(a2) as the second source.
OK.
>
>> +    case INDEX_op_rotls_vec:
>> +        t1 = tcg_temp_new_vec(type);
>> +        t2 = tcg_temp_new_i32();
>> +        tcg_gen_sub_i32(t2, tcg_constant_i32(8 << vece),
>> +                        temp_tcgv_i32(arg_temp(a2)));
>> +        tcg_gen_shrs_vec(vece, v0, v1, t2);
>
> Only the low lg2(SEW) bits are used; you can just tcg_gen_neg_i32.
Good idea.
>
>> +    case INDEX_op_rotlv_vec:
>> +        v2 = temp_tcgv_vec(arg_temp(a2));
>> +        t1 = tcg_temp_new_vec(type);
>> +        c1 = tcg_constant_vec(type, vece, 8 << vece);
>> +        tcg_gen_sub_vec(vece, t1, c1, v2);
>
> Likewise tcg_gen_neg_vec.
>
>> +    case INDEX_op_rotrv_vec:
>> +        v2 = temp_tcgv_vec(arg_temp(a2));
>> +        t1 = tcg_temp_new_vec(type);
>> +        c1 = tcg_constant_vec(type, vece, 8 << vece);
>> +        tcg_gen_sub_vec(vece, t1, c1, v2);
>
> Likewise.
Thanks,
Zhiwei
>
>
> r~
diff mbox series

Patch

diff --git a/tcg/riscv/tcg-target.c.inc b/tcg/riscv/tcg-target.c.inc
index 467437e175..59d23ed622 100644
--- a/tcg/riscv/tcg-target.c.inc
+++ b/tcg/riscv/tcg-target.c.inc
@@ -345,10 +345,13 @@  typedef enum {
     OPC_VMSGT_VX = 0x7c000057 | V_OPIVX,
 
     OPC_VSLL_VV = 0x94000057 | V_OPIVV,
+    OPC_VSLL_VI = 0x94000057 | V_OPIVI,
     OPC_VSLL_VX = 0x94000057 | V_OPIVX,
     OPC_VSRL_VV = 0xa0000057 | V_OPIVV,
+    OPC_VSRL_VI = 0xa0000057 | V_OPIVI,
     OPC_VSRL_VX = 0xa0000057 | V_OPIVX,
     OPC_VSRA_VV = 0xa4000057 | V_OPIVV,
+    OPC_VSRA_VI = 0xa4000057 | V_OPIVI,
     OPC_VSRA_VX = 0xa4000057 | V_OPIVX,
 
     OPC_VMV_V_V = 0x5e000057 | V_OPIVV,
@@ -2384,6 +2387,15 @@  static void tcg_out_vec_op(TCGContext *s, TCGOpcode opc,
     case INDEX_op_sarv_vec:
         tcg_out_opc_vv(s, OPC_VSRA_VV, a0, a1, a2, true);
         break;
+    case INDEX_op_rvv_shli_vec:
+        tcg_out_opc_vi(s, OPC_VSLL_VI, a0, a1, a2, true);
+        break;
+    case INDEX_op_rvv_shri_vec:
+        tcg_out_opc_vi(s, OPC_VSRL_VI, a0, a1, a2, true);
+        break;
+    case INDEX_op_rvv_sari_vec:
+        tcg_out_opc_vi(s, OPC_VSRA_VI, a0, a1, a2, true);
+        break;
     case INDEX_op_rvv_cmpcond_vec:
         {
             RISCVInsn op;
@@ -2422,7 +2434,8 @@  void tcg_expand_vec_op(TCGOpcode opc, TCGType type, unsigned vece,
                        TCGArg a0, ...)
 {
     va_list va;
-    TCGv_vec v0, v1;
+    TCGv_vec v0, v1, v2, c1, t1;
+    TCGv_i32 t2;
     TCGArg a2, a3;
 
     va_start(va, a0);
@@ -2442,6 +2455,81 @@  void tcg_expand_vec_op(TCGOpcode opc, TCGType type, unsigned vece,
                       tcgv_i64_arg(tcg_constant_i64(-1)));
         }
         break;
+    case INDEX_op_shli_vec:
+        if (a2 > 31) {
+            t2 = tcg_temp_new_i32();
+            tcg_gen_movi_i32(t2, (int32_t)a2);
+            tcg_gen_shls_vec(vece, v0, v1, t2);
+            tcg_temp_free_i32(t2);
+        } else {
+            vec_gen_3(INDEX_op_rvv_shli_vec, type, vece, tcgv_vec_arg(v0),
+                      tcgv_vec_arg(v1), a2);
+        }
+        break;
+    case INDEX_op_shri_vec:
+        if (a2 > 31) {
+            t2 = tcg_temp_new_i32();
+            tcg_gen_movi_i32(t2, (int32_t)a2);
+            tcg_gen_shrs_vec(vece, v0, v1, t2);
+            tcg_temp_free_i32(t2);
+        } else {
+            vec_gen_3(INDEX_op_rvv_shri_vec, type, vece, tcgv_vec_arg(v0),
+                      tcgv_vec_arg(v1), a2);
+        }
+        break;
+    case INDEX_op_sari_vec:
+        if (a2 > 31) {
+            t2 = tcg_temp_new_i32();
+            tcg_gen_movi_i32(t2, (int32_t)a2);
+            tcg_gen_sars_vec(vece, v0, v1, t2);
+            tcg_temp_free_i32(t2);
+        } else {
+            vec_gen_3(INDEX_op_rvv_sari_vec, type, vece, tcgv_vec_arg(v0),
+                      tcgv_vec_arg(v1), a2);
+        }
+        break;
+    case INDEX_op_rotli_vec:
+        t1 = tcg_temp_new_vec(type);
+        tcg_gen_shli_vec(vece, t1, v1, a2);
+        tcg_gen_shri_vec(vece, v0, v1, (8 << vece) - a2);
+        tcg_gen_or_vec(vece, v0, v0, t1);
+        tcg_temp_free_vec(t1);
+        break;
+    case INDEX_op_rotls_vec:
+        t1 = tcg_temp_new_vec(type);
+        t2 = tcg_temp_new_i32();
+        tcg_gen_sub_i32(t2, tcg_constant_i32(8 << vece),
+                        temp_tcgv_i32(arg_temp(a2)));
+        tcg_gen_shrs_vec(vece, v0, v1, t2);
+        tcg_gen_shls_vec(vece, t1, v1, temp_tcgv_i32(arg_temp(a2)));
+        tcg_gen_or_vec(vece, v0, v0, t1);
+        tcg_temp_free_vec(t1);
+        tcg_temp_free_i32(t2);
+        break;
+    case INDEX_op_rotlv_vec:
+        v2 = temp_tcgv_vec(arg_temp(a2));
+        t1 = tcg_temp_new_vec(type);
+        c1 = tcg_constant_vec(type, vece, 8 << vece);
+        tcg_gen_sub_vec(vece, t1, c1, v2);
+        vec_gen_3(INDEX_op_shrv_vec, type, vece, tcgv_vec_arg(t1),
+                  tcgv_vec_arg(v1), tcgv_vec_arg(t1));
+        vec_gen_3(INDEX_op_shlv_vec, type, vece, tcgv_vec_arg(v0),
+                  tcgv_vec_arg(v1), tcgv_vec_arg(v2));
+        tcg_gen_or_vec(vece, v0, v0, t1);
+        tcg_temp_free_vec(t1);
+        break;
+    case INDEX_op_rotrv_vec:
+        v2 = temp_tcgv_vec(arg_temp(a2));
+        t1 = tcg_temp_new_vec(type);
+        c1 = tcg_constant_vec(type, vece, 8 << vece);
+        tcg_gen_sub_vec(vece, t1, c1, v2);
+        vec_gen_3(INDEX_op_shlv_vec, type, vece, tcgv_vec_arg(t1),
+                  tcgv_vec_arg(v1), tcgv_vec_arg(t1));
+        vec_gen_3(INDEX_op_shrv_vec, type, vece, tcgv_vec_arg(v0),
+                  tcgv_vec_arg(v1), tcgv_vec_arg(v2));
+        tcg_gen_or_vec(vece, v0, v0, t1);
+        tcg_temp_free_vec(t1);
+        break;
     default:
         g_assert_not_reached();
     }
@@ -2475,6 +2563,13 @@  int tcg_can_emit_vec_op(TCGOpcode opc, TCGType type, unsigned vece)
     case INDEX_op_sarv_vec:
         return 1;
     case INDEX_op_cmp_vec:
+    case INDEX_op_shri_vec:
+    case INDEX_op_shli_vec:
+    case INDEX_op_sari_vec:
+    case INDEX_op_rotls_vec:
+    case INDEX_op_rotlv_vec:
+    case INDEX_op_rotrv_vec:
+    case INDEX_op_rotli_vec:
         return -1;
     default:
         return 0;
@@ -2628,6 +2723,13 @@  static TCGConstraintSetIndex tcg_target_op_def(TCGOpcode op)
         return C_O1_I1(v, r);
     case INDEX_op_neg_vec:
     case INDEX_op_not_vec:
+    case INDEX_op_rotli_vec:
+    case INDEX_op_shli_vec:
+    case INDEX_op_shri_vec:
+    case INDEX_op_sari_vec:
+    case INDEX_op_rvv_shli_vec:
+    case INDEX_op_rvv_shri_vec:
+    case INDEX_op_rvv_sari_vec:
         return C_O1_I1(v, v);
     case INDEX_op_add_vec:
     case INDEX_op_sub_vec:
@@ -2646,10 +2748,13 @@  static TCGConstraintSetIndex tcg_target_op_def(TCGOpcode op)
     case INDEX_op_shlv_vec:
     case INDEX_op_shrv_vec:
     case INDEX_op_sarv_vec:
+    case INDEX_op_rotlv_vec:
+    case INDEX_op_rotrv_vec:
         return C_O1_I2(v, v, v);
     case INDEX_op_shls_vec:
     case INDEX_op_shrs_vec:
     case INDEX_op_sars_vec:
+    case INDEX_op_rotls_vec:
         return C_O1_I2(v, v, r);
     case INDEX_op_cmp_vec:
     case INDEX_op_rvv_merge_vec:
diff --git a/tcg/riscv/tcg-target.h b/tcg/riscv/tcg-target.h
index 41c6c446e8..eb5129a976 100644
--- a/tcg/riscv/tcg-target.h
+++ b/tcg/riscv/tcg-target.h
@@ -154,10 +154,10 @@  typedef enum {
 #define TCG_TARGET_HAS_not_vec          1
 #define TCG_TARGET_HAS_neg_vec          1
 #define TCG_TARGET_HAS_abs_vec          0
-#define TCG_TARGET_HAS_roti_vec         0
-#define TCG_TARGET_HAS_rots_vec         0
-#define TCG_TARGET_HAS_rotv_vec         0
-#define TCG_TARGET_HAS_shi_vec          0
+#define TCG_TARGET_HAS_roti_vec         -1
+#define TCG_TARGET_HAS_rots_vec         -1
+#define TCG_TARGET_HAS_rotv_vec         -1
+#define TCG_TARGET_HAS_shi_vec          -1
 #define TCG_TARGET_HAS_shs_vec          1
 #define TCG_TARGET_HAS_shv_vec          1
 #define TCG_TARGET_HAS_mul_vec          1
diff --git a/tcg/riscv/tcg-target.opc.h b/tcg/riscv/tcg-target.opc.h
index 2f23453c35..3a010e853e 100644
--- a/tcg/riscv/tcg-target.opc.h
+++ b/tcg/riscv/tcg-target.opc.h
@@ -13,3 +13,6 @@ 
 
 DEF(rvv_cmpcond_vec, 0, 2, 1, IMPLVEC)
 DEF(rvv_merge_vec, 1, 2, 0, IMPLVEC)
+DEF(rvv_shli_vec, 1, 1, 1, IMPLVEC)
+DEF(rvv_shri_vec, 1, 1, 1, IMPLVEC)
+DEF(rvv_sari_vec, 1, 1, 1, IMPLVEC)