package main import ( "encoding/binary" "fmt" "io" "log" "net" "os" "os/exec" "sync" "syscall" "unsafe" "github.com/kr/pty" "golang.org/x/crypto/ssh" config "novit.tech/direktil/pkg/bootstrapconfig" ) func startSSH(cfg *config.Config) { sshConfig := &ssh.ServerConfig{ PublicKeyCallback: sshCheckPubkey, } pkBytes, err := os.ReadFile("/id_rsa") // TODO configurable if err != nil { fatalf("ssh: failed to load private key: %v", err) } pk, err := ssh.ParsePrivateKey(pkBytes) if err != nil { fatalf("ssh: failed to parse private key: %v", err) } sshConfig.AddHostKey(pk) sshBind := ":22" // TODO configurable listener, err := net.Listen("tcp", sshBind) if err != nil { fatalf("ssh: failed to listen on %s: %v", sshBind, err) } log.Print("SSH server listening on ", sshBind) go func() { for { conn, err := listener.Accept() if err != nil { log.Print("ssh: accept conn failed: ", err) continue } go sshHandleConn(conn, sshConfig) } }() } func sshHandleConn(conn net.Conn, sshConfig *ssh.ServerConfig) { sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig) if err != nil { log.Print("ssh: handshake failed: ", err) return } remoteAddr := sshConn.User() + "@" + sshConn.RemoteAddr().String() log.Print("ssh: new connection from ", remoteAddr) go sshHandleReqs(reqs) go sshHandleChannels(remoteAddr, chans) } func sshHandleReqs(reqs <-chan *ssh.Request) { for req := range reqs { switch req.Type { case "keepalive@openssh.com": req.Reply(true, nil) default: log.Printf("ssh: discarding req: %+v", req) req.Reply(false, nil) } } } func sshHandleChannels(remoteAddr string, chans <-chan ssh.NewChannel) { for newChannel := range chans { if t := newChannel.ChannelType(); t != "session" { newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) continue } channel, requests, err := newChannel.Accept() if err != nil { log.Print("ssh: failed to accept channel: ", err) continue } go sshHandleChannel(remoteAddr, channel, requests) } } func sshHandleChannel(remoteAddr string, channel ssh.Channel, requests <-chan *ssh.Request) { var ( ptyF, ttyF *os.File termEnv string ) defer func() { if ptyF != nil { ptyF.Close() } if ttyF != nil { ttyF.Close() } }() var once sync.Once closeCh := func() { channel.Close() } for req := range requests { switch req.Type { case "exec": command := string(req.Payload[4 : req.Payload[3]+4]) switch command { case "init": go func() { io.Copy(channel, stdout.NewReader()) once.Do(closeCh) }() go func() { io.Copy(stdinPipe, channel) once.Do(closeCh) }() req.Reply(true, nil) case "bootstrap": // extract a new bootstrap package os.MkdirAll("/bootstrap/current", 0750) cmd := exec.Command("/bin/tar", "xv", "-C", "/bootstrap/current") cmd.Stdin = channel cmd.Stdout = channel cmd.Stderr = channel.Stderr() go func() { cmd.Run() closeCh() }() req.Reply(true, nil) default: req.Reply(false, nil) } case "shell": cmd := exec.Command("/bin/ash") cmd.Env = []string{"TERM=" + termEnv} cmd.Stdin = ttyF cmd.Stdout = ttyF cmd.Stderr = ttyF cmd.SysProcAttr = &syscall.SysProcAttr{ Setctty: true, Setsid: true, Pdeathsig: syscall.SIGKILL, } cmd.Start() go func() { cmd.Wait() ptyF.Close() ptyF = nil ttyF.Close() ttyF = nil }() go func() { io.Copy(channel, ptyF) once.Do(closeCh) }() go func() { io.Copy(ptyF, channel) once.Do(closeCh) }() req.Reply(true, nil) case "pty-req": if ptyF != nil || ttyF != nil { req.Reply(false, nil) continue } var err error ptyF, ttyF, err = pty.Open() if err != nil { log.Print("PTY err: ", err) req.Reply(false, nil) continue } termLen := req.Payload[3] termEnv = string(req.Payload[4 : termLen+4]) w, h := sshParseDims(req.Payload[termLen+4:]) sshSetWinsize(ptyF.Fd(), w, h) req.Reply(true, nil) case "window-change": w, h := sshParseDims(req.Payload) sshSetWinsize(ptyF.Fd(), w, h) // no response default: req.Reply(false, nil) } } } func sshParseDims(b []byte) (uint32, uint32) { w := binary.BigEndian.Uint32(b) h := binary.BigEndian.Uint32(b[4:]) return w, h } // SetWinsize sets the size of the given pty. func sshSetWinsize(fd uintptr, w, h uint32) { // Winsize stores the Height and Width of a terminal. type Winsize struct { Height uint16 Width uint16 x uint16 // unused y uint16 // unused } ws := &Winsize{Width: uint16(w), Height: uint16(h)} syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(ws))) }