use log::warn; use std::fmt::Display; use std::sync::{Arc, LazyLock}; use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net; use tokio::sync::{oneshot, watch, Mutex}; pub async fn read_line(prompt: impl Display) -> String { read(prompt, false).await } pub async fn read_password(prompt: impl Display) -> String { read(prompt, true).await } fn choice_char(s: &str) -> char { s.chars().skip_while(|c| *c != '[').skip(1).next().unwrap() } #[test] fn test_choice_char() { assert_eq!('r', choice_char("[r]etry")); assert_eq!('b', choice_char("re[b]boot")); } /// ```no_run /// use init::input; /// /// #[tokio::main(flavor = "current_thread")] /// async fn main() { /// tokio::spawn(input::answer_requests_from_stdin()); /// match input::read_choice(["[r]etry","[i]gnore","re[b]oot"]).await { /// 'r' => todo!(), /// 'i' => todo!(), /// 'b' => todo!(), /// _ => unreachable!(), /// } /// } /// ``` pub async fn read_choice(choices: [&str; N]) -> char { let chars = choices.map(choice_char); let mut prompt = String::new(); for s in choices { if !prompt.is_empty() { prompt.push_str(", "); } prompt.push_str(s); } prompt.push_str("? "); loop { let line = read_line(&prompt).await; let Some(ch) = line.chars().nth(0) else { continue; }; for choice in chars { if ch == choice { return choice; } } } } #[derive(Clone, serde::Deserialize, serde::Serialize)] pub struct InputRequest { prompt: String, hide: bool, } pub type Reply = Arc>>>; static REQ: LazyLock>>> = LazyLock::new(|| { let (tx, _) = watch::channel(None); Mutex::new(tx) }); static READ_MUTEX: Mutex<()> = Mutex::const_new(()); async fn read(prompt: impl Display, hide_input: bool) -> String { let _read_lock = READ_MUTEX.lock(); let req = InputRequest { prompt: prompt.to_string(), hide: hide_input, }; let (tx, rx) = oneshot::channel(); let reply = Arc::new(Mutex::new(Some(tx))); REQ.lock().await.send_replace(Some((req, reply))); let input = rx.await.expect("reply sender should not be closed"); REQ.lock().await.send_replace(None); input } pub async fn answer_requests_from_stdin() { let mut stdin = BufReader::new(io::stdin()).lines(); let mut stdout = io::stdout(); let mut current_req = REQ.lock().await.subscribe(); current_req.mark_changed(); loop { // TODO check is stdin has been closed (using C-c is enough for now) (current_req.changed().await).expect("input request should not close"); let Some((req, reply)) = current_req.borrow_and_update().clone() else { continue; }; // handle hide let mut saved_termios = None; if req.hide { match termios::Termios::from_fd(0) { Ok(mut tio) => { saved_termios = Some(tio.clone()); tio.c_lflag &= !termios::ECHO; if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) { warn!("password may be echoed! {e}"); } } Err(e) => { warn!("password may be echoed! {e}"); } } } // print the prompt and wait for user input stdout.write_all(req.prompt.as_bytes()).await.unwrap(); stdout.flush().await.unwrap(); tokio::select!( r = stdin.next_line() => { let Ok(Some(line)) = r else { warn!("stdin closed"); return; }; if let Some(tx) = reply.lock().await.take() { let _ = tx.send(line); } if saved_termios.is_some() { // final '\n' is hidden too so fix it stdout.write_all(b"\n").await.unwrap(); stdout.flush().await.unwrap(); } } _ = current_req.changed() => { // reply came from somewhere else stdout.write_all(b"\n").await.unwrap(); stdout.flush().await.unwrap(); current_req.mark_changed(); } ); // restore term if input was hidden if let Some(tio) = saved_termios { if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) { warn!("failed to restore pty attrs: {e}"); } } } } const SOCKET_PATH: &str = "/run/init.sock"; pub async fn answer_requests_from_socket() { let Ok(listener) = net::UnixListener::bind(SOCKET_PATH) .inspect_err(|e| warn!("failed start input socket listener: {e}")) else { return; }; loop { let Ok((conn, _)) = (listener.accept()) .await .inspect_err(|e| warn!("input socket listener failed: {e}")) else { return; }; tokio::spawn(handle_connection(conn)); } } #[derive(serde::Deserialize, serde::Serialize)] enum Message { #[serde(rename = "l")] Log(Vec), #[serde(rename = "r")] Req(Option), } async fn handle_connection(conn: net::UnixStream) { use crate::dklog; let mut log = dklog::LOG.subscribe(); let mut current_req = REQ.lock().await.subscribe(); current_req.mark_changed(); let (rd, mut wr) = io::split(conn); let mut rd = BufReader::new(rd).lines(); macro_rules! wr { ($msg:expr) => { let mut buf = serde_json::to_vec(&$msg).unwrap(); buf.push(b'\n'); if wr.write_all(&buf).await.is_err() { return; } if wr.flush().await.is_err() { return; } }; } loop { tokio::select!( r = current_req.changed() => { r.expect("input request should not close"); }, l = log.next() => { let Some(l) = l else { return; }; wr!(Message::Log(l)); continue; }, ); let Some((req, reply)) = current_req.borrow_and_update().clone() else { wr!(Message::Req(None)); continue; }; wr!(Message::Req(Some(req))); loop { tokio::select!( r = rd.next_line() => { let Ok(Some(line)) = r else { return; // closed }; if let Some(tx) = reply.lock().await.take() { let _ = tx.send(line); } } _ = current_req.changed() => { // reply came from somewhere else current_req.mark_changed(); } l = log.next() => { let Some(l) = l else { return; }; wr!(Message::Log(l)); continue; }, ); break; } } } pub async fn forward_requests_from_socket() -> eyre::Result<()> { let stream = net::UnixStream::connect(SOCKET_PATH).await?; let (rd, mut wr) = io::split(stream); let mut rd = BufReader::new(rd).lines(); let (tx, mut rx) = tokio::sync::mpsc::channel(1); tokio::spawn(async move { loop { let Ok(line) = (rd.next_line().await).inspect_err(|e| warn!("socket read error: {e}")) else { return; }; let Some(line) = line else { // end of stream return; }; let Ok(msg) = serde_json::from_str::(&line) .inspect_err(|e| warn!("invalid message received: {e}")) else { continue; }; match msg { Message::Req(req) => { if tx.send(req).await.is_err() { // closed return; } } Message::Log(l) => { let mut out = io::stderr(); let _ = out.write_all(&l).await; let _ = out.flush().await; } }; } }); // the recv request if any, otherwise wait for the next let mut recv = None; loop { let req = match recv.take() { // value already available Some(v) => v, // no value available, wait for the next None => match rx.recv().await { Some(v) => v, // end of requests None => return Ok(()), }, }; let Some(req) = req else { REQ.lock().await.send_replace(None); continue; }; tokio::select!( mut r = read(req.prompt, req.hide) => { r.push('\n'); wr.write_all(r.as_bytes()).await?; wr.flush().await?; } r = rx.recv() => { recv = r } ); } }