Compare commits

..

1 Commits

Author SHA1 Message Date
93f3af0ba8 dls::File + variants for TLS 2025-11-20 11:56:23 +01:00
7 changed files with 232 additions and 809 deletions

748
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,18 +19,16 @@ clap = { version = "4.5.40", features = ["derive", "env"] }
clap_complete = { version = "4.5.54", features = ["unstable-dynamic"] }
env_logger = "0.11.8"
eyre = "0.6.12"
fastrand = "2.3.0"
futures = "0.3.31"
futures-util = "0.3.31"
glob = "0.3.2"
hex = "0.4.3"
human-units = "0.5.3"
log = "0.4.27"
lz4 = "1.28.1"
nix = { version = "0.30.1", features = ["user"] }
openssl = "0.10.73"
page_size = "0.6.0"
reqwest = { version = "0.13.1", features = ["json", "stream", "native-tls"] }
reqwest = { version = "0.12.20", features = ["json", "stream", "native-tls"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
serde_yaml = "0.9.34"

View File

@ -1,8 +1,6 @@
use clap::{CommandFactory, Parser, Subcommand};
use eyre::{format_err, Result};
use human_units::Duration;
use log::{debug, error};
use std::net::SocketAddr;
use tokio::fs;
#[derive(Parser)]
@ -67,17 +65,6 @@ enum Command {
#[arg(long, default_value = "/")]
chroot: std::path::PathBuf,
},
Proxy {
#[arg(long, short = 'l')]
listen: Vec<SocketAddr>,
targets: Vec<SocketAddr>,
/// target polling interval
#[arg(long, default_value = "30s")]
poll: Duration,
/// connect or check timeout
#[arg(long, default_value = "5s")]
timeout: Duration,
},
}
#[tokio::main(flavor = "current_thread")]
@ -139,20 +126,6 @@ async fn main() -> Result<()> {
.install(layer, version)
.await
}
C::Proxy {
listen,
targets,
poll,
timeout,
} => Ok(dkl::proxy::Proxy {
listen_addrs: listen,
targets,
poll: poll.into(),
timeout: timeout.into(),
}
.run()
.await
.map(|_| ())?),
}
}

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap as Map;
pub const TAKE_ALL: i16 = -1;
#[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct Config {
pub anti_phishing_code: String,
@ -42,11 +42,21 @@ impl Config {
pub fn new(bootstrap_dev: String) -> Self {
Self {
anti_phishing_code: "Direktil<3".into(),
keymap: None,
modules: None,
resolv_conf: None,
vpns: Map::new(),
networks: vec![],
auths: vec![],
ssh: Default::default(),
pre_lvm_crypt: vec![],
lvm: vec![],
crypt: vec![],
signer_public_key: None,
bootstrap: Bootstrap {
dev: bootstrap_dev,
..Default::default()
seed: None,
},
..Default::default()
}
}
}
@ -78,17 +88,6 @@ pub struct NetworkInterface {
pub udev: Option<UdevFilter>,
}
impl Default for NetworkInterface {
fn default() -> Self {
Self {
var: "iface".into(),
n: 1,
regexps: Vec::new(),
udev: Some(UdevFilter::Eq("INTERFACE".into(), "eth0".into())),
}
}
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct SSHServer {
pub listen: String,
@ -245,7 +244,7 @@ pub struct Raid {
pub stripes: Option<u8>,
}
#[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct Bootstrap {
pub dev: String,
#[serde(skip_serializing_if = "Option::is_none")]

View File

@ -5,6 +5,7 @@ use reqwest::Method;
use std::collections::BTreeMap as Map;
use std::fmt::Display;
use std::net::IpAddr;
use std::time::Duration;
pub struct Client {
base_url: String,
@ -159,30 +160,6 @@ impl<'t> Host<'t> {
}
}
#[derive(Default, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct Config {
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub clusters: Vec<ClusterConfig>,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub hosts: Vec<HostConfig>,
#[serde(default, deserialize_with = "deserialize_null_as_default")]
pub host_templates: Vec<HostConfig>,
#[serde(default, rename = "SSLConfig")]
pub ssl_config: String,
}
// compensate for go's encoder pitfalls
use serde::{Deserialize, Deserializer};
fn deserialize_null_as_default<'de, D, T>(deserializer: D) -> std::result::Result<T, D::Error>
where
T: Default + Deserialize<'de>,
D: Deserializer<'de>,
{
let opt = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}
#[derive(serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct ClusterConfig {
@ -195,15 +172,15 @@ pub struct ClusterConfig {
#[serde(rename_all = "PascalCase")]
pub struct HostConfig {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none")]
pub cluster_name: Option<String>,
#[serde(rename = "IPs")]
pub ips: Vec<IpAddr>,
#[serde(default, skip_serializing_if = "Map::is_empty")]
#[serde(skip_serializing_if = "Map::is_empty")]
pub labels: Map<String, String>,
#[serde(default, skip_serializing_if = "Map::is_empty")]
#[serde(skip_serializing_if = "Map::is_empty")]
pub annotations: Map<String, String>,
#[serde(rename = "IPXE", skip_serializing_if = "Option::is_none")]
@ -213,13 +190,10 @@ pub struct HostConfig {
pub kernel: String,
pub versions: Map<String, String>,
/// initrd config template
pub bootstrap_config: String,
/// files to add to the final initrd config, with rendering
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub initrd_files: Vec<crate::File>,
/// system config template
pub config: String,
}
@ -332,3 +306,50 @@ pub enum Error {
#[error("response parsing failed: {0}")]
Parse(serde_json::Error),
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum File {
Static(crate::File),
Gen { path: String, from: ContentGen },
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContentGen {
CaCrt(CaRef),
TlsKey(TlsRef),
TlsCrt {
key: TlsRef,
ca: CaRef,
profile: CertProfile,
},
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CaRef {
Global(String),
Cluster(String, String),
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TlsRef {
Cluster(String, String),
Host(String, String),
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CertProfile {
Client,
Server,
/// basicaly Client+Server
Peer,
Kube {
user: String,
group: String,
duration: Duration,
},
}

View File

@ -1,5 +1,4 @@
pub mod apply;
pub mod proxy;
pub mod bootstrap;
pub mod dls;
pub mod dynlay;
@ -53,7 +52,7 @@ pub struct User {
pub gid: Option<u32>,
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct File {
pub path: String,
#[serde(skip_serializing_if = "Option::is_none")]
@ -62,21 +61,10 @@ pub struct File {
pub kind: FileKind,
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FileKind {
Content(String),
Symlink(String),
Dir(bool),
}
// ------------------------------------------------------------------------
impl Config {
pub fn has_file(&self, path: &str) -> bool {
self.files.iter().any(|f| f.path == path)
}
pub fn file(&self, path: &str) -> Option<&File> {
self.files.iter().find(|f| f.path == path)
}
}

View File

@ -1,136 +0,0 @@
use log::{info, log_enabled, warn};
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::time::Duration;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::time;
pub struct Proxy {
pub listen_addrs: Vec<SocketAddr>,
pub targets: Vec<SocketAddr>,
pub poll: Duration,
pub timeout: Duration,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to listen on {0}: {1}")]
ListenFailed(SocketAddr, std::io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Proxy {
pub async fn run(self) -> Result<Infallible> {
let mut listeners = Vec::with_capacity(self.listen_addrs.len());
for addr in self.listen_addrs {
listeners.push(
TcpListener::bind(&addr)
.await
.map_err(|e| Error::ListenFailed(addr, e))?,
);
info!("listening on {addr}");
}
// all targets are initially ok (better land on a down one than just fail)
let targets: Vec<_> = (self.targets.into_iter())
.map(|addr| TargetStatus {
addr,
up: AtomicBool::new(true),
timeout: self.timeout,
})
.collect();
// the proxy runs forever -> using 'static is not a leak
let targets = targets.leak();
for listener in listeners {
tokio::spawn(proxy_listener(listener, targets));
}
check_targets(targets, self.poll).await
}
}
struct TargetStatus {
addr: SocketAddr,
up: AtomicBool,
timeout: Duration,
}
impl TargetStatus {
fn is_up(&self) -> bool {
self.up.load(Relaxed)
}
fn set_up(&self, is_up: bool) {
let prev = self.up.swap(is_up, Relaxed);
if prev != is_up {
if is_up {
info!("{} is up", self.addr);
} else {
warn!("{} is down", self.addr);
}
}
}
async fn connect(&self) -> io::Result<TcpStream> {
let r = match time::timeout(self.timeout, TcpStream::connect(self.addr)).await {
Ok(r) => r,
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
};
self.set_up(r.is_ok());
r
}
}
async fn check_targets(targets: &'static [TargetStatus], poll: Duration) -> ! {
use tokio::time;
let mut poll_ticker = time::interval(poll);
poll_ticker.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
poll_ticker.tick().await;
let mut tasks = tokio::task::JoinSet::new();
for target in targets {
tasks.spawn(target.connect());
}
tasks.join_all().await;
if log_enabled!(log::Level::Info) {
let mut infos = String::new();
for ts in targets.iter() {
infos.push_str(&format!("{} ", ts.addr));
infos.push_str(if ts.is_up() { "up " } else { "down " });
}
info!("{infos}");
}
}
}
async fn proxy_listener(listener: TcpListener, targets: &'static [TargetStatus]) {
let mut rng = fastrand::Rng::new();
loop {
let mut active = Vec::with_capacity(targets.len());
let (mut src, _) = listener.accept().await.expect("listener.accept() failed");
active.extend((targets.iter().enumerate()).filter_map(|(i, ts)| ts.is_up().then_some(i)));
rng.shuffle(&mut active);
tokio::spawn(async move {
for i in active {
if let Ok(mut dst) = targets[i].connect().await {
let _ = tokio::io::copy_bidirectional(&mut src, &mut dst).await;
break;
}
}
});
}
}