package etcdb import ( "context" "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha512" "encoding/base64" "encoding/binary" "encoding/hex" "errors" "fmt" "io" "log" "strings" "sync" "syscall" "golang.org/x/crypto/argon2" ) // encrypted key hash -> master key func Encrypted(spec EtcdSpec) *EncryptedSpec { v := &EncryptedSpec{ keys: spec.Sub("keys/"), data: spec.Sub("data/"), } v.l.Lock() return v } type EncryptedSpec struct { l sync.Mutex keys, data EtcdSpec masterKey []byte // TODO syscall.MemSecret when kernels are up to date salt1, salt2 []byte } func (es *EncryptedSpec) keyPairFromPassword(password []byte) (key, pub []byte) { key = argon2.IDKey(password, es.salt1, 1, 64*1024, 4, 32) pub = argon2.IDKey(key, es.salt2, 1, 64*1024, 4, 32) return } func (es *EncryptedSpec) keysIn(ctx context.Context) DB { return HexDB{es.keys.In(ctx)} } func (es *EncryptedSpec) loadSalts(ctx context.Context) error { if es.salt1 != nil && es.salt2 != nil { return nil } keySet := es.keysIn(ctx) for _, i := range []struct { key string target *[]byte }{ {"salt1", &es.salt1}, {"salt2", &es.salt2}, } { v, ok, err := keySet.Get(i.key) if err != nil { return err } if !ok { v = make([]byte, 64) readRandom(v) keySet.Put(i.key, v) } *i.target = v } return nil } func (es *EncryptedSpec) IsOpen() bool { return len(es.masterKey) != 0 } func (es *EncryptedSpec) WaitOpen() { es.l.Lock() es.l.Unlock() } func (es *EncryptedSpec) SetMasterKey(key []byte) { if len(key) != 32 { panic("wrong key size") } defer memzero(key) masterKey := make([]byte, 32) if err := syscall.Mlock(masterKey); err != nil { panic(fmt.Errorf("mlock(masterKey) failed: %w", err)) } for i := range key { masterKey[i] = key[i] } es.l.Unlock() } func (es *EncryptedSpec) Unlock(ctx context.Context, password []byte) (err error) { err = es.loadSalts(ctx) if err != nil { return } key, pubKey := es.keyPairFromPassword(password) memzero(password) defer memzero(key) keySet := es.keysIn(ctx) hasKey := false keySet.ForEach(func(k string, _ []byte) (cont bool) { if strings.HasPrefix(k, "0x") { hasKey = true } return !hasKey }) masterKey := make([]byte, 32) if err = syscall.Mlock(masterKey); err != nil { return fmt.Errorf("mlock(masterKey) failed: %w", err) } if !hasKey { // no keys, initialize readRandom(masterKey) encryptedMaster := aesEncrypt(masterKey, key) err = keySet.Put("0x"+hex.EncodeToString(pubKey), encryptedMaster) if err != nil { return } es.masterKey = masterKey es.l.Unlock() return } // has keys, lookup the key encryptedMaster, ok, err := keySet.Get("0x" + hex.EncodeToString(pubKey)) if err != nil { return } if !ok { err = fmt.Errorf("no such key") return } aesDecryptTo(masterKey, encryptedMaster, key) es.masterKey = masterKey es.l.Unlock() return } func (es *EncryptedSpec) Lock() { if !es.l.TryLock() { panic("double lock") } memzero(es.masterKey) syscall.Munlock(es.masterKey) es.masterKey = nil } func (es *EncryptedSpec) AddKey(ctx context.Context, password []byte) error { if es.masterKey == nil { panic("no masterKey") } key, pubKey := es.keyPairFromPassword(password) memzero(password) encryptedMaster := func() []byte { defer memzero(key) return aesEncrypt(es.masterKey, key) }() return es.keysIn(ctx).Put("0x"+hex.EncodeToString(pubKey), encryptedMaster) } func (es *EncryptedSpec) In(ctx context.Context) DB { return EncryptedDB{ es: es, wrappedDB: es.data.In(ctx), } } func (es *EncryptedSpec) Sub(prefix string) EncryptedSubSpec { return EncryptedSubSpec{ es: es, data: es.data.Sub(prefix), } } type EncryptedSubSpec struct { es *EncryptedSpec data EtcdSpec } func (spec EncryptedSubSpec) In(ctx context.Context) DB { return EncryptedDB{ es: spec.es, wrappedDB: spec.data.In(ctx), } } type EncryptedDB struct { es *EncryptedSpec wrappedDB DB } func (e EncryptedDB) obscureKey(keyBytes []byte) string { h := sha512.Sum512(keyBytes) return base64.RawStdEncoding.EncodeToString(h[:]) } func (e EncryptedDB) encodeKV(key string, value []byte) (outKey string, outValue []byte, err error) { keyBytes := []byte(key) if len(keyBytes) > 0xffff { err = errors.New("key too long") return } outKey = e.obscureKey(keyBytes) outValue = make([]byte, 0, 2+len(key)+len(value)) outValue = binary.BigEndian.AppendUint16(outValue, uint16(len(key))) outValue = append(outValue, keyBytes...) outValue = append(outValue, value...) outValue = aesEncrypt(outValue, e.es.masterKey) return } func (e EncryptedDB) decodeKV(encryptedKV []byte) (key string, value []byte, err error) { kv := aesDecrypt(encryptedKV, e.es.masterKey) if len(kv) < 2 { err = errors.New("value too short for key len") return } keyLen := int(binary.BigEndian.Uint16(kv[:2])) kv = kv[2:] if len(kv) < keyLen { err = errors.New("value too short for key") return } key = string(kv[:keyLen]) value = kv[keyLen:] return } func (e EncryptedDB) Get(key string) (value []byte, ok bool, err error) { encryptedValue, ok, err := e.wrappedDB.Get(e.obscureKey([]byte(key))) if !ok || err != nil { return } _, value, err = e.decodeKV(encryptedValue) return } func (e EncryptedDB) Put(key string, value []byte) error { dbKey, encryptedValue, err := e.encodeKV(key, value) if err != nil { return err } return e.wrappedDB.Put(dbKey, encryptedValue) } func (e EncryptedDB) Del(key string) error { return e.wrappedDB.Del(e.obscureKey([]byte(key))) } func (e EncryptedDB) ForEach(callback func(key string, value []byte) (cont bool)) (err error) { forEachErr := e.wrappedDB.ForEach(func(key string, encryptedValue []byte) (cont bool) { realKey, value, decodeErr := e.decodeKV(encryptedValue) if decodeErr != nil { err = fmt.Errorf("failed to decode value under key %q: %w", key, decodeErr) return false } return callback(realKey, value) }) if err != nil { return } return forEachErr } func (e EncryptedDB) Watch(rev int64) <-chan WatchEvent { ch := make(chan WatchEvent, 1) go func() { defer close(ch) inCh := e.wrappedDB.Watch(rev) for evt := range inCh { encryptedValue := evt.Value evt.Value = nil if len(encryptedValue) == 0 { ch <- evt continue } realKey, value, err := e.decodeKV(encryptedValue) if err != nil { evt.Err = fmt.Errorf("failed to decode value: %w", err) ch <- evt continue } evt.Key = realKey evt.Value = value ch <- evt } }() return ch } func aesEncrypt(data, key []byte) []byte { c, err := aes.NewCipher(key) if err != nil { panic(fmt.Errorf("failed to init AES: %w", err)) } // output = initialization vector + encrypted data const ivSize = aes.BlockSize output := make([]byte, ivSize+len(data)) iv := output[:ivSize] readRandom(iv) cipher.NewCFBEncrypter(c, iv).XORKeyStream(output[ivSize:], data) return output } func aesDecrypt(data, key []byte) []byte { dst := make([]byte, len(data)-aes.BlockSize) aesDecryptTo(dst, data, key) return dst } func aesDecryptTo(dst, data, key []byte) { c, err := aes.NewCipher(key) if err != nil { panic(fmt.Errorf("failed to init AES: %w", err)) } const ivSize = aes.BlockSize iv := data[:ivSize] cipher.NewCFBDecrypter(c, iv).XORKeyStream(dst, data[ivSize:]) } func readRandom(dst []byte) { if _, err := io.ReadFull(rand.Reader, dst); err != nil { log.Panic("failed to read random bytes: ", err) } } func memzero(ba []byte) { for i := range ba { ba[i] = 0 } }