// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Helper code for parsing a protocol buffer

package protolazy

import (
	"errors"
	"fmt"
	"io"

	"google.golang.org/protobuf/encoding/protowire"
)

// BufferReader is a structure encapsulating a protobuf and a current position
type BufferReader struct {
	Buf []byte
	Pos int
}

// NewBufferReader creates a new BufferRead from a protobuf
func NewBufferReader(buf []byte) BufferReader {
	return BufferReader{Buf: buf, Pos: 0}
}

var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
var errOverflow = errors.New("proto: integer overflow")

func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
	i := b.Pos
	l := len(b.Buf)

	for shift := uint(0); shift < 64; shift += 7 {
		if i >= l {
			err = io.ErrUnexpectedEOF
			return
		}
		v := b.Buf[i]
		i++
		x |= (uint64(v) & 0x7F) << shift
		if v < 0x80 {
			b.Pos = i
			return
		}
	}

	// The number is too large to represent in a 64-bit value.
	err = errOverflow
	return
}

// decodeVarint decodes a varint at the current position
func (b *BufferReader) DecodeVarint() (x uint64, err error) {
	i := b.Pos
	buf := b.Buf

	if i >= len(buf) {
		return 0, io.ErrUnexpectedEOF
	} else if buf[i] < 0x80 {
		b.Pos++
		return uint64(buf[i]), nil
	} else if len(buf)-i < 10 {
		return b.DecodeVarintSlow()
	}

	var v uint64
	// we already checked the first byte
	x = uint64(buf[i]) & 127
	i++

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 7
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 14
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 21
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 28
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 35
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 42
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 49
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 56
	if v < 128 {
		goto done
	}

	v = uint64(buf[i])
	i++
	x |= (v & 127) << 63
	if v < 128 {
		goto done
	}

	return 0, errOverflow

done:
	b.Pos = i
	return
}

// decodeVarint32 decodes a varint32 at the current position
func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
	i := b.Pos
	buf := b.Buf

	if i >= len(buf) {
		return 0, io.ErrUnexpectedEOF
	} else if buf[i] < 0x80 {
		b.Pos++
		return uint32(buf[i]), nil
	} else if len(buf)-i < 5 {
		v, err := b.DecodeVarintSlow()
		return uint32(v), err
	}

	var v uint32
	// we already checked the first byte
	x = uint32(buf[i]) & 127
	i++

	v = uint32(buf[i])
	i++
	x |= (v & 127) << 7
	if v < 128 {
		goto done
	}

	v = uint32(buf[i])
	i++
	x |= (v & 127) << 14
	if v < 128 {
		goto done
	}

	v = uint32(buf[i])
	i++
	x |= (v & 127) << 21
	if v < 128 {
		goto done
	}

	v = uint32(buf[i])
	i++
	x |= (v & 127) << 28
	if v < 128 {
		goto done
	}

	return 0, errOverflow

done:
	b.Pos = i
	return
}

// skipValue skips a value in the protobuf, based on the specified tag
func (b *BufferReader) SkipValue(tag uint32) (err error) {
	wireType := tag & 0x7
	switch protowire.Type(wireType) {
	case protowire.VarintType:
		err = b.SkipVarint()
	case protowire.Fixed64Type:
		err = b.SkipFixed64()
	case protowire.BytesType:
		var n uint32
		n, err = b.DecodeVarint32()
		if err == nil {
			err = b.Skip(int(n))
		}
	case protowire.StartGroupType:
		err = b.SkipGroup(tag)
	case protowire.Fixed32Type:
		err = b.SkipFixed32()
	default:
		err = fmt.Errorf("Unexpected wire type (%d)", wireType)
	}
	return
}

// skipGroup skips a group with the specified tag.  It executes efficiently using a tag stack
func (b *BufferReader) SkipGroup(tag uint32) (err error) {
	tagStack := make([]uint32, 0, 16)
	tagStack = append(tagStack, tag)
	var n uint32
	for len(tagStack) > 0 {
		tag, err = b.DecodeVarint32()
		if err != nil {
			return err
		}
		switch protowire.Type(tag & 0x7) {
		case protowire.VarintType:
			err = b.SkipVarint()
		case protowire.Fixed64Type:
			err = b.Skip(8)
		case protowire.BytesType:
			n, err = b.DecodeVarint32()
			if err == nil {
				err = b.Skip(int(n))
			}
		case protowire.StartGroupType:
			tagStack = append(tagStack, tag)
		case protowire.Fixed32Type:
			err = b.SkipFixed32()
		case protowire.EndGroupType:
			if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
				tagStack = tagStack[:len(tagStack)-1]
			} else {
				err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
					protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
			}
		}
		if err != nil {
			return err
		}
	}
	return nil
}

// skipVarint effiently skips a varint
func (b *BufferReader) SkipVarint() (err error) {
	i := b.Pos

	if len(b.Buf)-i < 10 {
		// Use DecodeVarintSlow() to check for buffer overflow, but ignore result
		if _, err := b.DecodeVarintSlow(); err != nil {
			return err
		}
		return nil
	}

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	i++

	if b.Buf[i] < 0x80 {
		goto out
	}
	return errOverflow

out:
	b.Pos = i + 1
	return nil
}

// skip skips the specified number of bytes
func (b *BufferReader) Skip(n int) (err error) {
	if len(b.Buf) < b.Pos+n {
		return io.ErrUnexpectedEOF
	}
	b.Pos += n
	return
}

// skipFixed64 skips a fixed64
func (b *BufferReader) SkipFixed64() (err error) {
	return b.Skip(8)
}

// skipFixed32 skips a fixed32
func (b *BufferReader) SkipFixed32() (err error) {
	return b.Skip(4)
}

// skipBytes skips a set of bytes
func (b *BufferReader) SkipBytes() (err error) {
	n, err := b.DecodeVarint32()
	if err != nil {
		return err
	}
	return b.Skip(int(n))
}

// Done returns whether we are at the end of the protobuf
func (b *BufferReader) Done() bool {
	return b.Pos == len(b.Buf)
}

// Remaining returns how many bytes remain
func (b *BufferReader) Remaining() int {
	return len(b.Buf) - b.Pos
}