344 lines
9.2 KiB
Rust
344 lines
9.2 KiB
Rust
![]() |
use log::warn;
|
||
|
use std::fmt::Display;
|
||
|
use std::sync::{Arc, LazyLock};
|
||
|
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||
|
use tokio::net;
|
||
|
use tokio::sync::{oneshot, watch, Mutex};
|
||
|
|
||
|
pub async fn read_line(prompt: impl Display) -> String {
|
||
|
read(prompt, false).await
|
||
|
}
|
||
|
|
||
|
pub async fn read_password(prompt: impl Display) -> String {
|
||
|
read(prompt, true).await
|
||
|
}
|
||
|
|
||
|
fn choice_char(s: &str) -> char {
|
||
|
s.chars().skip_while(|c| *c != '[').skip(1).next().unwrap()
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn test_choice_char() {
|
||
|
assert_eq!('r', choice_char("[r]etry"));
|
||
|
assert_eq!('b', choice_char("re[b]boot"));
|
||
|
}
|
||
|
|
||
|
/// ```no_run
|
||
|
/// use init::input;
|
||
|
///
|
||
|
/// #[tokio::main(flavor = "current_thread")]
|
||
|
/// async fn main() {
|
||
|
/// tokio::spawn(input::answer_requests_from_stdin());
|
||
|
/// match input::read_choice(["[r]etry","[i]gnore","re[b]oot"]).await {
|
||
|
/// 'r' => todo!(),
|
||
|
/// 'i' => todo!(),
|
||
|
/// 'b' => todo!(),
|
||
|
/// _ => unreachable!(),
|
||
|
/// }
|
||
|
/// }
|
||
|
/// ```
|
||
|
pub async fn read_choice<const N: usize>(choices: [&str; N]) -> char {
|
||
|
let chars = choices.map(choice_char);
|
||
|
|
||
|
let mut prompt = String::new();
|
||
|
for s in choices {
|
||
|
if !prompt.is_empty() {
|
||
|
prompt.push_str(", ");
|
||
|
}
|
||
|
prompt.push_str(s);
|
||
|
}
|
||
|
prompt.push_str("? ");
|
||
|
|
||
|
loop {
|
||
|
let line = read_line(&prompt).await;
|
||
|
let Some(ch) = line.chars().nth(0) else {
|
||
|
continue;
|
||
|
};
|
||
|
|
||
|
for choice in chars {
|
||
|
if ch == choice {
|
||
|
return choice;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, serde::Deserialize, serde::Serialize)]
|
||
|
pub struct InputRequest {
|
||
|
prompt: String,
|
||
|
hide: bool,
|
||
|
}
|
||
|
|
||
|
pub type Reply = Arc<Mutex<Option<oneshot::Sender<String>>>>;
|
||
|
|
||
|
static REQ: LazyLock<Mutex<watch::Sender<Option<(InputRequest, Reply)>>>> = LazyLock::new(|| {
|
||
|
let (tx, _) = watch::channel(None);
|
||
|
Mutex::new(tx)
|
||
|
});
|
||
|
static READ_MUTEX: Mutex<()> = Mutex::const_new(());
|
||
|
|
||
|
async fn read(prompt: impl Display, hide_input: bool) -> String {
|
||
|
let _read_lock = READ_MUTEX.lock();
|
||
|
|
||
|
let req = InputRequest {
|
||
|
prompt: prompt.to_string(),
|
||
|
hide: hide_input,
|
||
|
};
|
||
|
|
||
|
let (tx, rx) = oneshot::channel();
|
||
|
let reply = Arc::new(Mutex::new(Some(tx)));
|
||
|
|
||
|
REQ.lock().await.send_replace(Some((req, reply)));
|
||
|
|
||
|
let input = rx.await.expect("reply sender should not be closed");
|
||
|
|
||
|
REQ.lock().await.send_replace(None);
|
||
|
|
||
|
input
|
||
|
}
|
||
|
|
||
|
pub async fn answer_requests_from_stdin() {
|
||
|
let mut stdin = BufReader::new(io::stdin()).lines();
|
||
|
let mut stdout = io::stdout();
|
||
|
|
||
|
let mut current_req = REQ.lock().await.subscribe();
|
||
|
current_req.mark_changed();
|
||
|
|
||
|
loop {
|
||
|
// TODO check is stdin has been closed (using C-c is enough for now)
|
||
|
(current_req.changed().await).expect("input request should not close");
|
||
|
|
||
|
let Some((req, reply)) = current_req.borrow_and_update().clone() else {
|
||
|
continue;
|
||
|
};
|
||
|
|
||
|
// handle hide
|
||
|
let mut saved_termios = None;
|
||
|
if req.hide {
|
||
|
match termios::Termios::from_fd(0) {
|
||
|
Ok(mut tio) => {
|
||
|
saved_termios = Some(tio.clone());
|
||
|
tio.c_lflag &= !termios::ECHO;
|
||
|
if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) {
|
||
|
warn!("password may be echoed! {e}");
|
||
|
}
|
||
|
}
|
||
|
Err(e) => {
|
||
|
warn!("password may be echoed! {e}");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// print the prompt and wait for user input
|
||
|
stdout.write_all(req.prompt.as_bytes()).await.unwrap();
|
||
|
stdout.flush().await.unwrap();
|
||
|
|
||
|
tokio::select!(
|
||
|
r = stdin.next_line() => {
|
||
|
let Ok(Some(line)) = r else {
|
||
|
warn!("stdin closed");
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
if let Some(tx) = reply.lock().await.take() {
|
||
|
let _ = tx.send(line);
|
||
|
}
|
||
|
|
||
|
if saved_termios.is_some() {
|
||
|
// final '\n' is hidden too so fix it
|
||
|
stdout.write_all(b"\n").await.unwrap();
|
||
|
stdout.flush().await.unwrap();
|
||
|
}
|
||
|
}
|
||
|
_ = current_req.changed() => {
|
||
|
// reply came from somewhere else
|
||
|
stdout.write_all(b"<answered>\n").await.unwrap();
|
||
|
stdout.flush().await.unwrap();
|
||
|
|
||
|
current_req.mark_changed();
|
||
|
}
|
||
|
);
|
||
|
|
||
|
// restore term if input was hidden
|
||
|
if let Some(tio) = saved_termios {
|
||
|
if let Err(e) = termios::tcsetattr(0, termios::TCSAFLUSH, &tio) {
|
||
|
warn!("failed to restore pty attrs: {e}");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const SOCKET_PATH: &str = "/run/init.sock";
|
||
|
|
||
|
pub async fn answer_requests_from_socket() {
|
||
|
let Ok(listener) = net::UnixListener::bind(SOCKET_PATH)
|
||
|
.inspect_err(|e| warn!("failed start input socket listener: {e}"))
|
||
|
else {
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
loop {
|
||
|
let Ok((conn, _)) = (listener.accept())
|
||
|
.await
|
||
|
.inspect_err(|e| warn!("input socket listener failed: {e}"))
|
||
|
else {
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
tokio::spawn(handle_connection(conn));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(serde::Deserialize, serde::Serialize)]
|
||
|
enum Message {
|
||
|
#[serde(rename = "l")]
|
||
|
Log(Vec<u8>),
|
||
|
#[serde(rename = "r")]
|
||
|
Req(Option<InputRequest>),
|
||
|
}
|
||
|
|
||
|
async fn handle_connection(conn: net::UnixStream) {
|
||
|
use crate::dklog;
|
||
|
|
||
|
let mut log = dklog::LOG.subscribe();
|
||
|
|
||
|
let mut current_req = REQ.lock().await.subscribe();
|
||
|
current_req.mark_changed();
|
||
|
|
||
|
let (rd, mut wr) = io::split(conn);
|
||
|
let mut rd = BufReader::new(rd).lines();
|
||
|
|
||
|
macro_rules! wr {
|
||
|
($msg:expr) => {
|
||
|
let mut buf = serde_json::to_vec(&$msg).unwrap();
|
||
|
buf.push(b'\n');
|
||
|
|
||
|
if wr.write_all(&buf).await.is_err() {
|
||
|
return;
|
||
|
}
|
||
|
if wr.flush().await.is_err() {
|
||
|
return;
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
loop {
|
||
|
tokio::select!(
|
||
|
r = current_req.changed() => {
|
||
|
r.expect("input request should not close");
|
||
|
},
|
||
|
l = log.next() => {
|
||
|
let Some(l) = l else { return; };
|
||
|
wr!(Message::Log(l));
|
||
|
continue;
|
||
|
},
|
||
|
);
|
||
|
|
||
|
let Some((req, reply)) = current_req.borrow_and_update().clone() else {
|
||
|
wr!(Message::Req(None));
|
||
|
continue;
|
||
|
};
|
||
|
|
||
|
wr!(Message::Req(Some(req)));
|
||
|
|
||
|
loop {
|
||
|
tokio::select!(
|
||
|
r = rd.next_line() => {
|
||
|
let Ok(Some(line)) = r else {
|
||
|
return; // closed
|
||
|
};
|
||
|
|
||
|
if let Some(tx) = reply.lock().await.take() {
|
||
|
let _ = tx.send(line);
|
||
|
}
|
||
|
}
|
||
|
_ = current_req.changed() => {
|
||
|
// reply came from somewhere else
|
||
|
current_req.mark_changed();
|
||
|
}
|
||
|
l = log.next() => {
|
||
|
let Some(l) = l else { return; };
|
||
|
wr!(Message::Log(l));
|
||
|
continue;
|
||
|
},
|
||
|
);
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub async fn forward_requests_from_socket() -> eyre::Result<()> {
|
||
|
let stream = net::UnixStream::connect(SOCKET_PATH).await?;
|
||
|
|
||
|
let (rd, mut wr) = io::split(stream);
|
||
|
let mut rd = BufReader::new(rd).lines();
|
||
|
|
||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
|
||
|
|
||
|
tokio::spawn(async move {
|
||
|
loop {
|
||
|
let Ok(line) = (rd.next_line().await).inspect_err(|e| warn!("socket read error: {e}"))
|
||
|
else {
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
let Some(line) = line else {
|
||
|
// end of stream
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
let Ok(msg) = serde_json::from_str::<Message>(&line)
|
||
|
.inspect_err(|e| warn!("invalid message received: {e}"))
|
||
|
else {
|
||
|
continue;
|
||
|
};
|
||
|
|
||
|
match msg {
|
||
|
Message::Req(req) => {
|
||
|
if tx.send(req).await.is_err() {
|
||
|
// closed
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
Message::Log(l) => {
|
||
|
let mut out = io::stderr();
|
||
|
let _ = out.write_all(&l).await;
|
||
|
let _ = out.flush().await;
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
});
|
||
|
|
||
|
// the recv request if any, otherwise wait for the next
|
||
|
let mut recv = None;
|
||
|
|
||
|
loop {
|
||
|
let req = match recv.take() {
|
||
|
// value already available
|
||
|
Some(v) => v,
|
||
|
// no value available, wait for the next
|
||
|
None => match rx.recv().await {
|
||
|
Some(v) => v,
|
||
|
// end of requests
|
||
|
None => return Ok(()),
|
||
|
},
|
||
|
};
|
||
|
|
||
|
let Some(req) = req else {
|
||
|
REQ.lock().await.send_replace(None);
|
||
|
continue;
|
||
|
};
|
||
|
|
||
|
tokio::select!(
|
||
|
mut r = read(req.prompt, req.hide) => {
|
||
|
r.push('\n');
|
||
|
wr.write_all(r.as_bytes()).await?;
|
||
|
wr.flush().await?;
|
||
|
}
|
||
|
r = rx.recv() => {
|
||
|
recv = r
|
||
|
}
|
||
|
);
|
||
|
}
|
||
|
}
|