Files
initrd/src/input.rs
Mikaël Cluseau 2ab793c54a migrate to rust
2025-06-30 16:07:31 +02:00

344 lines
9.2 KiB
Rust

use log::warn;
use std::fmt::Display;
use std::sync::{Arc, LazyLock};
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net;
use tokio::sync::{oneshot, watch, Mutex};
pub async fn read_line(prompt: impl Display) -> String {
read(prompt, false).await
}
pub async fn read_password(prompt: impl Display) -> String {
read(prompt, true).await
}
fn choice_char(s: &str) -> char {
s.chars().skip_while(|c| *c != '[').skip(1).next().unwrap()
}
#[test]
fn test_choice_char() {
assert_eq!('r', choice_char("[r]etry"));
assert_eq!('b', choice_char("re[b]boot"));
}
/// ```no_run
/// use init::input;
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// tokio::spawn(input::answer_requests_from_stdin());
/// match input::read_choice(["[r]etry","[i]gnore","re[b]oot"]).await {
/// 'r' => todo!(),
/// 'i' => todo!(),
/// 'b' => todo!(),
/// _ => unreachable!(),
/// }
/// }
/// ```
pub async fn read_choice<const N: usize>(choices: [&str; N]) -> char {
let chars = choices.map(choice_char);
let mut prompt = String::new();
for s in choices {
if !prompt.is_empty() {
prompt.push_str(", ");
}
prompt.push_str(s);
}
prompt.push_str("? ");
loop {
let line = read_line(&prompt).await;
let Some(ch) = line.chars().nth(0) else {
continue;
};
for choice in chars {
if ch == choice {
return choice;
}
}
}
}
#[derive(Clone, serde::Deserialize, serde::Serialize)]
pub struct InputRequest {
prompt: String,
hide: bool,
}
pub type Reply = Arc<Mutex<Option<oneshot::Sender<String>>>>;
static REQ: LazyLock<Mutex<watch::Sender<Option<(InputRequest, Reply)>>>> = LazyLock::new(|| {
let (tx, _) = watch::channel(None);
Mutex::new(tx)
});
static READ_MUTEX: Mutex<()> = Mutex::const_new(());
async fn read(prompt: impl Display, hide_input: bool) -> String {
let _read_lock = READ_MUTEX.lock();
let req = InputRequest {
prompt: prompt.to_string(),
hide: hide_input,
};
let (tx, rx) = oneshot::channel();
let reply = Arc::new(Mutex::new(Some(tx)));
REQ.lock().await.send_replace(Some((req, reply)));
let input = rx.await.expect("reply sender should not be closed");
REQ.lock().await.send_replace(None);
input
}
pub async fn answer_requests_from_stdin() {
let mut stdin = BufReader::new(io::stdin()).lines();
let mut stdout = io::stdout();
let mut current_req = REQ.lock().await.subscribe();
current_req.mark_changed();
loop {
// TODO check is stdin has been closed (using C-c is enough for now)
(current_req.changed().await).expect("input request should not close");
let Some((req, reply)) = current_req.borrow_and_update().clone() else {
continue;
};
// handle hide
let mut saved_termios = None;
if req.hide {
match termios::Termios::from_fd(0) {
Ok(mut tio) => {
saved_termios = Some(tio.clone());
tio.c_lflag &= !termios::ECHO;
if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) {
warn!("password may be echoed! {e}");
}
}
Err(e) => {
warn!("password may be echoed! {e}");
}
}
}
// print the prompt and wait for user input
stdout.write_all(req.prompt.as_bytes()).await.unwrap();
stdout.flush().await.unwrap();
tokio::select!(
r = stdin.next_line() => {
let Ok(Some(line)) = r else {
warn!("stdin closed");
return;
};
if let Some(tx) = reply.lock().await.take() {
let _ = tx.send(line);
}
if saved_termios.is_some() {
// final '\n' is hidden too so fix it
stdout.write_all(b"\n").await.unwrap();
stdout.flush().await.unwrap();
}
}
_ = current_req.changed() => {
// reply came from somewhere else
stdout.write_all(b"<answered>\n").await.unwrap();
stdout.flush().await.unwrap();
current_req.mark_changed();
}
);
// restore term if input was hidden
if let Some(tio) = saved_termios {
if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) {
warn!("failed to restore pty attrs: {e}");
}
}
}
}
const SOCKET_PATH: &str = "/run/init.sock";
pub async fn answer_requests_from_socket() {
let Ok(listener) = net::UnixListener::bind(SOCKET_PATH)
.inspect_err(|e| warn!("failed start input socket listener: {e}"))
else {
return;
};
loop {
let Ok((conn, _)) = (listener.accept())
.await
.inspect_err(|e| warn!("input socket listener failed: {e}"))
else {
return;
};
tokio::spawn(handle_connection(conn));
}
}
#[derive(serde::Deserialize, serde::Serialize)]
enum Message {
#[serde(rename = "l")]
Log(Vec<u8>),
#[serde(rename = "r")]
Req(Option<InputRequest>),
}
async fn handle_connection(conn: net::UnixStream) {
use crate::dklog;
let mut log = dklog::LOG.subscribe();
let mut current_req = REQ.lock().await.subscribe();
current_req.mark_changed();
let (rd, mut wr) = io::split(conn);
let mut rd = BufReader::new(rd).lines();
macro_rules! wr {
($msg:expr) => {
let mut buf = serde_json::to_vec(&$msg).unwrap();
buf.push(b'\n');
if wr.write_all(&buf).await.is_err() {
return;
}
if wr.flush().await.is_err() {
return;
}
};
}
loop {
tokio::select!(
r = current_req.changed() => {
r.expect("input request should not close");
},
l = log.next() => {
let Some(l) = l else { return; };
wr!(Message::Log(l));
continue;
},
);
let Some((req, reply)) = current_req.borrow_and_update().clone() else {
wr!(Message::Req(None));
continue;
};
wr!(Message::Req(Some(req)));
loop {
tokio::select!(
r = rd.next_line() => {
let Ok(Some(line)) = r else {
return; // closed
};
if let Some(tx) = reply.lock().await.take() {
let _ = tx.send(line);
}
}
_ = current_req.changed() => {
// reply came from somewhere else
current_req.mark_changed();
}
l = log.next() => {
let Some(l) = l else { return; };
wr!(Message::Log(l));
continue;
},
);
break;
}
}
}
pub async fn forward_requests_from_socket() -> eyre::Result<()> {
let stream = net::UnixStream::connect(SOCKET_PATH).await?;
let (rd, mut wr) = io::split(stream);
let mut rd = BufReader::new(rd).lines();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
tokio::spawn(async move {
loop {
let Ok(line) = (rd.next_line().await).inspect_err(|e| warn!("socket read error: {e}"))
else {
return;
};
let Some(line) = line else {
// end of stream
return;
};
let Ok(msg) = serde_json::from_str::<Message>(&line)
.inspect_err(|e| warn!("invalid message received: {e}"))
else {
continue;
};
match msg {
Message::Req(req) => {
if tx.send(req).await.is_err() {
// closed
return;
}
}
Message::Log(l) => {
let mut out = io::stderr();
let _ = out.write_all(&l).await;
let _ = out.flush().await;
}
};
}
});
// the recv request if any, otherwise wait for the next
let mut recv = None;
loop {
let req = match recv.take() {
// value already available
Some(v) => v,
// no value available, wait for the next
None => match rx.recv().await {
Some(v) => v,
// end of requests
None => return Ok(()),
},
};
let Some(req) = req else {
REQ.lock().await.send_replace(None);
continue;
};
tokio::select!(
mut r = read(req.prompt, req.hide) => {
r.push('\n');
wr.write_all(r.as_bytes()).await?;
wr.flush().await?;
}
r = rx.recv() => {
recv = r
}
);
}
}