initrd/main.go
2024-01-21 15:32:51 +01:00

246 lines
5.0 KiB
Go

package main
import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/pkg/term/termios"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/term"
"novit.tech/direktil/initrd/colorio"
"novit.tech/direktil/initrd/shio"
)
const (
// VERSION is the current version of init
VERSION = "Direktil init v2.0"
rootMountFlags = 0
bootMountFlags = syscall.MS_NOEXEC | syscall.MS_NODEV | syscall.MS_NOSUID | syscall.MS_RDONLY
layerMountFlags = syscall.MS_RDONLY
)
var (
bootVersion string
stdin,
stdinPipe = newPipe()
stdout = shio.New()
stderr = colorio.NewWriter(colorio.Bold, stdout)
)
func newPipe() (io.ReadCloser, io.WriteCloser) {
return io.Pipe()
}
func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
registerInput(newTTY(os.Stdin.Fd()))
switch baseName := filepath.Base(os.Args[0]); baseName {
case "init":
runInit()
default:
log.Fatal().Msgf("unknown sub-command: %q", baseName)
}
}
func runInit() {
if len(os.Args) > 1 && os.Args[1] == "hello" {
fmt.Println("hello world!")
os.Exit(0)
}
runtime.LockOSThread()
// move log to shio
go io.Copy(os.Stdout, stdout.NewReader())
log.Logger = log.Output(zerolog.ConsoleWriter{Out: stderr})
// check the PID is 1
if pid := os.Getpid(); pid != 1 {
log.Fatal().Int("pid", pid).Msg("init must be PID 1")
}
// copy os.Stdin to my stdin pipe
go io.Copy(stdinPipe, os.Stdin)
log.Info().Msg("Welcome to " + VERSION)
// essential mounts
mount("none", "/proc", "proc", 0, "")
mount("none", "/sys", "sysfs", 0, "")
mount("none", "/dev", "devtmpfs", 0, "")
mount("none", "/dev/pts", "devpts", 0, "gid=5,mode=620")
// get the "boot version"
bootVersion = param("version", "current")
log.Info().Msgf("booting system %q", bootVersion)
os.Setenv("PATH", "/usr/bin:/bin:/usr/sbin:/sbin")
_, err := os.Stat("/config.yaml")
if err != nil {
log.Error().Err(err).Msg("config not found")
fatal()
}
bootV2()
}
var (
layersDir = "/boot/current/layers/"
layersOverride = map[string]string{}
)
func layerPath(name string) string {
if override, ok := layersOverride[name]; ok {
return override
}
return filepath.Join(layersDir, name+".fs")
}
func fatal(v ...interface{}) {
log.Error().Msg("*** FATAL ***")
log.Error().Msg(fmt.Sprint(v...))
die()
}
func fatalf(pattern string, v ...interface{}) {
log.Error().Msg("*** FATAL ***")
log.Error().Msgf(pattern, v...)
die()
}
func die() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
stdout.Close()
stdin.Close()
stdinPipe.Close()
stdin = nil
mainLoop:
for {
termios.Tcdrain(os.Stdin.Fd())
termios.Tcdrain(os.Stdout.Fd())
termios.Tcdrain(os.Stderr.Fd())
fmt.Print("\nr to reboot, o to power off, s to get a shell: ")
// TODO flush stdin (first char lost here?)
deadline := time.Now().Add(time.Minute)
os.Stdin.SetReadDeadline(deadline)
term.MakeRaw(int(os.Stdin.Fd()))
termios.Tcflush(os.Stdin.Fd(), termios.TCIFLUSH)
b := make([]byte, 1)
_, err := os.Stdin.Read(b)
if err != nil {
log.Error().Err(err).Msg("failed to read from stdin")
time.Sleep(5 * time.Second)
syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
}
fmt.Println()
switch b[0] {
case 'o':
run("sync")
syscall.Reboot(syscall.LINUX_REBOOT_CMD_POWER_OFF)
case 'r':
run("sync")
syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
case 's':
for _, sh := range []string{"bash", "ash", "sh", "busybox"} {
fullPath, err := exec.LookPath(sh)
if err != nil {
continue
}
args := make([]string, 0)
if sh == "busybox" {
args = append(args, "sh")
}
if !localAuth() {
continue mainLoop
}
cmd := exec.Command(fullPath, args...)
cmd.Env = os.Environ()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
fmt.Println("shell failed:", err)
}
continue mainLoop
}
log.Error().Msg("failed to find a shell!")
default:
log.Error().Msgf("unknown choice: %q", string(b))
}
}
}
func run(cmd string, args ...string) {
if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
fatalf("command %s %q failed: %v\n%s", cmd, args, err, string(output))
}
}
func mkdir(dir string, mode os.FileMode) {
if err := os.MkdirAll(dir, mode); err != nil {
fatalf("mkdir %q failed: %v", dir, err)
}
}
func mount(source, target, fstype string, flags uintptr, data string) {
if _, err := os.Stat(target); os.IsNotExist(err) {
mkdir(target, 0755)
}
if err := syscall.Mount(source, target, fstype, flags, data); err != nil {
fatalf("mount %q %q -t %q -o %q failed: %v", source, target, fstype, data, err)
}
log.Info().Str("target", target).Msg("mounted")
}
func cp(srcPath, dstPath string) {
var err error
defer func() {
if err != nil {
fatalf("cp %s %s failed: %v", srcPath, dstPath, err)
}
}()
src, err := os.Open(srcPath)
if err != nil {
return
}
defer src.Close()
dst, err := os.Create(dstPath)
if err != nil {
return
}
defer dst.Close()
_, err = io.Copy(dst, src)
}