diff --git a/cmd/dkl-local-server/auth.go b/cmd/dkl-local-server/auth.go index 94817f0..5afd882 100644 --- a/cmd/dkl-local-server/auth.go +++ b/cmd/dkl-local-server/auth.go @@ -1,22 +1,14 @@ package main import ( - "flag" "log" "net/http" ) -var ( - hostsToken = flag.String("hosts-token", "", "Token to give to access /hosts (open is none)") - adminToken = flag.String("admin-token", "", "Token to give to access to admin actions (open is none)") -) - -func authorizeHosts(r *http.Request) bool { - return authorizeToken(r, *hostsToken) -} +var adminToken string func authorizeAdmin(r *http.Request) bool { - return authorizeToken(r, *adminToken) + return authorizeToken(r, adminToken) } func authorizeToken(r *http.Request, token string) bool { @@ -49,9 +41,5 @@ func requireToken(token string, handler http.Handler) http.Handler { } func requireAdmin(handler http.Handler) http.Handler { - return requireToken(*adminToken, handler) -} - -func requireHosts(handler http.Handler) http.Handler { - return requireToken(*hostsToken, handler) + return requireToken(adminToken, handler) } diff --git a/cmd/dkl-local-server/httperr.go b/cmd/dkl-local-server/httperr.go index 8b553d8..c9fdc54 100644 --- a/cmd/dkl-local-server/httperr.go +++ b/cmd/dkl-local-server/httperr.go @@ -10,5 +10,7 @@ var ( ErrNotFound = httperr.NotFound ErrUnauthorized = httperr.StdStatus(http.StatusUnauthorized) ErrForbidden = httperr.StdStatus(http.StatusForbidden) + ErrInternal = httperr.StdStatus(http.StatusInternalServerError) ErrInvalidToken = httperr.NewStd(1000, http.StatusForbidden, "invalid token") + ErrStoreLocked = httperr.NewStd(1001, http.StatusServiceUnavailable, "store is locked") ) diff --git a/cmd/dkl-local-server/main.go b/cmd/dkl-local-server/main.go index 82a7002..af329a2 100644 --- a/cmd/dkl-local-server/main.go +++ b/cmd/dkl-local-server/main.go @@ -57,7 +57,7 @@ func main() { log.Fatal(err) } - log.Print("store auto-unlocked, admin token is ", *adminToken) + log.Print("store auto-unlocked, admin token is ", adminToken) } os.Setenv("DLS_AUTO_UNLOCK", "") diff --git a/cmd/dkl-local-server/secret-store.go b/cmd/dkl-local-server/secret-store.go index 0b778c9..94b7950 100644 --- a/cmd/dkl-local-server/secret-store.go +++ b/cmd/dkl-local-server/secret-store.go @@ -117,7 +117,7 @@ func unlockSecretStore(passphrase []byte) (err httperr.Error) { log.Print("wrote new admin token") } - *adminToken = token + adminToken = token { token, err := newToken(16) diff --git a/cmd/dkl-local-server/ws-auth.go b/cmd/dkl-local-server/ws-auth.go index 0346016..4dee483 100644 --- a/cmd/dkl-local-server/ws-auth.go +++ b/cmd/dkl-local-server/ws-auth.go @@ -7,18 +7,14 @@ import ( ) func adminAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { - tokenAuth(req, resp, chain, *adminToken) -} - -func hostsAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { - tokenAuth(req, resp, chain, *hostsToken, *adminToken) + tokenAuth(req, resp, chain, adminToken) } func tokenAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, allowedTokens ...string) { token := getToken(req) for _, allowedToken := range allowedTokens { - if allowedToken == "" || token == allowedToken { + if allowedToken != "" && token == allowedToken { chain.ProcessFilter(req, resp) return } diff --git a/cmd/dkl-local-server/ws-host.go b/cmd/dkl-local-server/ws-host.go index d23c46a..122894c 100644 --- a/cmd/dkl-local-server/ws-host.go +++ b/cmd/dkl-local-server/ws-host.go @@ -13,21 +13,23 @@ import ( "novit.tech/direktil/local-server/pkg/mime" ) -var trustXFF = flag.Bool("trust-xff", true, "Trust the X-Forwarded-For header") +var ( + allowDetectedHost = flag.Bool("allow-detected-host", false, "Allow access to host assets from its IP (insecure but enables unattended netboot)") + trustXFF = flag.Bool("trust-xff", false, "Trust the X-Forwarded-For header") +) type wsHost struct { - prefix string hostDoc string getHost func(req *restful.Request) (hostName string, err error) } -func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) { +func (ws wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) { b := func(what string) *restful.RouteBuilder { - return rws.GET(ws.prefix + "/" + what).To(ws.render) + return rws.GET("/" + what).To(ws.render) } for _, rb := range []*restful.RouteBuilder{ - rws.GET(ws.prefix).To(ws.get). + rws.GET("").To(ws.get). Doc("Get the "+ws.hostDoc+"'s details"). Returns(200, "OK", localconfig.Host{}), @@ -104,7 +106,7 @@ func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteB } } -func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) { +func (ws wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) { hostname, err := ws.getHost(req) if err != nil { wsError(resp, err) @@ -130,7 +132,7 @@ func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *loca return } -func (ws *wsHost) get(req *restful.Request, resp *restful.Response) { +func (ws wsHost) get(req *restful.Request, resp *restful.Response) { host, _ := ws.host(req, resp) if host == nil { return @@ -139,7 +141,7 @@ func (ws *wsHost) get(req *restful.Request, resp *restful.Response) { resp.WriteEntity(host) } -func (ws *wsHost) render(req *restful.Request, resp *restful.Response) { +func (ws wsHost) render(req *restful.Request, resp *restful.Response) { host, cfg := ws.host(req, resp) if host == nil { return diff --git a/cmd/dkl-local-server/ws-public.go b/cmd/dkl-local-server/ws-public.go index 006428f..5e3195c 100644 --- a/cmd/dkl-local-server/ws-public.go +++ b/cmd/dkl-local-server/ws-public.go @@ -24,7 +24,7 @@ func wsUnlockStore(req *restful.Request, resp *restful.Response) { return } - resp.WriteEntity(*adminToken) + resp.WriteEntity(adminToken) } func wsStoreDownload(req *restful.Request, resp *restful.Response) { diff --git a/cmd/dkl-local-server/ws.go b/cmd/dkl-local-server/ws.go index e153bbc..1639322 100644 --- a/cmd/dkl-local-server/ws.go +++ b/cmd/dkl-local-server/ws.go @@ -43,10 +43,10 @@ func registerWS(rest *restful.Container) { } // Admin-level APIs - ws := &restful.WebService{} - ws. + ws := (&restful.WebService{}). + Filter(requireSecStore). Filter(adminAuth). - Param(ws.HeaderParameter("Authorization", "Admin bearer token").Required(true)). + Param(restful.HeaderParameter("Authorization", "Admin bearer token").Required(true)). Produces(mime.JSON) // - store management @@ -118,8 +118,20 @@ func registerWS(rest *restful.Container) { ws.Route(ws.GET("/hosts").To(wsListHosts). Doc("List hosts")) + ws.Route(ws.GET("/ssh-acls").To(wsSSH_ACL_List)) + ws.Route(ws.GET("/ssh-acls/{acl-name}").To(wsSSH_ACL_Get)) + ws.Route(ws.PUT("/ssh-acls/{acl-name}").To(wsSSH_ACL_Set)) + + rest.Add(ws) + + // Hosts API + ws = (&restful.WebService{}). + Filter(requireSecStore). + Filter(adminAuth). + Path("/hosts/{host-name}"). + Param(ws.HeaderParameter("Authorization", "Host or admin bearer token")) + (&wsHost{ - prefix: "/hosts/{host-name}", hostDoc: "given host", getHost: func(req *restful.Request) (string, error) { return req.PathParameter("host-name"), nil @@ -128,17 +140,12 @@ func registerWS(rest *restful.Container) { rb.Param(ws.PathParameter("host-name", "host's name")) }) - ws.Route(ws.GET("/ssh-acls").To(wsSSH_ACL_List)) - ws.Route(ws.GET("/ssh-acls/{acl-name}").To(wsSSH_ACL_Get)) - ws.Route(ws.PUT("/ssh-acls/{acl-name}").To(wsSSH_ACL_Set)) - rest.Add(ws) - // Hosts API - ws = &restful.WebService{} - ws.Produces(mime.JSON). + // Detected host API + ws = (&restful.WebService{}). + Filter(requireSecStore). Path("/me"). - Filter(hostsAuth). Param(ws.HeaderParameter("Authorization", "Host or admin bearer token")) (&wsHost{ @@ -149,8 +156,10 @@ func registerWS(rest *restful.Container) { }) // Hosts by token API - ws = &restful.WebService{} - ws.Path("/hosts-by-token/{host-token}").Param(ws.PathParameter("host-token", "host's download token")) + ws = (&restful.WebService{}). + Filter(requireSecStore). + Path("/hosts-by-token/{host-token}"). + Param(ws.PathParameter("host-token", "host's download token")) (&wsHost{ hostDoc: "token's host", @@ -178,7 +187,19 @@ func registerWS(rest *restful.Container) { rest.Add(ws) } +func requireSecStore(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { + if !secStore.Unlocked() { + wsError(resp, ErrStoreLocked) + return + } + chain.ProcessFilter(req, resp) +} + func detectHost(req *restful.Request) (hostName string, err error) { + if !*allowDetectedHost { + return + } + r := req.Request remoteAddr := r.RemoteAddr