diff mbox series

[RFC,3/4] Drivers: hv: vmbus: Introduce vmbus_sendpacket_getid()

Message ID 20220328144244.100228-4-parri.andrea@gmail.com
State New
Headers show
Series PCI: hv: Miscellaneous changes | expand

Commit Message

Andrea Parri March 28, 2022, 2:42 p.m. UTC
The function can be used to send a VMbus packet and retrieve the
corresponding transaction ID.  It will be used by hv_pci.

No functional change.

Suggested-by: Michael Kelley <mikelley@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
---
 drivers/hv/channel.c      | 38 ++++++++++++++++++++++++++++++++------
 drivers/hv/hyperv_vmbus.h |  2 +-
 drivers/hv/ring_buffer.c  |  4 +++-
 include/linux/hyperv.h    |  7 +++++++
 4 files changed, 43 insertions(+), 8 deletions(-)

Comments

Michael Kelley (LINUX) March 31, 2022, 7:47 p.m. UTC | #1
From: Andrea Parri (Microsoft) <parri.andrea@gmail.com> Sent: Monday, March 28, 2022 7:43 AM
> 
> The function can be used to send a VMbus packet and retrieve the
> corresponding transaction ID.  It will be used by hv_pci.
> 
> No functional change.
> 
> Suggested-by: Michael Kelley <mikelley@microsoft.com>
> Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
> ---
>  drivers/hv/channel.c      | 38 ++++++++++++++++++++++++++++++++------
>  drivers/hv/hyperv_vmbus.h |  2 +-
>  drivers/hv/ring_buffer.c  |  4 +++-
>  include/linux/hyperv.h    |  7 +++++++
>  4 files changed, 43 insertions(+), 8 deletions(-)
> 
> diff --git a/drivers/hv/channel.c b/drivers/hv/channel.c
> index a253eee3aeb1a..3eaa41c7ce15f 100644
> --- a/drivers/hv/channel.c
> +++ b/drivers/hv/channel.c
> @@ -1022,11 +1022,13 @@ void vmbus_close(struct vmbus_channel *channel)
>  EXPORT_SYMBOL_GPL(vmbus_close);
> 
>  /**
> - * vmbus_sendpacket() - Send the specified buffer on the given channel
> + * vmbus_sendpacket_getid() - Send the specified buffer on the given channel
>   * @channel: Pointer to vmbus_channel structure
>   * @buffer: Pointer to the buffer you want to send the data from.
>   * @bufferlen: Maximum size of what the buffer holds.
>   * @requestid: Identifier of the request
> + * @trans_id: Identifier of the transaction associated to this request, if
> + *            the send is successful; undefined, otherwise.
>   * @type: Type of packet that is being sent e.g. negotiate, time
>   *	  packet etc.
>   * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
> @@ -1036,8 +1038,8 @@ EXPORT_SYMBOL_GPL(vmbus_close);
>   *
>   * Mainly used by Hyper-V drivers.
>   */
> -int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
> -			   u32 bufferlen, u64 requestid,
> +int vmbus_sendpacket_getid(struct vmbus_channel *channel, void *buffer,
> +			   u32 bufferlen, u64 requestid, u64 *trans_id,
>  			   enum vmbus_packet_type type, u32 flags)
>  {
>  	struct vmpacket_descriptor desc;
> @@ -1063,7 +1065,31 @@ int vmbus_sendpacket(struct vmbus_channel *channel,
> void *buffer,
>  	bufferlist[2].iov_base = &aligned_data;
>  	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
> 
> -	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid);
> +	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid, trans_id);
> +}
> +EXPORT_SYMBOL(vmbus_sendpacket_getid);
> +
> +/**
> + * vmbus_sendpacket() - Send the specified buffer on the given channel
> + * @channel: Pointer to vmbus_channel structure
> + * @buffer: Pointer to the buffer you want to send the data from.
> + * @bufferlen: Maximum size of what the buffer holds.
> + * @requestid: Identifier of the request
> + * @type: Type of packet that is being sent e.g. negotiate, time
> + *	  packet etc.
> + * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
> + *
> + * Sends data in @buffer directly to Hyper-V via the vmbus.
> + * This will send the data unparsed to Hyper-V.
> + *
> + * Mainly used by Hyper-V drivers.
> + */
> +int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
> +		     u32 bufferlen, u64 requestid,
> +		     enum vmbus_packet_type type, u32 flags)
> +{
> +	return vmbus_sendpacket_getid(channel, buffer, bufferlen,
> +				      requestid, NULL, type, flags);
>  }
>  EXPORT_SYMBOL(vmbus_sendpacket);
> 
> @@ -1122,7 +1148,7 @@ int vmbus_sendpacket_pagebuffer(struct vmbus_channel
> *channel,
>  	bufferlist[2].iov_base = &aligned_data;
>  	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
> 
> -	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
> +	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
>  }
>  EXPORT_SYMBOL_GPL(vmbus_sendpacket_pagebuffer);
> 
> @@ -1160,7 +1186,7 @@ int vmbus_sendpacket_mpb_desc(struct vmbus_channel
> *channel,
>  	bufferlist[2].iov_base = &aligned_data;
>  	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
> 
> -	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
> +	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
>  }
>  EXPORT_SYMBOL_GPL(vmbus_sendpacket_mpb_desc);
> 
> diff --git a/drivers/hv/hyperv_vmbus.h b/drivers/hv/hyperv_vmbus.h
> index 3a1f007b678a0..64c0b9cbe183b 100644
> --- a/drivers/hv/hyperv_vmbus.h
> +++ b/drivers/hv/hyperv_vmbus.h
> @@ -181,7 +181,7 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info
> *ring_info);
> 
>  int hv_ringbuffer_write(struct vmbus_channel *channel,
>  			const struct kvec *kv_list, u32 kv_count,
> -			u64 requestid);
> +			u64 requestid, u64 *trans_id);
> 
>  int hv_ringbuffer_read(struct vmbus_channel *channel,
>  		       void *buffer, u32 buflen, u32 *buffer_actual_len,
> diff --git a/drivers/hv/ring_buffer.c b/drivers/hv/ring_buffer.c
> index 71efacb909659..c8561c80c460c 100644
> --- a/drivers/hv/ring_buffer.c
> +++ b/drivers/hv/ring_buffer.c
> @@ -283,7 +283,7 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info
> *ring_info)
>  /* Write to the ring buffer. */
>  int hv_ringbuffer_write(struct vmbus_channel *channel,
>  			const struct kvec *kv_list, u32 kv_count,
> -			u64 requestid)
> +			u64 requestid, u64 *trans_id)
>  {
>  	int i;
>  	u32 bytes_avail_towrite;
> @@ -354,6 +354,8 @@ int hv_ringbuffer_write(struct vmbus_channel *channel,
>  	}
>  	desc = hv_get_ring_buffer(outring_info) + old_write;
>  	desc->trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id;
> +	if (trans_id)
> +		*trans_id = desc->trans_id;

This line should *not* read the trans_id out of the ring buffer, since that
memory is shared with the Hyper-V host and subject to being maliciously
changed by the host.  Need to set *trans_id only from local variables, and
somehow ensure the compiler doesn't generate code that reads the value
from the ring buffer.  Maybe mark the desc->trans_id field as volatile, or cast
it as such?  Or does WRITE_ONCE() work when setting it?

Michael

> 
>  	/* Set previous packet start */
>  	prev_indices = hv_get_ring_bufferindices(outring_info);
> diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h
> index fe2e0179ed51e..a7cb596d893b1 100644
> --- a/include/linux/hyperv.h
> +++ b/include/linux/hyperv.h
> @@ -1161,6 +1161,13 @@ extern int vmbus_open(struct vmbus_channel *channel,
> 
>  extern void vmbus_close(struct vmbus_channel *channel);
> 
> +extern int vmbus_sendpacket_getid(struct vmbus_channel *channel,
> +				  void *buffer,
> +				  u32 bufferLen,
> +				  u64 requestid,
> +				  u64 *trans_id,
> +				  enum vmbus_packet_type type,
> +				  u32 flags);
>  extern int vmbus_sendpacket(struct vmbus_channel *channel,
>  				  void *buffer,
>  				  u32 bufferLen,
> --
> 2.25.1
Andrea Parri April 1, 2022, 4:09 p.m. UTC | #2
> > @@ -354,6 +354,8 @@ int hv_ringbuffer_write(struct vmbus_channel *channel,
> >  	}
> >  	desc = hv_get_ring_buffer(outring_info) + old_write;
> >  	desc->trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id;
> > +	if (trans_id)
> > +		*trans_id = desc->trans_id;
> 
> This line should *not* read the trans_id out of the ring buffer, since that
> memory is shared with the Hyper-V host and subject to being maliciously
> changed by the host.  Need to set *trans_id only from local variables, and
> somehow ensure the compiler doesn't generate code that reads the value
> from the ring buffer.  Maybe mark the desc->trans_id field as volatile, or cast
> it as such?  Or does WRITE_ONCE() work when setting it?

I'd stick to WRITE_ONCE() (with a comment).

Good catch!

Thanks,
  Andrea
diff mbox series

Patch

diff --git a/drivers/hv/channel.c b/drivers/hv/channel.c
index a253eee3aeb1a..3eaa41c7ce15f 100644
--- a/drivers/hv/channel.c
+++ b/drivers/hv/channel.c
@@ -1022,11 +1022,13 @@  void vmbus_close(struct vmbus_channel *channel)
 EXPORT_SYMBOL_GPL(vmbus_close);
 
 /**
- * vmbus_sendpacket() - Send the specified buffer on the given channel
+ * vmbus_sendpacket_getid() - Send the specified buffer on the given channel
  * @channel: Pointer to vmbus_channel structure
  * @buffer: Pointer to the buffer you want to send the data from.
  * @bufferlen: Maximum size of what the buffer holds.
  * @requestid: Identifier of the request
+ * @trans_id: Identifier of the transaction associated to this request, if
+ *            the send is successful; undefined, otherwise.
  * @type: Type of packet that is being sent e.g. negotiate, time
  *	  packet etc.
  * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
@@ -1036,8 +1038,8 @@  EXPORT_SYMBOL_GPL(vmbus_close);
  *
  * Mainly used by Hyper-V drivers.
  */
-int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
-			   u32 bufferlen, u64 requestid,
+int vmbus_sendpacket_getid(struct vmbus_channel *channel, void *buffer,
+			   u32 bufferlen, u64 requestid, u64 *trans_id,
 			   enum vmbus_packet_type type, u32 flags)
 {
 	struct vmpacket_descriptor desc;
@@ -1063,7 +1065,31 @@  int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid, trans_id);
+}
+EXPORT_SYMBOL(vmbus_sendpacket_getid);
+
+/**
+ * vmbus_sendpacket() - Send the specified buffer on the given channel
+ * @channel: Pointer to vmbus_channel structure
+ * @buffer: Pointer to the buffer you want to send the data from.
+ * @bufferlen: Maximum size of what the buffer holds.
+ * @requestid: Identifier of the request
+ * @type: Type of packet that is being sent e.g. negotiate, time
+ *	  packet etc.
+ * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
+ *
+ * Sends data in @buffer directly to Hyper-V via the vmbus.
+ * This will send the data unparsed to Hyper-V.
+ *
+ * Mainly used by Hyper-V drivers.
+ */
+int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
+		     u32 bufferlen, u64 requestid,
+		     enum vmbus_packet_type type, u32 flags)
+{
+	return vmbus_sendpacket_getid(channel, buffer, bufferlen,
+				      requestid, NULL, type, flags);
 }
 EXPORT_SYMBOL(vmbus_sendpacket);
 
@@ -1122,7 +1148,7 @@  int vmbus_sendpacket_pagebuffer(struct vmbus_channel *channel,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_pagebuffer);
 
@@ -1160,7 +1186,7 @@  int vmbus_sendpacket_mpb_desc(struct vmbus_channel *channel,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_mpb_desc);
 
diff --git a/drivers/hv/hyperv_vmbus.h b/drivers/hv/hyperv_vmbus.h
index 3a1f007b678a0..64c0b9cbe183b 100644
--- a/drivers/hv/hyperv_vmbus.h
+++ b/drivers/hv/hyperv_vmbus.h
@@ -181,7 +181,7 @@  void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info);
 
 int hv_ringbuffer_write(struct vmbus_channel *channel,
 			const struct kvec *kv_list, u32 kv_count,
-			u64 requestid);
+			u64 requestid, u64 *trans_id);
 
 int hv_ringbuffer_read(struct vmbus_channel *channel,
 		       void *buffer, u32 buflen, u32 *buffer_actual_len,
diff --git a/drivers/hv/ring_buffer.c b/drivers/hv/ring_buffer.c
index 71efacb909659..c8561c80c460c 100644
--- a/drivers/hv/ring_buffer.c
+++ b/drivers/hv/ring_buffer.c
@@ -283,7 +283,7 @@  void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info)
 /* Write to the ring buffer. */
 int hv_ringbuffer_write(struct vmbus_channel *channel,
 			const struct kvec *kv_list, u32 kv_count,
-			u64 requestid)
+			u64 requestid, u64 *trans_id)
 {
 	int i;
 	u32 bytes_avail_towrite;
@@ -354,6 +354,8 @@  int hv_ringbuffer_write(struct vmbus_channel *channel,
 	}
 	desc = hv_get_ring_buffer(outring_info) + old_write;
 	desc->trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id;
+	if (trans_id)
+		*trans_id = desc->trans_id;
 
 	/* Set previous packet start */
 	prev_indices = hv_get_ring_bufferindices(outring_info);
diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h
index fe2e0179ed51e..a7cb596d893b1 100644
--- a/include/linux/hyperv.h
+++ b/include/linux/hyperv.h
@@ -1161,6 +1161,13 @@  extern int vmbus_open(struct vmbus_channel *channel,
 
 extern void vmbus_close(struct vmbus_channel *channel);
 
+extern int vmbus_sendpacket_getid(struct vmbus_channel *channel,
+				  void *buffer,
+				  u32 bufferLen,
+				  u64 requestid,
+				  u64 *trans_id,
+				  enum vmbus_packet_type type,
+				  u32 flags);
 extern int vmbus_sendpacket(struct vmbus_channel *channel,
 				  void *buffer,
 				  u32 bufferLen,