diff --git a/awkernel_lib/src/net.rs b/awkernel_lib/src/net.rs index e13b18c3f..d8a1dc39d 100644 --- a/awkernel_lib/src/net.rs +++ b/awkernel_lib/src/net.rs @@ -14,12 +14,6 @@ use self::{ net_device::{LinkStatus, NetCapabilities, NetDevice}, }; -#[cfg(not(feature = "std"))] -use self::tcp::TcpPort; - -#[cfg(not(feature = "std"))] -use alloc::collections::BTreeSet; - #[cfg(not(feature = "std"))] use alloc::{string::String, vec::Vec}; @@ -34,6 +28,7 @@ pub mod ip_addr; pub mod ipv6; pub mod multicast; pub mod net_device; +mod port_alloc; pub mod tcp; pub mod tcp_listener; pub mod tcp_stream; @@ -132,30 +127,6 @@ impl Display for IfStatus { static NET_MANAGER: RwLock = RwLock::new(NetManager { interfaces: BTreeMap::new(), interface_id: 0, - - #[cfg(not(feature = "std"))] - udp_ports_ipv4: BTreeSet::new(), - - #[cfg(not(feature = "std"))] - udp_port_ipv4_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - udp_ports_ipv6: BTreeSet::new(), - - #[cfg(not(feature = "std"))] - udp_port_ipv6_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv4: BTreeMap::new(), - - #[cfg(not(feature = "std"))] - tcp_port_ipv4_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv6: BTreeMap::new(), - - #[cfg(not(feature = "std"))] - tcp_port_ipv6_ephemeral: u16::MAX >> 2, }); static IRQ_WAKERS: Mutex> = Mutex::new(BTreeMap::new()); @@ -164,208 +135,6 @@ static POLL_WAKERS: Mutex> = Mutex::new(BTreeMap::new()) pub struct NetManager { interfaces: BTreeMap>, interface_id: u64, - - #[cfg(not(feature = "std"))] - udp_ports_ipv4: BTreeSet, - - #[cfg(not(feature = "std"))] - udp_port_ipv4_ephemeral: u16, - - #[cfg(not(feature = "std"))] - udp_ports_ipv6: BTreeSet, - - #[cfg(not(feature = "std"))] - udp_port_ipv6_ephemeral: u16, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv4: BTreeMap, - - #[cfg(not(feature = "std"))] - tcp_port_ipv4_ephemeral: u16, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv6: BTreeMap, - - #[cfg(not(feature = "std"))] - tcp_port_ipv6_ephemeral: u16, -} - -impl NetManager { - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_udp_ipv4(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.udp_port_ipv4_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - if !self.udp_ports_ipv4.contains(&port) { - self.udp_ports_ipv4.insert(port); - self.udp_port_ipv4_ephemeral = port; - ephemeral_port = Some(port); - break; - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn set_port_in_use_udp_ipv4(&mut self, port: u16) { - self.udp_ports_ipv4.insert(port); - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_udp_ipv4(&mut self, port: u16) -> bool { - self.udp_ports_ipv4.contains(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn free_port_udp_ipv4(&mut self, port: u16) { - self.udp_ports_ipv4.remove(&port); - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_udp_ipv6(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.udp_port_ipv6_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - if !self.udp_ports_ipv6.contains(&port) { - self.udp_ports_ipv6.insert(port); - self.udp_port_ipv4_ephemeral = port; - ephemeral_port = Some(port); - break; - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn set_port_in_use_udp_ipv6(&mut self, port: u16) { - self.udp_ports_ipv6.insert(port); - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_udp_ipv6(&mut self, port: u16) -> bool { - self.udp_ports_ipv6.contains(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn free_port_udp_ipv6(&mut self, port: u16) { - self.udp_ports_ipv6.remove(&port); - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_tcp_ipv4(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.tcp_port_ipv4_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - let entry = self.tcp_ports_ipv4.entry(i); - - match entry { - Entry::Occupied(_) => (), - Entry::Vacant(e) => { - e.insert(1); - ephemeral_port = Some(TcpPort::new(port, true)); - self.tcp_port_ipv4_ephemeral = port; - break; - } - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_tcp_ipv4(&mut self, port: u16) -> bool { - self.tcp_ports_ipv4.contains_key(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn port_in_use_tcp_ipv4(&mut self, port: u16) -> TcpPort { - if let Some(e) = self.tcp_ports_ipv4.get_mut(&port) { - *e += 1; - } else { - self.tcp_ports_ipv4.insert(port, 1); - } - - TcpPort::new(port, true) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn decrement_port_in_use_tcp_ipv4(&mut self, port: u16) { - if let Some(e) = self.tcp_ports_ipv4.get_mut(&port) { - *e -= 1; - if *e == 0 { - self.tcp_ports_ipv4.remove(&port); - } - } - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_tcp_ipv6(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.tcp_port_ipv6_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - let entry = self.tcp_ports_ipv6.entry(i); - - match entry { - Entry::Occupied(_) => (), - Entry::Vacant(e) => { - e.insert(1); - ephemeral_port = Some(TcpPort::new(port, false)); - self.tcp_port_ipv6_ephemeral = port; - break; - } - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_tcp_ipv6(&mut self, port: u16) -> bool { - self.tcp_ports_ipv6.contains_key(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn port_in_use_tcp_ipv6(&mut self, port: u16) -> TcpPort { - if let Some(e) = self.tcp_ports_ipv6.get_mut(&port) { - *e += 1; - } else { - self.tcp_ports_ipv6.insert(port, 1); - } - - TcpPort::new(port, true) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn decrement_port_in_use_tcp_ipv6(&mut self, port: u16) { - if let Some(e) = self.tcp_ports_ipv6.get_mut(&port) { - *e -= 1; - if *e == 0 { - self.tcp_ports_ipv6.remove(&port); - } - } - } } pub fn get_interface(interface_id: u64) -> Result { diff --git a/awkernel_lib/src/net/port_alloc.rs b/awkernel_lib/src/net/port_alloc.rs new file mode 100644 index 000000000..c579a75c1 --- /dev/null +++ b/awkernel_lib/src/net/port_alloc.rs @@ -0,0 +1,267 @@ +#![cfg(not(feature = "std"))] + +use alloc::collections::{btree_map::Entry, BTreeMap, BTreeSet}; + +use crate::sync::{mcs::MCSNode, mutex::Mutex}; + +use super::tcp::TcpPort; + +/// RAII handle for a claimed UDP port. Frees the port from [`PORT_ALLOC`] on drop, +/// so the port is released on any error path between claiming it and constructing +/// the owning socket. +pub(super) struct UdpPort { + port: u16, + is_ipv4: bool, +} + +impl UdpPort { + pub(super) fn port(&self) -> u16 { + self.port + } +} + +impl Drop for UdpPort { + fn drop(&mut self) { + if self.is_ipv4 { + PORT_ALLOC.free_udp_ipv4(self.port); + } else { + PORT_ALLOC.free_udp_ipv6(self.port); + } + } +} + +struct TcpPortsInner { + map: BTreeMap, + cursor: u16, +} + +struct UdpPortsInner { + set: BTreeSet, + cursor: u16, +} + +pub(super) struct PortAllocator { + tcp_ipv4: Mutex, + tcp_ipv6: Mutex, + udp_ipv4: Mutex, + udp_ipv6: Mutex, +} + +pub(super) static PORT_ALLOC: PortAllocator = PortAllocator::new(); + +impl PortAllocator { + pub(super) const fn new() -> Self { + Self { + tcp_ipv4: Mutex::new(TcpPortsInner { + map: BTreeMap::new(), + cursor: u16::MAX >> 2, + }), + tcp_ipv6: Mutex::new(TcpPortsInner { + map: BTreeMap::new(), + cursor: u16::MAX >> 2, + }), + udp_ipv4: Mutex::new(UdpPortsInner { + set: BTreeSet::new(), + cursor: u16::MAX >> 2, + }), + udp_ipv6: Mutex::new(UdpPortsInner { + set: BTreeSet::new(), + cursor: u16::MAX >> 2, + }), + } + } + + /// Allocate an ephemeral TCP IPv4 port. + pub(super) fn get_ephemeral_tcp_ipv4(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + return Some(TcpPort::new(port, true)); + } + } + None + } + + /// Claim a specific TCP IPv4 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_tcp_ipv4(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + Some(TcpPort::new(port, true)) + } else { + None + } + } + + /// Increment the reference count for a TCP IPv4 port (used by `TcpListener::accept`). + pub(super) fn increment_ref_tcp_ipv4(&self, port: u16) -> TcpPort { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e += 1; + } else { + ports.map.insert(port, 1); + } + TcpPort::new(port, true) + } + + /// Decrement the reference count for a TCP IPv4 port, freeing it when it reaches zero. + pub(super) fn decrement_ref_tcp_ipv4(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e -= 1; + if *e == 0 { + ports.map.remove(&port); + } + } + } + + /// Allocate an ephemeral TCP IPv6 port. + pub(super) fn get_ephemeral_tcp_ipv6(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + return Some(TcpPort::new(port, false)); + } + } + None + } + + /// Claim a specific TCP IPv6 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_tcp_ipv6(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + Some(TcpPort::new(port, false)) + } else { + None + } + } + + /// Increment the reference count for a TCP IPv6 port. + pub(super) fn increment_ref_tcp_ipv6(&self, port: u16) -> TcpPort { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e += 1; + } else { + ports.map.insert(port, 1); + } + TcpPort::new(port, false) + } + + /// Decrement the reference count for a TCP IPv6 port, freeing it when it reaches zero. + pub(super) fn decrement_ref_tcp_ipv6(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e -= 1; + if *e == 0 { + ports.map.remove(&port); + } + } + } + + /// Allocate an ephemeral UDP IPv4 port. + pub(super) fn get_ephemeral_udp_ipv4(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if ports.set.insert(port) { + return Some(UdpPort { + port, + is_ipv4: true, + }); + } + } + None + } + + /// Claim a specific UDP IPv4 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_udp_ipv4(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + if ports.set.insert(port) { + Some(UdpPort { + port, + is_ipv4: true, + }) + } else { + None + } + } + + /// Free a UDP IPv4 port. + pub(super) fn free_udp_ipv4(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + ports.set.remove(&port); + } + + /// Allocate an ephemeral UDP IPv6 port. + pub(super) fn get_ephemeral_udp_ipv6(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if ports.set.insert(port) { + return Some(UdpPort { + port, + is_ipv4: false, + }); + } + } + None + } + + /// Claim a specific UDP IPv6 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_udp_ipv6(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + if ports.set.insert(port) { + Some(UdpPort { + port, + is_ipv4: false, + }) + } else { + None + } + } + + /// Free a UDP IPv6 port. + pub(super) fn free_udp_ipv6(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + ports.set.remove(&port); + } +} diff --git a/awkernel_lib/src/net/tcp.rs b/awkernel_lib/src/net/tcp.rs index 0fc0f5cff..2aed89cd4 100644 --- a/awkernel_lib/src/net/tcp.rs +++ b/awkernel_lib/src/net/tcp.rs @@ -19,13 +19,14 @@ impl TCPHdr { } } -#[allow(dead_code)] +#[cfg(not(feature = "std"))] #[derive(Debug)] pub struct TcpPort { port: u16, is_ipv4: bool, } +#[cfg(not(feature = "std"))] impl TcpPort { pub fn new(port: u16, is_ipv4: bool) -> Self { Self { port, is_ipv4 } @@ -37,16 +38,13 @@ impl TcpPort { } } +#[cfg(not(feature = "std"))] impl Drop for TcpPort { fn drop(&mut self) { - #[cfg(not(feature = "std"))] - { - let mut net_manager = super::NET_MANAGER.write(); - if self.is_ipv4 { - net_manager.decrement_port_in_use_tcp_ipv4(self.port); - } else { - net_manager.decrement_port_in_use_tcp_ipv6(self.port); - } + if self.is_ipv4 { + super::port_alloc::PORT_ALLOC.decrement_ref_tcp_ipv4(self.port); + } else { + super::port_alloc::PORT_ALLOC.decrement_ref_tcp_ipv6(self.port); } } } diff --git a/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs b/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs index 8d548e368..222f415be 100644 --- a/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs +++ b/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs @@ -6,7 +6,8 @@ use crate::sync::mcs::MCSNode; use alloc::{vec, vec::Vec}; use crate::net::{ - ip_addr::IpAddr, tcp::TcpPort, tcp_stream::TcpStream, NetManagerError, NET_MANAGER, + ip_addr::IpAddr, port_alloc::PORT_ALLOC, tcp::TcpPort, tcp_stream::TcpStream, NetManagerError, + NET_MANAGER, }; use super::SockTcpListener; @@ -30,14 +31,13 @@ impl SockTcpListener for TcpListener { tx_buffer_size: usize, backlogs: usize, ) -> Result { - let mut net_manager = NET_MANAGER.write(); - - // Find the interface that has the specified address. - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)? - .clone(); + // Validate the interface exists before claiming a port. + { + let net_manager = NET_MANAGER.read(); + if !net_manager.interfaces.contains_key(&interface_id) { + return Err(NetManagerError::InvalidInterfaceID); + } + } let port = if let Some(port) = port { if port == 0 { @@ -45,44 +45,45 @@ impl SockTcpListener for TcpListener { } if addr.is_ipv4() { - // Check if the specified port is available. - if net_manager.is_port_in_use_tcp_ipv4(port) { - return Err(NetManagerError::PortInUse); - } - - net_manager.port_in_use_tcp_ipv4(port) + PORT_ALLOC + .try_claim_tcp_ipv4(port) + .ok_or(NetManagerError::PortInUse)? } else { - // Check if the specified port is available. - if net_manager.is_port_in_use_tcp_ipv6(port) { - return Err(NetManagerError::PortInUse); - } - - net_manager.port_in_use_tcp_ipv6(port) + PORT_ALLOC + .try_claim_tcp_ipv6(port) + .ok_or(NetManagerError::PortInUse)? } } else if addr.is_ipv4() { // Find an ephemeral port. - net_manager - .get_ephemeral_port_tcp_ipv4() + PORT_ALLOC + .get_ephemeral_tcp_ipv4() .ok_or(NetManagerError::NoAvailablePort)? } else { // Find an ephemeral port. - net_manager - .get_ephemeral_port_tcp_ipv6() + PORT_ALLOC + .get_ephemeral_tcp_ipv6() .ok_or(NetManagerError::NoAvailablePort)? }; - drop(net_manager); - - let mut handles = Vec::new(); - - for _ in 0..backlogs { - // Create a TCP socket. - let socket = create_listen_socket(addr, port.port(), rx_buffer_size, tx_buffer_size); - - let handle = if_net.socket_set.write().add(socket); - - handles.push(handle); - } + // Add the listening sockets while holding the read lock, so that a concurrent + // interface removal (which takes the write lock) cannot orphan them. + // If the interface is gone, `port` (TcpPort) frees the port on the early return. + let handles = { + let net_manager = NET_MANAGER.read(); + let if_net = net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)?; + + let mut handles = Vec::new(); + for _ in 0..backlogs { + // Create a TCP socket. + let socket = + create_listen_socket(addr, port.port(), rx_buffer_size, tx_buffer_size); + handles.push(if_net.socket_set.write().add(socket)); + } + handles + }; Ok(TcpListener { handles, @@ -98,13 +99,10 @@ impl SockTcpListener for TcpListener { fn accept(&mut self, waker: &core::task::Waker) -> Result, NetManagerError> { // If there is a connected socket, return it. if let Some(handle) = self.connected_sockets.pop_front() { - let port = { - let mut net_manager = NET_MANAGER.write(); - if self.addr.is_ipv4() { - net_manager.port_in_use_tcp_ipv4(self.port.port()) - } else { - net_manager.port_in_use_tcp_ipv6(self.port.port()) - } + let port = if self.addr.is_ipv4() { + PORT_ALLOC.increment_ref_tcp_ipv4(self.port.port()) + } else { + PORT_ALLOC.increment_ref_tcp_ipv6(self.port.port()) }; return Ok(Some(TcpStream { handle, @@ -171,13 +169,10 @@ impl SockTcpListener for TcpListener { // If there is a connected socket, return it. if let Some(handle) = self.connected_sockets.pop_front() { - let port = { - let mut net_manager = NET_MANAGER.write(); - if self.addr.is_ipv4() { - net_manager.port_in_use_tcp_ipv4(self.port.port()) - } else { - net_manager.port_in_use_tcp_ipv6(self.port.port()) - } + let port = if self.addr.is_ipv4() { + PORT_ALLOC.increment_ref_tcp_ipv4(self.port.port()) + } else { + PORT_ALLOC.increment_ref_tcp_ipv6(self.port.port()) }; if_net.poll_tx_only(crate::cpu::raw_cpu_id() & (if_net.net_device.num_queues() - 1)); diff --git a/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs b/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs index f0e1dd6d3..bb301ff8a 100644 --- a/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs +++ b/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs @@ -1,4 +1,6 @@ -use crate::net::{ip_addr::IpAddr, tcp::TcpPort, NetManagerError, NET_MANAGER}; +use crate::net::{ + ip_addr::IpAddr, port_alloc::PORT_ALLOC, tcp::TcpPort, NetManagerError, NET_MANAGER, +}; use super::{SockTcpStream, TcpResult}; @@ -90,40 +92,47 @@ impl SockTcpStream for TcpStream { tx_buffer_size: usize, waker: &core::task::Waker, ) -> Result { - let mut net_manager = NET_MANAGER.write(); - - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)?; - let if_net = if_net.clone(); + // Validate the interface exists before claiming a port or allocating buffers. + { + let net_manager = NET_MANAGER.read(); + if !net_manager.interfaces.contains_key(&interface_id) { + return Err(NetManagerError::InvalidInterfaceID); + } + } let local_port = if remote_addr.is_ipv4() { - net_manager - .get_ephemeral_port_tcp_ipv4() + PORT_ALLOC + .get_ephemeral_tcp_ipv4() .ok_or(NetManagerError::NoAvailablePort)? } else { - net_manager - .get_ephemeral_port_tcp_ipv6() + PORT_ALLOC + .get_ephemeral_tcp_ipv6() .ok_or(NetManagerError::NoAvailablePort)? }; - drop(net_manager); - let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; rx_buffer_size]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; tx_buffer_size]); let socket = smoltcp::socket::tcp::Socket::new(rx_buffer, tx_buffer); - let handle; - { + // Add the socket and connect it while holding the read lock, so that a concurrent + // interface removal (which takes the write lock) cannot orphan it. If the interface + // is gone, `local_port` (TcpPort) frees the port on the early return. + let (handle, if_net) = { + let net_manager = NET_MANAGER.read(); + let if_net = net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)? + .clone(); + let mut node = MCSNode::new(); let mut inner = if_net.inner.lock(&mut node); let interface = inner.get_interface(); let mut socket_set = if_net.socket_set.write(); - handle = socket_set.add(socket); + let handle = socket_set.add(socket); let connect_is_err = { let mut node: MCSNode = MCSNode::new(); @@ -146,7 +155,11 @@ impl SockTcpStream for TcpStream { socket_set.remove(handle); return Err(NetManagerError::InvalidState); } - } + + drop(socket_set); + drop(inner); + (handle, if_net) + }; let que_id = crate::cpu::raw_cpu_id() & (if_net.net_device.num_queues() - 1); if_net.poll_tx_only(que_id); diff --git a/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs b/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs index bf332ac7d..06e2f6ea4 100644 --- a/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs +++ b/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs @@ -1,6 +1,10 @@ use core::net::Ipv4Addr; -use crate::net::{ip_addr::IpAddr, NET_MANAGER}; +use crate::net::{ + ip_addr::IpAddr, + port_alloc::{UdpPort, PORT_ALLOC}, + NET_MANAGER, +}; use awkernel_sync::{mcs::MCSNode, mutex::Mutex}; use super::{NetManagerError, SockUdp}; @@ -19,8 +23,9 @@ pub struct UdpSocket { handle: smoltcp::iface::SocketHandle, interface_id: u64, addr: IpAddr, - port: u16, - is_ipv4: bool, + // Held only to free the port via its `Drop` impl when the socket is dropped. + #[allow(dead_code)] + udp_port: UdpPort, joined_multicast_addr_v4: BTreeSet, } @@ -32,55 +37,43 @@ impl super::SockUdp for UdpSocket { rx_buffer_size: usize, tx_buffer_size: usize, ) -> Result { - let mut net_manager = NET_MANAGER.write(); + // Validate the interface exists before claiming a port or allocating buffers. + { + let net_manager = NET_MANAGER.read(); + if !net_manager.interfaces.contains_key(&interface_id) { + return Err(NetManagerError::InvalidInterfaceID); + } + } - let is_ipv4; - let port = if let Some(port) = port { + // Claim the port. `udp_port` is RAII: it frees the port on any early return below. + let udp_port = if let Some(port) = port { if port == 0 { return Err(NetManagerError::InvalidPort); } - // Check if the specified port is available. + // Check if the specified port is available and claim it atomically. if addr.is_ipv4() { - if net_manager.is_port_in_use_udp_ipv4(port) { - return Err(NetManagerError::PortInUse); - } - - is_ipv4 = true; - net_manager.set_port_in_use_udp_ipv4(port); - port + PORT_ALLOC + .try_claim_udp_ipv4(port) + .ok_or(NetManagerError::PortInUse)? } else { - if net_manager.is_port_in_use_udp_ipv6(port) { - return Err(NetManagerError::PortInUse); - } - - is_ipv4 = false; - net_manager.set_port_in_use_udp_ipv6(port); - port + PORT_ALLOC + .try_claim_udp_ipv6(port) + .ok_or(NetManagerError::PortInUse)? } } else { // Find an ephemeral port. if addr.is_ipv4() { - is_ipv4 = true; - net_manager - .get_ephemeral_port_udp_ipv4() + PORT_ALLOC + .get_ephemeral_udp_ipv4() .ok_or(NetManagerError::PortInUse)? } else { - is_ipv4 = false; - net_manager - .get_ephemeral_port_udp_ipv6() + PORT_ALLOC + .get_ephemeral_udp_ipv6() .ok_or(NetManagerError::PortInUse)? } }; - - // Find the interface that has the specified address. - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)? - .clone(); - - drop(net_manager); + let port = udp_port.port(); // Create a UDP socket. use smoltcp::socket::udp; @@ -115,15 +108,23 @@ impl super::SockUdp for UdpSocket { } } - // Add the socket to the interface. - let handle = if_net.socket_set.write().add(socket); + // Add the socket to the interface while holding the read lock, so that a + // concurrent interface removal (which takes the write lock) cannot orphan it. + let handle = { + let net_manager = NET_MANAGER.read(); + let if_net = net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)?; + let handle = if_net.socket_set.write().add(socket); + handle + }; Ok(UdpSocket { handle, interface_id, addr: addr.clone(), - port, - is_ipv4, + udp_port, joined_multicast_addr_v4: Default::default(), }) } @@ -331,16 +332,13 @@ impl Drop for UdpSocket { } } - let mut net_manager = NET_MANAGER.write(); - - if self.is_ipv4 { - net_manager.free_port_udp_ipv4(self.port); - } else { - net_manager.free_port_udp_ipv6(self.port); + { + let net_manager = NET_MANAGER.read(); + if let Some(if_net) = net_manager.interfaces.get(&self.interface_id) { + if_net.socket_set.write().remove(self.handle); + } } - if let Some(if_net) = net_manager.interfaces.get(&self.interface_id) { - if_net.socket_set.write().remove(self.handle); - } + // `self.udp_port` (UdpPort) frees the port via its Drop impl. } }