Files
local-server/cmd/dkl-local-server/ws-clusters.go
2025-07-02 22:07:12 +02:00

332 lines
6.5 KiB
Go

package main
import (
"errors"
"fmt"
"log"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/cloudflare/cfssl/config"
"github.com/cloudflare/cfssl/csr"
"github.com/cloudflare/cfssl/signer"
restful "github.com/emicklei/go-restful"
"novit.tech/direktil/local-server/pkg/mime"
"novit.tech/direktil/pkg/localconfig"
)
var clusterSecretKVs = []string{}
func newClusterSecretKV[T any](name string) KVSecrets[T] {
clusterSecretKVs = append(clusterSecretKVs, name)
return KVSecrets[T]{"clusters/" + name}
}
func wsListClusters(req *restful.Request, resp *restful.Response) {
cfg := wsReadConfig(resp)
if cfg == nil {
return
}
clusterNames := make([]string, len(cfg.Clusters))
for i, cluster := range cfg.Clusters {
clusterNames[i] = cluster.Name
}
resp.WriteEntity(clusterNames)
}
func wsReadCluster(req *restful.Request, resp *restful.Response) (cluster *localconfig.Cluster) {
clusterName := req.PathParameter("cluster-name")
cfg := wsReadConfig(resp)
if cfg == nil {
return
}
cluster = cfg.Cluster(clusterName)
if cluster == nil {
wsNotFound(resp)
return
}
return
}
func wsCluster(req *restful.Request, resp *restful.Response) {
cluster := wsReadCluster(req, resp)
if cluster == nil {
return
}
resp.WriteEntity(cluster)
}
func wsClusterAddons(req *restful.Request, resp *restful.Response) {
cluster := wsReadCluster(req, resp)
if cluster == nil {
return
}
if len(cluster.Addons) == 0 {
log.Printf("cluster %q has no addons defined", cluster.Name)
wsNotFound(resp)
return
}
cfg := wsReadConfig(resp)
if cfg == nil {
return
}
sslCfg, err := sslConfigFromLocalConfig(cfg)
if err != nil {
wsError(resp, err)
return
}
wsRender(resp, sslCfg, cluster.Addons, cluster)
}
func wsClusterCACert(req *restful.Request, resp *restful.Response) {
clusterName := req.PathParameter("cluster-name")
caName := req.PathParameter("ca-name")
ca, found, err := clusterCAs.Get(clusterName + "/" + caName)
if err != nil {
wsError(resp, err)
return
}
if !found {
wsNotFound(resp)
return
}
resp.Header().Set("Content-Type", mime.CERT)
resp.Write(ca.Cert)
}
func wsClusterSignedCert(req *restful.Request, resp *restful.Response) {
clusterName := req.PathParameter("cluster-name")
caName := req.PathParameter("ca-name")
name := req.QueryParameter("name")
kc, found, err := clusterCASignedKeys.Get(clusterName + "/" + caName + "/" + name)
if err != nil {
wsError(resp, err)
return
}
if !found {
wsNotFound(resp)
return
}
resp.AddHeader("Content-Type", mime.CERT)
resp.AddHeader("Content-Disposition", "attachment; filename="+strconv.Quote(clusterName+"_"+caName+"_"+url.PathEscape(name)+".crt"))
resp.Write(kc.Cert)
}
type SSHSignReq struct {
PubKey string
Principal string
Validity string
Options []string
}
func wsClusterSSHUserCAPubKey(req *restful.Request, resp *restful.Response) {
clusterName := req.PathParameter("cluster-name")
pubkey, err := sshCAPubKey(clusterName)
if err != nil {
wsError(resp, err)
return
}
resp.Write(pubkey)
}
func wsClusterSSHUserCASign(req *restful.Request, resp *restful.Response) {
clusterName := req.PathParameter("cluster-name")
signReq := SSHSignReq{}
err := req.ReadEntity(&signReq)
if err != nil {
wsError(resp, err)
return
}
now := time.Now().Truncate(time.Second)
notBefore, notAfter, err := parseCertDurationRange(signReq.Validity, now)
if err != nil {
wsError(resp, fmt.Errorf("invalid validity: %w", err))
return
}
const sshTimestamp = "20060102150405Z"
validity := notBefore.Format(sshTimestamp) + ":"
if notAfter.IsZero() {
validity += "forever"
} else {
validity += notAfter.Format(sshTimestamp)
}
log.Printf("sign ssh public key, validity %s -> %s", signReq.Validity, validity)
cert, err := sshCASign(clusterName, []byte(signReq.PubKey), signReq.Principal, validity, signReq.Options...)
if err != nil {
wsError(resp, err)
return
}
resp.Write(cert)
}
type KubeSignReq struct {
CSR string
User string
Group string
Validity string
}
func wsClusterKubeCASign(req *restful.Request, resp *restful.Response) {
clusterName := req.PathParameter("cluster-name")
signReq := KubeSignReq{}
err := req.ReadEntity(&signReq)
if err != nil {
wsError(resp, err)
return
}
now := time.Now().Truncate(time.Second)
notBefore, notAfter, err := parseCertDurationRange(signReq.Validity, now)
if err != nil {
wsError(resp, fmt.Errorf("invalid validity: %w", err))
return
}
var names []csr.Name
if signReq.Group != "" {
names = []csr.Name{{O: signReq.Group}}
}
ca, err := getUsableClusterCA(clusterName, "cluster")
if err != nil {
wsError(resp, fmt.Errorf("get cluster CA failed: %w", err))
return
}
caSigner, err := ca.Signer(&config.Signing{
Default: &config.SigningProfile{
Usage: []string{"client auth"},
Expiry: notAfter.Sub(now),
},
})
if err != nil {
wsError(resp, err)
return
}
csr := signer.SignRequest{
Request: signReq.CSR,
Subject: &signer.Subject{
CN: signReq.User,
Names: names,
},
NotBefore: notBefore,
NotAfter: notAfter,
}
cert, err := caSigner.Sign(csr)
if err != nil {
wsError(resp, err)
return
}
resp.Write(cert)
}
func parseCertDurationRange(d string, now time.Time) (notBefore, notAfter time.Time, err error) {
if d == "" {
return
}
d1, d2, ok := strings.Cut(d, ":")
if ok {
notBefore, err = parseCertDuration(d1, now)
if err != nil {
return
}
notAfter, err = parseCertDuration(d2, now)
} else {
notAfter, err = parseCertDuration(d, now)
}
if err != nil {
return
}
if notBefore.IsZero() {
notBefore = now.Add(-5 * time.Minute)
}
return
}
var durRegex = regexp.MustCompile("^([+-]?)([0-9]+)([yMdwhms])")
func parseCertDuration(d string, now time.Time) (t time.Time, err error) {
if d == "" {
return
}
direction := 1
t = now
for d != "" {
match := durRegex.FindStringSubmatch(d)
if match == nil {
t = time.Time{}
err = errors.New("invalid duration: " + strconv.Quote(d))
return
}
d = d[len(match[0]):]
switch match[1] {
case "+":
direction = 1
case "-":
direction = -1
}
qty, _ := strconv.Atoi(match[2])
unit := match[3]
switch unit {
case "y":
t = t.AddDate(qty*direction, 0, 0)
case "M":
t = t.AddDate(0, qty*direction, 0)
case "d":
t = t.AddDate(0, 0, qty*direction)
case "w":
t = t.AddDate(0, 0, 7*qty*direction)
case "h":
t = t.Add(time.Duration(qty*direction) * time.Hour)
case "m":
t = t.Add(time.Duration(qty*direction) * time.Minute)
case "s":
t = t.Add(time.Duration(qty*direction) * time.Second)
}
}
return
}