260 lines
5.2 KiB
Go
260 lines
5.2 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"syscall"
|
|
"unsafe"
|
|
|
|
"github.com/kr/pty"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
config "novit.tech/direktil/pkg/bootstrapconfig"
|
|
)
|
|
|
|
func startSSH(cfg *config.Config) {
|
|
sshConfig := &ssh.ServerConfig{
|
|
PublicKeyCallback: sshCheckPubkey,
|
|
}
|
|
|
|
hostKeyLoaded := false
|
|
|
|
for _, format := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
|
|
log := log.With().Str("format", format).Logger()
|
|
|
|
pkBytes, err := os.ReadFile("/id_" + format)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("ssh: failed to load host key")
|
|
continue
|
|
}
|
|
|
|
pk, err := ssh.ParsePrivateKey(pkBytes)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("ssh: failed to parse host key")
|
|
continue
|
|
}
|
|
|
|
sshConfig.AddHostKey(pk)
|
|
hostKeyLoaded = true
|
|
|
|
log.Info().Msg("ssh: loaded host key")
|
|
}
|
|
|
|
if !hostKeyLoaded {
|
|
fatalf("ssh: failed to load any host key")
|
|
}
|
|
|
|
sshBind := ":22" // TODO configurable
|
|
listener, err := net.Listen("tcp", sshBind)
|
|
if err != nil {
|
|
fatalf("ssh: failed to listen on %s: %v", sshBind, err)
|
|
}
|
|
|
|
log.Info().Str("bind-address", sshBind).Msg("SSH server listening")
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Info().Err(err).Msg("ssh: accept conn failed")
|
|
continue
|
|
}
|
|
|
|
go sshHandleConn(conn, sshConfig)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func sshHandleConn(conn net.Conn, sshConfig *ssh.ServerConfig) {
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig)
|
|
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("ssh: handshake failed")
|
|
return
|
|
}
|
|
|
|
remoteAddr := sshConn.User() + "@" + sshConn.RemoteAddr().String()
|
|
log.Info().Str("remote", remoteAddr).Msg("ssh: new connection")
|
|
|
|
go sshHandleReqs(reqs)
|
|
go sshHandleChannels(remoteAddr, chans)
|
|
}
|
|
|
|
func sshHandleReqs(reqs <-chan *ssh.Request) {
|
|
for req := range reqs {
|
|
switch req.Type {
|
|
case "keepalive@openssh.com":
|
|
req.Reply(true, nil)
|
|
|
|
default:
|
|
log.Info().Str("type", req.Type).Msg("ssh: discarding request")
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
func sshHandleChannels(remoteAddr string, chans <-chan ssh.NewChannel) {
|
|
for newChannel := range chans {
|
|
if t := newChannel.ChannelType(); t != "session" {
|
|
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
|
continue
|
|
}
|
|
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("ssh: failed to accept channel")
|
|
continue
|
|
}
|
|
|
|
go sshHandleChannel(remoteAddr, channel, requests)
|
|
}
|
|
}
|
|
|
|
func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *ssh.Request) {
|
|
var (
|
|
ptyF, ttyF *os.File
|
|
termEnv string
|
|
)
|
|
|
|
defer func() {
|
|
if ptyF != nil {
|
|
ptyF.Close()
|
|
}
|
|
if ttyF != nil {
|
|
ttyF.Close()
|
|
}
|
|
}()
|
|
|
|
var once sync.Once
|
|
closeCh := func() {
|
|
channel.Close()
|
|
}
|
|
|
|
for req := range requests {
|
|
switch req.Type {
|
|
case "exec":
|
|
command := string(req.Payload[4 : req.Payload[3]+4])
|
|
switch command {
|
|
case "init":
|
|
go func() {
|
|
io.Copy(channel, stdout.NewReader())
|
|
once.Do(closeCh)
|
|
}()
|
|
go func() {
|
|
io.Copy(stdinPipe, channel)
|
|
once.Do(closeCh)
|
|
}()
|
|
|
|
req.Reply(true, nil)
|
|
|
|
case "bootstrap":
|
|
// extract a new bootstrap package
|
|
os.MkdirAll("/bootstrap/current", 0750)
|
|
|
|
cmd := exec.Command("/bin/tar", "xv", "-C", "/bootstrap/current")
|
|
cmd.Stdin = channel
|
|
cmd.Stdout = channel
|
|
cmd.Stderr = channel.Stderr()
|
|
|
|
go func() {
|
|
cmd.Run()
|
|
closeCh()
|
|
}()
|
|
|
|
req.Reply(true, nil)
|
|
|
|
default:
|
|
req.Reply(false, nil)
|
|
}
|
|
|
|
case "shell":
|
|
cmd := exec.Command("/bin/ash")
|
|
cmd.Env = []string{"TERM=" + termEnv}
|
|
cmd.Stdin = ttyF
|
|
cmd.Stdout = ttyF
|
|
cmd.Stderr = ttyF
|
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
|
Setctty: true,
|
|
Setsid: true,
|
|
Pdeathsig: syscall.SIGKILL,
|
|
}
|
|
|
|
cmd.Start()
|
|
|
|
go func() {
|
|
cmd.Wait()
|
|
ptyF.Close()
|
|
ptyF = nil
|
|
ttyF.Close()
|
|
ttyF = nil
|
|
}()
|
|
|
|
go func() {
|
|
io.Copy(channel, ptyF)
|
|
once.Do(closeCh)
|
|
}()
|
|
go func() {
|
|
io.Copy(ptyF, channel)
|
|
once.Do(closeCh)
|
|
}()
|
|
|
|
req.Reply(true, nil)
|
|
|
|
case "pty-req":
|
|
if ptyF != nil || ttyF != nil {
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
|
|
var err error
|
|
ptyF, ttyF, err = pty.Open()
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("ssh: PTY open failed")
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
|
|
termLen := req.Payload[3]
|
|
termEnv = string(req.Payload[4 : termLen+4])
|
|
w, h := sshParseDims(req.Payload[termLen+4:])
|
|
sshSetWinsize(ptyF.Fd(), w, h)
|
|
|
|
req.Reply(true, nil)
|
|
|
|
case "window-change":
|
|
w, h := sshParseDims(req.Payload)
|
|
sshSetWinsize(ptyF.Fd(), w, h)
|
|
// no response
|
|
|
|
default:
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
func sshParseDims(b []byte) (uint32, uint32) {
|
|
w := binary.BigEndian.Uint32(b)
|
|
h := binary.BigEndian.Uint32(b[4:])
|
|
return w, h
|
|
}
|
|
|
|
// SetWinsize sets the size of the given pty.
|
|
func sshSetWinsize(fd uintptr, w, h uint32) {
|
|
// Winsize stores the Height and Width of a terminal.
|
|
type Winsize struct {
|
|
Height uint16
|
|
Width uint16
|
|
x uint16 // unused
|
|
y uint16 // unused
|
|
}
|
|
|
|
ws := &Winsize{Width: uint16(w), Height: uint16(h)}
|
|
syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(ws)))
|
|
}
|