diff mbox series

[v2,5/8] OpenMP: C++ front-end support for dispatch + adjust_args

Message ID 20240712141155.255186-6-parras@baylibre.com
State New
Headers show
Series OpenMP: dispatch + adjust_args support | expand

Commit Message

Paul-Antoine Arras July 12, 2024, 2:11 p.m. UTC
This patch adds C++ support for the `dispatch` construct and the `adjust_args`
clause. It relies on the c-family bits comprised in the corresponding C front
end patch for pragmas and attributes.

Additional C/C++ common testcases are provided in a subsequent patch in the
series.

gcc/cp/ChangeLog:

	* decl.cc (omp_declare_variant_finalize_one): Set adjust_args
	need_device_ptr attribute.
	* parser.cc (cp_parser_direct_declarator): Update call to
	cp_parser_late_return_type_opt.
	(cp_parser_late_return_type_opt): Add parameter. Update call to
	cp_parser_late_parsing_omp_declare_simd.
	(cp_parser_omp_clause_name): Handle nocontext and novariants clauses.
	(cp_parser_omp_clause_novariants): New function.
	(cp_parser_omp_clause_nocontext): Likewise.
	(cp_parser_omp_all_clauses): Handle PRAGMA_OMP_CLAUSE_NOVARIANTS and
	PRAGMA_OMP_CLAUSE_NOCONTEXT.
	(cp_parser_omp_dispatch_body): New function, inspired from
	cp_parser_assignment_expression and cp_parser_postfix_expression.
	(OMP_DISPATCH_CLAUSE_MASK): Define.
	(cp_parser_omp_dispatch): New function.
	(cp_finish_omp_declare_variant): Add parameter. Handle adjust_args
	clause.
	(cp_parser_late_parsing_omp_declare_simd): Add parameter. Update calls
	to cp_finish_omp_declare_variant and cp_finish_omp_declare_variant.
	(cp_parser_omp_construct): Handle PRAGMA_OMP_DISPATCH.
	(cp_parser_pragma): Likewise.
	* pt.cc (tsubst_attribute): Skip pseudo-TSS need_device_ptr.
	* semantics.cc (finish_omp_clauses): Handle OMP_CLAUSE_NOCONTEXT and
	OMP_CLAUSE_NOVARIANTS.

gcc/testsuite/ChangeLog:

	* g++.dg/gomp/adjust-args-1.C: New test.
	* g++.dg/gomp/adjust-args-2.C: New test.
	* g++.dg/gomp/dispatch-1.C: New test.
	* g++.dg/gomp/dispatch-2.C: New test.
---
 gcc/cp/decl.cc                            |  33 ++
 gcc/cp/parser.cc                          | 612 ++++++++++++++++++++--
 gcc/cp/pt.cc                              |   3 +
 gcc/cp/semantics.cc                       |  20 +
 gcc/testsuite/g++.dg/gomp/adjust-args-1.C |  39 ++
 gcc/testsuite/g++.dg/gomp/adjust-args-2.C |  51 ++
 gcc/testsuite/g++.dg/gomp/dispatch-1.C    |  53 ++
 gcc/testsuite/g++.dg/gomp/dispatch-2.C    |  62 +++
 8 files changed, 828 insertions(+), 45 deletions(-)
 create mode 100644 gcc/testsuite/g++.dg/gomp/adjust-args-1.C
 create mode 100644 gcc/testsuite/g++.dg/gomp/adjust-args-2.C
 create mode 100644 gcc/testsuite/g++.dg/gomp/dispatch-1.C
 create mode 100644 gcc/testsuite/g++.dg/gomp/dispatch-2.C
diff mbox series

Patch

diff --git a/gcc/cp/decl.cc b/gcc/cp/decl.cc
index edf4c155bf7..fae37d508b0 100644
--- a/gcc/cp/decl.cc
+++ b/gcc/cp/decl.cc
@@ -8375,6 +8375,39 @@  omp_declare_variant_finalize_one (tree decl, tree attr)
 	  if (!omp_context_selector_matches (ctx))
 	    return true;
 	  TREE_PURPOSE (TREE_VALUE (attr)) = variant;
+
+	  for (tree a = ctx; a != NULL_TREE; a = TREE_CHAIN (a))
+	    {
+	      if (OMP_TSS_CODE (a) == OMP_TRAIT_SET_NEED_DEVICE_PTR)
+		{
+		  for (tree need_device_ptr_list = TREE_VALUE (a);
+		       need_device_ptr_list != NULL_TREE;
+		       need_device_ptr_list = TREE_CHAIN (need_device_ptr_list))
+		    {
+		      tree parm_decl = TREE_VALUE (need_device_ptr_list);
+		      bool found_arg = false;
+		      for (tree arg = DECL_ARGUMENTS (variant); arg != NULL;
+			   arg = TREE_CHAIN (arg))
+			if (DECL_NAME (arg) == DECL_NAME (parm_decl))
+			  {
+			    DECL_ATTRIBUTES (arg)
+			      = tree_cons (get_identifier (
+					     "omp declare variant adjust_args "
+					     "need_device_ptr"),
+					   NULL_TREE, DECL_ATTRIBUTES (arg));
+			    found_arg = true;
+			    break;
+			  }
+		      if (!found_arg)
+			{
+			  error_at (varid_loc,
+				    "variant %qD does not have a parameter %qD",
+				    variant, parm_decl);
+			  return true;
+			}
+		    }
+		}
+	    }
 	}
     }
   else if (!processing_template_decl)
diff --git a/gcc/cp/parser.cc b/gcc/cp/parser.cc
index 6bf3f52a059..b85c9c387fb 100644
--- a/gcc/cp/parser.cc
+++ b/gcc/cp/parser.cc
@@ -19,6 +19,7 @@  along with GCC; see the file COPYING3.  If not see
 <http://www.gnu.org/licenses/>.  */
 
 #include "config.h"
+#include "omp-selectors.h"
 #define INCLUDE_MEMORY
 #include "system.h"
 #include "coretypes.h"
@@ -2587,7 +2588,7 @@  static cp_ref_qualifier cp_parser_ref_qualifier_opt
 static tree cp_parser_tx_qualifier_opt
   (cp_parser *);
 static tree cp_parser_late_return_type_opt
-  (cp_parser *, cp_declarator *, tree &);
+  (cp_parser *, cp_declarator *, tree &, tree);
 static tree cp_parser_declarator_id
   (cp_parser *, bool);
 static tree cp_parser_type_id
@@ -2622,7 +2623,7 @@  static void cp_parser_ctor_initializer_opt_and_function_body
   (cp_parser *, bool);
 
 static tree cp_parser_late_parsing_omp_declare_simd
-  (cp_parser *, tree);
+  (cp_parser *, tree, tree);
 
 static tree cp_parser_late_parsing_oacc_routine
   (cp_parser *, tree);
@@ -24154,7 +24155,7 @@  cp_parser_direct_declarator (cp_parser* parser,
 		  tree requires_clause = NULL_TREE;
 		  late_return
 		    = cp_parser_late_return_type_opt (parser, declarator,
-						      requires_clause);
+						      requires_clause, params);
 
 		  cp_finalize_omp_declare_simd (parser, &odsd);
 
@@ -25019,8 +25020,8 @@  parsing_function_declarator ()
    function.  */
 
 static tree
-cp_parser_late_return_type_opt (cp_parser* parser, cp_declarator *declarator,
-				tree& requires_clause)
+cp_parser_late_return_type_opt (cp_parser *parser, cp_declarator *declarator,
+				tree &requires_clause, tree parms)
 {
   cp_token *token;
   tree type = NULL_TREE;
@@ -25056,8 +25057,8 @@  cp_parser_late_return_type_opt (cp_parser* parser, cp_declarator *declarator,
 
   if (declare_simd_p)
     declarator->attributes
-      = cp_parser_late_parsing_omp_declare_simd (parser,
-						 declarator->attributes);
+      = cp_parser_late_parsing_omp_declare_simd (parser, declarator->attributes,
+						 parms);
   if (oacc_routine_p)
     declarator->attributes
       = cp_parser_late_parsing_oacc_routine (parser,
@@ -38129,6 +38130,8 @@  cp_parser_omp_clause_name (cp_parser *parser)
 	case 'n':
 	  if (!strcmp ("no_create", p))
 	    result = PRAGMA_OACC_CLAUSE_NO_CREATE;
+	  else if (!strcmp ("nocontext", p))
+	    result = PRAGMA_OMP_CLAUSE_NOCONTEXT;
 	  else if (!strcmp ("nogroup", p))
 	    result = PRAGMA_OMP_CLAUSE_NOGROUP;
 	  else if (!strcmp ("nohost", p))
@@ -38137,6 +38140,8 @@  cp_parser_omp_clause_name (cp_parser *parser)
 	    result = PRAGMA_OMP_CLAUSE_NONTEMPORAL;
 	  else if (!strcmp ("notinbranch", p))
 	    result = PRAGMA_OMP_CLAUSE_NOTINBRANCH;
+	  else if (!strcmp ("novariants", p))
+	    result = PRAGMA_OMP_CLAUSE_NOVARIANTS;
 	  else if (!strcmp ("nowait", p))
 	    result = PRAGMA_OMP_CLAUSE_NOWAIT;
 	  else if (!strcmp ("num_gangs", p))
@@ -40583,6 +40588,56 @@  cp_parser_omp_clause_partial (cp_parser *parser, tree list, location_t loc)
   return c;
 }
 
+/* OpenMP 5.1
+   novariants ( scalar-expression ) */
+
+static tree
+cp_parser_omp_clause_novariants (cp_parser *parser, tree list, location_t loc)
+{
+  matching_parens parens;
+  if (!parens.require_open (parser))
+    return list;
+
+  tree t = cp_parser_assignment_expression (parser);
+  if (t == error_mark_node || !parens.require_close (parser))
+    cp_parser_skip_to_closing_parenthesis (parser, /*recovering=*/true,
+					   /*or_comma=*/false,
+					   /*consume_paren=*/true);
+
+  check_no_duplicate_clause (list, OMP_CLAUSE_NOVARIANTS, "novariants", loc);
+
+  tree c = build_omp_clause (loc, OMP_CLAUSE_NOVARIANTS);
+  OMP_CLAUSE_NOVARIANTS_EXPR (c) = t;
+  OMP_CLAUSE_CHAIN (c) = list;
+
+  return c;
+}
+
+/* OpenMP 5.1
+   nocontext ( scalar-expression ) */
+
+static tree
+cp_parser_omp_clause_nocontext (cp_parser *parser, tree list, location_t loc)
+{
+  matching_parens parens;
+  if (!parens.require_open (parser))
+    return list;
+
+  tree t = cp_parser_assignment_expression (parser);
+  if (t == error_mark_node || !parens.require_close (parser))
+    cp_parser_skip_to_closing_parenthesis (parser, /*recovering=*/true,
+					   /*or_comma=*/false,
+					   /*consume_paren=*/true);
+
+  check_no_duplicate_clause (list, OMP_CLAUSE_NOCONTEXT, "nocontext", loc);
+
+  tree c = build_omp_clause (loc, OMP_CLAUSE_NOCONTEXT);
+  OMP_CLAUSE_NOCONTEXT_EXPR (c) = t;
+  OMP_CLAUSE_CHAIN (c) = list;
+
+  return c;
+}
+
 /* OpenMP 4.0:
    aligned ( variable-list )
    aligned ( variable-list : constant-expression )  */
@@ -42690,6 +42745,16 @@  cp_parser_omp_all_clauses (cp_parser *parser, omp_clause_mask mask,
 	  clauses = cp_parser_omp_clause_full (clauses, token->location);
 	  c_name = "full";
 	  break;
+	case PRAGMA_OMP_CLAUSE_NOVARIANTS:
+	  clauses = cp_parser_omp_clause_novariants (parser, clauses,
+						     token->location);
+	  c_name = "novariants";
+	  break;
+	case PRAGMA_OMP_CLAUSE_NOCONTEXT:
+	  clauses
+	    = cp_parser_omp_clause_nocontext (parser, clauses, token->location);
+	  c_name = "nocontext";
+	  break;
 	default:
 	  cp_parser_error (parser, "expected an OpenMP clause");
 	  goto saw_error;
@@ -48895,12 +48960,305 @@  cp_parser_omp_assumes (cp_parser *parser, cp_token *pragma_tok)
   return false;
 }
 
+/* Parse a function dispatch structured block:
+
+    lvalue-expression = target-call ( [expression-list] );
+    or
+    target-call ( [expression-list] );
+
+  Inspired from cp_parser_assignment_expression and
+  cp_parser_postfix_expression.
+*/
+
+static tree
+cp_parser_omp_dispatch_body (cp_parser *parser)
+{
+  cp_expr expr;
+  cp_id_kind idk = CP_ID_KIND_NONE;
+
+  /* Parse the binary expressions (lvalue-expression or target-call).  */
+  expr = cp_parser_binary_expression (parser, false, false, false,
+				      PREC_NOT_OPERATOR, NULL);
+  if (TREE_CODE (expr) == CALL_EXPR || TREE_CODE (expr) == ERROR_MARK)
+    return expr;
+
+  /* We have the lvalue, now deal with the assignment. */
+
+  if (!cp_parser_require (parser, CPP_EQ, RT_EQ))
+    return error_mark_node;
+
+  /* Peek at the next token.  */
+  cp_token *token = cp_lexer_peek_token (parser->lexer);
+  location_t loc = token->location;
+  location_t start_loc = get_range_from_loc (line_table, loc).m_start;
+
+  /* Parse function name as primary expression.  */
+  cp_expr rhs
+    = cp_parser_primary_expression (parser, false, false, false, false, &idk);
+  if (TREE_CODE (rhs) == ERROR_MARK)
+    return rhs;
+
+  /* Keep looping until the postfix-expression is complete.  */
+  bool parens_found = false;
+  while (true)
+    {
+      if (idk == CP_ID_KIND_UNQUALIFIED && identifier_p (rhs)
+	  && cp_lexer_next_token_is_not (parser->lexer, CPP_OPEN_PAREN))
+	/* It is not a Koenig lookup function call.  */
+	rhs = unqualified_name_lookup_error (rhs);
+
+      /* Peek at the next token.  */
+      token = cp_lexer_peek_token (parser->lexer);
+
+      switch (token->type)
+	{
+	case CPP_OPEN_PAREN:
+	  /* postfix-expression ( expression-list [opt] ) */
+	  {
+	    if (parens_found)
+	      {
+		cp_parser_error (
+		  parser,
+		  "only one function call is allowed in a dispatch construct");
+		return error_mark_node;
+	      }
+	    parens_found = true;
+
+	    bool koenig_p;
+	    tsubst_flags_t complain = complain_flags (false);
+	    vec<tree, va_gc> *args;
+	    location_t close_paren_loc = UNKNOWN_LOCATION;
+	    location_t combined_loc = UNKNOWN_LOCATION;
+
+	    args = (cp_parser_parenthesized_expression_list (
+	      parser, non_attr,
+	      /*cast_p=*/false, /*allow_expansion_p=*/true,
+	      /*non_constant_p=*/NULL,
+	      /*close_paren_loc=*/&close_paren_loc,
+	      /*wrap_locations_p=*/true));
+
+	    if (args == NULL)
+	      {
+		rhs = error_mark_node;
+		break;
+	      }
+
+	    koenig_p = false;
+	    if (idk == CP_ID_KIND_UNQUALIFIED || idk == CP_ID_KIND_TEMPLATE_ID)
+	      {
+		if (identifier_p (rhs)
+		    /* In C++20, we may need to perform ADL for a template
+		       name.  */
+		    || (TREE_CODE (rhs) == TEMPLATE_ID_EXPR
+			&& identifier_p (TREE_OPERAND (rhs, 0))))
+		  {
+		    if (!args->is_empty ())
+		      {
+			koenig_p = true;
+			if (!any_type_dependent_arguments_p (args))
+			  rhs = perform_koenig_lookup (rhs, args, complain);
+		      }
+		    else
+		      rhs = unqualified_fn_lookup_error (rhs);
+		  }
+		/* We do not perform argument-dependent lookup if
+		   normal lookup finds a non-function, in accordance
+		   with the expected resolution of DR 218.  */
+		else if (!args->is_empty () && is_overloaded_fn (rhs))
+		  {
+		    /* Do not do argument dependent lookup if regular
+		       lookup finds a member function or a block-scope
+		       function declaration.  [basic.lookup.argdep]/3  */
+		    bool do_adl_p = true;
+		    tree fns = get_fns (rhs);
+		    for (lkp_iterator iter (fns); iter; ++iter)
+		      {
+			tree fn = STRIP_TEMPLATE (*iter);
+			if ((TREE_CODE (fn) == USING_DECL
+			     && DECL_DEPENDENT_P (fn))
+			    || DECL_FUNCTION_MEMBER_P (fn)
+			    || DECL_LOCAL_DECL_P (fn))
+			  {
+			    do_adl_p = false;
+			    break;
+			  }
+		      }
+
+		    if (do_adl_p)
+		      {
+			koenig_p = true;
+			if (!any_type_dependent_arguments_p (args))
+			  rhs = perform_koenig_lookup (rhs, args, complain);
+		      }
+		  }
+	      }
+
+	    /* Temporarily set input_location to the combined location
+	       with call expression range, as e.g. build_out_target_exprs
+	       called from convert_default_arg relies on input_location,
+	       so updating it only when the call is fully built results
+	       in inconsistencies between location handling in templates
+	       and outside of templates.  */
+	    if (close_paren_loc != UNKNOWN_LOCATION)
+	      combined_loc
+		= make_location (token->location, start_loc, close_paren_loc);
+	    iloc_sentinel ils (combined_loc);
+
+	    if (TREE_CODE (rhs) == COMPONENT_REF)
+	      {
+		tree instance = TREE_OPERAND (rhs, 0);
+		tree fn = TREE_OPERAND (rhs, 1);
+
+		if (processing_template_decl
+		    && (type_dependent_object_expression_p (instance)
+			|| (!BASELINK_P (fn) && TREE_CODE (fn) != FIELD_DECL)
+			|| type_dependent_expression_p (fn)
+			|| any_type_dependent_arguments_p (args)))
+		  {
+		    maybe_generic_this_capture (instance, fn);
+		    rhs = build_min_nt_call_vec (rhs, args);
+		  }
+		else if (BASELINK_P (fn))
+		  {
+		    rhs
+		      = (build_new_method_call (instance, fn, &args, NULL_TREE,
+						(idk == CP_ID_KIND_QUALIFIED
+						   ? LOOKUP_NORMAL
+						       | LOOKUP_NONVIRTUAL
+						   : LOOKUP_NORMAL),
+						/*fn_p=*/NULL, complain));
+		  }
+		else
+		  rhs = finish_call_expr (rhs, &args,
+					  /*disallow_virtual=*/false,
+					  /*koenig_p=*/false, complain);
+	      }
+	    else if (TREE_CODE (rhs) == OFFSET_REF
+		     || TREE_CODE (rhs) == MEMBER_REF
+		     || TREE_CODE (rhs) == DOTSTAR_EXPR)
+	      rhs = (build_offset_ref_call_from_tree (rhs, &args, complain));
+	    else if (idk == CP_ID_KIND_QUALIFIED)
+	      /* A call to a static class member, or a namespace-scope
+		 function.  */
+	      rhs = finish_call_expr (rhs, &args,
+				      /*disallow_virtual=*/true, koenig_p,
+				      complain);
+	    else
+	      /* All other function calls.  */
+	      {
+		if (DECL_P (rhs) && parser->omp_for_parse_state
+		    && parser->omp_for_parse_state->in_intervening_code
+		    && omp_runtime_api_call (rhs))
+		  {
+		    error_at (loc, "calls to the OpenMP runtime API are "
+				   "not permitted in intervening code");
+		    parser->omp_for_parse_state->fail = true;
+		  }
+		rhs = finish_call_expr (rhs, &args,
+					/*disallow_virtual=*/false, koenig_p,
+					complain);
+	      }
+	    if (close_paren_loc != UNKNOWN_LOCATION)
+	      rhs.set_location (combined_loc);
+
+	    /* The expr is certainly no longer an id.  */
+	    idk = CP_ID_KIND_NONE;
+
+	    release_tree_vector (args);
+	  }
+	  break;
+
+	case CPP_DOT:
+	case CPP_DEREF:
+	  /* postfix-expression . template [opt] id-expression
+	     postfix-expression . pseudo-destructor-name
+	     postfix-expression -> template [opt] id-expression
+	     postfix-expression -> pseudo-destructor-name */
+
+	  /* Consume the `.' or `->' operator.  */
+	  cp_lexer_consume_token (parser->lexer);
+
+	  rhs = cp_parser_postfix_dot_deref_expression (parser, token->type,
+							rhs, false, &idk, loc);
+
+	  break;
+
+	default:
+	  goto finish;
+	}
+    }
+finish:
+  if (!parens_found)
+    {
+      cp_parser_error (parser, "expected %<(%>");
+      return error_mark_node;
+    }
+
+  /* Build the assignment expression.  Its default
+     location:
+       LHS = RHS
+       ~~~~^~~~~
+     is the location of the '=' token as the
+     caret, ranging from the start of the lhs to the
+     end of the rhs.  */
+  loc = make_location (loc, expr.get_start (), rhs.get_finish ());
+  expr
+    = cp_build_modify_expr (loc, expr, NOP_EXPR, rhs, complain_flags (false));
+
+  return expr;
+}
+
+/* OpenMP 5.1:
+   # pragma omp dispatch dispatch-clause[optseq] new-line
+     expression-stmt
+
+   LOC is the location of the #pragma.
+*/
+
+#define OMP_DISPATCH_CLAUSE_MASK                                               \
+  ((OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_DEVICE)                             \
+   | (OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_DEPEND)                           \
+   | (OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_NOVARIANTS)                       \
+   | (OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_NOCONTEXT)                        \
+   | (OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_IS_DEVICE_PTR)                    \
+   | (OMP_CLAUSE_MASK_1 << PRAGMA_OMP_CLAUSE_NOWAIT))
+
+static tree
+cp_parser_omp_dispatch (cp_parser *parser, cp_token *pragma_tok)
+{
+  location_t loc = cp_lexer_peek_token (parser->lexer)->location;
+  tree stmt = make_node (OMP_DISPATCH);
+  SET_EXPR_LOCATION (stmt, loc);
+  TREE_TYPE (stmt) = void_type_node;
+
+  OMP_DISPATCH_CLAUSES (stmt)
+    = cp_parser_omp_all_clauses (parser, OMP_DISPATCH_CLAUSE_MASK,
+				 "#pragma omp dispatch", pragma_tok);
+
+  // Parse expression statement
+  loc = cp_lexer_peek_token (parser->lexer)->location;
+  tree dispatch_body = cp_parser_omp_dispatch_body (parser);
+  if (dispatch_body == error_mark_node)
+    {
+      inform (loc,
+	      "%<#pragma omp dispatch%> must be followed by a direct function "
+	      "call with optional assignment");
+      cp_parser_skip_to_end_of_block_or_statement (parser);
+      return NULL_TREE;
+    }
+
+  cp_parser_consume_semicolon_at_end_of_statement (parser);
+  OMP_DISPATCH_BODY (stmt) = dispatch_body;
+
+  return add_stmt (stmt);
+}
+
 /* Finalize #pragma omp declare variant after a fndecl has been parsed, and put
    that into "omp declare variant base" attribute.  */
 
 static tree
 cp_finish_omp_declare_variant (cp_parser *parser, cp_token *pragma_tok,
-			       tree attrs)
+			       tree attrs, tree parms)
 {
   matching_parens parens;
   if (!parens.require_open (parser))
@@ -48958,44 +49316,197 @@  cp_finish_omp_declare_variant (cp_parser *parser, cp_token *pragma_tok,
   location_t finish_loc = get_finish (varid.get_location ());
   location_t varid_loc = make_location (caret_loc, start_loc, finish_loc);
 
-  if (cp_lexer_next_token_is (parser->lexer, CPP_COMMA)
-      && cp_lexer_nth_token_is (parser->lexer, 2, CPP_NAME))
-    cp_lexer_consume_token (parser->lexer);
+  vec<tree> adjust_args_list = vNULL;
+  bool has_match = false, has_adjust_args = false;
+  location_t adjust_args_loc = UNKNOWN_LOCATION;
+  tree need_device_ptr_list = NULL_TREE, *need_device_ptr_chain_p = NULL;
 
-  const char *clause = "";
-  location_t match_loc = cp_lexer_peek_token (parser->lexer)->location;
-  if (cp_lexer_next_token_is (parser->lexer, CPP_NAME))
-    clause = IDENTIFIER_POINTER (cp_lexer_peek_token (parser->lexer)->u.value);
-  if (strcmp (clause, "match"))
+  do
     {
-      cp_parser_error (parser, "expected %<match%>");
-      goto fail;
+      if (cp_lexer_next_token_is (parser->lexer, CPP_COMMA)
+	  && cp_lexer_nth_token_is (parser->lexer, 2, CPP_NAME))
+	cp_lexer_consume_token (parser->lexer);
+
+      const char *clause = "";
+      location_t match_loc = cp_lexer_peek_token (parser->lexer)->location;
+      if (cp_lexer_next_token_is (parser->lexer, CPP_NAME))
+	clause
+	  = IDENTIFIER_POINTER (cp_lexer_peek_token (parser->lexer)->u.value);
+
+      enum clause
+      {
+	match,
+	adjust_args
+      } ccode;
+
+      if (strcmp (clause, "match") == 0)
+	ccode = match;
+      else if (strcmp (clause, "adjust_args") == 0)
+	{
+	  ccode = adjust_args;
+	  adjust_args_loc = match_loc;
+	}
+      else
+	{
+	  cp_parser_error (parser, "expected %<match%> or %<adjust_args%>");
+	  goto fail;
+	}
+
+      cp_lexer_consume_token (parser->lexer);
+
+      if (!parens.require_open (parser))
+	goto fail;
+
+      if (ccode == match)
+	{
+	  has_match = true;
+	  tree ctx
+	    = cp_parser_omp_context_selector_specification (parser, true);
+	  if (ctx == error_mark_node)
+	    goto fail;
+	  ctx = omp_check_context_selector (match_loc, ctx);
+	  if (ctx != error_mark_node && variant != error_mark_node)
+	    {
+	      tree match_loc_node
+		= maybe_wrap_with_location (integer_zero_node, match_loc);
+	      tree loc_node
+		= maybe_wrap_with_location (integer_zero_node, varid_loc);
+	      loc_node
+		= tree_cons (match_loc_node,
+			     build_int_cst (integer_type_node, idk),
+			     build_tree_list (loc_node, integer_zero_node));
+	      attrs = tree_cons (get_identifier ("omp declare variant base"),
+				 tree_cons (variant, ctx, loc_node), attrs);
+	      if (processing_template_decl)
+		ATTR_IS_DEPENDENT (attrs) = 1;
+	    }
+	  if (!parens.require_close (parser))
+	    goto fail;
+	}
+      else if (ccode == adjust_args)
+	{
+	  has_adjust_args = true;
+	  cp_token *adjust_op_tok = cp_lexer_peek_token (parser->lexer);
+	  if (cp_lexer_next_token_is (parser->lexer, CPP_NAME)
+	      && cp_lexer_nth_token_is (parser->lexer, 2, CPP_COLON))
+	    {
+	      const char *p = IDENTIFIER_POINTER (adjust_op_tok->u.value);
+	      if (strcmp (p, "need_device_ptr") == 0
+		  || strcmp (p, "nothing") == 0)
+		{
+		  cp_lexer_consume_token (parser->lexer); // need_device_ptr
+		  cp_lexer_consume_token (parser->lexer); // :
+		  location_t arg_loc
+		    = cp_lexer_peek_token (parser->lexer)->location;
+
+		  tree arg;
+		  tree list
+		    = cp_parser_omp_var_list_no_open (parser, OMP_CLAUSE_ERROR,
+						      NULL_TREE, NULL);
+
+		  for (tree c = list; c != NULL_TREE; c = TREE_CHAIN (c))
+		    {
+		      tree decl = TREE_PURPOSE (c);
+		      int idx;
+		      for (arg = parms, idx = 0; arg != NULL;
+			   arg = TREE_CHAIN (arg), idx++)
+			if (TREE_VALUE (arg) == decl)
+			  break;
+		      if (arg == NULL_TREE)
+			{
+			  error_at (arg_loc, "%qD is not a function argument",
+				    decl);
+			  continue;
+			}
+		      arg = TREE_VALUE (arg);
+		      if (adjust_args_list.contains (arg))
+			{
+			  error_at (arg_loc, "%qD is specified more than once",
+				    decl);
+			  continue;
+			}
+		      if (strcmp (p, "need_device_ptr") == 0)
+			{
+			  bool is_ptr_or_template
+			    = TEMPLATE_PARM_P (TREE_TYPE (arg))
+			      || POINTER_TYPE_P (TREE_TYPE (arg));
+			  if (!is_ptr_or_template)
+			    {
+			      error_at (arg_loc, "%qD is not a C pointer",
+					decl);
+			      continue;
+			    }
+			}
+		      adjust_args_list.safe_push (arg);
+		      if (strcmp (p, "need_device_ptr") == 0)
+			{
+			  tree attr = tree_cons (NULL_TREE, TREE_PURPOSE (c),
+						 NULL_TREE);
+			  if (need_device_ptr_list == NULL_TREE)
+			    {
+			      gcc_assert (need_device_ptr_chain_p == NULL);
+			      need_device_ptr_list = attr;
+			    }
+			  else
+			    *need_device_ptr_chain_p = attr;
+			  need_device_ptr_chain_p = &TREE_CHAIN (attr);
+			}
+		    }
+		}
+	      else
+		{
+		  error_at (adjust_op_tok->location,
+			    "expected %<nothing%> or %<need_device_ptr%>");
+		  goto fail;
+		}
+	    }
+	  else
+	    {
+	      error_at (adjust_op_tok->location,
+			"expected %<nothing%> or %<need_device_ptr%> followed "
+			"by %<:%>");
+	      goto fail;
+	    }
+	}
+  } while (cp_lexer_next_token_is_not (parser->lexer, CPP_PRAGMA_EOL));
+
+  if (has_adjust_args)
+    {
+      if (!has_match)
+	{
+	  error_at (
+	    adjust_args_loc,
+	    "an %<adjust_args%> clause can only be specified if the "
+	    "%<dispatch%> selector of the construct selector set appears "
+	    "in the %<match%> clause");
+	}
+      else
+	{
+	  tree ctx = TREE_VALUE (TREE_VALUE (attrs));
+	  if (!omp_get_context_selector (ctx, OMP_TRAIT_SET_CONSTRUCT,
+					 OMP_TRAIT_CONSTRUCT_DISPATCH))
+	    error_at (
+	      adjust_args_loc,
+	      "an %<adjust_args%> clause can only be specified if the "
+	      "%<dispatch%> selector of the construct selector set appears "
+	      "in the %<match%> clause");
+	}
     }
 
-  cp_lexer_consume_token (parser->lexer);
-
-  if (!parens.require_open (parser))
-    goto fail;
-
-  tree ctx = cp_parser_omp_context_selector_specification (parser, true);
-  if (ctx == error_mark_node)
-    goto fail;
-  ctx = omp_check_context_selector (match_loc, ctx);
-  if (ctx != error_mark_node && variant != error_mark_node)
+  if (need_device_ptr_list)
     {
-      tree match_loc_node = maybe_wrap_with_location (integer_zero_node,
-						      match_loc);
-      tree loc_node = maybe_wrap_with_location (integer_zero_node, varid_loc);
-      loc_node = tree_cons (match_loc_node,
-			    build_int_cst (integer_type_node, idk),
-			    build_tree_list (loc_node, integer_zero_node));
-      attrs = tree_cons (get_identifier ("omp declare variant base"),
-			 tree_cons (variant, ctx, loc_node), attrs);
-      if (processing_template_decl)
-	ATTR_IS_DEPENDENT (attrs) = 1;
+      // We might not have DECL_ARGUMENTS for the variant yet. So we store the
+      // need_device_ptr list in the base function attribute beside the context
+      // selector.
+      gcc_assert (TREE_PURPOSE (attrs)
+		  == get_identifier ("omp declare variant base"));
+      gcc_assert (TREE_PURPOSE (TREE_VALUE (attrs)) == variant);
+      TREE_VALUE (TREE_VALUE (attrs))
+	= make_trait_set_selector (OMP_TRAIT_SET_NEED_DEVICE_PTR,
+				   need_device_ptr_list,
+				   TREE_VALUE (TREE_VALUE (attrs)));
     }
 
-  parens.require_close (parser);
   cp_parser_skip_to_pragma_eol (parser, pragma_tok);
   return attrs;
 }
@@ -49005,7 +49516,8 @@  cp_finish_omp_declare_variant (cp_parser *parser, cp_token *pragma_tok,
    been parsed, and put that into "omp declare simd" attribute.  */
 
 static tree
-cp_parser_late_parsing_omp_declare_simd (cp_parser *parser, tree attrs)
+cp_parser_late_parsing_omp_declare_simd (cp_parser *parser, tree attrs,
+					 tree parms)
 {
   struct cp_token_cache *ce;
   cp_omp_declare_simd_data *data = parser->omp_declare_simd;
@@ -49049,7 +49561,7 @@  cp_parser_late_parsing_omp_declare_simd (cp_parser *parser, tree attrs)
 	{
 	  gcc_assert (strcmp (kind, "variant") == 0);
 	  attrs
-	    = cp_finish_omp_declare_variant (parser, pragma_tok, attrs);
+	    = cp_finish_omp_declare_variant (parser, pragma_tok, attrs, parms);
 	}
       cp_parser_pop_lexer (parser);
     }
@@ -49180,9 +49692,8 @@  cp_parser_late_parsing_omp_declare_simd (cp_parser *parser, tree attrs)
 		else
 		  {
 		    gcc_assert (strcmp (kind, "variant") == 0);
-		    attrs
-		      = cp_finish_omp_declare_variant (parser, pragma_tok,
-						       attrs);
+		    attrs = cp_finish_omp_declare_variant (parser, pragma_tok,
+							   attrs, parms);
 		  }
 		gcc_assert (parser->lexer != lexer);
 		vec_safe_truncate (lexer->buffer, 0);
@@ -50032,7 +50543,11 @@  cp_parser_omp_declare_reduction (cp_parser *parser, cp_token *pragma_tok,
    #pragma omp declare target new-line
 
    OpenMP 5.0
-   #pragma omp declare variant (identifier) match (context-selector)  */
+   #pragma omp declare variant (identifier) match (context-selector)
+
+   OpenMP 5.1
+   #pragma omp declare variant (identifier) match (context-selector) \
+      adjust_args (adjust-op:argument-list)  */
 
 static bool
 cp_parser_omp_declare (cp_parser *parser, cp_token *pragma_tok,
@@ -50893,6 +51408,9 @@  cp_parser_omp_construct (cp_parser *parser, cp_token *pragma_tok, bool *if_p)
     case PRAGMA_OMP_UNROLL:
       stmt = cp_parser_omp_unroll (parser, pragma_tok, if_p);
       break;
+    case PRAGMA_OMP_DISPATCH:
+      stmt = cp_parser_omp_dispatch (parser, pragma_tok);
+      break;
     default:
       gcc_unreachable ();
     }
@@ -51589,6 +52107,10 @@  cp_parser_pragma (cp_parser *parser, enum pragma_context context, bool *if_p)
 		"%<#pragma omp sections%> construct");
       break;
 
+    case PRAGMA_OMP_DISPATCH:
+      cp_parser_omp_dispatch (parser, pragma_tok);
+      return true;
+
     case PRAGMA_IVDEP:
     case PRAGMA_UNROLL:
     case PRAGMA_NOVECTOR:
diff --git a/gcc/cp/pt.cc b/gcc/cp/pt.cc
index e38e02488be..af6b112d463 100644
--- a/gcc/cp/pt.cc
+++ b/gcc/cp/pt.cc
@@ -11987,6 +11987,9 @@  tsubst_attribute (tree t, tree *decl_p, tree args,
       for (tree tss = ctx; tss; tss = TREE_CHAIN (tss))
 	{
 	  enum omp_tss_code set = OMP_TSS_CODE (tss);
+	  if (set == OMP_TRAIT_SET_NEED_DEVICE_PTR)
+	    continue;
+
 	  tree selectors = NULL_TREE;
 	  for (tree ts = OMP_TSS_TRAIT_SELECTORS (tss); ts;
 	       ts = TREE_CHAIN (ts))
diff --git a/gcc/cp/semantics.cc b/gcc/cp/semantics.cc
index cd3df13772d..76a5d3b23f3 100644
--- a/gcc/cp/semantics.cc
+++ b/gcc/cp/semantics.cc
@@ -7631,6 +7631,26 @@  finish_omp_clauses (tree clauses, enum c_omp_region_type ort)
 	  OMP_CLAUSE_FINAL_EXPR (c) = t;
 	  break;
 
+	case OMP_CLAUSE_NOCONTEXT:
+	  t = OMP_CLAUSE_NOCONTEXT_EXPR (c);
+	  t = maybe_convert_cond (t);
+	  if (t == error_mark_node)
+	    remove = true;
+	  else if (!processing_template_decl)
+	    t = fold_build_cleanup_point_expr (TREE_TYPE (t), t);
+	  OMP_CLAUSE_NOCONTEXT_EXPR (c) = t;
+	  break;
+
+	case OMP_CLAUSE_NOVARIANTS:
+	  t = OMP_CLAUSE_NOVARIANTS_EXPR (c);
+	  t = maybe_convert_cond (t);
+	  if (t == error_mark_node)
+	    remove = true;
+	  else if (!processing_template_decl)
+	    t = fold_build_cleanup_point_expr (TREE_TYPE (t), t);
+	  OMP_CLAUSE_NOVARIANTS_EXPR (c) = t;
+	  break;
+
 	case OMP_CLAUSE_GANG:
 	  /* Operand 1 is the gang static: argument.  */
 	  t = OMP_CLAUSE_OPERAND (c, 1);
diff --git a/gcc/testsuite/g++.dg/gomp/adjust-args-1.C b/gcc/testsuite/g++.dg/gomp/adjust-args-1.C
new file mode 100644
index 00000000000..1c6dd8ac97b
--- /dev/null
+++ b/gcc/testsuite/g++.dg/gomp/adjust-args-1.C
@@ -0,0 +1,39 @@ 
+/* Test parsing of OMP clause adjust_args */
+/* { dg-do compile } */
+
+int b;
+
+int f0 (void *a);
+int g (void *a);
+int f1 (int);
+
+#pragma omp declare variant (f0) match (construct={target}) adjust_args (nothing: a) /* { dg-error "an 'adjust_args' clause can only be specified if the 'dispatch' selector of the construct selector set appears in the 'match' clause" } */
+int f2 (void *a);
+#pragma omp declare variant (f0) match (construct={dispatch,target}) adjust_args (need_device_ptr: a) /* { dg-error "'int f0.void..' used as a variant with incompatible 'construct' selector sets" } */
+int f2a (void *a);
+#pragma omp declare variant (f0) match (construct={target,dispatch}) adjust_args (need_device_ptr: a) /* { dg-error "'int f0.void..' used as a variant with incompatible 'construct' selector sets" } */
+int f2b (void *a);
+#pragma omp declare variant (f0) match (construct={dispatch},device={arch(gcn)}) adjust_args (need_device_ptr: a) /* { dg-error "'int f0.void..' used as a variant with incompatible 'construct' selector sets" } */
+int f2c (void *a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (other: a) /* { dg-error "expected 'nothing' or 'need_device_ptr'" } */
+int f3 (int a);
+#pragma omp declare variant (f0) adjust_args (nothing: a) /* { dg-error "an 'adjust_args' clause can only be specified if the 'dispatch' selector of the construct selector set appears in the 'match' clause" } */
+int f4 (void *a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args () /* { dg-error "expected 'nothing' or 'need_device_ptr' followed by ':'" } */
+int f5 (int a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (nothing) /* { dg-error "expected 'nothing' or 'need_device_ptr' followed by ':'" } */
+int f6 (int a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (nothing:) /* { dg-error "expected unqualified-id before '\\)' token" } */
+int f7 (int a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (nothing: z) /* { dg-error "'z' has not been declared" } */
+int f8 (int a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (need_device_ptr: a) /* { dg-error "'a' is not a C pointer" } */
+int f9 (int a);
+#pragma omp declare variant (f1) match (construct={dispatch}) adjust_args (nothing: a) adjust_args (nothing: a) /* { dg-error "'a' is specified more than once" } */
+int f10 (int a);
+#pragma omp declare variant (g) match (construct={dispatch}) adjust_args (nothing: a) adjust_args (need_device_ptr: a) /* { dg-error "'a' is specified more than once" } */
+int f11 (void *a);
+#pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: b) /* { dg-error "'b' is not a function argument" } */
+int f12 (void *a);
+#pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: this) /* { dg-error "expected unqualified-id before 'this'" } */
+int f13 (void *a);
diff --git a/gcc/testsuite/g++.dg/gomp/adjust-args-2.C b/gcc/testsuite/g++.dg/gomp/adjust-args-2.C
new file mode 100644
index 00000000000..a78f06ec193
--- /dev/null
+++ b/gcc/testsuite/g++.dg/gomp/adjust-args-2.C
@@ -0,0 +1,51 @@ 
+struct S { 
+  int a;
+  int g (const void *b);
+  #pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: b)
+  int f0(const void *b); 
+  int operator()() { return a; }
+  bool operator!() { return !a; }
+};
+
+template <typename T>
+T f0(T a, T *b);
+
+#pragma omp declare variant (f0) match (construct={dispatch}) adjust_args (need_device_ptr: a, b)
+template <typename T>
+T f1(T a, T *b);
+
+namespace N {
+  class C{
+    public:
+  void g(C *c);
+  #pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: c)
+  void f0(C *c);
+  };
+  void g(C *c);
+  #pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: c)
+  void f0(C *c);
+}
+
+#pragma omp declare variant (g) match (construct={dispatch}) adjust_args (need_device_ptr: c)
+void f3(N::C *c);
+void f4(S *&s);
+#pragma omp declare variant (f4) match (construct={dispatch}) adjust_args (need_device_ptr: s)
+void f5(S *&s);
+
+void test() {
+  S s, *sp;
+  N::C c;
+  int *a, b;
+  #pragma omp dispatch
+  s.f0(a);
+  #pragma omp dispatch
+  f1(b, a);
+  #pragma omp dispatch
+  c.f0(&c);
+  #pragma omp dispatch
+  N::f0(&c);
+  #pragma omp dispatch
+  f3(&c);
+  #pragma omp dispatch
+  f5(sp);
+}
diff --git a/gcc/testsuite/g++.dg/gomp/dispatch-1.C b/gcc/testsuite/g++.dg/gomp/dispatch-1.C
new file mode 100644
index 00000000000..fb467afcd85
--- /dev/null
+++ b/gcc/testsuite/g++.dg/gomp/dispatch-1.C
@@ -0,0 +1,53 @@ 
+struct S { 
+  int a; 
+  void f0(double); 
+  int operator()() { return a; }
+  bool operator!() { return !a; }
+};
+
+int f0(int);
+template <typename T>
+T f1(T a, T b);
+void (*f2)(void);
+
+namespace N {
+  class C{};
+  void f0(C);
+  int a;
+}
+
+int test() {
+  int result;
+  double d = 5.0;
+  N::C c;
+  S s;
+  S* sp = &s;
+  int &r = result;
+  #pragma omp dispatch
+  result = f0(5);
+  #pragma omp dispatch
+  r = f0(5);
+  #pragma omp dispatch
+  N::a = ::f0(5);
+  #pragma omp dispatch
+  sp->a = f1<int>(5, 10);
+  #pragma omp dispatch
+  s.a = f1(5, 10);
+  #pragma omp dispatch
+  f2();
+  #pragma omp dispatch
+  N::f0(c);
+  #pragma omp dispatch
+  f0(c);
+  #pragma omp dispatch
+  s.f0(d);
+  #pragma omp dispatch
+  sp->f0(d);
+  #pragma omp dispatch
+  sp->f0(d);
+  #pragma omp dispatch
+  s();
+  #pragma omp dispatch
+  !s;
+  return result; 
+}
diff --git a/gcc/testsuite/g++.dg/gomp/dispatch-2.C b/gcc/testsuite/g++.dg/gomp/dispatch-2.C
new file mode 100644
index 00000000000..1bc304e005e
--- /dev/null
+++ b/gcc/testsuite/g++.dg/gomp/dispatch-2.C
@@ -0,0 +1,62 @@ 
+/* Test parsing of #pragma omp dispatch */
+/* { dg-do compile } */
+
+struct S {
+  int a;
+  int b;
+  virtual int f (double); 
+};
+
+int f0 (int);
+
+void f1 (void)
+{
+  int a, b;
+  double x;
+  int arr[1];
+  S s;
+
+#pragma omp dispatch
+  int c = f0 (a);	/* { dg-error "expected primary-expression before 'int'" } */
+#pragma omp dispatch
+  int f2 (int d);	/* { dg-error "expected primary-expression before 'int'" } */
+#pragma omp dispatch
+  a = b;	/* { dg-error "expected '\\(' before ';' token" } */
+#pragma omp dispatch
+  s.a = f0(a) + b;	/* { dg-error "expected ';' before '\\+' token" } */
+#pragma omp dispatch
+  b = !f0(a);	/* { dg-error "expected primary-expression before '!' token" } */
+#pragma omp dispatch
+  s.b += f0(s.a);	/* { dg-error "expected '=' before '\\+=' token" } */
+#pragma omp dispatch
+#pragma omp threadprivate(a)	/* { dg-error "'#pragma' is not allowed here" } */
+  a = f0(b);
+#pragma omp dispatch
+  a = s.f(x);   /* { dg-error "'f' is a virtual function but only a direct call is allowed in a dispatch construct" } */
+  
+#pragma omp dispatch nocontext(s) /* { dg-error "could not convert 's' from 'S' to 'bool'" } */
+  f0(a);
+#pragma omp dispatch nocontext(a, b) /* { dg-error "expected '\\)' before ','" } */
+  f0(a);
+#pragma omp dispatch nocontext(a) nocontext(b) /* { dg-error "too many 'nocontext' clauses" } */
+  f0(a);
+#pragma omp dispatch novariants(s) /* { dg-error "could not convert 's' from 'S' to 'bool'" } */
+  f0(a);
+#pragma omp dispatch novariants(a, b) /* { dg-error "expected '\\)' before ','" } */
+  f0(a);
+#pragma omp dispatch novariants(a) novariants(b) /* { dg-error "too many 'novariants' clauses" } */
+  f0(a);
+#pragma omp dispatch nowait nowait /* { dg-error "too many 'nowait' clauses" } */
+  f0(a);
+#pragma omp dispatch device(x) /* { dg-error "'device' id must be integral" } */
+  f0(a);
+#pragma omp dispatch device(arr) /* { dg-error "'device' id must be integral" } */
+  f0(a);
+#pragma omp dispatch is_device_ptr(x) /* { dg-error "'is_device_ptr' variable is neither a pointer, nor an array nor reference to pointer" } */
+  f0(a);
+#pragma omp dispatch is_device_ptr(&x) /* { dg-error "expected unqualified-id before '&' token" } */
+  f0(a);
+#pragma omp dispatch depend(inout: s.f) /* { dg-error "'s.S::f' is not lvalue expression nor array section in 'depend' clause" } */
+  f0(a);
+
+}