initrd/lvm.go
2024-11-04 17:33:23 +01:00

411 lines
8.3 KiB
Go

package main
import (
"bytes"
"io/fs"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"syscall"
udev "github.com/jochenvg/go-udev"
"github.com/rs/zerolog/log"
config "novit.tech/direktil/pkg/bootstrapconfig"
"novit.tech/direktil/initrd/lvm"
)
func sortedKeys[T any](m map[string]T) (keys []string) {
keys = make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return
}
func setupLVM(cfg *config.Config) {
if len(cfg.LVM) == 0 {
log.Info().Msg("no LVM VG configured.")
return
}
// [dev] = filesystem
// eg: [/dev/sda1] = ext4
createdDevs := map[string]string{}
run("pvscan")
run("vgscan", "--mknodes")
for _, vg := range cfg.LVM {
setupVG(vg)
}
for _, vg := range cfg.LVM {
setupLVs(vg, createdDevs)
}
run("vgchange", "--sysinit", "-a", "ly")
setupCrypt(cfg.Crypt, createdDevs)
devs := make([]string, 0, len(createdDevs))
for k := range createdDevs {
devs = append(devs, k)
}
sort.Strings(devs)
for _, dev := range devs {
setupFS(dev, createdDevs[dev])
}
}
func setupVG(vg config.LvmVG) {
pvs := lvm.PVSReport{}
err := runJSON(&pvs, "pvs", "--reportformat", "json")
if err != nil {
fatalf("failed to list LVM PVs: %v", err)
}
vgExists := false
devNeeded := vg.PVs.N
for _, pv := range pvs.PVs() {
if pv.VGName == vg.VG {
vgExists = true
devNeeded--
}
}
log := log.With().Str("vg", vg.VG).Logger()
if devNeeded <= 0 {
log.Info().Msg("LVM VG has all its devices")
return
}
if vgExists {
log.Info().Msgf("LVM VG misses %d devices", devNeeded)
} else {
log.Info().Msg("LVM VG does not exists, creating")
}
devNames := make([]NameAliases, 0)
{
devRefs := map[uint64]*NameAliases{}
enum := new(udev.Udev).NewEnumerate()
enum.AddMatchSubsystem("block")
devs, err := enum.Devices()
if err != nil {
fatal("udev enumeration failed")
}
for _, dev := range devs {
num := dev.Devnum()
n := dev.PropertyValue("DEVNAME")
idx := len(devNames)
devNames = append(devNames, nameAlias(n))
ref := uint64(num.Major())<<8 | uint64(num.Minor())
devRefs[ref] = &devNames[idx]
}
err = filepath.Walk("/dev", func(n string, fi fs.FileInfo, err error) error {
if fi.Mode().Type() == os.ModeDevice {
stat := fi.Sys().(*syscall.Stat_t)
ref := stat.Rdev
if na := devRefs[ref]; na != nil {
na.AddAlias(n)
}
}
return err
})
if err != nil {
fatalf("failed to walk /dev: %v", err)
}
}
for _, dev := range devNames {
log.Info().Str("name", dev.Name).Any("aliases", dev.Aliases).Msg("found block device")
}
m := regexpSelectN(vg.PVs.N, vg.PVs.Regexps, devNames)
if len(m) == 0 {
log.Error().Strs("regexps", vg.PVs.Regexps).Msg("no device match the regexps")
fatalf("failed to setup VG %s", vg.VG)
}
if vgExists {
log.Info().Strs("devices", m).Msg("LVM VG: extending")
run("vgextend", append([]string{vg.VG}, m...)...)
devNeeded -= len(m)
} else {
log.Info().Strs("devices", m).Msg("LVM VG: creating")
run("vgcreate", append([]string{vg.VG}, m...)...)
devNeeded -= len(m)
}
if devNeeded > 0 {
fatalf("VG %s does not have enough devices (%d missing)", vg.VG, devNeeded)
}
}
func setupLVs(vg config.LvmVG, createdDevs map[string]string) {
lvsRep := lvm.LVSReport{}
err := runJSON(&lvsRep, "lvs", "--reportformat", "json")
if err != nil {
fatalf("lvs failed: %v", err)
}
lvs := lvsRep.LVs()
defaults := vg.Defaults
for idx, lv := range vg.LVs {
log := log.With().Str("vg", vg.VG).Str("lv", lv.Name).Logger()
if contains(lvs, func(v lvm.LV) bool {
return v.VGName == vg.VG && v.Name == lv.Name
}) {
log.Info().Msg("LV exists")
continue
}
log.Info().Msg("LV does not exist")
if lv.Raid == nil {
lv.Raid = defaults.Raid
}
args := make([]string, 0)
if lv.Name == "" {
fatalf("LV[%d] has no name", idx)
}
args = append(args, vg.VG, "--name", lv.Name)
if lv.Size != "" && lv.Extents != "" {
fatalf("LV has both size and extents defined!")
} else if lv.Size == "" && lv.Extents == "" {
fatalf("LV does not have size or extents defined!")
} else if lv.Size != "" {
args = append(args, "-L", lv.Size)
} else /* if lv.Extents != "" */ {
args = append(args, "-l", lv.Extents)
}
if raid := lv.Raid; raid != nil {
if raid.Mirrors != 0 {
args = append(args, "--mirrors", strconv.Itoa(raid.Mirrors))
}
if raid.Stripes != 0 {
args = append(args, "--stripes", strconv.Itoa(raid.Stripes))
}
}
log.Info().Strs("args", args).Msg("LV: creating")
run("lvcreate", args...)
dev := "/dev/" + vg.VG + "/" + lv.Name
zeroDevStart(dev)
fs := lv.FS
if fs == "" {
fs = vg.Defaults.FS
}
createdDevs[dev] = fs
}
}
func zeroDevStart(dev string) {
f, err := os.OpenFile(dev, os.O_WRONLY, 0600)
if err != nil {
fatalf("failed to open %s: %v", dev, err)
}
defer f.Close()
_, err = f.Write(make([]byte, 8192))
if err != nil {
fatalf("failed to zero the beginning of %s: %v", dev, err)
}
}
var cryptDevs = map[string]bool{}
func setupCrypt(devSpecs []config.CryptDev, createdDevs map[string]string) {
var password []byte
passwordVerified := false
// flat, expanded devices to open
devNames := make([]config.CryptDev, 0, len(devSpecs))
for _, devSpec := range devSpecs {
if devSpec.Dev == "" && devSpec.Prefix == "" {
fatalf("crypt: name %q: no dev or match set", devSpec.Name)
}
if devSpec.Dev != "" && devSpec.Prefix != "" {
fatalf("crypt: name %q: both dev (%q) and match (%q) are set", devSpec.Name, devSpec.Dev, devSpec.Prefix)
}
if devSpec.Dev != "" {
// already flat
devNames = append(devNames, devSpec)
continue
}
matches, err := filepath.Glob(devSpec.Prefix + "*")
if err != nil {
fatalf("failed to search for device matches: %v", err)
}
for _, m := range matches {
suffix := m[len(devSpec.Prefix):]
devNames = append(devNames, config.CryptDev{Dev: m, Name: devSpec.Name + suffix})
}
}
for _, devName := range devNames {
name, dev := devName.Name, devName.Dev
if name == "" {
name = filepath.Base(dev)
}
if cryptDevs[name] {
fatalf("duplicate crypt device name: %s", name)
}
cryptDevs[name] = true
retryOpen:
if len(password) == 0 {
password = askSecret("crypt password")
if len(password) == 0 {
fatalf("empty password given")
}
}
fs := createdDevs[dev]
delete(createdDevs, dev)
tgtDev := "/dev/mapper/" + name
needFormat := !devInitialized(dev)
if needFormat {
if !passwordVerified {
retry:
p2 := askSecret("verify crypt password")
eq := bytes.Equal(password, p2)
for i := range p2 {
p2[i] = 0
}
if !eq {
log.Error().Msg("passwords don't match")
goto retry
}
}
log.Info().Str("dev", dev).Msg("formatting encrypted device")
cmd := exec.Command("cryptsetup", "luksFormat", dev, "--key-file=-")
cmd.Stdin = bytes.NewBuffer(password)
cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
fatalf("failed luksFormat: %v", err)
}
createdDevs[tgtDev] = fs
}
if len(password) == 0 {
password = askSecret("crypt password")
if len(password) == 0 {
fatalf("empty password given")
}
}
log.Info().Str("name", name).Str("dev", dev).Msg("openning encrypted device")
cmd := exec.Command("cryptsetup", "open", dev, name, "--key-file=-")
cmd.Stdin = bytes.NewBuffer(password)
cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
// maybe the password is wrong
for i := range password {
password[i] = 0
}
password = password[:0]
passwordVerified = false
goto retryOpen
}
if needFormat {
zeroDevStart(tgtDev)
}
passwordVerified = true
}
for i := range password {
password[i] = 0
}
}
func devInitialized(dev string) bool {
f, err := os.Open(dev)
if err != nil {
fatalf("failed to open %s: %v", dev, err)
}
defer f.Close()
ba := make([]byte, 8192)
_, err = f.Read(ba)
if err != nil {
fatalf("failed to read %s: %v", dev, err)
}
for _, b := range ba {
if b != 0 {
return true
}
}
return false
}
func setupFS(dev, fs string) {
if devInitialized(dev) {
log.Info().Str("dev", dev).Msg("device already formatted")
return
}
if fs == "" {
fs = "ext4"
}
log.Info().Str("dev", dev).Str("fs", fs).Msg("formatting device")
args := make([]string, 0)
switch fs {
case "btrfs":
args = append(args, "-f")
case "ext4":
args = append(args, "-F")
}
run("mkfs."+fs, append(args, dev)...)
}