@@ -1987,7 +1987,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev)
mss = skb_is_gso(skb) ? skb_shinfo(skb)->gso_size : data_len;
tls_ctx = tls_get_ctx(skb->sk);
- if (unlikely(tls_ctx->netdev != dev))
+ if (unlikely(tls_ctx->real_dev != dev))
goto out;
tx_ctx = chcr_get_ktls_tx_context(tls_ctx);
@@ -273,7 +273,7 @@ bool mlx5e_tls_handle_tx_skb(struct net_device *netdev, struct mlx5e_txqsq *sq,
mlx5e_tx_mpwqe_ensure_complete(sq);
tls_ctx = tls_get_ctx(skb->sk);
- if (WARN_ON_ONCE(tls_ctx->netdev != netdev))
+ if (WARN_ON_ONCE(tls_ctx->real_dev != netdev))
goto err_out;
if (mlx5_accel_is_ktls_tx(sq->channel->mdev))
@@ -241,6 +241,7 @@ struct tls_context {
void *priv_ctx_rx;
struct net_device *netdev;
+ struct net_device *real_dev;
/* rw cache line */
struct cipher_context tx;
@@ -970,6 +970,8 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
refcount_set(&ctx->refcount, 1);
dev_hold(netdev);
ctx->netdev = netdev;
+ if (!ctx->real_dev)
+ ctx->real_dev = netdev;
spin_lock_irq(&tls_device_lock);
list_add_tail(&ctx->list, &tls_device_list);
spin_unlock_irq(&tls_device_lock);
@@ -423,7 +423,7 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
struct net_device *dev,
struct sk_buff *skb)
{
- if (dev == tls_get_ctx(sk)->netdev)
+ if (dev == tls_get_ctx(sk)->netdev || dev == tls_get_ctx(sk)->real_dev)
return skb;
return tls_sw_fallback(sk, skb);