137 lines
3.7 KiB
Rust
137 lines
3.7 KiB
Rust
|
|
use log::{info, log_enabled, warn};
|
||
|
|
use std::convert::Infallible;
|
||
|
|
use std::io;
|
||
|
|
use std::net::SocketAddr;
|
||
|
|
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
|
||
|
|
use std::time::Duration;
|
||
|
|
use thiserror::Error;
|
||
|
|
use tokio::net::{TcpListener, TcpStream};
|
||
|
|
use tokio::time;
|
||
|
|
|
||
|
|
pub struct Proxy {
|
||
|
|
pub listen_addrs: Vec<SocketAddr>,
|
||
|
|
pub targets: Vec<SocketAddr>,
|
||
|
|
pub poll: Duration,
|
||
|
|
pub timeout: Duration,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Error)]
|
||
|
|
pub enum Error {
|
||
|
|
#[error("failed to listen on {0}: {1}")]
|
||
|
|
ListenFailed(SocketAddr, std::io::Error),
|
||
|
|
}
|
||
|
|
|
||
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
||
|
|
|
||
|
|
impl Proxy {
|
||
|
|
pub async fn run(self) -> Result<Infallible> {
|
||
|
|
let mut listeners = Vec::with_capacity(self.listen_addrs.len());
|
||
|
|
for addr in self.listen_addrs {
|
||
|
|
listeners.push(
|
||
|
|
TcpListener::bind(&addr)
|
||
|
|
.await
|
||
|
|
.map_err(|e| Error::ListenFailed(addr, e))?,
|
||
|
|
);
|
||
|
|
info!("listening on {addr}");
|
||
|
|
}
|
||
|
|
|
||
|
|
// all targets are initially ok (better land on a down one than just fail)
|
||
|
|
let targets: Vec<_> = (self.targets.into_iter())
|
||
|
|
.map(|addr| TargetStatus {
|
||
|
|
addr,
|
||
|
|
up: AtomicBool::new(true),
|
||
|
|
timeout: self.timeout,
|
||
|
|
})
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
// the proxy runs forever -> using 'static is not a leak
|
||
|
|
let targets = targets.leak();
|
||
|
|
|
||
|
|
for listener in listeners {
|
||
|
|
tokio::spawn(proxy_listener(listener, targets));
|
||
|
|
}
|
||
|
|
|
||
|
|
check_targets(targets, self.poll).await
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
struct TargetStatus {
|
||
|
|
addr: SocketAddr,
|
||
|
|
up: AtomicBool,
|
||
|
|
timeout: Duration,
|
||
|
|
}
|
||
|
|
impl TargetStatus {
|
||
|
|
fn is_up(&self) -> bool {
|
||
|
|
self.up.load(Relaxed)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn set_up(&self, is_up: bool) {
|
||
|
|
let prev = self.up.swap(is_up, Relaxed);
|
||
|
|
if prev != is_up {
|
||
|
|
if is_up {
|
||
|
|
info!("{} is up", self.addr);
|
||
|
|
} else {
|
||
|
|
warn!("{} is down", self.addr);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn connect(&self) -> io::Result<TcpStream> {
|
||
|
|
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
|
||
|
|
Ok(r) => r,
|
||
|
|
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
|
||
|
|
};
|
||
|
|
|
||
|
|
self.set_up(r.is_ok());
|
||
|
|
r
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
|
||
|
|
use tokio::time;
|
||
|
|
let mut poll_ticker = time::interval(poll);
|
||
|
|
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
|
||
|
|
|
||
|
|
loop {
|
||
|
|
poll_ticker.tick().await;
|
||
|
|
|
||
|
|
let mut tasks = tokio::task::JoinSet::new();
|
||
|
|
|
||
|
|
for target in targets {
|
||
|
|
tasks.spawn(target.connect());
|
||
|
|
}
|
||
|
|
|
||
|
|
tasks.join_all().await;
|
||
|
|
|
||
|
|
if log_enabled!(log::Level::Info) {
|
||
|
|
let mut infos = String::new();
|
||
|
|
for ts in targets.iter() {
|
||
|
|
infos.push_str(&format!("{} ", ts.addr));
|
||
|
|
infos.push_str(if ts.is_up() { "up " } else { "down " });
|
||
|
|
}
|
||
|
|
info!("{infos}");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
|
||
|
|
let mut rng = fastrand::Rng::new();
|
||
|
|
|
||
|
|
loop {
|
||
|
|
let mut active = Vec::with_capacity(targets.len());
|
||
|
|
let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
|
||
|
|
|
||
|
|
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
|
||
|
|
rng.shuffle(&mut active);
|
||
|
|
|
||
|
|
tokio::spawn(async move {
|
||
|
|
for i in active {
|
||
|
|
if let Ok(mut dst) = targets[i].connect().await {
|
||
|
|
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|