cleanup hosts ws

This commit is contained in:
Mikaël Cluseau 2023-05-18 19:55:52 +02:00
parent 4ed50e3b78
commit b6e7c55704
8 changed files with 55 additions and 46 deletions

View File

@ -1,22 +1,14 @@
package main package main
import ( import (
"flag"
"log" "log"
"net/http" "net/http"
) )
var ( var adminToken string
hostsToken = flag.String("hosts-token", "", "Token to give to access /hosts (open is none)")
adminToken = flag.String("admin-token", "", "Token to give to access to admin actions (open is none)")
)
func authorizeHosts(r *http.Request) bool {
return authorizeToken(r, *hostsToken)
}
func authorizeAdmin(r *http.Request) bool { func authorizeAdmin(r *http.Request) bool {
return authorizeToken(r, *adminToken) return authorizeToken(r, adminToken)
} }
func authorizeToken(r *http.Request, token string) bool { func authorizeToken(r *http.Request, token string) bool {
@ -49,9 +41,5 @@ func requireToken(token string, handler http.Handler) http.Handler {
} }
func requireAdmin(handler http.Handler) http.Handler { func requireAdmin(handler http.Handler) http.Handler {
return requireToken(*adminToken, handler) return requireToken(adminToken, handler)
}
func requireHosts(handler http.Handler) http.Handler {
return requireToken(*hostsToken, handler)
} }

View File

@ -10,5 +10,7 @@ var (
ErrNotFound = httperr.NotFound ErrNotFound = httperr.NotFound
ErrUnauthorized = httperr.StdStatus(http.StatusUnauthorized) ErrUnauthorized = httperr.StdStatus(http.StatusUnauthorized)
ErrForbidden = httperr.StdStatus(http.StatusForbidden) ErrForbidden = httperr.StdStatus(http.StatusForbidden)
ErrInternal = httperr.StdStatus(http.StatusInternalServerError)
ErrInvalidToken = httperr.NewStd(1000, http.StatusForbidden, "invalid token") ErrInvalidToken = httperr.NewStd(1000, http.StatusForbidden, "invalid token")
ErrStoreLocked = httperr.NewStd(1001, http.StatusServiceUnavailable, "store is locked")
) )

View File

@ -57,7 +57,7 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
log.Print("store auto-unlocked, admin token is ", *adminToken) log.Print("store auto-unlocked, admin token is ", adminToken)
} }
os.Setenv("DLS_AUTO_UNLOCK", "") os.Setenv("DLS_AUTO_UNLOCK", "")

View File

@ -117,7 +117,7 @@ func unlockSecretStore(passphrase []byte) (err httperr.Error) {
log.Print("wrote new admin token") log.Print("wrote new admin token")
} }
*adminToken = token adminToken = token
{ {
token, err := newToken(16) token, err := newToken(16)

View File

@ -7,18 +7,14 @@ import (
) )
func adminAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { func adminAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
tokenAuth(req, resp, chain, *adminToken) tokenAuth(req, resp, chain, adminToken)
}
func hostsAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
tokenAuth(req, resp, chain, *hostsToken, *adminToken)
} }
func tokenAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, allowedTokens ...string) { func tokenAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, allowedTokens ...string) {
token := getToken(req) token := getToken(req)
for _, allowedToken := range allowedTokens { for _, allowedToken := range allowedTokens {
if allowedToken == "" || token == allowedToken { if allowedToken != "" && token == allowedToken {
chain.ProcessFilter(req, resp) chain.ProcessFilter(req, resp)
return return
} }

View File

@ -13,21 +13,23 @@ import (
"novit.tech/direktil/local-server/pkg/mime" "novit.tech/direktil/local-server/pkg/mime"
) )
var trustXFF = flag.Bool("trust-xff", true, "Trust the X-Forwarded-For header") var (
allowDetectedHost = flag.Bool("allow-detected-host", false, "Allow access to host assets from its IP (insecure but enables unattended netboot)")
trustXFF = flag.Bool("trust-xff", false, "Trust the X-Forwarded-For header")
)
type wsHost struct { type wsHost struct {
prefix string
hostDoc string hostDoc string
getHost func(req *restful.Request) (hostName string, err error) getHost func(req *restful.Request) (hostName string, err error)
} }
func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) { func (ws wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteBuilder)) {
b := func(what string) *restful.RouteBuilder { b := func(what string) *restful.RouteBuilder {
return rws.GET(ws.prefix + "/" + what).To(ws.render) return rws.GET("/" + what).To(ws.render)
} }
for _, rb := range []*restful.RouteBuilder{ for _, rb := range []*restful.RouteBuilder{
rws.GET(ws.prefix).To(ws.get). rws.GET("").To(ws.get).
Doc("Get the "+ws.hostDoc+"'s details"). Doc("Get the "+ws.hostDoc+"'s details").
Returns(200, "OK", localconfig.Host{}), Returns(200, "OK", localconfig.Host{}),
@ -104,7 +106,7 @@ func (ws *wsHost) register(rws *restful.WebService, alterRB func(*restful.RouteB
} }
} }
func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) { func (ws wsHost) host(req *restful.Request, resp *restful.Response) (host *localconfig.Host, cfg *localconfig.Config) {
hostname, err := ws.getHost(req) hostname, err := ws.getHost(req)
if err != nil { if err != nil {
wsError(resp, err) wsError(resp, err)
@ -130,7 +132,7 @@ func (ws *wsHost) host(req *restful.Request, resp *restful.Response) (host *loca
return return
} }
func (ws *wsHost) get(req *restful.Request, resp *restful.Response) { func (ws wsHost) get(req *restful.Request, resp *restful.Response) {
host, _ := ws.host(req, resp) host, _ := ws.host(req, resp)
if host == nil { if host == nil {
return return
@ -139,7 +141,7 @@ func (ws *wsHost) get(req *restful.Request, resp *restful.Response) {
resp.WriteEntity(host) resp.WriteEntity(host)
} }
func (ws *wsHost) render(req *restful.Request, resp *restful.Response) { func (ws wsHost) render(req *restful.Request, resp *restful.Response) {
host, cfg := ws.host(req, resp) host, cfg := ws.host(req, resp)
if host == nil { if host == nil {
return return

View File

@ -24,7 +24,7 @@ func wsUnlockStore(req *restful.Request, resp *restful.Response) {
return return
} }
resp.WriteEntity(*adminToken) resp.WriteEntity(adminToken)
} }
func wsStoreDownload(req *restful.Request, resp *restful.Response) { func wsStoreDownload(req *restful.Request, resp *restful.Response) {

View File

@ -43,10 +43,10 @@ func registerWS(rest *restful.Container) {
} }
// Admin-level APIs // Admin-level APIs
ws := &restful.WebService{} ws := (&restful.WebService{}).
ws. Filter(requireSecStore).
Filter(adminAuth). Filter(adminAuth).
Param(ws.HeaderParameter("Authorization", "Admin bearer token").Required(true)). Param(restful.HeaderParameter("Authorization", "Admin bearer token").Required(true)).
Produces(mime.JSON) Produces(mime.JSON)
// - store management // - store management
@ -118,8 +118,20 @@ func registerWS(rest *restful.Container) {
ws.Route(ws.GET("/hosts").To(wsListHosts). ws.Route(ws.GET("/hosts").To(wsListHosts).
Doc("List hosts")) Doc("List hosts"))
ws.Route(ws.GET("/ssh-acls").To(wsSSH_ACL_List))
ws.Route(ws.GET("/ssh-acls/{acl-name}").To(wsSSH_ACL_Get))
ws.Route(ws.PUT("/ssh-acls/{acl-name}").To(wsSSH_ACL_Set))
rest.Add(ws)
// Hosts API
ws = (&restful.WebService{}).
Filter(requireSecStore).
Filter(adminAuth).
Path("/hosts/{host-name}").
Param(ws.HeaderParameter("Authorization", "Host or admin bearer token"))
(&wsHost{ (&wsHost{
prefix: "/hosts/{host-name}",
hostDoc: "given host", hostDoc: "given host",
getHost: func(req *restful.Request) (string, error) { getHost: func(req *restful.Request) (string, error) {
return req.PathParameter("host-name"), nil return req.PathParameter("host-name"), nil
@ -128,17 +140,12 @@ func registerWS(rest *restful.Container) {
rb.Param(ws.PathParameter("host-name", "host's name")) rb.Param(ws.PathParameter("host-name", "host's name"))
}) })
ws.Route(ws.GET("/ssh-acls").To(wsSSH_ACL_List))
ws.Route(ws.GET("/ssh-acls/{acl-name}").To(wsSSH_ACL_Get))
ws.Route(ws.PUT("/ssh-acls/{acl-name}").To(wsSSH_ACL_Set))
rest.Add(ws) rest.Add(ws)
// Hosts API // Detected host API
ws = &restful.WebService{} ws = (&restful.WebService{}).
ws.Produces(mime.JSON). Filter(requireSecStore).
Path("/me"). Path("/me").
Filter(hostsAuth).
Param(ws.HeaderParameter("Authorization", "Host or admin bearer token")) Param(ws.HeaderParameter("Authorization", "Host or admin bearer token"))
(&wsHost{ (&wsHost{
@ -149,8 +156,10 @@ func registerWS(rest *restful.Container) {
}) })
// Hosts by token API // Hosts by token API
ws = &restful.WebService{} ws = (&restful.WebService{}).
ws.Path("/hosts-by-token/{host-token}").Param(ws.PathParameter("host-token", "host's download token")) Filter(requireSecStore).
Path("/hosts-by-token/{host-token}").
Param(ws.PathParameter("host-token", "host's download token"))
(&wsHost{ (&wsHost{
hostDoc: "token's host", hostDoc: "token's host",
@ -178,7 +187,19 @@ func registerWS(rest *restful.Container) {
rest.Add(ws) rest.Add(ws)
} }
func requireSecStore(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
if !secStore.Unlocked() {
wsError(resp, ErrStoreLocked)
return
}
chain.ProcessFilter(req, resp)
}
func detectHost(req *restful.Request) (hostName string, err error) { func detectHost(req *restful.Request) (hostName string, err error) {
if !*allowDetectedHost {
return
}
r := req.Request r := req.Request
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr