diff mbox series

[1/2] middle-end Update the complex numbers auto-vec detection to the new format of the SLP tree.

Message ID patch-14986-tamar@arm.com
State New
Headers show
Series [1/2] middle-end Update the complex numbers auto-vec detection to the new format of the SLP tree. | expand

Commit Message

Tamar Christina Oct. 29, 2021, 11:06 a.m. UTC
Hi All,

The layout of the SLP tree has changed in GCC 12 which
broke the detection of complex FMA and FMS.

This patch updates the detection to the new tree shape
and by necessity merges the complex MUL and FMA detection
into one.

This does not yet address the wrong code-gen PR which I
will fix in a different patch as that needs backporting.

Regtested on aarch64-none-linux-gnu,
x86_64-pc-linux-gnu and no regressions.

Ok for master?

Thanks,
Tamar

gcc/ChangeLog:

	PR tree-optimization/102977
	* tree-vect-slp-patterns.c (vect_match_call_p): Remove.
	(vect_detect_pair_op): Add crosslane check.
	(vect_match_call_complex_mla): Remove.
	(class complex_mul_pattern): Update comment.
	(complex_mul_pattern::matches): Update detection.
	(class complex_fma_pattern): Remove.
	(complex_fma_pattern::matches): Remove.
	(complex_fma_pattern::recognize): Remove.
	(complex_fma_pattern::build): Remove.
	(class complex_fms_pattern):  Update comment.
	(complex_fms_pattern::matches): Remove.
	(complex_operations_pattern::recognize): Remove complex_fma_pattern

--- inline copy of patch -- 
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index b8d09b7832e29689ede832d555e1b6af2c24ce1e..99dea82aba91a333500bb5ff35bf30b6416c09ca 100644


--

Comments

Richard Biener Oct. 29, 2021, 11:26 a.m. UTC | #1
On Fri, 29 Oct 2021, Tamar Christina wrote:

> Hi All,
> 
> The layout of the SLP tree has changed in GCC 12 which
> broke the detection of complex FMA and FMS.
> 
> This patch updates the detection to the new tree shape
> and by necessity merges the complex MUL and FMA detection
> into one.
> 
> This does not yet address the wrong code-gen PR which I
> will fix in a different patch as that needs backporting.
> 
> Regtested on aarch64-none-linux-gnu,
> x86_64-pc-linux-gnu and no regressions.
> 
> Ok for master?

OK.

Thanks,
Richard.

> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
> 	PR tree-optimization/102977
> 	* tree-vect-slp-patterns.c (vect_match_call_p): Remove.
> 	(vect_detect_pair_op): Add crosslane check.
> 	(vect_match_call_complex_mla): Remove.
> 	(class complex_mul_pattern): Update comment.
> 	(complex_mul_pattern::matches): Update detection.
> 	(class complex_fma_pattern): Remove.
> 	(complex_fma_pattern::matches): Remove.
> 	(complex_fma_pattern::recognize): Remove.
> 	(complex_fma_pattern::build): Remove.
> 	(class complex_fms_pattern):  Update comment.
> 	(complex_fms_pattern::matches): Remove.
> 	(complex_operations_pattern::recognize): Remove complex_fma_pattern
> 
> --- inline copy of patch -- 
> diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
> index b8d09b7832e29689ede832d555e1b6af2c24ce1e..99dea82aba91a333500bb5ff35bf30b6416c09ca 100644
> --- a/gcc/tree-vect-slp-patterns.c
> +++ b/gcc/tree-vect-slp-patterns.c
> @@ -306,24 +306,6 @@ vect_match_expression_p (slp_tree node, tree_code code)
>    return true;
>  }
>  
> -/* Checks to see if the expression represented by NODE is a call to the internal
> -   function FN.  */
> -
> -static inline bool
> -vect_match_call_p (slp_tree node, internal_fn fn)
> -{
> -  if (!node
> -      || !SLP_TREE_REPRESENTATIVE (node))
> -    return false;
> -
> -  gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node));
> -  if (!expr
> -      || !gimple_call_internal_p (expr, fn))
> -    return false;
> -
> -   return true;
> -}
> -
>  /* Check if the given lane permute in PERMUTES matches an alternating sequence
>     of {even odd even odd ...}.  This to account for unrolled loops.  Further
>     mode there resulting permute must be linear.   */
> @@ -389,6 +371,16 @@ vect_detect_pair_op (slp_tree node1, slp_tree node2, lane_permutation_t &lanes,
>  
>    if (result != CMPLX_NONE && ops != NULL)
>      {
> +      if (two_operands)
> +	{
> +	  auto l0node = SLP_TREE_CHILDREN (node1);
> +	  auto l1node = SLP_TREE_CHILDREN (node2);
> +
> +	  /* Check if the tree is connected as we expect it.  */
> +	  if (!((l0node[0] == l1node[0] && l0node[1] == l1node[1])
> +	      || (l0node[0] == l1node[1] && l0node[1] == l1node[0])))
> +	    return CMPLX_NONE;
> +	}
>        ops->safe_push (node1);
>        ops->safe_push (node2);
>      }
> @@ -717,27 +709,6 @@ complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
>   * complex_mul_pattern
>   ******************************************************************************/
>  
> -/* Helper function of that looks for a match in the CHILDth child of NODE.  The
> -   child used is stored in RES.
> -
> -   If the match is successful then ARGS will contain the operands matched
> -   and the complex_operation_t type is returned.  If match is not successful
> -   then CMPLX_NONE is returned and ARGS is left unmodified.  */
> -
> -static inline complex_operation_t
> -vect_match_call_complex_mla (slp_tree node, unsigned child,
> -			     vec<slp_tree> *args = NULL, slp_tree *res = NULL)
> -{
> -  gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
> -
> -  slp_tree data = SLP_TREE_CHILDREN (node)[child];
> -
> -  if (res)
> -    *res = data;
> -
> -  return vect_detect_pair_op (data, false, args);
> -}
> -
>  /* Check to see if either of the trees in ARGS are a NEGATE_EXPR.  If the first
>     child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE.
>  
> @@ -945,9 +916,10 @@ class complex_mul_pattern : public complex_pattern
>  
>  };
>  
> -/* Pattern matcher for trying to match complex multiply pattern in SLP tree
> -   If the operation matches then IFN is set to the operation it matched
> -   and the arguments to the two replacement statements are put in m_ops.
> +/* Pattern matcher for trying to match complex multiply and complex multiply
> +   and accumulate pattern in SLP tree.  If the operation matches then IFN
> +   is set to the operation it matched and the arguments to the two
> +   replacement statements are put in m_ops.
>  
>     If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
>  
> @@ -972,19 +944,43 @@ complex_mul_pattern::matches (complex_operation_t op,
>    if (op != MINUS_PLUS)
>      return IFN_LAST;
>  
> -  slp_tree root = *node;
> -  /* First two nodes must be a multiply.  */
> -  auto_vec<slp_tree> muls;
> -  if (vect_match_call_complex_mla (root, 0) != MULT_MULT
> -      || vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT)
> +  auto childs = *ops;
> +  auto l0node = SLP_TREE_CHILDREN (childs[0]);
> +  auto l1node = SLP_TREE_CHILDREN (childs[1]);
> +
> +  bool mul0 = vect_match_expression_p (l0node[0], MULT_EXPR);
> +  bool mul1 = vect_match_expression_p (l0node[1], MULT_EXPR);
> +  if (!mul0 && !mul1)
>      return IFN_LAST;
>  
>    /* Now operand2+4 may lead to another expression.  */
>    auto_vec<slp_tree> left_op, right_op;
> -  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
> -  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
> +  slp_tree add0 = NULL;
> +
> +  /* Check if we may be a multiply add.  */
> +  if (!mul0
> +      && vect_match_expression_p (l0node[0], PLUS_EXPR))
> +    {
> +      auto vals = SLP_TREE_CHILDREN (l0node[0]);
> +      /* Check if it's a multiply, otherwise no idea what this is.  */
> +      if (!vect_match_expression_p (vals[1], MULT_EXPR))
> +	return IFN_LAST;
> +
> +      /* Check if the ADD is linear, otherwise it's not valid complex FMA.  */
> +      if (linear_loads_p (perm_cache, vals[0]) != PERM_EVENODD)
> +	return IFN_LAST;
>  
> -  if (linear_loads_p (perm_cache, left_op[1]) == PERM_ODDEVEN)
> +      left_op.safe_splice (SLP_TREE_CHILDREN (vals[1]));
> +      add0 = vals[0];
> +    }
> +  else
> +    left_op.safe_splice (SLP_TREE_CHILDREN (l0node[0]));
> +
> +  right_op.safe_splice (SLP_TREE_CHILDREN (l0node[1]));
> +
> +  if (left_op.length () != 2
> +      || right_op.length () != 2
> +      || linear_loads_p (perm_cache, left_op[1]) == PERM_ODDEVEN)
>      return IFN_LAST;
>  
>    bool neg_first = false;
> @@ -998,23 +994,32 @@ complex_mul_pattern::matches (complex_operation_t op,
>        if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN)
>  	  || vect_normalize_conj_loc (left_op))
>  	return IFN_LAST;
> -      ifn = IFN_COMPLEX_MUL;
> +      if (!mul0)
> +	ifn = IFN_COMPLEX_FMA;
> +      else
> +	ifn = IFN_COMPLEX_MUL;
>      }
> -  else if (is_neg)
> +  else
>      {
>        if (!vect_validate_multiplication (perm_cache, left_op, right_op,
>  					 neg_first, &conj_first_operand,
>  					 false))
>  	return IFN_LAST;
>  
> -      ifn = IFN_COMPLEX_MUL_CONJ;
> +      if(!mul0)
> +	ifn = IFN_COMPLEX_FMA_CONJ;
> +      else
> +	ifn = IFN_COMPLEX_MUL_CONJ;
>      }
>  
>    if (!vect_pattern_validate_optab (ifn, *node))
>      return IFN_LAST;
>  
>    ops->truncate (0);
> -  ops->create (3);
> +  ops->create (add0 ? 4 : 3);
> +
> +  if (add0)
> +    ops->quick_push (add0);
>  
>    complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]);
>    if (kind == PERM_EVENODD)
> @@ -1070,170 +1075,55 @@ complex_mul_pattern::build (vec_info *vinfo)
>  {
>    slp_tree node;
>    unsigned i;
> -  slp_tree newnode
> -    = vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node);
> -  SLP_TREE_REF_COUNT (this->m_ops[2])++;
> -
> -  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
> -    vect_free_slp_tree (node);
> -
> -  /* First re-arrange the children.  */
> -  SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
> -  SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
> -  SLP_TREE_CHILDREN (*this->m_node)[1] = newnode;
> +  switch (this->m_ifn)
> +  {
> +    case IFN_COMPLEX_MUL:
> +    case IFN_COMPLEX_MUL_CONJ:
> +      {
> +	slp_tree newnode
> +	  = vect_build_combine_node (this->m_ops[0], this->m_ops[1],
> +				     *this->m_node);
> +	SLP_TREE_REF_COUNT (this->m_ops[2])++;
> +
> +	FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
> +	  vect_free_slp_tree (node);
> +
> +	/* First re-arrange the children.  */
> +	SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
> +	SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
> +	SLP_TREE_CHILDREN (*this->m_node)[1] = newnode;
> +	break;
> +      }
> +    case IFN_COMPLEX_FMA:
> +    case IFN_COMPLEX_FMA_CONJ:
> +      {
> +	SLP_TREE_REF_COUNT (this->m_ops[0])++;
> +	slp_tree newnode
> +	  = vect_build_combine_node (this->m_ops[1], this->m_ops[2],
> +				     *this->m_node);
> +	SLP_TREE_REF_COUNT (this->m_ops[3])++;
> +
> +	FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
> +	  vect_free_slp_tree (node);
> +
> +	/* First re-arrange the children.  */
> +	SLP_TREE_CHILDREN (*this->m_node).safe_grow (3);
> +	SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[0];
> +	SLP_TREE_CHILDREN (*this->m_node)[1] = this->m_ops[3];
> +	SLP_TREE_CHILDREN (*this->m_node)[2] = newnode;
> +
> +	/* Tell the builder to expect an extra argument.  */
> +	this->m_num_args++;
> +	break;
> +      }
> +    default:
> +      gcc_unreachable ();
> +  }
>  
>    /* And then rewrite the node itself.  */
>    complex_pattern::build (vinfo);
>  }
>  
> -/*******************************************************************************
> - * complex_fma_pattern class
> - ******************************************************************************/
> -
> -class complex_fma_pattern : public complex_pattern
> -{
> -  protected:
> -    complex_fma_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
> -      : complex_pattern (node, m_ops, ifn)
> -    {
> -      this->m_num_args = 3;
> -    }
> -
> -  public:
> -    void build (vec_info *);
> -    static internal_fn
> -    matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
> -	     vec<slp_tree> *);
> -
> -    static vect_pattern*
> -    recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
> -
> -    static vect_pattern*
> -    mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
> -    {
> -      return new complex_fma_pattern (node, m_ops, ifn);
> -    }
> -};
> -
> -/* Pattern matcher for trying to match complex multiply and accumulate
> -   and multiply and subtract patterns in SLP tree.
> -   If the operation matches then IFN is set to the operation it matched and
> -   the arguments to the two replacement statements are put in m_ops.
> -
> -   If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
> -
> -   This function matches the patterns shaped as:
> -
> -   double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
> -   double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
> -
> -   c[i] = c[i] - ax;
> -   c[i+1] = c[i+1] + bx;
> -
> -   If a match occurred then TRUE is returned, else FALSE.  The match is
> -   performed after COMPLEX_MUL which would have done the majority of the work.
> -   This function merely matches an ADD with a COMPLEX_MUL IFN.  The initial
> -   match is expected to be in OP1 and the initial match operands in args0.  */
> -
> -internal_fn
> -complex_fma_pattern::matches (complex_operation_t op,
> -			      slp_tree_to_load_perm_map_t * /* perm_cache */,
> -			      slp_tree *ref_node, vec<slp_tree> *ops)
> -{
> -  internal_fn ifn = IFN_LAST;
> -
> -  /* Find the two components.  We match Complex MUL first which reduces the
> -     amount of work this pattern has to do.  After that we just match the
> -     head node and we're done.:
> -
> -     * FMA: + +.
> -
> -     We need to ignore the two_operands nodes that may also match.
> -     For that we can check if they have any scalar statements and also
> -     check that it's not a permute node as we're looking for a normal
> -     PLUS_EXPR operation.  */
> -  if (op != CMPLX_NONE)
> -    return IFN_LAST;
> -
> -  /* Find the two components.  We match Complex MUL first which reduces the
> -     amount of work this pattern has to do.  After that we just match the
> -     head node and we're done.:
> -
> -   * FMA: + + on a non-two_operands node.  */
> -  slp_tree vnode = *ref_node;
> -  if (SLP_TREE_LANE_PERMUTATION (vnode).exists ()
> -      || !SLP_TREE_CHILDREN (vnode).exists ()
> -      || !vect_match_expression_p (vnode, PLUS_EXPR))
> -    return IFN_LAST;
> -
> -  slp_tree node = SLP_TREE_CHILDREN (vnode)[1];
> -
> -  if (vect_match_call_p (node, IFN_COMPLEX_MUL))
> -    ifn = IFN_COMPLEX_FMA;
> -  else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ))
> -    ifn = IFN_COMPLEX_FMA_CONJ;
> -  else
> -    return IFN_LAST;
> -
> -  if (!vect_pattern_validate_optab (ifn, vnode))
> -    return IFN_LAST;
> -
> -  ops->truncate (0);
> -  ops->create (3);
> -
> -  if (ifn == IFN_COMPLEX_FMA)
> -    {
> -      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
> -      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
> -      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
> -    }
> -  else
> -    {
> -      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
> -      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
> -      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
> -    }
> -
> -  return ifn;
> -}
> -
> -/* Attempt to recognize a complex mul pattern.  */
> -
> -vect_pattern*
> -complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
> -				slp_tree *node)
> -{
> -  auto_vec<slp_tree> ops;
> -  complex_operation_t op
> -    = vect_detect_pair_op (*node, true, &ops);
> -  internal_fn ifn
> -    = complex_fma_pattern::matches (op, perm_cache, node, &ops);
> -  if (ifn == IFN_LAST)
> -    return NULL;
> -
> -  return new complex_fma_pattern (node, &ops, ifn);
> -}
> -
> -/* Perform a replacement of the detected complex mul pattern with the new
> -   instruction sequences.  */
> -
> -void
> -complex_fma_pattern::build (vec_info *vinfo)
> -{
> -  slp_tree node = SLP_TREE_CHILDREN (*this->m_node)[1];
> -
> -  SLP_TREE_CHILDREN (*this->m_node).release ();
> -  SLP_TREE_CHILDREN (*this->m_node).create (3);
> -  SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
> -
> -  SLP_TREE_REF_COUNT (this->m_ops[1])++;
> -  SLP_TREE_REF_COUNT (this->m_ops[2])++;
> -
> -  vect_free_slp_tree (node);
> -
> -  complex_pattern::build (vinfo);
> -}
> -
>  /*******************************************************************************
>   * complex_fms_pattern class
>   ******************************************************************************/
> @@ -1264,10 +1154,10 @@ class complex_fms_pattern : public complex_pattern
>  };
>  
>  
> -/* Pattern matcher for trying to match complex multiply and accumulate
> -   and multiply and subtract patterns in SLP tree.
> -   If the operation matches then IFN is set to the operation it matched and
> -   the arguments to the two replacement statements are put in m_ops.
> +/* Pattern matcher for trying to match complex multiply and subtract pattern
> +   in SLP tree.  If the operation matches then IFN is set to the operation
> +   it matched and the arguments to the two replacement statements are put in
> +   m_ops.
>  
>     If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
>  
> @@ -1289,38 +1179,33 @@ complex_fms_pattern::matches (complex_operation_t op,
>  {
>    internal_fn ifn = IFN_LAST;
>  
> -  /* Find the two components.  We match Complex MUL first which reduces the
> -     amount of work this pattern has to do.  After that we just match the
> -     head node and we're done.:
> -
> -     * FMS: - +.  */
> -  slp_tree child = NULL;
> -
>    /* We need to ignore the two_operands nodes that may also match,
>       for that we can check if they have any scalar statements and also
>       check that it's not a permute node as we're looking for a normal
> -     PLUS_EXPR operation.  */
> -  if (op != PLUS_MINUS)
> +     MINUS_EXPR operation.  */
> +  if (op != CMPLX_NONE)
>      return IFN_LAST;
>  
> -  child = SLP_TREE_CHILDREN ((*ops)[1])[1];
> -  if (vect_detect_pair_op (child) != MINUS_PLUS)
> +  slp_tree root = *ref_node;
> +  if (!vect_match_expression_p (root, MINUS_EXPR))
>      return IFN_LAST;
>  
> -  /* First two nodes must be a multiply.  */
> -  auto_vec<slp_tree> muls;
> -  if (vect_match_call_complex_mla (child, 0) != MULT_MULT
> -      || vect_match_call_complex_mla (child, 1, &muls) != MULT_MULT)
> +  auto nodes = SLP_TREE_CHILDREN (root);
> +  if (!vect_match_expression_p (nodes[1], MULT_EXPR)
> +      || vect_detect_pair_op (nodes[0]) != PLUS_MINUS)
>      return IFN_LAST;
>  
> +  auto childs = SLP_TREE_CHILDREN (nodes[0]);
> +  auto l0node = SLP_TREE_CHILDREN (childs[0]);
> +  auto l1node = SLP_TREE_CHILDREN (childs[1]);
> +
>    /* Now operand2+4 may lead to another expression.  */
>    auto_vec<slp_tree> left_op, right_op;
> -  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
> -  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
> +  left_op.safe_splice (SLP_TREE_CHILDREN (l0node[1]));
> +  right_op.safe_splice (SLP_TREE_CHILDREN (nodes[1]));
>  
>    bool is_neg = vect_normalize_conj_loc (left_op);
>  
> -  child = SLP_TREE_CHILDREN ((*ops)[1])[0];
>    bool conj_first_operand = false;
>    if (!vect_validate_multiplication (perm_cache, right_op, left_op, false,
>  				     &conj_first_operand, true))
> @@ -1340,28 +1225,28 @@ complex_fms_pattern::matches (complex_operation_t op,
>    complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
>    if (kind == PERM_EVENODD)
>      {
> -      ops->quick_push (child);
> +      ops->quick_push (l0node[0]);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (right_op[1]);
>        ops->quick_push (left_op[1]);
>      }
>    else if (kind == PERM_TOP)
>      {
> -      ops->quick_push (child);
> +      ops->quick_push (l0node[0]);
>        ops->quick_push (right_op[1]);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (left_op[0]);
>      }
>    else if (kind == PERM_EVENEVEN && !is_neg)
>      {
> -      ops->quick_push (child);
> +      ops->quick_push (l0node[0]);
>        ops->quick_push (right_op[1]);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (left_op[0]);
>      }
>    else
>      {
> -      ops->quick_push (child);
> +      ops->quick_push (l0node[0]);
>        ops->quick_push (right_op[1]);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (left_op[1]);
> @@ -1473,10 +1358,6 @@ complex_operations_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
>    if (ifn != IFN_LAST)
>      return complex_mul_pattern::mkInstance (node, &ops, ifn);
>  
> -  ifn  = complex_fma_pattern::matches (op, perm_cache, node, &ops);
> -  if (ifn != IFN_LAST)
> -    return complex_fma_pattern::mkInstance (node, &ops, ifn);
> -
>    ifn  = complex_add_pattern::matches (op, perm_cache, node, &ops);
>    if (ifn != IFN_LAST)
>      return complex_add_pattern::mkInstance (node, &ops, ifn);
> 
> 
>
diff mbox series

Patch

diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index b8d09b7832e29689ede832d555e1b6af2c24ce1e..99dea82aba91a333500bb5ff35bf30b6416c09ca 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -306,24 +306,6 @@  vect_match_expression_p (slp_tree node, tree_code code)
   return true;
 }
 
-/* Checks to see if the expression represented by NODE is a call to the internal
-   function FN.  */
-
-static inline bool
-vect_match_call_p (slp_tree node, internal_fn fn)
-{
-  if (!node
-      || !SLP_TREE_REPRESENTATIVE (node))
-    return false;
-
-  gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node));
-  if (!expr
-      || !gimple_call_internal_p (expr, fn))
-    return false;
-
-   return true;
-}
-
 /* Check if the given lane permute in PERMUTES matches an alternating sequence
    of {even odd even odd ...}.  This to account for unrolled loops.  Further
    mode there resulting permute must be linear.   */
@@ -389,6 +371,16 @@  vect_detect_pair_op (slp_tree node1, slp_tree node2, lane_permutation_t &lanes,
 
   if (result != CMPLX_NONE && ops != NULL)
     {
+      if (two_operands)
+	{
+	  auto l0node = SLP_TREE_CHILDREN (node1);
+	  auto l1node = SLP_TREE_CHILDREN (node2);
+
+	  /* Check if the tree is connected as we expect it.  */
+	  if (!((l0node[0] == l1node[0] && l0node[1] == l1node[1])
+	      || (l0node[0] == l1node[1] && l0node[1] == l1node[0])))
+	    return CMPLX_NONE;
+	}
       ops->safe_push (node1);
       ops->safe_push (node2);
     }
@@ -717,27 +709,6 @@  complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
  * complex_mul_pattern
  ******************************************************************************/
 
-/* Helper function of that looks for a match in the CHILDth child of NODE.  The
-   child used is stored in RES.
-
-   If the match is successful then ARGS will contain the operands matched
-   and the complex_operation_t type is returned.  If match is not successful
-   then CMPLX_NONE is returned and ARGS is left unmodified.  */
-
-static inline complex_operation_t
-vect_match_call_complex_mla (slp_tree node, unsigned child,
-			     vec<slp_tree> *args = NULL, slp_tree *res = NULL)
-{
-  gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
-
-  slp_tree data = SLP_TREE_CHILDREN (node)[child];
-
-  if (res)
-    *res = data;
-
-  return vect_detect_pair_op (data, false, args);
-}
-
 /* Check to see if either of the trees in ARGS are a NEGATE_EXPR.  If the first
    child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE.
 
@@ -945,9 +916,10 @@  class complex_mul_pattern : public complex_pattern
 
 };
 
-/* Pattern matcher for trying to match complex multiply pattern in SLP tree
-   If the operation matches then IFN is set to the operation it matched
-   and the arguments to the two replacement statements are put in m_ops.
+/* Pattern matcher for trying to match complex multiply and complex multiply
+   and accumulate pattern in SLP tree.  If the operation matches then IFN
+   is set to the operation it matched and the arguments to the two
+   replacement statements are put in m_ops.
 
    If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
 
@@ -972,19 +944,43 @@  complex_mul_pattern::matches (complex_operation_t op,
   if (op != MINUS_PLUS)
     return IFN_LAST;
 
-  slp_tree root = *node;
-  /* First two nodes must be a multiply.  */
-  auto_vec<slp_tree> muls;
-  if (vect_match_call_complex_mla (root, 0) != MULT_MULT
-      || vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT)
+  auto childs = *ops;
+  auto l0node = SLP_TREE_CHILDREN (childs[0]);
+  auto l1node = SLP_TREE_CHILDREN (childs[1]);
+
+  bool mul0 = vect_match_expression_p (l0node[0], MULT_EXPR);
+  bool mul1 = vect_match_expression_p (l0node[1], MULT_EXPR);
+  if (!mul0 && !mul1)
     return IFN_LAST;
 
   /* Now operand2+4 may lead to another expression.  */
   auto_vec<slp_tree> left_op, right_op;
-  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
-  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
+  slp_tree add0 = NULL;
+
+  /* Check if we may be a multiply add.  */
+  if (!mul0
+      && vect_match_expression_p (l0node[0], PLUS_EXPR))
+    {
+      auto vals = SLP_TREE_CHILDREN (l0node[0]);
+      /* Check if it's a multiply, otherwise no idea what this is.  */
+      if (!vect_match_expression_p (vals[1], MULT_EXPR))
+	return IFN_LAST;
+
+      /* Check if the ADD is linear, otherwise it's not valid complex FMA.  */
+      if (linear_loads_p (perm_cache, vals[0]) != PERM_EVENODD)
+	return IFN_LAST;
 
-  if (linear_loads_p (perm_cache, left_op[1]) == PERM_ODDEVEN)
+      left_op.safe_splice (SLP_TREE_CHILDREN (vals[1]));
+      add0 = vals[0];
+    }
+  else
+    left_op.safe_splice (SLP_TREE_CHILDREN (l0node[0]));
+
+  right_op.safe_splice (SLP_TREE_CHILDREN (l0node[1]));
+
+  if (left_op.length () != 2
+      || right_op.length () != 2
+      || linear_loads_p (perm_cache, left_op[1]) == PERM_ODDEVEN)
     return IFN_LAST;
 
   bool neg_first = false;
@@ -998,23 +994,32 @@  complex_mul_pattern::matches (complex_operation_t op,
       if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN)
 	  || vect_normalize_conj_loc (left_op))
 	return IFN_LAST;
-      ifn = IFN_COMPLEX_MUL;
+      if (!mul0)
+	ifn = IFN_COMPLEX_FMA;
+      else
+	ifn = IFN_COMPLEX_MUL;
     }
-  else if (is_neg)
+  else
     {
       if (!vect_validate_multiplication (perm_cache, left_op, right_op,
 					 neg_first, &conj_first_operand,
 					 false))
 	return IFN_LAST;
 
-      ifn = IFN_COMPLEX_MUL_CONJ;
+      if(!mul0)
+	ifn = IFN_COMPLEX_FMA_CONJ;
+      else
+	ifn = IFN_COMPLEX_MUL_CONJ;
     }
 
   if (!vect_pattern_validate_optab (ifn, *node))
     return IFN_LAST;
 
   ops->truncate (0);
-  ops->create (3);
+  ops->create (add0 ? 4 : 3);
+
+  if (add0)
+    ops->quick_push (add0);
 
   complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]);
   if (kind == PERM_EVENODD)
@@ -1070,170 +1075,55 @@  complex_mul_pattern::build (vec_info *vinfo)
 {
   slp_tree node;
   unsigned i;
-  slp_tree newnode
-    = vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node);
-  SLP_TREE_REF_COUNT (this->m_ops[2])++;
-
-  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
-    vect_free_slp_tree (node);
-
-  /* First re-arrange the children.  */
-  SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
-  SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
-  SLP_TREE_CHILDREN (*this->m_node)[1] = newnode;
+  switch (this->m_ifn)
+  {
+    case IFN_COMPLEX_MUL:
+    case IFN_COMPLEX_MUL_CONJ:
+      {
+	slp_tree newnode
+	  = vect_build_combine_node (this->m_ops[0], this->m_ops[1],
+				     *this->m_node);
+	SLP_TREE_REF_COUNT (this->m_ops[2])++;
+
+	FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
+	  vect_free_slp_tree (node);
+
+	/* First re-arrange the children.  */
+	SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
+	SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
+	SLP_TREE_CHILDREN (*this->m_node)[1] = newnode;
+	break;
+      }
+    case IFN_COMPLEX_FMA:
+    case IFN_COMPLEX_FMA_CONJ:
+      {
+	SLP_TREE_REF_COUNT (this->m_ops[0])++;
+	slp_tree newnode
+	  = vect_build_combine_node (this->m_ops[1], this->m_ops[2],
+				     *this->m_node);
+	SLP_TREE_REF_COUNT (this->m_ops[3])++;
+
+	FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
+	  vect_free_slp_tree (node);
+
+	/* First re-arrange the children.  */
+	SLP_TREE_CHILDREN (*this->m_node).safe_grow (3);
+	SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[0];
+	SLP_TREE_CHILDREN (*this->m_node)[1] = this->m_ops[3];
+	SLP_TREE_CHILDREN (*this->m_node)[2] = newnode;
+
+	/* Tell the builder to expect an extra argument.  */
+	this->m_num_args++;
+	break;
+      }
+    default:
+      gcc_unreachable ();
+  }
 
   /* And then rewrite the node itself.  */
   complex_pattern::build (vinfo);
 }
 
-/*******************************************************************************
- * complex_fma_pattern class
- ******************************************************************************/
-
-class complex_fma_pattern : public complex_pattern
-{
-  protected:
-    complex_fma_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
-      : complex_pattern (node, m_ops, ifn)
-    {
-      this->m_num_args = 3;
-    }
-
-  public:
-    void build (vec_info *);
-    static internal_fn
-    matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
-	     vec<slp_tree> *);
-
-    static vect_pattern*
-    recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
-
-    static vect_pattern*
-    mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
-    {
-      return new complex_fma_pattern (node, m_ops, ifn);
-    }
-};
-
-/* Pattern matcher for trying to match complex multiply and accumulate
-   and multiply and subtract patterns in SLP tree.
-   If the operation matches then IFN is set to the operation it matched and
-   the arguments to the two replacement statements are put in m_ops.
-
-   If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
-
-   This function matches the patterns shaped as:
-
-   double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
-   double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
-
-   c[i] = c[i] - ax;
-   c[i+1] = c[i+1] + bx;
-
-   If a match occurred then TRUE is returned, else FALSE.  The match is
-   performed after COMPLEX_MUL which would have done the majority of the work.
-   This function merely matches an ADD with a COMPLEX_MUL IFN.  The initial
-   match is expected to be in OP1 and the initial match operands in args0.  */
-
-internal_fn
-complex_fma_pattern::matches (complex_operation_t op,
-			      slp_tree_to_load_perm_map_t * /* perm_cache */,
-			      slp_tree *ref_node, vec<slp_tree> *ops)
-{
-  internal_fn ifn = IFN_LAST;
-
-  /* Find the two components.  We match Complex MUL first which reduces the
-     amount of work this pattern has to do.  After that we just match the
-     head node and we're done.:
-
-     * FMA: + +.
-
-     We need to ignore the two_operands nodes that may also match.
-     For that we can check if they have any scalar statements and also
-     check that it's not a permute node as we're looking for a normal
-     PLUS_EXPR operation.  */
-  if (op != CMPLX_NONE)
-    return IFN_LAST;
-
-  /* Find the two components.  We match Complex MUL first which reduces the
-     amount of work this pattern has to do.  After that we just match the
-     head node and we're done.:
-
-   * FMA: + + on a non-two_operands node.  */
-  slp_tree vnode = *ref_node;
-  if (SLP_TREE_LANE_PERMUTATION (vnode).exists ()
-      || !SLP_TREE_CHILDREN (vnode).exists ()
-      || !vect_match_expression_p (vnode, PLUS_EXPR))
-    return IFN_LAST;
-
-  slp_tree node = SLP_TREE_CHILDREN (vnode)[1];
-
-  if (vect_match_call_p (node, IFN_COMPLEX_MUL))
-    ifn = IFN_COMPLEX_FMA;
-  else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ))
-    ifn = IFN_COMPLEX_FMA_CONJ;
-  else
-    return IFN_LAST;
-
-  if (!vect_pattern_validate_optab (ifn, vnode))
-    return IFN_LAST;
-
-  ops->truncate (0);
-  ops->create (3);
-
-  if (ifn == IFN_COMPLEX_FMA)
-    {
-      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
-      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
-      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
-    }
-  else
-    {
-      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
-      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
-      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
-    }
-
-  return ifn;
-}
-
-/* Attempt to recognize a complex mul pattern.  */
-
-vect_pattern*
-complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
-				slp_tree *node)
-{
-  auto_vec<slp_tree> ops;
-  complex_operation_t op
-    = vect_detect_pair_op (*node, true, &ops);
-  internal_fn ifn
-    = complex_fma_pattern::matches (op, perm_cache, node, &ops);
-  if (ifn == IFN_LAST)
-    return NULL;
-
-  return new complex_fma_pattern (node, &ops, ifn);
-}
-
-/* Perform a replacement of the detected complex mul pattern with the new
-   instruction sequences.  */
-
-void
-complex_fma_pattern::build (vec_info *vinfo)
-{
-  slp_tree node = SLP_TREE_CHILDREN (*this->m_node)[1];
-
-  SLP_TREE_CHILDREN (*this->m_node).release ();
-  SLP_TREE_CHILDREN (*this->m_node).create (3);
-  SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
-
-  SLP_TREE_REF_COUNT (this->m_ops[1])++;
-  SLP_TREE_REF_COUNT (this->m_ops[2])++;
-
-  vect_free_slp_tree (node);
-
-  complex_pattern::build (vinfo);
-}
-
 /*******************************************************************************
  * complex_fms_pattern class
  ******************************************************************************/
@@ -1264,10 +1154,10 @@  class complex_fms_pattern : public complex_pattern
 };
 
 
-/* Pattern matcher for trying to match complex multiply and accumulate
-   and multiply and subtract patterns in SLP tree.
-   If the operation matches then IFN is set to the operation it matched and
-   the arguments to the two replacement statements are put in m_ops.
+/* Pattern matcher for trying to match complex multiply and subtract pattern
+   in SLP tree.  If the operation matches then IFN is set to the operation
+   it matched and the arguments to the two replacement statements are put in
+   m_ops.
 
    If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
 
@@ -1289,38 +1179,33 @@  complex_fms_pattern::matches (complex_operation_t op,
 {
   internal_fn ifn = IFN_LAST;
 
-  /* Find the two components.  We match Complex MUL first which reduces the
-     amount of work this pattern has to do.  After that we just match the
-     head node and we're done.:
-
-     * FMS: - +.  */
-  slp_tree child = NULL;
-
   /* We need to ignore the two_operands nodes that may also match,
      for that we can check if they have any scalar statements and also
      check that it's not a permute node as we're looking for a normal
-     PLUS_EXPR operation.  */
-  if (op != PLUS_MINUS)
+     MINUS_EXPR operation.  */
+  if (op != CMPLX_NONE)
     return IFN_LAST;
 
-  child = SLP_TREE_CHILDREN ((*ops)[1])[1];
-  if (vect_detect_pair_op (child) != MINUS_PLUS)
+  slp_tree root = *ref_node;
+  if (!vect_match_expression_p (root, MINUS_EXPR))
     return IFN_LAST;
 
-  /* First two nodes must be a multiply.  */
-  auto_vec<slp_tree> muls;
-  if (vect_match_call_complex_mla (child, 0) != MULT_MULT
-      || vect_match_call_complex_mla (child, 1, &muls) != MULT_MULT)
+  auto nodes = SLP_TREE_CHILDREN (root);
+  if (!vect_match_expression_p (nodes[1], MULT_EXPR)
+      || vect_detect_pair_op (nodes[0]) != PLUS_MINUS)
     return IFN_LAST;
 
+  auto childs = SLP_TREE_CHILDREN (nodes[0]);
+  auto l0node = SLP_TREE_CHILDREN (childs[0]);
+  auto l1node = SLP_TREE_CHILDREN (childs[1]);
+
   /* Now operand2+4 may lead to another expression.  */
   auto_vec<slp_tree> left_op, right_op;
-  left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
-  right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
+  left_op.safe_splice (SLP_TREE_CHILDREN (l0node[1]));
+  right_op.safe_splice (SLP_TREE_CHILDREN (nodes[1]));
 
   bool is_neg = vect_normalize_conj_loc (left_op);
 
-  child = SLP_TREE_CHILDREN ((*ops)[1])[0];
   bool conj_first_operand = false;
   if (!vect_validate_multiplication (perm_cache, right_op, left_op, false,
 				     &conj_first_operand, true))
@@ -1340,28 +1225,28 @@  complex_fms_pattern::matches (complex_operation_t op,
   complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
   if (kind == PERM_EVENODD)
     {
-      ops->quick_push (child);
+      ops->quick_push (l0node[0]);
       ops->quick_push (right_op[0]);
       ops->quick_push (right_op[1]);
       ops->quick_push (left_op[1]);
     }
   else if (kind == PERM_TOP)
     {
-      ops->quick_push (child);
+      ops->quick_push (l0node[0]);
       ops->quick_push (right_op[1]);
       ops->quick_push (right_op[0]);
       ops->quick_push (left_op[0]);
     }
   else if (kind == PERM_EVENEVEN && !is_neg)
     {
-      ops->quick_push (child);
+      ops->quick_push (l0node[0]);
       ops->quick_push (right_op[1]);
       ops->quick_push (right_op[0]);
       ops->quick_push (left_op[0]);
     }
   else
     {
-      ops->quick_push (child);
+      ops->quick_push (l0node[0]);
       ops->quick_push (right_op[1]);
       ops->quick_push (right_op[0]);
       ops->quick_push (left_op[1]);
@@ -1473,10 +1358,6 @@  complex_operations_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
   if (ifn != IFN_LAST)
     return complex_mul_pattern::mkInstance (node, &ops, ifn);
 
-  ifn  = complex_fma_pattern::matches (op, perm_cache, node, &ops);
-  if (ifn != IFN_LAST)
-    return complex_fma_pattern::mkInstance (node, &ops, ifn);
-
   ifn  = complex_add_pattern::matches (op, perm_cache, node, &ops);
   if (ifn != IFN_LAST)
     return complex_add_pattern::mkInstance (node, &ops, ifn);