diff mbox series

[v2,05/14] tcg/riscv: Implement vector load/store

Message ID 20240830061607.1940-6-zhiwei_liu@linux.alibaba.com
State New
Headers show
Series tcg/riscv: Add support for vector | expand

Commit Message

LIU Zhiwei Aug. 30, 2024, 6:15 a.m. UTC
From: TANG Tiancheng <tangtiancheng.ttc@alibaba-inc.com>

Signed-off-by: TANG Tiancheng <tangtiancheng.ttc@alibaba-inc.com>
Reviewed-by: Liu Zhiwei <zhiwei_liu@linux.alibaba.com>
---
 tcg/riscv/tcg-target-con-set.h |   2 +
 tcg/riscv/tcg-target.c.inc     | 169 ++++++++++++++++++++++++++++++++-
 2 files changed, 167 insertions(+), 4 deletions(-)

Comments

Richard Henderson Sept. 2, 2024, 1:31 a.m. UTC | #1
On 8/30/24 16:15, LIU Zhiwei wrote:
> @@ -799,6 +834,17 @@ static void tcg_out_ldst(TCGContext *s, RISCVInsn opc, TCGReg data,
>       case OPC_SD:
>           tcg_out_opc_store(s, opc, addr, data, imm12);
>           break;
> +    case OPC_VSE8_V:
> +    case OPC_VSE16_V:
> +    case OPC_VSE32_V:
> +    case OPC_VSE64_V:
> +    case OPC_VS1R_V:
> +    case OPC_VS2R_V:
> +    case OPC_VS4R_V:
> +    case OPC_VS8R_V:
> +        tcg_out_opc_imm(s, OPC_ADDI, TCG_REG_TMP0, addr, imm12);
> +        tcg_out_opc_ldst_vec(s, opc, data, TCG_REG_TMP0, true);
> +        break;
>       case OPC_LB:
>       case OPC_LBU:
>       case OPC_LH:

I think you shouldn't try to handle vector load/store in this same function.
You'll want something like

     if (offset != 0) {
         if (offset == sextreg(offset, 12)) {
             tcg_out_opc_imm(s, OPC_ADDI, TCG_REG_TMP0, addr, offset);
         } else {
             tcg_out_movi(s, TCG_REG_TMP0, offset);
             tcg_out_opc_reg(s, TCG_REG_TMP0, TCG_REG_TMP0, addr);
         }
         addr = TCG_REG_TMP0;
     }

at the top, instead of the imm12 split currently at the top of tcg_out_ldst.


r~
diff mbox series

Patch

diff --git a/tcg/riscv/tcg-target-con-set.h b/tcg/riscv/tcg-target-con-set.h
index aac5ceee2b..d73a62b0f2 100644
--- a/tcg/riscv/tcg-target-con-set.h
+++ b/tcg/riscv/tcg-target-con-set.h
@@ -21,3 +21,5 @@  C_O1_I2(r, rZ, rZ)
 C_N1_I2(r, r, rM)
 C_O1_I4(r, r, rI, rM, rM)
 C_O2_I4(r, r, rZ, rZ, rM, rM)
+C_O0_I2(v, r)
+C_O1_I1(v, r)
diff --git a/tcg/riscv/tcg-target.c.inc b/tcg/riscv/tcg-target.c.inc
index 49d01b8775..6f8814564a 100644
--- a/tcg/riscv/tcg-target.c.inc
+++ b/tcg/riscv/tcg-target.c.inc
@@ -178,8 +178,11 @@  static bool tcg_target_const_match(int64_t val, int ct,
 #define V_OPMVX (0x6 << 12)
 #define V_OPCFG (0x7 << 12)
 
-#define V_SUMOP (0x0 << 20)
-#define V_LUMOP (0x0 << 20)
+#define V_UNIT_STRIDE (0x0 << 20)
+#define V_UNIT_STRIDE_WHOLE_REG (0x8 << 20)
+
+/* NF <= 7 && BNF >= 0 */
+#define V_NF(x) (x << 29)
 
 typedef enum {
     OPC_ADD = 0x33,
@@ -281,6 +284,25 @@  typedef enum {
     OPC_VSETVLI  = 0x57 | V_OPCFG,
     OPC_VSETIVLI = 0xc0000057 | V_OPCFG,
     OPC_VSETVL   = 0x80000057 | V_OPCFG,
+
+    OPC_VLE8_V  = 0x7 | V_UNIT_STRIDE,
+    OPC_VLE16_V = 0x5007 | V_UNIT_STRIDE,
+    OPC_VLE32_V = 0x6007 | V_UNIT_STRIDE,
+    OPC_VLE64_V = 0x7007 | V_UNIT_STRIDE,
+    OPC_VSE8_V  = 0x27 | V_UNIT_STRIDE,
+    OPC_VSE16_V = 0x5027 | V_UNIT_STRIDE,
+    OPC_VSE32_V = 0x6027 | V_UNIT_STRIDE,
+    OPC_VSE64_V = 0x7027 | V_UNIT_STRIDE,
+
+    OPC_VL1RE64_V = 0x2007007 | V_UNIT_STRIDE_WHOLE_REG | V_NF(0),
+    OPC_VL2RE64_V = 0x2007007 | V_UNIT_STRIDE_WHOLE_REG | V_NF(1),
+    OPC_VL4RE64_V = 0x2007007 | V_UNIT_STRIDE_WHOLE_REG | V_NF(3),
+    OPC_VL8RE64_V = 0x2007007 | V_UNIT_STRIDE_WHOLE_REG | V_NF(7),
+
+    OPC_VS1R_V = 0x2000027 | V_UNIT_STRIDE_WHOLE_REG | V_NF(0),
+    OPC_VS2R_V = 0x2000027 | V_UNIT_STRIDE_WHOLE_REG | V_NF(1),
+    OPC_VS4R_V = 0x2000027 | V_UNIT_STRIDE_WHOLE_REG | V_NF(3),
+    OPC_VS8R_V = 0x2000027 | V_UNIT_STRIDE_WHOLE_REG | V_NF(7),
 } RISCVInsn;
 
 /*
@@ -607,6 +629,19 @@  static void tcg_target_set_vec_config(TCGContext *s, TCGType type,
     }
 }
 
+static int riscv_set_vec_config_vl(TCGContext *s, TCGType type)
+{
+    int prev_vsew = prev_vtypei < 0 ? MO_8 : ((prev_vtypei >> 3) & 0x7);
+    tcg_target_set_vec_config(s, type, prev_vsew);
+    return prev_vsew;
+}
+
+static void riscv_set_vec_config_vl_vece(TCGContext *s, TCGType type,
+                                         unsigned vece)
+{
+    tcg_target_set_vec_config(s, type, vece);
+}
+
 /*
  * TCG intrinsics
  */
@@ -799,6 +834,17 @@  static void tcg_out_ldst(TCGContext *s, RISCVInsn opc, TCGReg data,
     case OPC_SD:
         tcg_out_opc_store(s, opc, addr, data, imm12);
         break;
+    case OPC_VSE8_V:
+    case OPC_VSE16_V:
+    case OPC_VSE32_V:
+    case OPC_VSE64_V:
+    case OPC_VS1R_V:
+    case OPC_VS2R_V:
+    case OPC_VS4R_V:
+    case OPC_VS8R_V:
+        tcg_out_opc_imm(s, OPC_ADDI, TCG_REG_TMP0, addr, imm12);
+        tcg_out_opc_ldst_vec(s, opc, data, TCG_REG_TMP0, true);
+        break;
     case OPC_LB:
     case OPC_LBU:
     case OPC_LH:
@@ -808,6 +854,17 @@  static void tcg_out_ldst(TCGContext *s, RISCVInsn opc, TCGReg data,
     case OPC_LD:
         tcg_out_opc_imm(s, opc, data, addr, imm12);
         break;
+    case OPC_VLE8_V:
+    case OPC_VLE16_V:
+    case OPC_VLE32_V:
+    case OPC_VLE64_V:
+    case OPC_VL1RE64_V:
+    case OPC_VL2RE64_V:
+    case OPC_VL4RE64_V:
+    case OPC_VL8RE64_V:
+        tcg_out_opc_imm(s, OPC_ADDI, TCG_REG_TMP0, addr, imm12);
+        tcg_out_opc_ldst_vec(s, opc, data, TCG_REG_TMP0, true);
+        break;
     default:
         g_assert_not_reached();
     }
@@ -816,14 +873,101 @@  static void tcg_out_ldst(TCGContext *s, RISCVInsn opc, TCGReg data,
 static void tcg_out_ld(TCGContext *s, TCGType type, TCGReg arg,
                        TCGReg arg1, intptr_t arg2)
 {
-    RISCVInsn insn = type == TCG_TYPE_I32 ? OPC_LW : OPC_LD;
+    RISCVInsn insn;
+
+    if (type < TCG_TYPE_V64) {
+        insn = (type == TCG_TYPE_I32) ? OPC_LW : OPC_LD;
+    } else {
+        int nf = get_vec_type_bytes(type) / riscv_vlenb;
+
+        switch (nf) {
+        case 1:
+            insn = OPC_VL1RE64_V;
+            break;
+        case 2:
+            insn = OPC_VL2RE64_V;
+            break;
+        case 4:
+            insn = OPC_VL4RE64_V;
+            break;
+        case 8:
+            insn = OPC_VL8RE64_V;
+            break;
+        default:
+            {
+                int prev_vsew = riscv_set_vec_config_vl(s, type);
+
+                switch (prev_vsew) {
+                case MO_8:
+                    insn = OPC_VLE8_V;
+                    break;
+                case MO_16:
+                    insn = OPC_VLE16_V;
+                    break;
+                case MO_32:
+                    insn = OPC_VLE32_V;
+                    break;
+                case MO_64:
+                    insn = OPC_VLE64_V;
+                    break;
+                default:
+                    g_assert_not_reached();
+                }
+            }
+            break;
+        }
+    }
     tcg_out_ldst(s, insn, arg, arg1, arg2);
 }
 
 static void tcg_out_st(TCGContext *s, TCGType type, TCGReg arg,
                        TCGReg arg1, intptr_t arg2)
 {
-    RISCVInsn insn = type == TCG_TYPE_I32 ? OPC_SW : OPC_SD;
+    RISCVInsn insn;
+
+    if (type < TCG_TYPE_V64) {
+        insn = (type == TCG_TYPE_I32) ? OPC_SW : OPC_SD;
+        tcg_out_ldst(s, insn, arg, arg1, arg2);
+    } else {
+        int nf = get_vec_type_bytes(type) / riscv_vlenb;
+
+        switch (nf) {
+        case 1:
+            insn = OPC_VS1R_V;
+            break;
+        case 2:
+            insn = OPC_VS2R_V;
+            break;
+        case 4:
+            insn = OPC_VS4R_V;
+            break;
+        case 8:
+            insn = OPC_VS8R_V;
+            break;
+        default:
+            {
+                int prev_vsew = riscv_set_vec_config_vl(s, type);
+
+                switch (prev_vsew) {
+                case MO_8:
+                    insn = OPC_VSE8_V;
+                    break;
+                case MO_16:
+                    insn = OPC_VSE16_V;
+                    break;
+                case MO_32:
+                    insn = OPC_VSE32_V;
+                    break;
+                case MO_64:
+                    insn = OPC_VSE64_V;
+                    break;
+                default:
+                    g_assert_not_reached();
+                }
+            }
+            break;
+        }
+    }
     tcg_out_ldst(s, insn, arg, arg1, arg2);
 }
 
@@ -2018,7 +2162,20 @@  static void tcg_out_vec_op(TCGContext *s, TCGOpcode opc,
                            const TCGArg args[TCG_MAX_OP_ARGS],
                            const int const_args[TCG_MAX_OP_ARGS])
 {
+    TCGType type = vecl + TCG_TYPE_V64;
+    TCGArg a0, a1, a2;
+
+    a0 = args[0];
+    a1 = args[1];
+    a2 = args[2];
+
     switch (opc) {
+    case INDEX_op_ld_vec:
+        tcg_out_ld(s, type, a0, a1, a2);
+        break;
+    case INDEX_op_st_vec:
+        tcg_out_st(s, type, a0, a1, a2);
+        break;
     case INDEX_op_mov_vec: /* Always emitted via tcg_out_mov.  */
     case INDEX_op_dup_vec: /* Always emitted via tcg_out_dup_vec.  */
     default:
@@ -2182,6 +2339,10 @@  static TCGConstraintSetIndex tcg_target_op_def(TCGOpcode op)
     case INDEX_op_qemu_st_a64_i64:
         return C_O0_I2(rZ, r);
 
+    case INDEX_op_st_vec:
+        return C_O0_I2(v, r);
+    case INDEX_op_ld_vec:
+        return C_O1_I1(v, r);
     default:
         g_assert_not_reached();
     }