diff --git a/cmd/dkl-dir2config/render-context.go b/cmd/dkl-dir2config/render-context.go index 5f3b2b7..414e340 100644 --- a/cmd/dkl-dir2config/render-context.go +++ b/cmd/dkl-dir2config/render-context.go @@ -301,8 +301,11 @@ func (ctx *renderContext) templateFuncs(ctxMap map[string]interface{}) map[strin }, "ssh_host_keys": func(dir string) (s string) { - return fmt.Sprintf("{{ ssh_host_keys %q %q %q}}", - dir, cluster, ctx.Host.Name) + return fmt.Sprintf("{{ ssh_host_keys %q %q \"\"}}", + dir, cluster) + }, + "host_download_token": func() (s string) { + return "{{ host_download_token }}" }, "hosts_of_group": func() (hosts []interface{}) { diff --git a/cmd/dkl-local-server/cluster-render-context.go b/cmd/dkl-local-server/cluster-render-context.go index 988115c..75ca8c6 100644 --- a/cmd/dkl-local-server/cluster-render-context.go +++ b/cmd/dkl-local-server/cluster-render-context.go @@ -133,33 +133,6 @@ func templateFuncs(sslCfg *cfsslconfig.Config) map[string]any { }, }) }, - - "ssh_host_keys": func(dir, cluster, host string) (s string, err error) { - pairs, err := getSSHKeyPairs(host) - if err != nil { - return - } - - files := make([]config.FileDef, 0, len(pairs)*2) - - for _, pair := range pairs { - basePath := path.Join(dir, "ssh_host_"+pair.Type+"_key") - files = append(files, []config.FileDef{ - { - Path: basePath, - Mode: 0600, - Content: pair.Private, - }, - { - Path: basePath + ".pub", - Mode: 0644, - Content: pair.Public, - }, - }...) - } - - return asYaml(files) - }, } } diff --git a/cmd/dkl-local-server/host-download-tokens.go b/cmd/dkl-local-server/host-download-tokens.go new file mode 100644 index 0000000..7f03ed5 --- /dev/null +++ b/cmd/dkl-local-server/host-download-tokens.go @@ -0,0 +1,3 @@ +package main + +var hostDownloadTokens = KVSecrets[string]{"hosts/download-tokens"} diff --git a/cmd/dkl-local-server/render-context.go b/cmd/dkl-local-server/render-context.go index 05ea51f..8219776 100644 --- a/cmd/dkl-local-server/render-context.go +++ b/cmd/dkl-local-server/render-context.go @@ -4,10 +4,12 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "fmt" "io" "log" "net/http" "net/url" + "path" "path/filepath" "text/template" @@ -115,7 +117,7 @@ func (ctx *renderContext) BootstrapConfig() (ba []byte, cfg *bsconfig.Config, er func (ctx *renderContext) render(templateText string) (ba []byte, err error) { tmpl, err := template.New(ctx.Host.Name + "/config"). - Funcs(templateFuncs(ctx.SSLConfig)). + Funcs(ctx.TemplateFuncs()). Parse(templateText) if err != nil { @@ -168,3 +170,69 @@ func asMap(v interface{}) map[string]interface{} { return result } + +func (ctx *renderContext) TemplateFuncs() map[string]any { + funcs := templateFuncs(ctx.SSLConfig) + + for name, method := range map[string]any{ + "ssh_host_keys": func(dir, cluster, host string) (s string, err error) { + if host == "" { + host = ctx.Host.Name + } + if host != ctx.Host.Name { + err = fmt.Errorf("wrong host name") + return + } + + pairs, err := getSSHKeyPairs(host) + if err != nil { + return + } + + files := make([]config.FileDef, 0, len(pairs)*2) + + for _, pair := range pairs { + basePath := path.Join(dir, "ssh_host_"+pair.Type+"_key") + files = append(files, []config.FileDef{ + { + Path: basePath, + Mode: 0600, + Content: pair.Private, + }, + { + Path: basePath + ".pub", + Mode: 0644, + Content: pair.Public, + }, + }...) + } + + return asYaml(files) + }, + "host_download_token": func() (token string, err error) { + key := ctx.Host.Name + token, found, err := hostDownloadTokens.Get(key) + if err != nil { + return + } + + if !found { + token, err = newToken(32) + if err != nil { + return + } + + err = hostDownloadTokens.Put(key, token) + if err != nil { + return + } + } + + return + }, + } { + funcs[name] = method + } + + return funcs +} diff --git a/cmd/dkl-local-server/ws-host.go b/cmd/dkl-local-server/ws-host.go index 89a5afc..3a0e147 100644 --- a/cmd/dkl-local-server/ws-host.go +++ b/cmd/dkl-local-server/ws-host.go @@ -18,7 +18,7 @@ var trustXFF = flag.Bool("trust-xff", true, "Trust the X-Forwarded-For header") type wsHost struct { prefix string hostDoc string - getHost func(req *restful.Request) string + getHost func(req *restful.Request) (hostName string, err error) } func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) { @@ -105,13 +105,17 @@ func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteB } func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) { - hostname := ws.getHost(req) + hostname, err := ws.getHost(req) + if err != nil { + wsError(resp, err) + return + } if hostname == "" { wsNotFound(req, resp) return } - cfg, err := readConfig() + cfg, err = readConfig() if err != nil { wsError(resp, err) return diff --git a/cmd/dkl-local-server/ws.go b/cmd/dkl-local-server/ws.go index b0a7d8f..0b02e2e 100644 --- a/cmd/dkl-local-server/ws.go +++ b/cmd/dkl-local-server/ws.go @@ -120,8 +120,8 @@ func registerWS(rest *restful.Container) { (&wsHost{ prefix: "/hosts/{host-name}", hostDoc: "given host", - getHost: func(req *restful.Request) string { - return req.PathParameter("host-name") + getHost: func(req *restful.Request) (string, error) { + return req.PathParameter("host-name"), nil }, }).register(ws, func(rb *restful.RouteBuilder) { }) @@ -146,10 +146,37 @@ func registerWS(rest *restful.Container) { rb.Notes("In this case, the host is detected from the remote IP") }) + // Hosts by token API + ws = &restful.WebService{} + ws.Path("/hosts-by-token/{host-token}") + + (&wsHost{ + hostDoc: "token's host", + getHost: func(req *restful.Request) (host string, err error) { + reqToken := req.PathParameter("host-name") + + data, err := hostDownloadTokens.Data() + if err != nil { + return + } + + for h, token := range data { + if token == reqToken { + host = h + return + } + } + + return + }, + }).register(ws, func(rb *restful.RouteBuilder) { + rb.Notes("In this case, the host is detected from the remote IP") + }) + rest.Add(ws) } -func detectHost(req *restful.Request) string { +func detectHost(req *restful.Request) (hostName string, err error) { r := req.Request remoteAddr := r.RemoteAddr @@ -167,17 +194,17 @@ func detectHost(req *restful.Request) string { cfg, err := readConfig() if err != nil { - return "" + return } host := cfg.HostByIP(hostIP) if host == nil { log.Print("no host found for IP ", hostIP) - return "" + return } - return host.Name + return host.Name, nil } func wsReadConfig(resp *restful.Response) *localconfig.Config {