package main import ( "bytes" "io/fs" "os" "os/exec" "path/filepath" "sort" "strconv" "syscall" udev "github.com/jochenvg/go-udev" "github.com/rs/zerolog/log" config "novit.tech/direktil/pkg/bootstrapconfig" "novit.tech/direktil/initrd/lvm" ) func sortedKeys[T any](m map[string]T) (keys []string) { keys = make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return } func setupLVM(cfg *config.Config) { if len(cfg.LVM) == 0 { log.Info().Msg("no LVM VG configured.") return } // [dev] = filesystem // eg: [/dev/sda1] = ext4 createdDevs := map[string]string{} run("pvscan") run("vgscan", "--mknodes") for _, vg := range cfg.LVM { setupVG(vg) } for _, vg := range cfg.LVM { setupLVs(vg, createdDevs) } run("vgchange", "--sysinit", "-a", "ly") setupCrypt(cfg.Crypt, createdDevs) devs := make([]string, 0, len(createdDevs)) for k := range createdDevs { devs = append(devs, k) } sort.Strings(devs) for _, dev := range devs { setupFS(dev, createdDevs[dev]) } } func setupVG(vg config.LvmVG) { pvs := lvm.PVSReport{} err := runJSON(&pvs, "pvs", "--reportformat", "json") if err != nil { fatalf("failed to list LVM PVs: %v", err) } vgExists := false devNeeded := vg.PVs.N for _, pv := range pvs.PVs() { if pv.VGName == vg.VG { vgExists = true devNeeded-- } } log := log.With().Str("vg", vg.VG).Logger() if devNeeded <= 0 { log.Info().Msg("LVM VG has all its devices") return } if vgExists { log.Info().Msgf("LVM VG misses %d devices", devNeeded) } else { log.Info().Msg("LVM VG does not exists, creating") } devNames := make([]NameAliases, 0) { devRefs := map[uint64]*NameAliases{} enum := new(udev.Udev).NewEnumerate() enum.AddMatchSubsystem("block") devs, err := enum.Devices() if err != nil { fatal("udev enumeration failed") } for _, dev := range devs { num := dev.Devnum() n := dev.PropertyValue("DEVNAME") idx := len(devNames) devNames = append(devNames, nameAlias(n)) ref := uint64(num.Major())<<8 | uint64(num.Minor()) devRefs[ref] = &devNames[idx] } err = filepath.Walk("/dev", func(n string, fi fs.FileInfo, err error) error { if fi.Mode().Type() == os.ModeDevice { stat := fi.Sys().(*syscall.Stat_t) ref := stat.Rdev if na := devRefs[ref]; na != nil { na.AddAlias(n) } } return err }) if err != nil { fatalf("failed to walk /dev: %v", err) } } for _, dev := range devNames { log.Info().Str("name", dev.Name).Any("aliases", dev.Aliases).Msg("found block device") } m := regexpSelectN(vg.PVs.N, vg.PVs.Regexps, devNames) if len(m) == 0 { log.Error().Strs("regexps", vg.PVs.Regexps).Msg("no device match the regexps") fatalf("failed to setup VG %s", vg.VG) } if vgExists { log.Info().Strs("devices", m).Msg("LVM VG: extending") run("vgextend", append([]string{vg.VG}, m...)...) devNeeded -= len(m) } else { log.Info().Strs("devices", m).Msg("LVM VG: creating") run("vgcreate", append([]string{vg.VG}, m...)...) devNeeded -= len(m) } if devNeeded > 0 { fatalf("VG %s does not have enough devices (%d missing)", vg.VG, devNeeded) } } func setupLVs(vg config.LvmVG, createdDevs map[string]string) { lvsRep := lvm.LVSReport{} err := runJSON(&lvsRep, "lvs", "--reportformat", "json") if err != nil { fatalf("lvs failed: %v", err) } lvs := lvsRep.LVs() defaults := vg.Defaults for idx, lv := range vg.LVs { log := log.With().Str("vg", vg.VG).Str("lv", lv.Name).Logger() if contains(lvs, func(v lvm.LV) bool { return v.VGName == vg.VG && v.Name == lv.Name }) { log.Info().Msg("LV exists") continue } log.Info().Msg("LV does not exist") if lv.Raid == nil { lv.Raid = defaults.Raid } args := make([]string, 0) if lv.Name == "" { fatalf("LV[%d] has no name", idx) } args = append(args, vg.VG, "--name", lv.Name) if lv.Size != "" && lv.Extents != "" { fatalf("LV has both size and extents defined!") } else if lv.Size == "" && lv.Extents == "" { fatalf("LV does not have size or extents defined!") } else if lv.Size != "" { args = append(args, "-L", lv.Size) } else /* if lv.Extents != "" */ { args = append(args, "-l", lv.Extents) } if raid := lv.Raid; raid != nil { if raid.Mirrors != 0 { args = append(args, "--mirrors", strconv.Itoa(raid.Mirrors)) } if raid.Stripes != 0 { args = append(args, "--stripes", strconv.Itoa(raid.Stripes)) } } log.Info().Strs("args", args).Msg("LV: creating") run("lvcreate", args...) dev := "/dev/" + vg.VG + "/" + lv.Name zeroDevStart(dev) fs := lv.FS if fs == "" { fs = vg.Defaults.FS } createdDevs[dev] = fs } } func zeroDevStart(dev string) { f, err := os.OpenFile(dev, os.O_WRONLY, 0600) if err != nil { fatalf("failed to open %s: %v", dev, err) } defer f.Close() _, err = f.Write(make([]byte, 8192)) if err != nil { fatalf("failed to zero the beginning of %s: %v", dev, err) } } var cryptDevs = map[string]bool{} func setupCrypt(devSpecs []config.CryptDev, createdDevs map[string]string) { var password []byte passwordVerified := false // flat, expanded devices to open devNames := make([]config.CryptDev, 0, len(devSpecs)) for _, devSpec := range devSpecs { if devSpec.Dev == "" && devSpec.Prefix == "" { fatalf("crypt: name %q: no dev or match set", devSpec.Name) } if devSpec.Dev != "" && devSpec.Prefix != "" { fatalf("crypt: name %q: both dev (%q) and match (%q) are set", devSpec.Name, devSpec.Dev, devSpec.Prefix) } if devSpec.Dev != "" { // already flat devNames = append(devNames, devSpec) continue } matches, err := filepath.Glob(devSpec.Prefix + "*") if err != nil { fatalf("failed to search for device matches: %v", err) } for _, m := range matches { suffix := m[len(devSpec.Prefix):] devNames = append(devNames, config.CryptDev{Dev: m, Name: devSpec.Name + suffix}) } } for _, devName := range devNames { name, dev := devName.Name, devName.Dev if name == "" { name = filepath.Base(dev) } if cryptDevs[name] { fatalf("duplicate crypt device name: %s", name) } cryptDevs[name] = true retryOpen: if len(password) == 0 { password = askSecret("crypt password") if len(password) == 0 { fatalf("empty password given") } } fs := createdDevs[dev] delete(createdDevs, dev) tgtDev := "/dev/mapper/" + name needFormat := !devInitialized(dev) if needFormat { if !passwordVerified { retry: p2 := askSecret("verify crypt password") eq := bytes.Equal(password, p2) for i := range p2 { p2[i] = 0 } if !eq { log.Error().Msg("passwords don't match") goto retry } } log.Info().Str("dev", dev).Msg("formatting encrypted device") cmd := exec.Command("cryptsetup", "luksFormat", dev, "--key-file=-") cmd.Stdin = bytes.NewBuffer(password) cmd.Stdout = stdout cmd.Stderr = stderr err := cmd.Run() if err != nil { fatalf("failed luksFormat: %v", err) } createdDevs[tgtDev] = fs } if len(password) == 0 { password = askSecret("crypt password") if len(password) == 0 { fatalf("empty password given") } } log.Info().Str("name", name).Str("dev", dev).Msg("openning encrypted device") cmd := exec.Command("cryptsetup", "open", dev, name, "--key-file=-") cmd.Stdin = bytes.NewBuffer(password) cmd.Stdout = stdout cmd.Stderr = stderr err := cmd.Run() if err != nil { // maybe the password is wrong for i := range password { password[i] = 0 } password = password[:0] passwordVerified = false goto retryOpen } if needFormat { zeroDevStart(tgtDev) } passwordVerified = true } for i := range password { password[i] = 0 } } func devInitialized(dev string) bool { f, err := os.Open(dev) if err != nil { fatalf("failed to open %s: %v", dev, err) } defer f.Close() ba := make([]byte, 8192) _, err = f.Read(ba) if err != nil { fatalf("failed to read %s: %v", dev, err) } for _, b := range ba { if b != 0 { return true } } return false } func setupFS(dev, fs string) { if devInitialized(dev) { log.Info().Str("dev", dev).Msg("device already formatted") return } if fs == "" { fs = "ext4" } log.Info().Str("dev", dev).Str("fs", fs).Msg("formatting device") args := make([]string, 0) switch fs { case "btrfs": args = append(args, "-f") case "ext4": args = append(args, "-F") } run("mkfs."+fs, append(args, dev)...) }