feat: (tcp) proxy

This commit is contained in:
Mikaël Cluseau
2025-12-19 21:33:17 +01:00
parent 7acc9e9a3e
commit fb3f8942d4
5 changed files with 253 additions and 112 deletions

View File

@@ -1,6 +1,8 @@
use clap::{CommandFactory, Parser, Subcommand};
use eyre::{format_err, Result};
use log::{debug, error};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::fs;
#[derive(Parser)]
@@ -65,6 +67,17 @@ enum Command {
#[arg(long, default_value = "/")]
chroot: std::path::PathBuf,
},
Proxy {
#[arg(long, short = 'l')]
listen: Vec<SocketAddr>,
targets: Vec<SocketAddr>,
/// target polling interval, in seconds
#[arg(long, default_value = "30")]
poll: u16,
/// connect or check timeout, in seconds
#[arg(long, default_value = "5")]
timeout: u16,
},
}
#[tokio::main(flavor = "current_thread")]
@@ -126,6 +139,20 @@ async fn main() -> Result<()> {
.install(layer, version)
.await
}
C::Proxy {
listen,
targets,
poll,
timeout,
} => Ok(dkl::proxy::Proxy {
listen_addrs: listen,
targets,
poll: Duration::from_secs(poll.into()),
timeout: Duration::from_secs(timeout.into()),
}
.run()
.await
.map(|_| ())?),
}
}

View File

@@ -1,4 +1,5 @@
pub mod apply;
pub mod proxy;
pub mod bootstrap;
pub mod dls;
pub mod dynlay;

136
src/proxy.rs Normal file
View File

@@ -0,0 +1,136 @@
use log::{info, log_enabled, warn};
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::time::Duration;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>,
pub poll: Duration,
pub timeout: Duration,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to listen on {0}: {1}")]
ListenFailed(SocketAddr, std::io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Proxy {
pub async fn run(self) -> Result<Infallible> {
let mut listeners = Vec::with_capacity(self.listen_addrs.len());
for addr in self.listen_addrs {
listeners.push(
TcpListener::bind(&addr)
.await
.map_err(|e| Error::ListenFailed(addr, e))?,
);
info!("listening on {addr}");
}
// all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter())
.map(|addr| TargetStatus {
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect();
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners {
tokio::spawn(proxy_listener(listener, targets));
}
check_targets(targets, self.poll).await
}
}
struct TargetStatus {
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus {
fn is_up(&self) -> bool {
self.up.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
}
}
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
use tokio::time;
let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
poll_ticker.tick().await;
let mut tasks = tokio::task::JoinSet::new();
for target in targets {
tasks.spawn(target.connect());
}
tasks.join_all().await;
if log_enabled!(log::Level::Info) {
let mut infos = String::new();
for ts in targets.iter() {
infos.push_str(&format!("{} ", ts.addr));
infos.push_str(if ts.is_up() { "up " } else { "down " });
}
info!("{infos}");
}
}
}
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
let mut rng = fastrand::Rng::new();
loop {
let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
rng.shuffle(&mut active);
tokio::spawn(async move {
for i in active {
if let Ok(mut dst) = targets[i].connect().await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break;
}
}
});
}
}