diff --git a/Dockerfile b/Dockerfile index 1de6575..10de0bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,9 +16,10 @@ from alpine:3.19.0 as initrd run apk add --no-cache xz workdir /layer -run wget -O- https://dl-cdn.alpinelinux.org/alpine/v3.18/releases/x86_64/alpine-minirootfs-3.18.4-x86_64.tar.gz |tar zxv +run . /etc/os-release \ + && wget -O- https://dl-cdn.alpinelinux.org/alpine/v${VERSION_ID%.*}/releases/x86_64/alpine-minirootfs-${VERSION_ID}-x86_64.tar.gz |tar zxv -run apk add --no-cache -p . musl lvm2 lvm2-dmeventd udev cryptsetup e2fsprogs btrfs-progs lsblk +run apk add --no-cache -p . musl lvm2 lvm2-extra lvm2-dmeventd udev cryptsetup e2fsprogs btrfs-progs lsblk run rm -rf usr/share/apk var/cache/apk copy --from=build /go/bin/init . diff --git a/ask-secret.go b/ask-secret.go index a7cc29e..6ca1923 100644 --- a/ask-secret.go +++ b/ask-secret.go @@ -6,41 +6,51 @@ import ( "fmt" "io" "os" + "sync" + "sync/atomic" ) -func askSecret(prompt string) []byte { - stdinTTY.EchoOff() +var ( + inputTTYs = new(sync.Map) + askingSecret atomic.Bool +) - var ( - in io.Reader = stdin - out io.Writer = stdout - ) +func registerInput(in *tty) { inputTTYs.Store(in, in) } +func unregiterInput(in *tty) { inputTTYs.Delete(in) } - if stdin == nil { - in = os.Stdin - out = os.Stdout - } +func askSecret(prompt string) (s []byte) { + err := func() (err error) { + askingSecret.Store(true) + defer askingSecret.Store(false) - out.Write([]byte(prompt + ": ")) + inputTTYs.Range(func(k, v any) (con bool) { v.(*tty).EchoOff(); return true }) + defer inputTTYs.Range(func(k, v any) (con bool) { v.(*tty).Restore(); return true }) - if stdin != nil { - stdout.HideInput() - } + var ( + in io.Reader = stdin + out io.Writer = stdout + ) - s, err := bufio.NewReader(in).ReadBytes('\n') + if stdin == nil { + in = os.Stdin + out = os.Stdout + } - if stdin != nil { - stdout.ShowInput() - } + out.Write([]byte(prompt + ": ")) - stdinTTY.Restore() + s, err = bufio.NewReader(in).ReadBytes('\n') + if err != nil { + return + } + + fmt.Println() + s = bytes.TrimRight(s, "\r\n") + return + }() if err != nil { fatalf("failed to read from stdin: %v", err) } - fmt.Println() - - s = bytes.TrimRight(s, "\r\n") - return s + return } diff --git a/bootstrap.go b/bootstrap.go index b57add0..fa41b6f 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -125,7 +125,7 @@ func bootstrap(cfg *config.Config) { } // update-ca-certificates - log.Info().Msg("updating CA certifices") + log.Info().Msg("updating CA certificates") run("chroot", "/system", "update-ca-certificates") } diff --git a/go.mod b/go.mod index a4ce7d5..cf60a08 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module novit.tech/direktil/initrd require ( + github.com/creack/pty v1.1.21 github.com/freddierice/go-losetup/v2 v2.0.1 - github.com/kr/pty v1.1.8 github.com/pkg/term v1.1.0 github.com/rs/zerolog v1.31.0 golang.org/x/crypto v0.18.0 @@ -16,7 +16,6 @@ require ( require ( github.com/cavaliergopher/cpio v1.0.1 // indirect - github.com/creack/pty v1.1.21 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 6a6ef8e..4a09ca8 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,6 @@ github.com/cavaliergopher/cpio v1.0.1 h1:KQFSeKmZhv0cr+kawA3a0xTQCU4QxXF1vhU7P7av2KM= github.com/cavaliergopher/cpio v1.0.1/go.mod h1:pBdaqQjnvXxdS/6CvNDwIANIFSP0xRKI16PX4xejRQc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/freddierice/go-losetup/v2 v2.0.1 h1:wPDx/Elu9nDV8y/CvIbEDz5Xi5Zo80y4h7MKbi3XaAI= @@ -11,8 +10,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/main.go b/main.go index ce2f299..79e4be1 100644 --- a/main.go +++ b/main.go @@ -44,6 +44,8 @@ func newPipe() (io.ReadCloser, io.WriteCloser) { func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + registerInput(newTTY(os.Stdin.Fd())) + switch baseName := filepath.Base(os.Args[0]); baseName { case "init": runInit() diff --git a/ssh.go b/ssh.go index f715c4f..5fe9eb8 100644 --- a/ssh.go +++ b/ssh.go @@ -9,9 +9,10 @@ import ( "os/exec" "sync" "syscall" + "time" "unsafe" - "github.com/kr/pty" + "github.com/creack/pty" "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" @@ -142,14 +143,35 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s command := string(req.Payload[4 : req.Payload[3]+4]) switch command { case "init": - go func() { - io.Copy(channel, stdout.NewReader()) - once.Do(closeCh) - }() - go func() { - io.Copy(stdinPipe, channel) - once.Do(closeCh) - }() + if ptyF == nil { + go func() { + channel.Stderr().Write([]byte("\033[5m\033[31;1m\n\nWARNING: no TTY requested, passwords will be echoed!\n\n\033[0m")) + time.Sleep(3 * time.Second) + io.Copy(channel, stdout.NewReader()) + once.Do(closeCh) + }() + go func() { + io.Copy(stdinPipe, channel) + once.Do(closeCh) + }() + + } else { + stdinTTY := newTTY(ptyF.Fd()) + if askingSecret.Load() { + stdinTTY.EchoOff() + } + registerInput(stdinTTY) + defer unregiterInput(stdinTTY) + + go func() { + io.Copy(ttyF, stdout.NewReader()) + once.Do(closeCh) + }() + go func() { + io.Copy(stdinPipe, ttyF) + once.Do(closeCh) + }() + } req.Reply(true, nil) @@ -195,15 +217,6 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s ttyF = nil }() - go func() { - io.Copy(channel, ptyF) - once.Do(closeCh) - }() - go func() { - io.Copy(ptyF, channel) - once.Do(closeCh) - }() - req.Reply(true, nil) case "pty-req": @@ -227,6 +240,15 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s req.Reply(true, nil) + go func() { + io.Copy(channel, ptyF) + once.Do(closeCh) + }() + go func() { + io.Copy(ptyF, channel) + once.Do(closeCh) + }() + case "window-change": w, h := sshParseDims(req.Payload) sshSetWinsize(ptyF.Fd(), w, h) diff --git a/test-initrd/config.yaml b/test-initrd/config.yaml index 8919c7f..18fc337 100644 --- a/test-initrd/config.yaml +++ b/test-initrd/config.yaml @@ -26,9 +26,9 @@ networks: - eno.* - enp.* script: | - ip a add 2001:41d0:306:168f::1337:2eed/64 dev $iface ip li set $iface up - #udhcpc $iface + udhcpc -i $iface -b -t1 -T1 -A5 || + ip a add 2001:41d0:306:168f::1337:2eed/64 dev $iface pre_lvm_crypt: - dev: /dev/vda diff --git a/tty.go b/tty.go index ba419e8..f418227 100644 --- a/tty.go +++ b/tty.go @@ -1,29 +1,25 @@ package main import ( - "os" - "golang.org/x/sys/unix" ) -var ( - stdinTTY = &tty{int(os.Stdin.Fd()), nil} -) - type tty struct { fd int termios *unix.Termios } +func newTTY(fd uintptr) *tty { + termios, _ := unix.IoctlGetTermios(int(fd), unix.TCGETS) + return &tty{int(fd), termios} +} + func (t *tty) EchoOff() { - termios, err := unix.IoctlGetTermios(t.fd, unix.TCGETS) - if err != nil { + if t.termios == nil { return } - t.termios = termios - - newState := *termios + newState := *t.termios newState.Lflag &^= unix.ECHO newState.Lflag |= unix.ICANON | unix.ISIG newState.Iflag |= unix.ICRNL @@ -31,8 +27,9 @@ func (t *tty) EchoOff() { } func (t *tty) Restore() { - if t.termios != nil { - unix.IoctlSetTermios(t.fd, unix.TCSETS, t.termios) - t.termios = nil + if t.termios == nil { + return } + + unix.IoctlSetTermios(t.fd, unix.TCSETS, t.termios) }