Compare commits

..

2 Commits

Author SHA1 Message Date
93e5570293 use human_units for Duration params 2025-12-20 08:52:34 +01:00
fb3f8942d4 feat: (tcp) proxy 2025-12-19 23:16:06 +01:00
4 changed files with 85 additions and 47 deletions

16
Cargo.lock generated
View File

@ -279,6 +279,7 @@ dependencies = [
"futures-util", "futures-util",
"glob", "glob",
"hex", "hex",
"human-units",
"log", "log",
"lz4", "lz4",
"nix", "nix",
@ -586,6 +587,15 @@ version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "human-units"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47cf34dbcbbb7f1f6589c18e26f15a5a72592750dd5472037eb78fc0f92020d4"
dependencies = [
"paste",
]
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.8.1" version = "1.8.1"
@ -1059,6 +1069,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.2" version = "2.3.2"

View File

@ -24,6 +24,7 @@ futures = "0.3.31"
futures-util = "0.3.31" futures-util = "0.3.31"
glob = "0.3.2" glob = "0.3.2"
hex = "0.4.3" hex = "0.4.3"
human-units = "0.5.3"
log = "0.4.27" log = "0.4.27"
lz4 = "1.28.1" lz4 = "1.28.1"
nix = { version = "0.30.1", features = ["user"] } nix = { version = "0.30.1", features = ["user"] }

View File

@ -1,8 +1,8 @@
use clap::{CommandFactory, Parser, Subcommand}; use clap::{CommandFactory, Parser, Subcommand};
use eyre::{format_err, Result}; use eyre::{format_err, Result};
use human_units::Duration;
use log::{debug, error}; use log::{debug, error};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration;
use tokio::fs; use tokio::fs;
#[derive(Parser)] #[derive(Parser)]
@ -71,12 +71,12 @@ enum Command {
#[arg(long, short = 'l')] #[arg(long, short = 'l')]
listen: Vec<SocketAddr>, listen: Vec<SocketAddr>,
targets: Vec<SocketAddr>, targets: Vec<SocketAddr>,
/// target polling interval, in seconds /// target polling interval
#[arg(long, default_value = "30")] #[arg(long, default_value = "30s")]
poll: u16, poll: Duration,
/// connect or check timeout, in seconds /// connect or check timeout
#[arg(long, default_value = "5")] #[arg(long, default_value = "5s")]
timeout: u16, timeout: Duration,
}, },
} }
@ -147,8 +147,8 @@ async fn main() -> Result<()> {
} => Ok(dkl::proxy::Proxy { } => Ok(dkl::proxy::Proxy {
listen_addrs: listen, listen_addrs: listen,
targets, targets,
poll_interval: Duration::from_secs(poll.into()), poll: poll.into(),
timeout: Duration::from_secs(timeout.into()), timeout: timeout.into(),
} }
.run() .run()
.await .await

View File

@ -1,16 +1,17 @@
use log::info; use log::{info, log_enabled, warn};
use std::convert::Infallible; use std::convert::Infallible;
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::Ordering::Relaxed; use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::sync::{atomic::AtomicBool, Arc};
use std::time::Duration; use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy { pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>, pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>, pub targets: Vec<SocketAddr>,
pub poll_interval: Duration, pub poll: Duration,
pub timeout: Duration, pub timeout: Duration,
} }
@ -36,76 +37,96 @@ impl Proxy {
// all targets are initially ok (better land on a down one than just fail) // all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter()) let targets: Vec<_> = (self.targets.into_iter())
.map(|t| TargetStatus(t, AtomicBool::new(true))) .map(|addr| TargetStatus {
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect(); .collect();
let targets = Arc::new(targets);
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners { for listener in listeners {
tokio::spawn(proxy_listener(listener, targets.clone())); tokio::spawn(proxy_listener(listener, targets));
} }
check_targets(targets.clone(), self.poll_interval, self.timeout).await check_targets(targets, self.poll).await
} }
} }
struct TargetStatus(SocketAddr, AtomicBool); struct TargetStatus {
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus { impl TargetStatus {
fn is_up(&self) -> bool { fn is_up(&self) -> bool {
self.1.load(Relaxed) self.up.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
} }
} }
async fn check_targets(targets: Arc<Vec<TargetStatus>>, poll: Duration, timeout: Duration) -> ! { async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
use tokio::time; use tokio::time;
let mut poll_ticker = time::interval(poll); let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip); poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
let mut first = true;
loop { loop {
poll_ticker.tick().await; poll_ticker.tick().await;
for i in 0..targets.len() { let mut tasks = tokio::task::JoinSet::new();
let targets = targets.clone();
tokio::spawn(async move { check_target(&targets[i], timeout, first).await }); for target in targets {
tasks.spawn(target.connect());
} }
first = false; tasks.join_all().await;
}
}
async fn check_target(ts: &TargetStatus, timeout: Duration, first: bool) { if log_enabled!(log::Level::Info) {
let target = ts.0; let mut infos = String::new();
for ts in targets.iter() {
let is_up = match tokio::time::timeout(timeout, TcpStream::connect(target)).await { infos.push_str(&format!("{} ", ts.addr));
Ok(Ok(_)) => true, infos.push_str(if ts.is_up() { "up " } else { "down " });
_ => false, }
}; info!("{infos}");
let prev = ts.1.swap(is_up, Relaxed);
if first || prev != is_up {
if is_up {
info!("{target} is up");
} else {
info!("{target} is down");
} }
} }
} }
async fn proxy_listener(listener: TcpListener, targets: Arc<Vec<TargetStatus>>) { async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
let mut rng = fastrand::Rng::new(); let mut rng = fastrand::Rng::new();
loop { loop {
let mut active = Vec::with_capacity(targets.len()); let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed"); let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend(targets.iter().filter(|ts| ts.is_up()).map(|ts| ts.0)); active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
rng.shuffle(&mut active); rng.shuffle(&mut active);
tokio::spawn(async move { tokio::spawn(async move {
for target in active { for i in active {
if let Ok(mut dst) = TcpStream::connect(target).await { if let Ok(mut dst) = targets[i].connect().await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await; let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break; break;
} }