host download tokens

This commit is contained in:
Mikaël Cluseau 2023-02-13 15:57:30 +01:00
parent 1e3ac9a0fb
commit bde41c9859
6 changed files with 117 additions and 39 deletions

View File

@ -301,8 +301,11 @@ func (ctx *renderContext) templateFuncs(ctxMap map[string]interface{}) map[strin
}, },
"ssh_host_keys": func(dir string) (s string) { "ssh_host_keys": func(dir string) (s string) {
return fmt.Sprintf("{{ ssh_host_keys %q %q %q}}", return fmt.Sprintf("{{ ssh_host_keys %q %q \"\"}}",
dir, cluster, ctx.Host.Name) dir, cluster)
},
"host_download_token": func() (s string) {
return "{{ host_download_token }}"
}, },
"hosts_of_group": func() (hosts []interface{}) { "hosts_of_group": func() (hosts []interface{}) {

View File

@ -133,33 +133,6 @@ func templateFuncs(sslCfg *cfsslconfig.Config) map[string]any {
}, },
}) })
}, },
"ssh_host_keys": func(dir, cluster, host string) (s string, err error) {
pairs, err := getSSHKeyPairs(host)
if err != nil {
return
}
files := make([]config.FileDef, 0, len(pairs)*2)
for _, pair := range pairs {
basePath := path.Join(dir, "ssh_host_"+pair.Type+"_key")
files = append(files, []config.FileDef{
{
Path: basePath,
Mode: 0600,
Content: pair.Private,
},
{
Path: basePath + ".pub",
Mode: 0644,
Content: pair.Public,
},
}...)
}
return asYaml(files)
},
} }
} }

View File

@ -0,0 +1,3 @@
package main
var hostDownloadTokens = KVSecrets[string]{"hosts/download-tokens"}

View File

@ -4,10 +4,12 @@ import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"path"
"path/filepath" "path/filepath"
"text/template" "text/template"
@ -115,7 +117,7 @@ func (ctx *renderContext) BootstrapConfig() (ba []byte, cfg *bsconfig.Config, er
func (ctx *renderContext) render(templateText string) (ba []byte, err error) { func (ctx *renderContext) render(templateText string) (ba []byte, err error) {
tmpl, err := template.New(ctx.Host.Name + "/config"). tmpl, err := template.New(ctx.Host.Name + "/config").
Funcs(templateFuncs(ctx.SSLConfig)). Funcs(ctx.TemplateFuncs()).
Parse(templateText) Parse(templateText)
if err != nil { if err != nil {
@ -168,3 +170,69 @@ func asMap(v interface{}) map[string]interface{} {
return result return result
} }
func (ctx *renderContext) TemplateFuncs() map[string]any {
funcs := templateFuncs(ctx.SSLConfig)
for name, method := range map[string]any{
"ssh_host_keys": func(dir, cluster, host string) (s string, err error) {
if host == "" {
host = ctx.Host.Name
}
if host != ctx.Host.Name {
err = fmt.Errorf("wrong host name")
return
}
pairs, err := getSSHKeyPairs(host)
if err != nil {
return
}
files := make([]config.FileDef, 0, len(pairs)*2)
for _, pair := range pairs {
basePath := path.Join(dir, "ssh_host_"+pair.Type+"_key")
files = append(files, []config.FileDef{
{
Path: basePath,
Mode: 0600,
Content: pair.Private,
},
{
Path: basePath + ".pub",
Mode: 0644,
Content: pair.Public,
},
}...)
}
return asYaml(files)
},
"host_download_token": func() (token string, err error) {
key := ctx.Host.Name
token, found, err := hostDownloadTokens.Get(key)
if err != nil {
return
}
if !found {
token, err = newToken(32)
if err != nil {
return
}
err = hostDownloadTokens.Put(key, token)
if err != nil {
return
}
}
return
},
} {
funcs[name] = method
}
return funcs
}

View File

@ -18,7 +18,7 @@ var trustXFF = flag.Bool("trust-xff", true, "Trust the X-Forwarded-For header")
type wsHost struct { type wsHost struct {
prefix string prefix string
hostDoc string hostDoc string
getHost func(req *restful.Request) string getHost func(req *restful.Request) (hostName string, err error)
} }
func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) { func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) {
@ -105,13 +105,17 @@ func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteB
} }
func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) { func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) {
hostname := ws.getHost(req) hostname, err := ws.getHost(req)
if err != nil {
wsError(resp, err)
return
}
if hostname == "" { if hostname == "" {
wsNotFound(req, resp) wsNotFound(req, resp)
return return
} }
cfg, err := readConfig() cfg, err = readConfig()
if err != nil { if err != nil {
wsError(resp, err) wsError(resp, err)
return return

View File

@ -120,8 +120,8 @@ func registerWS(rest *restful.Container) {
(&wsHost{ (&wsHost{
prefix: "/hosts/{host-name}", prefix: "/hosts/{host-name}",
hostDoc: "given host", hostDoc: "given host",
getHost: func(req *restful.Request) string { getHost: func(req *restful.Request) (string, error) {
return req.PathParameter("host-name") return req.PathParameter("host-name"), nil
}, },
}).register(ws, func(rb *restful.RouteBuilder) { }).register(ws, func(rb *restful.RouteBuilder) {
}) })
@ -146,10 +146,37 @@ func registerWS(rest *restful.Container) {
rb.Notes("In this case, the host is detected from the remote IP") rb.Notes("In this case, the host is detected from the remote IP")
}) })
// Hosts by token API
ws = &restful.WebService{}
ws.Path("/hosts-by-token/{host-token}")
(&wsHost{
hostDoc: "token's host",
getHost: func(req *restful.Request) (host string, err error) {
reqToken := req.PathParameter("host-name")
data, err := hostDownloadTokens.Data()
if err != nil {
return
}
for h, token := range data {
if token == reqToken {
host = h
return
}
}
return
},
}).register(ws, func(rb *restful.RouteBuilder) {
rb.Notes("In this case, the host is detected from the remote IP")
})
rest.Add(ws) rest.Add(ws)
} }
func detectHost(req *restful.Request) string { func detectHost(req *restful.Request) (hostName string, err error) {
r := req.Request r := req.Request
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
@ -167,17 +194,17 @@ func detectHost(req *restful.Request) string {
cfg, err := readConfig() cfg, err := readConfig()
if err != nil { if err != nil {
return "" return
} }
host := cfg.HostByIP(hostIP) host := cfg.HostByIP(hostIP)
if host == nil { if host == nil {
log.Print("no host found for IP ", hostIP) log.Print("no host found for IP ", hostIP)
return "" return
} }
return host.Name return host.Name, nil
} }
func wsReadConfig(resp *restful.Response) *localconfig.Config { func wsReadConfig(resp *restful.Response) *localconfig.Config {