// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package impl import ( "fmt" "math/bits" "os" "reflect" "sort" "sync/atomic" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/protolazy" "google.golang.org/protobuf/reflect/protoreflect" preg "google.golang.org/protobuf/reflect/protoregistry" piface "google.golang.org/protobuf/runtime/protoiface" ) var enableLazy int32 = func() int32 { if os.Getenv("GOPROTODEBUG") == "nolazy" { return 0 } return 1 }() // EnableLazyUnmarshal enables lazy unmarshaling. func EnableLazyUnmarshal(enable bool) { if enable { atomic.StoreInt32(&enableLazy, 1) return } atomic.StoreInt32(&enableLazy, 0) } // LazyEnabled reports whether lazy unmarshalling is currently enabled. func LazyEnabled() bool { return atomic.LoadInt32(&enableLazy) != 0 } // UnmarshalField unmarshals a field in a message. func UnmarshalField(m interface{}, num protowire.Number) { switch m := m.(type) { case *messageState: m.messageInfo().lazyUnmarshal(m.pointer(), num) case *messageReflectWrapper: m.messageInfo().lazyUnmarshal(m.pointer(), num) default: panic(fmt.Sprintf("unsupported wrapper type %T", m)) } } func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) { var f *coderFieldInfo if int(num) < len(mi.denseCoderFields) { f = mi.denseCoderFields[num] } else { f = mi.coderFields[num] } if f == nil { panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num)) } lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num)) if !found && multipleEntries == nil { panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num)) } // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races. // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil. fp := pointerOfValue(reflect.New(f.ft)) if multipleEntries != nil { for _, entry := range multipleEntries { mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags()) } } else { mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags()) } p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem()) } func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error { opts := lazyUnmarshalOptions opts.flags |= flags for len(b) > 0 { // Parse the tag (field number and wire type). var tag uint64 if b[0] < 0x80 { tag = uint64(b[0]) b = b[1:] } else if len(b) >= 2 && b[1] < 128 { tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 b = b[2:] } else { var n int tag, n = protowire.ConsumeVarint(b) if n < 0 { return errors.New("invalid wire data") } b = b[n:] } var num protowire.Number if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { return errors.New("invalid wire data") } else { num = protowire.Number(n) } wtyp := protowire.Type(tag & 7) if num == f.num { o, err := f.funcs.unmarshal(b, p, wtyp, f, opts) if err == nil { b = b[o.n:] continue } if err != errUnknown { return err } } n := protowire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return errors.New("invalid wire data") } b = b[n:] } return nil } func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { fmi := f.validation.mi if fmi == nil { fd := mi.Desc.Fields().ByNumber(f.num) if fd == nil || !fd.IsWeak() { return out, ValidationUnknown } messageName := fd.Message().FullName() messageType, err := preg.GlobalTypes.FindMessageByName(messageName) if err != nil { return out, ValidationUnknown } var ok bool fmi, ok = messageType.(*MessageInfo) if !ok { return out, ValidationUnknown } } fmi.init() switch f.validation.typ { case validationTypeMessage: if wtyp != protowire.BytesType { return out, ValidationWrongWireType } v, n := protowire.ConsumeBytes(b) if n < 0 { return out, ValidationInvalid } out, st := fmi.validate(v, 0, opts) out.n = n return out, st case validationTypeGroup: if wtyp != protowire.StartGroupType { return out, ValidationWrongWireType } out, st := fmi.validate(b, f.num, opts) return out, st default: return out, ValidationUnknown } } // unmarshalPointerLazy is similar to unmarshalPointerEager, but it // specifically handles lazy unmarshalling. it expects lazyOffset and // presenceOffset to both be valid. func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { initialized := true var requiredMask uint64 var lazy **protolazy.XXX_lazyUnmarshalInfo var presence presence var lazyIndex []protolazy.IndexEntry var lastNum protowire.Number outOfOrder := false lazyDecode := false presence = p.Apply(mi.presenceOffset).PresenceInfo() lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() if !presence.AnyPresent(mi.presenceSize) { if opts.CanBeLazy() { // If the message contains existing data, we need to merge into it. // Lazy unmarshaling doesn't merge, so only enable it when the // message is empty (has no presence bitmap). lazyDecode = true if *lazy == nil { *lazy = &protolazy.XXX_lazyUnmarshalInfo{} } (*lazy).SetUnmarshalFlags(opts.flags) if !opts.AliasBuffer() { // Make a copy of the buffer for lazy unmarshaling. // Set the AliasBuffer flag so recursive unmarshal // operations reuse the copy. b = append([]byte{}, b...) opts.flags |= piface.UnmarshalAliasBuffer } (*lazy).SetBuffer(b) } } // Track special handling of lazy fields. // // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil). // In the event that validation for a field fails, this map tracks handling of the field. type lazyAction uint8 const ( lazyValidateOnly lazyAction = iota // validate the field only lazyUnmarshalNow // eagerly unmarshal the field lazyUnmarshalLater // unmarshal the field after the message is fully processed ) var lazyFields map[*coderFieldInfo]lazyAction var exts *map[int32]ExtensionField start := len(b) pos := 0 for len(b) > 0 { // Parse the tag (field number and wire type). var tag uint64 if b[0] < 0x80 { tag = uint64(b[0]) b = b[1:] } else if len(b) >= 2 && b[1] < 128 { tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 b = b[2:] } else { var n int tag, n = protowire.ConsumeVarint(b) if n < 0 { return out, errDecode } b = b[n:] } var num protowire.Number if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { return out, errors.New("invalid field number") } else { num = protowire.Number(n) } wtyp := protowire.Type(tag & 7) if wtyp == protowire.EndGroupType { if num != groupTag { return out, errors.New("mismatching end group marker") } groupTag = 0 break } var f *coderFieldInfo if int(num) < len(mi.denseCoderFields) { f = mi.denseCoderFields[num] } else { f = mi.coderFields[num] } var n int err := errUnknown discardUnknown := false Field: switch { case f != nil: if f.funcs.unmarshal == nil { break } if f.isLazy && lazyDecode { switch { case lazyFields == nil || lazyFields[f] == lazyValidateOnly: // Attempt to validate this field and leave it for later lazy unmarshaling. o, valid := mi.skipField(b, f, wtyp, opts) switch valid { case ValidationValid: // Skip over the valid field and continue. err = nil presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) requiredMask |= f.validation.requiredBit if !o.initialized { initialized = false } n = o.n break Field case ValidationInvalid: return out, errors.New("invalid proto wire format") case ValidationWrongWireType: break Field case ValidationUnknown: if lazyFields == nil { lazyFields = make(map[*coderFieldInfo]lazyAction) } if presence.Present(f.presenceIndex) { // We were unable to determine if the field is valid or not, // and we've already skipped over at least one instance of this // field. Clear the presence bit (so if we stop decoding early, // we don't leave a partially-initialized field around) and flag // the field for unmarshaling before we return. presence.ClearPresent(f.presenceIndex) lazyFields[f] = lazyUnmarshalLater discardUnknown = true break Field } else { // We were unable to determine if the field is valid or not, // but this is the first time we've seen it. Flag it as needing // eager unmarshaling and fall through to the eager unmarshal case below. lazyFields[f] = lazyUnmarshalNow } } case lazyFields[f] == lazyUnmarshalLater: // This field will be unmarshaled in a separate pass below. // Skip over it here. discardUnknown = true break Field default: // Eagerly unmarshal the field. } } if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) { if p.Apply(f.offset).AtomicGetPointer().IsNil() { mi.lazyUnmarshal(p, f.num) } } var o unmarshalOutput o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) n = o.n if err != nil { break } requiredMask |= f.validation.requiredBit if f.funcs.isInit != nil && !o.initialized { initialized = false } if f.presenceIndex != noPresence { presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) } default: // Possible extension. if exts == nil && mi.extensionOffset.IsValid() { exts = p.Apply(mi.extensionOffset).Extensions() if *exts == nil { *exts = make(map[int32]ExtensionField) } } if exts == nil { break } var o unmarshalOutput o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) if err != nil { break } n = o.n if !o.initialized { initialized = false } } if err != nil { if err != errUnknown { return out, err } n = protowire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return out, errDecode } if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { u := mi.mutableUnknownBytes(p) *u = protowire.AppendTag(*u, num, wtyp) *u = append(*u, b[:n]...) } } b = b[n:] end := start - len(b) if lazyDecode && f != nil && f.isLazy { if num != lastNum { lazyIndex = append(lazyIndex, protolazy.IndexEntry{ FieldNum: uint32(num), Start: uint32(pos), End: uint32(end), }) } else { i := len(lazyIndex) - 1 lazyIndex[i].End = uint32(end) lazyIndex[i].MultipleContiguous = true } } if num < lastNum { outOfOrder = true } pos = end lastNum = num } if groupTag != 0 { return out, errors.New("missing end group marker") } if lazyFields != nil { // Some fields failed validation, and now need to be unmarshaled. for f, action := range lazyFields { if action != lazyUnmarshalLater { continue } initialized = false if *lazy == nil { *lazy = &protolazy.XXX_lazyUnmarshalInfo{} } if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil { return out, err } presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) } } if lazyDecode { if outOfOrder { sort.Slice(lazyIndex, func(i, j int) bool { return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum || (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum && lazyIndex[i].Start < lazyIndex[j].Start) }) } if *lazy == nil { *lazy = &protolazy.XXX_lazyUnmarshalInfo{} } (*lazy).SetIndex(lazyIndex) } if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { initialized = false } if initialized { out.initialized = true } out.n = start - len(b) return out, nil }