19 Commits

Author SHA1 Message Date
Mikaël Cluseau aa7f15516c bump dockerfile 2026-03-16 11:20:48 +01:00
Mikaël Cluseau 4619899e65 dls: add password hash function 2026-03-16 11:06:19 +01:00
Mikaël Cluseau 4b1edb2a55 reqwest: enable socks 2026-02-25 09:45:59 +01:00
Mikaël Cluseau d449fc8dcf dkl apply-config --dry-run 2026-02-21 18:15:39 +01:00
Mikaël Cluseau ddc82199fb cargo update 2026-02-21 08:18:07 +01:00
Mikaël Cluseau 61d31bc22c dls::Config.extra_ca_certs 2026-02-21 08:17:42 +01:00
Mikaël Cluseau d2293df011 base64: be a tolerant reader 2026-02-10 21:23:11 +01:00
Mikaël Cluseau 723cecff1b fix vg name compat 2026-02-10 15:41:40 +01:00
Mikaël Cluseau e8c9ee9885 wow base64 w/ and wo/ padding are incompatible 2026-01-25 21:59:23 +01:00
Mikaël Cluseau 6a6536bdfb files: add content64 for base64 encoded values 2026-01-25 20:01:50 +01:00
Mikaël Cluseau a6dc420275 cargo update 2026-01-07 18:24:14 +01:00
Mikaël Cluseau d9fa31ec33 bootstrap: impl Default for NetworkInterface 2026-01-07 18:24:02 +01:00
Mikaël Cluseau 93e5570293 use human_units for Duration params 2025-12-20 08:52:34 +01:00
Mikaël Cluseau fb3f8942d4 feat: (tcp) proxy 2025-12-19 23:16:06 +01:00
Mikaël Cluseau 7acc9e9a3e bootstrap: impl default for config 2025-12-19 18:21:03 +01:00
Mikaël Cluseau ac90b35142 add dls::Config 2025-12-12 16:58:40 +01:00
Mikaël Cluseau 298366a0aa Config: query files 2025-12-03 12:54:45 +01:00
Mikaël Cluseau ecbbb82c7a more default in dls HostConfig 2025-12-03 09:47:57 +01:00
Mikaël Cluseau ebd2f21d42 adjust initrd_files 2025-11-20 11:58:38 +01:00
13 changed files with 880 additions and 501 deletions
Generated
+534 -411
View File
File diff suppressed because it is too large Load Diff
+7 -2
View File
@@ -13,22 +13,27 @@ codegen-units = 1
[dependencies] [dependencies]
async-compression = { version = "0.4.27", features = ["tokio", "zstd"] } async-compression = { version = "0.4.27", features = ["tokio", "zstd"] }
base32 = "0.5.1" base32 = "0.5.1"
base64 = "0.22.1"
bytes = "1.10.1" bytes = "1.10.1"
chrono = { version = "0.4.41", default-features = false, features = ["clock", "now"] } chrono = { version = "0.4.41", default-features = false, features = ["clock", "now"] }
clap = { version = "4.5.40", features = ["derive", "env"] } clap = { version = "4.5.40", features = ["derive", "env"] }
clap_complete = { version = "4.5.54", features = ["unstable-dynamic"] } clap_complete = { version = "4.5.54", features = ["unstable-dynamic"] }
env_logger = "0.11.8" env_logger = "0.11.8"
eyre = "0.6.12" eyre = "0.6.12"
fastrand = "2.3.0"
futures = "0.3.31" futures = "0.3.31"
futures-util = "0.3.31" futures-util = "0.3.31"
glob = "0.3.2" glob = "0.3.2"
hex = "0.4.3" hex = "0.4.3"
human-units = "0.5.3"
log = "0.4.27" log = "0.4.27"
lz4 = "1.28.1" lz4 = "1.28.1"
nix = { version = "0.30.1", features = ["user"] } nix = { version = "0.31.2", features = ["user"] }
openssl = "0.10.73" openssl = "0.10.73"
page_size = "0.6.0" page_size = "0.6.0"
reqwest = { version = "0.12.20", features = ["json", "stream", "native-tls"] } reqwest = { version = "0.13.1", features = ["json", "stream", "native-tls", "socks"], default-features = false }
rpassword = "7.4.0"
rust-argon2 = "3.0.0"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140" serde_json = "1.0.140"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
+2 -2
View File
@@ -1,4 +1,4 @@
from mcluseau/rust:1.88.0 as build from mcluseau/rust:1.94.0 as build
workdir /app workdir /app
copy . . copy . .
@@ -10,6 +10,6 @@ run \
&& find target/release -maxdepth 1 -type f -executable -exec cp -v {} /dist/ + && find target/release -maxdepth 1 -type f -executable -exec cp -v {} /dist/ +
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
from alpine:3.22 from alpine:3.23
copy --from=build /dist/ /bin/ copy --from=build /dist/ /bin/
+45 -5
View File
@@ -3,21 +3,61 @@ use log::info;
use std::path::Path; use std::path::Path;
use tokio::fs; use tokio::fs;
pub async fn files(files: &[crate::File], root: &str) -> Result<()> { use crate::base64_decode;
pub async fn files(files: &[crate::File], root: &str, dry_run: bool) -> Result<()> {
for file in files { for file in files {
let path = chroot(root, &file.path); let path = chroot(root, &file.path);
let path = Path::new(&path); let path = Path::new(&path);
if let Some(parent) = path.parent() { if !dry_run && let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?; fs::create_dir_all(parent).await?;
} }
use crate::FileKind as K; use crate::FileKind as K;
match &file.kind { match &file.kind {
K::Content(content) => fs::write(path, content.as_bytes()).await?, K::Content(content) => {
K::Dir(true) => fs::create_dir(path).await?, if dry_run {
info!(
"would create {} ({} bytes from content)",
file.path,
content.len()
);
} else {
fs::write(path, content.as_bytes()).await?;
}
}
K::Content64(content) => {
let content = base64_decode(content)?;
if dry_run {
info!(
"would create {} ({} bytes from content64)",
file.path,
content.len()
);
} else {
fs::write(path, content).await?
}
}
K::Dir(true) => {
if dry_run {
info!("would create {} (directory)", file.path);
} else {
fs::create_dir(path).await?;
}
}
K::Dir(false) => {} // shouldn't happen, but semantic is to ignore K::Dir(false) => {} // shouldn't happen, but semantic is to ignore
K::Symlink(tgt) => fs::symlink(tgt, path).await?, K::Symlink(tgt) => {
if dry_run {
info!("would create {} (symlink to {})", file.path, tgt);
} else {
fs::symlink(tgt, path).await?;
}
}
}
if dry_run {
continue;
} }
match file.kind { match file.kind {
+39 -3
View File
@@ -1,6 +1,8 @@
use clap::{CommandFactory, Parser, Subcommand}; use clap::{CommandFactory, Parser, Subcommand};
use eyre::{format_err, Result}; use eyre::{format_err, Result};
use human_units::Duration;
use log::{debug, error}; use log::{debug, error};
use std::net::SocketAddr;
use tokio::fs; use tokio::fs;
#[derive(Parser)] #[derive(Parser)]
@@ -22,6 +24,9 @@ enum Command {
/// path prefix (aka chroot) /// path prefix (aka chroot)
#[arg(short = 'P', long, default_value = "/")] #[arg(short = 'P', long, default_value = "/")]
prefix: String, prefix: String,
/// don't really write files
#[arg(long)]
dry_run: bool,
}, },
Logger { Logger {
/// Path where the logs are stored /// Path where the logs are stored
@@ -65,6 +70,17 @@ enum Command {
#[arg(long, default_value = "/")] #[arg(long, default_value = "/")]
chroot: std::path::PathBuf, chroot: std::path::PathBuf,
}, },
Proxy {
#[arg(long, short = 'l')]
listen: Vec<SocketAddr>,
targets: Vec<SocketAddr>,
/// target polling interval
#[arg(long, default_value = "30s")]
poll: Duration,
/// connect or check timeout
#[arg(long, default_value = "5s")]
timeout: Duration,
},
} }
#[tokio::main(flavor = "current_thread")] #[tokio::main(flavor = "current_thread")]
@@ -84,9 +100,10 @@ async fn main() -> Result<()> {
config, config,
filters, filters,
prefix, prefix,
dry_run,
} => { } => {
let filters = parse_globs(&filters)?; let filters = parse_globs(&filters)?;
apply_config(&config, &filters, &prefix).await apply_config(&config, &filters, &prefix, dry_run).await
} }
C::Logger { C::Logger {
ref log_path, ref log_path,
@@ -126,10 +143,29 @@ async fn main() -> Result<()> {
.install(layer, version) .install(layer, version)
.await .await
} }
C::Proxy {
listen,
targets,
poll,
timeout,
} => Ok(dkl::proxy::Proxy {
listen_addrs: listen,
targets,
poll: poll.into(),
timeout: timeout.into(),
}
.run()
.await
.map(|_| ())?),
} }
} }
async fn apply_config(config_file: &str, filters: &[glob::Pattern], chroot: &str) -> Result<()> { async fn apply_config(
config_file: &str,
filters: &[glob::Pattern],
chroot: &str,
dry_run: bool,
) -> Result<()> {
let config = fs::read_to_string(config_file).await?; let config = fs::read_to_string(config_file).await?;
let config: dkl::Config = serde_yaml::from_str(&config)?; let config: dkl::Config = serde_yaml::from_str(&config)?;
@@ -141,7 +177,7 @@ async fn apply_config(config_file: &str, filters: &[glob::Pattern], chroot: &str
.collect() .collect()
}; };
dkl::apply::files(&files, chroot).await dkl::apply::files(&files, chroot, dry_run).await
} }
#[derive(Subcommand)] #[derive(Subcommand)]
+21 -6
View File
@@ -36,6 +36,10 @@ enum Command {
}, },
#[command(subcommand)] #[command(subcommand)]
DlSet(DlSet), DlSet(DlSet),
/// hash a password
Hash {
salt: String,
},
} }
#[derive(Subcommand)] #[derive(Subcommand)]
@@ -103,14 +107,16 @@ async fn main() -> eyre::Result<()> {
.parse_default_env() .parse_default_env()
.init(); .init();
let token = std::env::var("DLS_TOKEN").map_err(|_| format_err!("DLS_TOKEN should be set"))?; let dls = || {
let token = std::env::var("DLS_TOKEN").expect("DLS_TOKEN should be set");
let dls = dls::Client::new(cli.dls, token); dls::Client::new(cli.dls, token)
};
use Command as C; use Command as C;
match cli.command { match cli.command {
C::Clusters => write_json(&dls.clusters().await?), C::Clusters => write_json(&dls().clusters().await?),
C::Cluster { cluster, command } => { C::Cluster { cluster, command } => {
let dls = dls();
let cluster = dls.cluster(cluster); let cluster = dls.cluster(cluster);
use ClusterCommand as CC; use ClusterCommand as CC;
@@ -155,8 +161,9 @@ async fn main() -> eyre::Result<()> {
} }
} }
} }
C::Hosts => write_json(&dls.hosts().await?), C::Hosts => write_json(&dls().hosts().await?),
C::Host { out, host, asset } => { C::Host { out, host, asset } => {
let dls = dls();
let host_name = host.clone(); let host_name = host.clone();
let host = dls.host(host); let host = dls.host(host);
match asset { match asset {
@@ -171,7 +178,7 @@ async fn main() -> eyre::Result<()> {
C::DlSet(set) => match set { C::DlSet(set) => match set {
DlSet::Sign { expiry, items } => { DlSet::Sign { expiry, items } => {
let req = dls::DownloadSetReq { expiry, items }; let req = dls::DownloadSetReq { expiry, items };
let signed = dls.sign_dl_set(&req).await?; let signed = dls().sign_dl_set(&req).await?;
println!("{signed}"); println!("{signed}");
} }
DlSet::Show { signed_set } => { DlSet::Show { signed_set } => {
@@ -211,11 +218,19 @@ async fn main() -> eyre::Result<()> {
name, name,
asset, asset,
} => { } => {
let dls = dls();
let stream = dls.fetch_dl_set(&signed_set, &kind, &name, &asset).await?; let stream = dls.fetch_dl_set(&signed_set, &kind, &name, &asset).await?;
let mut out = create_asset_file(out, &kind, &name, &asset).await?; let mut out = create_asset_file(out, &kind, &name, &asset).await?;
copy_stream(stream, &mut out).await?; copy_stream(stream, &mut out).await?;
} }
}, },
C::Hash { salt } => {
let salt = dkl::base64_decode(&salt)?;
let passphrase = rpassword::prompt_password("password to hash: ")?;
let hash = dls::store::hash_password(&salt, &passphrase)?;
println!("hash (hex): {}", hex::encode(&hash));
println!("hash (base64): {}", dkl::base64_encode(&hash));
}
}; };
Ok(()) Ok(())
+16 -15
View File
@@ -2,7 +2,7 @@ use std::collections::BTreeMap as Map;
pub const TAKE_ALL: i16 = -1; pub const TAKE_ALL: i16 = -1;
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
pub struct Config { pub struct Config {
pub anti_phishing_code: String, pub anti_phishing_code: String,
@@ -42,21 +42,11 @@ impl Config {
pub fn new(bootstrap_dev: String) -> Self { pub fn new(bootstrap_dev: String) -> Self {
Self { Self {
anti_phishing_code: "Direktil<3".into(), anti_phishing_code: "Direktil<3".into(),
keymap: None,
modules: None,
resolv_conf: None,
vpns: Map::new(),
networks: vec![],
auths: vec![],
ssh: Default::default(),
pre_lvm_crypt: vec![],
lvm: vec![],
crypt: vec![],
signer_public_key: None,
bootstrap: Bootstrap { bootstrap: Bootstrap {
dev: bootstrap_dev, dev: bootstrap_dev,
seed: None, ..Default::default()
}, },
..Default::default()
} }
} }
} }
@@ -88,6 +78,17 @@ pub struct NetworkInterface {
pub udev: Option<UdevFilter>, pub udev: Option<UdevFilter>,
} }
impl Default for NetworkInterface {
fn default() -> Self {
Self {
var: "iface".into(),
n: 1,
regexps: Vec::new(),
udev: Some(UdevFilter::Eq("INTERFACE".into(), "eth0".into())),
}
}
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct SSHServer { pub struct SSHServer {
pub listen: String, pub listen: String,
@@ -104,7 +105,7 @@ impl Default for SSHServer {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct LvmVG { pub struct LvmVG {
#[serde(alias = "vg")] #[serde(rename = "vg", alias = "name")]
pub name: String, pub name: String,
pub pvs: LvmPV, pub pvs: LvmPV,
@@ -244,7 +245,7 @@ pub struct Raid {
pub stripes: Option<u8>, pub stripes: Option<u8>,
} }
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
pub struct Bootstrap { pub struct Bootstrap {
pub dev: String, pub dev: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
+34 -51
View File
@@ -5,7 +5,8 @@ use reqwest::Method;
use std::collections::BTreeMap as Map; use std::collections::BTreeMap as Map;
use std::fmt::Display; use std::fmt::Display;
use std::net::IpAddr; use std::net::IpAddr;
use std::time::Duration;
pub mod store;
pub struct Client { pub struct Client {
base_url: String, base_url: String,
@@ -160,6 +161,32 @@ impl<'t> Host<'t> {
} }
} }
#[derive(Default, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct Config {
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub clusters: Vec<ClusterConfig>,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub hosts: Vec<HostConfig>,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub host_templates: Vec<HostConfig>,
#[serde(default, rename = "SSLConfig")]
pub ssl_config: String,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub extra_ca_certs: Map<String, String>,
}
// compensate for go's encoder pitfalls
use serde::{Deserialize, Deserializer};
fn deserialize_null_as_default<'de, D, T>(deserializer: D) -> std::result::Result<T, D::Error>
where
T: Default + Deserialize<'de>,
D: Deserializer<'de>,
{
let opt = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}
#[derive(serde::Deserialize, serde::Serialize)] #[derive(serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
pub struct ClusterConfig { pub struct ClusterConfig {
@@ -172,15 +199,15 @@ pub struct ClusterConfig {
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
pub struct HostConfig { pub struct HostConfig {
pub name: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub cluster_name: Option<String>, pub cluster_name: Option<String>,
#[serde(rename = "IPs")] #[serde(rename = "IPs")]
pub ips: Vec<IpAddr>, pub ips: Vec<IpAddr>,
#[serde(skip_serializing_if = "Map::is_empty")] #[serde(default, skip_serializing_if = "Map::is_empty")]
pub labels: Map<String, String>, pub labels: Map<String, String>,
#[serde(skip_serializing_if = "Map::is_empty")] #[serde(default, skip_serializing_if = "Map::is_empty")]
pub annotations: Map<String, String>, pub annotations: Map<String, String>,
#[serde(rename = "IPXE", skip_serializing_if = "Option::is_none")] #[serde(rename = "IPXE", skip_serializing_if = "Option::is_none")]
@@ -190,10 +217,13 @@ pub struct HostConfig {
pub kernel: String, pub kernel: String,
pub versions: Map<String, String>, pub versions: Map<String, String>,
/// initrd config template
pub bootstrap_config: String, pub bootstrap_config: String,
/// files to add to the final initrd config, with rendering
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub initrd_files: Vec<crate::File>, pub initrd_files: Vec<crate::File>,
/// system config template
pub config: String, pub config: String,
} }
@@ -306,50 +336,3 @@ pub enum Error {
#[error("response parsing failed: {0}")] #[error("response parsing failed: {0}")]
Parse(serde_json::Error), Parse(serde_json::Error),
} }
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum File {
Static(crate::File),
Gen { path: String, from: ContentGen },
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContentGen {
CaCrt(CaRef),
TlsKey(TlsRef),
TlsCrt {
key: TlsRef,
ca: CaRef,
profile: CertProfile,
},
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CaRef {
Global(String),
Cluster(String, String),
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TlsRef {
Cluster(String, String),
Host(String, String),
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CertProfile {
Client,
Server,
/// basicaly Client+Server
Peer,
Kube {
user: String,
group: String,
duration: Duration,
},
}
+17
View File
@@ -0,0 +1,17 @@
pub fn hash_password(salt: &[u8], passphrase: &str) -> argon2::Result<[u8; 32]> {
let hash = argon2::hash_raw(
passphrase.as_bytes(),
salt,
&argon2::Config {
variant: argon2::Variant::Argon2id,
hash_length: 32,
time_cost: 1,
mem_cost: 65536,
thread_mode: argon2::ThreadMode::Parallel,
lanes: 4,
..Default::default()
},
)?;
unsafe { Ok(hash.try_into().unwrap_unchecked()) }
}
+1 -1
View File
@@ -1,4 +1,4 @@
use eyre::{format_err, Result}; use eyre::{Result, format_err};
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use std::path::PathBuf; use std::path::PathBuf;
use tokio::{fs, io::AsyncWriteExt, process::Command}; use tokio::{fs, io::AsyncWriteExt, process::Command};
+26 -3
View File
@@ -4,6 +4,7 @@ pub mod dls;
pub mod dynlay; pub mod dynlay;
pub mod fs; pub mod fs;
pub mod logger; pub mod logger;
pub mod proxy;
#[derive(Debug, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Default, serde::Deserialize, serde::Serialize)]
pub struct Config { pub struct Config {
@@ -52,7 +53,7 @@ pub struct User {
pub gid: Option<u32>, pub gid: Option<u32>,
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct File { pub struct File {
pub path: String, pub path: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -61,10 +62,32 @@ pub struct File {
pub kind: FileKind, pub kind: FileKind,
} }
#[derive(Debug, serde::Deserialize, serde::Serialize)] #[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "lowercase")]
pub enum FileKind { pub enum FileKind {
Content(String), Content(String),
Content64(String),
Symlink(String), Symlink(String),
Dir(bool), Dir(bool),
} }
// ------------------------------------------------------------------------
impl Config {
pub fn has_file(&self, path: &str) -> bool {
self.files.iter().any(|f| f.path == path)
}
pub fn file(&self, path: &str) -> Option<&File> {
self.files.iter().find(|f| f.path == path)
}
}
pub fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
use base64::{prelude::BASE64_STANDARD_NO_PAD as B64, Engine as _};
B64.decode(s.trim_end_matches('='))
}
pub fn base64_encode(b: &[u8]) -> String {
use base64::{prelude::BASE64_STANDARD as B64, Engine as _};
B64.encode(b)
}
+2 -2
View File
@@ -1,6 +1,6 @@
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder}; use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
use chrono::{DurationRound, TimeDelta, Utc}; use chrono::{DurationRound, TimeDelta, Utc};
use eyre::{format_err, Result}; use eyre::{Result, format_err};
use log::{debug, error, warn}; use log::{debug, error, warn};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Stdio; use std::process::Stdio;
@@ -9,7 +9,7 @@ use tokio::{
io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter},
process, process,
sync::mpsc, sync::mpsc,
time::{sleep, Duration}, time::{Duration, sleep},
}; };
pub type Timestamp = chrono::DateTime<Utc>; pub type Timestamp = chrono::DateTime<Utc>;
+136
View File
@@ -0,0 +1,136 @@
use log::{info, log_enabled, warn};
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::time::Duration;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>,
pub poll: Duration,
pub timeout: Duration,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to listen on {0}: {1}")]
ListenFailed(SocketAddr, std::io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Proxy {
pub async fn run(self) -> Result<Infallible> {
let mut listeners = Vec::with_capacity(self.listen_addrs.len());
for addr in self.listen_addrs {
listeners.push(
TcpListener::bind(&addr)
.await
.map_err(|e| Error::ListenFailed(addr, e))?,
);
info!("listening on {addr}");
}
// all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter())
.map(|addr| TargetStatus {
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect();
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners {
tokio::spawn(proxy_listener(listener, targets));
}
check_targets(targets, self.poll).await
}
}
struct TargetStatus {
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus {
fn is_up(&self) -> bool {
self.up.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
}
}
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
use tokio::time;
let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
poll_ticker.tick().await;
let mut tasks = tokio::task::JoinSet::new();
for target in targets {
tasks.spawn(target.connect());
}
tasks.join_all().await;
if log_enabled!(log::Level::Info) {
let mut infos = String::new();
for ts in targets.iter() {
infos.push_str(&format!("{} ", ts.addr));
infos.push_str(if ts.is_up() { "up " } else { "down " });
}
info!("{infos}");
}
}
}
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
let mut rng = fastrand::Rng::new();
loop {
let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
rng.shuffle(&mut active);
tokio::spawn(async move {
for i in active {
if let Ok(mut dst) = targets[i].connect().await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break;
}
}
});
}
}