initrd/main.go

272 lines
5.7 KiB
Go
Raw Normal View History

2019-02-11 05:05:43 +00:00
package main
import (
"fmt"
2020-02-28 00:30:10 +00:00
"io"
2019-02-11 05:05:43 +00:00
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
2019-03-07 00:35:11 +00:00
"runtime"
2019-02-11 05:05:43 +00:00
"strings"
"syscall"
"time"
2020-11-19 20:22:17 +00:00
"golang.org/x/term"
2019-02-11 05:05:43 +00:00
yaml "gopkg.in/yaml.v2"
"novit.nc/direktil/pkg/sysfs"
)
const (
// VERSION is the current version of init
2020-11-19 16:09:46 +00:00
VERSION = "Direktil init v1.0"
2019-02-11 05:05:43 +00:00
rootMountFlags = 0
bootMountFlags = syscall.MS_NOEXEC | syscall.MS_NODEV | syscall.MS_NOSUID | syscall.MS_RDONLY
layerMountFlags = syscall.MS_RDONLY
)
var (
bootVersion string
)
func main() {
2019-03-07 00:35:11 +00:00
runtime.LockOSThread()
2019-02-11 05:05:43 +00:00
log.Print("Welcome to ", VERSION)
// essential mounts
mount("none", "/proc", "proc", 0, "")
mount("none", "/sys", "sysfs", 0, "")
mount("none", "/dev", "devtmpfs", 0, "")
// get the "boot version"
bootVersion = param("version", "current")
log.Printf("booting system %q", bootVersion)
// find and mount /boot
bootMatch := param("boot", "")
2020-02-28 00:30:10 +00:00
bootMounted := false
2019-02-11 05:05:43 +00:00
if bootMatch != "" {
bootFS := param("boot.fs", "vfat")
for i := 0; ; i++ {
devNames := sysfs.DeviceByProperty("block", bootMatch)
if len(devNames) == 0 {
if i > 30 {
fatal("boot partition not found after 30s")
}
log.Print("boot partition not found, retrying")
time.Sleep(1 * time.Second)
continue
}
devFile := filepath.Join("/dev", devNames[0])
log.Print("boot partition found: ", devFile)
mount(devFile, "/boot", bootFS, bootMountFlags, "")
2020-02-28 00:30:10 +00:00
bootMounted = true
2019-02-11 05:05:43 +00:00
break
}
} else {
log.Print("Assuming /boot is already populated.")
}
// load config
cfgPath := param("config", "/boot/config.yaml")
cfgBytes, err := ioutil.ReadFile(cfgPath)
if err != nil {
fatalf("failed to read %s: %v", cfgPath, err)
}
cfg := &config{}
if err := yaml.Unmarshal(cfgBytes, cfg); err != nil {
fatal("failed to load config: ", err)
}
// mount layers
if len(cfg.Layers) == 0 {
fatal("no layers configured!")
}
log.Printf("wanted layers: %q", cfg.Layers)
2020-02-28 00:30:10 +00:00
layersInMemory := paramBool("layers-in-mem", false)
const layersInMemDir = "/layers-in-mem"
if layersInMemory {
mkdir(layersInMemDir, 0700)
mount("layers-mem", layersInMemDir, "tmpfs", 0, "")
}
2019-02-11 05:05:43 +00:00
lowers := make([]string, len(cfg.Layers))
for i, layer := range cfg.Layers {
path := layerPath(layer)
info, err := os.Stat(path)
if err != nil {
fatal(err)
}
log.Printf("layer %s found (%d bytes)", layer, info.Size())
2020-02-28 00:30:10 +00:00
if layersInMemory {
log.Print(" copying to memory...")
targetPath := filepath.Join(layersInMemDir, layer)
cp(path, targetPath)
path = targetPath
}
2019-02-11 05:05:43 +00:00
dir := "/layers/" + layer
lowers[i] = dir
loopDev := fmt.Sprintf("/dev/loop%d", i)
losetup(loopDev, path)
mount(loopDev, dir, "squashfs", layerMountFlags, "")
}
// prepare system root
mount("mem", "/changes", "tmpfs", 0, "")
mkdir("/changes/workdir", 0755)
mkdir("/changes/upperdir", 0755)
mount("overlay", "/system", "overlay", rootMountFlags,
"lowerdir="+strings.Join(lowers, ":")+",upperdir=/changes/upperdir,workdir=/changes/workdir")
2020-02-28 00:30:10 +00:00
2020-11-19 16:04:45 +00:00
if bootMounted {
if layersInMemory {
if err := syscall.Unmount("/boot", 0); err != nil {
log.Print("WARNING: failed to unmount /boot: ", err)
time.Sleep(2 * time.Second)
}
} else {
mount("/boot", "/system/boot", "", syscall.MS_BIND, "")
2020-02-28 00:30:10 +00:00
}
}
2019-02-11 05:05:43 +00:00
// - write configuration
log.Print("writing /config.yaml")
if err := ioutil.WriteFile("/system/config.yaml", cfgBytes, 0600); err != nil {
fatal("failed: ", err)
}
2019-12-03 09:56:57 +00:00
// - write files
for _, fileDef := range cfg.Files {
log.Print("writing ", fileDef.Path)
filePath := filepath.Join("/system", fileDef.Path)
ioutil.WriteFile(filePath, []byte(fileDef.Content), fileDef.Mode)
}
2019-02-11 05:05:43 +00:00
// clean zombies
cleanZombies()
// switch root
log.Print("switching root")
2019-03-07 00:35:11 +00:00
err = syscall.Exec("/sbin/switch_root", []string{"switch_root",
"-c", "/dev/console", "/system", "/sbin/init"}, os.Environ())
fatal("switch_root failed: ", err)
2019-02-11 05:05:43 +00:00
}
func layerPath(name string) string {
return fmt.Sprintf("/boot/%s/layers/%s.fs", bootVersion, name)
}
func fatal(v ...interface{}) {
log.Print("*** FATAL ***")
log.Print(v...)
2020-11-19 20:22:17 +00:00
die()
2019-02-11 05:05:43 +00:00
}
func fatalf(pattern string, v ...interface{}) {
log.Print("*** FATAL ***")
log.Printf(pattern, v...)
2020-11-19 20:22:17 +00:00
die()
}
func die() {
fmt.Println("\nwill reboot in 1 minute; press r to reboot now, o to power off")
deadline := time.Now().Add(time.Minute)
term.MakeRaw(int(os.Stdin.Fd())) // disable line buffering
os.Stdin.SetReadDeadline(deadline)
b := []byte{0}
for {
_, err := os.Stdin.Read(b)
if err != nil {
break
}
switch b[0] {
case 'o':
syscall.Reboot(syscall.LINUX_REBOOT_CMD_POWER_OFF)
case 'r':
syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
}
}
syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
2019-02-11 05:05:43 +00:00
}
func losetup(dev, file string) {
run("/sbin/losetup", dev, file)
}
func run(cmd string, args ...string) {
if output, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
fatalf("command %s %q failed: %v\n%s", cmd, args, err, string(output))
}
}
func mkdir(dir string, mode os.FileMode) {
if err := os.MkdirAll(dir, mode); err != nil {
fatalf("mkdir %q failed: %v", dir, err)
}
}
func mount(source, target, fstype string, flags uintptr, data string) {
if _, err := os.Stat(target); os.IsNotExist(err) {
mkdir(target, 0755)
}
if err := syscall.Mount(source, target, fstype, flags, data); err != nil {
fatalf("mount %q %q -t %q -o %q failed: %v", source, target, fstype, data, err)
}
log.Printf("mounted %q", target)
}
2020-02-28 00:30:10 +00:00
func cp(srcPath, dstPath string) {
var err error
defer func() {
if err != nil {
fatalf("cp %s %s failed: %v", srcPath, dstPath, err)
}
}()
src, err := os.Open(srcPath)
if err != nil {
return
}
defer src.Close()
dst, err := os.Create(dstPath)
if err != nil {
return
}
defer dst.Close()
_, err = io.Copy(dst, src)
}