@@ -13,8 +13,9 @@ extern spinlock_t reuseport_lock;
struct sock_reuseport {
struct rcu_head rcu;
- u16 max_socks; /* length of socks */
- u16 num_socks; /* elements in socks */
+ u16 max_socks; /* length of socks */
+ u16 num_socks; /* elements in socks */
+ u16 num_closed_socks; /* closed elements in socks */
/* The last synq overflow event timestamp of this
* reuse->socks[] group.
*/
@@ -98,14 +98,15 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
return NULL;
more_reuse->num_socks = reuse->num_socks;
+ more_reuse->num_closed_socks = reuse->num_closed_socks;
more_reuse->prog = reuse->prog;
more_reuse->reuseport_id = reuse->reuseport_id;
more_reuse->bind_inany = reuse->bind_inany;
more_reuse->has_conns = reuse->has_conns;
+ more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
memcpy(more_reuse->socks, reuse->socks,
reuse->num_socks * sizeof(struct sock *));
- more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts);
for (i = 0; i < reuse->num_socks; ++i)
rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
@@ -152,8 +153,10 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
reuse = rcu_dereference_protected(sk2->sk_reuseport_cb,
lockdep_is_held(&reuseport_lock));
old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
- lockdep_is_held(&reuseport_lock));
- if (old_reuse && old_reuse->num_socks != 1) {
+ lockdep_is_held(&reuseport_lock));
+ if (old_reuse == reuse) {
+ reuse->num_closed_socks--;
+ } else if (old_reuse && old_reuse->num_socks != 1) {
spin_unlock_bh(&reuseport_lock);
return -EBUSY;
}
@@ -174,8 +177,9 @@ int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
spin_unlock_bh(&reuseport_lock);
- if (old_reuse)
+ if (old_reuse && old_reuse != reuse)
call_rcu(&old_reuse->rcu, reuseport_free_rcu);
+
return 0;
}
EXPORT_SYMBOL(reuseport_add_sock);
@@ -199,17 +203,28 @@ void reuseport_detach_sock(struct sock *sk)
*/
bpf_sk_reuseport_detach(sk);
- rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
-
- for (i = 0; i < reuse->num_socks; i++) {
- if (reuse->socks[i] == sk) {
- reuse->socks[i] = reuse->socks[reuse->num_socks - 1];
- reuse->num_socks--;
- if (reuse->num_socks == 0)
- call_rcu(&reuse->rcu, reuseport_free_rcu);
+ if (sk->sk_protocol == IPPROTO_TCP && sk->sk_state == TCP_CLOSE) {
+ reuse->num_closed_socks--;
+ rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
+ } else {
+ for (i = 0; i < reuse->num_socks; i++) {
+ if (reuse->socks[i] != sk)
+ continue;
break;
}
+
+ reuse->num_socks--;
+ reuse->socks[i] = reuse->socks[reuse->num_socks];
+
+ if (sk->sk_protocol == IPPROTO_TCP)
+ reuse->num_closed_socks++;
+ else
+ rcu_assign_pointer(sk->sk_reuseport_cb, NULL);
}
+
+ if (reuse->num_socks + reuse->num_closed_socks == 0)
+ call_rcu(&reuse->rcu, reuseport_free_rcu);
+
spin_unlock_bh(&reuseport_lock);
}
EXPORT_SYMBOL(reuseport_detach_sock);
@@ -138,6 +138,7 @@ static int inet_csk_bind_conflict(const struct sock *sk,
bool reuse = sk->sk_reuse;
bool reuseport = !!sk->sk_reuseport;
kuid_t uid = sock_i_uid((struct sock *)sk);
+ struct sock_reuseport *reuseport_cb = rcu_access_pointer(sk->sk_reuseport_cb);
/*
* Unlike other sk lookup places we do not check
@@ -156,14 +157,16 @@ static int inet_csk_bind_conflict(const struct sock *sk,
if ((!relax ||
(!reuseport_ok &&
reuseport && sk2->sk_reuseport &&
- !rcu_access_pointer(sk->sk_reuseport_cb) &&
+ (!reuseport_cb ||
+ reuseport_cb == rcu_access_pointer(sk2->sk_reuseport_cb)) &&
(sk2->sk_state == TCP_TIME_WAIT ||
uid_eq(uid, sock_i_uid(sk2))))) &&
inet_rcv_saddr_equal(sk, sk2, true))
break;
} else if (!reuseport_ok ||
!reuseport || !sk2->sk_reuseport ||
- rcu_access_pointer(sk->sk_reuseport_cb) ||
+ (reuseport_cb &&
+ reuseport_cb != rcu_access_pointer(sk2->sk_reuseport_cb)) ||
(sk2->sk_state != TCP_TIME_WAIT &&
!uid_eq(uid, sock_i_uid(sk2)))) {
if (inet_rcv_saddr_equal(sk, sk2, true))