etcdb/encrypted-set.go
2022-12-05 11:11:22 +01:00

388 lines
7.4 KiB
Go

package etcdb
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"strings"
"sync"
"syscall"
"golang.org/x/crypto/argon2"
)
// encrypted key hash -> master key
func Encrypted(spec EtcdSpec) *EncryptedSpec {
v := &EncryptedSpec{
keys: spec.Sub("keys/"),
data: spec.Sub("data/"),
}
v.l.Lock()
return v
}
type EncryptedSpec struct {
l sync.Mutex
keys,
data EtcdSpec
masterKey []byte // TODO syscall.MemSecret when kernels are up to date
salt1, salt2 []byte
}
func (es *EncryptedSpec) keyPairFromPassword(password []byte) (key, pub []byte) {
key = argon2.IDKey(password, es.salt1, 1, 64*1024, 4, 32)
pub = argon2.IDKey(key, es.salt2, 1, 64*1024, 4, 32)
return
}
func (es *EncryptedSpec) keysIn(ctx context.Context) DB {
return HexDB{es.keys.In(ctx)}
}
func (es *EncryptedSpec) loadSalts(ctx context.Context) error {
if es.salt1 != nil && es.salt2 != nil {
return nil
}
keySet := es.keysIn(ctx)
for _, i := range []struct {
key string
target *[]byte
}{
{"salt1", &es.salt1},
{"salt2", &es.salt2},
} {
v, ok, err := keySet.Get(i.key)
if err != nil {
return err
}
if !ok {
v = make([]byte, 64)
readRandom(v)
keySet.Put(i.key, v)
}
*i.target = v
}
return nil
}
func (es *EncryptedSpec) IsOpen() bool {
return len(es.masterKey) != 0
}
func (es *EncryptedSpec) WaitOpen() {
es.l.Lock()
es.l.Unlock()
}
func (es *EncryptedSpec) SetMasterKey(key []byte) {
if len(key) != 32 {
panic("wrong key size")
}
defer memzero(key)
masterKey := make([]byte, 32)
if err := syscall.Mlock(masterKey); err != nil {
panic(fmt.Errorf("mlock(masterKey) failed: %w", err))
}
for i := range key {
masterKey[i] = key[i]
}
es.l.Unlock()
}
func (es *EncryptedSpec) Unlock(ctx context.Context, password []byte) (err error) {
err = es.loadSalts(ctx)
if err != nil {
return
}
key, pubKey := es.keyPairFromPassword(password)
memzero(password)
defer memzero(key)
keySet := es.keysIn(ctx)
hasKey := false
keySet.ForEach(func(k string, _ []byte) (cont bool) {
if strings.HasPrefix(k, "0x") {
hasKey = true
}
return !hasKey
})
masterKey := make([]byte, 32)
if err = syscall.Mlock(masterKey); err != nil {
return fmt.Errorf("mlock(masterKey) failed: %w", err)
}
if !hasKey {
// no keys, initialize
readRandom(masterKey)
encryptedMaster := aesEncrypt(masterKey, key)
err = keySet.Put("0x"+hex.EncodeToString(pubKey), encryptedMaster)
if err != nil {
return
}
es.masterKey = masterKey
es.l.Unlock()
return
}
// has keys, lookup the key
encryptedMaster, ok, err := keySet.Get("0x" + hex.EncodeToString(pubKey))
if err != nil {
return
}
if !ok {
err = fmt.Errorf("no such key")
return
}
aesDecryptTo(masterKey, encryptedMaster, key)
es.masterKey = masterKey
es.l.Unlock()
return
}
func (es *EncryptedSpec) Lock() {
if !es.l.TryLock() {
panic("double lock")
}
memzero(es.masterKey)
syscall.Munlock(es.masterKey)
es.masterKey = nil
}
func (es *EncryptedSpec) AddKey(ctx context.Context, password []byte) error {
if es.masterKey == nil {
panic("no masterKey")
}
key, pubKey := es.keyPairFromPassword(password)
memzero(password)
encryptedMaster := func() []byte {
defer memzero(key)
return aesEncrypt(es.masterKey, key)
}()
return es.keysIn(ctx).Put("0x"+hex.EncodeToString(pubKey), encryptedMaster)
}
func (es *EncryptedSpec) In(ctx context.Context) DB {
return EncryptedDB{
es: es,
wrappedDB: es.data.In(ctx),
}
}
func (es *EncryptedSpec) Sub(prefix string) EncryptedSubSpec {
return EncryptedSubSpec{
es: es,
data: es.data.Sub(prefix),
}
}
type EncryptedSubSpec struct {
es *EncryptedSpec
data EtcdSpec
}
func (spec EncryptedSubSpec) In(ctx context.Context) DB {
return EncryptedDB{
es: spec.es,
wrappedDB: spec.data.In(ctx),
}
}
type EncryptedDB struct {
es *EncryptedSpec
wrappedDB DB
}
func (e EncryptedDB) obscureKey(keyBytes []byte) string {
h := sha512.Sum512(keyBytes)
return base64.RawStdEncoding.EncodeToString(h[:])
}
func (e EncryptedDB) encodeKV(key string, value []byte) (outKey string, outValue []byte, err error) {
keyBytes := []byte(key)
if len(keyBytes) > 0xffff {
err = errors.New("key too long")
return
}
outKey = e.obscureKey(keyBytes)
outValue = make([]byte, 0, 2+len(key)+len(value))
outValue = binary.BigEndian.AppendUint16(outValue, uint16(len(key)))
outValue = append(outValue, keyBytes...)
outValue = append(outValue, value...)
outValue = aesEncrypt(outValue, e.es.masterKey)
return
}
func (e EncryptedDB) decodeKV(encryptedKV []byte) (key string, value []byte, err error) {
kv := aesDecrypt(encryptedKV, e.es.masterKey)
if len(kv) < 2 {
err = errors.New("value too short for key len")
return
}
keyLen := int(binary.BigEndian.Uint16(kv[:2]))
kv = kv[2:]
if len(kv) < keyLen {
err = errors.New("value too short for key")
return
}
key = string(kv[:keyLen])
value = kv[keyLen:]
return
}
func (e EncryptedDB) Get(key string) (value []byte, ok bool, err error) {
encryptedValue, ok, err := e.wrappedDB.Get(e.obscureKey([]byte(key)))
if !ok || err != nil {
return
}
_, value, err = e.decodeKV(encryptedValue)
return
}
func (e EncryptedDB) Put(key string, value []byte) error {
dbKey, encryptedValue, err := e.encodeKV(key, value)
if err != nil {
return err
}
return e.wrappedDB.Put(dbKey, encryptedValue)
}
func (e EncryptedDB) Del(key string) error {
return e.wrappedDB.Del(e.obscureKey([]byte(key)))
}
func (e EncryptedDB) ForEach(callback func(key string, value []byte) (cont bool)) (err error) {
forEachErr := e.wrappedDB.ForEach(func(key string, encryptedValue []byte) (cont bool) {
realKey, value, decodeErr := e.decodeKV(encryptedValue)
if decodeErr != nil {
err = fmt.Errorf("failed to decode value under key %q: %w", key, decodeErr)
return false
}
return callback(realKey, value)
})
if err != nil {
return
}
return forEachErr
}
func (e EncryptedDB) Watch(rev int64) <-chan WatchEvent {
ch := make(chan WatchEvent, 1)
go func() {
defer close(ch)
inCh := e.wrappedDB.Watch(rev)
for evt := range inCh {
encryptedValue := evt.Value
evt.Value = nil
if len(encryptedValue) == 0 {
ch <- evt
continue
}
realKey, value, err := e.decodeKV(encryptedValue)
if err != nil {
evt.Err = fmt.Errorf("failed to decode value: %w", err)
ch <- evt
continue
}
evt.Key = realKey
evt.Value = value
ch <- evt
}
}()
return ch
}
func aesEncrypt(data, key []byte) []byte {
c, err := aes.NewCipher(key)
if err != nil {
panic(fmt.Errorf("failed to init AES: %w", err))
}
// output = initialization vector + encrypted data
const ivSize = aes.BlockSize
output := make([]byte, ivSize+len(data))
iv := output[:ivSize]
readRandom(iv)
cipher.NewCFBEncrypter(c, iv).XORKeyStream(output[ivSize:], data)
return output
}
func aesDecrypt(data, key []byte) []byte {
dst := make([]byte, len(data)-aes.BlockSize)
aesDecryptTo(dst, data, key)
return dst
}
func aesDecryptTo(dst, data, key []byte) {
c, err := aes.NewCipher(key)
if err != nil {
panic(fmt.Errorf("failed to init AES: %w", err))
}
const ivSize = aes.BlockSize
iv := data[:ivSize]
cipher.NewCFBDecrypter(c, iv).XORKeyStream(dst, data[ivSize:])
}
func readRandom(dst []byte) {
if _, err := io.ReadFull(rand.Reader, dst); err != nil {
log.Panic("failed to read random bytes: ", err)
}
}
func memzero(ba []byte) {
for i := range ba {
ba[i] = 0
}
}