Compare commits

..

1 Commits

Author SHA1 Message Date
a7ebd827fc feat: (tcp) proxy 2025-12-19 21:33:17 +01:00
10 changed files with 238 additions and 948 deletions

971
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,6 @@ codegen-units = 1
[dependencies] [dependencies]
async-compression = { version = "0.4.27", features = ["tokio", "zstd"] } async-compression = { version = "0.4.27", features = ["tokio", "zstd"] }
base32 = "0.5.1" base32 = "0.5.1"
base64 = "0.22.1"
bytes = "1.10.1" bytes = "1.10.1"
chrono = { version = "0.4.41", default-features = false, features = ["clock", "now"] } chrono = { version = "0.4.41", default-features = false, features = ["clock", "now"] }
clap = { version = "4.5.40", features = ["derive", "env"] } clap = { version = "4.5.40", features = ["derive", "env"] }
@ -25,13 +24,12 @@ futures = "0.3.31"
futures-util = "0.3.31" futures-util = "0.3.31"
glob = "0.3.2" glob = "0.3.2"
hex = "0.4.3" hex = "0.4.3"
human-units = "0.5.3"
log = "0.4.27" log = "0.4.27"
lz4 = "1.28.1" lz4 = "1.28.1"
nix = { version = "0.30.1", features = ["user"] } nix = { version = "0.30.1", features = ["user"] }
openssl = "0.10.73" openssl = "0.10.73"
page_size = "0.6.0" page_size = "0.6.0"
reqwest = { version = "0.13.1", features = ["json", "stream", "native-tls"] } reqwest = { version = "0.12.20", features = ["json", "stream", "native-tls"] }
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140" serde_json = "1.0.140"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"

View File

@ -3,61 +3,21 @@ use log::info;
use std::path::Path; use std::path::Path;
use tokio::fs; use tokio::fs;
use crate::base64_decode; pub async fn files(files: &[crate::File], root: &str) -> Result<()> {
pub async fn files(files: &[crate::File], root: &str, dry_run: bool) -> Result<()> {
for file in files { for file in files {
let path = chroot(root, &file.path); let path = chroot(root, &file.path);
let path = Path::new(&path); let path = Path::new(&path);
if !dry_run && let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?; fs::create_dir_all(parent).await?;
} }
use crate::FileKind as K; use crate::FileKind as K;
match &file.kind { match &file.kind {
K::Content(content) => { K::Content(content) => fs::write(path, content.as_bytes()).await?,
if dry_run { K::Dir(true) => fs::create_dir(path).await?,
info!(
"would create {} ({} bytes from content)",
file.path,
content.len()
);
} else {
fs::write(path, content.as_bytes()).await?;
}
}
K::Content64(content) => {
let content = base64_decode(content)?;
if dry_run {
info!(
"would create {} ({} bytes from content64)",
file.path,
content.len()
);
} else {
fs::write(path, content).await?
}
}
K::Dir(true) => {
if dry_run {
info!("would create {} (directory)", file.path);
} else {
fs::create_dir(path).await?;
}
}
K::Dir(false) => {} // shouldn't happen, but semantic is to ignore K::Dir(false) => {} // shouldn't happen, but semantic is to ignore
K::Symlink(tgt) => { K::Symlink(tgt) => fs::symlink(tgt, path).await?,
if dry_run {
info!("would create {} (symlink to {})", file.path, tgt);
} else {
fs::symlink(tgt, path).await?;
}
}
}
if dry_run {
continue;
} }
match file.kind { match file.kind {

View File

@ -1,8 +1,8 @@
use clap::{CommandFactory, Parser, Subcommand}; use clap::{CommandFactory, Parser, Subcommand};
use eyre::{format_err, Result}; use eyre::{format_err, Result};
use human_units::Duration;
use log::{debug, error}; use log::{debug, error};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration;
use tokio::fs; use tokio::fs;
#[derive(Parser)] #[derive(Parser)]
@ -24,9 +24,6 @@ enum Command {
/// path prefix (aka chroot) /// path prefix (aka chroot)
#[arg(short = 'P', long, default_value = "/")] #[arg(short = 'P', long, default_value = "/")]
prefix: String, prefix: String,
/// don't really write files
#[arg(long)]
dry_run: bool,
}, },
Logger { Logger {
/// Path where the logs are stored /// Path where the logs are stored
@ -74,12 +71,12 @@ enum Command {
#[arg(long, short = 'l')] #[arg(long, short = 'l')]
listen: Vec<SocketAddr>, listen: Vec<SocketAddr>,
targets: Vec<SocketAddr>, targets: Vec<SocketAddr>,
/// target polling interval /// target polling interval, in seconds
#[arg(long, default_value = "30s")] #[arg(long, default_value = "30")]
poll: Duration, poll: u16,
/// connect or check timeout /// connect or check timeout, in seconds
#[arg(long, default_value = "5s")] #[arg(long, default_value = "5")]
timeout: Duration, timeout: u16,
}, },
} }
@ -100,10 +97,9 @@ async fn main() -> Result<()> {
config, config,
filters, filters,
prefix, prefix,
dry_run,
} => { } => {
let filters = parse_globs(&filters)?; let filters = parse_globs(&filters)?;
apply_config(&config, &filters, &prefix, dry_run).await apply_config(&config, &filters, &prefix).await
} }
C::Logger { C::Logger {
ref log_path, ref log_path,
@ -151,8 +147,8 @@ async fn main() -> Result<()> {
} => Ok(dkl::proxy::Proxy { } => Ok(dkl::proxy::Proxy {
listen_addrs: listen, listen_addrs: listen,
targets, targets,
poll: poll.into(), poll_interval: Duration::from_secs(poll.into()),
timeout: timeout.into(), timeout: Duration::from_secs(timeout.into()),
} }
.run() .run()
.await .await
@ -160,12 +156,7 @@ async fn main() -> Result<()> {
} }
} }
async fn apply_config( async fn apply_config(config_file: &str, filters: &[glob::Pattern], chroot: &str) -> Result<()> {
config_file: &str,
filters: &[glob::Pattern],
chroot: &str,
dry_run: bool,
) -> Result<()> {
let config = fs::read_to_string(config_file).await?; let config = fs::read_to_string(config_file).await?;
let config: dkl::Config = serde_yaml::from_str(&config)?; let config: dkl::Config = serde_yaml::from_str(&config)?;
@ -177,7 +168,7 @@ async fn apply_config(
.collect() .collect()
}; };
dkl::apply::files(&files, chroot, dry_run).await dkl::apply::files(&files, chroot).await
} }
#[derive(Subcommand)] #[derive(Subcommand)]

View File

@ -78,17 +78,6 @@ pub struct NetworkInterface {
pub udev: Option<UdevFilter>, pub udev: Option<UdevFilter>,
} }
impl Default for NetworkInterface {
fn default() -> Self {
Self {
var: "iface".into(),
n: 1,
regexps: Vec::new(),
udev: Some(UdevFilter::Eq("INTERFACE".into(), "eth0".into())),
}
}
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct SSHServer { pub struct SSHServer {
pub listen: String, pub listen: String,
@ -105,7 +94,7 @@ impl Default for SSHServer {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct LvmVG { pub struct LvmVG {
#[serde(rename = "vg", alias = "name")] #[serde(alias = "vg")]
pub name: String, pub name: String,
pub pvs: LvmPV, pub pvs: LvmPV,

View File

@ -170,8 +170,6 @@ pub struct Config {
pub host_templates: Vec<HostConfig>, pub host_templates: Vec<HostConfig>,
#[serde(default, rename = "SSLConfig")] #[serde(default, rename = "SSLConfig")]
pub ssl_config: String, pub ssl_config: String,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub extra_ca_certs: Map<String, String>,
} }
// compensate for go's encoder pitfalls // compensate for go's encoder pitfalls

View File

@ -1,4 +1,4 @@
use eyre::{Result, format_err}; use eyre::{format_err, Result};
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use std::path::PathBuf; use std::path::PathBuf;
use tokio::{fs, io::AsyncWriteExt, process::Command}; use tokio::{fs, io::AsyncWriteExt, process::Command};

View File

@ -1,10 +1,10 @@
pub mod apply; pub mod apply;
pub mod proxy;
pub mod bootstrap; pub mod bootstrap;
pub mod dls; pub mod dls;
pub mod dynlay; pub mod dynlay;
pub mod fs; pub mod fs;
pub mod logger; pub mod logger;
pub mod proxy;
#[derive(Debug, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Default, serde::Deserialize, serde::Serialize)]
pub struct Config { pub struct Config {
@ -63,10 +63,9 @@ pub struct File {
} }
#[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "snake_case")]
pub enum FileKind { pub enum FileKind {
Content(String), Content(String),
Content64(String),
Symlink(String), Symlink(String),
Dir(bool), Dir(bool),
} }
@ -81,8 +80,3 @@ impl Config {
self.files.iter().find(|f| f.path == path) self.files.iter().find(|f| f.path == path)
} }
} }
pub fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
use base64::{Engine, prelude::BASE64_STANDARD_NO_PAD as B64};
B64.decode(s.trim_end_matches('='))
}

View File

@ -1,6 +1,6 @@
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder}; use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
use chrono::{DurationRound, TimeDelta, Utc}; use chrono::{DurationRound, TimeDelta, Utc};
use eyre::{Result, format_err}; use eyre::{format_err, Result};
use log::{debug, error, warn}; use log::{debug, error, warn};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Stdio; use std::process::Stdio;
@ -9,7 +9,7 @@ use tokio::{
io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter},
process, process,
sync::mpsc, sync::mpsc,
time::{Duration, sleep}, time::{sleep, Duration},
}; };
pub type Timestamp = chrono::DateTime<Utc>; pub type Timestamp = chrono::DateTime<Utc>;

View File

@ -1,17 +1,16 @@
use log::{info, log_enabled, warn}; use log::info;
use std::convert::Infallible; use std::convert::Infallible;
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; use std::sync::atomic::Ordering::Relaxed;
use std::sync::{atomic::AtomicBool, Arc};
use std::time::Duration; use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy { pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>, pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>, pub targets: Vec<SocketAddr>,
pub poll: Duration, pub poll_interval: Duration,
pub timeout: Duration, pub timeout: Duration,
} }
@ -37,96 +36,76 @@ impl Proxy {
// all targets are initially ok (better land on a down one than just fail) // all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter()) let targets: Vec<_> = (self.targets.into_iter())
.map(|addr| TargetStatus { .map(|t| TargetStatus(t, AtomicBool::new(true)))
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect(); .collect();
let targets = Arc::new(targets);
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners { for listener in listeners {
tokio::spawn(proxy_listener(listener, targets)); tokio::spawn(proxy_listener(listener, targets.clone()));
} }
check_targets(targets, self.poll).await check_targets(targets.clone(), self.poll_interval, self.timeout).await
} }
} }
struct TargetStatus { struct TargetStatus(SocketAddr, AtomicBool);
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus { impl TargetStatus {
fn is_up(&self) -> bool { fn is_up(&self) -> bool {
self.up.load(Relaxed) self.1.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
} }
} }
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! { async fn check_targets(targets: Arc<Vec<TargetStatus>>, poll: Duration, timeout: Duration) -> ! {
use tokio::time; use tokio::time;
let mut poll_ticker = time::interval(poll); let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip); poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
let mut first = true;
loop { loop {
poll_ticker.tick().await; poll_ticker.tick().await;
let mut tasks = tokio::task::JoinSet::new(); for i in 0..targets.len() {
let targets = targets.clone();
for target in targets { tokio::spawn(async move { check_target(&targets[i], timeout, first).await });
tasks.spawn(target.connect());
} }
tasks.join_all().await; first = false;
}
}
if log_enabled!(log::Level::Info) { async fn check_target(ts: &TargetStatus, timeout: Duration, first: bool) {
let mut infos = String::new(); let target = ts.0;
for ts in targets.iter() {
infos.push_str(&format!("{} ", ts.addr)); let is_up = match tokio::time::timeout(timeout, TcpStream::connect(target)).await {
infos.push_str(if ts.is_up() { "up " } else { "down " }); Ok(Ok(_)) => true,
} _ => false,
info!("{infos}"); };
let prev = ts.1.swap(is_up, Relaxed);
if first || prev != is_up {
if is_up {
info!("{target} is up");
} else {
info!("{target} is down");
} }
} }
} }
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) { async fn proxy_listener(listener: TcpListener, targets: Arc<Vec<TargetStatus>>) {
let mut rng = fastrand::Rng::new(); let mut rng = fastrand::Rng::new();
loop { loop {
let mut active = Vec::with_capacity(targets.len()); let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed"); let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i))); active.extend(targets.iter().filter(|ts| ts.is_up()).map(|ts| ts.0));
rng.shuffle(&mut active); rng.shuffle(&mut active);
tokio::spawn(async move { tokio::spawn(async move {
for i in active { for target in active {
if let Ok(mut dst) = targets[i].connect().await { if let Ok(mut dst) = TcpStream::connect(target).await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await; let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break; break;
} }