rework ssh 'init' command

This commit is contained in:
Mikaël Cluseau 2024-01-21 15:26:00 +01:00
parent 2df68d3fca
commit d0904ba602
8 changed files with 92 additions and 64 deletions

View File

@ -16,9 +16,10 @@ from alpine:3.19.0 as initrd
run apk add --no-cache xz run apk add --no-cache xz
workdir /layer workdir /layer
run wget -O- https://dl-cdn.alpinelinux.org/alpine/v3.18/releases/x86_64/alpine-minirootfs-3.18.4-x86_64.tar.gz |tar zxv run . /etc/os-release \
&& wget -O- https://dl-cdn.alpinelinux.org/alpine/v${VERSION_ID%.*}/releases/x86_64/alpine-minirootfs-${VERSION_ID}-x86_64.tar.gz |tar zxv
run apk add --no-cache -p . musl lvm2 lvm2-dmeventd udev cryptsetup e2fsprogs btrfs-progs lsblk run apk add --no-cache -p . musl lvm2 lvm2-extra lvm2-dmeventd udev cryptsetup e2fsprogs btrfs-progs lsblk
run rm -rf usr/share/apk var/cache/apk run rm -rf usr/share/apk var/cache/apk
copy --from=build /go/bin/init . copy --from=build /go/bin/init .

View File

@ -6,41 +6,51 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sync"
"sync/atomic"
) )
func askSecret(prompt string) []byte { var (
stdinTTY.EchoOff() inputTTYs = new(sync.Map)
askingSecret atomic.Bool
)
var ( func registerInput(in *tty) { inputTTYs.Store(in, in) }
in io.Reader = stdin func unregiterInput(in *tty) { inputTTYs.Delete(in) }
out io.Writer = stdout
)
if stdin == nil { func askSecret(prompt string) (s []byte) {
in = os.Stdin err := func() (err error) {
out = os.Stdout askingSecret.Store(true)
} defer askingSecret.Store(false)
out.Write([]byte(prompt + ": ")) inputTTYs.Range(func(k, v any) (con bool) { v.(*tty).EchoOff(); return true })
defer inputTTYs.Range(func(k, v any) (con bool) { v.(*tty).Restore(); return true })
if stdin != nil { var (
stdout.HideInput() in io.Reader = stdin
} out io.Writer = stdout
)
s, err := bufio.NewReader(in).ReadBytes('\n') if stdin == nil {
in = os.Stdin
out = os.Stdout
}
if stdin != nil { out.Write([]byte(prompt + ": "))
stdout.ShowInput()
}
stdinTTY.Restore() s, err = bufio.NewReader(in).ReadBytes('\n')
if err != nil {
return
}
fmt.Println()
s = bytes.TrimRight(s, "\r\n")
return
}()
if err != nil { if err != nil {
fatalf("failed to read from stdin: %v", err) fatalf("failed to read from stdin: %v", err)
} }
fmt.Println() return
s = bytes.TrimRight(s, "\r\n")
return s
} }

3
go.mod
View File

@ -1,8 +1,8 @@
module novit.tech/direktil/initrd module novit.tech/direktil/initrd
require ( require (
github.com/creack/pty v1.1.21
github.com/freddierice/go-losetup/v2 v2.0.1 github.com/freddierice/go-losetup/v2 v2.0.1
github.com/kr/pty v1.1.8
github.com/pkg/term v1.1.0 github.com/pkg/term v1.1.0
github.com/rs/zerolog v1.31.0 github.com/rs/zerolog v1.31.0
golang.org/x/crypto v0.18.0 golang.org/x/crypto v0.18.0
@ -16,7 +16,6 @@ require (
require ( require (
github.com/cavaliergopher/cpio v1.0.1 // indirect github.com/cavaliergopher/cpio v1.0.1 // indirect
github.com/creack/pty v1.1.21 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

3
go.sum
View File

@ -1,7 +1,6 @@
github.com/cavaliergopher/cpio v1.0.1 h1:KQFSeKmZhv0cr+kawA3a0xTQCU4QxXF1vhU7P7av2KM= github.com/cavaliergopher/cpio v1.0.1 h1:KQFSeKmZhv0cr+kawA3a0xTQCU4QxXF1vhU7P7av2KM=
github.com/cavaliergopher/cpio v1.0.1/go.mod h1:pBdaqQjnvXxdS/6CvNDwIANIFSP0xRKI16PX4xejRQc= github.com/cavaliergopher/cpio v1.0.1/go.mod h1:pBdaqQjnvXxdS/6CvNDwIANIFSP0xRKI16PX4xejRQc=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0=
github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/freddierice/go-losetup/v2 v2.0.1 h1:wPDx/Elu9nDV8y/CvIbEDz5Xi5Zo80y4h7MKbi3XaAI= github.com/freddierice/go-losetup/v2 v2.0.1 h1:wPDx/Elu9nDV8y/CvIbEDz5Xi5Zo80y4h7MKbi3XaAI=
@ -11,8 +10,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=

View File

@ -44,6 +44,8 @@ func newPipe() (io.ReadCloser, io.WriteCloser) {
func main() { func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
registerInput(newTTY(os.Stdin.Fd()))
switch baseName := filepath.Base(os.Args[0]); baseName { switch baseName := filepath.Base(os.Args[0]); baseName {
case "init": case "init":
runInit() runInit()

58
ssh.go
View File

@ -9,9 +9,10 @@ import (
"os/exec" "os/exec"
"sync" "sync"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/kr/pty" "github.com/creack/pty"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -142,14 +143,35 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s
command := string(req.Payload[4 : req.Payload[3]+4]) command := string(req.Payload[4 : req.Payload[3]+4])
switch command { switch command {
case "init": case "init":
go func() { if ptyF == nil {
io.Copy(channel, stdout.NewReader()) go func() {
once.Do(closeCh) channel.Stderr().Write([]byte("\033[5m\033[31;1m\n\nWARNING: no TTY requested, passwords will be echoed!\n\n\033[0m"))
}() time.Sleep(3 * time.Second)
go func() { io.Copy(channel, stdout.NewReader())
io.Copy(stdinPipe, channel) once.Do(closeCh)
once.Do(closeCh) }()
}() go func() {
io.Copy(stdinPipe, channel)
once.Do(closeCh)
}()
} else {
stdinTTY := newTTY(ptyF.Fd())
if askingSecret.Load() {
stdinTTY.EchoOff()
}
registerInput(stdinTTY)
defer unregiterInput(stdinTTY)
go func() {
io.Copy(ttyF, stdout.NewReader())
once.Do(closeCh)
}()
go func() {
io.Copy(stdinPipe, ttyF)
once.Do(closeCh)
}()
}
req.Reply(true, nil) req.Reply(true, nil)
@ -195,15 +217,6 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s
ttyF = nil ttyF = nil
}() }()
go func() {
io.Copy(channel, ptyF)
once.Do(closeCh)
}()
go func() {
io.Copy(ptyF, channel)
once.Do(closeCh)
}()
req.Reply(true, nil) req.Reply(true, nil)
case "pty-req": case "pty-req":
@ -227,6 +240,15 @@ func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *s
req.Reply(true, nil) req.Reply(true, nil)
go func() {
io.Copy(channel, ptyF)
once.Do(closeCh)
}()
go func() {
io.Copy(ptyF, channel)
once.Do(closeCh)
}()
case "window-change": case "window-change":
w, h := sshParseDims(req.Payload) w, h := sshParseDims(req.Payload)
sshSetWinsize(ptyF.Fd(), w, h) sshSetWinsize(ptyF.Fd(), w, h)

View File

@ -26,9 +26,9 @@ networks:
- eno.* - eno.*
- enp.* - enp.*
script: | script: |
ip a add 2001:41d0:306:168f::1337:2eed/64 dev $iface
ip li set $iface up ip li set $iface up
#udhcpc $iface udhcpc -i $iface -b -t1 -T1 -A5 ||
ip a add 2001:41d0:306:168f::1337:2eed/64 dev $iface
pre_lvm_crypt: pre_lvm_crypt:
- dev: /dev/vda - dev: /dev/vda

25
tty.go
View File

@ -1,29 +1,25 @@
package main package main
import ( import (
"os"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var (
stdinTTY = &tty{int(os.Stdin.Fd()), nil}
)
type tty struct { type tty struct {
fd int fd int
termios *unix.Termios termios *unix.Termios
} }
func newTTY(fd uintptr) *tty {
termios, _ := unix.IoctlGetTermios(int(fd), unix.TCGETS)
return &tty{int(fd), termios}
}
func (t *tty) EchoOff() { func (t *tty) EchoOff() {
termios, err := unix.IoctlGetTermios(t.fd, unix.TCGETS) if t.termios == nil {
if err != nil {
return return
} }
t.termios = termios newState := *t.termios
newState := *termios
newState.Lflag &^= unix.ECHO newState.Lflag &^= unix.ECHO
newState.Lflag |= unix.ICANON | unix.ISIG newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= unix.ICRNL newState.Iflag |= unix.ICRNL
@ -31,8 +27,9 @@ func (t *tty) EchoOff() {
} }
func (t *tty) Restore() { func (t *tty) Restore() {
if t.termios != nil { if t.termios == nil {
unix.IoctlSetTermios(t.fd, unix.TCSETS, t.termios) return
t.termios = nil
} }
unix.IoctlSetTermios(t.fd, unix.TCSETS, t.termios)
} }