package main import ( "bytes" "crypto/ed25519" "encoding/base32" "fmt" "io" "path/filepath" "slices" "strconv" "strings" "time" restful "github.com/emicklei/go-restful" "github.com/pierrec/lz4" "m.cluseau.fr/go/httperr" ) func globMatch(pattern, value string) bool { ok, _ := filepath.Match(pattern, value) return ok } type DownloadSet struct { Expiry time.Time Items []DownloadSetItem } func (s DownloadSet) Contains(kind, name, asset string) bool { for _, item := range s.Items { if item.Kind == kind && globMatch(item.Name, name) && slices.Contains(item.Assets, asset) { return true } } return false } func (s DownloadSet) Encode() string { buf := new(strings.Builder) s.EncodeTo(buf) return buf.String() } func (s DownloadSet) EncodeTo(buf *strings.Builder) { buf.WriteString(strconv.FormatInt(s.Expiry.Unix(), 16)) for _, item := range s.Items { buf.WriteByte('|') item.EncodeTo(buf) } } func (s *DownloadSet) Decode(encoded string) (err error) { exp, rem, _ := strings.Cut(encoded, "|") expUnix, err := strconv.ParseInt(exp, 16, 64) if err != nil { return } s.Expiry = time.Unix(expUnix, 0) if rem == "" { s.Items = nil } else { itemStrs := strings.Split(rem, "|") s.Items = make([]DownloadSetItem, len(itemStrs)) for i, itemStr := range itemStrs { s.Items[i].Decode(itemStr) } } return } func (s DownloadSet) Signed(privKey ed25519.PrivateKey) string { buf := new(bytes.Buffer) { setBytes := []byte(s.Encode()) w := lz4.NewWriter(buf) w.Write(setBytes) w.Close() } setBytes := buf.Bytes() sig := ed25519.Sign(privKey, setBytes) buf = bytes.NewBuffer(make([]byte, 0, 1+len(sig)+len(setBytes))) buf.WriteByte(byte(len(sig))) buf.Write(sig) buf.Write(setBytes) enc := base32.StdEncoding.WithPadding(base32.NoPadding) return enc.EncodeToString(buf.Bytes()) } type DownloadSetItem struct { Kind string Name string Assets []string } func (i DownloadSetItem) EncodeTo(buf *strings.Builder) { kind := i.Kind switch kind { case "host": kind = "h" case "cluster": kind = "c" } buf.WriteString(kind) buf.WriteByte(':') buf.WriteString(i.Name) for _, asset := range i.Assets { buf.WriteByte(':') buf.WriteString(asset) } } func (i *DownloadSetItem) Decode(encoded string) { rem := encoded i.Kind, rem, _ = strings.Cut(rem, ":") switch i.Kind { case "h": i.Kind = "host" case "c": i.Kind = "cluster" } i.Name, rem, _ = strings.Cut(rem, ":") if rem == "" { i.Assets = nil } else { i.Assets = strings.Split(rem, ":") } } type DownloadSetReq struct { Expiry string Items []DownloadSetItem } func wsSignDownloadSet(req *restful.Request, resp *restful.Response) { setReq := DownloadSetReq{} if err := req.ReadEntity(&setReq); err != nil { wsError(resp, err) return } exp, err := parseCertDuration(setReq.Expiry, time.Now()) if err != nil { wsError(resp, err) return } set := DownloadSet{ Expiry: exp, Items: setReq.Items, } privKey, _ := dlsSigningKeys() resp.WriteEntity(set.Signed(privKey)) } func getDlSet(req *restful.Request) (*DownloadSet, *httperr.Error) { setStr := req.QueryParameter("set") setBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(setStr) if err != nil { err := httperr.BadRequest("invalid set") return nil, &err } if len(setBytes) == 0 { err := httperr.BadRequest("invalid set") return nil, &err } sigLen := int(setBytes[0]) setBytes = setBytes[1:] if len(setBytes) < sigLen { err := httperr.BadRequest("invalid set") return nil, &err } sig := setBytes[:sigLen] setBytes = setBytes[sigLen:] _, pubkey := dlsSigningKeys() if !ed25519.Verify(pubkey, setBytes, sig) { err := httperr.BadRequest("invalid signature") return nil, &err } setBytes, err = io.ReadAll(lz4.NewReader(bytes.NewBuffer(setBytes))) if err != nil { err := httperr.BadRequest("invalid data") return nil, &err } fmt.Println(string(setBytes)) set := DownloadSet{} if err := set.Decode(string(setBytes)); err != nil { err := httperr.BadRequest("invalid set: " + err.Error()) return nil, &err } if time.Now().After(set.Expiry) { err := httperr.BadRequest("set expired") return nil, &err } return &set, nil } func wsDownloadSetAsset(req *restful.Request, resp *restful.Response) { set, err := getDlSet(req) if err != nil { wsError(resp, *err) return } kind := req.PathParameter("kind") name := req.PathParameter("name") asset := req.PathParameter("asset") if !set.Contains(kind, name, asset) { wsNotFound(resp) return } downloadAsset(req, resp, kind, name, asset) } func wsDownloadSet(req *restful.Request, resp *restful.Response) { setStr := req.QueryParameter("set") set, err := getDlSet(req) if err != nil { resp.WriteHeader(err.Status) resp.Write([]byte(htmlHeader(err.Error()))) resp.Write([]byte(htmlFooter)) return } buf := new(bytes.Buffer) buf.WriteString(htmlHeader("Download set")) cfg, err2 := readConfig() if err2 != nil { wsError(resp, err2) return } for _, item := range set.Items { names := make([]string, 0) switch item.Kind { case "cluster": for _, c := range cfg.Clusters { if globMatch(item.Name, c.Name) { names = append(names, c.Name) } } case "host": for _, h := range cfg.Hosts { if globMatch(item.Name, h.Name) { names = append(names, h.Name) } } } for _, name := range names { fmt.Fprintf(buf, "

%s %s

", strings.Title(item.Kind), name) fmt.Fprintf(buf, "

\n") for _, asset := range item.Assets { fmt.Fprintf(buf, " %s\n", item.Kind, name, asset, setStr, asset) } fmt.Fprintf(buf, `

`) } } buf.WriteString(htmlFooter) buf.WriteTo(resp) }