package main import ( "crypto/sha1" "encoding/hex" "encoding/json" "flag" "fmt" "io" "log" "net/http" "os" "path" "path/filepath" "sort" "strings" "time" "github.com/coreos/go-semver/semver" ) var ( bind = flag.String("bind", ":8080", "Bind address") uploadToken = flag.String("upload-token", "", "Upload token (no uploads allowed if empty)") storeDir = flag.String("store-dir", "/srv/dkl-store", "Store directory") ) func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) flag.Parse() http.HandleFunc("/", handleHTTP) log.Print("listening on ", *bind) log.Fatal(http.ListenAndServe(*bind, nil)) } func handleHTTP(w http.ResponseWriter, req *http.Request) { filePath := filepath.Join(*storeDir, path.Clean(req.URL.Path)) l := fmt.Sprintf("%s %s", req.Method, filePath) log.Print(l) defer log.Print(l, " done") stat, err := os.Stat(filePath) if err != nil { if !os.IsNotExist(err) { writeErr(err, w) } else { http.NotFound(w, req) } return } if stat.Mode().IsDir() { entries, err := os.ReadDir(filePath) if err != nil { writeErr(err, w) return } w.Header().Set("Content-Type", "application/json") resp := struct { Versions map[string]*VersionInfo `json:",omitempty"` Names []string }{ Versions: make(map[string]*VersionInfo, len(entries)), Names: make([]string, 0, len(entries)), } for _, e := range entries { name := e.Name() if strings.HasSuffix(name, ".sha1") { continue } resp.Names = append(resp.Names, e.Name()) } resp.Versions = aggregateVersions(resp.Names) sort.Strings(resp.Names) json.NewEncoder(w).Encode(resp) return } switch req.Method { case "GET", "HEAD": sha1Hex, err := hashOf(filePath) if err != nil { writeErr(err, w) return } w.Header().Set("X-Content-SHA1", sha1Hex) http.ServeFile(w, req, filePath) case "POST": if req.Header.Get("Authorization") != ("Bearer " + *uploadToken) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } tmpOut := filepath.Join(filepath.Dir(filePath), "."+filepath.Base(filePath)) if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { writeErr(err, w) return } out, err := os.Create(tmpOut) if err != nil { writeErr(err, w) return } h := sha1.New() mw := io.MultiWriter(out, h) _, err = io.Copy(mw, req.Body) out.Close() if err != nil { os.Remove(tmpOut) writeErr(err, w) return } sha1Hex := hex.EncodeToString(h.Sum(nil)) log.Print("upload SHA1: ", sha1Hex) reqSHA1 := req.Header.Get("X-Content-SHA1") if reqSHA1 != "" { if reqSHA1 != sha1Hex { err = fmt.Errorf("upload SHA1 does not match given SHA1: %s", reqSHA1) w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error() + "\n")) return } log.Print("upload SHA1 is as expected") } os.Rename(tmpOut, filePath) if err := os.WriteFile(filePath+".sha1", []byte(sha1Hex), 0644); err != nil { writeErr(err, w) return } w.WriteHeader(http.StatusCreated) default: http.NotFound(w, req) return } } func writeErr(err error, w http.ResponseWriter) { if os.IsNotExist(err) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("Not found\n")) return } log.Output(2, fmt.Sprint("internal error: ", err)) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Internal error\n")) } func hashOf(filePath string) (sha1Hex string, err error) { sha1Path := filePath + ".sha1" fileStat, err := os.Stat(filePath) if err != nil { return } sha1Stat, err := os.Stat(sha1Path) if err == nil { if sha1Stat.ModTime().After(fileStat.ModTime()) { // cached value is up-to-date sha1HexBytes, readErr := os.ReadFile(sha1Path) if readErr == nil { sha1Hex = string(sha1HexBytes) return } } } else if !os.IsNotExist(err) { // failed to stat cached value return } // no cached value could be read log.Print("hashing ", filePath) start := time.Now() // hash the input f, err := os.Open(filePath) if err != nil { return } defer f.Close() h := sha1.New() _, err = io.Copy(h, f) if err != nil { return } sha1Hex = hex.EncodeToString(h.Sum(nil)) log.Print("hashing ", filePath, " took ", time.Since(start).Truncate(time.Millisecond)) if writeErr := os.WriteFile(sha1Path, []byte(sha1Hex), 0644); writeErr != nil { log.Printf("WARNING: failed to cache SHA1: %v", writeErr) } return } func aggregateVersions(names []string) map[string]*VersionInfo { versions := make([]VersionName, 0, len(names)) for _, name := range names { rem := name segments := make([]semver.Version, 0, 5) for len(rem) != 0 { var s string s, rem, _ = strings.Cut(rem, "_") // remove non-number prefix chars s = strings.TrimFunc(s, func(c rune) bool { return !('0' <= c && c <= '9') }) ver, err := semver.NewVersion(s) if err != nil { continue } segments = append(segments, *ver) } if len(segments) == 0 { continue } versions = append(versions, VersionName{segments, name}) } sort.Slice(versions, func(i, j int) bool { return versions[i].LessThan(versions[j]) }) ret := make(map[string]*VersionInfo, len(versions)) for _, vi := range versions { v := vi.segments[0] name := vi.name for _, key := range []string{ fmt.Sprintf("%d", v.Major), fmt.Sprintf("%d.%d", v.Major, v.Minor), } { agg, ok := ret[key] if !ok { agg = &VersionInfo{} ret[key] = agg } agg.Latest = name agg.All = append(agg.All, name) } } return ret } type VersionInfo struct { Latest string All []string } type VersionName struct { segments []semver.Version name string } func (a VersionName) LessThan(b VersionName) bool { n := len(a.segments) if l := len(b.segments); l > n { n = l } for i := 0; i != n; i++ { if i >= len(a.segments) { return true } if i >= len(b.segments) { return false } va, vb := a.segments[i], b.segments[i] if !va.Equal(vb) { return va.LessThan(vb) } } return a.name < b.name }