package main import ( "archive/tar" "bytes" "crypto" "encoding/binary" "encoding/hex" "fmt" "io" "log" "net/http" "os" "os/exec" "path/filepath" "slices" "strings" "github.com/klauspost/compress/zstd" "novit.tech/direktil/pkg/config" "novit.tech/direktil/pkg/cpiocat" "novit.tech/direktil/pkg/localconfig" ) func renderBootstrapConfig(w http.ResponseWriter, ctx *renderContext) (err error) { log.Printf("sending bootstrap config for %q", ctx.Host.Name) ba, err := ctx.BootstrapConfig() if err != nil { return err } _, err = w.Write(ba) return } func buildInitrd(out io.Writer, ctx *renderContext) (err error) { _, cfg, err := ctx.Config() if err != nil { return } zout, err := zstd.NewWriter(out, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(12))) if err != nil { return fmt.Errorf("zstd writer setup failed: %w", err) } cat := cpiocat.New(zout) // initrd initrdPath, err := distFetch("initrd", ctx.Host.Initrd) if err != nil { return } cat.AppendArchFile(initrdPath) // embedded layers (modules) for _, layer := range cfg.Layers { switch layer { case "modules": layerVersion := ctx.Host.Versions[layer] if layerVersion == "" { layerVersion = ctx.Host.Kernel } modulesPath, err := distFetch("layers", layer, layerVersion) if err != nil { return err } cat.AppendFile(modulesPath, "modules.sqfs") } } // config cfgBytes, err := ctx.BootstrapConfig() if err != nil { return } cat.AppendBytes(cfgBytes, "config.yaml", 0o600) // ssh keys cat.AppendDir("/etc", 0o755) cat.AppendDir("/etc/ssh", 0o700) // XXX do we want bootstrap-stage keys instead of the real host key? for _, format := range []string{"rsa", "ecdsa", "ed25519"} { keyPath := "/etc/ssh/ssh_host_" + format + "_key" cat.AppendBytes(cfg.FileContent(keyPath), keyPath, 0o600) } // ssh user CA userCA, err := sshCAPubKey(ctx.Host.ClusterName) if err != nil { return fmt.Errorf("failed to get SSH user CA: %w", err) } cat.AppendBytes(userCA, "user_ca.pub", 0600) if err = cat.Close(); err != nil { return fmt.Errorf("cpio close failed: %w", err) } if err = zout.Close(); err != nil { return fmt.Errorf("zstd close failed: %w", err) } return } func buildBootstrap(out io.Writer, ctx *renderContext) (err error) { arch := tar.NewWriter(out) defer arch.Close() ca, err := getUsableClusterCA(ctx.Host.ClusterName, "boot-signer") if err != nil { return } signer, err := ca.ParseKey() if err != nil { return } hash := crypto.SHA512 sign := func(name string, digest []byte) (err error) { sigBytes, err := signer.Sign(nil, digest, hash) if err != nil { err = fmt.Errorf("signing to %s failed: %w", name, err) return err } if err = arch.WriteHeader(&tar.Header{ Name: name, Size: int64(len(sigBytes)), Mode: 0o644, }); err != nil { return } _, err = io.Copy(arch, bytes.NewReader(sigBytes)) return } // config cfgBytes, cfg, err := ctx.Config() if err != nil { return err } err = arch.WriteHeader(&tar.Header{ Name: "config.yaml", Size: int64(len(cfgBytes)), Mode: 0o600, }) if err != nil { return } _, err = arch.Write(cfgBytes) if err != nil { return } { h := hash.New() h.Write(cfgBytes) err = sign("config.yaml.sig", h.Sum(nil)) if err != nil { return } } // layers appendSignedLayer := func(layer, layerPath string) (err error) { f, err := os.Open(layerPath) if err != nil { return err } defer f.Close() stat, err := f.Stat() if err != nil { return err } h := hash.New() reader := io.TeeReader(f, h) if err = arch.WriteHeader(&tar.Header{ Name: layer, Size: stat.Size(), Mode: 0o600, }); err != nil { return err } _, err = io.Copy(arch, reader) if err != nil { return err } digest := h.Sum(nil) err = sign(layer+".sig", digest) return } allErofs := true for _, layer := range cfg.Layers { if layer == "modules" { continue // modules are in the initrd with boot v2 } if !strings.HasSuffix(ctx.Host.Versions[layer], ".erofs") { allErofs = false break } } if allErofs { layerPath, e := layersCombo(ctx, cfg, signer) if e != nil { err = e return } if err = appendSignedLayer("merged", layerPath); err != nil { return } } else { for _, layer := range cfg.Layers { if layer == "modules" { continue // modules are in the initrd with boot v2 } layerPath, e := fetchHostLayer(ctx.Host, layer) if e != nil { err = e return } if err = appendSignedLayer(layer+".fs", layerPath); err != nil { return } } } return nil } func layersCombo(ctx *renderContext, cfg *config.Config, signer crypto.Signer) (path string, err error) { key := layersComboKey(ctx.Host, cfg) return opMutex(key, func() (path string, err error) { path = filepath.Join(*dataDir, "cache") if err = os.MkdirAll(path, 0o700); err != nil { return } path = filepath.Join(path, key) + ".fs" if _, statErr := os.Stat(path); statErr == nil { return // exists -> already done } workdir, err := os.MkdirTemp("/tmp", "layers") if err != nil { return } defer os.RemoveAll(workdir) tmpTar := filepath.Join(workdir, "output.tar") layers := append([]string{}, cfg.Layers...) slices.Reverse(layers) cmdOut := new(bytes.Buffer) run := func(prog string, arg ...string) bool { cmdOut.Reset() cmd := exec.Command(prog, arg...) cmd.Stdout = cmdOut // os.Stdout cmd.Stderr = os.Stderr if e := cmd.Run(); e != nil { err = fmt.Errorf("%s %q failed: %w", cmd.Path, cmd.Args, e) return false } return true } for i, layer := range layers { if layer == "modules" { continue // modules are in the initrd with boot v2 } layerFile, e := fetchHostLayer(ctx.Host, layer) if e != nil { err = e return } mountPoint := filepath.Join(workdir, layer) os.MkdirAll(mountPoint, 0700) if e := exec.Command("erofsfuse", layerFile, mountPoint).Run(); e != nil { err = fmt.Errorf("erofsfuse %s %s failed: %w", layerFile, mountPoint, e) return } defer func() { if err := exec.Command("umount", mountPoint).Run(); err != nil { log.Printf("umount %s failed: %v", mountPoint, err) } }() mode := "--append" if i == 0 { mode = "--create" } if !run("tar", mode, "-p", "-f", tmpTar, "-C", mountPoint, ".") { return } layers = append(layers, mountPoint) } fsOut := filepath.Join(workdir, "output.fs") if !run("mkfs.erofs", "-z", "lzma", "-C131072", "-Efragments,ztailpacking", "-T0", "--all-time", "--ignore-mtime", "--tar=f", fsOut, tmpTar) { return } hashOut := filepath.Join(workdir, "output.hash") if !run("veritysetup", "format", fsOut, hashOut) { return } var rootHash []byte for line := range strings.SplitSeq(cmdOut.String(), "\n") { v, ok := strings.CutPrefix(line, "Root hash:") if !ok { continue } v = strings.TrimSpace(v) b, e := hex.DecodeString(v) if e != nil { err = fmt.Errorf("invalid root hash: %w", e) return } rootHash = b break } if len(rootHash) == 0 { err = fmt.Errorf("root hash not found in output") return } sigBytes, err := signer.Sign(nil, rootHash, crypto.SHA256) if err != nil { err = fmt.Errorf("root hash signature failed: %w", err) return } outPath := path + ".tmp" err = func() (err error) { fsRd, e := os.Open(fsOut) if e != nil { return e } defer fsRd.Close() hashRd, e := os.Open(hashOut) if e != nil { return e } defer hashRd.Close() fsStat, e := fsRd.Stat() if e != nil { return e } hashStat, e := hashRd.Stat() if e != nil { return e } out, err := os.Create(outPath) if err != nil { return } defer out.Close() append := func(sz uint64, rd io.Reader) { if err != nil { return } szB := make([]byte, 8) binary.BigEndian.PutUint64(szB, sz) _, err = out.Write(szB) if err != nil { return } _, err = io.Copy(out, rd) } append(uint64(len(sigBytes)), bytes.NewBuffer(sigBytes)) append(uint64(len(rootHash)), bytes.NewBuffer(rootHash)) append(uint64(fsStat.Size()), fsRd) append(uint64(hashStat.Size()), hashRd) if err != nil { return } err = out.Close() return }() if err != nil { err = fmt.Errorf("assembly failed: %w", err) return } err = os.Rename(outPath, path) return }) } func layersComboKey(host *localconfig.Host, cfg *config.Config) string { key := new(strings.Builder) key.WriteString("layers") for _, layer := range cfg.Layers { if layer == "modules" { continue } key.WriteByte(':') key.WriteString(layer) if v, ok := host.Versions[layer]; ok { key.WriteByte('@') key.WriteString(v) } } return key.String() } func fetchHostLayer(host *localconfig.Host, layer string) (path string, err error) { layerVersion := host.Versions[layer] if layerVersion == "" { return "", fmt.Errorf("layer %q not mapped to a version", layer) } return distFetch("layers", layer, layerVersion) }