build: move e2e dependencies into e2e/go.mod

Several packages are only used while running the e2e suite. These
packages are less important to update, as the they can not influence the
final executable that is part of the Ceph-CSI container-image.

By moving these dependencies out of the main Ceph-CSI go.mod, it is
easier to identify if a reported CVE affects Ceph-CSI, or only the
testing (like most of the Kubernetes CVEs).

Signed-off-by: Niels de Vos <ndevos@ibm.com>
This commit is contained in:
Niels de Vos
2025-03-04 08:57:28 +01:00
committed by mergify[bot]
parent 15da101b1b
commit bec6090996
8047 changed files with 1407827 additions and 3453 deletions

View File

@ -0,0 +1,177 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"strconv"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// Export is a zero-length named type that exists only to export a set of
// functions that we do not want to appear in godoc.
type Export struct{}
// NewError formats a string according to the format specifier and arguments and
// returns an error that has a "proto" prefix.
func (Export) NewError(f string, x ...any) error {
return errors.New(f, x...)
}
// enum is any enum type generated by protoc-gen-go
// and must be a named int32 type.
type enum = any
// EnumOf returns the protoreflect.Enum interface over e.
// It returns nil if e is nil.
func (Export) EnumOf(e enum) protoreflect.Enum {
switch e := e.(type) {
case nil:
return nil
case protoreflect.Enum:
return e
default:
return legacyWrapEnum(reflect.ValueOf(e))
}
}
// EnumDescriptorOf returns the protoreflect.EnumDescriptor for e.
// It returns nil if e is nil.
func (Export) EnumDescriptorOf(e enum) protoreflect.EnumDescriptor {
switch e := e.(type) {
case nil:
return nil
case protoreflect.Enum:
return e.Descriptor()
default:
return LegacyLoadEnumDesc(reflect.TypeOf(e))
}
}
// EnumTypeOf returns the protoreflect.EnumType for e.
// It returns nil if e is nil.
func (Export) EnumTypeOf(e enum) protoreflect.EnumType {
switch e := e.(type) {
case nil:
return nil
case protoreflect.Enum:
return e.Type()
default:
return legacyLoadEnumType(reflect.TypeOf(e))
}
}
// EnumStringOf returns the enum value as a string, either as the name if
// the number is resolvable, or the number formatted as a string.
func (Export) EnumStringOf(ed protoreflect.EnumDescriptor, n protoreflect.EnumNumber) string {
ev := ed.Values().ByNumber(n)
if ev != nil {
return string(ev.Name())
}
return strconv.Itoa(int(n))
}
// message is any message type generated by protoc-gen-go
// and must be a pointer to a named struct type.
type message = any
// legacyMessageWrapper wraps a v2 message as a v1 message.
type legacyMessageWrapper struct{ m protoreflect.ProtoMessage }
func (m legacyMessageWrapper) Reset() { proto.Reset(m.m) }
func (m legacyMessageWrapper) String() string { return Export{}.MessageStringOf(m.m) }
func (m legacyMessageWrapper) ProtoMessage() {}
// ProtoMessageV1Of converts either a v1 or v2 message to a v1 message.
// It returns nil if m is nil.
func (Export) ProtoMessageV1Of(m message) protoiface.MessageV1 {
switch mv := m.(type) {
case nil:
return nil
case protoiface.MessageV1:
return mv
case unwrapper:
return Export{}.ProtoMessageV1Of(mv.protoUnwrap())
case protoreflect.ProtoMessage:
return legacyMessageWrapper{mv}
default:
panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
}
}
func (Export) protoMessageV2Of(m message) protoreflect.ProtoMessage {
switch mv := m.(type) {
case nil:
return nil
case protoreflect.ProtoMessage:
return mv
case legacyMessageWrapper:
return mv.m
case protoiface.MessageV1:
return nil
default:
panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
}
}
// ProtoMessageV2Of converts either a v1 or v2 message to a v2 message.
// It returns nil if m is nil.
func (Export) ProtoMessageV2Of(m message) protoreflect.ProtoMessage {
if m == nil {
return nil
}
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv
}
return legacyWrapMessage(reflect.ValueOf(m)).Interface()
}
// MessageOf returns the protoreflect.Message interface over m.
// It returns nil if m is nil.
func (Export) MessageOf(m message) protoreflect.Message {
if m == nil {
return nil
}
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect()
}
return legacyWrapMessage(reflect.ValueOf(m))
}
// MessageDescriptorOf returns the protoreflect.MessageDescriptor for m.
// It returns nil if m is nil.
func (Export) MessageDescriptorOf(m message) protoreflect.MessageDescriptor {
if m == nil {
return nil
}
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Descriptor()
}
return LegacyLoadMessageDesc(reflect.TypeOf(m))
}
// MessageTypeOf returns the protoreflect.MessageType for m.
// It returns nil if m is nil.
func (Export) MessageTypeOf(m message) protoreflect.MessageType {
if m == nil {
return nil
}
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type()
}
return legacyLoadMessageType(reflect.TypeOf(m), "")
}
// MessageStringOf returns the message value as a string,
// which is the message serialized in the protobuf text format.
func (Export) MessageStringOf(m protoreflect.ProtoMessage) string {
return prototext.MarshalOptions{Multiline: false}.Format(m)
}

View File

@ -0,0 +1,128 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"strconv"
"sync/atomic"
"unsafe"
"google.golang.org/protobuf/reflect/protoreflect"
)
func (Export) UnmarshalField(msg any, fieldNum int32) {
UnmarshalField(msg.(protoreflect.ProtoMessage).ProtoReflect(), protoreflect.FieldNumber(fieldNum))
}
// Present checks the presence set for a certain field number (zero
// based, ordered by appearance in original proto file). part is
// a pointer to the correct element in the bitmask array, num is the
// field number unaltered. Example (field number 70 -> part =
// &m.XXX_presence[1], num = 70)
func (Export) Present(part *uint32, num uint32) bool {
// This hook will read an unprotected shadow presence set if
// we're unning under the race detector
raceDetectHookPresent(part, num)
return atomic.LoadUint32(part)&(1<<(num%32)) > 0
}
// SetPresent adds a field to the presence set. part is a pointer to
// the relevant element in the array and num is the field number
// unaltered. size is the number of fields in the protocol
// buffer.
func (Export) SetPresent(part *uint32, num uint32, size uint32) {
// This hook will mutate an unprotected shadow presence set if
// we're running under the race detector
raceDetectHookSetPresent(part, num, presenceSize(size))
for {
old := atomic.LoadUint32(part)
if atomic.CompareAndSwapUint32(part, old, old|(1<<(num%32))) {
return
}
}
}
// SetPresentNonAtomic is like SetPresent, but operates non-atomically.
// It is meant for use by builder methods, where the message is known not
// to be accessible yet by other goroutines.
func (Export) SetPresentNonAtomic(part *uint32, num uint32, size uint32) {
// This hook will mutate an unprotected shadow presence set if
// we're running under the race detector
raceDetectHookSetPresent(part, num, presenceSize(size))
*part |= 1 << (num % 32)
}
// ClearPresence removes a field from the presence set. part is a
// pointer to the relevant element in the presence array and num is
// the field number unaltered.
func (Export) ClearPresent(part *uint32, num uint32) {
// This hook will mutate an unprotected shadow presence set if
// we're running under the race detector
raceDetectHookClearPresent(part, num)
for {
old := atomic.LoadUint32(part)
if atomic.CompareAndSwapUint32(part, old, old&^(1<<(num%32))) {
return
}
}
}
// interfaceToPointer takes a pointer to an empty interface whose value is a
// pointer type, and converts it into a "pointer" that points to the same
// target
func interfaceToPointer(i *any) pointer {
return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
}
func (p pointer) atomicGetPointer() pointer {
return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))}
}
func (p pointer) atomicSetPointer(q pointer) {
atomic.StorePointer((*unsafe.Pointer)(p.p), q.p)
}
// AtomicCheckPointerIsNil takes an interface (which is a pointer to a
// pointer) and returns true if the pointed-to pointer is nil (using an
// atomic load). This function is inlineable and, on x86, just becomes a
// simple load and compare.
func (Export) AtomicCheckPointerIsNil(ptr any) bool {
return interfaceToPointer(&ptr).atomicGetPointer().IsNil()
}
// AtomicSetPointer takes two interfaces (first is a pointer to a pointer,
// second is a pointer) and atomically sets the second pointer into location
// referenced by first pointer. Unfortunately, atomicSetPointer() does not inline
// (even on x86), so this does not become a simple store on x86.
func (Export) AtomicSetPointer(dstPtr, valPtr any) {
interfaceToPointer(&dstPtr).atomicSetPointer(interfaceToPointer(&valPtr))
}
// AtomicLoadPointer loads the pointer at the location pointed at by src,
// and stores that pointer value into the location pointed at by dst.
func (Export) AtomicLoadPointer(ptr Pointer, dst Pointer) {
*(*unsafe.Pointer)(unsafe.Pointer(dst)) = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(ptr)))
}
// AtomicInitializePointer makes ptr and dst point to the same value.
//
// If *ptr is a nil pointer, it sets *ptr = *dst.
//
// If *ptr is a non-nil pointer, it sets *dst = *ptr.
func (Export) AtomicInitializePointer(ptr Pointer, dst Pointer) {
if !atomic.CompareAndSwapPointer((*unsafe.Pointer)(ptr), unsafe.Pointer(nil), *(*unsafe.Pointer)(dst)) {
*(*unsafe.Pointer)(unsafe.Pointer(dst)) = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(ptr)))
}
}
// MessageFieldStringOf returns the field formatted as a string,
// either as the field name if resolvable otherwise as a decimal string.
func (Export) MessageFieldStringOf(md protoreflect.MessageDescriptor, n protoreflect.FieldNumber) string {
fd := md.Fields().ByNumber(n)
if fd != nil {
return string(fd.Name())
}
return strconv.Itoa(int(n))
}

View File

@ -0,0 +1,34 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !race
package impl
// There is no additional data as we're not running under race detector.
type RaceDetectHookData struct{}
// Empty stubs for when not using the race detector. Calls to these from index.go should be optimized away.
func (presence) raceDetectHookPresent(num uint32) {}
func (presence) raceDetectHookSetPresent(num uint32, size presenceSize) {}
func (presence) raceDetectHookClearPresent(num uint32) {}
func (presence) raceDetectHookAllocAndCopy(src presence) {}
// raceDetectHookPresent is called by the generated file interface
// (*proto.internalFuncs) Present to optionally read an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookPresent(field *uint32, num uint32) {}
// raceDetectHookSetPresent is called by the generated file interface
// (*proto.internalFuncs) SetPresent to optionally write an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookSetPresent(field *uint32, num uint32, size presenceSize) {}
// raceDetectHookClearPresent is called by the generated file interface
// (*proto.internalFuncs) ClearPresent to optionally write an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookClearPresent(field *uint32, num uint32) {}

View File

@ -0,0 +1,126 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build race
package impl
// When running under race detector, we add a presence map of bytes, that we can access
// in the hook functions so that we trigger the race detection whenever we have concurrent
// Read-Writes or Write-Writes. The race detector does not otherwise detect invalid concurrent
// access to lazy fields as all updates of bitmaps and pointers are done using atomic operations.
type RaceDetectHookData struct {
shadowPresence *[]byte
}
// Hooks for presence bitmap operations that allocate, read and write the shadowPresence
// using non-atomic operations.
func (data *RaceDetectHookData) raceDetectHookAlloc(size presenceSize) {
sp := make([]byte, size)
atomicStoreShadowPresence(&data.shadowPresence, &sp)
}
func (p presence) raceDetectHookPresent(num uint32) {
data := p.toRaceDetectData()
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp != nil {
_ = (*sp)[num]
}
}
func (p presence) raceDetectHookSetPresent(num uint32, size presenceSize) {
data := p.toRaceDetectData()
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp == nil {
data.raceDetectHookAlloc(size)
sp = atomicLoadShadowPresence(&data.shadowPresence)
}
(*sp)[num] = 1
}
func (p presence) raceDetectHookClearPresent(num uint32) {
data := p.toRaceDetectData()
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp != nil {
(*sp)[num] = 0
}
}
// raceDetectHookAllocAndCopy allocates a new shadowPresence slice at lazy and copies
// shadowPresence bytes from src to lazy.
func (p presence) raceDetectHookAllocAndCopy(q presence) {
sData := q.toRaceDetectData()
dData := p.toRaceDetectData()
if sData == nil {
return
}
srcSp := atomicLoadShadowPresence(&sData.shadowPresence)
if srcSp == nil {
atomicStoreShadowPresence(&dData.shadowPresence, nil)
return
}
n := len(*srcSp)
dSlice := make([]byte, n)
atomicStoreShadowPresence(&dData.shadowPresence, &dSlice)
for i := 0; i < n; i++ {
dSlice[i] = (*srcSp)[i]
}
}
// raceDetectHookPresent is called by the generated file interface
// (*proto.internalFuncs) Present to optionally read an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookPresent(field *uint32, num uint32) {
data := findPointerToRaceDetectData(field, num)
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp != nil {
_ = (*sp)[num]
}
}
// raceDetectHookSetPresent is called by the generated file interface
// (*proto.internalFuncs) SetPresent to optionally write an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookSetPresent(field *uint32, num uint32, size presenceSize) {
data := findPointerToRaceDetectData(field, num)
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp == nil {
data.raceDetectHookAlloc(size)
sp = atomicLoadShadowPresence(&data.shadowPresence)
}
(*sp)[num] = 1
}
// raceDetectHookClearPresent is called by the generated file interface
// (*proto.internalFuncs) ClearPresent to optionally write an unprotected
// shadow bitmap when race detection is enabled. In regular code it is
// a noop.
func raceDetectHookClearPresent(field *uint32, num uint32) {
data := findPointerToRaceDetectData(field, num)
if data == nil {
return
}
sp := atomicLoadShadowPresence(&data.shadowPresence)
if sp != nil {
(*sp)[num] = 0
}
}

View File

@ -0,0 +1,174 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"sync"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) {
var p pointer
if ms, ok := in.Message.(*messageState); ok {
p = ms.pointer()
} else {
p = in.Message.(*messageReflectWrapper).pointer()
}
return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
}
func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
mi.init()
if !mi.needsInitCheck {
return nil
}
if p.IsNil() {
for _, f := range mi.orderedCoderFields {
if f.isRequired {
return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
}
}
return nil
}
var presence presence
if mi.presenceOffset.IsValid() {
presence = p.Apply(mi.presenceOffset).PresenceInfo()
}
if mi.extensionOffset.IsValid() {
e := p.Apply(mi.extensionOffset).Extensions()
if err := mi.isInitExtensions(e); err != nil {
return err
}
}
for _, f := range mi.orderedCoderFields {
if !f.isRequired && f.funcs.isInit == nil {
continue
}
if f.presenceIndex != noPresence {
if !presence.Present(f.presenceIndex) {
if f.isRequired {
return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
}
continue
}
if f.funcs.isInit != nil {
f.mi.init()
if f.mi.needsInitCheck {
if f.isLazy && p.Apply(f.offset).AtomicGetPointer().IsNil() {
lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
if !lazy.AllowedPartial() {
// Nothing to see here, it was checked on unmarshal
continue
}
mi.lazyUnmarshal(p, f.num)
}
if err := f.funcs.isInit(p.Apply(f.offset), f); err != nil {
return err
}
}
}
continue
}
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
if f.isRequired {
return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
}
continue
}
if f.funcs.isInit == nil {
continue
}
if err := f.funcs.isInit(fptr, f); err != nil {
return err
}
}
return nil
}
func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
if ext == nil {
return nil
}
for _, x := range *ext {
ei := getExtensionFieldInfo(x.Type())
if ei.funcs.isInit == nil || x.isUnexpandedLazy() {
continue
}
v := x.Value()
if !v.IsValid() {
continue
}
if err := ei.funcs.isInit(v); err != nil {
return err
}
}
return nil
}
var (
needsInitCheckMu sync.Mutex
needsInitCheckMap sync.Map
)
// needsInitCheck reports whether a message needs to be checked for partial initialization.
//
// It returns true if the message transitively includes any required or extension fields.
func needsInitCheck(md protoreflect.MessageDescriptor) bool {
if v, ok := needsInitCheckMap.Load(md); ok {
if has, ok := v.(bool); ok {
return has
}
}
needsInitCheckMu.Lock()
defer needsInitCheckMu.Unlock()
return needsInitCheckLocked(md)
}
func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) {
if v, ok := needsInitCheckMap.Load(md); ok {
// If has is true, we've previously determined that this message
// needs init checks.
//
// If has is false, we've previously determined that it can never
// be uninitialized.
//
// If has is not a bool, we've just encountered a cycle in the
// message graph. In this case, it is safe to return false: If
// the message does have required fields, we'll detect them later
// in the graph traversal.
has, ok := v.(bool)
return ok && has
}
needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
defer func() {
needsInitCheckMap.Store(md, has)
}()
if md.RequiredNumbers().Len() > 0 {
return true
}
if md.ExtensionRanges().Len() > 0 {
return true
}
for i := 0; i < md.Fields().Len(); i++ {
fd := md.Fields().Get(i)
// Map keys are never messages, so just consider the map value.
if fd.IsMap() {
fd = fd.MapValue()
}
fmd := fd.Message()
if fmd != nil && needsInitCheckLocked(fmd) {
return true
}
}
return false
}

View File

@ -0,0 +1,228 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"sync"
"sync/atomic"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect"
)
type extensionFieldInfo struct {
wiretag uint64
tagsize int
unmarshalNeedsValue bool
funcs valueCoderFuncs
validation validationInfo
}
func getExtensionFieldInfo(xt protoreflect.ExtensionType) *extensionFieldInfo {
if xi, ok := xt.(*ExtensionInfo); ok {
xi.lazyInit()
return xi.info
}
// Ideally we'd cache the resulting *extensionFieldInfo so we don't have to
// recompute this metadata repeatedly. But without support for something like
// weak references, such a cache would pin temporary values (like dynamic
// extension types, constructed for the duration of a user request) to the
// heap forever, causing memory usage of the cache to grow unbounded.
// See discussion in https://github.com/golang/protobuf/issues/1521.
return makeExtensionFieldInfo(xt.TypeDescriptor())
}
func makeExtensionFieldInfo(xd protoreflect.ExtensionDescriptor) *extensionFieldInfo {
var wiretag uint64
if !xd.IsPacked() {
wiretag = protowire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
} else {
wiretag = protowire.EncodeTag(xd.Number(), protowire.BytesType)
}
e := &extensionFieldInfo{
wiretag: wiretag,
tagsize: protowire.SizeVarint(wiretag),
funcs: encoderFuncsForValue(xd),
}
// Does the unmarshal function need a value passed to it?
// This is true for composite types, where we pass in a message, list, or map to fill in,
// and for enums, where we pass in a prototype value to specify the concrete enum type.
switch xd.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.EnumKind:
e.unmarshalNeedsValue = true
default:
if xd.Cardinality() == protoreflect.Repeated {
e.unmarshalNeedsValue = true
}
}
return e
}
type lazyExtensionValue struct {
atomicOnce uint32 // atomically set if value is valid
mu sync.Mutex
xi *extensionFieldInfo
value protoreflect.Value
b []byte
}
type ExtensionField struct {
typ protoreflect.ExtensionType
// value is either the value of GetValue,
// or a *lazyExtensionValue that then returns the value of GetValue.
value protoreflect.Value
lazy *lazyExtensionValue
}
func (f *ExtensionField) appendLazyBytes(xt protoreflect.ExtensionType, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, b []byte) {
if f.lazy == nil {
f.lazy = &lazyExtensionValue{xi: xi}
}
f.typ = xt
f.lazy.xi = xi
f.lazy.b = protowire.AppendTag(f.lazy.b, num, wtyp)
f.lazy.b = append(f.lazy.b, b...)
}
func (f *ExtensionField) canLazy(xt protoreflect.ExtensionType) bool {
if f.typ == nil {
return true
}
if f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
return true
}
return false
}
// isUnexpandedLazy returns true if the ExensionField is lazy and not
// yet expanded, which means it's present and already checked for
// initialized required fields.
func (f *ExtensionField) isUnexpandedLazy() bool {
return f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0
}
// lazyBuffer retrieves the buffer for a lazy extension if it's not yet expanded.
//
// The returned buffer has to be kept over whatever operation we're planning,
// as re-retrieving it will fail after the message is lazily decoded.
func (f *ExtensionField) lazyBuffer() []byte {
// This function might be in the critical path, so check the atomic without
// taking a look first, then only take the lock if needed.
if !f.isUnexpandedLazy() {
return nil
}
f.lazy.mu.Lock()
defer f.lazy.mu.Unlock()
return f.lazy.b
}
func (f *ExtensionField) lazyInit() {
f.lazy.mu.Lock()
defer f.lazy.mu.Unlock()
if atomic.LoadUint32(&f.lazy.atomicOnce) == 1 {
return
}
if f.lazy.xi != nil {
b := f.lazy.b
val := f.typ.New()
for len(b) > 0 {
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = protowire.ConsumeVarint(b)
if n < 0 {
panic(errors.New("bad tag in lazy extension decoding"))
}
b = b[n:]
}
num := protowire.Number(tag >> 3)
wtyp := protowire.Type(tag & 7)
var out unmarshalOutput
var err error
val, out, err = f.lazy.xi.funcs.unmarshal(b, val, num, wtyp, lazyUnmarshalOptions)
if err != nil {
panic(errors.New("decode failure in lazy extension decoding: %v", err))
}
b = b[out.n:]
}
f.lazy.value = val
} else {
panic("No support for lazy fns for ExtensionField")
}
f.lazy.xi = nil
f.lazy.b = nil
atomic.StoreUint32(&f.lazy.atomicOnce, 1)
}
// Set sets the type and value of the extension field.
// This must not be called concurrently.
func (f *ExtensionField) Set(t protoreflect.ExtensionType, v protoreflect.Value) {
f.typ = t
f.value = v
f.lazy = nil
}
// Value returns the value of the extension field.
// This may be called concurrently.
func (f *ExtensionField) Value() protoreflect.Value {
if f.lazy != nil {
if atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
f.lazyInit()
}
return f.lazy.value
}
return f.value
}
// Type returns the type of the extension field.
// This may be called concurrently.
func (f ExtensionField) Type() protoreflect.ExtensionType {
return f.typ
}
// IsSet returns whether the extension field is set.
// This may be called concurrently.
func (f ExtensionField) IsSet() bool {
return f.typ != nil
}
// IsLazy reports whether a field is lazily encoded.
// It is exported for testing.
func IsLazy(m protoreflect.Message, fd protoreflect.FieldDescriptor) bool {
var mi *MessageInfo
var p pointer
switch m := m.(type) {
case *messageState:
mi = m.messageInfo()
p = m.pointer()
case *messageReflectWrapper:
mi = m.messageInfo()
p = m.pointer()
default:
return false
}
xd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
if !ok {
return false
}
xt := xd.Type()
ext := mi.extensionMap(p)
if ext == nil {
return false
}
f, ok := (*ext)[int32(fd.Number())]
if !ok {
return false
}
return f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0
}

View File

@ -0,0 +1,788 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
type errInvalidUTF8 struct{}
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
func (errInvalidUTF8) Unwrap() error { return errors.Error }
// initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
//
// For size, marshal, and isInit operations, functions are set only on the first field
// in the oneof. The functions are called when the oneof is non-nil, and will dispatch
// to the appropriate field-specific function as necessary.
//
// The unmarshal function is set on each field individually as usual.
func (mi *MessageInfo) initOneofFieldCoders(od protoreflect.OneofDescriptor, si structInfo) {
fs := si.oneofsByName[od.Name()]
ft := fs.Type
oneofFields := make(map[reflect.Type]*coderFieldInfo)
needIsInit := false
fields := od.Fields()
for i, lim := 0, fields.Len(); i < lim; i++ {
fd := od.Fields().Get(i)
num := fd.Number()
// Make a copy of the original coderFieldInfo for use in unmarshaling.
//
// oneofFields[oneofType].funcs.marshal is the field-specific marshal function.
//
// mi.coderFields[num].marshal is set on only the first field in the oneof,
// and dispatches to the field-specific marshaler in oneofFields.
cf := *mi.coderFields[num]
ot := si.oneofWrappersByNumber[num]
cf.ft = ot.Field(0).Type
cf.mi, cf.funcs = fieldCoder(fd, cf.ft)
oneofFields[ot] = &cf
if cf.funcs.isInit != nil {
needIsInit = true
}
mi.coderFields[num].funcs.unmarshal = func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
var vw reflect.Value // pointer to wrapper type
vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
vw = vi.Elem()
} else {
vw = reflect.New(ot)
}
out, err := cf.funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, &cf, opts)
if err != nil {
return out, err
}
if cf.funcs.isInit == nil {
out.initialized = true
}
vi.Set(vw)
return out, nil
}
}
getInfo := func(p pointer) (pointer, *coderFieldInfo) {
v := p.AsValueOf(ft).Elem()
if v.IsNil() {
return pointer{}, nil
}
v = v.Elem() // interface -> *struct
if v.IsNil() {
return pointer{}, nil
}
return pointerOfValue(v).Apply(zeroOffset), oneofFields[v.Elem().Type()]
}
first := mi.coderFields[od.Fields().Get(0).Number()]
first.funcs.size = func(p pointer, _ *coderFieldInfo, opts marshalOptions) int {
p, info := getInfo(p)
if info == nil || info.funcs.size == nil {
return 0
}
return info.funcs.size(p, info, opts)
}
first.funcs.marshal = func(b []byte, p pointer, _ *coderFieldInfo, opts marshalOptions) ([]byte, error) {
p, info := getInfo(p)
if info == nil || info.funcs.marshal == nil {
return b, nil
}
return info.funcs.marshal(b, p, info, opts)
}
first.funcs.merge = func(dst, src pointer, _ *coderFieldInfo, opts mergeOptions) {
srcp, srcinfo := getInfo(src)
if srcinfo == nil || srcinfo.funcs.merge == nil {
return
}
dstp, dstinfo := getInfo(dst)
if dstinfo != srcinfo {
dst.AsValueOf(ft).Elem().Set(reflect.New(src.AsValueOf(ft).Elem().Elem().Elem().Type()))
dstp = pointerOfValue(dst.AsValueOf(ft).Elem().Elem()).Apply(zeroOffset)
}
srcinfo.funcs.merge(dstp, srcp, srcinfo, opts)
}
if needIsInit {
first.funcs.isInit = func(p pointer, _ *coderFieldInfo) error {
p, info := getInfo(p)
if info == nil || info.funcs.isInit == nil {
return nil
}
return info.funcs.isInit(p, info)
}
}
}
func makeMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if mi := getMessageInfo(ft); mi != nil {
funcs := pointerCoderFuncs{
size: sizeMessageInfo,
marshal: appendMessageInfo,
unmarshal: consumeMessageInfo,
merge: mergeMessage,
}
if needsInitCheck(mi.Desc) {
funcs.isInit = isInitMessageInfo
}
return funcs
} else {
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
m := asMessage(p.AsValueOf(ft).Elem())
return sizeMessage(m, f.tagsize, opts)
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, f.wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
}
return consumeMessage(b, asMessage(mp), wtyp, opts)
},
isInit: func(p pointer, f *coderFieldInfo) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.CheckInitialized(m)
},
merge: mergeMessage,
}
}
}
func sizeMessageInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return protowire.SizeBytes(f.mi.sizePointer(p.Elem(), opts)) + f.tagsize
}
func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
calculatedSize := f.mi.sizePointer(p.Elem(), opts)
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
o, err := f.mi.unmarshalPointer(v, p.Elem(), 0, opts)
if err != nil {
return out, err
}
out.n = n
out.initialized = o.initialized
return out, nil
}
func isInitMessageInfo(p pointer, f *coderFieldInfo) error {
return f.mi.checkInitializedPointer(p.Elem())
}
func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int {
return protowire.SizeBytes(opts.Options().Size(m)) + tagsize
}
func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
mopts := opts.Options()
calculatedSize := mopts.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := mopts.MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: v,
Message: m.ProtoReflect(),
})
if err != nil {
return out, err
}
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return out, nil
}
func sizeMessageValue(v protoreflect.Value, tagsize int, opts marshalOptions) int {
m := v.Message().Interface()
return sizeMessage(m, tagsize, opts)
}
func appendMessageValue(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
m := v.Message().Interface()
return appendMessage(b, m, wiretag, opts)
}
func consumeMessageValue(b []byte, v protoreflect.Value, _ protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error) {
m := v.Message().Interface()
out, err := consumeMessage(b, m, wtyp, opts)
return v, out, err
}
func isInitMessageValue(v protoreflect.Value) error {
m := v.Message().Interface()
return proto.CheckInitialized(m)
}
var coderMessageValue = valueCoderFuncs{
size: sizeMessageValue,
marshal: appendMessageValue,
unmarshal: consumeMessageValue,
isInit: isInitMessageValue,
merge: mergeMessageValue,
}
func sizeGroupValue(v protoreflect.Value, tagsize int, opts marshalOptions) int {
m := v.Message().Interface()
return sizeGroup(m, tagsize, opts)
}
func appendGroupValue(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
m := v.Message().Interface()
return appendGroup(b, m, wiretag, opts)
}
func consumeGroupValue(b []byte, v protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error) {
m := v.Message().Interface()
out, err := consumeGroup(b, m, num, wtyp, opts)
return v, out, err
}
var coderGroupValue = valueCoderFuncs{
size: sizeGroupValue,
marshal: appendGroupValue,
unmarshal: consumeGroupValue,
isInit: isInitMessageValue,
merge: mergeMessageValue,
}
func makeGroupFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
num := fd.Number()
if mi := getMessageInfo(ft); mi != nil {
funcs := pointerCoderFuncs{
size: sizeGroupType,
marshal: appendGroupType,
unmarshal: consumeGroupType,
merge: mergeMessage,
}
if needsInitCheck(mi.Desc) {
funcs.isInit = isInitMessageInfo
}
return funcs
} else {
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
m := asMessage(p.AsValueOf(ft).Elem())
return sizeGroup(m, f.tagsize, opts)
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, f.wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
}
return consumeGroup(b, asMessage(mp), num, wtyp, opts)
},
isInit: func(p pointer, f *coderFieldInfo) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.CheckInitialized(m)
},
merge: mergeMessage,
}
}
}
func sizeGroupType(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return 2*f.tagsize + f.mi.sizePointer(p.Elem(), opts)
}
func appendGroupType(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
b = protowire.AppendVarint(b, f.wiretag) // start group
b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
b = protowire.AppendVarint(b, f.wiretag+1) // end group
return b, err
}
func consumeGroupType(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.StartGroupType {
return out, errUnknown
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
}
func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int {
return 2*tagsize + opts.Options().Size(m)
}
func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
b = protowire.AppendVarint(b, wiretag) // start group
b, err := opts.Options().MarshalAppend(b, m)
b = protowire.AppendVarint(b, wiretag+1) // end group
return b, err
}
func consumeGroup(b []byte, m proto.Message, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.StartGroupType {
return out, errUnknown
}
b, n := protowire.ConsumeGroup(num, b)
if n < 0 {
return out, errDecode
}
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: b,
Message: m.ProtoReflect(),
})
if err != nil {
return out, err
}
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return out, nil
}
func makeMessageSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if mi := getMessageInfo(ft); mi != nil {
funcs := pointerCoderFuncs{
size: sizeMessageSliceInfo,
marshal: appendMessageSliceInfo,
unmarshal: consumeMessageSliceInfo,
merge: mergeMessageSlice,
}
if needsInitCheck(mi.Desc) {
funcs.isInit = isInitMessageSliceInfo
}
return funcs
}
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return sizeMessageSlice(p, ft, f.tagsize, opts)
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, f.wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageSlice(b, p, ft, wtyp, opts)
},
isInit: func(p pointer, f *coderFieldInfo) error {
return isInitMessageSlice(p, ft)
},
merge: mergeMessageSlice,
}
}
func sizeMessageSliceInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
s := p.PointerSlice()
n := 0
for _, v := range s {
n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
}
return n
}
func appendMessageSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
s := p.PointerSlice()
var err error
for _, v := range s {
b = protowire.AppendVarint(b, f.wiretag)
siz := f.mi.sizePointer(v, opts)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
func consumeMessageSliceInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
m := reflect.New(f.mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
if err != nil {
return out, err
}
p.AppendPointerSlice(mp)
out.n = n
out.initialized = o.initialized
return out, nil
}
func isInitMessageSliceInfo(p pointer, f *coderFieldInfo) error {
s := p.PointerSlice()
for _, v := range s {
if err := f.mi.checkInitializedPointer(v); err != nil {
return err
}
}
return nil
}
func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
mopts := opts.Options()
s := p.PointerSlice()
var err error
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
b = protowire.AppendVarint(b, wiretag)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
mp := reflect.New(goType.Elem())
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: v,
Message: asMessage(mp).ProtoReflect(),
})
if err != nil {
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return out, nil
}
func isInitMessageSlice(p pointer, goType reflect.Type) error {
s := p.PointerSlice()
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
if err := proto.CheckInitialized(m); err != nil {
return err
}
}
return nil
}
// Slices of messages
func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
list := listv.List()
mopts := opts.Options()
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
b = protowire.AppendVarint(b, wiretag)
siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
func consumeMessageSliceValue(b []byte, listv protoreflect.Value, _ protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != protowire.BytesType {
return protoreflect.Value{}, out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return protoreflect.Value{}, out, errDecode
}
m := list.NewElement()
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: v,
Message: m.Message(),
})
if err != nil {
return protoreflect.Value{}, out, err
}
list.Append(m)
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return listv, out, nil
}
func isInitMessageSliceValue(listv protoreflect.Value) error {
list := listv.List()
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
if err := proto.CheckInitialized(m); err != nil {
return err
}
}
return nil
}
var coderMessageSliceValue = valueCoderFuncs{
size: sizeMessageSliceValue,
marshal: appendMessageSliceValue,
unmarshal: consumeMessageSliceValue,
isInit: isInitMessageSliceValue,
merge: mergeMessageListValue,
}
func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
n += 2*tagsize + mopts.Size(m)
}
return n
}
func appendGroupSliceValue(b []byte, listv protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) {
list := listv.List()
mopts := opts.Options()
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
b = protowire.AppendVarint(b, wiretag) // start group
var err error
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
b = protowire.AppendVarint(b, wiretag+1) // end group
}
return b, nil
}
func consumeGroupSliceValue(b []byte, listv protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != protowire.StartGroupType {
return protoreflect.Value{}, out, errUnknown
}
b, n := protowire.ConsumeGroup(num, b)
if n < 0 {
return protoreflect.Value{}, out, errDecode
}
m := list.NewElement()
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: b,
Message: m.Message(),
})
if err != nil {
return protoreflect.Value{}, out, err
}
list.Append(m)
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return listv, out, nil
}
var coderGroupSliceValue = valueCoderFuncs{
size: sizeGroupSliceValue,
marshal: appendGroupSliceValue,
unmarshal: consumeGroupSliceValue,
isInit: isInitMessageSliceValue,
merge: mergeMessageListValue,
}
func makeGroupSliceFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
num := fd.Number()
if mi := getMessageInfo(ft); mi != nil {
funcs := pointerCoderFuncs{
size: sizeGroupSliceInfo,
marshal: appendGroupSliceInfo,
unmarshal: consumeGroupSliceInfo,
merge: mergeMessageSlice,
}
if needsInitCheck(mi.Desc) {
funcs.isInit = isInitMessageSliceInfo
}
return funcs
}
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return sizeGroupSlice(p, ft, f.tagsize, opts)
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, f.wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupSlice(b, p, num, wtyp, ft, opts)
},
isInit: func(p pointer, f *coderFieldInfo) error {
return isInitMessageSlice(p, ft)
},
merge: mergeMessageSlice,
}
}
func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int {
mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(messageType.Elem()))
n += 2*tagsize + mopts.Size(m)
}
return n
}
func appendGroupSlice(b []byte, p pointer, wiretag uint64, messageType reflect.Type, opts marshalOptions) ([]byte, error) {
s := p.PointerSlice()
var err error
for _, v := range s {
m := asMessage(v.AsValueOf(messageType.Elem()))
b = protowire.AppendVarint(b, wiretag) // start group
b, err = opts.Options().MarshalAppend(b, m)
if err != nil {
return b, err
}
b = protowire.AppendVarint(b, wiretag+1) // end group
}
return b, nil
}
func consumeGroupSlice(b []byte, p pointer, num protowire.Number, wtyp protowire.Type, goType reflect.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.StartGroupType {
return out, errUnknown
}
b, n := protowire.ConsumeGroup(num, b)
if n < 0 {
return out, errDecode
}
mp := reflect.New(goType.Elem())
o, err := opts.Options().UnmarshalState(protoiface.UnmarshalInput{
Buf: b,
Message: asMessage(mp).ProtoReflect(),
})
if err != nil {
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
out.n = n
out.initialized = o.Flags&protoiface.UnmarshalInitialized != 0
return out, nil
}
func sizeGroupSliceInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
s := p.PointerSlice()
n := 0
for _, v := range s {
n += 2*f.tagsize + f.mi.sizePointer(v, opts)
}
return n
}
func appendGroupSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
s := p.PointerSlice()
var err error
for _, v := range s {
b = protowire.AppendVarint(b, f.wiretag) // start group
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
b = protowire.AppendVarint(b, f.wiretag+1) // end group
}
return b, nil
}
func consumeGroupSliceInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
if wtyp != protowire.StartGroupType {
return unmarshalOutput{}, errUnknown
}
m := reflect.New(f.mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
out, err := f.mi.unmarshalPointer(b, mp, f.num, opts)
if err != nil {
return out, err
}
p.AppendPointerSlice(mp)
return out, nil
}
func asMessage(v reflect.Value) protoreflect.ProtoMessage {
if m, ok := v.Interface().(protoreflect.ProtoMessage); ok {
return m
}
return legacyWrapMessage(v).Interface()
}

View File

@ -0,0 +1,264 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect"
)
func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
mi := getMessageInfo(ft)
if mi == nil {
panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft))
}
switch fd.Kind() {
case protoreflect.MessageKind:
return mi, pointerCoderFuncs{
size: sizeOpaqueMessage,
marshal: appendOpaqueMessage,
unmarshal: consumeOpaqueMessage,
isInit: isInitOpaqueMessage,
merge: mergeOpaqueMessage,
}
case protoreflect.GroupKind:
return mi, pointerCoderFuncs{
size: sizeOpaqueGroup,
marshal: appendOpaqueGroup,
unmarshal: consumeOpaqueGroup,
isInit: isInitOpaqueMessage,
merge: mergeOpaqueMessage,
}
}
panic("unexpected field kind")
}
func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize
}
func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
mp := p.AtomicGetPointer()
calculatedSize := f.mi.sizePointer(mp, opts)
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := f.mi.marshalAppendPointer(b, mp, opts)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
mp := p.AtomicGetPointer()
if mp.IsNil() {
mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
if err != nil {
return out, err
}
out.n = n
out.initialized = o.initialized
return out, nil
}
func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error {
mp := p.AtomicGetPointer()
if mp.IsNil() {
return nil
}
return f.mi.checkInitializedPointer(mp)
}
func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
dstmp := dst.AtomicGetPointer()
if dstmp.IsNil() {
dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts)
}
func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts)
}
func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
b = protowire.AppendVarint(b, f.wiretag) // start group
b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts)
b = protowire.AppendVarint(b, f.wiretag+1) // end group
return b, err
}
func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.StartGroupType {
return out, errUnknown
}
mp := p.AtomicGetPointer()
if mp.IsNil() {
mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
o, e := f.mi.unmarshalPointer(b, mp, f.num, opts)
return o, e
}
func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft))
}
mt := ft.Elem().Elem() // *[]*T -> *T
mi := getMessageInfo(mt)
if mi == nil {
panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt))
}
switch fd.Kind() {
case protoreflect.MessageKind:
return mi, pointerCoderFuncs{
size: sizeOpaqueMessageSlice,
marshal: appendOpaqueMessageSlice,
unmarshal: consumeOpaqueMessageSlice,
isInit: isInitOpaqueMessageSlice,
merge: mergeOpaqueMessageSlice,
}
case protoreflect.GroupKind:
return mi, pointerCoderFuncs{
size: sizeOpaqueGroupSlice,
marshal: appendOpaqueGroupSlice,
unmarshal: consumeOpaqueGroupSlice,
isInit: isInitOpaqueMessageSlice,
merge: mergeOpaqueMessageSlice,
}
}
panic("unexpected field kind")
}
func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
s := p.AtomicGetPointer().PointerSlice()
n := 0
for _, v := range s {
n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
}
return n
}
func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
s := p.AtomicGetPointer().PointerSlice()
var err error
for _, v := range s {
b = protowire.AppendVarint(b, f.wiretag)
siz := f.mi.sizePointer(v, opts)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
if err != nil {
return out, err
}
sp := p.AtomicGetPointer()
if sp.IsNil() {
sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
}
sp.AppendPointerSlice(mp)
out.n = n
out.initialized = o.initialized
return out, nil
}
func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error {
sp := p.AtomicGetPointer()
if sp.IsNil() {
return nil
}
s := sp.PointerSlice()
for _, v := range s {
if err := f.mi.checkInitializedPointer(v); err != nil {
return err
}
}
return nil
}
func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
ds := dst.AtomicGetPointer()
if ds.IsNil() {
ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
}
for _, sp := range src.AtomicGetPointer().PointerSlice() {
dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
f.mi.mergePointer(dm, sp, opts)
ds.AppendPointerSlice(dm)
}
}
func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
s := p.AtomicGetPointer().PointerSlice()
n := 0
for _, v := range s {
n += 2*f.tagsize + f.mi.sizePointer(v, opts)
}
return n
}
func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
s := p.AtomicGetPointer().PointerSlice()
var err error
for _, v := range s {
b = protowire.AppendVarint(b, f.wiretag) // start group
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
b = protowire.AppendVarint(b, f.wiretag+1) // end group
}
return b, nil
}
func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.StartGroupType {
return out, errUnknown
}
mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
out, err = f.mi.unmarshalPointer(b, mp, f.num, opts)
if err != nil {
return out, err
}
sp := p.AtomicGetPointer()
if sp.IsNil() {
sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
}
sp.AppendPointerSlice(mp)
return out, err
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
)
type mapInfo struct {
goType reflect.Type
keyWiretag uint64
valWiretag uint64
keyFuncs valueCoderFuncs
valFuncs valueCoderFuncs
keyZero protoreflect.Value
keyKind protoreflect.Kind
conv *mapConverter
}
func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
valField := fd.MapValue()
keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
keyFuncs := encoderFuncsForValue(keyField)
valFuncs := encoderFuncsForValue(valField)
conv := newMapConverter(ft, fd)
mapi := &mapInfo{
goType: ft,
keyWiretag: keyWiretag,
valWiretag: valWiretag,
keyFuncs: keyFuncs,
valFuncs: valFuncs,
keyZero: keyField.Default(),
keyKind: keyField.Kind(),
conv: conv,
}
if valField.Kind() == protoreflect.MessageKind {
valueMessage = getMessageInfo(ft.Elem())
}
funcs = pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
}
if f.mi == nil {
return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
} else {
return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
}
},
}
switch valField.Kind() {
case protoreflect.MessageKind:
funcs.merge = mergeMapOfMessage
case protoreflect.BytesKind:
funcs.merge = mergeMapOfBytes
default:
funcs.merge = mergeMap
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer, f *coderFieldInfo) error {
return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
}
}
return valueMessage, funcs
}
const (
mapKeyTagSize = 1 // field 1, tag size 1.
mapValTagSize = 1 // field 2, tag size 2.
)
func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
if mapv.Len() == 0 {
return 0
}
n := 0
iter := mapv.MapRange()
for iter.Next() {
key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
var valSize int
value := mapi.conv.valConv.PBValueOf(iter.Value())
if f.mi == nil {
valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
} else {
p := pointerOfValue(iter.Value())
valSize += mapValTagSize
valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
}
n += f.tagsize + protowire.SizeBytes(keySize+valSize)
}
return n
}
func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
b, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
var (
key = mapi.keyZero
val = mapi.conv.valConv.New()
)
for len(b) > 0 {
num, wtyp, n := protowire.ConsumeTag(b)
if n < 0 {
return out, errDecode
}
if num > protowire.MaxValidNumber {
return out, errDecode
}
b = b[n:]
err := errUnknown
switch num {
case genid.MapEntry_Key_field_number:
var v protoreflect.Value
var o unmarshalOutput
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
n = o.n
case genid.MapEntry_Value_field_number:
var v protoreflect.Value
var o unmarshalOutput
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
}
val = v
n = o.n
}
if err == errUnknown {
n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, errDecode
}
} else if err != nil {
return out, err
}
b = b[n:]
}
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
out.n = n
return out, nil
}
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
b, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
var (
key = mapi.keyZero
val = reflect.New(f.mi.GoReflectType.Elem())
)
for len(b) > 0 {
num, wtyp, n := protowire.ConsumeTag(b)
if n < 0 {
return out, errDecode
}
if num > protowire.MaxValidNumber {
return out, errDecode
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v protoreflect.Value
var o unmarshalOutput
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
n = o.n
case 2:
if wtyp != protowire.BytesType {
break
}
var v []byte
v, n = protowire.ConsumeBytes(b)
if n < 0 {
return out, errDecode
}
var o unmarshalOutput
o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
if o.initialized {
// Consider this map item initialized so long as we see
// an initialized value.
out.initialized = true
}
}
if err == errUnknown {
n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, errDecode
}
} else if err != nil {
return out, err
}
b = b[n:]
}
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
out.n = n
return out, nil
}
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
if f.mi == nil {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := mapi.conv.valConv.PBValueOf(valrv)
size := 0
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(val, mapValTagSize, opts)
b = protowire.AppendVarint(b, uint64(size))
before := len(b)
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return nil, err
}
b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
if measuredSize := len(b) - before; size != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(size, measuredSize)
}
return b, err
} else {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := pointerOfValue(valrv)
valSize := f.mi.sizePointer(val, opts)
size := 0
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapValTagSize + protowire.SizeBytes(valSize)
b = protowire.AppendVarint(b, uint64(size))
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return nil, err
}
b = protowire.AppendVarint(b, mapi.valWiretag)
b = protowire.AppendVarint(b, uint64(valSize))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, val, opts)
if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
}
return b, err
}
}
func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
if mapv.Len() == 0 {
return b, nil
}
if opts.Deterministic() {
return appendMapDeterministic(b, mapv, mapi, f, opts)
}
iter := mapv.MapRange()
for iter.Next() {
var err error
b = protowire.AppendVarint(b, f.wiretag)
b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
if err != nil {
return b, err
}
}
return b, nil
}
func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
keys := mapv.MapKeys()
sort.Slice(keys, func(i, j int) bool {
switch keys[i].Kind() {
case reflect.Bool:
return !keys[i].Bool() && keys[j].Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return keys[i].Int() < keys[j].Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return keys[i].Uint() < keys[j].Uint()
case reflect.Float32, reflect.Float64:
return keys[i].Float() < keys[j].Float()
case reflect.String:
return keys[i].String() < keys[j].String()
default:
panic("invalid kind: " + keys[i].Kind().String())
}
})
for _, key := range keys {
var err error
b = protowire.AppendVarint(b, f.wiretag)
b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
if err != nil {
return b, err
}
}
return b, nil
}
func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
if mi := f.mi; mi != nil {
mi.init()
if !mi.needsInitCheck {
return nil
}
iter := mapv.MapRange()
for iter.Next() {
val := pointerOfValue(iter.Value())
if err := mi.checkInitializedPointer(val); err != nil {
return err
}
}
} else {
iter := mapv.MapRange()
for iter.Next() {
val := mapi.conv.valConv.PBValueOf(iter.Value())
if err := mapi.valFuncs.isInit(val); err != nil {
return err
}
}
}
return nil
}
func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
dstm := dst.AsValueOf(f.ft).Elem()
srcm := src.AsValueOf(f.ft).Elem()
if srcm.Len() == 0 {
return
}
if dstm.IsNil() {
dstm.Set(reflect.MakeMap(f.ft))
}
iter := srcm.MapRange()
for iter.Next() {
dstm.SetMapIndex(iter.Key(), iter.Value())
}
}
func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
dstm := dst.AsValueOf(f.ft).Elem()
srcm := src.AsValueOf(f.ft).Elem()
if srcm.Len() == 0 {
return
}
if dstm.IsNil() {
dstm.Set(reflect.MakeMap(f.ft))
}
iter := srcm.MapRange()
for iter.Next() {
dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
}
}
func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
dstm := dst.AsValueOf(f.ft).Elem()
srcm := src.AsValueOf(f.ft).Elem()
if srcm.Len() == 0 {
return
}
if dstm.IsNil() {
dstm.Set(reflect.MakeMap(f.ft))
}
iter := srcm.MapRange()
for iter.Next() {
val := reflect.New(f.ft.Elem().Elem())
if f.mi != nil {
f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
} else {
opts.Merge(asMessage(val), asMessage(iter.Value()))
}
dstm.SetMapIndex(iter.Key(), val)
}
}

View File

@ -0,0 +1,230 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// coderMessageInfo contains per-message information used by the fast-path functions.
// This is a different type from MessageInfo to keep MessageInfo as general-purpose as
// possible.
type coderMessageInfo struct {
methods protoiface.Methods
orderedCoderFields []*coderFieldInfo
denseCoderFields []*coderFieldInfo
coderFields map[protowire.Number]*coderFieldInfo
sizecacheOffset offset
unknownOffset offset
unknownPtrKind bool
extensionOffset offset
needsInitCheck bool
isMessageSet bool
numRequiredFields uint8
lazyOffset offset
presenceOffset offset
presenceSize presenceSize
}
type coderFieldInfo struct {
funcs pointerCoderFuncs // fast-path per-field functions
mi *MessageInfo // field's message
ft reflect.Type
validation validationInfo // information used by message validation
num protoreflect.FieldNumber // field number
offset offset // struct field offset
wiretag uint64 // field tag (number + wire type)
tagsize int // size of the varint-encoded tag
isPointer bool // true if IsNil may be called on the struct field
isRequired bool // true if field is required
isLazy bool
presenceIndex uint32
}
const noPresence = 0xffffffff
func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
mi.sizecacheOffset = invalidOffset
mi.unknownOffset = invalidOffset
mi.extensionOffset = invalidOffset
mi.lazyOffset = invalidOffset
mi.presenceOffset = si.presenceOffset
if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
mi.sizecacheOffset = si.sizecacheOffset
}
if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
mi.unknownOffset = si.unknownOffset
mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
}
if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
mi.extensionOffset = si.extensionOffset
}
mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
fields := mi.Desc.Fields()
preallocFields := make([]coderFieldInfo, fields.Len())
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
fs := si.fieldsByNumber[fd.Number()]
isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
if isOneof {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
ft := fs.Type
var wiretag uint64
if !fd.IsPacked() {
wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
} else {
wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
}
var fieldOffset offset
var funcs pointerCoderFuncs
var childMessage *MessageInfo
switch {
case ft == nil:
// This never occurs for generated message types.
// It implies that a hand-crafted type has missing Go fields
// for specific protobuf message fields.
funcs = pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
return 0
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
return nil, nil
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
panic("missing Go struct field for " + string(fd.FullName()))
},
isInit: func(p pointer, f *coderFieldInfo) error {
panic("missing Go struct field for " + string(fd.FullName()))
},
merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
panic("missing Go struct field for " + string(fd.FullName()))
},
}
case isOneof:
fieldOffset = offsetOf(fs)
default:
fieldOffset = offsetOf(fs)
childMessage, funcs = fieldCoder(fd, ft)
}
cf := &preallocFields[i]
*cf = coderFieldInfo{
num: fd.Number(),
offset: fieldOffset,
wiretag: wiretag,
ft: ft,
tagsize: protowire.SizeVarint(wiretag),
funcs: funcs,
mi: childMessage,
validation: newFieldValidationInfo(mi, si, fd, ft),
isPointer: fd.Cardinality() == protoreflect.Repeated || fd.HasPresence(),
isRequired: fd.Cardinality() == protoreflect.Required,
presenceIndex: noPresence,
}
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
if od := oneofs.Get(i); !od.IsSynthetic() {
mi.initOneofFieldCoders(od, si)
}
}
if messageset.IsMessageSet(mi.Desc) {
if !mi.extensionOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
}
if !mi.unknownOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
}
mi.isMessageSet = true
}
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
})
var maxDense protoreflect.FieldNumber
for _, cf := range mi.orderedCoderFields {
if cf.num >= 16 && cf.num >= 2*maxDense {
break
}
maxDense = cf.num
}
mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
for _, cf := range mi.orderedCoderFields {
if int(cf.num) >= len(mi.denseCoderFields) {
break
}
mi.denseCoderFields[cf.num] = cf
}
// To preserve compatibility with historic wire output, marshal oneofs last.
if mi.Desc.Oneofs().Len() > 0 {
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
fi := fields.ByNumber(mi.orderedCoderFields[i].num)
fj := fields.ByNumber(mi.orderedCoderFields[j].num)
return order.LegacyFieldOrder(fi, fj)
})
}
mi.needsInitCheck = needsInitCheck(mi.Desc)
if mi.methods.Marshal == nil && mi.methods.Size == nil {
mi.methods.Flags |= protoiface.SupportMarshalDeterministic
mi.methods.Marshal = mi.marshal
mi.methods.Size = mi.size
}
if mi.methods.Unmarshal == nil {
mi.methods.Flags |= protoiface.SupportUnmarshalDiscardUnknown
mi.methods.Unmarshal = mi.unmarshal
}
if mi.methods.CheckInitialized == nil {
mi.methods.CheckInitialized = mi.checkInitialized
}
if mi.methods.Merge == nil {
mi.methods.Merge = mi.merge
}
if mi.methods.Equal == nil {
mi.methods.Equal = equal
}
}
// getUnknownBytes returns a *[]byte for the unknown fields.
// It is the caller's responsibility to check whether the pointer is nil.
// This function is specially designed to be inlineable.
func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
if mi.unknownPtrKind {
return *p.Apply(mi.unknownOffset).BytesPtr()
} else {
return p.Apply(mi.unknownOffset).Bytes()
}
}
// mutableUnknownBytes returns a *[]byte for the unknown fields.
// The returned pointer is guaranteed to not be nil.
func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
if mi.unknownPtrKind {
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
*bp = new([]byte)
}
return *bp
} else {
return p.Apply(mi.unknownOffset).Bytes()
}
}

View File

@ -0,0 +1,153 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/reflect/protoreflect"
piface "google.golang.org/protobuf/runtime/protoiface"
)
func (mi *MessageInfo) makeOpaqueCoderMethods(t reflect.Type, si opaqueStructInfo) {
mi.sizecacheOffset = si.sizecacheOffset
mi.unknownOffset = si.unknownOffset
mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
mi.extensionOffset = si.extensionOffset
mi.lazyOffset = si.lazyOffset
mi.presenceOffset = si.presenceOffset
mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
fields := mi.Desc.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
fs := si.fieldsByNumber[fd.Number()]
if fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic() {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
ft := fs.Type
var wiretag uint64
if !fd.IsPacked() {
wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
} else {
wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
}
var fieldOffset offset
var funcs pointerCoderFuncs
var childMessage *MessageInfo
switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
fieldOffset = offsetOf(fs)
case fd.Message() != nil && !fd.IsMap():
fieldOffset = offsetOf(fs)
if fd.IsList() {
childMessage, funcs = makeOpaqueRepeatedMessageFieldCoder(fd, ft)
} else {
childMessage, funcs = makeOpaqueMessageFieldCoder(fd, ft)
}
default:
fieldOffset = offsetOf(fs)
childMessage, funcs = fieldCoder(fd, ft)
}
cf := &coderFieldInfo{
num: fd.Number(),
offset: fieldOffset,
wiretag: wiretag,
ft: ft,
tagsize: protowire.SizeVarint(wiretag),
funcs: funcs,
mi: childMessage,
validation: newFieldValidationInfo(mi, si.structInfo, fd, ft),
isPointer: (fd.Cardinality() == protoreflect.Repeated ||
fd.Kind() == protoreflect.MessageKind ||
fd.Kind() == protoreflect.GroupKind),
isRequired: fd.Cardinality() == protoreflect.Required,
presenceIndex: noPresence,
}
// TODO: Use presence for all fields.
//
// In some cases, such as maps, presence means only "might be set" rather
// than "is definitely set", but every field should have a presence bit to
// permit us to skip over definitely-unset fields at marshal time.
var hasPresence bool
hasPresence, cf.isLazy = usePresenceForField(si, fd)
if hasPresence {
cf.presenceIndex, mi.presenceSize = presenceIndex(mi.Desc, fd)
}
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
if od := oneofs.Get(i); !od.IsSynthetic() {
mi.initOneofFieldCoders(od, si.structInfo)
}
}
if messageset.IsMessageSet(mi.Desc) {
if !mi.extensionOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
}
if !mi.unknownOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
}
mi.isMessageSet = true
}
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
})
var maxDense protoreflect.FieldNumber
for _, cf := range mi.orderedCoderFields {
if cf.num >= 16 && cf.num >= 2*maxDense {
break
}
maxDense = cf.num
}
mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
for _, cf := range mi.orderedCoderFields {
if int(cf.num) > len(mi.denseCoderFields) {
break
}
mi.denseCoderFields[cf.num] = cf
}
// To preserve compatibility with historic wire output, marshal oneofs last.
if mi.Desc.Oneofs().Len() > 0 {
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
fi := fields.ByNumber(mi.orderedCoderFields[i].num)
fj := fields.ByNumber(mi.orderedCoderFields[j].num)
return order.LegacyFieldOrder(fi, fj)
})
}
mi.needsInitCheck = needsInitCheck(mi.Desc)
if mi.methods.Marshal == nil && mi.methods.Size == nil {
mi.methods.Flags |= piface.SupportMarshalDeterministic
mi.methods.Marshal = mi.marshal
mi.methods.Size = mi.size
}
if mi.methods.Unmarshal == nil {
mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown
mi.methods.Unmarshal = mi.unmarshal
}
if mi.methods.CheckInitialized == nil {
mi.methods.CheckInitialized = mi.checkInitialized
}
if mi.methods.Merge == nil {
mi.methods.Merge = mi.merge
}
if mi.methods.Equal == nil {
mi.methods.Equal = equal
}
}

View File

@ -0,0 +1,145 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
)
func sizeMessageSet(mi *MessageInfo, p pointer, opts marshalOptions) (size int) {
if !flags.ProtoLegacy {
return 0
}
ext := *p.Apply(mi.extensionOffset).Extensions()
for _, x := range ext {
xi := getExtensionFieldInfo(x.Type())
if xi.funcs.size == nil {
continue
}
num, _ := protowire.DecodeTag(xi.wiretag)
size += messageset.SizeField(num)
if fullyLazyExtensions(opts) {
// Don't expand the extension, instead use the buffer to calculate size
if lb := x.lazyBuffer(); lb != nil {
// We got hold of the buffer, so it's still lazy.
// Don't count the tag size in the extension buffer, it's already added.
size += protowire.SizeTag(messageset.FieldMessage) + len(lb) - xi.tagsize
continue
}
}
size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts)
}
if u := mi.getUnknownBytes(p); u != nil {
size += messageset.SizeUnknown(*u)
}
return size
}
func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts marshalOptions) ([]byte, error) {
if !flags.ProtoLegacy {
return b, errors.New("no support for message_set_wire_format")
}
ext := *p.Apply(mi.extensionOffset).Extensions()
switch len(ext) {
case 0:
case 1:
// Fast-path for one extension: Don't bother sorting the keys.
for _, x := range ext {
var err error
b, err = marshalMessageSetField(mi, b, x, opts)
if err != nil {
return b, err
}
}
default:
// Sort the keys to provide a deterministic encoding.
// Not sure this is required, but the old code does it.
keys := make([]int, 0, len(ext))
for k := range ext {
keys = append(keys, int(k))
}
sort.Ints(keys)
for _, k := range keys {
var err error
b, err = marshalMessageSetField(mi, b, ext[int32(k)], opts)
if err != nil {
return b, err
}
}
}
if u := mi.getUnknownBytes(p); u != nil {
var err error
b, err = messageset.AppendUnknown(b, *u)
if err != nil {
return b, err
}
}
return b, nil
}
func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
xi := getExtensionFieldInfo(x.Type())
num, _ := protowire.DecodeTag(xi.wiretag)
b = messageset.AppendFieldStart(b, num)
if fullyLazyExtensions(opts) {
// Don't expand the extension if it's still in wire format, instead use the buffer content.
if lb := x.lazyBuffer(); lb != nil {
// The tag inside the lazy buffer is a different tag (the extension
// number), but what we need here is the tag for FieldMessage:
b = protowire.AppendVarint(b, protowire.EncodeTag(messageset.FieldMessage, protowire.BytesType))
b = append(b, lb[xi.tagsize:]...)
b = messageset.AppendFieldEnd(b)
return b, nil
}
}
b, err := xi.funcs.marshal(b, x.Value(), protowire.EncodeTag(messageset.FieldMessage, protowire.BytesType), opts)
if err != nil {
return b, err
}
b = messageset.AppendFieldEnd(b)
return b, nil
}
func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOptions) (out unmarshalOutput, err error) {
if !flags.ProtoLegacy {
return out, errors.New("no support for message_set_wire_format")
}
ep := p.Apply(mi.extensionOffset).Extensions()
if *ep == nil {
*ep = make(map[int32]ExtensionField)
}
ext := *ep
initialized := true
err = messageset.Unmarshal(b, true, func(num protowire.Number, v []byte) error {
o, err := mi.unmarshalExtension(v, num, protowire.BytesType, ext, opts)
if err == errUnknown {
u := mi.mutableUnknownBytes(p)
*u = protowire.AppendTag(*u, num, protowire.BytesType)
*u = append(*u, v...)
return nil
}
if !o.initialized {
initialized = false
}
return err
})
out.n = len(b)
out.initialized = initialized
return out, err
}

View File

@ -0,0 +1,557 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
// pointerCoderFuncs is a set of pointer encoding functions.
type pointerCoderFuncs struct {
mi *MessageInfo
size func(p pointer, f *coderFieldInfo, opts marshalOptions) int
marshal func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error)
isInit func(p pointer, f *coderFieldInfo) error
merge func(dst, src pointer, f *coderFieldInfo, opts mergeOptions)
}
// valueCoderFuncs is a set of protoreflect.Value encoding functions.
type valueCoderFuncs struct {
size func(v protoreflect.Value, tagsize int, opts marshalOptions) int
marshal func(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, v protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error)
isInit func(v protoreflect.Value) error
merge func(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value
}
// fieldCoder returns pointer functions for a field, used for operating on
// struct fields.
func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
switch {
case fd.IsMap():
return encoderFuncsForMap(fd, ft)
case fd.Cardinality() == protoreflect.Repeated && !fd.IsPacked():
// Repeated fields (not packed).
if ft.Kind() != reflect.Slice {
break
}
ft := ft.Elem()
switch fd.Kind() {
case protoreflect.BoolKind:
if ft.Kind() == reflect.Bool {
return nil, coderBoolSlice
}
case protoreflect.EnumKind:
if ft.Kind() == reflect.Int32 {
return nil, coderEnumSlice
}
case protoreflect.Int32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderInt32Slice
}
case protoreflect.Sint32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSint32Slice
}
case protoreflect.Uint32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderUint32Slice
}
case protoreflect.Int64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderInt64Slice
}
case protoreflect.Sint64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSint64Slice
}
case protoreflect.Uint64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderUint64Slice
}
case protoreflect.Sfixed32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSfixed32Slice
}
case protoreflect.Fixed32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderFixed32Slice
}
case protoreflect.FloatKind:
if ft.Kind() == reflect.Float32 {
return nil, coderFloatSlice
}
case protoreflect.Sfixed64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSfixed64Slice
}
case protoreflect.Fixed64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderFixed64Slice
}
case protoreflect.DoubleKind:
if ft.Kind() == reflect.Float64 {
return nil, coderDoubleSlice
}
case protoreflect.StringKind:
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return nil, coderStringSliceValidateUTF8
}
if ft.Kind() == reflect.String {
return nil, coderStringSlice
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return nil, coderBytesSliceValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytesSlice
}
case protoreflect.BytesKind:
if ft.Kind() == reflect.String {
return nil, coderStringSlice
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytesSlice
}
case protoreflect.MessageKind:
return getMessageInfo(ft), makeMessageSliceFieldCoder(fd, ft)
case protoreflect.GroupKind:
return getMessageInfo(ft), makeGroupSliceFieldCoder(fd, ft)
}
case fd.Cardinality() == protoreflect.Repeated && fd.IsPacked():
// Packed repeated fields.
//
// Only repeated fields of primitive numeric types
// (Varint, Fixed32, or Fixed64 wire type) can be packed.
if ft.Kind() != reflect.Slice {
break
}
ft := ft.Elem()
switch fd.Kind() {
case protoreflect.BoolKind:
if ft.Kind() == reflect.Bool {
return nil, coderBoolPackedSlice
}
case protoreflect.EnumKind:
if ft.Kind() == reflect.Int32 {
return nil, coderEnumPackedSlice
}
case protoreflect.Int32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderInt32PackedSlice
}
case protoreflect.Sint32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSint32PackedSlice
}
case protoreflect.Uint32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderUint32PackedSlice
}
case protoreflect.Int64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderInt64PackedSlice
}
case protoreflect.Sint64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSint64PackedSlice
}
case protoreflect.Uint64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderUint64PackedSlice
}
case protoreflect.Sfixed32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSfixed32PackedSlice
}
case protoreflect.Fixed32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderFixed32PackedSlice
}
case protoreflect.FloatKind:
if ft.Kind() == reflect.Float32 {
return nil, coderFloatPackedSlice
}
case protoreflect.Sfixed64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSfixed64PackedSlice
}
case protoreflect.Fixed64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderFixed64PackedSlice
}
case protoreflect.DoubleKind:
if ft.Kind() == reflect.Float64 {
return nil, coderDoublePackedSlice
}
}
case fd.Kind() == protoreflect.MessageKind:
return getMessageInfo(ft), makeMessageFieldCoder(fd, ft)
case fd.Kind() == protoreflect.GroupKind:
return getMessageInfo(ft), makeGroupFieldCoder(fd, ft)
case !fd.HasPresence() && fd.ContainingOneof() == nil:
// Populated oneof fields always encode even if set to the zero value,
// which normally are not encoded in proto3.
switch fd.Kind() {
case protoreflect.BoolKind:
if ft.Kind() == reflect.Bool {
return nil, coderBoolNoZero
}
case protoreflect.EnumKind:
if ft.Kind() == reflect.Int32 {
return nil, coderEnumNoZero
}
case protoreflect.Int32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderInt32NoZero
}
case protoreflect.Sint32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSint32NoZero
}
case protoreflect.Uint32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderUint32NoZero
}
case protoreflect.Int64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderInt64NoZero
}
case protoreflect.Sint64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSint64NoZero
}
case protoreflect.Uint64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderUint64NoZero
}
case protoreflect.Sfixed32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSfixed32NoZero
}
case protoreflect.Fixed32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderFixed32NoZero
}
case protoreflect.FloatKind:
if ft.Kind() == reflect.Float32 {
return nil, coderFloatNoZero
}
case protoreflect.Sfixed64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSfixed64NoZero
}
case protoreflect.Fixed64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderFixed64NoZero
}
case protoreflect.DoubleKind:
if ft.Kind() == reflect.Float64 {
return nil, coderDoubleNoZero
}
case protoreflect.StringKind:
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return nil, coderStringNoZeroValidateUTF8
}
if ft.Kind() == reflect.String {
return nil, coderStringNoZero
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return nil, coderBytesNoZeroValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytesNoZero
}
case protoreflect.BytesKind:
if ft.Kind() == reflect.String {
return nil, coderStringNoZero
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytesNoZero
}
}
case ft.Kind() == reflect.Ptr:
ft := ft.Elem()
switch fd.Kind() {
case protoreflect.BoolKind:
if ft.Kind() == reflect.Bool {
return nil, coderBoolPtr
}
case protoreflect.EnumKind:
if ft.Kind() == reflect.Int32 {
return nil, coderEnumPtr
}
case protoreflect.Int32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderInt32Ptr
}
case protoreflect.Sint32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSint32Ptr
}
case protoreflect.Uint32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderUint32Ptr
}
case protoreflect.Int64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderInt64Ptr
}
case protoreflect.Sint64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSint64Ptr
}
case protoreflect.Uint64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderUint64Ptr
}
case protoreflect.Sfixed32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSfixed32Ptr
}
case protoreflect.Fixed32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderFixed32Ptr
}
case protoreflect.FloatKind:
if ft.Kind() == reflect.Float32 {
return nil, coderFloatPtr
}
case protoreflect.Sfixed64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSfixed64Ptr
}
case protoreflect.Fixed64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderFixed64Ptr
}
case protoreflect.DoubleKind:
if ft.Kind() == reflect.Float64 {
return nil, coderDoublePtr
}
case protoreflect.StringKind:
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return nil, coderStringPtrValidateUTF8
}
if ft.Kind() == reflect.String {
return nil, coderStringPtr
}
case protoreflect.BytesKind:
if ft.Kind() == reflect.String {
return nil, coderStringPtr
}
}
default:
switch fd.Kind() {
case protoreflect.BoolKind:
if ft.Kind() == reflect.Bool {
return nil, coderBool
}
case protoreflect.EnumKind:
if ft.Kind() == reflect.Int32 {
return nil, coderEnum
}
case protoreflect.Int32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderInt32
}
case protoreflect.Sint32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSint32
}
case protoreflect.Uint32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderUint32
}
case protoreflect.Int64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderInt64
}
case protoreflect.Sint64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSint64
}
case protoreflect.Uint64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderUint64
}
case protoreflect.Sfixed32Kind:
if ft.Kind() == reflect.Int32 {
return nil, coderSfixed32
}
case protoreflect.Fixed32Kind:
if ft.Kind() == reflect.Uint32 {
return nil, coderFixed32
}
case protoreflect.FloatKind:
if ft.Kind() == reflect.Float32 {
return nil, coderFloat
}
case protoreflect.Sfixed64Kind:
if ft.Kind() == reflect.Int64 {
return nil, coderSfixed64
}
case protoreflect.Fixed64Kind:
if ft.Kind() == reflect.Uint64 {
return nil, coderFixed64
}
case protoreflect.DoubleKind:
if ft.Kind() == reflect.Float64 {
return nil, coderDouble
}
case protoreflect.StringKind:
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return nil, coderStringValidateUTF8
}
if ft.Kind() == reflect.String {
return nil, coderString
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return nil, coderBytesValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytes
}
case protoreflect.BytesKind:
if ft.Kind() == reflect.String {
return nil, coderString
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytes
}
}
}
panic(fmt.Sprintf("invalid type: no encoder for %v %v %v/%v", fd.FullName(), fd.Cardinality(), fd.Kind(), ft))
}
// encoderFuncsForValue returns value functions for a field, used for
// extension values and map encoding.
func encoderFuncsForValue(fd protoreflect.FieldDescriptor) valueCoderFuncs {
switch {
case fd.Cardinality() == protoreflect.Repeated && !fd.IsPacked():
switch fd.Kind() {
case protoreflect.BoolKind:
return coderBoolSliceValue
case protoreflect.EnumKind:
return coderEnumSliceValue
case protoreflect.Int32Kind:
return coderInt32SliceValue
case protoreflect.Sint32Kind:
return coderSint32SliceValue
case protoreflect.Uint32Kind:
return coderUint32SliceValue
case protoreflect.Int64Kind:
return coderInt64SliceValue
case protoreflect.Sint64Kind:
return coderSint64SliceValue
case protoreflect.Uint64Kind:
return coderUint64SliceValue
case protoreflect.Sfixed32Kind:
return coderSfixed32SliceValue
case protoreflect.Fixed32Kind:
return coderFixed32SliceValue
case protoreflect.FloatKind:
return coderFloatSliceValue
case protoreflect.Sfixed64Kind:
return coderSfixed64SliceValue
case protoreflect.Fixed64Kind:
return coderFixed64SliceValue
case protoreflect.DoubleKind:
return coderDoubleSliceValue
case protoreflect.StringKind:
// We don't have a UTF-8 validating coder for repeated string fields.
// Value coders are used for extensions and maps.
// Extensions are never proto3, and maps never contain lists.
return coderStringSliceValue
case protoreflect.BytesKind:
return coderBytesSliceValue
case protoreflect.MessageKind:
return coderMessageSliceValue
case protoreflect.GroupKind:
return coderGroupSliceValue
}
case fd.Cardinality() == protoreflect.Repeated && fd.IsPacked():
switch fd.Kind() {
case protoreflect.BoolKind:
return coderBoolPackedSliceValue
case protoreflect.EnumKind:
return coderEnumPackedSliceValue
case protoreflect.Int32Kind:
return coderInt32PackedSliceValue
case protoreflect.Sint32Kind:
return coderSint32PackedSliceValue
case protoreflect.Uint32Kind:
return coderUint32PackedSliceValue
case protoreflect.Int64Kind:
return coderInt64PackedSliceValue
case protoreflect.Sint64Kind:
return coderSint64PackedSliceValue
case protoreflect.Uint64Kind:
return coderUint64PackedSliceValue
case protoreflect.Sfixed32Kind:
return coderSfixed32PackedSliceValue
case protoreflect.Fixed32Kind:
return coderFixed32PackedSliceValue
case protoreflect.FloatKind:
return coderFloatPackedSliceValue
case protoreflect.Sfixed64Kind:
return coderSfixed64PackedSliceValue
case protoreflect.Fixed64Kind:
return coderFixed64PackedSliceValue
case protoreflect.DoubleKind:
return coderDoublePackedSliceValue
}
default:
switch fd.Kind() {
default:
case protoreflect.BoolKind:
return coderBoolValue
case protoreflect.EnumKind:
return coderEnumValue
case protoreflect.Int32Kind:
return coderInt32Value
case protoreflect.Sint32Kind:
return coderSint32Value
case protoreflect.Uint32Kind:
return coderUint32Value
case protoreflect.Int64Kind:
return coderInt64Value
case protoreflect.Sint64Kind:
return coderSint64Value
case protoreflect.Uint64Kind:
return coderUint64Value
case protoreflect.Sfixed32Kind:
return coderSfixed32Value
case protoreflect.Fixed32Kind:
return coderFixed32Value
case protoreflect.FloatKind:
return coderFloatValue
case protoreflect.Sfixed64Kind:
return coderSfixed64Value
case protoreflect.Fixed64Kind:
return coderFixed64Value
case protoreflect.DoubleKind:
return coderDoubleValue
case protoreflect.StringKind:
if strs.EnforceUTF8(fd) {
return coderStringValueValidateUTF8
}
return coderStringValue
case protoreflect.BytesKind:
return coderBytesValue
case protoreflect.MessageKind:
return coderMessageValue
case protoreflect.GroupKind:
return coderGroupValue
}
}
panic(fmt.Sprintf("invalid field: no encoder for %v %v %v", fd.FullName(), fd.Cardinality(), fd.Kind()))
}

View File

@ -0,0 +1,15 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
// When using unsafe pointers, we can just treat enum values as int32s.
var (
coderEnumNoZero = coderInt32NoZero
coderEnum = coderInt32
coderEnumPtr = coderInt32Ptr
coderEnumSlice = coderInt32Slice
coderEnumPackedSlice = coderInt32PackedSlice
)

View File

@ -0,0 +1,495 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
// unwrapper unwraps the value to the underlying value.
// This is implemented by List and Map.
type unwrapper interface {
protoUnwrap() any
}
// A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.
type Converter interface {
// PBValueOf converts a reflect.Value to a protoreflect.Value.
PBValueOf(reflect.Value) protoreflect.Value
// GoValueOf converts a protoreflect.Value to a reflect.Value.
GoValueOf(protoreflect.Value) reflect.Value
// IsValidPB returns whether a protoreflect.Value is compatible with this type.
IsValidPB(protoreflect.Value) bool
// IsValidGo returns whether a reflect.Value is compatible with this type.
IsValidGo(reflect.Value) bool
// New returns a new field value.
// For scalars, it returns the default value of the field.
// For composite types, it returns a new mutable value.
New() protoreflect.Value
// Zero returns a new field value.
// For scalars, it returns the default value of the field.
// For composite types, it returns an immutable, empty value.
Zero() protoreflect.Value
}
// NewConverter matches a Go type with a protobuf field and returns a Converter
// that converts between the two. Enums must be a named int32 kind that
// implements protoreflect.Enum, and messages must be pointer to a named
// struct type that implements protoreflect.ProtoMessage.
//
// This matcher deliberately supports a wider range of Go types than what
// protoc-gen-go historically generated to be able to automatically wrap some
// v1 messages generated by other forks of protoc-gen-go.
func NewConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Converter {
switch {
case fd.IsList():
return newListConverter(t, fd)
case fd.IsMap():
return newMapConverter(t, fd)
default:
return newSingularConverter(t, fd)
}
}
var (
boolType = reflect.TypeOf(bool(false))
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))
stringType = reflect.TypeOf(string(""))
bytesType = reflect.TypeOf([]byte(nil))
byteType = reflect.TypeOf(byte(0))
)
var (
boolZero = protoreflect.ValueOfBool(false)
int32Zero = protoreflect.ValueOfInt32(0)
int64Zero = protoreflect.ValueOfInt64(0)
uint32Zero = protoreflect.ValueOfUint32(0)
uint64Zero = protoreflect.ValueOfUint64(0)
float32Zero = protoreflect.ValueOfFloat32(0)
float64Zero = protoreflect.ValueOfFloat64(0)
stringZero = protoreflect.ValueOfString("")
bytesZero = protoreflect.ValueOfBytes(nil)
)
func newSingularConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Converter {
defVal := func(fd protoreflect.FieldDescriptor, zero protoreflect.Value) protoreflect.Value {
if fd.Cardinality() == protoreflect.Repeated {
// Default isn't defined for repeated fields.
return zero
}
return fd.Default()
}
switch fd.Kind() {
case protoreflect.BoolKind:
if t.Kind() == reflect.Bool {
return &boolConverter{t, defVal(fd, boolZero)}
}
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if t.Kind() == reflect.Int32 {
return &int32Converter{t, defVal(fd, int32Zero)}
}
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
if t.Kind() == reflect.Int64 {
return &int64Converter{t, defVal(fd, int64Zero)}
}
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
if t.Kind() == reflect.Uint32 {
return &uint32Converter{t, defVal(fd, uint32Zero)}
}
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
if t.Kind() == reflect.Uint64 {
return &uint64Converter{t, defVal(fd, uint64Zero)}
}
case protoreflect.FloatKind:
if t.Kind() == reflect.Float32 {
return &float32Converter{t, defVal(fd, float32Zero)}
}
case protoreflect.DoubleKind:
if t.Kind() == reflect.Float64 {
return &float64Converter{t, defVal(fd, float64Zero)}
}
case protoreflect.StringKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &stringConverter{t, defVal(fd, stringZero)}
}
case protoreflect.BytesKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &bytesConverter{t, defVal(fd, bytesZero)}
}
case protoreflect.EnumKind:
// Handle enums, which must be a named int32 type.
if t.Kind() == reflect.Int32 {
return newEnumConverter(t, fd)
}
case protoreflect.MessageKind, protoreflect.GroupKind:
return newMessageConverter(t)
}
panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
}
type boolConverter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *boolConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfBool(v.Bool())
}
func (c *boolConverter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(v.Bool()).Convert(c.goType)
}
func (c *boolConverter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(bool)
return ok
}
func (c *boolConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *boolConverter) New() protoreflect.Value { return c.def }
func (c *boolConverter) Zero() protoreflect.Value { return c.def }
type int32Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *int32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfInt32(int32(v.Int()))
}
func (c *int32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(int32(v.Int())).Convert(c.goType)
}
func (c *int32Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(int32)
return ok
}
func (c *int32Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *int32Converter) New() protoreflect.Value { return c.def }
func (c *int32Converter) Zero() protoreflect.Value { return c.def }
type int64Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *int64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfInt64(int64(v.Int()))
}
func (c *int64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(int64(v.Int())).Convert(c.goType)
}
func (c *int64Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(int64)
return ok
}
func (c *int64Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *int64Converter) New() protoreflect.Value { return c.def }
func (c *int64Converter) Zero() protoreflect.Value { return c.def }
type uint32Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *uint32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfUint32(uint32(v.Uint()))
}
func (c *uint32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType)
}
func (c *uint32Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(uint32)
return ok
}
func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *uint32Converter) New() protoreflect.Value { return c.def }
func (c *uint32Converter) Zero() protoreflect.Value { return c.def }
type uint64Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *uint64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfUint64(uint64(v.Uint()))
}
func (c *uint64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType)
}
func (c *uint64Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(uint64)
return ok
}
func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *uint64Converter) New() protoreflect.Value { return c.def }
func (c *uint64Converter) Zero() protoreflect.Value { return c.def }
type float32Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *float32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfFloat32(float32(v.Float()))
}
func (c *float32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(float32(v.Float())).Convert(c.goType)
}
func (c *float32Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(float32)
return ok
}
func (c *float32Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *float32Converter) New() protoreflect.Value { return c.def }
func (c *float32Converter) Zero() protoreflect.Value { return c.def }
type float64Converter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *float64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfFloat64(float64(v.Float()))
}
func (c *float64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(float64(v.Float())).Convert(c.goType)
}
func (c *float64Converter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(float64)
return ok
}
func (c *float64Converter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *float64Converter) New() protoreflect.Value { return c.def }
func (c *float64Converter) Zero() protoreflect.Value { return c.def }
type stringConverter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *stringConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfString(v.Convert(stringType).String())
}
func (c *stringConverter) GoValueOf(v protoreflect.Value) reflect.Value {
// protoreflect.Value.String never panics, so we go through an interface
// conversion here to check the type.
s := v.Interface().(string)
if c.goType.Kind() == reflect.Slice && s == "" {
return reflect.Zero(c.goType) // ensure empty string is []byte(nil)
}
return reflect.ValueOf(s).Convert(c.goType)
}
func (c *stringConverter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(string)
return ok
}
func (c *stringConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *stringConverter) New() protoreflect.Value { return c.def }
func (c *stringConverter) Zero() protoreflect.Value { return c.def }
type bytesConverter struct {
goType reflect.Type
def protoreflect.Value
}
func (c *bytesConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
if c.goType.Kind() == reflect.String && v.Len() == 0 {
return protoreflect.ValueOfBytes(nil) // ensure empty string is []byte(nil)
}
return protoreflect.ValueOfBytes(v.Convert(bytesType).Bytes())
}
func (c *bytesConverter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(v.Bytes()).Convert(c.goType)
}
func (c *bytesConverter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().([]byte)
return ok
}
func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *bytesConverter) New() protoreflect.Value { return c.def }
func (c *bytesConverter) Zero() protoreflect.Value { return c.def }
type enumConverter struct {
goType reflect.Type
def protoreflect.Value
}
func newEnumConverter(goType reflect.Type, fd protoreflect.FieldDescriptor) Converter {
var def protoreflect.Value
if fd.Cardinality() == protoreflect.Repeated {
def = protoreflect.ValueOfEnum(fd.Enum().Values().Get(0).Number())
} else {
def = fd.Default()
}
return &enumConverter{goType, def}
}
func (c *enumConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfEnum(protoreflect.EnumNumber(v.Int()))
}
func (c *enumConverter) GoValueOf(v protoreflect.Value) reflect.Value {
return reflect.ValueOf(v.Enum()).Convert(c.goType)
}
func (c *enumConverter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(protoreflect.EnumNumber)
return ok
}
func (c *enumConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *enumConverter) New() protoreflect.Value {
return c.def
}
func (c *enumConverter) Zero() protoreflect.Value {
return c.def
}
type messageConverter struct {
goType reflect.Type
}
func newMessageConverter(goType reflect.Type) Converter {
return &messageConverter{goType}
}
func (c *messageConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
if c.isNonPointer() {
if v.CanAddr() {
v = v.Addr() // T => *T
} else {
v = reflect.Zero(reflect.PtrTo(v.Type()))
}
}
if m, ok := v.Interface().(protoreflect.ProtoMessage); ok {
return protoreflect.ValueOfMessage(m.ProtoReflect())
}
return protoreflect.ValueOfMessage(legacyWrapMessage(v))
}
func (c *messageConverter) GoValueOf(v protoreflect.Value) reflect.Value {
m := v.Message()
var rv reflect.Value
if u, ok := m.(unwrapper); ok {
rv = reflect.ValueOf(u.protoUnwrap())
} else {
rv = reflect.ValueOf(m.Interface())
}
if c.isNonPointer() {
if rv.Type() != reflect.PtrTo(c.goType) {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), reflect.PtrTo(c.goType)))
}
if !rv.IsNil() {
rv = rv.Elem() // *T => T
} else {
rv = reflect.Zero(rv.Type().Elem())
}
}
if rv.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.goType))
}
return rv
}
func (c *messageConverter) IsValidPB(v protoreflect.Value) bool {
m := v.Message()
var rv reflect.Value
if u, ok := m.(unwrapper); ok {
rv = reflect.ValueOf(u.protoUnwrap())
} else {
rv = reflect.ValueOf(m.Interface())
}
if c.isNonPointer() {
return rv.Type() == reflect.PtrTo(c.goType)
}
return rv.Type() == c.goType
}
func (c *messageConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *messageConverter) New() protoreflect.Value {
if c.isNonPointer() {
return c.PBValueOf(reflect.New(c.goType).Elem())
}
return c.PBValueOf(reflect.New(c.goType.Elem()))
}
func (c *messageConverter) Zero() protoreflect.Value {
return c.PBValueOf(reflect.Zero(c.goType))
}
// isNonPointer reports whether the type is a non-pointer type.
// This never occurs for generated message types.
func (c *messageConverter) isNonPointer() bool {
return c.goType.Kind() != reflect.Ptr
}

View File

@ -0,0 +1,141 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
func newListConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Converter {
switch {
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Slice:
return &listPtrConverter{t, newSingularConverter(t.Elem().Elem(), fd)}
case t.Kind() == reflect.Slice:
return &listConverter{t, newSingularConverter(t.Elem(), fd)}
}
panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
}
type listConverter struct {
goType reflect.Type // []T
c Converter
}
func (c *listConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
pv := reflect.New(c.goType)
pv.Elem().Set(v)
return protoreflect.ValueOfList(&listReflect{pv, c.c})
}
func (c *listConverter) GoValueOf(v protoreflect.Value) reflect.Value {
rv := v.List().(*listReflect).v
if rv.IsNil() {
return reflect.Zero(c.goType)
}
return rv.Elem()
}
func (c *listConverter) IsValidPB(v protoreflect.Value) bool {
list, ok := v.Interface().(*listReflect)
if !ok {
return false
}
return list.v.Type().Elem() == c.goType
}
func (c *listConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *listConverter) New() protoreflect.Value {
return protoreflect.ValueOfList(&listReflect{reflect.New(c.goType), c.c})
}
func (c *listConverter) Zero() protoreflect.Value {
return protoreflect.ValueOfList(&listReflect{reflect.Zero(reflect.PtrTo(c.goType)), c.c})
}
type listPtrConverter struct {
goType reflect.Type // *[]T
c Converter
}
func (c *listPtrConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfList(&listReflect{v, c.c})
}
func (c *listPtrConverter) GoValueOf(v protoreflect.Value) reflect.Value {
return v.List().(*listReflect).v
}
func (c *listPtrConverter) IsValidPB(v protoreflect.Value) bool {
list, ok := v.Interface().(*listReflect)
if !ok {
return false
}
return list.v.Type() == c.goType
}
func (c *listPtrConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *listPtrConverter) New() protoreflect.Value {
return c.PBValueOf(reflect.New(c.goType.Elem()))
}
func (c *listPtrConverter) Zero() protoreflect.Value {
return c.PBValueOf(reflect.Zero(c.goType))
}
type listReflect struct {
v reflect.Value // *[]T
conv Converter
}
func (ls *listReflect) Len() int {
if ls.v.IsNil() {
return 0
}
return ls.v.Elem().Len()
}
func (ls *listReflect) Get(i int) protoreflect.Value {
return ls.conv.PBValueOf(ls.v.Elem().Index(i))
}
func (ls *listReflect) Set(i int, v protoreflect.Value) {
ls.v.Elem().Index(i).Set(ls.conv.GoValueOf(v))
}
func (ls *listReflect) Append(v protoreflect.Value) {
ls.v.Elem().Set(reflect.Append(ls.v.Elem(), ls.conv.GoValueOf(v)))
}
func (ls *listReflect) AppendMutable() protoreflect.Value {
if _, ok := ls.conv.(*messageConverter); !ok {
panic("invalid AppendMutable on list with non-message type")
}
v := ls.NewElement()
ls.Append(v)
return v
}
func (ls *listReflect) Truncate(i int) {
ls.v.Elem().Set(ls.v.Elem().Slice(0, i))
}
func (ls *listReflect) NewElement() protoreflect.Value {
return ls.conv.New()
}
func (ls *listReflect) IsValid() bool {
return !ls.v.IsNil()
}
func (ls *listReflect) protoUnwrap() any {
return ls.v.Interface()
}

View File

@ -0,0 +1,121 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
type mapConverter struct {
goType reflect.Type // map[K]V
keyConv, valConv Converter
}
func newMapConverter(t reflect.Type, fd protoreflect.FieldDescriptor) *mapConverter {
if t.Kind() != reflect.Map {
panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
}
return &mapConverter{
goType: t,
keyConv: newSingularConverter(t.Key(), fd.MapKey()),
valConv: newSingularConverter(t.Elem(), fd.MapValue()),
}
}
func (c *mapConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
return protoreflect.ValueOfMap(&mapReflect{v, c.keyConv, c.valConv})
}
func (c *mapConverter) GoValueOf(v protoreflect.Value) reflect.Value {
return v.Map().(*mapReflect).v
}
func (c *mapConverter) IsValidPB(v protoreflect.Value) bool {
mapv, ok := v.Interface().(*mapReflect)
if !ok {
return false
}
return mapv.v.Type() == c.goType
}
func (c *mapConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *mapConverter) New() protoreflect.Value {
return c.PBValueOf(reflect.MakeMap(c.goType))
}
func (c *mapConverter) Zero() protoreflect.Value {
return c.PBValueOf(reflect.Zero(c.goType))
}
type mapReflect struct {
v reflect.Value // map[K]V
keyConv Converter
valConv Converter
}
func (ms *mapReflect) Len() int {
return ms.v.Len()
}
func (ms *mapReflect) Has(k protoreflect.MapKey) bool {
rk := ms.keyConv.GoValueOf(k.Value())
rv := ms.v.MapIndex(rk)
return rv.IsValid()
}
func (ms *mapReflect) Get(k protoreflect.MapKey) protoreflect.Value {
rk := ms.keyConv.GoValueOf(k.Value())
rv := ms.v.MapIndex(rk)
if !rv.IsValid() {
return protoreflect.Value{}
}
return ms.valConv.PBValueOf(rv)
}
func (ms *mapReflect) Set(k protoreflect.MapKey, v protoreflect.Value) {
rk := ms.keyConv.GoValueOf(k.Value())
rv := ms.valConv.GoValueOf(v)
ms.v.SetMapIndex(rk, rv)
}
func (ms *mapReflect) Clear(k protoreflect.MapKey) {
rk := ms.keyConv.GoValueOf(k.Value())
ms.v.SetMapIndex(rk, reflect.Value{})
}
func (ms *mapReflect) Mutable(k protoreflect.MapKey) protoreflect.Value {
if _, ok := ms.valConv.(*messageConverter); !ok {
panic("invalid Mutable on map with non-message value type")
}
v := ms.Get(k)
if !v.IsValid() {
v = ms.NewValue()
ms.Set(k, v)
}
return v
}
func (ms *mapReflect) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) {
iter := ms.v.MapRange()
for iter.Next() {
k := ms.keyConv.PBValueOf(iter.Key()).MapKey()
v := ms.valConv.PBValueOf(iter.Value())
if !f(k, v) {
return
}
}
}
func (ms *mapReflect) NewValue() protoreflect.Value {
return ms.valConv.New()
}
func (ms *mapReflect) IsValid() bool {
return !ms.v.IsNil()
}
func (ms *mapReflect) protoUnwrap() any {
return ms.v.Interface()
}

View File

@ -0,0 +1,333 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"math/bits"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
)
var errDecode = errors.New("cannot parse invalid wire-format data")
var errRecursionDepth = errors.New("exceeded maximum recursion depth")
type unmarshalOptions struct {
flags protoiface.UnmarshalInputFlags
resolver interface {
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
depth int
}
func (o unmarshalOptions) Options() proto.UnmarshalOptions {
return proto.UnmarshalOptions{
Merge: true,
AllowPartial: true,
DiscardUnknown: o.DiscardUnknown(),
Resolver: o.resolver,
NoLazyDecoding: o.NoLazyDecoding(),
}
}
func (o unmarshalOptions) DiscardUnknown() bool {
return o.flags&protoiface.UnmarshalDiscardUnknown != 0
}
func (o unmarshalOptions) AliasBuffer() bool { return o.flags&protoiface.UnmarshalAliasBuffer != 0 }
func (o unmarshalOptions) Validated() bool { return o.flags&protoiface.UnmarshalValidated != 0 }
func (o unmarshalOptions) NoLazyDecoding() bool {
return o.flags&protoiface.UnmarshalNoLazyDecoding != 0
}
func (o unmarshalOptions) CanBeLazy() bool {
if o.resolver != protoregistry.GlobalTypes {
return false
}
// We ignore the UnmarshalInvalidateSizeCache even though it's not in the default set
return (o.flags & ^(protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated | protoiface.UnmarshalCheckRequired)) == 0
}
var lazyUnmarshalOptions = unmarshalOptions{
resolver: protoregistry.GlobalTypes,
flags: protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated,
depth: protowire.DefaultRecursionLimit,
}
type unmarshalOutput struct {
n int // number of bytes consumed
initialized bool
}
// unmarshal is protoreflect.Methods.Unmarshal.
func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
var p pointer
if ms, ok := in.Message.(*messageState); ok {
p = ms.pointer()
} else {
p = in.Message.(*messageReflectWrapper).pointer()
}
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
flags: in.Flags,
resolver: in.Resolver,
depth: in.Depth,
})
var flags protoiface.UnmarshalOutputFlags
if out.initialized {
flags |= protoiface.UnmarshalInitialized
}
return protoiface.UnmarshalOutput{
Flags: flags,
}, err
}
// errUnknown is returned during unmarshaling to indicate a parse error that
// should result in a field being placed in the unknown fields section (for example,
// when the wire type doesn't match) as opposed to the entire unmarshal operation
// failing (for example, when a field extends past the available input).
//
// This is a sentinel error which should never be visible to the user.
var errUnknown = errors.New("unknown")
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
mi.init()
opts.depth--
if opts.depth < 0 {
return out, errRecursionDepth
}
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
}
lazyDecoding := LazyEnabled() // default
if opts.NoLazyDecoding() {
lazyDecoding = false // explicitly disabled
}
if mi.lazyOffset.IsValid() && lazyDecoding {
return mi.unmarshalPointerLazy(b, p, groupTag, opts)
}
return mi.unmarshalPointerEager(b, p, groupTag, opts)
}
// unmarshalPointerEager is the message unmarshalling function for all messages that are not lazy.
// The corresponding function for Lazy is in google_lazy.go.
func (mi *MessageInfo) unmarshalPointerEager(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
initialized := true
var requiredMask uint64
var exts *map[int32]ExtensionField
var presence presence
if mi.presenceOffset.IsValid() {
presence = p.Apply(mi.presenceOffset).PresenceInfo()
}
start := len(b)
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = protowire.ConsumeVarint(b)
if n < 0 {
return out, errDecode
}
b = b[n:]
}
var num protowire.Number
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
return out, errDecode
} else {
num = protowire.Number(n)
}
wtyp := protowire.Type(tag & 7)
if wtyp == protowire.EndGroupType {
if num != groupTag {
return out, errDecode
}
groupTag = 0
break
}
var f *coderFieldInfo
if int(num) < len(mi.denseCoderFields) {
f = mi.denseCoderFields[num]
} else {
f = mi.coderFields[num]
}
var n int
err := errUnknown
switch {
case f != nil:
if f.funcs.unmarshal == nil {
break
}
var o unmarshalOutput
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
n = o.n
if err != nil {
break
}
requiredMask |= f.validation.requiredBit
if f.funcs.isInit != nil && !o.initialized {
initialized = false
}
if f.presenceIndex != noPresence {
presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
}
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
exts = p.Apply(mi.extensionOffset).Extensions()
if *exts == nil {
*exts = make(map[int32]ExtensionField)
}
}
if exts == nil {
break
}
var o unmarshalOutput
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
if err != nil {
break
}
n = o.n
if !o.initialized {
initialized = false
}
}
if err != nil {
if err != errUnknown {
return out, err
}
n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, errDecode
}
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
u := mi.mutableUnknownBytes(p)
*u = protowire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...)
}
}
b = b[n:]
}
if groupTag != 0 {
return out, errDecode
}
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
initialized = false
}
if initialized {
out.initialized = true
}
out.n = start - len(b)
return out, nil
}
func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
x := exts[int32(num)]
xt := x.Type()
if xt == nil {
var err error
xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
if err != nil {
if err == protoregistry.NotFound {
return out, errUnknown
}
return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
}
}
xi := getExtensionFieldInfo(xt)
if xi.funcs.unmarshal == nil {
return out, errUnknown
}
if flags.LazyUnmarshalExtensions {
if opts.CanBeLazy() && x.canLazy(xt) {
out, valid := skipExtension(b, xi, num, wtyp, opts)
switch valid {
case ValidationValid:
if out.initialized {
x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
exts[int32(num)] = x
return out, nil
}
case ValidationInvalid:
return out, errDecode
case ValidationUnknown:
}
}
}
ival := x.Value()
if !ival.IsValid() && xi.unmarshalNeedsValue {
// Create a new message, list, or map value to fill in.
// For enums, create a prototype value to let the unmarshal func know the
// concrete type.
ival = xt.New()
}
v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
if err != nil {
return out, err
}
if xi.funcs.isInit == nil {
out.initialized = true
}
x.Set(xt, v)
exts[int32(num)] = x
return out, nil
}
func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
if xi.validation.mi == nil {
return out, ValidationUnknown
}
xi.validation.mi.init()
switch xi.validation.typ {
case validationTypeMessage:
if wtyp != protowire.BytesType {
return out, ValidationUnknown
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, ValidationUnknown
}
if opts.Validated() {
out.initialized = true
out.n = n
return out, ValidationValid
}
out, st := xi.validation.mi.validate(v, 0, opts)
out.n = n
return out, st
case validationTypeGroup:
if wtyp != protowire.StartGroupType {
return out, ValidationUnknown
}
out, st := xi.validation.mi.validate(b, num, opts)
return out, st
default:
return out, ValidationUnknown
}
}

View File

@ -0,0 +1,315 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"math"
"sort"
"sync/atomic"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/protolazy"
"google.golang.org/protobuf/proto"
piface "google.golang.org/protobuf/runtime/protoiface"
)
type marshalOptions struct {
flags piface.MarshalInputFlags
}
func (o marshalOptions) Options() proto.MarshalOptions {
return proto.MarshalOptions{
AllowPartial: true,
Deterministic: o.Deterministic(),
UseCachedSize: o.UseCachedSize(),
}
}
func (o marshalOptions) Deterministic() bool { return o.flags&piface.MarshalDeterministic != 0 }
func (o marshalOptions) UseCachedSize() bool { return o.flags&piface.MarshalUseCachedSize != 0 }
// size is protoreflect.Methods.Size.
func (mi *MessageInfo) size(in piface.SizeInput) piface.SizeOutput {
var p pointer
if ms, ok := in.Message.(*messageState); ok {
p = ms.pointer()
} else {
p = in.Message.(*messageReflectWrapper).pointer()
}
size := mi.sizePointer(p, marshalOptions{
flags: in.Flags,
})
return piface.SizeOutput{Size: size}
}
func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
mi.init()
if p.IsNil() {
return 0
}
if opts.UseCachedSize() && mi.sizecacheOffset.IsValid() {
// The size cache contains the size + 1, to allow the
// zero value to be invalid, while also allowing for a
// 0 size to be cached.
if size := atomic.LoadInt32(p.Apply(mi.sizecacheOffset).Int32()); size > 0 {
return int(size - 1)
}
}
return mi.sizePointerSlow(p, opts)
}
func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int) {
if flags.ProtoLegacy && mi.isMessageSet {
size = sizeMessageSet(mi, p, opts)
if mi.sizecacheOffset.IsValid() {
atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size+1))
}
return size
}
if mi.extensionOffset.IsValid() {
e := p.Apply(mi.extensionOffset).Extensions()
size += mi.sizeExtensions(e, opts)
}
var lazy **protolazy.XXX_lazyUnmarshalInfo
var presence presence
if mi.presenceOffset.IsValid() {
presence = p.Apply(mi.presenceOffset).PresenceInfo()
if mi.lazyOffset.IsValid() {
lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
}
}
for _, f := range mi.orderedCoderFields {
if f.funcs.size == nil {
continue
}
fptr := p.Apply(f.offset)
if f.presenceIndex != noPresence {
if !presence.Present(f.presenceIndex) {
continue
}
if f.isLazy && fptr.AtomicGetPointer().IsNil() {
if lazyFields(opts) {
size += (*lazy).SizeField(uint32(f.num))
continue
} else {
mi.lazyUnmarshal(p, f.num)
}
}
size += f.funcs.size(fptr, f, opts)
continue
}
if f.isPointer && fptr.Elem().IsNil() {
continue
}
size += f.funcs.size(fptr, f, opts)
}
if mi.unknownOffset.IsValid() {
if u := mi.getUnknownBytes(p); u != nil {
size += len(*u)
}
}
if mi.sizecacheOffset.IsValid() {
if size > (math.MaxInt32 - 1) {
// The size is too large for the int32 sizecache field.
// We will need to recompute the size when encoding;
// unfortunately expensive, but better than invalid output.
atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), 0)
} else {
// The size cache contains the size + 1, to allow the
// zero value to be invalid, while also allowing for a
// 0 size to be cached.
atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size+1))
}
}
return size
}
// marshal is protoreflect.Methods.Marshal.
func (mi *MessageInfo) marshal(in piface.MarshalInput) (out piface.MarshalOutput, err error) {
var p pointer
if ms, ok := in.Message.(*messageState); ok {
p = ms.pointer()
} else {
p = in.Message.(*messageReflectWrapper).pointer()
}
b, err := mi.marshalAppendPointer(in.Buf, p, marshalOptions{
flags: in.Flags,
})
return piface.MarshalOutput{Buf: b}, err
}
func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOptions) ([]byte, error) {
mi.init()
if p.IsNil() {
return b, nil
}
if flags.ProtoLegacy && mi.isMessageSet {
return marshalMessageSet(mi, b, p, opts)
}
var err error
// The old marshaler encodes extensions at beginning.
if mi.extensionOffset.IsValid() {
e := p.Apply(mi.extensionOffset).Extensions()
// TODO: Special handling for MessageSet?
b, err = mi.appendExtensions(b, e, opts)
if err != nil {
return b, err
}
}
var lazy **protolazy.XXX_lazyUnmarshalInfo
var presence presence
if mi.presenceOffset.IsValid() {
presence = p.Apply(mi.presenceOffset).PresenceInfo()
if mi.lazyOffset.IsValid() {
lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
}
}
for _, f := range mi.orderedCoderFields {
if f.funcs.marshal == nil {
continue
}
fptr := p.Apply(f.offset)
if f.presenceIndex != noPresence {
if !presence.Present(f.presenceIndex) {
continue
}
if f.isLazy {
// Be careful, this field needs to be read atomically, like for a get
if f.isPointer && fptr.AtomicGetPointer().IsNil() {
if lazyFields(opts) {
b, _ = (*lazy).AppendField(b, uint32(f.num))
continue
} else {
mi.lazyUnmarshal(p, f.num)
}
}
b, err = f.funcs.marshal(b, fptr, f, opts)
if err != nil {
return b, err
}
continue
} else if f.isPointer && fptr.Elem().IsNil() {
continue
}
b, err = f.funcs.marshal(b, fptr, f, opts)
if err != nil {
return b, err
}
continue
}
if f.isPointer && fptr.Elem().IsNil() {
continue
}
b, err = f.funcs.marshal(b, fptr, f, opts)
if err != nil {
return b, err
}
}
if mi.unknownOffset.IsValid() && !mi.isMessageSet {
if u := mi.getUnknownBytes(p); u != nil {
b = append(b, (*u)...)
}
}
return b, nil
}
// fullyLazyExtensions returns true if we should attempt to keep extensions lazy over size and marshal.
func fullyLazyExtensions(opts marshalOptions) bool {
// When deterministic marshaling is requested, force an unmarshal for lazy
// extensions to produce a deterministic result, instead of passing through
// bytes lazily that may or may not match what Go Protobuf would produce.
return opts.flags&piface.MarshalDeterministic == 0
}
// lazyFields returns true if we should attempt to keep fields lazy over size and marshal.
func lazyFields(opts marshalOptions) bool {
// When deterministic marshaling is requested, force an unmarshal for lazy
// fields to produce a deterministic result, instead of passing through
// bytes lazily that may or may not match what Go Protobuf would produce.
return opts.flags&piface.MarshalDeterministic == 0
}
func (mi *MessageInfo) sizeExtensions(ext *map[int32]ExtensionField, opts marshalOptions) (n int) {
if ext == nil {
return 0
}
for _, x := range *ext {
xi := getExtensionFieldInfo(x.Type())
if xi.funcs.size == nil {
continue
}
if fullyLazyExtensions(opts) {
// Don't expand the extension, instead use the buffer to calculate size
if lb := x.lazyBuffer(); lb != nil {
// We got hold of the buffer, so it's still lazy.
n += len(lb)
continue
}
}
n += xi.funcs.size(x.Value(), xi.tagsize, opts)
}
return n
}
func (mi *MessageInfo) appendExtensions(b []byte, ext *map[int32]ExtensionField, opts marshalOptions) ([]byte, error) {
if ext == nil {
return b, nil
}
switch len(*ext) {
case 0:
return b, nil
case 1:
// Fast-path for one extension: Don't bother sorting the keys.
var err error
for _, x := range *ext {
xi := getExtensionFieldInfo(x.Type())
if fullyLazyExtensions(opts) {
// Don't expand the extension if it's still in wire format, instead use the buffer content.
if lb := x.lazyBuffer(); lb != nil {
b = append(b, lb...)
continue
}
}
b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
}
return b, err
default:
// Sort the keys to provide a deterministic encoding.
// Not sure this is required, but the old code does it.
keys := make([]int, 0, len(*ext))
for k := range *ext {
keys = append(keys, int(k))
}
sort.Ints(keys)
var err error
for _, k := range keys {
x := (*ext)[int32(k)]
xi := getExtensionFieldInfo(x.Type())
if fullyLazyExtensions(opts) {
// Don't expand the extension if it's still in wire format, instead use the buffer content.
if lb := x.lazyBuffer(); lb != nil {
b = append(b, lb...)
continue
}
}
b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
if err != nil {
return b, err
}
}
return b, nil
}
}

View File

@ -0,0 +1,21 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
type EnumInfo struct {
GoReflectType reflect.Type // int32 kind
Desc protoreflect.EnumDescriptor
}
func (t *EnumInfo) New(n protoreflect.EnumNumber) protoreflect.Enum {
return reflect.ValueOf(n).Convert(t.GoReflectType).Interface().(protoreflect.Enum)
}
func (t *EnumInfo) Descriptor() protoreflect.EnumDescriptor { return t.Desc }

View File

@ -0,0 +1,224 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"bytes"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
func equal(in protoiface.EqualInput) protoiface.EqualOutput {
return protoiface.EqualOutput{Equal: equalMessage(in.MessageA, in.MessageB)}
}
// equalMessage is a fast-path variant of protoreflect.equalMessage.
// It takes advantage of the internal messageState type to avoid
// unnecessary allocations, type assertions.
func equalMessage(mx, my protoreflect.Message) bool {
if mx == nil || my == nil {
return mx == my
}
if mx.Descriptor() != my.Descriptor() {
return false
}
msx, ok := mx.(*messageState)
if !ok {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
msy, ok := my.(*messageState)
if !ok {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
mi := msx.messageInfo()
miy := msy.messageInfo()
if mi != miy {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
mi.init()
// Compares regular fields
// Modified Message.Range code that compares two messages of the same type
// while going over the fields.
for _, ri := range mi.rangeInfos {
var fd protoreflect.FieldDescriptor
var vx, vy protoreflect.Value
switch ri := ri.(type) {
case *fieldInfo:
hx := ri.has(msx.pointer())
hy := ri.has(msy.pointer())
if hx != hy {
return false
}
if !hx {
continue
}
fd = ri.fieldDesc
vx = ri.get(msx.pointer())
vy = ri.get(msy.pointer())
case *oneofInfo:
fnx := ri.which(msx.pointer())
fny := ri.which(msy.pointer())
if fnx != fny {
return false
}
if fnx <= 0 {
continue
}
fi := mi.fields[fnx]
fd = fi.fieldDesc
vx = fi.get(msx.pointer())
vy = fi.get(msy.pointer())
}
if !equalValue(fd, vx, vy) {
return false
}
}
// Compare extensions.
// This is more complicated because mx or my could have empty/nil extension maps,
// however some populated extension map values are equal to nil extension maps.
emx := mi.extensionMap(msx.pointer())
emy := mi.extensionMap(msy.pointer())
if emx != nil {
for k, x := range *emx {
xd := x.Type().TypeDescriptor()
xv := x.Value()
var y ExtensionField
ok := false
if emy != nil {
y, ok = (*emy)[k]
}
// We need to treat empty lists as equal to nil values
if emy == nil || !ok {
if xd.IsList() && xv.List().Len() == 0 {
continue
}
return false
}
if !equalValue(xd, xv, y.Value()) {
return false
}
}
}
if emy != nil {
// emy may have extensions emx does not have, need to check them as well
for k, y := range *emy {
if emx != nil {
// emx has the field, so we already checked it
if _, ok := (*emx)[k]; ok {
continue
}
}
// Empty lists are equal to nil
if y.Type().TypeDescriptor().IsList() && y.Value().List().Len() == 0 {
continue
}
// Cant be equal if the extension is populated
return false
}
}
return equalUnknown(mx.GetUnknown(), my.GetUnknown())
}
func equalValue(fd protoreflect.FieldDescriptor, vx, vy protoreflect.Value) bool {
// slow path
if fd.Kind() != protoreflect.MessageKind {
return vx.Equal(vy)
}
// fast path special cases
if fd.IsMap() {
if fd.MapValue().Kind() == protoreflect.MessageKind {
return equalMessageMap(vx.Map(), vy.Map())
}
return vx.Equal(vy)
}
if fd.IsList() {
return equalMessageList(vx.List(), vy.List())
}
return equalMessage(vx.Message(), vy.Message())
}
// Mostly copied from protoreflect.equalMap.
// This variant only works for messages as map types.
// All other map types should be handled via Value.Equal.
func equalMessageMap(mx, my protoreflect.Map) bool {
if mx.Len() != my.Len() {
return false
}
equal := true
mx.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
if !my.Has(k) {
equal = false
return false
}
vy := my.Get(k)
equal = equalMessage(vx.Message(), vy.Message())
return equal
})
return equal
}
// Mostly copied from protoreflect.equalList.
// The only change is the usage of equalImpl instead of protoreflect.equalValue.
func equalMessageList(lx, ly protoreflect.List) bool {
if lx.Len() != ly.Len() {
return false
}
for i := 0; i < lx.Len(); i++ {
// We only operate on messages here since equalImpl will not call us in any other case.
if !equalMessage(lx.Get(i).Message(), ly.Get(i).Message()) {
return false
}
}
return true
}
// equalUnknown compares unknown fields by direct comparison on the raw bytes
// of each individual field number.
// Copied from protoreflect.equalUnknown.
func equalUnknown(x, y protoreflect.RawFields) bool {
if len(x) != len(y) {
return false
}
if bytes.Equal([]byte(x), []byte(y)) {
return true
}
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
for len(x) > 0 {
fnum, _, n := protowire.ConsumeField(x)
mx[fnum] = append(mx[fnum], x[:n]...)
x = x[n:]
}
for len(y) > 0 {
fnum, _, n := protowire.ConsumeField(y)
my[fnum] = append(my[fnum], y[:n]...)
y = y[n:]
}
if len(mx) != len(my) {
return false
}
for k, v1 := range mx {
if v2, ok := my[k]; !ok || !bytes.Equal([]byte(v1), []byte(v2)) {
return false
}
}
return true
}

View File

@ -0,0 +1,156 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"sync"
"sync/atomic"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// ExtensionInfo implements ExtensionType.
//
// This type contains a number of exported fields for legacy compatibility.
// The only non-deprecated use of this type is through the methods of the
// ExtensionType interface.
type ExtensionInfo struct {
// An ExtensionInfo may exist in several stages of initialization.
//
// extensionInfoUninitialized: Some or all of the legacy exported
// fields may be set, but none of the unexported fields have been
// initialized. This is the starting state for an ExtensionInfo
// in legacy generated code.
//
// extensionInfoDescInit: The desc field is set, but other unexported fields
// may not be initialized. Legacy exported fields may or may not be set.
// This is the starting state for an ExtensionInfo in newly generated code.
//
// extensionInfoFullInit: The ExtensionInfo is fully initialized.
// This state is only entered after lazy initialization is complete.
init uint32
mu sync.Mutex
goType reflect.Type
desc extensionTypeDescriptor
conv Converter
info *extensionFieldInfo // for fast-path method implementations
// ExtendedType is a typed nil-pointer to the parent message type that
// is being extended. It is possible for this to be unpopulated in v2
// since the message may no longer implement the MessageV1 interface.
//
// Deprecated: Use the ExtendedType method instead.
ExtendedType protoiface.MessageV1
// ExtensionType is the zero value of the extension type.
//
// For historical reasons, reflect.TypeOf(ExtensionType) and the
// type returned by InterfaceOf may not be identical.
//
// Deprecated: Use InterfaceOf(xt.Zero()) instead.
ExtensionType any
// Field is the field number of the extension.
//
// Deprecated: Use the Descriptor().Number method instead.
Field int32
// Name is the fully qualified name of extension.
//
// Deprecated: Use the Descriptor().FullName method instead.
Name string
// Tag is the protobuf struct tag used in the v1 API.
//
// Deprecated: Do not use.
Tag string
// Filename is the proto filename in which the extension is defined.
//
// Deprecated: Use Descriptor().ParentFile().Path() instead.
Filename string
}
// Stages of initialization: See the ExtensionInfo.init field.
const (
extensionInfoUninitialized = 0
extensionInfoDescInit = 1
extensionInfoFullInit = 2
)
func InitExtensionInfo(xi *ExtensionInfo, xd protoreflect.ExtensionDescriptor, goType reflect.Type) {
xi.goType = goType
xi.desc = extensionTypeDescriptor{xd, xi}
xi.init = extensionInfoDescInit
}
func (xi *ExtensionInfo) New() protoreflect.Value {
return xi.lazyInit().New()
}
func (xi *ExtensionInfo) Zero() protoreflect.Value {
return xi.lazyInit().Zero()
}
func (xi *ExtensionInfo) ValueOf(v any) protoreflect.Value {
return xi.lazyInit().PBValueOf(reflect.ValueOf(v))
}
func (xi *ExtensionInfo) InterfaceOf(v protoreflect.Value) any {
return xi.lazyInit().GoValueOf(v).Interface()
}
func (xi *ExtensionInfo) IsValidValue(v protoreflect.Value) bool {
return xi.lazyInit().IsValidPB(v)
}
func (xi *ExtensionInfo) IsValidInterface(v any) bool {
return xi.lazyInit().IsValidGo(reflect.ValueOf(v))
}
func (xi *ExtensionInfo) TypeDescriptor() protoreflect.ExtensionTypeDescriptor {
if atomic.LoadUint32(&xi.init) < extensionInfoDescInit {
xi.lazyInitSlow()
}
return &xi.desc
}
func (xi *ExtensionInfo) lazyInit() Converter {
if atomic.LoadUint32(&xi.init) < extensionInfoFullInit {
xi.lazyInitSlow()
}
return xi.conv
}
func (xi *ExtensionInfo) lazyInitSlow() {
xi.mu.Lock()
defer xi.mu.Unlock()
if xi.init == extensionInfoFullInit {
return
}
defer atomic.StoreUint32(&xi.init, extensionInfoFullInit)
if xi.desc.ExtensionDescriptor == nil {
xi.initFromLegacy()
}
if !xi.desc.ExtensionDescriptor.IsPlaceholder() {
if xi.ExtensionType == nil {
xi.initToLegacy()
}
xi.conv = NewConverter(xi.goType, xi.desc.ExtensionDescriptor)
xi.info = makeExtensionFieldInfo(xi.desc.ExtensionDescriptor)
xi.info.validation = newValidationInfo(xi.desc.ExtensionDescriptor, xi.goType)
}
}
type extensionTypeDescriptor struct {
protoreflect.ExtensionDescriptor
xi *ExtensionInfo
}
func (xtd *extensionTypeDescriptor) Type() protoreflect.ExtensionType {
return xtd.xi
}
func (xtd *extensionTypeDescriptor) Descriptor() protoreflect.ExtensionDescriptor {
return xtd.ExtensionDescriptor
}

View File

@ -0,0 +1,433 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"math/bits"
"os"
"reflect"
"sort"
"sync/atomic"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/protolazy"
"google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
piface "google.golang.org/protobuf/runtime/protoiface"
)
var enableLazy int32 = func() int32 {
if os.Getenv("GOPROTODEBUG") == "nolazy" {
return 0
}
return 1
}()
// EnableLazyUnmarshal enables lazy unmarshaling.
func EnableLazyUnmarshal(enable bool) {
if enable {
atomic.StoreInt32(&enableLazy, 1)
return
}
atomic.StoreInt32(&enableLazy, 0)
}
// LazyEnabled reports whether lazy unmarshalling is currently enabled.
func LazyEnabled() bool {
return atomic.LoadInt32(&enableLazy) != 0
}
// UnmarshalField unmarshals a field in a message.
func UnmarshalField(m interface{}, num protowire.Number) {
switch m := m.(type) {
case *messageState:
m.messageInfo().lazyUnmarshal(m.pointer(), num)
case *messageReflectWrapper:
m.messageInfo().lazyUnmarshal(m.pointer(), num)
default:
panic(fmt.Sprintf("unsupported wrapper type %T", m))
}
}
func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
var f *coderFieldInfo
if int(num) < len(mi.denseCoderFields) {
f = mi.denseCoderFields[num]
} else {
f = mi.coderFields[num]
}
if f == nil {
panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
}
lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
if !found && multipleEntries == nil {
panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
}
// The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
// Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
fp := pointerOfValue(reflect.New(f.ft))
if multipleEntries != nil {
for _, entry := range multipleEntries {
mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
}
} else {
mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
}
p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
}
func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
opts := lazyUnmarshalOptions
opts.flags |= flags
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = protowire.ConsumeVarint(b)
if n < 0 {
return errors.New("invalid wire data")
}
b = b[n:]
}
var num protowire.Number
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
return errors.New("invalid wire data")
} else {
num = protowire.Number(n)
}
wtyp := protowire.Type(tag & 7)
if num == f.num {
o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
if err == nil {
b = b[o.n:]
continue
}
if err != errUnknown {
return err
}
}
n := protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return errors.New("invalid wire data")
}
b = b[n:]
}
return nil
}
func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
fmi := f.validation.mi
if fmi == nil {
fd := mi.Desc.Fields().ByNumber(f.num)
if fd == nil {
return out, ValidationUnknown
}
messageName := fd.Message().FullName()
messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
if err != nil {
return out, ValidationUnknown
}
var ok bool
fmi, ok = messageType.(*MessageInfo)
if !ok {
return out, ValidationUnknown
}
}
fmi.init()
switch f.validation.typ {
case validationTypeMessage:
if wtyp != protowire.BytesType {
return out, ValidationWrongWireType
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return out, ValidationInvalid
}
out, st := fmi.validate(v, 0, opts)
out.n = n
return out, st
case validationTypeGroup:
if wtyp != protowire.StartGroupType {
return out, ValidationWrongWireType
}
out, st := fmi.validate(b, f.num, opts)
return out, st
default:
return out, ValidationUnknown
}
}
// unmarshalPointerLazy is similar to unmarshalPointerEager, but it
// specifically handles lazy unmarshalling. it expects lazyOffset and
// presenceOffset to both be valid.
func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
initialized := true
var requiredMask uint64
var lazy **protolazy.XXX_lazyUnmarshalInfo
var presence presence
var lazyIndex []protolazy.IndexEntry
var lastNum protowire.Number
outOfOrder := false
lazyDecode := false
presence = p.Apply(mi.presenceOffset).PresenceInfo()
lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
if !presence.AnyPresent(mi.presenceSize) {
if opts.CanBeLazy() {
// If the message contains existing data, we need to merge into it.
// Lazy unmarshaling doesn't merge, so only enable it when the
// message is empty (has no presence bitmap).
lazyDecode = true
if *lazy == nil {
*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
}
(*lazy).SetUnmarshalFlags(opts.flags)
if !opts.AliasBuffer() {
// Make a copy of the buffer for lazy unmarshaling.
// Set the AliasBuffer flag so recursive unmarshal
// operations reuse the copy.
b = append([]byte{}, b...)
opts.flags |= piface.UnmarshalAliasBuffer
}
(*lazy).SetBuffer(b)
}
}
// Track special handling of lazy fields.
//
// In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
// In the event that validation for a field fails, this map tracks handling of the field.
type lazyAction uint8
const (
lazyValidateOnly lazyAction = iota // validate the field only
lazyUnmarshalNow // eagerly unmarshal the field
lazyUnmarshalLater // unmarshal the field after the message is fully processed
)
var lazyFields map[*coderFieldInfo]lazyAction
var exts *map[int32]ExtensionField
start := len(b)
pos := 0
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = protowire.ConsumeVarint(b)
if n < 0 {
return out, errDecode
}
b = b[n:]
}
var num protowire.Number
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
return out, errors.New("invalid field number")
} else {
num = protowire.Number(n)
}
wtyp := protowire.Type(tag & 7)
if wtyp == protowire.EndGroupType {
if num != groupTag {
return out, errors.New("mismatching end group marker")
}
groupTag = 0
break
}
var f *coderFieldInfo
if int(num) < len(mi.denseCoderFields) {
f = mi.denseCoderFields[num]
} else {
f = mi.coderFields[num]
}
var n int
err := errUnknown
discardUnknown := false
Field:
switch {
case f != nil:
if f.funcs.unmarshal == nil {
break
}
if f.isLazy && lazyDecode {
switch {
case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
// Attempt to validate this field and leave it for later lazy unmarshaling.
o, valid := mi.skipField(b, f, wtyp, opts)
switch valid {
case ValidationValid:
// Skip over the valid field and continue.
err = nil
presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
requiredMask |= f.validation.requiredBit
if !o.initialized {
initialized = false
}
n = o.n
break Field
case ValidationInvalid:
return out, errors.New("invalid proto wire format")
case ValidationWrongWireType:
break Field
case ValidationUnknown:
if lazyFields == nil {
lazyFields = make(map[*coderFieldInfo]lazyAction)
}
if presence.Present(f.presenceIndex) {
// We were unable to determine if the field is valid or not,
// and we've already skipped over at least one instance of this
// field. Clear the presence bit (so if we stop decoding early,
// we don't leave a partially-initialized field around) and flag
// the field for unmarshaling before we return.
presence.ClearPresent(f.presenceIndex)
lazyFields[f] = lazyUnmarshalLater
discardUnknown = true
break Field
} else {
// We were unable to determine if the field is valid or not,
// but this is the first time we've seen it. Flag it as needing
// eager unmarshaling and fall through to the eager unmarshal case below.
lazyFields[f] = lazyUnmarshalNow
}
}
case lazyFields[f] == lazyUnmarshalLater:
// This field will be unmarshaled in a separate pass below.
// Skip over it here.
discardUnknown = true
break Field
default:
// Eagerly unmarshal the field.
}
}
if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
if p.Apply(f.offset).AtomicGetPointer().IsNil() {
mi.lazyUnmarshal(p, f.num)
}
}
var o unmarshalOutput
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
n = o.n
if err != nil {
break
}
requiredMask |= f.validation.requiredBit
if f.funcs.isInit != nil && !o.initialized {
initialized = false
}
if f.presenceIndex != noPresence {
presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
}
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
exts = p.Apply(mi.extensionOffset).Extensions()
if *exts == nil {
*exts = make(map[int32]ExtensionField)
}
}
if exts == nil {
break
}
var o unmarshalOutput
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
if err != nil {
break
}
n = o.n
if !o.initialized {
initialized = false
}
}
if err != nil {
if err != errUnknown {
return out, err
}
n = protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, errDecode
}
if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
u := mi.mutableUnknownBytes(p)
*u = protowire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...)
}
}
b = b[n:]
end := start - len(b)
if lazyDecode && f != nil && f.isLazy {
if num != lastNum {
lazyIndex = append(lazyIndex, protolazy.IndexEntry{
FieldNum: uint32(num),
Start: uint32(pos),
End: uint32(end),
})
} else {
i := len(lazyIndex) - 1
lazyIndex[i].End = uint32(end)
lazyIndex[i].MultipleContiguous = true
}
}
if num < lastNum {
outOfOrder = true
}
pos = end
lastNum = num
}
if groupTag != 0 {
return out, errors.New("missing end group marker")
}
if lazyFields != nil {
// Some fields failed validation, and now need to be unmarshaled.
for f, action := range lazyFields {
if action != lazyUnmarshalLater {
continue
}
initialized = false
if *lazy == nil {
*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
}
if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
return out, err
}
presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
}
}
if lazyDecode {
if outOfOrder {
sort.Slice(lazyIndex, func(i, j int) bool {
return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
(lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
lazyIndex[i].Start < lazyIndex[j].Start)
})
}
if *lazy == nil {
*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
}
(*lazy).SetIndex(lazyIndex)
}
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
initialized = false
}
if initialized {
out.initialized = true
}
out.n = start - len(b)
return out, nil
}

View File

@ -0,0 +1,219 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"strings"
"sync"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
// legacyEnumName returns the name of enums used in legacy code.
// It is neither the protobuf full name nor the qualified Go name,
// but rather an odd hybrid of both.
func legacyEnumName(ed protoreflect.EnumDescriptor) string {
var protoPkg string
enumName := string(ed.FullName())
if fd := ed.ParentFile(); fd != nil {
protoPkg = string(fd.Package())
enumName = strings.TrimPrefix(enumName, protoPkg+".")
}
if protoPkg == "" {
return strs.GoCamelCase(enumName)
}
return protoPkg + "." + strs.GoCamelCase(enumName)
}
// legacyWrapEnum wraps v as a protoreflect.Enum,
// where v must be a int32 kind and not implement the v2 API already.
func legacyWrapEnum(v reflect.Value) protoreflect.Enum {
et := legacyLoadEnumType(v.Type())
return et.New(protoreflect.EnumNumber(v.Int()))
}
var legacyEnumTypeCache sync.Map // map[reflect.Type]protoreflect.EnumType
// legacyLoadEnumType dynamically loads a protoreflect.EnumType for t,
// where t must be an int32 kind and not implement the v2 API already.
func legacyLoadEnumType(t reflect.Type) protoreflect.EnumType {
// Fast-path: check if a EnumType is cached for this concrete type.
if et, ok := legacyEnumTypeCache.Load(t); ok {
return et.(protoreflect.EnumType)
}
// Slow-path: derive enum descriptor and initialize EnumType.
var et protoreflect.EnumType
ed := LegacyLoadEnumDesc(t)
et = &legacyEnumType{
desc: ed,
goType: t,
}
if et, ok := legacyEnumTypeCache.LoadOrStore(t, et); ok {
return et.(protoreflect.EnumType)
}
return et
}
type legacyEnumType struct {
desc protoreflect.EnumDescriptor
goType reflect.Type
m sync.Map // map[protoreflect.EnumNumber]proto.Enum
}
func (t *legacyEnumType) New(n protoreflect.EnumNumber) protoreflect.Enum {
if e, ok := t.m.Load(n); ok {
return e.(protoreflect.Enum)
}
e := &legacyEnumWrapper{num: n, pbTyp: t, goTyp: t.goType}
t.m.Store(n, e)
return e
}
func (t *legacyEnumType) Descriptor() protoreflect.EnumDescriptor {
return t.desc
}
type legacyEnumWrapper struct {
num protoreflect.EnumNumber
pbTyp protoreflect.EnumType
goTyp reflect.Type
}
func (e *legacyEnumWrapper) Descriptor() protoreflect.EnumDescriptor {
return e.pbTyp.Descriptor()
}
func (e *legacyEnumWrapper) Type() protoreflect.EnumType {
return e.pbTyp
}
func (e *legacyEnumWrapper) Number() protoreflect.EnumNumber {
return e.num
}
func (e *legacyEnumWrapper) ProtoReflect() protoreflect.Enum {
return e
}
func (e *legacyEnumWrapper) protoUnwrap() any {
v := reflect.New(e.goTyp).Elem()
v.SetInt(int64(e.num))
return v.Interface()
}
var (
_ protoreflect.Enum = (*legacyEnumWrapper)(nil)
_ unwrapper = (*legacyEnumWrapper)(nil)
)
var legacyEnumDescCache sync.Map // map[reflect.Type]protoreflect.EnumDescriptor
// LegacyLoadEnumDesc returns an EnumDescriptor derived from the Go type,
// which must be an int32 kind and not implement the v2 API already.
//
// This is exported for testing purposes.
func LegacyLoadEnumDesc(t reflect.Type) protoreflect.EnumDescriptor {
// Fast-path: check if an EnumDescriptor is cached for this concrete type.
if ed, ok := legacyEnumDescCache.Load(t); ok {
return ed.(protoreflect.EnumDescriptor)
}
// Slow-path: initialize EnumDescriptor from the raw descriptor.
ev := reflect.Zero(t).Interface()
if _, ok := ev.(protoreflect.Enum); ok {
panic(fmt.Sprintf("%v already implements proto.Enum", t))
}
edV1, ok := ev.(enumV1)
if !ok {
return aberrantLoadEnumDesc(t)
}
b, idxs := edV1.EnumDescriptor()
var ed protoreflect.EnumDescriptor
if len(idxs) == 1 {
ed = legacyLoadFileDesc(b).Enums().Get(idxs[0])
} else {
md := legacyLoadFileDesc(b).Messages().Get(idxs[0])
for _, i := range idxs[1 : len(idxs)-1] {
md = md.Messages().Get(i)
}
ed = md.Enums().Get(idxs[len(idxs)-1])
}
if ed, ok := legacyEnumDescCache.LoadOrStore(t, ed); ok {
return ed.(protoreflect.EnumDescriptor)
}
return ed
}
var aberrantEnumDescCache sync.Map // map[reflect.Type]protoreflect.EnumDescriptor
// aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
// which must not implement protoreflect.Enum or enumV1.
//
// If the type does not implement enumV1, then there is no reliable
// way to derive the original protobuf type information.
// We are unable to use the global enum registry since it is
// unfortunately keyed by the protobuf full name, which we also do not know.
// Thus, this produces some bogus enum descriptor based on the Go type name.
func aberrantLoadEnumDesc(t reflect.Type) protoreflect.EnumDescriptor {
// Fast-path: check if an EnumDescriptor is cached for this concrete type.
if ed, ok := aberrantEnumDescCache.Load(t); ok {
return ed.(protoreflect.EnumDescriptor)
}
// Slow-path: construct a bogus, but unique EnumDescriptor.
ed := &filedesc.Enum{L2: new(filedesc.EnumL2)}
ed.L0.FullName = AberrantDeriveFullName(t) // e.g., github_com.user.repo.MyEnum
ed.L0.ParentFile = filedesc.SurrogateProto3
ed.L1.EditionFeatures = ed.L0.ParentFile.L1.EditionFeatures
ed.L2.Values.List = append(ed.L2.Values.List, filedesc.EnumValue{})
// TODO: Use the presence of a UnmarshalJSON method to determine proto2?
vd := &ed.L2.Values.List[0]
vd.L0.FullName = ed.L0.FullName + "_UNKNOWN" // e.g., github_com.user.repo.MyEnum_UNKNOWN
vd.L0.ParentFile = ed.L0.ParentFile
vd.L0.Parent = ed
// TODO: We could use the String method to obtain some enum value names by
// starting at 0 and print the enum until it produces invalid identifiers.
// An exhaustive query is clearly impractical, but can be best-effort.
if ed, ok := aberrantEnumDescCache.LoadOrStore(t, ed); ok {
return ed.(protoreflect.EnumDescriptor)
}
return ed
}
// AberrantDeriveFullName derives a fully qualified protobuf name for the given Go type
// The provided name is not guaranteed to be stable nor universally unique.
// It should be sufficiently unique within a program.
//
// This is exported for testing purposes.
func AberrantDeriveFullName(t reflect.Type) protoreflect.FullName {
sanitize := func(r rune) rune {
switch {
case r == '/':
return '.'
case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z', '0' <= r && r <= '9':
return r
default:
return '_'
}
}
prefix := strings.Map(sanitize, t.PkgPath())
suffix := strings.Map(sanitize, t.Name())
if suffix == "" {
suffix = fmt.Sprintf("UnknownX%X", reflect.ValueOf(t).Pointer())
}
ss := append(strings.Split(prefix, "."), suffix)
for i, s := range ss {
if s == "" || ('0' <= s[0] && s[0] <= '9') {
ss[i] = "x" + s
}
}
return protoreflect.FullName(strings.Join(ss, "."))
}

View File

@ -0,0 +1,92 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"encoding/binary"
"encoding/json"
"hash/crc32"
"math"
"reflect"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// These functions exist to support exported APIs in generated protobufs.
// While these are deprecated, they cannot be removed for compatibility reasons.
// LegacyEnumName returns the name of enums used in legacy code.
func (Export) LegacyEnumName(ed protoreflect.EnumDescriptor) string {
return legacyEnumName(ed)
}
// LegacyMessageTypeOf returns the protoreflect.MessageType for m,
// with name used as the message name if necessary.
func (Export) LegacyMessageTypeOf(m protoiface.MessageV1, name protoreflect.FullName) protoreflect.MessageType {
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type()
}
return legacyLoadMessageType(reflect.TypeOf(m), name)
}
// UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input.
// The input can either be a string representing the enum value by name,
// or a number representing the enum number itself.
func (Export) UnmarshalJSONEnum(ed protoreflect.EnumDescriptor, b []byte) (protoreflect.EnumNumber, error) {
if b[0] == '"' {
var name protoreflect.Name
if err := json.Unmarshal(b, &name); err != nil {
return 0, errors.New("invalid input for enum %v: %s", ed.FullName(), b)
}
ev := ed.Values().ByName(name)
if ev == nil {
return 0, errors.New("invalid value for enum %v: %s", ed.FullName(), name)
}
return ev.Number(), nil
} else {
var num protoreflect.EnumNumber
if err := json.Unmarshal(b, &num); err != nil {
return 0, errors.New("invalid input for enum %v: %s", ed.FullName(), b)
}
return num, nil
}
}
// CompressGZIP compresses the input as a GZIP-encoded file.
// The current implementation does no compression.
func (Export) CompressGZIP(in []byte) (out []byte) {
// RFC 1952, section 2.3.1.
var gzipHeader = [10]byte{0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff}
// RFC 1951, section 3.2.4.
var blockHeader [5]byte
const maxBlockSize = math.MaxUint16
numBlocks := 1 + len(in)/maxBlockSize
// RFC 1952, section 2.3.1.
var gzipFooter [8]byte
binary.LittleEndian.PutUint32(gzipFooter[0:4], crc32.ChecksumIEEE(in))
binary.LittleEndian.PutUint32(gzipFooter[4:8], uint32(len(in)))
// Encode the input without compression using raw DEFLATE blocks.
out = make([]byte, 0, len(gzipHeader)+len(blockHeader)*numBlocks+len(in)+len(gzipFooter))
out = append(out, gzipHeader[:]...)
for blockHeader[0] == 0 {
blockSize := maxBlockSize
if blockSize > len(in) {
blockHeader[0] = 0x01 // final bit per RFC 1951, section 3.2.3.
blockSize = len(in)
}
binary.LittleEndian.PutUint16(blockHeader[1:3], uint16(blockSize))
binary.LittleEndian.PutUint16(blockHeader[3:5], ^uint16(blockSize))
out = append(out, blockHeader[:]...)
out = append(out, in[:blockSize]...)
in = in[blockSize:]
}
out = append(out, gzipFooter[:]...)
return out
}

View File

@ -0,0 +1,177 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"google.golang.org/protobuf/internal/descopts"
"google.golang.org/protobuf/internal/encoding/messageset"
ptag "google.golang.org/protobuf/internal/encoding/tag"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
)
func (xi *ExtensionInfo) initToLegacy() {
xd := xi.desc
var parent protoiface.MessageV1
messageName := xd.ContainingMessage().FullName()
if mt, _ := protoregistry.GlobalTypes.FindMessageByName(messageName); mt != nil {
// Create a new parent message and unwrap it if possible.
mv := mt.New().Interface()
t := reflect.TypeOf(mv)
if mv, ok := mv.(unwrapper); ok {
t = reflect.TypeOf(mv.protoUnwrap())
}
// Check whether the message implements the legacy v1 Message interface.
mz := reflect.Zero(t).Interface()
if mz, ok := mz.(protoiface.MessageV1); ok {
parent = mz
}
}
// Determine the v1 extension type, which is unfortunately not the same as
// the v2 ExtensionType.GoType.
extType := xi.goType
switch extType.Kind() {
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields
}
// Reconstruct the legacy enum full name.
var enumName string
if xd.Kind() == protoreflect.EnumKind {
enumName = legacyEnumName(xd.Enum())
}
// Derive the proto file that the extension was declared within.
var filename string
if fd := xd.ParentFile(); fd != nil {
filename = fd.Path()
}
// For MessageSet extensions, the name used is the parent message.
name := xd.FullName()
if messageset.IsMessageSetExtension(xd) {
name = name.Parent()
}
xi.ExtendedType = parent
xi.ExtensionType = reflect.Zero(extType).Interface()
xi.Field = int32(xd.Number())
xi.Name = string(name)
xi.Tag = ptag.Marshal(xd, enumName)
xi.Filename = filename
}
// initFromLegacy initializes an ExtensionInfo from
// the contents of the deprecated exported fields of the type.
func (xi *ExtensionInfo) initFromLegacy() {
// The v1 API returns "type incomplete" descriptors where only the
// field number is specified. In such a case, use a placeholder.
if xi.ExtendedType == nil || xi.ExtensionType == nil {
xd := placeholderExtension{
name: protoreflect.FullName(xi.Name),
number: protoreflect.FieldNumber(xi.Field),
}
xi.desc = extensionTypeDescriptor{xd, xi}
return
}
// Resolve enum or message dependencies.
var ed protoreflect.EnumDescriptor
var md protoreflect.MessageDescriptor
t := reflect.TypeOf(xi.ExtensionType)
isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
if isOptional || isRepeated {
t = t.Elem()
}
switch v := reflect.Zero(t).Interface().(type) {
case protoreflect.Enum:
ed = v.Descriptor()
case enumV1:
ed = LegacyLoadEnumDesc(t)
case protoreflect.ProtoMessage:
md = v.ProtoReflect().Descriptor()
case messageV1:
md = LegacyLoadMessageDesc(t)
}
// Derive basic field information from the struct tag.
var evs protoreflect.EnumValueDescriptors
if ed != nil {
evs = ed.Values()
}
fd := ptag.Unmarshal(xi.Tag, t, evs).(*filedesc.Field)
// Construct a v2 ExtensionType.
xd := &filedesc.Extension{L2: new(filedesc.ExtensionL2)}
xd.L0.ParentFile = filedesc.SurrogateProto2
xd.L0.FullName = protoreflect.FullName(xi.Name)
xd.L1.Number = protoreflect.FieldNumber(xi.Field)
xd.L1.Cardinality = fd.L1.Cardinality
xd.L1.Kind = fd.L1.Kind
xd.L1.EditionFeatures = fd.L1.EditionFeatures
xd.L2.Default = fd.L1.Default
xd.L1.Extendee = Export{}.MessageDescriptorOf(xi.ExtendedType)
xd.L2.Enum = ed
xd.L2.Message = md
// Derive real extension field name for MessageSets.
if messageset.IsMessageSet(xd.L1.Extendee) && md.FullName() == xd.L0.FullName {
xd.L0.FullName = xd.L0.FullName.Append(messageset.ExtensionName)
}
tt := reflect.TypeOf(xi.ExtensionType)
if isOptional {
tt = tt.Elem()
}
xi.goType = tt
xi.desc = extensionTypeDescriptor{xd, xi}
}
type placeholderExtension struct {
name protoreflect.FullName
number protoreflect.FieldNumber
}
func (x placeholderExtension) ParentFile() protoreflect.FileDescriptor { return nil }
func (x placeholderExtension) Parent() protoreflect.Descriptor { return nil }
func (x placeholderExtension) Index() int { return 0 }
func (x placeholderExtension) Syntax() protoreflect.Syntax { return 0 }
func (x placeholderExtension) Name() protoreflect.Name { return x.name.Name() }
func (x placeholderExtension) FullName() protoreflect.FullName { return x.name }
func (x placeholderExtension) IsPlaceholder() bool { return true }
func (x placeholderExtension) Options() protoreflect.ProtoMessage { return descopts.Field }
func (x placeholderExtension) Number() protoreflect.FieldNumber { return x.number }
func (x placeholderExtension) Cardinality() protoreflect.Cardinality { return 0 }
func (x placeholderExtension) Kind() protoreflect.Kind { return 0 }
func (x placeholderExtension) HasJSONName() bool { return false }
func (x placeholderExtension) JSONName() string { return "[" + string(x.name) + "]" }
func (x placeholderExtension) TextName() string { return "[" + string(x.name) + "]" }
func (x placeholderExtension) HasPresence() bool { return false }
func (x placeholderExtension) HasOptionalKeyword() bool { return false }
func (x placeholderExtension) IsExtension() bool { return true }
func (x placeholderExtension) IsWeak() bool { return false }
func (x placeholderExtension) IsLazy() bool { return false }
func (x placeholderExtension) IsPacked() bool { return false }
func (x placeholderExtension) IsList() bool { return false }
func (x placeholderExtension) IsMap() bool { return false }
func (x placeholderExtension) MapKey() protoreflect.FieldDescriptor { return nil }
func (x placeholderExtension) MapValue() protoreflect.FieldDescriptor { return nil }
func (x placeholderExtension) HasDefault() bool { return false }
func (x placeholderExtension) Default() protoreflect.Value { return protoreflect.Value{} }
func (x placeholderExtension) DefaultEnumValue() protoreflect.EnumValueDescriptor { return nil }
func (x placeholderExtension) ContainingOneof() protoreflect.OneofDescriptor { return nil }
func (x placeholderExtension) ContainingMessage() protoreflect.MessageDescriptor { return nil }
func (x placeholderExtension) Enum() protoreflect.EnumDescriptor { return nil }
func (x placeholderExtension) Message() protoreflect.MessageDescriptor { return nil }
func (x placeholderExtension) ProtoType(protoreflect.FieldDescriptor) { return }
func (x placeholderExtension) ProtoInternal(pragma.DoNotImplement) { return }

View File

@ -0,0 +1,81 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"bytes"
"compress/gzip"
"io"
"sync"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
// Every enum and message type generated by protoc-gen-go since commit 2fc053c5
// on February 25th, 2016 has had a method to get the raw descriptor.
// Types that were not generated by protoc-gen-go or were generated prior
// to that version are not supported.
//
// The []byte returned is the encoded form of a FileDescriptorProto message
// compressed using GZIP. The []int is the path from the top-level file
// to the specific message or enum declaration.
type (
enumV1 interface {
EnumDescriptor() ([]byte, []int)
}
messageV1 interface {
Descriptor() ([]byte, []int)
}
)
var legacyFileDescCache sync.Map // map[*byte]protoreflect.FileDescriptor
// legacyLoadFileDesc unmarshals b as a compressed FileDescriptorProto message.
//
// This assumes that b is immutable and that b does not refer to part of a
// concatenated series of GZIP files (which would require shenanigans that
// rely on the concatenation properties of both protobufs and GZIP).
// File descriptors generated by protoc-gen-go do not rely on that property.
func legacyLoadFileDesc(b []byte) protoreflect.FileDescriptor {
// Fast-path: check whether we already have a cached file descriptor.
if fd, ok := legacyFileDescCache.Load(&b[0]); ok {
return fd.(protoreflect.FileDescriptor)
}
// Slow-path: decompress and unmarshal the file descriptor proto.
zr, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
panic(err)
}
b2, err := io.ReadAll(zr)
if err != nil {
panic(err)
}
fd := filedesc.Builder{
RawDescriptor: b2,
FileRegistry: resolverOnly{protoregistry.GlobalFiles}, // do not register back to global registry
}.Build().File
if fd, ok := legacyFileDescCache.LoadOrStore(&b[0], fd); ok {
return fd.(protoreflect.FileDescriptor)
}
return fd
}
type resolverOnly struct {
reg *protoregistry.Files
}
func (r resolverOnly) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
return r.reg.FindFileByPath(path)
}
func (r resolverOnly) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
return r.reg.FindDescriptorByName(name)
}
func (resolverOnly) RegisterFile(protoreflect.FileDescriptor) error {
return nil
}

View File

@ -0,0 +1,569 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"strings"
"sync"
"google.golang.org/protobuf/internal/descopts"
ptag "google.golang.org/protobuf/internal/encoding/tag"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// legacyWrapMessage wraps v as a protoreflect.Message,
// where v must be a *struct kind and not implement the v2 API already.
func legacyWrapMessage(v reflect.Value) protoreflect.Message {
t := v.Type()
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return aberrantMessage{v: v}
}
mt := legacyLoadMessageInfo(t, "")
return mt.MessageOf(v.Interface())
}
// legacyLoadMessageType dynamically loads a protoreflect.Type for t,
// where t must be not implement the v2 API already.
// The provided name is used if it cannot be determined from the message.
func legacyLoadMessageType(t reflect.Type, name protoreflect.FullName) protoreflect.MessageType {
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return aberrantMessageType{t}
}
return legacyLoadMessageInfo(t, name)
}
var legacyMessageTypeCache sync.Map // map[reflect.Type]*MessageInfo
// legacyLoadMessageInfo dynamically loads a *MessageInfo for t,
// where t must be a *struct kind and not implement the v2 API already.
// The provided name is used if it cannot be determined from the message.
func legacyLoadMessageInfo(t reflect.Type, name protoreflect.FullName) *MessageInfo {
// Fast-path: check if a MessageInfo is cached for this concrete type.
if mt, ok := legacyMessageTypeCache.Load(t); ok {
return mt.(*MessageInfo)
}
// Slow-path: derive message descriptor and initialize MessageInfo.
mi := &MessageInfo{
Desc: legacyLoadMessageDesc(t, name),
GoReflectType: t,
}
var hasMarshal, hasUnmarshal bool
v := reflect.Zero(t).Interface()
if _, hasMarshal = v.(legacyMarshaler); hasMarshal {
mi.methods.Marshal = legacyMarshal
// We have no way to tell whether the type's Marshal method
// supports deterministic serialization or not, but this
// preserves the v1 implementation's behavior of always
// calling Marshal methods when present.
mi.methods.Flags |= protoiface.SupportMarshalDeterministic
}
if _, hasUnmarshal = v.(legacyUnmarshaler); hasUnmarshal {
mi.methods.Unmarshal = legacyUnmarshal
}
if _, hasMerge := v.(legacyMerger); hasMerge || (hasMarshal && hasUnmarshal) {
mi.methods.Merge = legacyMerge
}
if mi, ok := legacyMessageTypeCache.LoadOrStore(t, mi); ok {
return mi.(*MessageInfo)
}
return mi
}
var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
// LegacyLoadMessageDesc returns an MessageDescriptor derived from the Go type,
// which should be a *struct kind and must not implement the v2 API already.
//
// This is exported for testing purposes.
func LegacyLoadMessageDesc(t reflect.Type) protoreflect.MessageDescriptor {
return legacyLoadMessageDesc(t, "")
}
func legacyLoadMessageDesc(t reflect.Type, name protoreflect.FullName) protoreflect.MessageDescriptor {
// Fast-path: check if a MessageDescriptor is cached for this concrete type.
if mi, ok := legacyMessageDescCache.Load(t); ok {
return mi.(protoreflect.MessageDescriptor)
}
// Slow-path: initialize MessageDescriptor from the raw descriptor.
mv := reflect.Zero(t).Interface()
if _, ok := mv.(protoreflect.ProtoMessage); ok {
panic(fmt.Sprintf("%v already implements proto.Message", t))
}
mdV1, ok := mv.(messageV1)
if !ok {
return aberrantLoadMessageDesc(t, name)
}
// If this is a dynamic message type where there isn't a 1-1 mapping between
// Go and protobuf types, calling the Descriptor method on the zero value of
// the message type isn't likely to work. If it panics, swallow the panic and
// continue as if the Descriptor method wasn't present.
b, idxs := func() ([]byte, []int) {
defer func() {
recover()
}()
return mdV1.Descriptor()
}()
if b == nil {
return aberrantLoadMessageDesc(t, name)
}
// If the Go type has no fields, then this might be a proto3 empty message
// from before the size cache was added. If there are any fields, check to
// see that at least one of them looks like something we generated.
if t.Elem().Kind() == reflect.Struct {
if nfield := t.Elem().NumField(); nfield > 0 {
hasProtoField := false
for i := 0; i < nfield; i++ {
f := t.Elem().Field(i)
if f.Tag.Get("protobuf") != "" || f.Tag.Get("protobuf_oneof") != "" || strings.HasPrefix(f.Name, "XXX_") {
hasProtoField = true
break
}
}
if !hasProtoField {
return aberrantLoadMessageDesc(t, name)
}
}
}
md := legacyLoadFileDesc(b).Messages().Get(idxs[0])
for _, i := range idxs[1:] {
md = md.Messages().Get(i)
}
if name != "" && md.FullName() != name {
panic(fmt.Sprintf("mismatching message name: got %v, want %v", md.FullName(), name))
}
if md, ok := legacyMessageDescCache.LoadOrStore(t, md); ok {
return md.(protoreflect.MessageDescriptor)
}
return md
}
var (
aberrantMessageDescLock sync.Mutex
aberrantMessageDescCache map[reflect.Type]protoreflect.MessageDescriptor
)
// aberrantLoadMessageDesc returns an MessageDescriptor derived from the Go type,
// which must not implement protoreflect.ProtoMessage or messageV1.
//
// This is a best-effort derivation of the message descriptor using the protobuf
// tags on the struct fields.
func aberrantLoadMessageDesc(t reflect.Type, name protoreflect.FullName) protoreflect.MessageDescriptor {
aberrantMessageDescLock.Lock()
defer aberrantMessageDescLock.Unlock()
if aberrantMessageDescCache == nil {
aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
}
return aberrantLoadMessageDescReentrant(t, name)
}
func aberrantLoadMessageDescReentrant(t reflect.Type, name protoreflect.FullName) protoreflect.MessageDescriptor {
// Fast-path: check if an MessageDescriptor is cached for this concrete type.
if md, ok := aberrantMessageDescCache[t]; ok {
return md
}
// Slow-path: construct a descriptor from the Go struct type (best-effort).
// Cache the MessageDescriptor early on so that we can resolve internal
// cyclic references.
md := &filedesc.Message{L2: new(filedesc.MessageL2)}
md.L0.FullName = aberrantDeriveMessageName(t, name)
md.L0.ParentFile = filedesc.SurrogateProto2
aberrantMessageDescCache[t] = md
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return md
}
// Try to determine if the message is using proto3 by checking scalars.
for i := 0; i < t.Elem().NumField(); i++ {
f := t.Elem().Field(i)
if tag := f.Tag.Get("protobuf"); tag != "" {
switch f.Type.Kind() {
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
md.L0.ParentFile = filedesc.SurrogateProto3
}
for _, s := range strings.Split(tag, ",") {
if s == "proto3" {
md.L0.ParentFile = filedesc.SurrogateProto3
}
}
}
}
md.L1.EditionFeatures = md.L0.ParentFile.L1.EditionFeatures
// Obtain a list of oneof wrapper types.
var oneofWrappers []reflect.Type
methods := make([]reflect.Method, 0, 2)
if m, ok := t.MethodByName("XXX_OneofFuncs"); ok {
methods = append(methods, m)
}
if m, ok := t.MethodByName("XXX_OneofWrappers"); ok {
methods = append(methods, m)
}
for _, fn := range methods {
for _, v := range fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))}) {
if vs, ok := v.Interface().([]any); ok {
for _, v := range vs {
oneofWrappers = append(oneofWrappers, reflect.TypeOf(v))
}
}
}
}
// Obtain a list of the extension ranges.
if fn, ok := t.MethodByName("ExtensionRangeArray"); ok {
vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0]
for i := 0; i < vs.Len(); i++ {
v := vs.Index(i)
md.L2.ExtensionRanges.List = append(md.L2.ExtensionRanges.List, [2]protoreflect.FieldNumber{
protoreflect.FieldNumber(v.FieldByName("Start").Int()),
protoreflect.FieldNumber(v.FieldByName("End").Int() + 1),
})
md.L2.ExtensionRangeOptions = append(md.L2.ExtensionRangeOptions, nil)
}
}
// Derive the message fields by inspecting the struct fields.
for i := 0; i < t.Elem().NumField(); i++ {
f := t.Elem().Field(i)
if tag := f.Tag.Get("protobuf"); tag != "" {
tagKey := f.Tag.Get("protobuf_key")
tagVal := f.Tag.Get("protobuf_val")
aberrantAppendField(md, f.Type, tag, tagKey, tagVal)
}
if tag := f.Tag.Get("protobuf_oneof"); tag != "" {
n := len(md.L2.Oneofs.List)
md.L2.Oneofs.List = append(md.L2.Oneofs.List, filedesc.Oneof{})
od := &md.L2.Oneofs.List[n]
od.L0.FullName = md.FullName().Append(protoreflect.Name(tag))
od.L0.ParentFile = md.L0.ParentFile
od.L1.EditionFeatures = md.L1.EditionFeatures
od.L0.Parent = md
od.L0.Index = n
for _, t := range oneofWrappers {
if t.Implements(f.Type) {
f := t.Elem().Field(0)
if tag := f.Tag.Get("protobuf"); tag != "" {
aberrantAppendField(md, f.Type, tag, "", "")
fd := &md.L2.Fields.List[len(md.L2.Fields.List)-1]
fd.L1.ContainingOneof = od
fd.L1.EditionFeatures = od.L1.EditionFeatures
od.L1.Fields.List = append(od.L1.Fields.List, fd)
}
}
}
}
}
return md
}
func aberrantDeriveMessageName(t reflect.Type, name protoreflect.FullName) protoreflect.FullName {
if name.IsValid() {
return name
}
func() {
defer func() { recover() }() // swallow possible nil panics
if m, ok := reflect.Zero(t).Interface().(interface{ XXX_MessageName() string }); ok {
name = protoreflect.FullName(m.XXX_MessageName())
}
}()
if name.IsValid() {
return name
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return AberrantDeriveFullName(t)
}
func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey, tagVal string) {
t := goType
isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
if isOptional || isRepeated {
t = t.Elem()
}
fd := ptag.Unmarshal(tag, t, placeholderEnumValues{}).(*filedesc.Field)
// Append field descriptor to the message.
n := len(md.L2.Fields.List)
md.L2.Fields.List = append(md.L2.Fields.List, *fd)
fd = &md.L2.Fields.List[n]
fd.L0.FullName = md.FullName().Append(fd.Name())
fd.L0.ParentFile = md.L0.ParentFile
fd.L0.Parent = md
fd.L0.Index = n
if fd.L1.EditionFeatures.IsPacked {
fd.L1.Options = func() protoreflect.ProtoMessage {
opts := descopts.Field.ProtoReflect().New()
if fd.L1.EditionFeatures.IsPacked {
opts.Set(opts.Descriptor().Fields().ByName("packed"), protoreflect.ValueOfBool(fd.L1.EditionFeatures.IsPacked))
}
return opts.Interface()
}
}
// Populate Enum and Message.
if fd.Enum() == nil && fd.Kind() == protoreflect.EnumKind {
switch v := reflect.Zero(t).Interface().(type) {
case protoreflect.Enum:
fd.L1.Enum = v.Descriptor()
default:
fd.L1.Enum = LegacyLoadEnumDesc(t)
}
}
if fd.Message() == nil && (fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind) {
switch v := reflect.Zero(t).Interface().(type) {
case protoreflect.ProtoMessage:
fd.L1.Message = v.ProtoReflect().Descriptor()
case messageV1:
fd.L1.Message = LegacyLoadMessageDesc(t)
default:
if t.Kind() == reflect.Map {
n := len(md.L1.Messages.List)
md.L1.Messages.List = append(md.L1.Messages.List, filedesc.Message{L2: new(filedesc.MessageL2)})
md2 := &md.L1.Messages.List[n]
md2.L0.FullName = md.FullName().Append(protoreflect.Name(strs.MapEntryName(string(fd.Name()))))
md2.L0.ParentFile = md.L0.ParentFile
md2.L0.Parent = md
md2.L0.Index = n
md2.L1.EditionFeatures = md.L1.EditionFeatures
md2.L1.IsMapEntry = true
md2.L2.Options = func() protoreflect.ProtoMessage {
opts := descopts.Message.ProtoReflect().New()
opts.Set(opts.Descriptor().Fields().ByName("map_entry"), protoreflect.ValueOfBool(true))
return opts.Interface()
}
aberrantAppendField(md2, t.Key(), tagKey, "", "")
aberrantAppendField(md2, t.Elem(), tagVal, "", "")
fd.L1.Message = md2
break
}
fd.L1.Message = aberrantLoadMessageDescReentrant(t, "")
}
}
}
type placeholderEnumValues struct {
protoreflect.EnumValueDescriptors
}
func (placeholderEnumValues) ByNumber(n protoreflect.EnumNumber) protoreflect.EnumValueDescriptor {
return filedesc.PlaceholderEnumValue(protoreflect.FullName(fmt.Sprintf("UNKNOWN_%d", n)))
}
// legacyMarshaler is the proto.Marshaler interface superseded by protoiface.Methoder.
type legacyMarshaler interface {
Marshal() ([]byte, error)
}
// legacyUnmarshaler is the proto.Unmarshaler interface superseded by protoiface.Methoder.
type legacyUnmarshaler interface {
Unmarshal([]byte) error
}
// legacyMerger is the proto.Merger interface superseded by protoiface.Methoder.
type legacyMerger interface {
Merge(protoiface.MessageV1)
}
var aberrantProtoMethods = &protoiface.Methods{
Marshal: legacyMarshal,
Unmarshal: legacyUnmarshal,
Merge: legacyMerge,
// We have no way to tell whether the type's Marshal method
// supports deterministic serialization or not, but this
// preserves the v1 implementation's behavior of always
// calling Marshal methods when present.
Flags: protoiface.SupportMarshalDeterministic,
}
func legacyMarshal(in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
v := in.Message.(unwrapper).protoUnwrap()
marshaler, ok := v.(legacyMarshaler)
if !ok {
return protoiface.MarshalOutput{}, errors.New("%T does not implement Marshal", v)
}
out, err := marshaler.Marshal()
if in.Buf != nil {
out = append(in.Buf, out...)
}
return protoiface.MarshalOutput{
Buf: out,
}, err
}
func legacyUnmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
v := in.Message.(unwrapper).protoUnwrap()
unmarshaler, ok := v.(legacyUnmarshaler)
if !ok {
return protoiface.UnmarshalOutput{}, errors.New("%T does not implement Unmarshal", v)
}
return protoiface.UnmarshalOutput{}, unmarshaler.Unmarshal(in.Buf)
}
func legacyMerge(in protoiface.MergeInput) protoiface.MergeOutput {
// Check whether this supports the legacy merger.
dstv := in.Destination.(unwrapper).protoUnwrap()
merger, ok := dstv.(legacyMerger)
if ok {
merger.Merge(Export{}.ProtoMessageV1Of(in.Source))
return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
}
// If legacy merger is unavailable, implement merge in terms of
// a marshal and unmarshal operation.
srcv := in.Source.(unwrapper).protoUnwrap()
marshaler, ok := srcv.(legacyMarshaler)
if !ok {
return protoiface.MergeOutput{}
}
dstv = in.Destination.(unwrapper).protoUnwrap()
unmarshaler, ok := dstv.(legacyUnmarshaler)
if !ok {
return protoiface.MergeOutput{}
}
if !in.Source.IsValid() {
// Legacy Marshal methods may not function on nil messages.
// Check for a typed nil source only after we confirm that
// legacy Marshal/Unmarshal methods are present, for
// consistency.
return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
}
b, err := marshaler.Marshal()
if err != nil {
return protoiface.MergeOutput{}
}
err = unmarshaler.Unmarshal(b)
if err != nil {
return protoiface.MergeOutput{}
}
return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
}
// aberrantMessageType implements MessageType for all types other than pointer-to-struct.
type aberrantMessageType struct {
t reflect.Type
}
func (mt aberrantMessageType) New() protoreflect.Message {
if mt.t.Kind() == reflect.Ptr {
return aberrantMessage{reflect.New(mt.t.Elem())}
}
return aberrantMessage{reflect.Zero(mt.t)}
}
func (mt aberrantMessageType) Zero() protoreflect.Message {
return aberrantMessage{reflect.Zero(mt.t)}
}
func (mt aberrantMessageType) GoType() reflect.Type {
return mt.t
}
func (mt aberrantMessageType) Descriptor() protoreflect.MessageDescriptor {
return LegacyLoadMessageDesc(mt.t)
}
// aberrantMessage implements Message for all types other than pointer-to-struct.
//
// When the underlying type implements legacyMarshaler or legacyUnmarshaler,
// the aberrant Message can be marshaled or unmarshaled. Otherwise, there is
// not much that can be done with values of this type.
type aberrantMessage struct {
v reflect.Value
}
// Reset implements the v1 proto.Message.Reset method.
func (m aberrantMessage) Reset() {
if mr, ok := m.v.Interface().(interface{ Reset() }); ok {
mr.Reset()
return
}
if m.v.Kind() == reflect.Ptr && !m.v.IsNil() {
m.v.Elem().Set(reflect.Zero(m.v.Type().Elem()))
}
}
func (m aberrantMessage) ProtoReflect() protoreflect.Message {
return m
}
func (m aberrantMessage) Descriptor() protoreflect.MessageDescriptor {
return LegacyLoadMessageDesc(m.v.Type())
}
func (m aberrantMessage) Type() protoreflect.MessageType {
return aberrantMessageType{m.v.Type()}
}
func (m aberrantMessage) New() protoreflect.Message {
if m.v.Type().Kind() == reflect.Ptr {
return aberrantMessage{reflect.New(m.v.Type().Elem())}
}
return aberrantMessage{reflect.Zero(m.v.Type())}
}
func (m aberrantMessage) Interface() protoreflect.ProtoMessage {
return m
}
func (m aberrantMessage) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
return
}
func (m aberrantMessage) Has(protoreflect.FieldDescriptor) bool {
return false
}
func (m aberrantMessage) Clear(protoreflect.FieldDescriptor) {
panic("invalid Message.Clear on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
if fd.Default().IsValid() {
return fd.Default()
}
panic("invalid Message.Get on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) Set(protoreflect.FieldDescriptor, protoreflect.Value) {
panic("invalid Message.Set on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) Mutable(protoreflect.FieldDescriptor) protoreflect.Value {
panic("invalid Message.Mutable on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) NewField(protoreflect.FieldDescriptor) protoreflect.Value {
panic("invalid Message.NewField on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) WhichOneof(protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
panic("invalid Message.WhichOneof descriptor on " + string(m.Descriptor().FullName()))
}
func (m aberrantMessage) GetUnknown() protoreflect.RawFields {
return nil
}
func (m aberrantMessage) SetUnknown(protoreflect.RawFields) {
// SetUnknown discards its input on messages which don't support unknown field storage.
}
func (m aberrantMessage) IsValid() bool {
if m.v.Kind() == reflect.Ptr {
return !m.v.IsNil()
}
return false
}
func (m aberrantMessage) ProtoMethods() *protoiface.Methods {
return aberrantProtoMethods
}
func (m aberrantMessage) protoUnwrap() any {
return m.v.Interface()
}

View File

@ -0,0 +1,203 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
type mergeOptions struct{}
func (o mergeOptions) Merge(dst, src proto.Message) {
proto.Merge(dst, src)
}
// merge is protoreflect.Methods.Merge.
func (mi *MessageInfo) merge(in protoiface.MergeInput) protoiface.MergeOutput {
dp, ok := mi.getPointer(in.Destination)
if !ok {
return protoiface.MergeOutput{}
}
sp, ok := mi.getPointer(in.Source)
if !ok {
return protoiface.MergeOutput{}
}
mi.mergePointer(dp, sp, mergeOptions{})
return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
}
func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
mi.init()
if dst.IsNil() {
panic(fmt.Sprintf("invalid value: merging into nil message"))
}
if src.IsNil() {
return
}
var presenceSrc presence
var presenceDst presence
if mi.presenceOffset.IsValid() {
presenceSrc = src.Apply(mi.presenceOffset).PresenceInfo()
presenceDst = dst.Apply(mi.presenceOffset).PresenceInfo()
}
for _, f := range mi.orderedCoderFields {
if f.funcs.merge == nil {
continue
}
sfptr := src.Apply(f.offset)
if f.presenceIndex != noPresence {
if !presenceSrc.Present(f.presenceIndex) {
continue
}
dfptr := dst.Apply(f.offset)
if f.isLazy {
if sfptr.AtomicGetPointer().IsNil() {
mi.lazyUnmarshal(src, f.num)
}
if presenceDst.Present(f.presenceIndex) && dfptr.AtomicGetPointer().IsNil() {
mi.lazyUnmarshal(dst, f.num)
}
}
f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
presenceDst.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
continue
}
if f.isPointer && sfptr.Elem().IsNil() {
continue
}
f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
}
if mi.extensionOffset.IsValid() {
sext := src.Apply(mi.extensionOffset).Extensions()
dext := dst.Apply(mi.extensionOffset).Extensions()
if *dext == nil {
*dext = make(map[int32]ExtensionField)
}
for num, sx := range *sext {
xt := sx.Type()
xi := getExtensionFieldInfo(xt)
if xi.funcs.merge == nil {
continue
}
dx := (*dext)[num]
var dv protoreflect.Value
if dx.Type() == sx.Type() {
dv = dx.Value()
}
if !dv.IsValid() && xi.unmarshalNeedsValue {
dv = xt.New()
}
dv = xi.funcs.merge(dv, sx.Value(), opts)
dx.Set(sx.Type(), dv)
(*dext)[num] = dx
}
}
if mi.unknownOffset.IsValid() {
su := mi.getUnknownBytes(src)
if su != nil && len(*su) > 0 {
du := mi.mutableUnknownBytes(dst)
*du = append(*du, *su...)
}
}
}
func mergeScalarValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
return src
}
func mergeBytesValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
return protoreflect.ValueOfBytes(append(emptyBuf[:], src.Bytes()...))
}
func mergeListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
dstl := dst.List()
srcl := src.List()
for i, llen := 0, srcl.Len(); i < llen; i++ {
dstl.Append(srcl.Get(i))
}
return dst
}
func mergeBytesListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
dstl := dst.List()
srcl := src.List()
for i, llen := 0, srcl.Len(); i < llen; i++ {
sb := srcl.Get(i).Bytes()
db := append(emptyBuf[:], sb...)
dstl.Append(protoreflect.ValueOfBytes(db))
}
return dst
}
func mergeMessageListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
dstl := dst.List()
srcl := src.List()
for i, llen := 0, srcl.Len(); i < llen; i++ {
sm := srcl.Get(i).Message()
dm := proto.Clone(sm.Interface()).ProtoReflect()
dstl.Append(protoreflect.ValueOfMessage(dm))
}
return dst
}
func mergeMessageValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
opts.Merge(dst.Message().Interface(), src.Message().Interface())
return dst
}
func mergeMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
if f.mi != nil {
if dst.Elem().IsNil() {
dst.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
}
f.mi.mergePointer(dst.Elem(), src.Elem(), opts)
} else {
dm := dst.AsValueOf(f.ft).Elem()
sm := src.AsValueOf(f.ft).Elem()
if dm.IsNil() {
dm.Set(reflect.New(f.ft.Elem()))
}
opts.Merge(asMessage(dm), asMessage(sm))
}
}
func mergeMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
for _, sp := range src.PointerSlice() {
dm := reflect.New(f.ft.Elem().Elem())
if f.mi != nil {
f.mi.mergePointer(pointerOfValue(dm), sp, opts)
} else {
opts.Merge(asMessage(dm), asMessage(sp.AsValueOf(f.ft.Elem().Elem())))
}
dst.AppendPointerSlice(pointerOfValue(dm))
}
}
func mergeBytes(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Bytes() = append(emptyBuf[:], *src.Bytes()...)
}
func mergeBytesNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Bytes()
if len(v) > 0 {
*dst.Bytes() = append(emptyBuf[:], v...)
}
}
func mergeBytesSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.BytesSlice()
for _, v := range *src.BytesSlice() {
*ds = append(*ds, append(emptyBuf[:], v...))
}
}

View File

@ -0,0 +1,209 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate-types. DO NOT EDIT.
package impl
import ()
func mergeBool(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Bool() = *src.Bool()
}
func mergeBoolNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Bool()
if v != false {
*dst.Bool() = v
}
}
func mergeBoolPtr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.BoolPtr()
if p != nil {
v := *p
*dst.BoolPtr() = &v
}
}
func mergeBoolSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.BoolSlice()
ss := src.BoolSlice()
*ds = append(*ds, *ss...)
}
func mergeInt32(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Int32() = *src.Int32()
}
func mergeInt32NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Int32()
if v != 0 {
*dst.Int32() = v
}
}
func mergeInt32Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Int32Ptr()
if p != nil {
v := *p
*dst.Int32Ptr() = &v
}
}
func mergeInt32Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Int32Slice()
ss := src.Int32Slice()
*ds = append(*ds, *ss...)
}
func mergeUint32(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Uint32() = *src.Uint32()
}
func mergeUint32NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Uint32()
if v != 0 {
*dst.Uint32() = v
}
}
func mergeUint32Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Uint32Ptr()
if p != nil {
v := *p
*dst.Uint32Ptr() = &v
}
}
func mergeUint32Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Uint32Slice()
ss := src.Uint32Slice()
*ds = append(*ds, *ss...)
}
func mergeInt64(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Int64() = *src.Int64()
}
func mergeInt64NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Int64()
if v != 0 {
*dst.Int64() = v
}
}
func mergeInt64Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Int64Ptr()
if p != nil {
v := *p
*dst.Int64Ptr() = &v
}
}
func mergeInt64Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Int64Slice()
ss := src.Int64Slice()
*ds = append(*ds, *ss...)
}
func mergeUint64(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Uint64() = *src.Uint64()
}
func mergeUint64NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Uint64()
if v != 0 {
*dst.Uint64() = v
}
}
func mergeUint64Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Uint64Ptr()
if p != nil {
v := *p
*dst.Uint64Ptr() = &v
}
}
func mergeUint64Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Uint64Slice()
ss := src.Uint64Slice()
*ds = append(*ds, *ss...)
}
func mergeFloat32(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Float32() = *src.Float32()
}
func mergeFloat32NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Float32()
if v != 0 {
*dst.Float32() = v
}
}
func mergeFloat32Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Float32Ptr()
if p != nil {
v := *p
*dst.Float32Ptr() = &v
}
}
func mergeFloat32Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Float32Slice()
ss := src.Float32Slice()
*ds = append(*ds, *ss...)
}
func mergeFloat64(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.Float64() = *src.Float64()
}
func mergeFloat64NoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.Float64()
if v != 0 {
*dst.Float64() = v
}
}
func mergeFloat64Ptr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.Float64Ptr()
if p != nil {
v := *p
*dst.Float64Ptr() = &v
}
}
func mergeFloat64Slice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.Float64Slice()
ss := src.Float64Slice()
*ds = append(*ds, *ss...)
}
func mergeString(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
*dst.String() = *src.String()
}
func mergeStringNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
v := *src.String()
if v != "" {
*dst.String() = v
}
}
func mergeStringPtr(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
p := *src.StringPtr()
if p != nil {
v := *p
*dst.StringPtr() = &v
}
}
func mergeStringSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
ds := dst.StringSlice()
ss := src.StringSlice()
*ds = append(*ds, *ss...)
}

View File

@ -0,0 +1,283 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
)
// MessageInfo provides protobuf related functionality for a given Go type
// that represents a message. A given instance of MessageInfo is tied to
// exactly one Go type, which must be a pointer to a struct type.
//
// The exported fields must be populated before any methods are called
// and cannot be mutated after set.
type MessageInfo struct {
// GoReflectType is the underlying message Go type and must be populated.
GoReflectType reflect.Type // pointer to struct
// Desc is the underlying message descriptor type and must be populated.
Desc protoreflect.MessageDescriptor
// Deprecated: Exporter will be removed the next time we bump
// protoimpl.GenVersion. See https://github.com/golang/protobuf/issues/1640
Exporter exporter
// OneofWrappers is list of pointers to oneof wrapper struct types.
OneofWrappers []any
initMu sync.Mutex // protects all unexported fields
initDone uint32
reflectMessageInfo // for reflection implementation
coderMessageInfo // for fast-path method implementations
}
// exporter is a function that returns a reference to the ith field of v,
// where v is a pointer to a struct. It returns nil if it does not support
// exporting the requested field (e.g., already exported).
type exporter func(v any, i int) any
// getMessageInfo returns the MessageInfo for any message type that
// is generated by our implementation of protoc-gen-go (for v2 and on).
// If it is unable to obtain a MessageInfo, it returns nil.
func getMessageInfo(mt reflect.Type) *MessageInfo {
m, ok := reflect.Zero(mt).Interface().(protoreflect.ProtoMessage)
if !ok {
return nil
}
mr, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *MessageInfo })
if !ok {
return nil
}
return mr.ProtoMessageInfo()
}
func (mi *MessageInfo) init() {
// This function is called in the hot path. Inline the sync.Once logic,
// since allocating a closure for Once.Do is expensive.
// Keep init small to ensure that it can be inlined.
if atomic.LoadUint32(&mi.initDone) == 0 {
mi.initOnce()
}
}
func (mi *MessageInfo) initOnce() {
mi.initMu.Lock()
defer mi.initMu.Unlock()
if mi.initDone == 1 {
return
}
if opaqueInitHook(mi) {
return
}
t := mi.GoReflectType
if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
panic(fmt.Sprintf("got %v, want *struct kind", t))
}
t = t.Elem()
si := mi.makeStructInfo(t)
mi.makeReflectFuncs(t, si)
mi.makeCoderMethods(t, si)
atomic.StoreUint32(&mi.initDone, 1)
}
// getPointer returns the pointer for a message, which should be of
// the type of the MessageInfo. If the message is of a different type,
// it returns ok==false.
func (mi *MessageInfo) getPointer(m protoreflect.Message) (p pointer, ok bool) {
switch m := m.(type) {
case *messageState:
return m.pointer(), m.messageInfo() == mi
case *messageReflectWrapper:
return m.pointer(), m.messageInfo() == mi
}
return pointer{}, false
}
type (
SizeCache = int32
WeakFields = map[int32]protoreflect.ProtoMessage
UnknownFields = unknownFieldsA // TODO: switch to unknownFieldsB
unknownFieldsA = []byte
unknownFieldsB = *[]byte
ExtensionFields = map[int32]ExtensionField
)
var (
sizecacheType = reflect.TypeOf(SizeCache(0))
unknownFieldsAType = reflect.TypeOf(unknownFieldsA(nil))
unknownFieldsBType = reflect.TypeOf(unknownFieldsB(nil))
extensionFieldsType = reflect.TypeOf(ExtensionFields(nil))
)
type structInfo struct {
sizecacheOffset offset
sizecacheType reflect.Type
unknownOffset offset
unknownType reflect.Type
extensionOffset offset
extensionType reflect.Type
lazyOffset offset
presenceOffset offset
fieldsByNumber map[protoreflect.FieldNumber]reflect.StructField
oneofsByName map[protoreflect.Name]reflect.StructField
oneofWrappersByType map[reflect.Type]protoreflect.FieldNumber
oneofWrappersByNumber map[protoreflect.FieldNumber]reflect.Type
}
func (mi *MessageInfo) makeStructInfo(t reflect.Type) structInfo {
si := structInfo{
sizecacheOffset: invalidOffset,
unknownOffset: invalidOffset,
extensionOffset: invalidOffset,
lazyOffset: invalidOffset,
presenceOffset: invalidOffset,
fieldsByNumber: map[protoreflect.FieldNumber]reflect.StructField{},
oneofsByName: map[protoreflect.Name]reflect.StructField{},
oneofWrappersByType: map[reflect.Type]protoreflect.FieldNumber{},
oneofWrappersByNumber: map[protoreflect.FieldNumber]reflect.Type{},
}
fieldLoop:
for i := 0; i < t.NumField(); i++ {
switch f := t.Field(i); f.Name {
case genid.SizeCache_goname, genid.SizeCacheA_goname:
if f.Type == sizecacheType {
si.sizecacheOffset = offsetOf(f)
si.sizecacheType = f.Type
}
case genid.UnknownFields_goname, genid.UnknownFieldsA_goname:
if f.Type == unknownFieldsAType || f.Type == unknownFieldsBType {
si.unknownOffset = offsetOf(f)
si.unknownType = f.Type
}
case genid.ExtensionFields_goname, genid.ExtensionFieldsA_goname, genid.ExtensionFieldsB_goname:
if f.Type == extensionFieldsType {
si.extensionOffset = offsetOf(f)
si.extensionType = f.Type
}
case "lazyFields", "XXX_lazyUnmarshalInfo":
si.lazyOffset = offsetOf(f)
case "XXX_presence":
si.presenceOffset = offsetOf(f)
default:
for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
if len(s) > 0 && strings.Trim(s, "0123456789") == "" {
n, _ := strconv.ParseUint(s, 10, 64)
si.fieldsByNumber[protoreflect.FieldNumber(n)] = f
continue fieldLoop
}
}
if s := f.Tag.Get("protobuf_oneof"); len(s) > 0 {
si.oneofsByName[protoreflect.Name(s)] = f
continue fieldLoop
}
}
}
// Derive a mapping of oneof wrappers to fields.
oneofWrappers := mi.OneofWrappers
methods := make([]reflect.Method, 0, 2)
if m, ok := reflect.PtrTo(t).MethodByName("XXX_OneofFuncs"); ok {
methods = append(methods, m)
}
if m, ok := reflect.PtrTo(t).MethodByName("XXX_OneofWrappers"); ok {
methods = append(methods, m)
}
for _, fn := range methods {
for _, v := range fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))}) {
if vs, ok := v.Interface().([]any); ok {
oneofWrappers = vs
}
}
}
for _, v := range oneofWrappers {
tf := reflect.TypeOf(v).Elem()
f := tf.Field(0)
for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
if len(s) > 0 && strings.Trim(s, "0123456789") == "" {
n, _ := strconv.ParseUint(s, 10, 64)
si.oneofWrappersByType[tf] = protoreflect.FieldNumber(n)
si.oneofWrappersByNumber[protoreflect.FieldNumber(n)] = tf
break
}
}
}
return si
}
func (mi *MessageInfo) New() protoreflect.Message {
m := reflect.New(mi.GoReflectType.Elem()).Interface()
if r, ok := m.(protoreflect.ProtoMessage); ok {
return r.ProtoReflect()
}
return mi.MessageOf(m)
}
func (mi *MessageInfo) Zero() protoreflect.Message {
return mi.MessageOf(reflect.Zero(mi.GoReflectType).Interface())
}
func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor {
return mi.Desc
}
func (mi *MessageInfo) Enum(i int) protoreflect.EnumType {
mi.init()
fd := mi.Desc.Fields().Get(i)
return Export{}.EnumTypeOf(mi.fieldTypes[fd.Number()])
}
func (mi *MessageInfo) Message(i int) protoreflect.MessageType {
mi.init()
fd := mi.Desc.Fields().Get(i)
switch {
case fd.IsMap():
return mapEntryType{fd.Message(), mi.fieldTypes[fd.Number()]}
default:
return Export{}.MessageTypeOf(mi.fieldTypes[fd.Number()])
}
}
type mapEntryType struct {
desc protoreflect.MessageDescriptor
valType any // zero value of enum or message type
}
func (mt mapEntryType) New() protoreflect.Message {
return nil
}
func (mt mapEntryType) Zero() protoreflect.Message {
return nil
}
func (mt mapEntryType) Descriptor() protoreflect.MessageDescriptor {
return mt.desc
}
func (mt mapEntryType) Enum(i int) protoreflect.EnumType {
fd := mt.desc.Fields().Get(i)
if fd.Enum() == nil {
return nil
}
return Export{}.EnumTypeOf(mt.valType)
}
func (mt mapEntryType) Message(i int) protoreflect.MessageType {
fd := mt.desc.Fields().Get(i)
if fd.Message() == nil {
return nil
}
return Export{}.MessageTypeOf(mt.valType)
}

View File

@ -0,0 +1,627 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"math"
"reflect"
"strings"
"sync/atomic"
"google.golang.org/protobuf/reflect/protoreflect"
)
type opaqueStructInfo struct {
structInfo
}
// isOpaque determines whether a protobuf message type is on the Opaque API. It
// checks whether the type is a Go struct that protoc-gen-go would generate.
//
// This function only detects newly generated messages from the v2
// implementation of protoc-gen-go. It is unable to classify generated messages
// that are too old or those that are generated by a different generator
// such as protoc-gen-gogo.
func isOpaque(t reflect.Type) bool {
// The current detection mechanism is to simply check the first field
// for a struct tag with the "protogen" key.
if t.Kind() == reflect.Struct && t.NumField() > 0 {
pgt := t.Field(0).Tag.Get("protogen")
return strings.HasPrefix(pgt, "opaque.")
}
return false
}
func opaqueInitHook(mi *MessageInfo) bool {
mt := mi.GoReflectType.Elem()
si := opaqueStructInfo{
structInfo: mi.makeStructInfo(mt),
}
if !isOpaque(mt) {
return false
}
defer atomic.StoreUint32(&mi.initDone, 1)
mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
fds := mi.Desc.Fields()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
var fi fieldInfo
usePresence, _ := usePresenceForField(si, fd)
switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
// Oneofs are no different for opaque.
fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
case fd.IsMap():
fi = mi.fieldInfoForMapOpaque(si, fd, fs)
case fd.IsList() && fd.Message() == nil && usePresence:
fi = mi.fieldInfoForScalarListOpaque(si, fd, fs)
case fd.IsList() && fd.Message() == nil:
// Proto3 lists without presence can use same access methods as open
fi = fieldInfoForList(fd, fs, mi.Exporter)
case fd.IsList() && usePresence:
fi = mi.fieldInfoForMessageListOpaque(si, fd, fs)
case fd.IsList():
// Proto3 opaque messages that does not need presence bitmap.
// Different representation than open struct, but same logic
fi = mi.fieldInfoForMessageListOpaqueNoPresence(si, fd, fs)
case fd.Message() != nil && usePresence:
fi = mi.fieldInfoForMessageOpaque(si, fd, fs)
case fd.Message() != nil:
// Proto3 messages without presence can use same access methods as open
fi = fieldInfoForMessage(fd, fs, mi.Exporter)
default:
fi = mi.fieldInfoForScalarOpaque(si, fd, fs)
}
mi.fields[fd.Number()] = &fi
}
mi.oneofs = map[protoreflect.Name]*oneofInfo{}
for i := 0; i < mi.Desc.Oneofs().Len(); i++ {
od := mi.Desc.Oneofs().Get(i)
mi.oneofs[od.Name()] = makeOneofInfoOpaque(mi, od, si.structInfo, mi.Exporter)
}
mi.denseFields = make([]*fieldInfo, fds.Len()*2)
for i := 0; i < fds.Len(); i++ {
if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
}
}
for i := 0; i < fds.Len(); {
fd := fds.Get(i)
if od := fd.ContainingOneof(); od != nil && !fd.ContainingOneof().IsSynthetic() {
mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
i += od.Fields().Len()
} else {
mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
i++
}
}
mi.makeExtensionFieldsFunc(mt, si.structInfo)
mi.makeUnknownFieldsFunc(mt, si.structInfo)
mi.makeOpaqueCoderMethods(mt, si)
mi.makeFieldTypes(si.structInfo)
return true
}
func makeOneofInfoOpaque(mi *MessageInfo, od protoreflect.OneofDescriptor, si structInfo, x exporter) *oneofInfo {
oi := &oneofInfo{oneofDesc: od}
if od.IsSynthetic() {
fd := od.Fields().Get(0)
index, _ := presenceIndex(mi.Desc, fd)
oi.which = func(p pointer) protoreflect.FieldNumber {
if p.IsNil() {
return 0
}
if !mi.present(p, index) {
return 0
}
return od.Fields().Get(0).Number()
}
return oi
}
// Dispatch to non-opaque oneof implementation for non-synthetic oneofs.
return makeOneofInfo(od, si, x)
}
func (mi *MessageInfo) fieldInfoForMapOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Map {
panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
}
fieldOffset := offsetOf(fs)
conv := NewConverter(ft, fd)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
// Don't bother checking presence bits, since we need to
// look at the map length even if the presence bit is set.
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return rv.Len() > 0
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
pv := conv.GoValueOf(v)
if pv.IsNil() {
panic(fmt.Sprintf("invalid value: setting map field to read-only value"))
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(pv)
},
mutable: func(p pointer) protoreflect.Value {
v := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if v.IsNil() {
v.Set(reflect.MakeMap(fs.Type))
}
return conv.PBValueOf(v)
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func (mi *MessageInfo) fieldInfoForScalarListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Slice {
panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
}
conv := NewConverter(reflect.PtrTo(ft), fd)
fieldOffset := offsetOf(fs)
index, _ := presenceIndex(mi.Desc, fd)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return rv.Len() > 0
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
if rv.Elem().Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
pv := conv.GoValueOf(v)
if pv.IsNil() {
panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
}
mi.setPresent(p, index)
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(pv.Elem())
},
mutable: func(p pointer) protoreflect.Value {
mi.setPresent(p, index)
return conv.PBValueOf(p.Apply(fieldOffset).AsValueOf(fs.Type))
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func (mi *MessageInfo) fieldInfoForMessageListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
}
conv := NewConverter(ft, fd)
fieldOffset := offsetOf(fs)
index, _ := presenceIndex(mi.Desc, fd)
fieldNumber := fd.Number()
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
if !mi.present(p, index) {
return false
}
sp := p.Apply(fieldOffset).AtomicGetPointer()
if sp.IsNil() {
// Lazily unmarshal this field.
mi.lazyUnmarshal(p, fieldNumber)
sp = p.Apply(fieldOffset).AtomicGetPointer()
}
rv := sp.AsValueOf(fs.Type.Elem())
return rv.Elem().Len() > 0
},
clear: func(p pointer) {
fp := p.Apply(fieldOffset)
sp := fp.AtomicGetPointer()
if sp.IsNil() {
sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
mi.setPresent(p, index)
}
rv := sp.AsValueOf(fs.Type.Elem())
rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
if !mi.present(p, index) {
return conv.Zero()
}
sp := p.Apply(fieldOffset).AtomicGetPointer()
if sp.IsNil() {
// Lazily unmarshal this field.
mi.lazyUnmarshal(p, fieldNumber)
sp = p.Apply(fieldOffset).AtomicGetPointer()
}
rv := sp.AsValueOf(fs.Type.Elem())
if rv.Elem().Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
fp := p.Apply(fieldOffset)
sp := fp.AtomicGetPointer()
if sp.IsNil() {
sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
mi.setPresent(p, index)
}
rv := sp.AsValueOf(fs.Type.Elem())
val := conv.GoValueOf(v)
if val.IsNil() {
panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
} else {
rv.Elem().Set(val.Elem())
}
},
mutable: func(p pointer) protoreflect.Value {
fp := p.Apply(fieldOffset)
sp := fp.AtomicGetPointer()
if sp.IsNil() {
if mi.present(p, index) {
// Lazily unmarshal this field.
mi.lazyUnmarshal(p, fieldNumber)
sp = p.Apply(fieldOffset).AtomicGetPointer()
} else {
sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
mi.setPresent(p, index)
}
}
rv := sp.AsValueOf(fs.Type.Elem())
return conv.PBValueOf(rv)
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func (mi *MessageInfo) fieldInfoForMessageListOpaqueNoPresence(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
}
conv := NewConverter(ft, fd)
fieldOffset := offsetOf(fs)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
sp := p.Apply(fieldOffset).AtomicGetPointer()
if sp.IsNil() {
return false
}
rv := sp.AsValueOf(fs.Type.Elem())
return rv.Elem().Len() > 0
},
clear: func(p pointer) {
sp := p.Apply(fieldOffset).AtomicGetPointer()
if !sp.IsNil() {
rv := sp.AsValueOf(fs.Type.Elem())
rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
}
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
sp := p.Apply(fieldOffset).AtomicGetPointer()
if sp.IsNil() {
return conv.Zero()
}
rv := sp.AsValueOf(fs.Type.Elem())
if rv.Elem().Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() {
rv.Set(reflect.New(fs.Type.Elem()))
}
val := conv.GoValueOf(v)
if val.IsNil() {
panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
} else {
rv.Elem().Set(val.Elem())
}
},
mutable: func(p pointer) protoreflect.Value {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() {
rv.Set(reflect.New(fs.Type.Elem()))
}
return conv.PBValueOf(rv)
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func (mi *MessageInfo) fieldInfoForScalarOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
nullable := fd.HasPresence()
if oneof := fd.ContainingOneof(); oneof != nil && oneof.IsSynthetic() {
nullable = true
}
deref := false
if nullable && ft.Kind() == reflect.Ptr {
ft = ft.Elem()
deref = true
}
conv := NewConverter(ft, fd)
fieldOffset := offsetOf(fs)
index, _ := presenceIndex(mi.Desc, fd)
var getter func(p pointer) protoreflect.Value
if !nullable {
getter = getterForDirectScalar(fd, fs, conv, fieldOffset)
} else {
getter = getterForOpaqueNullableScalar(mi, index, fd, fs, conv, fieldOffset)
}
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
if nullable {
return mi.present(p, index)
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
switch rv.Kind() {
case reflect.Bool:
return rv.Bool()
case reflect.Int32, reflect.Int64:
return rv.Int() != 0
case reflect.Uint32, reflect.Uint64:
return rv.Uint() != 0
case reflect.Float32, reflect.Float64:
return rv.Float() != 0 || math.Signbit(rv.Float())
case reflect.String, reflect.Slice:
return rv.Len() > 0
default:
panic(fmt.Sprintf("invalid type: %v", rv.Type())) // should never happen
}
},
clear: func(p pointer) {
if nullable {
mi.clearPresent(p, index)
}
// This is only valuable for bytes and strings, but we do it unconditionally.
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: getter,
// TODO: Implement unsafe fast path for set?
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if deref {
if rv.IsNil() {
rv.Set(reflect.New(ft))
}
rv = rv.Elem()
}
rv.Set(conv.GoValueOf(v))
if nullable && rv.Kind() == reflect.Slice && rv.IsNil() {
rv.Set(emptyBytes)
}
if nullable {
mi.setPresent(p, index)
}
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func (mi *MessageInfo) fieldInfoForMessageOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
conv := NewConverter(ft, fd)
fieldOffset := offsetOf(fs)
index, _ := presenceIndex(mi.Desc, fd)
fieldNumber := fd.Number()
elemType := fs.Type.Elem()
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
return mi.present(p, index)
},
clear: func(p pointer) {
mi.clearPresent(p, index)
p.Apply(fieldOffset).AtomicSetNilPointer()
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
fp := p.Apply(fieldOffset)
mp := fp.AtomicGetPointer()
if mp.IsNil() {
// Lazily unmarshal this field.
mi.lazyUnmarshal(p, fieldNumber)
mp = fp.AtomicGetPointer()
}
rv := mp.AsValueOf(elemType)
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
val := pointerOfValue(conv.GoValueOf(v))
if val.IsNil() {
panic("invalid nil pointer")
}
p.Apply(fieldOffset).AtomicSetPointer(val)
mi.setPresent(p, index)
},
mutable: func(p pointer) protoreflect.Value {
fp := p.Apply(fieldOffset)
mp := fp.AtomicGetPointer()
if mp.IsNil() {
if mi.present(p, index) {
// Lazily unmarshal this field.
mi.lazyUnmarshal(p, fieldNumber)
mp = fp.AtomicGetPointer()
} else {
mp = pointerOfValue(conv.GoValueOf(conv.New()))
fp.AtomicSetPointer(mp)
mi.setPresent(p, index)
}
}
return conv.PBValueOf(mp.AsValueOf(fs.Type.Elem()))
},
newMessage: func() protoreflect.Message {
return conv.New().Message()
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
// A presenceList wraps a List, updating presence bits as necessary when the
// list contents change.
type presenceList struct {
pvalueList
setPresence func(bool)
}
type pvalueList interface {
protoreflect.List
//Unwrapper
}
func (list presenceList) Append(v protoreflect.Value) {
list.pvalueList.Append(v)
list.setPresence(true)
}
func (list presenceList) Truncate(i int) {
list.pvalueList.Truncate(i)
list.setPresence(i > 0)
}
// presenceIndex returns the index to pass to presence functions.
//
// TODO: field.Desc.Index() would be simpler, and would give space to record the presence of oneof fields.
func presenceIndex(md protoreflect.MessageDescriptor, fd protoreflect.FieldDescriptor) (uint32, presenceSize) {
found := false
var index, numIndices uint32
for i := 0; i < md.Fields().Len(); i++ {
f := md.Fields().Get(i)
if f == fd {
found = true
index = numIndices
}
if f.ContainingOneof() == nil || isLastOneofField(f) {
numIndices++
}
}
if !found {
panic(fmt.Sprintf("BUG: %v not in %v", fd.Name(), md.FullName()))
}
return index, presenceSize(numIndices)
}
func isLastOneofField(fd protoreflect.FieldDescriptor) bool {
fields := fd.ContainingOneof().Fields()
return fields.Get(fields.Len()-1) == fd
}
func (mi *MessageInfo) setPresent(p pointer, index uint32) {
p.Apply(mi.presenceOffset).PresenceInfo().SetPresent(index, mi.presenceSize)
}
func (mi *MessageInfo) clearPresent(p pointer, index uint32) {
p.Apply(mi.presenceOffset).PresenceInfo().ClearPresent(index)
}
func (mi *MessageInfo) present(p pointer, index uint32) bool {
return p.Apply(mi.presenceOffset).PresenceInfo().Present(index)
}
// usePresenceForField implements the somewhat intricate logic of when
// the presence bitmap is used for a field. The main logic is that a
// field that is optional or that can be lazy will use the presence
// bit, but for proto2, also maps have a presence bit. It also records
// if the field can ever be lazy, which is true if we have a
// lazyOffset and the field is a message or a slice of messages. A
// field that is lazy will always need a presence bit. Oneofs are not
// lazy and do not use presence, unless they are a synthetic oneof,
// which is a proto3 optional field. For proto3 optionals, we use the
// presence and they can also be lazy when applicable (a message).
func usePresenceForField(si opaqueStructInfo, fd protoreflect.FieldDescriptor) (usePresence, canBeLazy bool) {
hasLazyField := fd.(interface{ IsLazy() bool }).IsLazy()
// Non-oneof scalar fields with explicit field presence use the presence array.
usesPresenceArray := fd.HasPresence() && fd.Message() == nil && (fd.ContainingOneof() == nil || fd.ContainingOneof().IsSynthetic())
switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
return false, false
case fd.IsMap():
return false, false
case fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind:
return hasLazyField, hasLazyField
default:
return usesPresenceArray || (hasLazyField && fd.HasPresence()), false
}
}

View File

@ -0,0 +1,132 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate-types. DO NOT EDIT.
package impl
import (
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
func getterForOpaqueNullableScalar(mi *MessageInfo, index uint32, fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value {
ft := fs.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if fd.Kind() == protoreflect.EnumKind {
// Enums for nullable opaque types.
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return conv.PBValueOf(rv)
}
}
switch ft.Kind() {
case reflect.Bool:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bool()
return protoreflect.ValueOfBool(*x)
}
case reflect.Int32:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int32()
return protoreflect.ValueOfInt32(*x)
}
case reflect.Uint32:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint32()
return protoreflect.ValueOfUint32(*x)
}
case reflect.Int64:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int64()
return protoreflect.ValueOfInt64(*x)
}
case reflect.Uint64:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint64()
return protoreflect.ValueOfUint64(*x)
}
case reflect.Float32:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float32()
return protoreflect.ValueOfFloat32(*x)
}
case reflect.Float64:
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float64()
return protoreflect.ValueOfFloat64(*x)
}
case reflect.String:
if fd.Kind() == protoreflect.BytesKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).StringPtr()
if *x == nil {
return conv.Zero()
}
if len(**x) == 0 {
return protoreflect.ValueOfBytes(nil)
}
return protoreflect.ValueOfBytes([]byte(**x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).StringPtr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfString(**x)
}
case reflect.Slice:
if fd.Kind() == protoreflect.StringKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
return protoreflect.ValueOfString(string(*x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() || !mi.present(p, index) {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
return protoreflect.ValueOfBytes(*x)
}
}
panic("unexpected protobuf kind: " + ft.Kind().String())
}

View File

@ -0,0 +1,462 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
)
type reflectMessageInfo struct {
fields map[protoreflect.FieldNumber]*fieldInfo
oneofs map[protoreflect.Name]*oneofInfo
// fieldTypes contains the zero value of an enum or message field.
// For lists, it contains the element type.
// For maps, it contains the entry value type.
fieldTypes map[protoreflect.FieldNumber]any
// denseFields is a subset of fields where:
// 0 < fieldDesc.Number() < len(denseFields)
// It provides faster access to the fieldInfo, but may be incomplete.
denseFields []*fieldInfo
// rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
rangeInfos []any // either *fieldInfo or *oneofInfo
getUnknown func(pointer) protoreflect.RawFields
setUnknown func(pointer, protoreflect.RawFields)
extensionMap func(pointer) *extensionMap
nilMessage atomicNilMessage
}
// makeReflectFuncs generates the set of functions to support reflection.
func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
mi.makeKnownFieldsFunc(si)
mi.makeUnknownFieldsFunc(t, si)
mi.makeExtensionFieldsFunc(t, si)
mi.makeFieldTypes(si)
}
// makeKnownFieldsFunc generates functions for operations that can be performed
// on each protobuf message field. It takes in a reflect.Type representing the
// Go struct and matches message fields with struct fields.
//
// This code assumes that the struct is well-formed and panics if there are
// any discrepancies.
func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
md := mi.Desc
fds := md.Fields()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
if isOneof {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
var fi fieldInfo
switch {
case fs.Type == nil:
fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
case isOneof:
fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
case fd.IsMap():
fi = fieldInfoForMap(fd, fs, mi.Exporter)
case fd.IsList():
fi = fieldInfoForList(fd, fs, mi.Exporter)
case fd.Message() != nil:
fi = fieldInfoForMessage(fd, fs, mi.Exporter)
default:
fi = fieldInfoForScalar(fd, fs, mi.Exporter)
}
mi.fields[fd.Number()] = &fi
}
mi.oneofs = map[protoreflect.Name]*oneofInfo{}
for i := 0; i < md.Oneofs().Len(); i++ {
od := md.Oneofs().Get(i)
mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
}
mi.denseFields = make([]*fieldInfo, fds.Len()*2)
for i := 0; i < fds.Len(); i++ {
if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
}
}
for i := 0; i < fds.Len(); {
fd := fds.Get(i)
if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
i += od.Fields().Len()
} else {
mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
i++
}
}
// Introduce instability to iteration order, but keep it deterministic.
if len(mi.rangeInfos) > 1 && detrand.Bool() {
i := detrand.Intn(len(mi.rangeInfos) - 1)
mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
}
}
func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
switch {
case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
// Handle as []byte.
mi.getUnknown = func(p pointer) protoreflect.RawFields {
if p.IsNil() {
return nil
}
return *p.Apply(mi.unknownOffset).Bytes()
}
mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
if p.IsNil() {
panic("invalid SetUnknown on nil Message")
}
*p.Apply(mi.unknownOffset).Bytes() = b
}
case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
// Handle as *[]byte.
mi.getUnknown = func(p pointer) protoreflect.RawFields {
if p.IsNil() {
return nil
}
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
return nil
}
return **bp
}
mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
if p.IsNil() {
panic("invalid SetUnknown on nil Message")
}
bp := p.Apply(mi.unknownOffset).BytesPtr()
if *bp == nil {
*bp = new([]byte)
}
**bp = b
}
default:
mi.getUnknown = func(pointer) protoreflect.RawFields {
return nil
}
mi.setUnknown = func(p pointer, _ protoreflect.RawFields) {
if p.IsNil() {
panic("invalid SetUnknown on nil Message")
}
}
}
}
func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
if si.extensionOffset.IsValid() {
mi.extensionMap = func(p pointer) *extensionMap {
if p.IsNil() {
return (*extensionMap)(nil)
}
v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
}
} else {
mi.extensionMap = func(pointer) *extensionMap {
return (*extensionMap)(nil)
}
}
}
func (mi *MessageInfo) makeFieldTypes(si structInfo) {
md := mi.Desc
fds := md.Fields()
for i := 0; i < fds.Len(); i++ {
var ft reflect.Type
fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
if isOneof {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
var isMessage bool
switch {
case fs.Type == nil:
continue // never occurs for officially generated message types
case isOneof:
if fd.Enum() != nil || fd.Message() != nil {
ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
}
case fd.IsMap():
if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
ft = fs.Type.Elem()
}
isMessage = fd.MapValue().Message() != nil
case fd.IsList():
if fd.Enum() != nil || fd.Message() != nil {
ft = fs.Type.Elem()
if ft.Kind() == reflect.Slice {
ft = ft.Elem()
}
}
isMessage = fd.Message() != nil
case fd.Enum() != nil:
ft = fs.Type
if fd.HasPresence() && ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
case fd.Message() != nil:
ft = fs.Type
isMessage = true
}
if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
ft = reflect.PtrTo(ft) // never occurs for officially generated message types
}
if ft != nil {
if mi.fieldTypes == nil {
mi.fieldTypes = make(map[protoreflect.FieldNumber]any)
}
mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
}
}
}
type extensionMap map[int32]ExtensionField
func (m *extensionMap) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
if m != nil {
for _, x := range *m {
xd := x.Type().TypeDescriptor()
v := x.Value()
if xd.IsList() && v.List().Len() == 0 {
continue
}
if !f(xd, v) {
return
}
}
}
}
func (m *extensionMap) Has(xd protoreflect.ExtensionTypeDescriptor) (ok bool) {
if m == nil {
return false
}
x, ok := (*m)[int32(xd.Number())]
if !ok {
return false
}
if x.isUnexpandedLazy() {
// Avoid calling x.Value(), which triggers a lazy unmarshal.
return true
}
switch {
case xd.IsList():
return x.Value().List().Len() > 0
case xd.IsMap():
return x.Value().Map().Len() > 0
}
return true
}
func (m *extensionMap) Clear(xd protoreflect.ExtensionTypeDescriptor) {
delete(*m, int32(xd.Number()))
}
func (m *extensionMap) Get(xd protoreflect.ExtensionTypeDescriptor) protoreflect.Value {
if m != nil {
if x, ok := (*m)[int32(xd.Number())]; ok {
return x.Value()
}
}
return xd.Type().Zero()
}
func (m *extensionMap) Set(xd protoreflect.ExtensionTypeDescriptor, v protoreflect.Value) {
xt := xd.Type()
isValid := true
switch {
case !xt.IsValidValue(v):
isValid = false
case xd.IsList():
isValid = v.List().IsValid()
case xd.IsMap():
isValid = v.Map().IsValid()
case xd.Message() != nil:
isValid = v.Message().IsValid()
}
if !isValid {
panic(fmt.Sprintf("%v: assigning invalid value", xd.FullName()))
}
if *m == nil {
*m = make(map[int32]ExtensionField)
}
var x ExtensionField
x.Set(xt, v)
(*m)[int32(xd.Number())] = x
}
func (m *extensionMap) Mutable(xd protoreflect.ExtensionTypeDescriptor) protoreflect.Value {
if xd.Kind() != protoreflect.MessageKind && xd.Kind() != protoreflect.GroupKind && !xd.IsList() && !xd.IsMap() {
panic("invalid Mutable on field with non-composite type")
}
if x, ok := (*m)[int32(xd.Number())]; ok {
return x.Value()
}
v := xd.Type().New()
m.Set(xd, v)
return v
}
// MessageState is a data structure that is nested as the first field in a
// concrete message. It provides a way to implement the ProtoReflect method
// in an allocation-free way without needing to have a shadow Go type generated
// for every message type. This technique only works using unsafe.
//
// Example generated code:
//
// type M struct {
// state protoimpl.MessageState
//
// Field1 int32
// Field2 string
// Field3 *BarMessage
// ...
// }
//
// func (m *M) ProtoReflect() protoreflect.Message {
// mi := &file_fizz_buzz_proto_msgInfos[5]
// if protoimpl.UnsafeEnabled && m != nil {
// ms := protoimpl.X.MessageStateOf(Pointer(m))
// if ms.LoadMessageInfo() == nil {
// ms.StoreMessageInfo(mi)
// }
// return ms
// }
// return mi.MessageOf(m)
// }
//
// The MessageState type holds a *MessageInfo, which must be atomically set to
// the message info associated with a given message instance.
// By unsafely converting a *M into a *MessageState, the MessageState object
// has access to all the information needed to implement protobuf reflection.
// It has access to the message info as its first field, and a pointer to the
// MessageState is identical to a pointer to the concrete message value.
//
// Requirements:
// - The type M must implement protoreflect.ProtoMessage.
// - The address of m must not be nil.
// - The address of m and the address of m.state must be equal,
// even though they are different Go types.
type MessageState struct {
pragma.NoUnkeyedLiterals
pragma.DoNotCompare
pragma.DoNotCopy
atomicMessageInfo *MessageInfo
}
type messageState MessageState
var (
_ protoreflect.Message = (*messageState)(nil)
_ unwrapper = (*messageState)(nil)
)
// messageDataType is a tuple of a pointer to the message data and
// a pointer to the message type. It is a generalized way of providing a
// reflective view over a message instance. The disadvantage of this approach
// is the need to allocate this tuple of 16B.
type messageDataType struct {
p pointer
mi *MessageInfo
}
type (
messageReflectWrapper messageDataType
messageIfaceWrapper messageDataType
)
var (
_ protoreflect.Message = (*messageReflectWrapper)(nil)
_ unwrapper = (*messageReflectWrapper)(nil)
_ protoreflect.ProtoMessage = (*messageIfaceWrapper)(nil)
_ unwrapper = (*messageIfaceWrapper)(nil)
)
// MessageOf returns a reflective view over a message. The input must be a
// pointer to a named Go struct. If the provided type has a ProtoReflect method,
// it must be implemented by calling this method.
func (mi *MessageInfo) MessageOf(m any) protoreflect.Message {
if reflect.TypeOf(m) != mi.GoReflectType {
panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
}
p := pointerOfIface(m)
if p.IsNil() {
return mi.nilMessage.Init(mi)
}
return &messageReflectWrapper{p, mi}
}
func (m *messageReflectWrapper) pointer() pointer { return m.p }
func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
// Reset implements the v1 proto.Message.Reset method.
func (m *messageIfaceWrapper) Reset() {
if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
mr.Reset()
return
}
rv := reflect.ValueOf(m.protoUnwrap())
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
}
}
func (m *messageIfaceWrapper) ProtoReflect() protoreflect.Message {
return (*messageReflectWrapper)(m)
}
func (m *messageIfaceWrapper) protoUnwrap() any {
return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
}
// checkField verifies that the provided field descriptor is valid.
// Exactly one of the returned values is populated.
func (mi *MessageInfo) checkField(fd protoreflect.FieldDescriptor) (*fieldInfo, protoreflect.ExtensionTypeDescriptor) {
var fi *fieldInfo
if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
fi = mi.denseFields[n]
} else {
fi = mi.fields[n]
}
if fi != nil {
if fi.fieldDesc != fd {
if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
}
panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
}
return fi, nil
}
if fd.IsExtension() {
if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
// TODO: Should this be exact containing message descriptor match?
panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
}
if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
}
xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
if !ok {
panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
}
return nil, xtd
}
panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
}

View File

@ -0,0 +1,423 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"math"
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
type fieldInfo struct {
fieldDesc protoreflect.FieldDescriptor
// These fields are used for protobuf reflection support.
has func(pointer) bool
clear func(pointer)
get func(pointer) protoreflect.Value
set func(pointer, protoreflect.Value)
mutable func(pointer) protoreflect.Value
newMessage func() protoreflect.Message
newField func() protoreflect.Value
}
func fieldInfoForMissing(fd protoreflect.FieldDescriptor) fieldInfo {
// This never occurs for generated message types.
// It implies that a hand-crafted type has missing Go fields
// for specific protobuf message fields.
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
return false
},
clear: func(p pointer) {
panic("missing Go struct field for " + string(fd.FullName()))
},
get: func(p pointer) protoreflect.Value {
return fd.Default()
},
set: func(p pointer, v protoreflect.Value) {
panic("missing Go struct field for " + string(fd.FullName()))
},
mutable: func(p pointer) protoreflect.Value {
panic("missing Go struct field for " + string(fd.FullName()))
},
newMessage: func() protoreflect.Message {
panic("missing Go struct field for " + string(fd.FullName()))
},
newField: func() protoreflect.Value {
if v := fd.Default(); v.IsValid() {
return v
}
panic("missing Go struct field for " + string(fd.FullName()))
},
}
}
func fieldInfoForOneof(fd protoreflect.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Interface {
panic(fmt.Sprintf("field %v has invalid type: got %v, want interface kind", fd.FullName(), ft))
}
if ot.Kind() != reflect.Struct {
panic(fmt.Sprintf("field %v has invalid type: got %v, want struct kind", fd.FullName(), ot))
}
if !reflect.PtrTo(ot).Implements(ft) {
panic(fmt.Sprintf("field %v has invalid type: %v does not implement %v", fd.FullName(), ot, ft))
}
conv := NewConverter(ot.Field(0).Type, fd)
isMessage := fd.Message() != nil
// TODO: Implement unsafe fast path?
fieldOffset := offsetOf(fs)
return fieldInfo{
// NOTE: The logic below intentionally assumes that oneof fields are
// well-formatted. That is, the oneof interface never contains a
// typed nil pointer to one of the wrapper structs.
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
return false
}
return true
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot {
// NOTE: We intentionally don't check for rv.Elem().IsNil()
// so that (*OneofWrapperType)(nil) gets cleared to nil.
return
}
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
return conv.Zero()
}
rv = rv.Elem().Elem().Field(0)
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
rv.Set(reflect.New(ot))
}
rv = rv.Elem().Elem().Field(0)
rv.Set(conv.GoValueOf(v))
},
mutable: func(p pointer) protoreflect.Value {
if !isMessage {
panic(fmt.Sprintf("field %v with invalid Mutable call on field with non-composite type", fd.FullName()))
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
rv.Set(reflect.New(ot))
}
rv = rv.Elem().Elem().Field(0)
if rv.Kind() == reflect.Ptr && rv.IsNil() {
rv.Set(conv.GoValueOf(protoreflect.ValueOfMessage(conv.New().Message())))
}
return conv.PBValueOf(rv)
},
newMessage: func() protoreflect.Message {
return conv.New().Message()
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func fieldInfoForMap(fd protoreflect.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Map {
panic(fmt.Sprintf("field %v has invalid type: got %v, want map kind", fd.FullName(), ft))
}
conv := NewConverter(ft, fd)
// TODO: Implement unsafe fast path?
fieldOffset := offsetOf(fs)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return rv.Len() > 0
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
pv := conv.GoValueOf(v)
if pv.IsNil() {
panic(fmt.Sprintf("map field %v cannot be set with read-only value", fd.FullName()))
}
rv.Set(pv)
},
mutable: func(p pointer) protoreflect.Value {
v := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if v.IsNil() {
v.Set(reflect.MakeMap(fs.Type))
}
return conv.PBValueOf(v)
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func fieldInfoForList(fd protoreflect.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Slice {
panic(fmt.Sprintf("field %v has invalid type: got %v, want slice kind", fd.FullName(), ft))
}
conv := NewConverter(reflect.PtrTo(ft), fd)
// TODO: Implement unsafe fast path?
fieldOffset := offsetOf(fs)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return rv.Len() > 0
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
if rv.Elem().Len() == 0 {
return conv.Zero()
}
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
pv := conv.GoValueOf(v)
if pv.IsNil() {
panic(fmt.Sprintf("list field %v cannot be set with read-only value", fd.FullName()))
}
rv.Set(pv.Elem())
},
mutable: func(p pointer) protoreflect.Value {
v := p.Apply(fieldOffset).AsValueOf(fs.Type)
return conv.PBValueOf(v)
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
var (
nilBytes = reflect.ValueOf([]byte(nil))
emptyBytes = reflect.ValueOf([]byte{})
)
func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
ft := fs.Type
nullable := fd.HasPresence()
isBytes := ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8
var getter func(p pointer) protoreflect.Value
if nullable {
if ft.Kind() != reflect.Ptr && ft.Kind() != reflect.Slice {
// This never occurs for generated message types.
// Despite the protobuf type system specifying presence,
// the Go field type cannot represent it.
nullable = false
}
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
}
conv := NewConverter(ft, fd)
fieldOffset := offsetOf(fs)
// Generate specialized getter functions to avoid going through reflect.Value
if nullable {
getter = getterForNullableScalar(fd, fs, conv, fieldOffset)
} else {
getter = getterForDirectScalar(fd, fs, conv, fieldOffset)
}
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
if nullable {
return !p.Apply(fieldOffset).Elem().IsNil()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
switch rv.Kind() {
case reflect.Bool:
return rv.Bool()
case reflect.Int32, reflect.Int64:
return rv.Int() != 0
case reflect.Uint32, reflect.Uint64:
return rv.Uint() != 0
case reflect.Float32, reflect.Float64:
return rv.Float() != 0 || math.Signbit(rv.Float())
case reflect.String, reflect.Slice:
return rv.Len() > 0
default:
panic(fmt.Sprintf("field %v has invalid type: %v", fd.FullName(), rv.Type())) // should never happen
}
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: getter,
// TODO: Implement unsafe fast path for set?
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if nullable && rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(ft))
}
rv = rv.Elem()
}
rv.Set(conv.GoValueOf(v))
if isBytes && rv.Len() == 0 {
if nullable {
rv.Set(emptyBytes) // preserve presence
} else {
rv.Set(nilBytes) // do not preserve presence
}
}
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
func fieldInfoForMessage(fd protoreflect.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
ft := fs.Type
conv := NewConverter(ft, fd)
// TODO: Implement unsafe fast path?
fieldOffset := offsetOf(fs)
return fieldInfo{
fieldDesc: fd,
has: func(p pointer) bool {
if p.IsNil() {
return false
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if fs.Type.Kind() != reflect.Ptr {
return !rv.IsZero()
}
return !rv.IsNil()
},
clear: func(p pointer) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(reflect.Zero(rv.Type()))
},
get: func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return conv.PBValueOf(rv)
},
set: func(p pointer, v protoreflect.Value) {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
rv.Set(conv.GoValueOf(v))
if fs.Type.Kind() == reflect.Ptr && rv.IsNil() {
panic(fmt.Sprintf("field %v has invalid nil pointer", fd.FullName()))
}
},
mutable: func(p pointer) protoreflect.Value {
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if fs.Type.Kind() == reflect.Ptr && rv.IsNil() {
rv.Set(conv.GoValueOf(conv.New()))
}
return conv.PBValueOf(rv)
},
newMessage: func() protoreflect.Message {
return conv.New().Message()
},
newField: func() protoreflect.Value {
return conv.New()
},
}
}
type oneofInfo struct {
oneofDesc protoreflect.OneofDescriptor
which func(pointer) protoreflect.FieldNumber
}
func makeOneofInfo(od protoreflect.OneofDescriptor, si structInfo, x exporter) *oneofInfo {
oi := &oneofInfo{oneofDesc: od}
if od.IsSynthetic() {
fs := si.fieldsByNumber[od.Fields().Get(0).Number()]
fieldOffset := offsetOf(fs)
oi.which = func(p pointer) protoreflect.FieldNumber {
if p.IsNil() {
return 0
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() { // valid on either *T or []byte
return 0
}
return od.Fields().Get(0).Number()
}
} else {
fs := si.oneofsByName[od.Name()]
fieldOffset := offsetOf(fs)
oi.which = func(p pointer) protoreflect.FieldNumber {
if p.IsNil() {
return 0
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
if rv.IsNil() {
return 0
}
rv = rv.Elem()
if rv.IsNil() {
return 0
}
return si.oneofWrappersByType[rv.Type().Elem()]
}
}
return oi
}

View File

@ -0,0 +1,273 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate-types. DO NOT EDIT.
package impl
import (
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
)
func getterForNullableScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value {
ft := fs.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if fd.Kind() == protoreflect.EnumKind {
elemType := fs.Type.Elem()
// Enums for nullable types.
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).Elem().AsValueOf(elemType)
if rv.IsNil() {
return conv.Zero()
}
return conv.PBValueOf(rv.Elem())
}
}
switch ft.Kind() {
case reflect.Bool:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).BoolPtr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfBool(**x)
}
case reflect.Int32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int32Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfInt32(**x)
}
case reflect.Uint32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint32Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfUint32(**x)
}
case reflect.Int64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int64Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfInt64(**x)
}
case reflect.Uint64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint64Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfUint64(**x)
}
case reflect.Float32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float32Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfFloat32(**x)
}
case reflect.Float64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float64Ptr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfFloat64(**x)
}
case reflect.String:
if fd.Kind() == protoreflect.BytesKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).StringPtr()
if *x == nil {
return conv.Zero()
}
if len(**x) == 0 {
return protoreflect.ValueOfBytes(nil)
}
return protoreflect.ValueOfBytes([]byte(**x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).StringPtr()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfString(**x)
}
case reflect.Slice:
if fd.Kind() == protoreflect.StringKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
if len(*x) == 0 {
return conv.Zero()
}
return protoreflect.ValueOfString(string(*x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
if *x == nil {
return conv.Zero()
}
return protoreflect.ValueOfBytes(*x)
}
}
panic("unexpected protobuf kind: " + ft.Kind().String())
}
func getterForDirectScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, conv Converter, fieldOffset offset) func(p pointer) protoreflect.Value {
ft := fs.Type
if fd.Kind() == protoreflect.EnumKind {
// Enums for non nullable types.
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
return conv.PBValueOf(rv)
}
}
switch ft.Kind() {
case reflect.Bool:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bool()
return protoreflect.ValueOfBool(*x)
}
case reflect.Int32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int32()
return protoreflect.ValueOfInt32(*x)
}
case reflect.Uint32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint32()
return protoreflect.ValueOfUint32(*x)
}
case reflect.Int64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Int64()
return protoreflect.ValueOfInt64(*x)
}
case reflect.Uint64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Uint64()
return protoreflect.ValueOfUint64(*x)
}
case reflect.Float32:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float32()
return protoreflect.ValueOfFloat32(*x)
}
case reflect.Float64:
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Float64()
return protoreflect.ValueOfFloat64(*x)
}
case reflect.String:
if fd.Kind() == protoreflect.BytesKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).String()
if len(*x) == 0 {
return protoreflect.ValueOfBytes(nil)
}
return protoreflect.ValueOfBytes([]byte(*x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).String()
return protoreflect.ValueOfString(*x)
}
case reflect.Slice:
if fd.Kind() == protoreflect.StringKind {
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
return protoreflect.ValueOfString(string(*x))
}
}
return func(p pointer) protoreflect.Value {
if p.IsNil() {
return conv.Zero()
}
x := p.Apply(fieldOffset).Bytes()
return protoreflect.ValueOfBytes(*x)
}
}
panic("unexpected protobuf kind: " + ft.Kind().String())
}

View File

@ -0,0 +1,271 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate-types. DO NOT EDIT.
package impl
import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
return m.messageInfo().Desc
}
func (m *messageState) Type() protoreflect.MessageType {
return m.messageInfo()
}
func (m *messageState) New() protoreflect.Message {
return m.messageInfo().New()
}
func (m *messageState) Interface() protoreflect.ProtoMessage {
return m.protoUnwrap().(protoreflect.ProtoMessage)
}
func (m *messageState) protoUnwrap() any {
return m.pointer().AsIfaceOf(m.messageInfo().GoReflectType.Elem())
}
func (m *messageState) ProtoMethods() *protoiface.Methods {
mi := m.messageInfo()
mi.init()
return &mi.methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
// to be able to retrieve a v2 MessageInfo struct.
//
// WARNING: This method is exempt from the compatibility promise and
// may be removed in the future without warning.
func (m *messageState) ProtoMessageInfo() *MessageInfo {
return m.messageInfo()
}
func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
mi := m.messageInfo()
mi.init()
for _, ri := range mi.rangeInfos {
switch ri := ri.(type) {
case *fieldInfo:
if ri.has(m.pointer()) {
if !f(ri.fieldDesc, ri.get(m.pointer())) {
return
}
}
case *oneofInfo:
if n := ri.which(m.pointer()); n > 0 {
fi := mi.fields[n]
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
}
mi.extensionMap(m.pointer()).Range(f)
}
func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Has(xd)
}
}
func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
mi.extensionMap(m.pointer()).Clear(xd)
}
}
func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Get(xd)
}
}
func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
mi.extensionMap(m.pointer()).Set(xd, v)
}
}
func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Mutable(xd)
}
}
func (m *messageState) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.newField()
} else {
return xd.Type().New()
}
}
func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
mi := m.messageInfo()
mi.init()
if oi := mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor " + string(od.FullName()) + " for message " + string(m.Descriptor().FullName()))
}
func (m *messageState) GetUnknown() protoreflect.RawFields {
mi := m.messageInfo()
mi.init()
return mi.getUnknown(m.pointer())
}
func (m *messageState) SetUnknown(b protoreflect.RawFields) {
mi := m.messageInfo()
mi.init()
mi.setUnknown(m.pointer(), b)
}
func (m *messageState) IsValid() bool {
return !m.pointer().IsNil()
}
func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
return m.messageInfo().Desc
}
func (m *messageReflectWrapper) Type() protoreflect.MessageType {
return m.messageInfo()
}
func (m *messageReflectWrapper) New() protoreflect.Message {
return m.messageInfo().New()
}
func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
if m, ok := m.protoUnwrap().(protoreflect.ProtoMessage); ok {
return m
}
return (*messageIfaceWrapper)(m)
}
func (m *messageReflectWrapper) protoUnwrap() any {
return m.pointer().AsIfaceOf(m.messageInfo().GoReflectType.Elem())
}
func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
mi := m.messageInfo()
mi.init()
return &mi.methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
// to be able to retrieve a v2 MessageInfo struct.
//
// WARNING: This method is exempt from the compatibility promise and
// may be removed in the future without warning.
func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo {
return m.messageInfo()
}
func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
mi := m.messageInfo()
mi.init()
for _, ri := range mi.rangeInfos {
switch ri := ri.(type) {
case *fieldInfo:
if ri.has(m.pointer()) {
if !f(ri.fieldDesc, ri.get(m.pointer())) {
return
}
}
case *oneofInfo:
if n := ri.which(m.pointer()); n > 0 {
fi := mi.fields[n]
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
}
mi.extensionMap(m.pointer()).Range(f)
}
func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Has(xd)
}
}
func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
mi.extensionMap(m.pointer()).Clear(xd)
}
}
func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Get(xd)
}
}
func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
mi.extensionMap(m.pointer()).Set(xd, v)
}
}
func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return mi.extensionMap(m.pointer()).Mutable(xd)
}
}
func (m *messageReflectWrapper) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
mi := m.messageInfo()
mi.init()
if fi, xd := mi.checkField(fd); fi != nil {
return fi.newField()
} else {
return xd.Type().New()
}
}
func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
mi := m.messageInfo()
mi.init()
if oi := mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor " + string(od.FullName()) + " for message " + string(m.Descriptor().FullName()))
}
func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
mi := m.messageInfo()
mi.init()
return mi.getUnknown(m.pointer())
}
func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
mi := m.messageInfo()
mi.init()
mi.setUnknown(m.pointer(), b)
}
func (m *messageReflectWrapper) IsValid() bool {
return !m.pointer().IsNil()
}

View File

@ -0,0 +1,220 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"sync/atomic"
"unsafe"
"google.golang.org/protobuf/internal/protolazy"
)
const UnsafeEnabled = true
// Pointer is an opaque pointer type.
type Pointer unsafe.Pointer
// offset represents the offset to a struct field, accessible from a pointer.
// The offset is the byte offset to the field from the start of the struct.
type offset uintptr
// offsetOf returns a field offset for the struct field.
func offsetOf(f reflect.StructField) offset {
return offset(f.Offset)
}
// IsValid reports whether the offset is valid.
func (f offset) IsValid() bool { return f != invalidOffset }
// invalidOffset is an invalid field offset.
var invalidOffset = ^offset(0)
// zeroOffset is a noop when calling pointer.Apply.
var zeroOffset = offset(0)
// pointer is a pointer to a message struct or field.
type pointer struct{ p unsafe.Pointer }
// pointerOf returns p as a pointer.
func pointerOf(p Pointer) pointer {
return pointer{p: unsafe.Pointer(p)}
}
// pointerOfValue returns v as a pointer.
func pointerOfValue(v reflect.Value) pointer {
return pointer{p: unsafe.Pointer(v.Pointer())}
}
// pointerOfIface returns the pointer portion of an interface.
func pointerOfIface(v any) pointer {
type ifaceHeader struct {
Type unsafe.Pointer
Data unsafe.Pointer
}
return pointer{p: (*ifaceHeader)(unsafe.Pointer(&v)).Data}
}
// IsNil reports whether the pointer is nil.
func (p pointer) IsNil() bool {
return p.p == nil
}
// Apply adds an offset to the pointer to derive a new pointer
// to a specified field. The pointer must be valid and pointing at a struct.
func (p pointer) Apply(f offset) pointer {
if p.IsNil() {
panic("invalid nil pointer")
}
return pointer{p: unsafe.Pointer(uintptr(p.p) + uintptr(f))}
}
// AsValueOf treats p as a pointer to an object of type t and returns the value.
// It is equivalent to reflect.ValueOf(p.AsIfaceOf(t))
func (p pointer) AsValueOf(t reflect.Type) reflect.Value {
return reflect.NewAt(t, p.p)
}
// AsIfaceOf treats p as a pointer to an object of type t and returns the value.
// It is equivalent to p.AsValueOf(t).Interface()
func (p pointer) AsIfaceOf(t reflect.Type) any {
// TODO: Use tricky unsafe magic to directly create ifaceHeader.
return p.AsValueOf(t).Interface()
}
func (p pointer) Bool() *bool { return (*bool)(p.p) }
func (p pointer) BoolPtr() **bool { return (**bool)(p.p) }
func (p pointer) BoolSlice() *[]bool { return (*[]bool)(p.p) }
func (p pointer) Int32() *int32 { return (*int32)(p.p) }
func (p pointer) Int32Ptr() **int32 { return (**int32)(p.p) }
func (p pointer) Int32Slice() *[]int32 { return (*[]int32)(p.p) }
func (p pointer) Int64() *int64 { return (*int64)(p.p) }
func (p pointer) Int64Ptr() **int64 { return (**int64)(p.p) }
func (p pointer) Int64Slice() *[]int64 { return (*[]int64)(p.p) }
func (p pointer) Uint32() *uint32 { return (*uint32)(p.p) }
func (p pointer) Uint32Ptr() **uint32 { return (**uint32)(p.p) }
func (p pointer) Uint32Slice() *[]uint32 { return (*[]uint32)(p.p) }
func (p pointer) Uint64() *uint64 { return (*uint64)(p.p) }
func (p pointer) Uint64Ptr() **uint64 { return (**uint64)(p.p) }
func (p pointer) Uint64Slice() *[]uint64 { return (*[]uint64)(p.p) }
func (p pointer) Float32() *float32 { return (*float32)(p.p) }
func (p pointer) Float32Ptr() **float32 { return (**float32)(p.p) }
func (p pointer) Float32Slice() *[]float32 { return (*[]float32)(p.p) }
func (p pointer) Float64() *float64 { return (*float64)(p.p) }
func (p pointer) Float64Ptr() **float64 { return (**float64)(p.p) }
func (p pointer) Float64Slice() *[]float64 { return (*[]float64)(p.p) }
func (p pointer) String() *string { return (*string)(p.p) }
func (p pointer) StringPtr() **string { return (**string)(p.p) }
func (p pointer) StringSlice() *[]string { return (*[]string)(p.p) }
func (p pointer) Bytes() *[]byte { return (*[]byte)(p.p) }
func (p pointer) BytesPtr() **[]byte { return (**[]byte)(p.p) }
func (p pointer) BytesSlice() *[][]byte { return (*[][]byte)(p.p) }
func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) }
func (p pointer) LazyInfoPtr() **protolazy.XXX_lazyUnmarshalInfo {
return (**protolazy.XXX_lazyUnmarshalInfo)(p.p)
}
func (p pointer) PresenceInfo() presence {
return presence{P: p.p}
}
func (p pointer) Elem() pointer {
return pointer{p: *(*unsafe.Pointer)(p.p)}
}
// PointerSlice loads []*T from p as a []pointer.
// The value returned is aliased with the original slice.
// This behavior differs from the implementation in pointer_reflect.go.
func (p pointer) PointerSlice() []pointer {
// Super-tricky - p should point to a []*T where T is a
// message type. We load it as []pointer.
return *(*[]pointer)(p.p)
}
// AppendPointerSlice appends v to p, which must be a []*T.
func (p pointer) AppendPointerSlice(v pointer) {
*(*[]pointer)(p.p) = append(*(*[]pointer)(p.p), v)
}
// SetPointer sets *p to v.
func (p pointer) SetPointer(v pointer) {
*(*unsafe.Pointer)(p.p) = (unsafe.Pointer)(v.p)
}
func (p pointer) growBoolSlice(addCap int) {
sp := p.BoolSlice()
s := make([]bool, 0, addCap+len(*sp))
s = s[:len(*sp)]
copy(s, *sp)
*sp = s
}
func (p pointer) growInt32Slice(addCap int) {
sp := p.Int32Slice()
s := make([]int32, 0, addCap+len(*sp))
s = s[:len(*sp)]
copy(s, *sp)
*sp = s
}
func (p pointer) growUint32Slice(addCap int) {
p.growInt32Slice(addCap)
}
func (p pointer) growFloat32Slice(addCap int) {
p.growInt32Slice(addCap)
}
func (p pointer) growInt64Slice(addCap int) {
sp := p.Int64Slice()
s := make([]int64, 0, addCap+len(*sp))
s = s[:len(*sp)]
copy(s, *sp)
*sp = s
}
func (p pointer) growUint64Slice(addCap int) {
p.growInt64Slice(addCap)
}
func (p pointer) growFloat64Slice(addCap int) {
p.growInt64Slice(addCap)
}
// Static check that MessageState does not exceed the size of a pointer.
const _ = uint(unsafe.Sizeof(unsafe.Pointer(nil)) - unsafe.Sizeof(MessageState{}))
func (Export) MessageStateOf(p Pointer) *messageState {
// Super-tricky - see documentation on MessageState.
return (*messageState)(unsafe.Pointer(p))
}
func (ms *messageState) pointer() pointer {
// Super-tricky - see documentation on MessageState.
return pointer{p: unsafe.Pointer(ms)}
}
func (ms *messageState) messageInfo() *MessageInfo {
mi := ms.LoadMessageInfo()
if mi == nil {
panic("invalid nil message info; this suggests memory corruption due to a race or shallow copy on the message struct")
}
return mi
}
func (ms *messageState) LoadMessageInfo() *MessageInfo {
return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.atomicMessageInfo))))
}
func (ms *messageState) StoreMessageInfo(mi *MessageInfo) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&ms.atomicMessageInfo)), unsafe.Pointer(mi))
}
type atomicNilMessage struct{ p unsafe.Pointer } // p is a *messageReflectWrapper
func (m *atomicNilMessage) Init(mi *MessageInfo) *messageReflectWrapper {
if p := atomic.LoadPointer(&m.p); p != nil {
return (*messageReflectWrapper)(p)
}
w := &messageReflectWrapper{mi: mi}
atomic.CompareAndSwapPointer(&m.p, nil, (unsafe.Pointer)(w))
return (*messageReflectWrapper)(atomic.LoadPointer(&m.p))
}

View File

@ -0,0 +1,42 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"sync/atomic"
"unsafe"
)
func (p pointer) AtomicGetPointer() pointer {
return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))}
}
func (p pointer) AtomicSetPointer(v pointer) {
atomic.StorePointer((*unsafe.Pointer)(p.p), v.p)
}
func (p pointer) AtomicSetNilPointer() {
atomic.StorePointer((*unsafe.Pointer)(p.p), unsafe.Pointer(nil))
}
func (p pointer) AtomicSetPointerIfNil(v pointer) pointer {
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(p.p), unsafe.Pointer(nil), v.p) {
return v
}
return pointer{p: atomic.LoadPointer((*unsafe.Pointer)(p.p))}
}
type atomicV1MessageInfo struct{ p Pointer }
func (mi *atomicV1MessageInfo) Get() Pointer {
return Pointer(atomic.LoadPointer((*unsafe.Pointer)(&mi.p)))
}
func (mi *atomicV1MessageInfo) SetIfNil(p Pointer) Pointer {
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(&mi.p), nil, unsafe.Pointer(p)) {
return p
}
return mi.Get()
}

View File

@ -0,0 +1,142 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"sync/atomic"
"unsafe"
)
// presenceSize represents the size of a presence set, which should be the largest index of the set+1
type presenceSize uint32
// presence is the internal representation of the bitmap array in a generated protobuf
type presence struct {
// This is a pointer to the beginning of an array of uint32
P unsafe.Pointer
}
func (p presence) toElem(num uint32) (ret *uint32) {
const (
bitsPerByte = 8
siz = unsafe.Sizeof(*ret)
)
// p.P points to an array of uint32, num is the bit in this array that the
// caller wants to check/manipulate. Calculate the index in the array that
// contains this specific bit. E.g.: 76 / 32 = 2 (integer division).
offset := uintptr(num) / (siz * bitsPerByte) * siz
return (*uint32)(unsafe.Pointer(uintptr(p.P) + offset))
}
// Present checks for the presence of a specific field number in a presence set.
func (p presence) Present(num uint32) bool {
if p.P == nil {
return false
}
return Export{}.Present(p.toElem(num), num)
}
// SetPresent adds presence for a specific field number in a presence set.
func (p presence) SetPresent(num uint32, size presenceSize) {
Export{}.SetPresent(p.toElem(num), num, uint32(size))
}
// SetPresentUnatomic adds presence for a specific field number in a presence set without using
// atomic operations. Only to be called during unmarshaling.
func (p presence) SetPresentUnatomic(num uint32, size presenceSize) {
Export{}.SetPresentNonAtomic(p.toElem(num), num, uint32(size))
}
// ClearPresent removes presence for a specific field number in a presence set.
func (p presence) ClearPresent(num uint32) {
Export{}.ClearPresent(p.toElem(num), num)
}
// LoadPresenceCache (together with PresentInCache) allows for a
// cached version of checking for presence without re-reading the word
// for every field. It is optimized for efficiency and assumes no
// simltaneous mutation of the presence set (or at least does not have
// a problem with simultaneous mutation giving inconsistent results).
func (p presence) LoadPresenceCache() (current uint32) {
if p.P == nil {
return 0
}
return atomic.LoadUint32((*uint32)(p.P))
}
// PresentInCache reads presence from a cached word in the presence
// bitmap. It caches up a new word if the bit is outside the
// word. This is for really fast iteration through bitmaps in cases
// where we either know that the bitmap will not be altered, or we
// don't care about inconsistencies caused by simultaneous writes.
func (p presence) PresentInCache(num uint32, cachedElement *uint32, current *uint32) bool {
if num/32 != *cachedElement {
o := uintptr(num/32) * unsafe.Sizeof(uint32(0))
q := (*uint32)(unsafe.Pointer(uintptr(p.P) + o))
*current = atomic.LoadUint32(q)
*cachedElement = num / 32
}
return (*current & (1 << (num % 32))) > 0
}
// AnyPresent checks if any field is marked as present in the bitmap.
func (p presence) AnyPresent(size presenceSize) bool {
n := uintptr((size + 31) / 32)
for j := uintptr(0); j < n; j++ {
o := j * unsafe.Sizeof(uint32(0))
q := (*uint32)(unsafe.Pointer(uintptr(p.P) + o))
b := atomic.LoadUint32(q)
if b > 0 {
return true
}
}
return false
}
// toRaceDetectData finds the preceding RaceDetectHookData in a
// message by using pointer arithmetic. As the type of the presence
// set (bitmap) varies with the number of fields in the protobuf, we
// can not have a struct type containing the array and the
// RaceDetectHookData. instead the RaceDetectHookData is placed
// immediately before the bitmap array, and we find it by walking
// backwards in the struct.
//
// This method is only called from the race-detect version of the code,
// so RaceDetectHookData is never an empty struct.
func (p presence) toRaceDetectData() *RaceDetectHookData {
var template struct {
d RaceDetectHookData
a [1]uint32
}
o := (uintptr(unsafe.Pointer(&template.a)) - uintptr(unsafe.Pointer(&template.d)))
return (*RaceDetectHookData)(unsafe.Pointer(uintptr(p.P) - o))
}
func atomicLoadShadowPresence(p **[]byte) *[]byte {
return (*[]byte)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p))))
}
func atomicStoreShadowPresence(p **[]byte, v *[]byte) {
atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(p)), nil, unsafe.Pointer(v))
}
// findPointerToRaceDetectData finds the preceding RaceDetectHookData
// in a message by using pointer arithmetic. For the methods called
// directy from generated code, we don't have a pointer to the
// beginning of the presence set, but a pointer inside the array. As
// we know the index of the bit we're manipulating (num), we can
// calculate which element of the array ptr is pointing to. With that
// information we find the preceding RaceDetectHookData and can
// manipulate the shadow bitmap.
//
// This method is only called from the race-detect version of the
// code, so RaceDetectHookData is never an empty struct.
func findPointerToRaceDetectData(ptr *uint32, num uint32) *RaceDetectHookData {
var template struct {
d RaceDetectHookData
a [1]uint32
}
o := (uintptr(unsafe.Pointer(&template.a)) - uintptr(unsafe.Pointer(&template.d))) + uintptr(num/32)*unsafe.Sizeof(uint32(0))
return (*RaceDetectHookData)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) - o))
}

View File

@ -0,0 +1,570 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"math"
"math/bits"
"reflect"
"unicode/utf8"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
)
// ValidationStatus is the result of validating the wire-format encoding of a message.
type ValidationStatus int
const (
// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
// The validator was unable to render a judgement.
//
// The only causes of this status are an aberrant message type appearing somewhere
// in the message or a failure in the extension resolver.
ValidationUnknown ValidationStatus = iota + 1
// ValidationInvalid indicates that unmarshaling the message will fail.
ValidationInvalid
// ValidationValid indicates that unmarshaling the message will succeed.
ValidationValid
// ValidationWrongWireType indicates that a validated field does not have
// the expected wire type.
ValidationWrongWireType
)
func (v ValidationStatus) String() string {
switch v {
case ValidationUnknown:
return "ValidationUnknown"
case ValidationInvalid:
return "ValidationInvalid"
case ValidationValid:
return "ValidationValid"
default:
return fmt.Sprintf("ValidationStatus(%d)", int(v))
}
}
// Validate determines whether the contents of the buffer are a valid wire encoding
// of the message type.
//
// This function is exposed for testing.
func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
mi, ok := mt.(*MessageInfo)
if !ok {
return out, ValidationUnknown
}
if in.Resolver == nil {
in.Resolver = protoregistry.GlobalTypes
}
o, st := mi.validate(in.Buf, 0, unmarshalOptions{
flags: in.Flags,
resolver: in.Resolver,
})
if o.initialized {
out.Flags |= protoiface.UnmarshalInitialized
}
return out, st
}
type validationInfo struct {
mi *MessageInfo
typ validationType
keyType, valType validationType
// For non-required fields, requiredBit is 0.
//
// For required fields, requiredBit's nth bit is set, where n is a
// unique index in the range [0, MessageInfo.numRequiredFields).
//
// If there are more than 64 required fields, requiredBit is 0.
requiredBit uint64
}
type validationType uint8
const (
validationTypeOther validationType = iota
validationTypeMessage
validationTypeGroup
validationTypeMap
validationTypeRepeatedVarint
validationTypeRepeatedFixed32
validationTypeRepeatedFixed64
validationTypeVarint
validationTypeFixed32
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
validationTypeMessageSetItem
)
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
var vi validationInfo
switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
switch fd.Kind() {
case protoreflect.MessageKind:
vi.typ = validationTypeMessage
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
vi.mi = getMessageInfo(ot.Field(0).Type)
}
case protoreflect.GroupKind:
vi.typ = validationTypeGroup
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
vi.mi = getMessageInfo(ot.Field(0).Type)
}
case protoreflect.StringKind:
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
}
default:
vi = newValidationInfo(fd, ft)
}
if fd.Cardinality() == protoreflect.Required {
// Avoid overflow. The required field check is done with a 64-bit mask, with
// any message containing more than 64 required fields always reported as
// potentially uninitialized, so it is not important to get a precise count
// of the required fields past 64.
if mi.numRequiredFields < math.MaxUint8 {
mi.numRequiredFields++
vi.requiredBit = 1 << (mi.numRequiredFields - 1)
}
}
return vi
}
func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
var vi validationInfo
switch {
case fd.IsList():
switch fd.Kind() {
case protoreflect.MessageKind:
vi.typ = validationTypeMessage
if ft.Kind() == reflect.Ptr {
// Repeated opaque message fields are *[]*T.
ft = ft.Elem()
}
if ft.Kind() == reflect.Slice {
vi.mi = getMessageInfo(ft.Elem())
}
case protoreflect.GroupKind:
vi.typ = validationTypeGroup
if ft.Kind() == reflect.Ptr {
// Repeated opaque message fields are *[]*T.
ft = ft.Elem()
}
if ft.Kind() == reflect.Slice {
vi.mi = getMessageInfo(ft.Elem())
}
case protoreflect.StringKind:
vi.typ = validationTypeBytes
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
default:
switch wireTypes[fd.Kind()] {
case protowire.VarintType:
vi.typ = validationTypeRepeatedVarint
case protowire.Fixed32Type:
vi.typ = validationTypeRepeatedFixed32
case protowire.Fixed64Type:
vi.typ = validationTypeRepeatedFixed64
}
}
case fd.IsMap():
vi.typ = validationTypeMap
switch fd.MapKey().Kind() {
case protoreflect.StringKind:
if strs.EnforceUTF8(fd) {
vi.keyType = validationTypeUTF8String
}
}
switch fd.MapValue().Kind() {
case protoreflect.MessageKind:
vi.valType = validationTypeMessage
if ft.Kind() == reflect.Map {
vi.mi = getMessageInfo(ft.Elem())
}
case protoreflect.StringKind:
if strs.EnforceUTF8(fd) {
vi.valType = validationTypeUTF8String
}
}
default:
switch fd.Kind() {
case protoreflect.MessageKind:
vi.typ = validationTypeMessage
vi.mi = getMessageInfo(ft)
case protoreflect.GroupKind:
vi.typ = validationTypeGroup
vi.mi = getMessageInfo(ft)
case protoreflect.StringKind:
vi.typ = validationTypeBytes
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
default:
switch wireTypes[fd.Kind()] {
case protowire.VarintType:
vi.typ = validationTypeVarint
case protowire.Fixed32Type:
vi.typ = validationTypeFixed32
case protowire.Fixed64Type:
vi.typ = validationTypeFixed64
case protowire.BytesType:
vi.typ = validationTypeBytes
}
}
}
return vi
}
func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
mi.init()
type validationState struct {
typ validationType
keyType, valType validationType
endGroup protowire.Number
mi *MessageInfo
tail []byte
requiredMask uint64
}
// Pre-allocate some slots to avoid repeated slice reallocation.
states := make([]validationState, 0, 16)
states = append(states, validationState{
typ: validationTypeMessage,
mi: mi,
})
if groupTag > 0 {
states[0].typ = validationTypeGroup
states[0].endGroup = groupTag
}
initialized := true
start := len(b)
State:
for len(states) > 0 {
st := &states[len(states)-1]
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = protowire.ConsumeVarint(b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
var num protowire.Number
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
return out, ValidationInvalid
} else {
num = protowire.Number(n)
}
wtyp := protowire.Type(tag & 7)
if wtyp == protowire.EndGroupType {
if st.endGroup == num {
goto PopState
}
return out, ValidationInvalid
}
var vi validationInfo
switch {
case st.typ == validationTypeMap:
switch num {
case genid.MapEntry_Key_field_number:
vi.typ = st.keyType
case genid.MapEntry_Value_field_number:
vi.typ = st.valType
vi.mi = st.mi
vi.requiredBit = 1
}
case flags.ProtoLegacy && st.mi.isMessageSet:
switch num {
case messageset.FieldItem:
vi.typ = validationTypeMessageSetItem
}
default:
var f *coderFieldInfo
if int(num) < len(st.mi.denseCoderFields) {
f = st.mi.denseCoderFields[num]
} else {
f = st.mi.coderFields[num]
}
if f != nil {
vi = f.validation
break
}
// Possible extension field.
//
// TODO: We should return ValidationUnknown when:
// 1. The resolver is not frozen. (More extensions may be added to it.)
// 2. The resolver returns preg.NotFound.
// In this case, a type added to the resolver in the future could cause
// unmarshaling to begin failing. Supporting this requires some way to
// determine if the resolver is frozen.
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
if err != nil && err != protoregistry.NotFound {
return out, ValidationUnknown
}
if err == nil {
vi = getExtensionFieldInfo(xt).validation
}
}
if vi.requiredBit != 0 {
// Check that the field has a compatible wire type.
// We only need to consider non-repeated field types,
// since repeated fields (and maps) can never be required.
ok := false
switch vi.typ {
case validationTypeVarint:
ok = wtyp == protowire.VarintType
case validationTypeFixed32:
ok = wtyp == protowire.Fixed32Type
case validationTypeFixed64:
ok = wtyp == protowire.Fixed64Type
case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
ok = wtyp == protowire.BytesType
case validationTypeGroup:
ok = wtyp == protowire.StartGroupType
}
if ok {
st.requiredMask |= vi.requiredBit
}
}
switch wtyp {
case protowire.VarintType:
if len(b) >= 10 {
switch {
case b[0] < 0x80:
b = b[1:]
case b[1] < 0x80:
b = b[2:]
case b[2] < 0x80:
b = b[3:]
case b[3] < 0x80:
b = b[4:]
case b[4] < 0x80:
b = b[5:]
case b[5] < 0x80:
b = b[6:]
case b[6] < 0x80:
b = b[7:]
case b[7] < 0x80:
b = b[8:]
case b[8] < 0x80:
b = b[9:]
case b[9] < 0x80 && b[9] < 2:
b = b[10:]
default:
return out, ValidationInvalid
}
} else {
switch {
case len(b) > 0 && b[0] < 0x80:
b = b[1:]
case len(b) > 1 && b[1] < 0x80:
b = b[2:]
case len(b) > 2 && b[2] < 0x80:
b = b[3:]
case len(b) > 3 && b[3] < 0x80:
b = b[4:]
case len(b) > 4 && b[4] < 0x80:
b = b[5:]
case len(b) > 5 && b[5] < 0x80:
b = b[6:]
case len(b) > 6 && b[6] < 0x80:
b = b[7:]
case len(b) > 7 && b[7] < 0x80:
b = b[8:]
case len(b) > 8 && b[8] < 0x80:
b = b[9:]
case len(b) > 9 && b[9] < 2:
b = b[10:]
default:
return out, ValidationInvalid
}
}
continue State
case protowire.BytesType:
var size uint64
if len(b) >= 1 && b[0] < 0x80 {
size = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
size = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
size, n = protowire.ConsumeVarint(b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
if size > uint64(len(b)) {
return out, ValidationInvalid
}
v := b[:size]
b = b[size:]
switch vi.typ {
case validationTypeMessage:
if vi.mi == nil {
return out, ValidationUnknown
}
vi.mi.init()
fallthrough
case validationTypeMap:
if vi.mi != nil {
vi.mi.init()
}
states = append(states, validationState{
typ: vi.typ,
keyType: vi.keyType,
valType: vi.valType,
mi: vi.mi,
tail: b,
})
b = v
continue State
case validationTypeRepeatedVarint:
// Packed field.
for len(v) > 0 {
_, n := protowire.ConsumeVarint(v)
if n < 0 {
return out, ValidationInvalid
}
v = v[n:]
}
case validationTypeRepeatedFixed32:
// Packed field.
if len(v)%4 != 0 {
return out, ValidationInvalid
}
case validationTypeRepeatedFixed64:
// Packed field.
if len(v)%8 != 0 {
return out, ValidationInvalid
}
case validationTypeUTF8String:
if !utf8.Valid(v) {
return out, ValidationInvalid
}
}
case protowire.Fixed32Type:
if len(b) < 4 {
return out, ValidationInvalid
}
b = b[4:]
case protowire.Fixed64Type:
if len(b) < 8 {
return out, ValidationInvalid
}
b = b[8:]
case protowire.StartGroupType:
switch {
case vi.typ == validationTypeGroup:
if vi.mi == nil {
return out, ValidationUnknown
}
vi.mi.init()
states = append(states, validationState{
typ: validationTypeGroup,
mi: vi.mi,
endGroup: num,
})
continue State
case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
if err != nil {
return out, ValidationInvalid
}
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
switch {
case err == protoregistry.NotFound:
b = b[n:]
case err != nil:
return out, ValidationUnknown
default:
xvi := getExtensionFieldInfo(xt).validation
if xvi.mi != nil {
xvi.mi.init()
}
states = append(states, validationState{
typ: xvi.typ,
mi: xvi.mi,
tail: b[n:],
})
b = v
continue State
}
default:
n := protowire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
default:
return out, ValidationInvalid
}
}
if st.endGroup != 0 {
return out, ValidationInvalid
}
if len(b) != 0 {
return out, ValidationInvalid
}
b = st.tail
PopState:
numRequiredFields := 0
switch st.typ {
case validationTypeMessage, validationTypeGroup:
numRequiredFields = int(st.mi.numRequiredFields)
case validationTypeMap:
// If this is a map field with a message value that contains
// required fields, require that the value be present.
if st.mi != nil && st.mi.numRequiredFields > 0 {
numRequiredFields = 1
}
}
// If there are more than 64 required fields, this check will
// always fail and we will report that the message is potentially
// uninitialized.
if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
initialized = false
}
states = states[:len(states)-1]
}
out.n = start - len(b)
if initialized {
out.initialized = true
}
return out, ValidationValid
}