package main

import (
	"bytes"
	"fmt"
	"io"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/rs/zerolog/log"

	config "novit.tech/direktil/pkg/bootstrapconfig"
)

func bootstrap(cfg *config.Config) {
	if cfg.Bootstrap.Dev == "" {
		fatalf("bootstrap device not defined!")
	}

	const bsDir = "/bootstrap"
	os.MkdirAll(bsDir, 0700)

	run("mount", cfg.Bootstrap.Dev, bsDir)

	baseDir := filepath.Join(bsDir, bootVersion)
	sysCfgPath := filepath.Join(baseDir, "config.yaml")

	if _, err := os.Stat(sysCfgPath); os.IsNotExist(err) {
		log.Warn().Msgf("bootstrap %q does not exist", bootVersion)

		seed := cfg.Bootstrap.Seed
		if seed == "" {
			fatalf("boostrap seed not defined, admin required")
		}

		log.Info().Str("from", seed).Msgf("seeding bootstrap")

		err = os.MkdirAll(baseDir, 0700)
		if err != nil {
			fatalf("failed to create bootstrap dir: %v", err)
		}

		bootstrapFile := filepath.Join(baseDir, "bootstrap.tar")

		err = func() (err error) {
			var resp *http.Response

			start := time.Now()
			for time.Since(start) <= time.Minute {
				resp, err = http.Get(seed)
				if err == nil {
					break
				}
				time.Sleep(time.Second)
			}
			if err != nil {
				return fmt.Errorf("failed to fetch bootstrap")
			}

			if resp.StatusCode != http.StatusOK {
				err = fmt.Errorf("bad HTTP status: %s", resp.Status)
				return
			}

			defer resp.Body.Close()

			out, err := os.Create(bootstrapFile)
			if err != nil {
				return
			}

			defer out.Close()

			_, err = io.Copy(out, resp.Body)

			return
		}()

		if err != nil {
			fatalf("seeding failed: %v", err)
		}

		log.Info().Msg("unpacking bootstrap file")
		run("tar", "xvf", bootstrapFile, "-C", baseDir)
	}

	layersDir = baseDir
	layersOverride["modules"] = "/modules.sqfs"
	sysCfg := applyConfig(sysCfgPath, false)

	// load requested modules
	for _, mod := range sysCfg.Modules {
		log.Info().Str("module", mod).Msg("loading module")
		run("modprobe", mod)
	}

	// localy-generated assets dir
	localGenDir := filepath.Join(bsDir, "local-gen")

	// vpns are v2+
	for _, vpn := range sysCfg.VPNs {
		setupVPN(vpn, localGenDir)
	}

	// mounts are v2+
	for _, mount := range sysCfg.Mounts {
		log.Info().Str("source", mount.Dev).Str("target", mount.Path).Msg("mount")

		path := filepath.Join("/system", mount.Path)

		os.MkdirAll(path, 0755)

		args := []string{mount.Dev, path}
		if mount.Type != "" {
			args = append(args, "-t", mount.Type)
		}
		if mount.Options != "" {
			args = append(args, "-o", mount.Options)
		}

		run("mount", args...)
	}

	// setup root user
	if ph := sysCfg.RootUser.PasswordHash; ph != "" {
		log.Info().Msg("setting root's password")
		setUserPass("root", ph)
	}
	if ak := sysCfg.RootUser.AuthorizedKeys; len(ak) != 0 {
		log.Info().Msg("setting root's authorized keys")
		setAuthorizedKeys(ak)
	}

	// update-ca-certificates
	log.Info().Msg("updating CA certificates")
	run("chroot", "/system", "update-ca-certificates")
}

func setUserPass(user, passwordHash string) {
	const fpath = "/system/etc/shadow"

	ba, err := os.ReadFile(fpath)
	if err != nil {
		fatalf("failed to read shadow: %v", err)
	}

	lines := bytes.Split(ba, []byte{'\n'})

	buf := new(bytes.Buffer)
	for _, line := range lines {
		line := string(line)
		p := strings.Split(line, ":")
		if len(p) < 2 || p[0] != user {
			buf.WriteString(line)
			buf.WriteByte('\n')
			continue
		}

		p[1] = passwordHash
		line = strings.Join(p, ":")

		buf.WriteString(line)
		buf.WriteByte('\n')
	}

	err = os.WriteFile(fpath, buf.Bytes(), 0600)
	if err != nil {
		fatalf("failed to write shadow: %v", err)
	}
}

func setAuthorizedKeys(ak []string) {
	buf := new(bytes.Buffer)
	for _, k := range ak {
		buf.WriteString(k)
		buf.WriteByte('\n')
	}

	const sshDir = "/system/root/.ssh"
	err := os.MkdirAll(sshDir, 0700)
	if err != nil {
		fatalf("failed to create %s: %v", sshDir, err)
	}

	err = os.WriteFile(filepath.Join(sshDir, "authorized_keys"), buf.Bytes(), 0600)
	if err != nil {
		fatalf("failed to write authorized keys: %v", err)
	}
}