Files
dkl/src/proxy.rs

137 lines
3.7 KiB
Rust
Raw Normal View History

2025-12-19 21:33:17 +01:00
use log::{info, log_enabled, warn};
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::time::Duration;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>,
pub poll: Duration,
pub timeout: Duration,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to listen on {0}: {1}")]
ListenFailed(SocketAddr, std::io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Proxy {
pub async fn run(self) -> Result<Infallible> {
let mut listeners = Vec::with_capacity(self.listen_addrs.len());
for addr in self.listen_addrs {
listeners.push(
TcpListener::bind(&addr)
.await
.map_err(|e| Error::ListenFailed(addr, e))?,
);
info!("listening on {addr}");
}
// all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter())
.map(|addr| TargetStatus {
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect();
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners {
tokio::spawn(proxy_listener(listener, targets));
}
check_targets(targets, self.poll).await
}
}
struct TargetStatus {
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus {
fn is_up(&self) -> bool {
self.up.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
}
}
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
use tokio::time;
let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
poll_ticker.tick().await;
let mut tasks = tokio::task::JoinSet::new();
for target in targets {
tasks.spawn(target.connect());
}
tasks.join_all().await;
if log_enabled!(log::Level::Info) {
let mut infos = String::new();
for ts in targets.iter() {
infos.push_str(&format!("{} ", ts.addr));
infos.push_str(if ts.is_up() { "up " } else { "down " });
}
info!("{infos}");
}
}
}
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
let mut rng = fastrand::Rng::new();
loop {
let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
rng.shuffle(&mut active);
tokio::spawn(async move {
for i in active {
if let Ok(mut dst) = targets[i].connect().await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break;
}
}
});
}
}