initrd/ssh.go
2024-01-20 16:41:54 +01:00

260 lines
5.2 KiB
Go

package main
import (
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/exec"
"sync"
"syscall"
"unsafe"
"github.com/kr/pty"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/ssh"
config "novit.tech/direktil/pkg/bootstrapconfig"
)
func startSSH(cfg *config.Config) {
sshConfig := &ssh.ServerConfig{
PublicKeyCallback: sshCheckPubkey,
}
hostKeyLoaded := false
for _, format := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
log := log.With().Str("format", format).Logger()
pkBytes, err := os.ReadFile("/id_" + format)
if err != nil {
log.Error().Err(err).Msg("ssh: failed to load host key")
continue
}
pk, err := ssh.ParsePrivateKey(pkBytes)
if err != nil {
log.Error().Err(err).Msg("ssh: failed to parse host key")
continue
}
sshConfig.AddHostKey(pk)
hostKeyLoaded = true
log.Info().Msg("ssh: loaded host key")
}
if !hostKeyLoaded {
fatalf("ssh: failed to load any host key")
}
sshBind := ":22" // TODO configurable
listener, err := net.Listen("tcp", sshBind)
if err != nil {
fatalf("ssh: failed to listen on %s: %v", sshBind, err)
}
log.Info().Str("bind-address", sshBind).Msg("SSH server listening")
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Info().Err(err).Msg("ssh: accept conn failed")
continue
}
go sshHandleConn(conn, sshConfig)
}
}()
}
func sshHandleConn(conn net.Conn, sshConfig *ssh.ServerConfig) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig)
if err != nil {
log.Error().Err(err).Msg("ssh: handshake failed")
return
}
remoteAddr := sshConn.User() + "@" + sshConn.RemoteAddr().String()
log.Info().Str("remote", remoteAddr).Msg("ssh: new connection")
go sshHandleReqs(reqs)
go sshHandleChannels(remoteAddr, chans)
}
func sshHandleReqs(reqs <-chan *ssh.Request) {
for req := range reqs {
switch req.Type {
case "keepalive@openssh.com":
req.Reply(true, nil)
default:
log.Info().Str("type", req.Type).Msg("ssh: discarding request")
req.Reply(false, nil)
}
}
}
func sshHandleChannels(remoteAddr string, chans <-chan ssh.NewChannel) {
for newChannel := range chans {
if t := newChannel.ChannelType(); t != "session" {
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Error().Err(err).Msg("ssh: failed to accept channel")
continue
}
go sshHandleChannel(remoteAddr, channel, requests)
}
}
func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *ssh.Request) {
var (
ptyF, ttyF *os.File
termEnv string
)
defer func() {
if ptyF != nil {
ptyF.Close()
}
if ttyF != nil {
ttyF.Close()
}
}()
var once sync.Once
closeCh := func() {
channel.Close()
}
for req := range requests {
switch req.Type {
case "exec":
command := string(req.Payload[4 : req.Payload[3]+4])
switch command {
case "init":
go func() {
io.Copy(channel, stdout.NewReader())
once.Do(closeCh)
}()
go func() {
io.Copy(stdinPipe, channel)
once.Do(closeCh)
}()
req.Reply(true, nil)
case "bootstrap":
// extract a new bootstrap package
os.MkdirAll("/bootstrap/current", 0750)
cmd := exec.Command("/bin/tar", "xv", "-C", "/bootstrap/current")
cmd.Stdin = channel
cmd.Stdout = channel
cmd.Stderr = channel.Stderr()
go func() {
cmd.Run()
closeCh()
}()
req.Reply(true, nil)
default:
req.Reply(false, nil)
}
case "shell":
cmd := exec.Command("/bin/ash")
cmd.Env = []string{"TERM=" + termEnv}
cmd.Stdin = ttyF
cmd.Stdout = ttyF
cmd.Stderr = ttyF
cmd.SysProcAttr = &syscall.SysProcAttr{
Setctty: true,
Setsid: true,
Pdeathsig: syscall.SIGKILL,
}
cmd.Start()
go func() {
cmd.Wait()
ptyF.Close()
ptyF = nil
ttyF.Close()
ttyF = nil
}()
go func() {
io.Copy(channel, ptyF)
once.Do(closeCh)
}()
go func() {
io.Copy(ptyF, channel)
once.Do(closeCh)
}()
req.Reply(true, nil)
case "pty-req":
if ptyF != nil || ttyF != nil {
req.Reply(false, nil)
continue
}
var err error
ptyF, ttyF, err = pty.Open()
if err != nil {
log.Error().Err(err).Msg("ssh: PTY open failed")
req.Reply(false, nil)
continue
}
termLen := req.Payload[3]
termEnv = string(req.Payload[4 : termLen+4])
w, h := sshParseDims(req.Payload[termLen+4:])
sshSetWinsize(ptyF.Fd(), w, h)
req.Reply(true, nil)
case "window-change":
w, h := sshParseDims(req.Payload)
sshSetWinsize(ptyF.Fd(), w, h)
// no response
default:
req.Reply(false, nil)
}
}
}
func sshParseDims(b []byte) (uint32, uint32) {
w := binary.BigEndian.Uint32(b)
h := binary.BigEndian.Uint32(b[4:])
return w, h
}
// SetWinsize sets the size of the given pty.
func sshSetWinsize(fd uintptr, w, h uint32) {
// Winsize stores the Height and Width of a terminal.
type Winsize struct {
Height uint16
Width uint16
x uint16 // unused
y uint16 // unused
}
ws := &Winsize{Width: uint16(w), Height: uint16(h)}
syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(ws)))
}