// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package impl import ( "fmt" "math" "math/bits" "reflect" "unicode/utf8" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/genid" "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/runtime/protoiface" ) // ValidationStatus is the result of validating the wire-format encoding of a message. type ValidationStatus int const ( // ValidationUnknown indicates that unmarshaling the message might succeed or fail. // The validator was unable to render a judgement. // // The only causes of this status are an aberrant message type appearing somewhere // in the message or a failure in the extension resolver. ValidationUnknown ValidationStatus = iota + 1 // ValidationInvalid indicates that unmarshaling the message will fail. ValidationInvalid // ValidationValid indicates that unmarshaling the message will succeed. ValidationValid // ValidationWrongWireType indicates that a validated field does not have // the expected wire type. ValidationWrongWireType ) func (v ValidationStatus) String() string { switch v { case ValidationUnknown: return "ValidationUnknown" case ValidationInvalid: return "ValidationInvalid" case ValidationValid: return "ValidationValid" default: return fmt.Sprintf("ValidationStatus(%d)", int(v)) } } // Validate determines whether the contents of the buffer are a valid wire encoding // of the message type. // // This function is exposed for testing. func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) { mi, ok := mt.(*MessageInfo) if !ok { return out, ValidationUnknown } if in.Resolver == nil { in.Resolver = protoregistry.GlobalTypes } o, st := mi.validate(in.Buf, 0, unmarshalOptions{ flags: in.Flags, resolver: in.Resolver, }) if o.initialized { out.Flags |= protoiface.UnmarshalInitialized } return out, st } type validationInfo struct { mi *MessageInfo typ validationType keyType, valType validationType // For non-required fields, requiredBit is 0. // // For required fields, requiredBit's nth bit is set, where n is a // unique index in the range [0, MessageInfo.numRequiredFields). // // If there are more than 64 required fields, requiredBit is 0. requiredBit uint64 } type validationType uint8 const ( validationTypeOther validationType = iota validationTypeMessage validationTypeGroup validationTypeMap validationTypeRepeatedVarint validationTypeRepeatedFixed32 validationTypeRepeatedFixed64 validationTypeVarint validationTypeFixed32 validationTypeFixed64 validationTypeBytes validationTypeUTF8String validationTypeMessageSetItem ) func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { var vi validationInfo switch { case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): switch fd.Kind() { case protoreflect.MessageKind: vi.typ = validationTypeMessage if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { vi.mi = getMessageInfo(ot.Field(0).Type) } case protoreflect.GroupKind: vi.typ = validationTypeGroup if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { vi.mi = getMessageInfo(ot.Field(0).Type) } case protoreflect.StringKind: if strs.EnforceUTF8(fd) { vi.typ = validationTypeUTF8String } } default: vi = newValidationInfo(fd, ft) } if fd.Cardinality() == protoreflect.Required { // Avoid overflow. The required field check is done with a 64-bit mask, with // any message containing more than 64 required fields always reported as // potentially uninitialized, so it is not important to get a precise count // of the required fields past 64. if mi.numRequiredFields < math.MaxUint8 { mi.numRequiredFields++ vi.requiredBit = 1 << (mi.numRequiredFields - 1) } } return vi } func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { var vi validationInfo switch { case fd.IsList(): switch fd.Kind() { case protoreflect.MessageKind: vi.typ = validationTypeMessage if ft.Kind() == reflect.Ptr { // Repeated opaque message fields are *[]*T. ft = ft.Elem() } if ft.Kind() == reflect.Slice { vi.mi = getMessageInfo(ft.Elem()) } case protoreflect.GroupKind: vi.typ = validationTypeGroup if ft.Kind() == reflect.Ptr { // Repeated opaque message fields are *[]*T. ft = ft.Elem() } if ft.Kind() == reflect.Slice { vi.mi = getMessageInfo(ft.Elem()) } case protoreflect.StringKind: vi.typ = validationTypeBytes if strs.EnforceUTF8(fd) { vi.typ = validationTypeUTF8String } default: switch wireTypes[fd.Kind()] { case protowire.VarintType: vi.typ = validationTypeRepeatedVarint case protowire.Fixed32Type: vi.typ = validationTypeRepeatedFixed32 case protowire.Fixed64Type: vi.typ = validationTypeRepeatedFixed64 } } case fd.IsMap(): vi.typ = validationTypeMap switch fd.MapKey().Kind() { case protoreflect.StringKind: if strs.EnforceUTF8(fd) { vi.keyType = validationTypeUTF8String } } switch fd.MapValue().Kind() { case protoreflect.MessageKind: vi.valType = validationTypeMessage if ft.Kind() == reflect.Map { vi.mi = getMessageInfo(ft.Elem()) } case protoreflect.StringKind: if strs.EnforceUTF8(fd) { vi.valType = validationTypeUTF8String } } default: switch fd.Kind() { case protoreflect.MessageKind: vi.typ = validationTypeMessage if !fd.IsWeak() { vi.mi = getMessageInfo(ft) } case protoreflect.GroupKind: vi.typ = validationTypeGroup vi.mi = getMessageInfo(ft) case protoreflect.StringKind: vi.typ = validationTypeBytes if strs.EnforceUTF8(fd) { vi.typ = validationTypeUTF8String } default: switch wireTypes[fd.Kind()] { case protowire.VarintType: vi.typ = validationTypeVarint case protowire.Fixed32Type: vi.typ = validationTypeFixed32 case protowire.Fixed64Type: vi.typ = validationTypeFixed64 case protowire.BytesType: vi.typ = validationTypeBytes } } } return vi } func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) { mi.init() type validationState struct { typ validationType keyType, valType validationType endGroup protowire.Number mi *MessageInfo tail []byte requiredMask uint64 } // Pre-allocate some slots to avoid repeated slice reallocation. states := make([]validationState, 0, 16) states = append(states, validationState{ typ: validationTypeMessage, mi: mi, }) if groupTag > 0 { states[0].typ = validationTypeGroup states[0].endGroup = groupTag } initialized := true start := len(b) State: for len(states) > 0 { st := &states[len(states)-1] for len(b) > 0 { // Parse the tag (field number and wire type). var tag uint64 if b[0] < 0x80 { tag = uint64(b[0]) b = b[1:] } else if len(b) >= 2 && b[1] < 128 { tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 b = b[2:] } else { var n int tag, n = protowire.ConsumeVarint(b) if n < 0 { return out, ValidationInvalid } b = b[n:] } var num protowire.Number if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { return out, ValidationInvalid } else { num = protowire.Number(n) } wtyp := protowire.Type(tag & 7) if wtyp == protowire.EndGroupType { if st.endGroup == num { goto PopState } return out, ValidationInvalid } var vi validationInfo switch { case st.typ == validationTypeMap: switch num { case genid.MapEntry_Key_field_number: vi.typ = st.keyType case genid.MapEntry_Value_field_number: vi.typ = st.valType vi.mi = st.mi vi.requiredBit = 1 } case flags.ProtoLegacy && st.mi.isMessageSet: switch num { case messageset.FieldItem: vi.typ = validationTypeMessageSetItem } default: var f *coderFieldInfo if int(num) < len(st.mi.denseCoderFields) { f = st.mi.denseCoderFields[num] } else { f = st.mi.coderFields[num] } if f != nil { vi = f.validation if vi.typ == validationTypeMessage && vi.mi == nil { // Probable weak field. // // TODO: Consider storing the results of this lookup somewhere // rather than recomputing it on every validation. fd := st.mi.Desc.Fields().ByNumber(num) if fd == nil || !fd.IsWeak() { break } messageName := fd.Message().FullName() messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) switch err { case nil: vi.mi, _ = messageType.(*MessageInfo) case protoregistry.NotFound: vi.typ = validationTypeBytes default: return out, ValidationUnknown } } break } // Possible extension field. // // TODO: We should return ValidationUnknown when: // 1. The resolver is not frozen. (More extensions may be added to it.) // 2. The resolver returns preg.NotFound. // In this case, a type added to the resolver in the future could cause // unmarshaling to begin failing. Supporting this requires some way to // determine if the resolver is frozen. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) if err != nil && err != protoregistry.NotFound { return out, ValidationUnknown } if err == nil { vi = getExtensionFieldInfo(xt).validation } } if vi.requiredBit != 0 { // Check that the field has a compatible wire type. // We only need to consider non-repeated field types, // since repeated fields (and maps) can never be required. ok := false switch vi.typ { case validationTypeVarint: ok = wtyp == protowire.VarintType case validationTypeFixed32: ok = wtyp == protowire.Fixed32Type case validationTypeFixed64: ok = wtyp == protowire.Fixed64Type case validationTypeBytes, validationTypeUTF8String, validationTypeMessage: ok = wtyp == protowire.BytesType case validationTypeGroup: ok = wtyp == protowire.StartGroupType } if ok { st.requiredMask |= vi.requiredBit } } switch wtyp { case protowire.VarintType: if len(b) >= 10 { switch { case b[0] < 0x80: b = b[1:] case b[1] < 0x80: b = b[2:] case b[2] < 0x80: b = b[3:] case b[3] < 0x80: b = b[4:] case b[4] < 0x80: b = b[5:] case b[5] < 0x80: b = b[6:] case b[6] < 0x80: b = b[7:] case b[7] < 0x80: b = b[8:] case b[8] < 0x80: b = b[9:] case b[9] < 0x80 && b[9] < 2: b = b[10:] default: return out, ValidationInvalid } } else { switch { case len(b) > 0 && b[0] < 0x80: b = b[1:] case len(b) > 1 && b[1] < 0x80: b = b[2:] case len(b) > 2 && b[2] < 0x80: b = b[3:] case len(b) > 3 && b[3] < 0x80: b = b[4:] case len(b) > 4 && b[4] < 0x80: b = b[5:] case len(b) > 5 && b[5] < 0x80: b = b[6:] case len(b) > 6 && b[6] < 0x80: b = b[7:] case len(b) > 7 && b[7] < 0x80: b = b[8:] case len(b) > 8 && b[8] < 0x80: b = b[9:] case len(b) > 9 && b[9] < 2: b = b[10:] default: return out, ValidationInvalid } } continue State case protowire.BytesType: var size uint64 if len(b) >= 1 && b[0] < 0x80 { size = uint64(b[0]) b = b[1:] } else if len(b) >= 2 && b[1] < 128 { size = uint64(b[0]&0x7f) + uint64(b[1])<<7 b = b[2:] } else { var n int size, n = protowire.ConsumeVarint(b) if n < 0 { return out, ValidationInvalid } b = b[n:] } if size > uint64(len(b)) { return out, ValidationInvalid } v := b[:size] b = b[size:] switch vi.typ { case validationTypeMessage: if vi.mi == nil { return out, ValidationUnknown } vi.mi.init() fallthrough case validationTypeMap: if vi.mi != nil { vi.mi.init() } states = append(states, validationState{ typ: vi.typ, keyType: vi.keyType, valType: vi.valType, mi: vi.mi, tail: b, }) b = v continue State case validationTypeRepeatedVarint: // Packed field. for len(v) > 0 { _, n := protowire.ConsumeVarint(v) if n < 0 { return out, ValidationInvalid } v = v[n:] } case validationTypeRepeatedFixed32: // Packed field. if len(v)%4 != 0 { return out, ValidationInvalid } case validationTypeRepeatedFixed64: // Packed field. if len(v)%8 != 0 { return out, ValidationInvalid } case validationTypeUTF8String: if !utf8.Valid(v) { return out, ValidationInvalid } } case protowire.Fixed32Type: if len(b) < 4 { return out, ValidationInvalid } b = b[4:] case protowire.Fixed64Type: if len(b) < 8 { return out, ValidationInvalid } b = b[8:] case protowire.StartGroupType: switch { case vi.typ == validationTypeGroup: if vi.mi == nil { return out, ValidationUnknown } vi.mi.init() states = append(states, validationState{ typ: validationTypeGroup, mi: vi.mi, endGroup: num, }) continue State case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem: typeid, v, n, err := messageset.ConsumeFieldValue(b, false) if err != nil { return out, ValidationInvalid } xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) switch { case err == protoregistry.NotFound: b = b[n:] case err != nil: return out, ValidationUnknown default: xvi := getExtensionFieldInfo(xt).validation if xvi.mi != nil { xvi.mi.init() } states = append(states, validationState{ typ: xvi.typ, mi: xvi.mi, tail: b[n:], }) b = v continue State } default: n := protowire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return out, ValidationInvalid } b = b[n:] } default: return out, ValidationInvalid } } if st.endGroup != 0 { return out, ValidationInvalid } if len(b) != 0 { return out, ValidationInvalid } b = st.tail PopState: numRequiredFields := 0 switch st.typ { case validationTypeMessage, validationTypeGroup: numRequiredFields = int(st.mi.numRequiredFields) case validationTypeMap: // If this is a map field with a message value that contains // required fields, require that the value be present. if st.mi != nil && st.mi.numRequiredFields > 0 { numRequiredFields = 1 } } // If there are more than 64 required fields, this check will // always fail and we will report that the message is potentially // uninitialized. if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields { initialized = false } states = states[:len(states)-1] } out.n = start - len(b) if initialized { out.initialized = true } return out, ValidationValid }