package main import ( "bytes" "fmt" "io" "log" "math/rand" "path" "reflect" "strings" "github.com/cespare/xxhash" yaml "gopkg.in/yaml.v2" "novit.tech/direktil/pkg/config" "novit.tech/direktil/local-server/pkg/clustersconfig" ) type renderContext struct { Labels map[string]string Annotations map[string]string Host *clustersconfig.Host Cluster *clustersconfig.Cluster Vars map[string]any BootstrapConfigTemplate *clustersconfig.Template ConfigTemplate *clustersconfig.Template StaticPodsTemplate *clustersconfig.Template clusterConfig *clustersconfig.Config } func newRenderContext(host *clustersconfig.Host, cfg *clustersconfig.Config) (ctx *renderContext, err error) { cluster := cfg.Cluster(host.Cluster) if cluster == nil { err = fmt.Errorf("no cluster named %q", host.Cluster) return } vars := make(map[string]any) for _, oVars := range []map[string]any{ cluster.Vars, host.Vars, } { mapMerge(vars, oVars) } return &renderContext{ Labels: mergeLabels(cluster.Labels, host.Labels), Annotations: mergeLabels(cluster.Annotations, host.Annotations), Host: host, Cluster: cluster, Vars: vars, BootstrapConfigTemplate: cfg.ConfigTemplate(host.BootstrapConfig), ConfigTemplate: cfg.ConfigTemplate(host.Config), clusterConfig: cfg, }, nil } func mergeLabels(sources ...map[string]string) map[string]string { ret := map[string]string{} for _, src := range sources { for k, v := range src { ret[k] = v } } return ret } func mapMerge(target, source map[string]interface{}) { for k, v := range source { target[k] = genericMerge(target[k], v) } } func genericMerge(target, source interface{}) (result interface{}) { srcV := reflect.ValueOf(source) tgtV := reflect.ValueOf(target) if srcV.Kind() == reflect.Map && tgtV.Kind() == reflect.Map { // XXX maybe more specific later result = map[interface{}]interface{}{} resultV := reflect.ValueOf(result) tgtIt := tgtV.MapRange() for tgtIt.Next() { sv := srcV.MapIndex(tgtIt.Key()) if sv.Kind() == 0 { resultV.SetMapIndex(tgtIt.Key(), tgtIt.Value()) continue } merged := genericMerge(tgtIt.Value().Interface(), sv.Interface()) resultV.SetMapIndex(tgtIt.Key(), reflect.ValueOf(merged)) } srcIt := srcV.MapRange() for srcIt.Next() { if resultV.MapIndex(srcIt.Key()).Kind() != 0 { continue // already done } resultV.SetMapIndex(srcIt.Key(), srcIt.Value()) } return } return source } func (ctx *renderContext) Name() string { switch { case ctx.Host != nil: return "host:" + ctx.Host.Name case ctx.Cluster != nil: return "cluster:" + ctx.Cluster.Name default: return "unknown" } } func (ctx *renderContext) BootstrapConfig() string { if ctx.BootstrapConfigTemplate == nil { log.Fatalf("no such (bootstrap) config: %q", ctx.Host.BootstrapConfig) } return ctx.renderConfig(ctx.BootstrapConfigTemplate) } func (ctx *renderContext) Config() string { if ctx.ConfigTemplate == nil { log.Fatalf("no such config: %q", ctx.Host.Config) } return ctx.renderConfig(ctx.ConfigTemplate) } func (ctx *renderContext) renderConfig(configTemplate *clustersconfig.Template) string { buf := new(strings.Builder) ctx.renderConfigTo(buf, configTemplate) return buf.String() } func (ctx *renderContext) renderConfigTo(buf io.Writer, configTemplate *clustersconfig.Template) { ctxName := ctx.Name() ctxMap := ctx.asMap() extraFuncs := ctx.templateFuncs(ctxMap) extraFuncs["static_pods_files"] = func(dir string) (string, error) { namePods := ctx.renderStaticPods() defs := make([]config.FileDef, 0) for _, namePod := range namePods { name := namePod.Namespace + "_" + namePod.Name ba, err := yaml.Marshal(namePod.Pod) if err != nil { return "", fmt.Errorf("static pod %s: failed to render: %v", name, err) } defs = append(defs, config.FileDef{ Path: path.Join(dir, name+".yaml"), Mode: 0640, Content: string(ba), }) } ba, err := yaml.Marshal(defs) return string(ba), err } extraFuncs["host_ip"] = func() string { if ctx.Host.Template { return "{{ host_ip }}" } return ctx.Host.IP } extraFuncs["host_name"] = func() string { if ctx.Host.Template { return "{{ host_name }}" } return ctx.Host.Name } extraFuncs["machine_id"] = func() string { return "{{ machine_id }}" } extraFuncs["version"] = func() string { return Version } if err := configTemplate.Execute(ctxName, "config", buf, ctxMap, extraFuncs); err != nil { log.Fatalf("failed to render config %q for host %q: %v", ctx.Host.Config, ctx.Host.Name, err) } } func (ctx *renderContext) templateFuncs(ctxMap map[string]any) map[string]interface{} { cluster := ctx.Cluster.Name getKeyCert := func(name, funcName string) (s string, err error) { req := ctx.clusterConfig.CSR(name) if req == nil { err = fmt.Errorf("no certificate request named %q", name) return } if req.CA == "" { err = fmt.Errorf("CA not defined in req %q", name) return } buf := &bytes.Buffer{} err = req.Execute(ctx.Name(), "req:"+name, buf, ctxMap, nil) if err != nil { return } key := name if req.PerHost { key += "/" + ctx.Host.Name } if funcName == "tls_dir" { // needs the dir name dir := "/etc/tls/" + name s = fmt.Sprintf("{{ %s %q %q %q %q %q %q %q }}", funcName, dir, cluster, req.CA, key, req.Profile, req.Label, buf.String()) } else { s = fmt.Sprintf("{{ %s %q %q %q %q %q %q }}", funcName, cluster, req.CA, key, req.Profile, req.Label, buf.String()) } return } funcs := clusterFuncs(ctx.Cluster) for k, v := range map[string]any{ "default": func(value, defaultValue any) any { switch v := value.(type) { case string: if v != "" { return v } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64: if v != 0 { return v } default: if v != nil { return v } } return defaultValue }, "tls_key": func(name string) (string, error) { return getKeyCert(name, "tls_key") }, "tls_crt": func(name string) (s string, err error) { return getKeyCert(name, "tls_crt") }, "tls_dir": func(name string) (s string, err error) { return getKeyCert(name, "tls_dir") }, "ssh_host_keys": func(dir string) (s string) { return fmt.Sprintf("{{ ssh_host_keys %q %q \"\"}}", dir, cluster) }, "host_download_token": func() (s string) { return "{{ host_download_token }}" }, "hosts_of_group": func() (hosts []any) { hosts = make([]any, 0) for _, host := range ctx.clusterConfig.Hosts { if host.Cluster == ctx.Cluster.Name && host.Group != ctx.Host.Group { continue } hosts = append(hosts, asMap(host)) } return hosts }, "hosts_of_group_count": func() (count int) { for _, host := range ctx.clusterConfig.Hosts { if host.Cluster == ctx.Cluster.Name && host.Group == ctx.Host.Group { count++ } } return }, "shuffled_hosts_by_group": func(group string) (hosts []any) { for _, host := range src.Hosts { if host.Cluster == ctx.Cluster.Name && host.Group == group { hosts = append(hosts, asMap(host)) } } if len(hosts) == 0 { log.Printf("WARNING: no hosts in group %q", group) return } seed := xxhash.Sum64String(ctx.Host.Name) rng := rand.New(rand.NewSource(int64(seed))) rng.Shuffle(len(hosts), func(i, j int) { hosts[i], hosts[j] = hosts[j], hosts[i] }) return }, } { funcs[k] = v } return funcs } func (ctx *renderContext) asMap() map[string]interface{} { result := asMap(ctx) // also expand cluster: cluster := result["cluster"].(map[interface{}]interface{}) cluster["kubernetes_svc_ip"] = ctx.Cluster.KubernetesSvcIP().String() cluster["dns_svc_ip"] = ctx.Cluster.DNSSvcIP().String() return result } func asMap(v interface{}) map[string]interface{} { ba, err := yaml.Marshal(v) if err != nil { panic(err) // shouldn't happen } result := make(map[string]interface{}) if err := yaml.Unmarshal(ba, result); err != nil { panic(err) // shouldn't happen } return result }