local-server/cmd/dkl-local-server/render-context.go
2024-04-15 15:42:40 +02:00

251 lines
4.9 KiB
Go

package main
import (
"bytes"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"net/url"
"path"
"path/filepath"
"text/template"
cfsslconfig "github.com/cloudflare/cfssl/config"
restful "github.com/emicklei/go-restful"
yaml "gopkg.in/yaml.v2"
"novit.tech/direktil/pkg/config"
"novit.tech/direktil/pkg/localconfig"
bsconfig "novit.tech/direktil/pkg/bootstrapconfig"
)
var cmdlineParam = restful.QueryParameter("cmdline", "Linux kernel cmdline addition")
type renderContext struct {
Host *localconfig.Host
SSLConfig *cfsslconfig.Config
// Linux kernel extra cmdline
CmdLine string `yaml:"-"`
}
func renderCtx(w http.ResponseWriter, r *http.Request, ctx *renderContext, what string,
create func(out io.Writer, ctx *renderContext) error) error {
tag, err := ctx.Tag()
if err != nil {
return err
}
ctx.CmdLine = r.URL.Query().Get(cmdlineParam.Data().Name)
if ctx.CmdLine != "" {
what = what + "?cmdline=" + url.QueryEscape(ctx.CmdLine)
}
// get it or create it
content, meta, err := casStore.GetOrCreate(tag, what, func(out io.Writer) error {
log.Printf("building %s for %q", what, ctx.Host.Name)
return create(out, ctx)
})
if err != nil {
return err
}
// serve it
log.Printf("sending %s for %q", what, ctx.Host.Name)
http.ServeContent(w, r, what, meta.ModTime(), content)
return nil
}
func sslConfigFromLocalConfig(cfg *localconfig.Config) (sslCfg *cfsslconfig.Config, err error) {
if len(cfg.SSLConfig) == 0 {
sslCfg = &cfsslconfig.Config{}
} else {
sslCfg, err = cfsslconfig.LoadConfig([]byte(cfg.SSLConfig))
if err != nil {
return
}
}
return
}
func newRenderContext(host *localconfig.Host, cfg *localconfig.Config) (ctx *renderContext, err error) {
sslCfg, err := sslConfigFromLocalConfig(cfg)
if err != nil {
return
}
return &renderContext{
Host: host,
SSLConfig: sslCfg,
}, nil
}
func (ctx *renderContext) Config() (ba []byte, cfg *config.Config, err error) {
ba, err = ctx.render(ctx.Host.Config)
if err != nil {
return
}
cfg = &config.Config{}
if err = yaml.Unmarshal(ba, cfg); err != nil {
return
}
return
}
func (ctx *renderContext) BootstrapConfig() (ba []byte, cfg *bsconfig.Config, err error) {
ba, err = ctx.render(ctx.Host.BootstrapConfig)
if err != nil {
return
}
cfg = &bsconfig.Config{}
if err = yaml.Unmarshal(ba, cfg); err != nil {
return
}
return
}
func (ctx *renderContext) render(templateText string) (ba []byte, err error) {
tmpl, err := template.New(ctx.Host.Name + "/config").
Funcs(ctx.TemplateFuncs()).
Parse(templateText)
if err != nil {
return
}
buf := bytes.NewBuffer(make([]byte, 0, 4096))
if err = tmpl.Execute(buf, nil); err != nil {
return
}
ba = buf.Bytes()
return
}
func (ctx *renderContext) distFilePath(path ...string) string {
return filepath.Join(append([]string{*dataDir, "dist"}, path...)...)
}
func (ctx *renderContext) Tag() (string, error) {
h := sha256.New()
_, cfg, err := ctx.Config()
if err != nil {
return "", err
}
enc := yaml.NewEncoder(h)
for _, o := range []interface{}{cfg, ctx} {
if err := enc.Encode(o); err != nil {
return "", err
}
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func asMap(v interface{}) map[string]interface{} {
ba, err := yaml.Marshal(v)
if err != nil {
panic(err) // shouldn't happen
}
result := make(map[string]interface{})
if err := yaml.Unmarshal(ba, result); err != nil {
panic(err) // shouldn't happen
}
return result
}
func (ctx *renderContext) TemplateFuncs() map[string]any {
funcs := templateFuncs(ctx.SSLConfig)
for name, method := range map[string]any{
"host_ip": func() (s string) {
return ctx.Host.IPs[0]
},
"host_name": func() (s string) {
return ctx.Host.Name
},
"machine_id": func() (s string) {
ba := sha1.Sum([]byte(ctx.Host.ClusterName + "/" + ctx.Host.Name))
return hex.EncodeToString(ba[:])
},
"ssh_host_keys": func(dir, cluster, host string) (s string, err error) {
if host == "" {
host = ctx.Host.Name
}
if host != ctx.Host.Name {
err = fmt.Errorf("wrong host name")
return
}
pairs, err := getSSHKeyPairs(host)
if err != nil {
return
}
files := make([]config.FileDef, 0, len(pairs)*2)
for _, pair := range pairs {
basePath := path.Join(dir, "ssh_host_"+pair.Type+"_key")
files = append(files, []config.FileDef{
{
Path: basePath,
Mode: 0600,
Content: pair.Private,
},
{
Path: basePath + ".pub",
Mode: 0644,
Content: pair.Public,
},
}...)
}
return asYaml(files)
},
"host_download_token": func() (token string, err error) {
key := ctx.Host.Name
token, found, err := hostDownloadTokens.Get(key)
if err != nil {
return
}
if !found {
token, err = newToken(32)
if err != nil {
return
}
err = hostDownloadTokens.Put(key, token)
if err != nil {
return
}
}
return
},
} {
funcs[name] = method
}
return funcs
}