use log::{info, log_enabled, warn}; use std::convert::Infallible; use std::io; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; use std::time::Duration; use thiserror::Error; use tokio::net::{TcpListener, TcpStream}; use tokio::time; pub struct Proxy { pub listen_addrs: Vec, pub targets: Vec, pub poll: Duration, pub timeout: Duration, } #[derive(Debug, Error)] pub enum Error { #[error("failed to listen on {0}: {1}")] ListenFailed(SocketAddr, std::io::Error), } pub type Result = std::result::Result; impl Proxy { pub async fn run(self) -> Result { let mut listeners = Vec::with_capacity(self.listen_addrs.len()); for addr in self.listen_addrs { listeners.push( TcpListener::bind(&addr) .await .map_err(|e| Error::ListenFailed(addr, e))?, ); info!("listening on {addr}"); } // all targets are initially ok (better land on a down one than just fail) let targets: Vec<_> = (self.targets.into_iter()) .map(|addr| TargetStatus { addr, up: AtomicBool::new(true), timeout: self.timeout, }) .collect(); // the proxy runs forever -> using 'static is not a leak let targets = targets.leak(); for listener in listeners { tokio::spawn(proxy_listener(listener, targets)); } check_targets(targets, self.poll).await } } struct TargetStatus { addr: SocketAddr, up: AtomicBool, timeout: Duration, } impl TargetStatus { fn is_up(&self) -> bool { self.up.load(Relaxed) } fn set_up(&self, is_up: bool) { let prev = self.up.swap(is_up, Relaxed); if prev != is_up { if is_up { info!("{} is up", self.addr); } else { warn!("{} is down", self.addr); } } } async fn connect(&self) -> io::Result { let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await { Ok(r) => r, Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), }; self.set_up(r.is_ok()); r } } async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! { use tokio::time; let mut poll_ticker = time::interval(poll); poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip); loop { poll_ticker.tick().await; let mut tasks = tokio::task::JoinSet::new(); for target in targets { tasks.spawn(target.connect()); } tasks.join_all().await; if log_enabled!(log::Level::Info) { let mut infos = String::new(); for ts in targets.iter() { infos.push_str(&format!("{} ", ts.addr)); infos.push_str(if ts.is_up() { "up " } else { "down " }); } info!("{infos}"); } } } async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) { let mut rng = fastrand::Rng::new(); loop { let mut active = Vec::with_capacity(targets.len()); let (mut src, _) = listener.accept().await.expect("listener.accept() failed"); active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i))); rng.shuffle(&mut active); tokio::spawn(async move { for i in active { if let Ok(mut dst) = targets[i].connect().await { let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await; break; } } }); } }