rebase: update kubernetes to 1.28.0 in main

updating kubernetes to 1.28.0
in the main repo.

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna 2023-08-17 07:15:28 +02:00 committed by mergify[bot]
parent b2fdc269c3
commit ff3e84ad67
706 changed files with 45252 additions and 16346 deletions

114
go.mod
View File

@ -35,15 +35,15 @@ require (
//
// when updating k8s.io/kubernetes, make sure to update the replace section too
//
k8s.io/api v0.27.4
k8s.io/apimachinery v0.27.4
k8s.io/api v0.28.0
k8s.io/apimachinery v0.28.0
k8s.io/client-go v12.0.0+incompatible
k8s.io/cloud-provider v0.27.4
k8s.io/cloud-provider v0.28.0
k8s.io/klog/v2 v2.100.1
k8s.io/kubernetes v1.27.4
k8s.io/mount-utils v0.27.4
k8s.io/kubernetes v1.28.0
k8s.io/mount-utils v0.28.0
k8s.io/pod-security-admission v0.0.0
k8s.io/utils v0.0.0-20230209194617-a36077c30491
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
sigs.k8s.io/controller-runtime v0.15.1
)
@ -51,7 +51,7 @@ require (
github.com/NYTimes/gziphandler v1.1.1 // indirect
github.com/ansel1/merry v1.6.2 // indirect
github.com/ansel1/merry/v2 v2.0.1 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
github.com/armon/go-metrics v0.3.10 // indirect
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect
github.com/aws/aws-sdk-go-v2 v1.20.1 // indirect
@ -62,14 +62,14 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v3 v3.2.2 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/coreos/go-semver v0.3.0 // indirect
github.com/coreos/go-systemd/v22 v22.4.0 // indirect
github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/emicklei/go-restful/v3 v3.9.0 // indirect
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
github.com/evanphx/json-patch v5.6.0+incompatible // indirect
github.com/evanphx/json-patch/v5 v5.6.0 // indirect
github.com/fatih/color v1.13.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
@ -81,14 +81,14 @@ require (
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.1 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.22.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/cel-go v0.12.6 // indirect
github.com/google/gnostic v0.6.9 // indirect
github.com/google/cel-go v0.16.0 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
@ -108,7 +108,7 @@ require (
github.com/hashicorp/vault v1.11.11 // indirect
github.com/hashicorp/vault/sdk v0.7.0 // indirect
github.com/imdario/mergo v0.3.13 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@ -131,15 +131,15 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.42.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/spf13/cobra v1.6.0 // indirect
github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect
go.etcd.io/etcd/api/v3 v3.5.7 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.7 // indirect
go.etcd.io/etcd/client/v3 v3.5.7 // indirect
go.etcd.io/etcd/api/v3 v3.5.9 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect
go.etcd.io/etcd/client/v3 v3.5.9 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.35.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.35.1 // indirect
go.opentelemetry.io/otel v1.10.0 // indirect
@ -151,31 +151,31 @@ require (
go.opentelemetry.io/otel/trace v1.10.0 // indirect
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/oauth2 v0.7.0 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/sync v0.2.0 // indirect
golang.org/x/term v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.9.3 // indirect
gomodules.xyz/jsonpatch/v2 v2.3.0 // indirect
google.golang.org/api v0.110.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apiextensions-apiserver v0.27.4 // indirect
k8s.io/apiserver v0.27.4 // indirect
k8s.io/component-base v0.27.4 // indirect
k8s.io/component-helpers v0.27.4 // indirect
k8s.io/controller-manager v0.27.4 // indirect
k8s.io/kms v0.27.4 // indirect
k8s.io/kube-openapi v0.0.0-20230501164219-8b0f38b5fd1f // indirect
k8s.io/apiextensions-apiserver v0.28.0 // indirect
k8s.io/apiserver v0.28.0 // indirect
k8s.io/component-base v0.28.0 // indirect
k8s.io/component-helpers v0.28.0 // indirect
k8s.io/controller-manager v0.28.0 // indirect
k8s.io/kms v0.28.0 // indirect
k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 // indirect
k8s.io/kubectl v0.0.0 // indirect
k8s.io/kubelet v0.0.0 // indirect
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.1.2 // indirect
@ -194,32 +194,32 @@ replace (
//
// k8s.io/kubernetes depends on these k8s.io packages, but unversioned
//
k8s.io/api => k8s.io/api v0.27.4
k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.27.4
k8s.io/apimachinery => k8s.io/apimachinery v0.27.4
k8s.io/apiserver => k8s.io/apiserver v0.27.4
k8s.io/cli-runtime => k8s.io/cli-runtime v0.27.4
k8s.io/client-go => k8s.io/client-go v0.27.4
k8s.io/cloud-provider => k8s.io/cloud-provider v0.27.4
k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.27.4
k8s.io/code-generator => k8s.io/code-generator v0.27.4
k8s.io/component-base => k8s.io/component-base v0.27.4
k8s.io/component-helpers => k8s.io/component-helpers v0.27.4
k8s.io/controller-manager => k8s.io/controller-manager v0.27.4
k8s.io/cri-api => k8s.io/cri-api v0.27.4
k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.27.4
k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.27.4
k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.27.4
k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.27.4
k8s.io/kube-proxy => k8s.io/kube-proxy v0.27.4
k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.27.4
k8s.io/kubectl => k8s.io/kubectl v0.27.4
k8s.io/kubelet => k8s.io/kubelet v0.27.4
k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.27.4
k8s.io/metrics => k8s.io/metrics v0.27.4
k8s.io/mount-utils => k8s.io/mount-utils v0.27.4
k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.27.4
k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.27.4
k8s.io/api => k8s.io/api v0.28.0
k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.28.0
k8s.io/apimachinery => k8s.io/apimachinery v0.28.0
k8s.io/apiserver => k8s.io/apiserver v0.28.0
k8s.io/cli-runtime => k8s.io/cli-runtime v0.28.0
k8s.io/client-go => k8s.io/client-go v0.28.0
k8s.io/cloud-provider => k8s.io/cloud-provider v0.28.0
k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.28.0
k8s.io/code-generator => k8s.io/code-generator v0.28.0
k8s.io/component-base => k8s.io/component-base v0.28.0
k8s.io/component-helpers => k8s.io/component-helpers v0.28.0
k8s.io/controller-manager => k8s.io/controller-manager v0.28.0
k8s.io/cri-api => k8s.io/cri-api v0.28.0
k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.28.0
k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.28.0
k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.28.0
k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.28.0
k8s.io/kube-proxy => k8s.io/kube-proxy v0.28.0
k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.28.0
k8s.io/kubectl => k8s.io/kubectl v0.28.0
k8s.io/kubelet => k8s.io/kubelet v0.28.0
k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.28.0
k8s.io/metrics => k8s.io/metrics v0.28.0
k8s.io/mount-utils => k8s.io/mount-utils v0.28.0
k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.28.0
k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.28.0
// layeh.com seems to be misbehaving
layeh.com/radius => github.com/layeh/radius v0.0.0-20190322222518-890bc1058917
)

1095
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,68 @@
/*
Package antlr implements the Go version of the ANTLR 4 runtime.
# The ANTLR Tool
ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing,
or translating structured text or binary files. It's widely used to build languages, tools, and frameworks.
From a grammar, ANTLR generates a parser that can build parse trees and also generates a listener interface
(or visitor) that makes it easy to respond to the recognition of phrases of interest.
# Code Generation
ANTLR supports the generation of code in a number of [target languages], and the generated code is supported by a
runtime library, written specifically to support the generated code in the target language. This library is the
runtime for the Go target.
To generate code for the go target, it is generally recommended to place the source grammar files in a package of
their own, and use the `.sh` script method of generating code, using the go generate directive. In that same directory
it is usual, though not required, to place the antlr tool that should be used to generate the code. That does mean
that the antlr tool JAR file will be checked in to your source code control though, so you are free to use any other
way of specifying the version of the ANTLR tool to use, such as aliasing in `.zshrc` or equivalent, or a profile in
your IDE, or configuration in your CI system.
Here is a general template for an ANTLR based recognizer in Go:
.
myproject
parser
mygrammar.g4
antlr-4.12.0-complete.jar
error_listeners.go
generate.go
generate.sh
go.mod
go.sum
main.go
main_test.go
Make sure that the package statement in your grammar file(s) reflects the go package they exist in.
The generate.go file then looks like this:
package parser
//go:generate ./generate.sh
And the generate.sh file will look similar to this:
#!/bin/sh
alias antlr4='java -Xmx500M -cp "./antlr4-4.12.0-complete.jar:$CLASSPATH" org.antlr.v4.Tool'
antlr4 -Dlanguage=Go -no-visitor -package parser *.g4
depending on whether you want visitors or listeners or any other ANTLR options.
From the command line at the root of your package myproject you can then simply issue the command:
go generate ./...
# Copyright Notice
Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
Use of this file is governed by the BSD 3-clause license, which can be found in the [LICENSE.txt] file in the project root.
[target languages]: https://github.com/antlr/antlr4/tree/master/runtime
[LICENSE.txt]: https://github.com/antlr/antlr4/blob/master/LICENSE.txt
*/
package antlr

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -6,11 +6,24 @@ package antlr
import "sync"
// ATNInvalidAltNumber is used to represent an ALT number that has yet to be calculated or
// which is invalid for a particular struct such as [*antlr.BaseRuleContext]
var ATNInvalidAltNumber int
// ATN represents an “[Augmented Transition Network]”, though general in ANTLR the term
// “Augmented Recursive Transition Network” though there are some descriptions of “[Recursive Transition Network]”
// in existence.
//
// ATNs represent the main networks in the system and are serialized by the code generator and support [ALL(*)].
//
// [Augmented Transition Network]: https://en.wikipedia.org/wiki/Augmented_transition_network
// [ALL(*)]: https://www.antlr.org/papers/allstar-techreport.pdf
// [Recursive Transition Network]: https://en.wikipedia.org/wiki/Recursive_transition_network
type ATN struct {
// DecisionToState is the decision points for all rules, subrules, optional
// blocks, ()+, ()*, etc. Used to build DFA predictors for them.
// blocks, ()+, ()*, etc. Each subrule/rule is a decision point, and we must track them so we
// can go back later and build DFA predictors for them. This includes
// all the rules, subrules, optional blocks, ()+, ()* etc...
DecisionToState []DecisionState
// grammarType is the ATN type and is used for deserializing ATNs from strings.
@ -45,6 +58,8 @@ type ATN struct {
edgeMu sync.RWMutex
}
// NewATN returns a new ATN struct representing the given grammarType and is used
// for runtime deserialization of ATNs from the code generated by the ANTLR tool
func NewATN(grammarType int, maxTokenType int) *ATN {
return &ATN{
grammarType: grammarType,
@ -53,7 +68,7 @@ func NewATN(grammarType int, maxTokenType int) *ATN {
}
}
// NextTokensInContext computes the set of valid tokens that can occur starting
// NextTokensInContext computes and returns the set of valid tokens that can occur starting
// in state s. If ctx is nil, the set of tokens will not include what can follow
// the rule surrounding s. In other words, the set will be restricted to tokens
// reachable staying within the rule of s.
@ -61,8 +76,8 @@ func (a *ATN) NextTokensInContext(s ATNState, ctx RuleContext) *IntervalSet {
return NewLL1Analyzer(a).Look(s, nil, ctx)
}
// NextTokensNoContext computes the set of valid tokens that can occur starting
// in s and staying in same rule. Token.EPSILON is in set if we reach end of
// NextTokensNoContext computes and returns the set of valid tokens that can occur starting
// in state s and staying in same rule. [antlr.Token.EPSILON] is in set if we reach end of
// rule.
func (a *ATN) NextTokensNoContext(s ATNState) *IntervalSet {
a.mu.Lock()
@ -76,6 +91,8 @@ func (a *ATN) NextTokensNoContext(s ATNState) *IntervalSet {
return iset
}
// NextTokens computes and returns the set of valid tokens starting in state s, by
// calling either [NextTokensNoContext] (ctx == nil) or [NextTokensInContext] (ctx != nil).
func (a *ATN) NextTokens(s ATNState, ctx RuleContext) *IntervalSet {
if ctx == nil {
return a.NextTokensNoContext(s)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -8,19 +8,14 @@ import (
"fmt"
)
type comparable interface {
equals(other interface{}) bool
}
// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
// context). The syntactic context is a graph-structured stack node whose
// path(s) to the root is the rule invocation(s) chain used to arrive at the
// state. The semantic context is the tree of semantic predicates encountered
// before reaching an ATN state.
type ATNConfig interface {
comparable
hash() int
Equals(o Collectable[ATNConfig]) bool
Hash() int
GetState() ATNState
GetAlt() int
@ -47,7 +42,7 @@ type BaseATNConfig struct {
reachesIntoOuterContext int
}
func NewBaseATNConfig7(old *BaseATNConfig) *BaseATNConfig { // TODO: Dup
func NewBaseATNConfig7(old *BaseATNConfig) ATNConfig { // TODO: Dup
return &BaseATNConfig{
state: old.state,
alt: old.alt,
@ -135,11 +130,16 @@ func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) {
b.reachesIntoOuterContext = v
}
// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same.
func (b *BaseATNConfig) equals(o interface{}) bool {
func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool {
if b == o {
return true
} else if o == nil {
return false
}
var other, ok = o.(*BaseATNConfig)
@ -153,30 +153,32 @@ func (b *BaseATNConfig) equals(o interface{}) bool {
if b.context == nil {
equal = other.context == nil
} else {
equal = b.context.equals(other.context)
equal = b.context.Equals(other.context)
}
var (
nums = b.state.GetStateNumber() == other.state.GetStateNumber()
alts = b.alt == other.alt
cons = b.semanticContext.equals(other.semanticContext)
cons = b.semanticContext.Equals(other.semanticContext)
sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed
)
return nums && alts && cons && sups && equal
}
func (b *BaseATNConfig) hash() int {
// Hash is the default hash function for BaseATNConfig, when no specialist hash function
// is required for a collection
func (b *BaseATNConfig) Hash() int {
var c int
if b.context != nil {
c = b.context.hash()
c = b.context.Hash()
}
h := murmurInit(7)
h = murmurUpdate(h, b.state.GetStateNumber())
h = murmurUpdate(h, b.alt)
h = murmurUpdate(h, c)
h = murmurUpdate(h, b.semanticContext.hash())
h = murmurUpdate(h, b.semanticContext.Hash())
return murmurFinish(h, 4)
}
@ -243,7 +245,9 @@ func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *Lex
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
}
func (l *LexerATNConfig) hash() int {
// Hash is the default hash function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Hash() int {
var f int
if l.passedThroughNonGreedyDecision {
f = 1
@ -253,15 +257,20 @@ func (l *LexerATNConfig) hash() int {
h := murmurInit(7)
h = murmurUpdate(h, l.state.GetStateNumber())
h = murmurUpdate(h, l.alt)
h = murmurUpdate(h, l.context.hash())
h = murmurUpdate(h, l.semanticContext.hash())
h = murmurUpdate(h, l.context.Hash())
h = murmurUpdate(h, l.semanticContext.Hash())
h = murmurUpdate(h, f)
h = murmurUpdate(h, l.lexerActionExecutor.hash())
h = murmurUpdate(h, l.lexerActionExecutor.Hash())
h = murmurFinish(h, 6)
return h
}
func (l *LexerATNConfig) equals(other interface{}) bool {
// Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool {
if l == other {
return true
}
var othert, ok = other.(*LexerATNConfig)
if l == other {
@ -275,7 +284,7 @@ func (l *LexerATNConfig) equals(other interface{}) bool {
var b bool
if l.lexerActionExecutor != nil {
b = !l.lexerActionExecutor.equals(othert.lexerActionExecutor)
b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor)
} else {
b = othert.lexerActionExecutor != nil
}
@ -284,10 +293,9 @@ func (l *LexerATNConfig) equals(other interface{}) bool {
return false
}
return l.BaseATNConfig.equals(othert.BaseATNConfig)
return l.BaseATNConfig.Equals(othert.BaseATNConfig)
}
func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool {
var ds, ok = target.(DecisionState)

View File

@ -1,24 +1,25 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import "fmt"
import (
"fmt"
)
type ATNConfigSet interface {
hash() int
Hash() int
Equals(o Collectable[ATNConfig]) bool
Add(ATNConfig, *DoubleDict) bool
AddAll([]ATNConfig) bool
GetStates() Set
GetStates() *JStore[ATNState, Comparator[ATNState]]
GetPredicates() []SemanticContext
GetItems() []ATNConfig
OptimizeConfigs(interpreter *BaseATNSimulator)
Equals(other interface{}) bool
Length() int
IsEmpty() bool
Contains(ATNConfig) bool
@ -57,7 +58,7 @@ type BaseATNConfigSet struct {
// effectively doubles the number of objects associated with ATNConfigs. All
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when
// read-only because a set becomes a DFA state.
configLookup Set
configLookup *JStore[ATNConfig, Comparator[ATNConfig]]
// configs is the added elements.
configs []ATNConfig
@ -83,7 +84,7 @@ type BaseATNConfigSet struct {
// readOnly is whether it is read-only. Do not
// allow any code to manipulate the set if true because DFA states will point at
// sets and those must not change. It not protect other fields; conflictingAlts
// sets and those must not change. It not, protect other fields; conflictingAlts
// in particular, which is assigned after readOnly.
readOnly bool
@ -104,7 +105,7 @@ func (b *BaseATNConfigSet) Alts() *BitSet {
func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet {
return &BaseATNConfigSet{
cachedHash: -1,
configLookup: newArray2DHashSetWithCap(hashATNConfig, equalATNConfigs, 16, 2),
configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](aConfCompInst),
fullCtx: fullCtx,
}
}
@ -126,9 +127,11 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
b.dipsIntoOuterContext = true
}
existing := b.configLookup.Add(config).(ATNConfig)
existing, present := b.configLookup.Put(config)
if existing == config {
// The config was not already in the set
//
if !present {
b.cachedHash = -1
b.configs = append(b.configs, config) // Track order here
return true
@ -154,11 +157,14 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
return true
}
func (b *BaseATNConfigSet) GetStates() Set {
states := newArray2DHashSet(nil, nil)
func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {
// states uses the standard comparator provided by the ATNState instance
//
states := NewJStore[ATNState, Comparator[ATNState]](aStateEqInst)
for i := 0; i < len(b.configs); i++ {
states.Add(b.configs[i].GetState())
states.Put(b.configs[i].GetState())
}
return states
@ -214,7 +220,34 @@ func (b *BaseATNConfigSet) AddAll(coll []ATNConfig) bool {
return false
}
func (b *BaseATNConfigSet) Equals(other interface{}) bool {
// Compare is a hack function just to verify that adding DFAstares to the known
// set works, so long as comparison of ATNConfigSet s works. For that to work, we
// need to make sure that the set of ATNConfigs in two sets are equivalent. We can't
// know the order, so we do this inefficient hack. If this proves the point, then
// we can change the config set to a better structure.
func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
if len(b.configs) != len(bs.configs) {
return false
}
for _, c := range b.configs {
found := false
for _, c2 := range bs.configs {
if c.Equals(c2) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
if b == other {
return true
} else if _, ok := other.(*BaseATNConfigSet); !ok {
@ -224,15 +257,15 @@ func (b *BaseATNConfigSet) Equals(other interface{}) bool {
other2 := other.(*BaseATNConfigSet)
return b.configs != nil &&
// TODO: b.configs.equals(other2.configs) && // TODO: Is b necessary?
b.fullCtx == other2.fullCtx &&
b.uniqueAlt == other2.uniqueAlt &&
b.conflictingAlts == other2.conflictingAlts &&
b.hasSemanticContext == other2.hasSemanticContext &&
b.dipsIntoOuterContext == other2.dipsIntoOuterContext
b.dipsIntoOuterContext == other2.dipsIntoOuterContext &&
b.Compare(other2)
}
func (b *BaseATNConfigSet) hash() int {
func (b *BaseATNConfigSet) Hash() int {
if b.readOnly {
if b.cachedHash == -1 {
b.cachedHash = b.hashCodeConfigs()
@ -247,7 +280,7 @@ func (b *BaseATNConfigSet) hash() int {
func (b *BaseATNConfigSet) hashCodeConfigs() int {
h := 1
for _, config := range b.configs {
h = 31*h + config.hash()
h = 31*h + config.Hash()
}
return h
}
@ -283,7 +316,7 @@ func (b *BaseATNConfigSet) Clear() {
b.configs = make([]ATNConfig, 0)
b.cachedHash = -1
b.configLookup = newArray2DHashSet(nil, equalATNConfigs)
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](atnConfCompInst)
}
func (b *BaseATNConfigSet) FullContext() bool {
@ -365,7 +398,8 @@ type OrderedATNConfigSet struct {
func NewOrderedATNConfigSet() *OrderedATNConfigSet {
b := NewBaseATNConfigSet(false)
b.configLookup = newArray2DHashSet(nil, nil)
// This set uses the standard Hash() and Equals() from ATNConfig
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
return &OrderedATNConfigSet{BaseATNConfigSet: b}
}
@ -375,7 +409,7 @@ func hashATNConfig(i interface{}) int {
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().hash()
hash = 31*hash + o.GetSemanticContext().Hash()
return hash
}
@ -403,5 +437,5 @@ func equalATNConfigs(a, b interface{}) bool {
return false
}
return ai.GetSemanticContext().equals(bi.GetSemanticContext())
return ai.GetSemanticContext().Equals(bi.GetSemanticContext())
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -49,7 +49,8 @@ type ATNState interface {
AddTransition(Transition, int)
String() string
hash() int
Hash() int
Equals(Collectable[ATNState]) bool
}
type BaseATNState struct {
@ -123,7 +124,7 @@ func (as *BaseATNState) SetNextTokenWithinRule(v *IntervalSet) {
as.NextTokenWithinRule = v
}
func (as *BaseATNState) hash() int {
func (as *BaseATNState) Hash() int {
return as.stateNumber
}
@ -131,7 +132,7 @@ func (as *BaseATNState) String() string {
return strconv.Itoa(as.stateNumber)
}
func (as *BaseATNState) equals(other interface{}) bool {
func (as *BaseATNState) Equals(other Collectable[ATNState]) bool {
if ot, ok := other.(ATNState); ok {
return as.stateNumber == ot.GetStateNumber()
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -331,10 +331,12 @@ func (c *CommonTokenStream) GetTextFromRuleContext(interval RuleContext) string
func (c *CommonTokenStream) GetTextFromInterval(interval *Interval) string {
c.lazyInit()
c.Fill()
if interval == nil {
c.Fill()
interval = NewInterval(0, len(c.tokens)-1)
} else {
c.Sync(interval.Stop)
}
start := interval.Start

View File

@ -0,0 +1,147 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
// This file contains all the implementations of custom comparators used for generic collections when the
// Hash() and Equals() funcs supplied by the struct objects themselves need to be overridden. Normally, we would
// put the comparators in the source file for the struct themselves, but given the organization of this code is
// sorta kinda based upon the Java code, I found it confusing trying to find out which comparator was where and used by
// which instantiation of a collection. For instance, an Array2DHashSet in the Java source, when used with ATNConfig
// collections requires three different comparators depending on what the collection is being used for. Collecting - pun intended -
// all the comparators here, makes it much easier to see which implementation of hash and equals is used by which collection.
// It also makes it easy to verify that the Hash() and Equals() functions marry up with the Java implementations.
// ObjEqComparator is the equivalent of the Java ObjectEqualityComparator, which is the default instance of
// Equality comparator. We do not have inheritance in Go, only interfaces, so we use generics to enforce some
// type safety and avoid having to implement this for every type that we want to perform comparison on.
//
// This comparator works by using the standard Hash() and Equals() methods of the type T that is being compared. Which
// allows us to use it in any collection instance that does nto require a special hash or equals implementation.
type ObjEqComparator[T Collectable[T]] struct{}
var (
aStateEqInst = &ObjEqComparator[ATNState]{}
aConfEqInst = &ObjEqComparator[ATNConfig]{}
aConfCompInst = &ATNConfigComparator[ATNConfig]{}
atnConfCompInst = &BaseATNConfigComparator[ATNConfig]{}
dfaStateEqInst = &ObjEqComparator[*DFAState]{}
semctxEqInst = &ObjEqComparator[SemanticContext]{}
atnAltCfgEqInst = &ATNAltConfigComparator[ATNConfig]{}
)
// Equals2 delegates to the Equals() method of type T
func (c *ObjEqComparator[T]) Equals2(o1, o2 T) bool {
return o1.Equals(o2)
}
// Hash1 delegates to the Hash() method of type T
func (c *ObjEqComparator[T]) Hash1(o T) int {
return o.Hash()
}
type SemCComparator[T Collectable[T]] struct{}
// ATNConfigComparator is used as the compartor for the configLookup field of an ATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type ATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetAlt() == o2.GetAlt() &&
o1.GetSemanticContext().Equals(o2.GetSemanticContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Hash1(o ATNConfig) int {
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().Hash()
return hash
}
// ATNAltConfigComparator is used as the comparator for mapping configs to Alt Bitsets
type ATNAltConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetContext().Equals(o2.GetContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Hash1(o ATNConfig) int {
h := murmurInit(7)
h = murmurUpdate(h, o.GetState().GetStateNumber())
h = murmurUpdate(h, o.GetContext().Hash())
return murmurFinish(h, 2)
}
// BaseATNConfigComparator is used as the comparator for the configLookup field of a BaseATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type BaseATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for baseATNConfigSet
func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetAlt() == o2.GetAlt() &&
o1.GetSemanticContext().Equals(o2.GetSemanticContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup, but in fact just
// delegates to the standard Hash() method of the ATNConfig type.
func (c *BaseATNConfigComparator[T]) Hash1(o ATNConfig) int {
return o.Hash()
}

View File

@ -1,13 +1,9 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"sort"
)
type DFA struct {
// atnStartState is the ATN state in which this was created
atnStartState DecisionState
@ -15,8 +11,15 @@ type DFA struct {
decision int
// states is all the DFA states. Use Map to get the old state back; Set can only
// indicate whether it is there.
states map[int]*DFAState
// indicate whether it is there. Go maps implement key hash collisions and so on and are very
// good, but the DFAState is an object and can't be used directly as the key as it can in say JAva
// amd C#, whereby if the hashcode is the same for two objects, then Equals() is called against them
// to see if they really are the same object.
//
//
states *JStore[*DFAState, *ObjEqComparator[*DFAState]]
numstates int
s0 *DFAState
@ -29,7 +32,7 @@ func NewDFA(atnStartState DecisionState, decision int) *DFA {
dfa := &DFA{
atnStartState: atnStartState,
decision: decision,
states: make(map[int]*DFAState),
states: NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst),
}
if s, ok := atnStartState.(*StarLoopEntryState); ok && s.precedenceRuleDecision {
dfa.precedenceDfa = true
@ -92,7 +95,8 @@ func (d *DFA) getPrecedenceDfa() bool {
// true or nil otherwise, and d.precedenceDfa is updated.
func (d *DFA) setPrecedenceDfa(precedenceDfa bool) {
if d.getPrecedenceDfa() != precedenceDfa {
d.setStates(make(map[int]*DFAState))
d.states = NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst)
d.numstates = 0
if precedenceDfa {
precedenceState := NewDFAState(-1, NewBaseATNConfigSet(false))
@ -117,38 +121,12 @@ func (d *DFA) setS0(s *DFAState) {
d.s0 = s
}
func (d *DFA) getState(hash int) (*DFAState, bool) {
s, ok := d.states[hash]
return s, ok
}
func (d *DFA) setStates(states map[int]*DFAState) {
d.states = states
}
func (d *DFA) setState(hash int, state *DFAState) {
d.states[hash] = state
}
func (d *DFA) numStates() int {
return len(d.states)
}
type dfaStateList []*DFAState
func (d dfaStateList) Len() int { return len(d) }
func (d dfaStateList) Less(i, j int) bool { return d[i].stateNumber < d[j].stateNumber }
func (d dfaStateList) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
// sortedStates returns the states in d sorted by their state number.
func (d *DFA) sortedStates() []*DFAState {
vs := make([]*DFAState, 0, len(d.states))
for _, v := range d.states {
vs = append(vs, v)
}
sort.Sort(dfaStateList(vs))
vs := d.states.SortedSlice(func(i, j *DFAState) bool {
return i.stateNumber < j.stateNumber
})
return vs
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -90,16 +90,16 @@ func NewDFAState(stateNumber int, configs ATNConfigSet) *DFAState {
}
// GetAltSet gets the set of all alts mentioned by all ATN configurations in d.
func (d *DFAState) GetAltSet() Set {
alts := newArray2DHashSet(nil, nil)
func (d *DFAState) GetAltSet() []int {
var alts []int
if d.configs != nil {
for _, c := range d.configs.GetItems() {
alts.Add(c.GetAlt())
alts = append(alts, c.GetAlt())
}
}
if alts.Len() == 0 {
if len(alts) == 0 {
return nil
}
@ -130,27 +130,6 @@ func (d *DFAState) setPrediction(v int) {
d.prediction = v
}
// equals returns whether d equals other. Two DFAStates are equal if their ATN
// configuration sets are the same. This method is used to see if a state
// already exists.
//
// Because the number of alternatives and number of ATN configurations are
// finite, there is a finite number of DFA states that can be processed. This is
// necessary to show that the algorithm terminates.
//
// Cannot test the DFA state numbers here because in
// ParserATNSimulator.addDFAState we need to know if any other state exists that
// has d exact set of ATN configurations. The stateNumber is irrelevant.
func (d *DFAState) equals(other interface{}) bool {
if d == other {
return true
} else if _, ok := other.(*DFAState); !ok {
return false
}
return d.configs.Equals(other.(*DFAState).configs)
}
func (d *DFAState) String() string {
var s string
if d.isAcceptState {
@ -164,8 +143,27 @@ func (d *DFAState) String() string {
return fmt.Sprintf("%d:%s%s", d.stateNumber, fmt.Sprint(d.configs), s)
}
func (d *DFAState) hash() int {
func (d *DFAState) Hash() int {
h := murmurInit(7)
h = murmurUpdate(h, d.configs.hash())
h = murmurUpdate(h, d.configs.Hash())
return murmurFinish(h, 1)
}
// Equals returns whether d equals other. Two DFAStates are equal if their ATN
// configuration sets are the same. This method is used to see if a state
// already exists.
//
// Because the number of alternatives and number of ATN configurations are
// finite, there is a finite number of DFA states that can be processed. This is
// necessary to show that the algorithm terminates.
//
// Cannot test the DFA state numbers here because in
// ParserATNSimulator.addDFAState we need to know if any other state exists that
// has d exact set of ATN configurations. The stateNumber is irrelevant.
func (d *DFAState) Equals(o Collectable[*DFAState]) bool {
if d == o {
return true
}
return d.configs.Equals(o.(*DFAState).configs)
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -87,7 +87,6 @@ func (d *DiagnosticErrorListener) getDecisionDescription(recognizer Parser, dfa
return strconv.Itoa(decision) + " (" + ruleName + ")"
}
//
// Computes the set of conflicting or ambiguous alternatives from a
// configuration set, if that information was not already provided by the
// parser.
@ -97,7 +96,6 @@ func (d *DiagnosticErrorListener) getDecisionDescription(recognizer Parser, dfa
// @param configs The conflicting or ambiguous configuration set.
// @return Returns {@code ReportedAlts} if it is not {@code nil}, otherwise
// returns the set of alternatives represented in {@code configs}.
//
func (d *DiagnosticErrorListener) getConflictingAlts(ReportedAlts *BitSet, set ATNConfigSet) *BitSet {
if ReportedAlts != nil {
return ReportedAlts

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -48,12 +48,9 @@ func NewConsoleErrorListener() *ConsoleErrorListener {
return new(ConsoleErrorListener)
}
//
// Provides a default instance of {@link ConsoleErrorListener}.
//
var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener()
//
// {@inheritDoc}
//
// <p>
@ -64,7 +61,6 @@ var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener()
// <pre>
// line <em>line</em>:<em>charPositionInLine</em> <em>msg</em>
// </pre>
//
func (c *ConsoleErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException) {
fmt.Fprintln(os.Stderr, "line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg)
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -23,7 +23,6 @@ type ErrorStrategy interface {
// This is the default implementation of {@link ANTLRErrorStrategy} used for
// error Reporting and recovery in ANTLR parsers.
//
type DefaultErrorStrategy struct {
errorRecoveryMode bool
lastErrorIndex int
@ -61,12 +60,10 @@ func (d *DefaultErrorStrategy) reset(recognizer Parser) {
d.endErrorCondition(recognizer)
}
//
// This method is called to enter error recovery mode when a recognition
// exception is Reported.
//
// @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) beginErrorCondition(recognizer Parser) {
d.errorRecoveryMode = true
}
@ -75,28 +72,23 @@ func (d *DefaultErrorStrategy) InErrorRecoveryMode(recognizer Parser) bool {
return d.errorRecoveryMode
}
//
// This method is called to leave error recovery mode after recovering from
// a recognition exception.
//
// @param recognizer
//
func (d *DefaultErrorStrategy) endErrorCondition(recognizer Parser) {
d.errorRecoveryMode = false
d.lastErrorStates = nil
d.lastErrorIndex = -1
}
//
// {@inheritDoc}
//
// <p>The default implementation simply calls {@link //endErrorCondition}.</p>
//
func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) {
d.endErrorCondition(recognizer)
}
//
// {@inheritDoc}
//
// <p>The default implementation returns immediately if the handler is already
@ -114,7 +106,6 @@ func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) {
// <li>All other types: calls {@link Parser//NotifyErrorListeners} to Report
// the exception</li>
// </ul>
//
func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionException) {
// if we've already Reported an error and have not Matched a token
// yet successfully, don't Report any errors.
@ -142,7 +133,6 @@ func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionExcep
// <p>The default implementation reSynchronizes the parser by consuming tokens
// until we find one in the reSynchronization set--loosely the set of tokens
// that can follow the current rule.</p>
//
func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
if d.lastErrorIndex == recognizer.GetInputStream().Index() &&
@ -206,7 +196,6 @@ func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException
// compare token set at the start of the loop and at each iteration. If for
// some reason speed is suffering for you, you can turn off d
// functionality by simply overriding d method as a blank { }.</p>
//
func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
// If already recovering, don't try to Sync
if d.InErrorRecoveryMode(recognizer) {
@ -247,7 +236,6 @@ func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
//
// @param recognizer the parser instance
// @param e the recognition exception
//
func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *NoViableAltException) {
tokens := recognizer.GetTokenStream()
var input string
@ -264,7 +252,6 @@ func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *N
recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
}
//
// This is called by {@link //ReportError} when the exception is an
// {@link InputMisMatchException}.
//
@ -272,14 +259,12 @@ func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *N
//
// @param recognizer the parser instance
// @param e the recognition exception
//
func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *InputMisMatchException) {
msg := "mismatched input " + this.GetTokenErrorDisplay(e.offendingToken) +
" expecting " + e.getExpectedTokens().StringVerbose(recognizer.GetLiteralNames(), recognizer.GetSymbolicNames(), false)
recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
}
//
// This is called by {@link //ReportError} when the exception is a
// {@link FailedPredicateException}.
//
@ -287,7 +272,6 @@ func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *Inpu
//
// @param recognizer the parser instance
// @param e the recognition exception
//
func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *FailedPredicateException) {
ruleName := recognizer.GetRuleNames()[recognizer.GetParserRuleContext().GetRuleIndex()]
msg := "rule " + ruleName + " " + e.message
@ -310,7 +294,6 @@ func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *Faile
// {@link Parser//NotifyErrorListeners}.</p>
//
// @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) {
return
@ -339,7 +322,6 @@ func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
// {@link Parser//NotifyErrorListeners}.</p>
//
// @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) {
return
@ -392,15 +374,14 @@ func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
// derivation:
//
// <pre>
// =&gt ID '=' '(' INT ')' ('+' atom)* ''
// =&gt ID '=' '(' INT ')' ('+' atom)*
// ^
// </pre>
//
// The attempt to Match {@code ')'} will fail when it sees {@code ''} and
// call {@link //recoverInline}. To recover, it sees that {@code LA(1)==''}
// The attempt to Match {@code ')'} will fail when it sees {@code } and
// call {@link //recoverInline}. To recover, it sees that {@code LA(1)==}
// is in the set of tokens that can follow the {@code ')'} token reference
// in rule {@code atom}. It can assume that you forgot the {@code ')'}.
//
func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
// SINGLE TOKEN DELETION
MatchedSymbol := d.SingleTokenDeletion(recognizer)
@ -418,7 +399,6 @@ func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
panic(NewInputMisMatchException(recognizer))
}
//
// This method implements the single-token insertion inline error recovery
// strategy. It is called by {@link //recoverInline} if the single-token
// deletion strategy fails to recover from the mismatched input. If this
@ -434,7 +414,6 @@ func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
// @param recognizer the parser instance
// @return {@code true} if single-token insertion is a viable recovery
// strategy for the current mismatched input, otherwise {@code false}
//
func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
currentSymbolType := recognizer.GetTokenStream().LA(1)
// if current token is consistent with what could come after current
@ -469,7 +448,6 @@ func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
// @return the successfully Matched {@link Token} instance if single-token
// deletion successfully recovers from the mismatched input, otherwise
// {@code nil}
//
func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
NextTokenType := recognizer.GetTokenStream().LA(2)
expecting := d.GetExpectedTokens(recognizer)
@ -507,7 +485,6 @@ func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
// a CommonToken of the appropriate type. The text will be the token.
// If you change what tokens must be created by the lexer,
// override d method to create the appropriate tokens.
//
func (d *DefaultErrorStrategy) GetMissingSymbol(recognizer Parser) Token {
currentSymbol := recognizer.GetCurrentToken()
expecting := d.GetExpectedTokens(recognizer)
@ -546,7 +523,6 @@ func (d *DefaultErrorStrategy) GetExpectedTokens(recognizer Parser) *IntervalSet
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
//
func (d *DefaultErrorStrategy) GetTokenErrorDisplay(t Token) string {
if t == nil {
return "<no token>"
@ -578,7 +554,7 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
// from within the rule i.e., the FIRST computation done by
// ANTLR stops at the end of a rule.
//
// EXAMPLE
// # EXAMPLE
//
// When you find a "no viable alt exception", the input is not
// consistent with any of the alternatives for rule r. The best
@ -597,7 +573,6 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
// c : ID
// | INT
//
//
// At each rule invocation, the set of tokens that could follow
// that rule is pushed on a stack. Here are the various
// context-sensitive follow sets:
@ -660,7 +635,6 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
//
// Like Grosch I implement context-sensitive FOLLOW sets that are combined
// at run-time upon error to avoid overhead during parsing.
//
func (d *DefaultErrorStrategy) getErrorRecoverySet(recognizer Parser) *IntervalSet {
atn := recognizer.GetInterpreter().atn
ctx := recognizer.GetParserRuleContext()
@ -733,7 +707,6 @@ func NewBailErrorStrategy() *BailErrorStrategy {
// in a {@link ParseCancellationException} so it is not caught by the
// rule func catches. Use {@link Exception//getCause()} to get the
// original {@link RecognitionException}.
//
func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
context := recognizer.GetParserRuleContext()
for context != nil {
@ -749,7 +722,6 @@ func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
// Make sure we don't attempt to recover inline if the parser
// successfully recovers, it won't panic an exception.
//
func (b *BailErrorStrategy) RecoverInline(recognizer Parser) Token {
b.Recover(recognizer, NewInputMisMatchException(recognizer))

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -74,7 +74,6 @@ func (b *BaseRecognitionException) GetInputStream() IntStream {
// <p>If the state number is not known, b method returns -1.</p>
//
// Gets the set of input symbols which could potentially follow the
// previously Matched symbol at the time b exception was panicn.
//
@ -136,7 +135,6 @@ type NoViableAltException struct {
// to take based upon the remaining input. It tracks the starting token
// of the offending input and also knows where the parser was
// in the various paths when the error. Reported by ReportNoViableAlternative()
//
func NewNoViableAltException(recognizer Parser, input TokenStream, startToken Token, offendingToken Token, deadEndConfigs ATNConfigSet, ctx ParserRuleContext) *NoViableAltException {
if ctx == nil {
@ -177,7 +175,6 @@ type InputMisMatchException struct {
// This signifies any kind of mismatched input exceptions such as
// when the current input does not Match the expected token.
//
func NewInputMisMatchException(recognizer Parser) *InputMisMatchException {
i := new(InputMisMatchException)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -223,6 +223,10 @@ func (i *IntervalSet) StringVerbose(literalNames []string, symbolicNames []strin
return i.toIndexString()
}
func (i *IntervalSet) GetIntervals() []*Interval {
return i.intervals
}
func (i *IntervalSet) toCharString() string {
names := make([]string, len(i.intervals))

View File

@ -0,0 +1,198 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
import (
"sort"
)
// Collectable is an interface that a struct should implement if it is to be
// usable as a key in these collections.
type Collectable[T any] interface {
Hash() int
Equals(other Collectable[T]) bool
}
type Comparator[T any] interface {
Hash1(o T) int
Equals2(T, T) bool
}
// JStore implements a container that allows the use of a struct to calculate the key
// for a collection of values akin to map. This is not meant to be a full-blown HashMap but just
// serve the needs of the ANTLR Go runtime.
//
// For ease of porting the logic of the runtime from the master target (Java), this collection
// operates in a similar way to Java, in that it can use any struct that supplies a Hash() and Equals()
// function as the key. The values are stored in a standard go map which internally is a form of hashmap
// itself, the key for the go map is the hash supplied by the key object. The collection is able to deal with
// hash conflicts by using a simple slice of values associated with the hash code indexed bucket. That isn't
// particularly efficient, but it is simple, and it works. As this is specifically for the ANTLR runtime, and
// we understand the requirements, then this is fine - this is not a general purpose collection.
type JStore[T any, C Comparator[T]] struct {
store map[int][]T
len int
comparator Comparator[T]
}
func NewJStore[T any, C Comparator[T]](comparator Comparator[T]) *JStore[T, C] {
if comparator == nil {
panic("comparator cannot be nil")
}
s := &JStore[T, C]{
store: make(map[int][]T, 1),
comparator: comparator,
}
return s
}
// Put will store given value in the collection. Note that the key for storage is generated from
// the value itself - this is specifically because that is what ANTLR needs - this would not be useful
// as any kind of general collection.
//
// If the key has a hash conflict, then the value will be added to the slice of values associated with the
// hash, unless the value is already in the slice, in which case the existing value is returned. Value equivalence is
// tested by calling the equals() method on the key.
//
// # If the given value is already present in the store, then the existing value is returned as v and exists is set to true
//
// If the given value is not present in the store, then the value is added to the store and returned as v and exists is set to false.
func (s *JStore[T, C]) Put(value T) (v T, exists bool) { //nolint:ireturn
kh := s.comparator.Hash1(value)
for _, v1 := range s.store[kh] {
if s.comparator.Equals2(value, v1) {
return v1, true
}
}
s.store[kh] = append(s.store[kh], value)
s.len++
return value, false
}
// Get will return the value associated with the key - the type of the key is the same type as the value
// which would not generally be useful, but this is a specific thing for ANTLR where the key is
// generated using the object we are going to store.
func (s *JStore[T, C]) Get(key T) (T, bool) { //nolint:ireturn
kh := s.comparator.Hash1(key)
for _, v := range s.store[kh] {
if s.comparator.Equals2(key, v) {
return v, true
}
}
return key, false
}
// Contains returns true if the given key is present in the store
func (s *JStore[T, C]) Contains(key T) bool { //nolint:ireturn
_, present := s.Get(key)
return present
}
func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T {
vs := make([]T, 0, len(s.store))
for _, v := range s.store {
vs = append(vs, v...)
}
sort.Slice(vs, func(i, j int) bool {
return less(vs[i], vs[j])
})
return vs
}
func (s *JStore[T, C]) Each(f func(T) bool) {
for _, e := range s.store {
for _, v := range e {
f(v)
}
}
}
func (s *JStore[T, C]) Len() int {
return s.len
}
func (s *JStore[T, C]) Values() []T {
vs := make([]T, 0, len(s.store))
for _, e := range s.store {
for _, v := range e {
vs = append(vs, v)
}
}
return vs
}
type entry[K, V any] struct {
key K
val V
}
type JMap[K, V any, C Comparator[K]] struct {
store map[int][]*entry[K, V]
len int
comparator Comparator[K]
}
func NewJMap[K, V any, C Comparator[K]](comparator Comparator[K]) *JMap[K, V, C] {
return &JMap[K, V, C]{
store: make(map[int][]*entry[K, V], 1),
comparator: comparator,
}
}
func (m *JMap[K, V, C]) Put(key K, val V) {
kh := m.comparator.Hash1(key)
m.store[kh] = append(m.store[kh], &entry[K, V]{key, val})
m.len++
}
func (m *JMap[K, V, C]) Values() []V {
vs := make([]V, 0, len(m.store))
for _, e := range m.store {
for _, v := range e {
vs = append(vs, v.val)
}
}
return vs
}
func (m *JMap[K, V, C]) Get(key K) (V, bool) {
var none V
kh := m.comparator.Hash1(key)
for _, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
return e.val, true
}
}
return none, false
}
func (m *JMap[K, V, C]) Len() int {
return len(m.store)
}
func (m *JMap[K, V, C]) Delete(key K) {
kh := m.comparator.Hash1(key)
for i, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
m.store[kh] = append(m.store[kh][:i], m.store[kh][i+1:]...)
m.len--
return
}
}
}
func (m *JMap[K, V, C]) Clear() {
m.store = make(map[int][]*entry[K, V])
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -232,8 +232,6 @@ func (b *BaseLexer) NextToken() Token {
}
return b.token
}
return nil
}
// Instruct the lexer to Skip creating a token for current lexer rule
@ -342,7 +340,7 @@ func (b *BaseLexer) GetCharIndex() int {
}
// Return the text Matched so far for the current token or any text override.
//Set the complete text of l token it wipes any previous changes to the text.
// Set the complete text of l token it wipes any previous changes to the text.
func (b *BaseLexer) GetText() string {
if b.text != "" {
return b.text

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -21,8 +21,8 @@ type LexerAction interface {
getActionType() int
getIsPositionDependent() bool
execute(lexer Lexer)
hash() int
equals(other LexerAction) bool
Hash() int
Equals(other LexerAction) bool
}
type BaseLexerAction struct {
@ -51,15 +51,14 @@ func (b *BaseLexerAction) getIsPositionDependent() bool {
return b.isPositionDependent
}
func (b *BaseLexerAction) hash() int {
func (b *BaseLexerAction) Hash() int {
return b.actionType
}
func (b *BaseLexerAction) equals(other LexerAction) bool {
func (b *BaseLexerAction) Equals(other LexerAction) bool {
return b == other
}
//
// Implements the {@code Skip} lexer action by calling {@link Lexer//Skip}.
//
// <p>The {@code Skip} command does not have any parameters, so l action is
@ -85,7 +84,8 @@ func (l *LexerSkipAction) String() string {
return "skip"
}
// Implements the {@code type} lexer action by calling {@link Lexer//setType}
// Implements the {@code type} lexer action by calling {@link Lexer//setType}
//
// with the assigned type.
type LexerTypeAction struct {
*BaseLexerAction
@ -104,14 +104,14 @@ func (l *LexerTypeAction) execute(lexer Lexer) {
lexer.SetType(l.thetype)
}
func (l *LexerTypeAction) hash() int {
func (l *LexerTypeAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.thetype)
return murmurFinish(h, 2)
}
func (l *LexerTypeAction) equals(other LexerAction) bool {
func (l *LexerTypeAction) Equals(other LexerAction) bool {
if l == other {
return true
} else if _, ok := other.(*LexerTypeAction); !ok {
@ -148,14 +148,14 @@ func (l *LexerPushModeAction) execute(lexer Lexer) {
lexer.PushMode(l.mode)
}
func (l *LexerPushModeAction) hash() int {
func (l *LexerPushModeAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.mode)
return murmurFinish(h, 2)
}
func (l *LexerPushModeAction) equals(other LexerAction) bool {
func (l *LexerPushModeAction) Equals(other LexerAction) bool {
if l == other {
return true
} else if _, ok := other.(*LexerPushModeAction); !ok {
@ -245,14 +245,14 @@ func (l *LexerModeAction) execute(lexer Lexer) {
lexer.SetMode(l.mode)
}
func (l *LexerModeAction) hash() int {
func (l *LexerModeAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.mode)
return murmurFinish(h, 2)
}
func (l *LexerModeAction) equals(other LexerAction) bool {
func (l *LexerModeAction) Equals(other LexerAction) bool {
if l == other {
return true
} else if _, ok := other.(*LexerModeAction); !ok {
@ -303,7 +303,7 @@ func (l *LexerCustomAction) execute(lexer Lexer) {
lexer.Action(nil, l.ruleIndex, l.actionIndex)
}
func (l *LexerCustomAction) hash() int {
func (l *LexerCustomAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.ruleIndex)
@ -311,13 +311,14 @@ func (l *LexerCustomAction) hash() int {
return murmurFinish(h, 3)
}
func (l *LexerCustomAction) equals(other LexerAction) bool {
func (l *LexerCustomAction) Equals(other LexerAction) bool {
if l == other {
return true
} else if _, ok := other.(*LexerCustomAction); !ok {
return false
} else {
return l.ruleIndex == other.(*LexerCustomAction).ruleIndex && l.actionIndex == other.(*LexerCustomAction).actionIndex
return l.ruleIndex == other.(*LexerCustomAction).ruleIndex &&
l.actionIndex == other.(*LexerCustomAction).actionIndex
}
}
@ -344,14 +345,14 @@ func (l *LexerChannelAction) execute(lexer Lexer) {
lexer.SetChannel(l.channel)
}
func (l *LexerChannelAction) hash() int {
func (l *LexerChannelAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.channel)
return murmurFinish(h, 2)
}
func (l *LexerChannelAction) equals(other LexerAction) bool {
func (l *LexerChannelAction) Equals(other LexerAction) bool {
if l == other {
return true
} else if _, ok := other.(*LexerChannelAction); !ok {
@ -412,10 +413,10 @@ func (l *LexerIndexedCustomAction) execute(lexer Lexer) {
l.lexerAction.execute(lexer)
}
func (l *LexerIndexedCustomAction) hash() int {
func (l *LexerIndexedCustomAction) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, l.offset)
h = murmurUpdate(h, l.lexerAction.hash())
h = murmurUpdate(h, l.lexerAction.Hash())
return murmurFinish(h, 2)
}
@ -425,6 +426,7 @@ func (l *LexerIndexedCustomAction) equals(other LexerAction) bool {
} else if _, ok := other.(*LexerIndexedCustomAction); !ok {
return false
} else {
return l.offset == other.(*LexerIndexedCustomAction).offset && l.lexerAction == other.(*LexerIndexedCustomAction).lexerAction
return l.offset == other.(*LexerIndexedCustomAction).offset &&
l.lexerAction.Equals(other.(*LexerIndexedCustomAction).lexerAction)
}
}

View File

@ -1,9 +1,11 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import "golang.org/x/exp/slices"
// Represents an executor for a sequence of lexer actions which traversed during
// the Matching operation of a lexer rule (token).
//
@ -12,8 +14,8 @@ package antlr
// not cause bloating of the {@link DFA} created for the lexer.</p>
type LexerActionExecutor struct {
lexerActions []LexerAction
cachedHash int
lexerActions []LexerAction
cachedHash int
}
func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor {
@ -30,7 +32,7 @@ func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor {
// of the performance-critical {@link LexerATNConfig//hashCode} operation.
l.cachedHash = murmurInit(57)
for _, a := range lexerActions {
l.cachedHash = murmurUpdate(l.cachedHash, a.hash())
l.cachedHash = murmurUpdate(l.cachedHash, a.Hash())
}
return l
@ -151,14 +153,17 @@ func (l *LexerActionExecutor) execute(lexer Lexer, input CharStream, startIndex
}
}
func (l *LexerActionExecutor) hash() int {
func (l *LexerActionExecutor) Hash() int {
if l == nil {
// TODO: Why is this here? l should not be nil
return 61
}
// TODO: This is created from the action itself when the struct is created - will this be an issue at some point? Java uses the runtime assign hashcode
return l.cachedHash
}
func (l *LexerActionExecutor) equals(other interface{}) bool {
func (l *LexerActionExecutor) Equals(other interface{}) bool {
if l == other {
return true
}
@ -169,5 +174,13 @@ func (l *LexerActionExecutor) equals(other interface{}) bool {
if othert == nil {
return false
}
return l.cachedHash == othert.cachedHash && &l.lexerActions == &othert.lexerActions
if l.cachedHash != othert.cachedHash {
return false
}
if len(l.lexerActions) != len(othert.lexerActions) {
return false
}
return slices.EqualFunc(l.lexerActions, othert.lexerActions, func(i, j LexerAction) bool {
return i.Equals(j)
})
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -591,19 +591,24 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool)
proposed.lexerActionExecutor = firstConfigWithRuleStopState.(*LexerATNConfig).lexerActionExecutor
proposed.setPrediction(l.atn.ruleToTokenType[firstConfigWithRuleStopState.GetState().GetRuleIndex()])
}
hash := proposed.hash()
dfa := l.decisionToDFA[l.mode]
l.atn.stateMu.Lock()
defer l.atn.stateMu.Unlock()
existing, ok := dfa.getState(hash)
if ok {
existing, present := dfa.states.Get(proposed)
if present {
// This state was already present, so just return it.
//
proposed = existing
} else {
proposed.stateNumber = dfa.numStates()
// We need to add the new state
//
proposed.stateNumber = dfa.states.Len()
configs.SetReadOnly(true)
proposed.configs = configs
dfa.setState(hash, proposed)
dfa.states.Put(proposed)
}
if !suppressEdge {
dfa.setS0(proposed)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -14,14 +14,15 @@ func NewLL1Analyzer(atn *ATN) *LL1Analyzer {
return la
}
//* Special value added to the lookahead sets to indicate that we hit
// a predicate during analysis if {@code seeThruPreds==false}.
///
// - Special value added to the lookahead sets to indicate that we hit
// a predicate during analysis if {@code seeThruPreds==false}.
//
// /
const (
LL1AnalyzerHitPred = TokenInvalidType
)
//*
// *
// Calculates the SLL(1) expected lookahead set for each outgoing transition
// of an {@link ATNState}. The returned array has one element for each
// outgoing transition in {@code s}. If the closure from transition
@ -38,7 +39,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
look := make([]*IntervalSet, count)
for alt := 0; alt < count; alt++ {
look[alt] = NewIntervalSet()
lookBusy := newArray2DHashSet(nil, nil)
lookBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
seeThruPreds := false // fail to get lookahead upon pred
la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), seeThruPreds, false)
// Wipe out lookahead for la alternative if we found nothing
@ -50,7 +51,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
return look
}
//*
// *
// Compute set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}.
//
@ -67,7 +68,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
//
// @return The set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}.
///
// /
func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet {
r := NewIntervalSet()
seeThruPreds := true // ignore preds get all lookahead
@ -75,7 +76,7 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet
if ctx != nil {
lookContext = predictionContextFromRuleContext(s.GetATN(), ctx)
}
la.look1(s, stopState, lookContext, r, newArray2DHashSet(nil, nil), NewBitSet(), seeThruPreds, true)
la.look1(s, stopState, lookContext, r, NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst), NewBitSet(), seeThruPreds, true)
return r
}
@ -109,14 +110,14 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet
// outermost context is reached. This parameter has no effect if {@code ctx}
// is {@code nil}.
func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) {
func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) {
returnState := la.atn.states[ctx.getReturnState(i)]
la.look1(returnState, stopState, ctx.GetParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF)
}
func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool) {
func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool) {
c := NewBaseATNConfig6(s, 0, ctx)
@ -124,8 +125,11 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
return
}
lookBusy.Add(c)
_, present := lookBusy.Put(c)
if present {
return
}
if s == stopState {
if ctx == nil {
look.addOne(TokenEpsilon)
@ -198,7 +202,7 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
}
}
func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) {
func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) {
newContext := SingletonBasePredictionContextCreate(ctx, t1.followState.GetStateNumber())

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -91,7 +91,6 @@ func NewBaseParser(input TokenStream) *BaseParser {
// bypass alternatives.
//
// @see ATNDeserializationOptions//isGenerateRuleBypassTransitions()
//
var bypassAltsAtnCache = make(map[string]int)
// reset the parser's state//
@ -230,7 +229,6 @@ func (p *BaseParser) GetParseListeners() []ParseTreeListener {
// @param listener the listener to add
//
// @panics nilPointerException if {@code} listener is {@code nil}
//
func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
if listener == nil {
panic("listener")
@ -241,13 +239,11 @@ func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
p.parseListeners = append(p.parseListeners, listener)
}
//
// Remove {@code listener} from the list of parse listeners.
//
// <p>If {@code listener} is {@code nil} or has not been added as a parse
// listener, p.method does nothing.</p>
// @param listener the listener to remove
//
func (p *BaseParser) RemoveParseListener(listener ParseTreeListener) {
if p.parseListeners != nil {
@ -289,11 +285,9 @@ func (p *BaseParser) TriggerEnterRuleEvent() {
}
}
//
// Notify any parse listeners of an exit rule event.
//
// @see //addParseListener
//
func (p *BaseParser) TriggerExitRuleEvent() {
if p.parseListeners != nil {
// reverse order walk of listeners
@ -330,7 +324,6 @@ func (p *BaseParser) setTokenFactory(factory TokenFactory) {
//
// @panics UnsupportedOperationException if the current parser does not
// implement the {@link //getSerializedATN()} method.
//
func (p *BaseParser) GetATNWithBypassAlts() {
// TODO
@ -402,7 +395,6 @@ func (p *BaseParser) SetTokenStream(input TokenStream) {
// Match needs to return the current input symbol, which gets put
// into the label for the associated token ref e.g., x=ID.
//
func (p *BaseParser) GetCurrentToken() Token {
return p.input.LT(1)
}
@ -624,7 +616,6 @@ func (p *BaseParser) IsExpectedToken(symbol int) bool {
// respectively.
//
// @see ATN//getExpectedTokens(int, RuleContext)
//
func (p *BaseParser) GetExpectedTokens() *IntervalSet {
return p.Interpreter.atn.getExpectedTokens(p.state, p.ctx)
}
@ -686,7 +677,7 @@ func (p *BaseParser) GetDFAStrings() string {
func (p *BaseParser) DumpDFA() {
seenOne := false
for _, dfa := range p.Interpreter.decisionToDFA {
if dfa.numStates() > 0 {
if dfa.states.Len() > 0 {
if seenOne {
fmt.Println()
}
@ -703,7 +694,6 @@ func (p *BaseParser) GetSourceName() string {
// During a parse is sometimes useful to listen in on the rule entry and exit
// events as well as token Matches. p.is for quick and dirty debugging.
//
func (p *BaseParser) SetTrace(trace *TraceListener) {
if trace == nil {
p.RemoveParseListener(p.tracer)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -11,11 +11,11 @@ import (
)
var (
ParserATNSimulatorDebug = false
ParserATNSimulatorListATNDecisions = false
ParserATNSimulatorDFADebug = false
ParserATNSimulatorRetryDebug = false
TurnOffLRLoopEntryBranchOpt = false
ParserATNSimulatorDebug = false
ParserATNSimulatorTraceATNSim = false
ParserATNSimulatorDFADebug = false
ParserATNSimulatorRetryDebug = false
TurnOffLRLoopEntryBranchOpt = false
)
type ParserATNSimulator struct {
@ -70,8 +70,8 @@ func (p *ParserATNSimulator) reset() {
}
func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions {
fmt.Println("AdaptivePredict decision " + strconv.Itoa(decision) +
if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("adaptivePredict decision " + strconv.Itoa(decision) +
" exec LA(1)==" + p.getLookaheadName(input) +
" line " + strconv.Itoa(input.LT(1).GetLine()) + ":" +
strconv.Itoa(input.LT(1).GetColumn()))
@ -111,15 +111,15 @@ func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, ou
if s0 == nil {
if outerContext == nil {
outerContext = RuleContextEmpty
outerContext = ParserRuleContextEmpty
}
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions {
if ParserATNSimulatorDebug {
fmt.Println("predictATN decision " + strconv.Itoa(dfa.decision) +
" exec LA(1)==" + p.getLookaheadName(input) +
", outerContext=" + outerContext.String(p.parser.GetRuleNames(), nil))
}
fullCtx := false
s0Closure := p.computeStartState(dfa.atnStartState, RuleContextEmpty, fullCtx)
s0Closure := p.computeStartState(dfa.atnStartState, ParserRuleContextEmpty, fullCtx)
p.atn.stateMu.Lock()
if dfa.getPrecedenceDfa() {
@ -174,17 +174,18 @@ func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, ou
// Reporting insufficient predicates
// cover these cases:
// dead end
// single alt
// single alt + preds
// conflict
// conflict + preds
//
// dead end
// single alt
// single alt + preds
// conflict
// conflict + preds
func (p *ParserATNSimulator) execATN(dfa *DFA, s0 *DFAState, input TokenStream, startIndex int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions {
if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("execATN decision " + strconv.Itoa(dfa.decision) +
" exec LA(1)==" + p.getLookaheadName(input) +
", DFA state " + s0.String() +
", LA(1)==" + p.getLookaheadName(input) +
" line " + strconv.Itoa(input.LT(1).GetLine()) + ":" + strconv.Itoa(input.LT(1).GetColumn()))
}
@ -277,8 +278,6 @@ func (p *ParserATNSimulator) execATN(dfa *DFA, s0 *DFAState, input TokenStream,
t = input.LA(1)
}
}
panic("Should not have reached p state")
}
// Get an existing target state for an edge in the DFA. If the target state
@ -384,7 +383,7 @@ func (p *ParserATNSimulator) predicateDFAState(dfaState *DFAState, decisionState
// comes back with reach.uniqueAlt set to a valid alt
func (p *ParserATNSimulator) execATNWithFullContext(dfa *DFA, D *DFAState, s0 ATNConfigSet, input TokenStream, startIndex int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions {
if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("execATNWithFullContext " + s0.String())
}
@ -492,9 +491,6 @@ func (p *ParserATNSimulator) execATNWithFullContext(dfa *DFA, D *DFAState, s0 AT
}
func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCtx bool) ATNConfigSet {
if ParserATNSimulatorDebug {
fmt.Println("in computeReachSet, starting closure: " + closure.String())
}
if p.mergeCache == nil {
p.mergeCache = NewDoubleDict()
}
@ -570,7 +566,7 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
//
if reach == nil {
reach = NewBaseATNConfigSet(fullCtx)
closureBusy := newArray2DHashSet(nil, nil)
closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
treatEOFAsEpsilon := t == TokenEOF
amount := len(intermediate.configs)
for k := 0; k < amount; k++ {
@ -610,6 +606,11 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
reach.Add(skippedStopStates[l], p.mergeCache)
}
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("computeReachSet " + closure.String() + " -> " + reach.String())
}
if len(reach.GetItems()) == 0 {
return nil
}
@ -617,7 +618,6 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
return reach
}
//
// Return a configuration set containing only the configurations from
// {@code configs} which are in a {@link RuleStopState}. If all
// configurations in {@code configs} are already in a rule stop state, p
@ -636,7 +636,6 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
// @return {@code configs} if all configurations in {@code configs} are in a
// rule stop state, otherwise return a Newconfiguration set containing only
// the configurations from {@code configs} which are in a rule stop state
//
func (p *ParserATNSimulator) removeAllConfigsNotInRuleStopState(configs ATNConfigSet, lookToEndOfRule bool) ATNConfigSet {
if PredictionModeallConfigsInRuleStopStates(configs) {
return configs
@ -662,16 +661,20 @@ func (p *ParserATNSimulator) computeStartState(a ATNState, ctx RuleContext, full
// always at least the implicit call to start rule
initialContext := predictionContextFromRuleContext(p.atn, ctx)
configs := NewBaseATNConfigSet(fullCtx)
if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("computeStartState from ATN state " + a.String() +
" initialContext=" + initialContext.String())
}
for i := 0; i < len(a.GetTransitions()); i++ {
target := a.GetTransitions()[i].getTarget()
c := NewBaseATNConfig6(target, i+1, initialContext)
closureBusy := newArray2DHashSet(nil, nil)
closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](atnConfCompInst)
p.closure(c, configs, closureBusy, true, fullCtx, false)
}
return configs
}
//
// This method transforms the start state computed by
// {@link //computeStartState} to the special start state used by a
// precedence DFA for a particular precedence value. The transformation
@ -726,7 +729,6 @@ func (p *ParserATNSimulator) computeStartState(a ATNState, ctx RuleContext, full
// @return The transformed configuration set representing the start state
// for a precedence DFA at a particular precedence level (determined by
// calling {@link Parser//getPrecedence}).
//
func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConfigSet {
statesFromAlt1 := make(map[int]PredictionContext)
@ -760,7 +762,7 @@ func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConf
// (basically a graph subtraction algorithm).
if !config.getPrecedenceFilterSuppressed() {
context := statesFromAlt1[config.GetState().GetStateNumber()]
if context != nil && context.equals(config.GetContext()) {
if context != nil && context.Equals(config.GetContext()) {
// eliminated
continue
}
@ -824,7 +826,6 @@ func (p *ParserATNSimulator) getPredicatePredictions(ambigAlts *BitSet, altToPre
return pairs
}
//
// This method is used to improve the localization of error messages by
// choosing an alternative rather than panicing a
// {@link NoViableAltException} in particular prediction scenarios where the
@ -869,7 +870,6 @@ func (p *ParserATNSimulator) getPredicatePredictions(ambigAlts *BitSet, altToPre
// @return The value to return from {@link //AdaptivePredict}, or
// {@link ATN//INVALID_ALT_NUMBER} if a suitable alternative was not
// identified and {@link //AdaptivePredict} should Report an error instead.
//
func (p *ParserATNSimulator) getSynValidOrSemInvalidAltThatFinishedDecisionEntryRule(configs ATNConfigSet, outerContext ParserRuleContext) int {
cfgs := p.splitAccordingToSemanticValidity(configs, outerContext)
semValidConfigs := cfgs[0]
@ -938,11 +938,11 @@ func (p *ParserATNSimulator) splitAccordingToSemanticValidity(configs ATNConfigS
}
// Look through a list of predicate/alt pairs, returning alts for the
// pairs that win. A {@code NONE} predicate indicates an alt containing an
// unpredicated config which behaves as "always true." If !complete
// then we stop at the first predicate that evaluates to true. This
// includes pairs with nil predicates.
//
// pairs that win. A {@code NONE} predicate indicates an alt containing an
// unpredicated config which behaves as "always true." If !complete
// then we stop at the first predicate that evaluates to true. This
// includes pairs with nil predicates.
func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPrediction, outerContext ParserRuleContext, complete bool) *BitSet {
predictions := NewBitSet()
for i := 0; i < len(predPredictions); i++ {
@ -972,16 +972,16 @@ func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPredicti
return predictions
}
func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx, treatEOFAsEpsilon bool) {
func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx, treatEOFAsEpsilon bool) {
initialDepth := 0
p.closureCheckingStopState(config, configs, closureBusy, collectPredicates,
fullCtx, initialDepth, treatEOFAsEpsilon)
}
func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
if ParserATNSimulatorDebug {
func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
if ParserATNSimulatorTraceATNSim {
fmt.Println("closure(" + config.String() + ")")
fmt.Println("configs(" + configs.String() + ")")
//fmt.Println("configs(" + configs.String() + ")")
if config.GetReachesIntoOuterContext() > 50 {
panic("problem")
}
@ -1031,7 +1031,7 @@ func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs
}
// Do the actual work of walking epsilon edges//
func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
state := config.GetState()
// optimization
if !state.GetEpsilonOnlyTransitions() {
@ -1066,7 +1066,8 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet,
c.SetReachesIntoOuterContext(c.GetReachesIntoOuterContext() + 1)
if closureBusy.Add(c) != c {
_, present := closureBusy.Put(c)
if present {
// avoid infinite recursion for right-recursive rules
continue
}
@ -1077,9 +1078,13 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet,
fmt.Println("dips into outer ctx: " + c.String())
}
} else {
if !t.getIsEpsilon() && closureBusy.Add(c) != c {
// avoid infinite recursion for EOF* and EOF+
continue
if !t.getIsEpsilon() {
_, present := closureBusy.Put(c)
if present {
// avoid infinite recursion for EOF* and EOF+
continue
}
}
if _, ok := t.(*RuleTransition); ok {
// latch when newDepth goes negative - once we step out of the entry context we can't return
@ -1104,7 +1109,16 @@ func (p *ParserATNSimulator) canDropLoopEntryEdgeInLeftRecursiveRule(config ATNC
// left-recursion elimination. For efficiency, also check if
// the context has an empty stack case. If so, it would mean
// global FOLLOW so we can't perform optimization
if startLoop, ok := _p.(StarLoopEntryState); !ok || !startLoop.precedenceRuleDecision || config.GetContext().isEmpty() || config.GetContext().hasEmptyPath() {
if _p.GetStateType() != ATNStateStarLoopEntry {
return false
}
startLoop, ok := _p.(*StarLoopEntryState)
if !ok {
return false
}
if !startLoop.precedenceRuleDecision ||
config.GetContext().isEmpty() ||
config.GetContext().hasEmptyPath() {
return false
}
@ -1117,8 +1131,8 @@ func (p *ParserATNSimulator) canDropLoopEntryEdgeInLeftRecursiveRule(config ATNC
return false
}
}
decisionStartState := _p.(BlockStartState).GetTransitions()[0].getTarget().(BlockStartState)
x := _p.GetTransitions()[0].getTarget()
decisionStartState := x.(BlockStartState)
blockEndStateNum := decisionStartState.getEndState().stateNumber
blockEndState := p.atn.states[blockEndStateNum].(*BlockEndState)
@ -1355,13 +1369,12 @@ func (p *ParserATNSimulator) GetTokenName(t int) string {
return "EOF"
}
if p.parser != nil && p.parser.GetLiteralNames() != nil {
if t >= len(p.parser.GetLiteralNames()) {
fmt.Println(strconv.Itoa(t) + " ttype out of range: " + strings.Join(p.parser.GetLiteralNames(), ","))
// fmt.Println(p.parser.GetInputStream().(TokenStream).GetAllText()) // p seems incorrect
} else {
return p.parser.GetLiteralNames()[t] + "<" + strconv.Itoa(t) + ">"
}
if p.parser != nil && p.parser.GetLiteralNames() != nil && t < len(p.parser.GetLiteralNames()) {
return p.parser.GetLiteralNames()[t] + "<" + strconv.Itoa(t) + ">"
}
if p.parser != nil && p.parser.GetLiteralNames() != nil && t < len(p.parser.GetSymbolicNames()) {
return p.parser.GetSymbolicNames()[t] + "<" + strconv.Itoa(t) + ">"
}
return strconv.Itoa(t)
@ -1372,9 +1385,9 @@ func (p *ParserATNSimulator) getLookaheadName(input TokenStream) string {
}
// Used for debugging in AdaptivePredict around execATN but I cut
// it out for clarity now that alg. works well. We can leave p
// "dead" code for a bit.
//
// it out for clarity now that alg. works well. We can leave p
// "dead" code for a bit.
func (p *ParserATNSimulator) dumpDeadEndConfigs(nvae *NoViableAltException) {
panic("Not implemented")
@ -1421,7 +1434,6 @@ func (p *ParserATNSimulator) getUniqueAlt(configs ATNConfigSet) int {
return alt
}
//
// Add an edge to the DFA, if possible. This method calls
// {@link //addDFAState} to ensure the {@code to} state is present in the
// DFA. If {@code from} is {@code nil}, or if {@code t} is outside the
@ -1440,7 +1452,6 @@ func (p *ParserATNSimulator) getUniqueAlt(configs ATNConfigSet) int {
// @return If {@code to} is {@code nil}, p method returns {@code nil}
// otherwise p method returns the result of calling {@link //addDFAState}
// on {@code to}
//
func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFAState) *DFAState {
if ParserATNSimulatorDebug {
fmt.Println("EDGE " + from.String() + " -> " + to.String() + " upon " + p.GetTokenName(t))
@ -1472,7 +1483,6 @@ func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFA
return to
}
//
// Add state {@code D} to the DFA if it is not already present, and return
// the actual instance stored in the DFA. If a state equivalent to {@code D}
// is already in the DFA, the existing state is returned. Otherwise p
@ -1486,25 +1496,30 @@ func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFA
// @return The state stored in the DFA. This will be either the existing
// state if {@code D} is already in the DFA, or {@code D} itself if the
// state was not already present.
//
func (p *ParserATNSimulator) addDFAState(dfa *DFA, d *DFAState) *DFAState {
if d == ATNSimulatorError {
return d
}
hash := d.hash()
existing, ok := dfa.getState(hash)
if ok {
existing, present := dfa.states.Get(d)
if present {
if ParserATNSimulatorTraceATNSim {
fmt.Print("addDFAState " + d.String() + " exists")
}
return existing
}
d.stateNumber = dfa.numStates()
// The state was not present, so update it with configs
//
d.stateNumber = dfa.states.Len()
if !d.configs.ReadOnly() {
d.configs.OptimizeConfigs(p.BaseATNSimulator)
d.configs.SetReadOnly(true)
}
dfa.setState(hash, d)
if ParserATNSimulatorDebug {
fmt.Println("adding NewDFA state: " + d.String())
dfa.states.Put(d)
if ParserATNSimulatorTraceATNSim {
fmt.Println("addDFAState new " + d.String())
}
return d
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -340,7 +340,7 @@ func (prc *BaseParserRuleContext) String(ruleNames []string, stop RuleContext) s
return s
}
var RuleContextEmpty = NewBaseParserRuleContext(nil, -1)
var ParserRuleContextEmpty = NewBaseParserRuleContext(nil, -1)
type InterpreterRuleContext interface {
ParserRuleContext

View File

@ -1,10 +1,12 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"fmt"
"golang.org/x/exp/slices"
"strconv"
)
@ -26,10 +28,10 @@ var (
)
type PredictionContext interface {
hash() int
Hash() int
Equals(interface{}) bool
GetParent(int) PredictionContext
getReturnState(int) int
equals(PredictionContext) bool
length() int
isEmpty() bool
hasEmptyPath() bool
@ -53,7 +55,7 @@ func (b *BasePredictionContext) isEmpty() bool {
func calculateHash(parent PredictionContext, returnState int) int {
h := murmurInit(1)
h = murmurUpdate(h, parent.hash())
h = murmurUpdate(h, parent.Hash())
h = murmurUpdate(h, returnState)
return murmurFinish(h, 2)
}
@ -86,7 +88,6 @@ func NewPredictionContextCache() *PredictionContextCache {
// Add a context to the cache and return it. If the context already exists,
// return that one instead and do not add a Newcontext to the cache.
// Protect shared cache from unsafe thread access.
//
func (p *PredictionContextCache) add(ctx PredictionContext) PredictionContext {
if ctx == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY
@ -160,28 +161,28 @@ func (b *BaseSingletonPredictionContext) hasEmptyPath() bool {
return b.returnState == BasePredictionContextEmptyReturnState
}
func (b *BaseSingletonPredictionContext) equals(other PredictionContext) bool {
func (b *BaseSingletonPredictionContext) Hash() int {
return b.cachedHash
}
func (b *BaseSingletonPredictionContext) Equals(other interface{}) bool {
if b == other {
return true
} else if _, ok := other.(*BaseSingletonPredictionContext); !ok {
}
if _, ok := other.(*BaseSingletonPredictionContext); !ok {
return false
} else if b.hash() != other.hash() {
return false // can't be same if hash is different
}
otherP := other.(*BaseSingletonPredictionContext)
if b.returnState != other.getReturnState(0) {
if b.returnState != otherP.getReturnState(0) {
return false
} else if b.parentCtx == nil {
}
if b.parentCtx == nil {
return otherP.parentCtx == nil
}
return b.parentCtx.equals(otherP.parentCtx)
}
func (b *BaseSingletonPredictionContext) hash() int {
return b.cachedHash
return b.parentCtx.Equals(otherP.parentCtx)
}
func (b *BaseSingletonPredictionContext) String() string {
@ -215,7 +216,7 @@ func NewEmptyPredictionContext() *EmptyPredictionContext {
p := new(EmptyPredictionContext)
p.BaseSingletonPredictionContext = NewBaseSingletonPredictionContext(nil, BasePredictionContextEmptyReturnState)
p.cachedHash = calculateEmptyHash()
return p
}
@ -231,7 +232,11 @@ func (e *EmptyPredictionContext) getReturnState(index int) int {
return e.returnState
}
func (e *EmptyPredictionContext) equals(other PredictionContext) bool {
func (e *EmptyPredictionContext) Hash() int {
return e.cachedHash
}
func (e *EmptyPredictionContext) Equals(other interface{}) bool {
return e == other
}
@ -254,7 +259,7 @@ func NewArrayPredictionContext(parents []PredictionContext, returnStates []int)
hash := murmurInit(1)
for _, parent := range parents {
hash = murmurUpdate(hash, parent.hash())
hash = murmurUpdate(hash, parent.Hash())
}
for _, returnState := range returnStates {
@ -298,18 +303,31 @@ func (a *ArrayPredictionContext) getReturnState(index int) int {
return a.returnStates[index]
}
func (a *ArrayPredictionContext) equals(other PredictionContext) bool {
if _, ok := other.(*ArrayPredictionContext); !ok {
return false
} else if a.cachedHash != other.hash() {
return false // can't be same if hash is different
} else {
otherP := other.(*ArrayPredictionContext)
return &a.returnStates == &otherP.returnStates && &a.parents == &otherP.parents
// Equals is the default comparison function for ArrayPredictionContext when no specialized
// implementation is needed for a collection
func (a *ArrayPredictionContext) Equals(o interface{}) bool {
if a == o {
return true
}
other, ok := o.(*ArrayPredictionContext)
if !ok {
return false
}
if a.cachedHash != other.Hash() {
return false // can't be same if hash is different
}
// Must compare the actual array elements and not just the array address
//
return slices.Equal(a.returnStates, other.returnStates) &&
slices.EqualFunc(a.parents, other.parents, func(x, y PredictionContext) bool {
return x.Equals(y)
})
}
func (a *ArrayPredictionContext) hash() int {
// Hash is the default hash function for ArrayPredictionContext when no specialized
// implementation is needed for a collection
func (a *ArrayPredictionContext) Hash() int {
return a.BasePredictionContext.cachedHash
}
@ -343,11 +361,11 @@ func (a *ArrayPredictionContext) String() string {
// /
func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) PredictionContext {
if outerContext == nil {
outerContext = RuleContextEmpty
outerContext = ParserRuleContextEmpty
}
// if we are in RuleContext of start rule, s, then BasePredictionContext
// is EMPTY. Nobody called us. (if we are empty, return empty)
if outerContext.GetParent() == nil || outerContext == RuleContextEmpty {
if outerContext.GetParent() == nil || outerContext == ParserRuleContextEmpty {
return BasePredictionContextEMPTY
}
// If we have a parent, convert it to a BasePredictionContext graph
@ -359,11 +377,20 @@ func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) Predicti
}
func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
// share same graph if both same
if a == b {
// Share same graph if both same
//
if a == b || a.Equals(b) {
return a
}
// In Java, EmptyPredictionContext inherits from SingletonPredictionContext, and so the test
// in java for SingletonPredictionContext will succeed and a new ArrayPredictionContext will be created
// from it.
// In go, EmptyPredictionContext does not equate to SingletonPredictionContext and so that conversion
// will fail. We need to test for both Empty and Singleton and create an ArrayPredictionContext from
// either of them.
ac, ok1 := a.(*BaseSingletonPredictionContext)
bc, ok2 := b.(*BaseSingletonPredictionContext)
@ -380,17 +407,32 @@ func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict)
return b
}
}
// convert singleton so both are arrays to normalize
if _, ok := a.(*BaseSingletonPredictionContext); ok {
a = NewArrayPredictionContext([]PredictionContext{a.GetParent(0)}, []int{a.getReturnState(0)})
// Convert Singleton or Empty so both are arrays to normalize - We should not use the existing parameters
// here.
//
// TODO: I think that maybe the Prediction Context structs should be redone as there is a chance we will see this mess again - maybe redo the logic here
var arp, arb *ArrayPredictionContext
var ok bool
if arp, ok = a.(*ArrayPredictionContext); ok {
} else if _, ok = a.(*BaseSingletonPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{a.GetParent(0)}, []int{a.getReturnState(0)})
} else if _, ok = a.(*EmptyPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{}, []int{})
}
if _, ok := b.(*BaseSingletonPredictionContext); ok {
b = NewArrayPredictionContext([]PredictionContext{b.GetParent(0)}, []int{b.getReturnState(0)})
if arb, ok = b.(*ArrayPredictionContext); ok {
} else if _, ok = b.(*BaseSingletonPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{b.GetParent(0)}, []int{b.getReturnState(0)})
} else if _, ok = b.(*EmptyPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{}, []int{})
}
return mergeArrays(a.(*ArrayPredictionContext), b.(*ArrayPredictionContext), rootIsWildcard, mergeCache)
// Both arp and arb
return mergeArrays(arp, arb, rootIsWildcard, mergeCache)
}
//
// Merge two {@link SingletonBasePredictionContext} instances.
//
// <p>Stack tops equal, parents merge is same return left graph.<br>
@ -423,11 +465,11 @@ func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict)
// /
func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil {
previous := mergeCache.Get(a.hash(), b.hash())
previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil {
return previous.(PredictionContext)
}
previous = mergeCache.Get(b.hash(), a.hash())
previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil {
return previous.(PredictionContext)
}
@ -436,7 +478,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
rootMerge := mergeRoot(a, b, rootIsWildcard)
if rootMerge != nil {
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), rootMerge)
mergeCache.set(a.Hash(), b.Hash(), rootMerge)
}
return rootMerge
}
@ -456,7 +498,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
// Newjoined parent so create Newsingleton pointing to it, a'
spc := SingletonBasePredictionContextCreate(parent, a.returnState)
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), spc)
mergeCache.set(a.Hash(), b.Hash(), spc)
}
return spc
}
@ -478,7 +520,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
parents := []PredictionContext{singleParent, singleParent}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), apc)
mergeCache.set(a.Hash(), b.Hash(), apc)
}
return apc
}
@ -494,12 +536,11 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
}
apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), apc)
mergeCache.set(a.Hash(), b.Hash(), apc)
}
return apc
}
//
// Handle case where at least one of {@code a} or {@code b} is
// {@link //EMPTY}. In the following diagrams, the symbol {@code $} is used
// to represent {@link //EMPTY}.
@ -561,7 +602,6 @@ func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionC
return nil
}
//
// Merge two {@link ArrayBasePredictionContext} instances.
//
// <p>Different tops, different parents.<br>
@ -583,12 +623,18 @@ func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionC
// /
func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil {
previous := mergeCache.Get(a.hash(), b.hash())
previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext)
}
previous = mergeCache.Get(b.hash(), a.hash())
previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext)
}
}
@ -608,7 +654,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
payload := a.returnStates[i]
// $+$ = $
bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil
axAX := (aParent != nil && bParent != nil && aParent == bParent) // ax+ax
axAX := aParent != nil && bParent != nil && aParent == bParent // ax+ax
// ->
// ax
if bothDollars || axAX {
@ -651,7 +697,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
if k == 1 { // for just one merged element, return singleton top
pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0])
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), pc)
mergeCache.set(a.Hash(), b.Hash(), pc)
}
return pc
}
@ -663,27 +709,36 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
// if we created same array as a or b, return that instead
// TODO: track whether this is possible above during merge sort for speed
// TODO: In go, I do not think we can just do M == xx as M is a brand new allocation. This could be causing allocation problems
if M == a {
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), a)
mergeCache.set(a.Hash(), b.Hash(), a)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> a")
}
return a
}
if M == b {
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), b)
mergeCache.set(a.Hash(), b.Hash(), b)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> b")
}
return b
}
combineCommonParents(mergedParents)
if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), M)
mergeCache.set(a.Hash(), b.Hash(), M)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> " + M.String())
}
return M
}
//
// Make pass over all <em>M</em> {@code parents} merge any {@code equals()}
// ones.
// /

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -70,7 +70,6 @@ const (
PredictionModeLLExactAmbigDetection = 2
)
//
// Computes the SLL prediction termination condition.
//
// <p>
@ -108,9 +107,9 @@ const (
// The single-alt-state thing lets prediction continue upon rules like
// (otherwise, it would admit defeat too soon):</p>
//
// <p>{@code [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) '' }</p>
// <p>{@code [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) }</p>
//
// <p>When the ATN simulation reaches the state before {@code ''}, it has a
// <p>When the ATN simulation reaches the state before {@code }, it has a
// DFA state that looks like: {@code [12|1|[], 6|2|[], 12|2|[]]}. Naturally
// {@code 12|1|[]} and {@code 12|2|[]} conflict, but we cannot stop
// processing this node because alternative to has another way to continue,
@ -152,16 +151,15 @@ const (
//
// <p>Before testing these configurations against others, we have to merge
// {@code x} and {@code x'} (without modifying the existing configurations).
// For example, we test {@code (x+x')==x''} when looking for conflicts in
// For example, we test {@code (x+x')==x} when looking for conflicts in
// the following configurations.</p>
//
// <p>{@code (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x'', {})}</p>
// <p>{@code (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x, {})}</p>
//
// <p>If the configuration set has predicates (as indicated by
// {@link ATNConfigSet//hasSemanticContext}), this algorithm makes a copy of
// the configurations to strip out all of the predicates so that a standard
// {@link ATNConfigSet} will merge everything ignoring predicates.</p>
//
func PredictionModehasSLLConflictTerminatingPrediction(mode int, configs ATNConfigSet) bool {
// Configs in rule stop states indicate reaching the end of the decision
// rule (local context) or end of start rule (full context). If all
@ -229,7 +227,6 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
return true
}
//
// Full LL prediction termination.
//
// <p>Can we stop looking ahead during ATN simulation or is there some
@ -334,7 +331,7 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
// </li>
//
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 1, y)},
// {@code (s', 2, y)}, {@code (s'', 1, z)} yields non-conflicting set
// {@code (s', 2, y)}, {@code (s, 1, z)} yields non-conflicting set
// {@code {1}} U conflicting sets {@code min({1,2})} U {@code min({1,2})} =
// {@code {1}} =&gt stop and predict 1</li>
//
@ -369,31 +366,26 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
// two or one and three so we keep going. We can only stop prediction when
// we need exact ambiguity detection when the sets look like
// {@code A={{1,2}}} or {@code {{1,2},{1,2}}}, etc...</p>
//
func PredictionModeresolvesToJustOneViableAlt(altsets []*BitSet) int {
return PredictionModegetSingleViableAlt(altsets)
}
//
// Determines if every alternative subset in {@code altsets} contains more
// than one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if every {@link BitSet} in {@code altsets} has
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
//
func PredictionModeallSubsetsConflict(altsets []*BitSet) bool {
return !PredictionModehasNonConflictingAltSet(altsets)
}
//
// Determines if any single alternative subset in {@code altsets} contains
// exactly one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} 1, otherwise {@code false}
//
func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
@ -404,14 +396,12 @@ func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
return false
}
//
// Determines if any single alternative subset in {@code altsets} contains
// more than one alternative.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
//
func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ {
alts := altsets[i]
@ -422,13 +412,11 @@ func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
return false
}
//
// Determines if every alternative subset in {@code altsets} is equivalent.
//
// @param altsets a collection of alternative subsets
// @return {@code true} if every member of {@code altsets} is equal to the
// others, otherwise {@code false}
//
func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
var first *BitSet
@ -444,13 +432,11 @@ func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
return true
}
//
// Returns the unique alternative predicted by all alternative subsets in
// {@code altsets}. If no such alternative exists, this method returns
// {@link ATN//INVALID_ALT_NUMBER}.
//
// @param altsets a collection of alternative subsets
//
func PredictionModegetUniqueAlt(altsets []*BitSet) int {
all := PredictionModeGetAlts(altsets)
if all.length() == 1 {
@ -466,7 +452,6 @@ func PredictionModegetUniqueAlt(altsets []*BitSet) int {
//
// @param altsets a collection of alternative subsets
// @return the set of represented alternatives in {@code altsets}
//
func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
all := NewBitSet()
for _, alts := range altsets {
@ -475,44 +460,35 @@ func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
return all
}
//
// This func gets the conflicting alt subsets from a configuration set.
// PredictionModegetConflictingAltSubsets gets the conflicting alt subsets from a configuration set.
// For each configuration {@code c} in {@code configs}:
//
// <pre>
// map[c] U= c.{@link ATNConfig//alt alt} // map hash/equals uses s and x, not
// alt and not pred
// </pre>
//
func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet {
configToAlts := make(map[int]*BitSet)
configToAlts := NewJMap[ATNConfig, *BitSet, *ATNAltConfigComparator[ATNConfig]](atnAltCfgEqInst)
for _, c := range configs.GetItems() {
key := 31 * c.GetState().GetStateNumber() + c.GetContext().hash()
alts, ok := configToAlts[key]
alts, ok := configToAlts.Get(c)
if !ok {
alts = NewBitSet()
configToAlts[key] = alts
configToAlts.Put(c, alts)
}
alts.add(c.GetAlt())
}
values := make([]*BitSet, 0, 10)
for _, v := range configToAlts {
values = append(values, v)
}
return values
return configToAlts.Values()
}
//
// Get a map from state to alt subset from a configuration set. For each
// PredictionModeGetStateToAltMap gets a map from state to alt subset from a configuration set. For each
// configuration {@code c} in {@code configs}:
//
// <pre>
// map[c.{@link ATNConfig//state state}] U= c.{@link ATNConfig//alt alt}
// </pre>
//
func PredictionModeGetStateToAltMap(configs ATNConfigSet) *AltDict {
m := NewAltDict()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -49,7 +49,7 @@ var tokenTypeMapCache = make(map[string]int)
var ruleIndexMapCache = make(map[string]int)
func (b *BaseRecognizer) checkVersion(toolVersion string) {
runtimeVersion := "4.10.1"
runtimeVersion := "4.12.0"
if runtimeVersion != toolVersion {
fmt.Println("ANTLR runtime and generated code versions disagree: " + runtimeVersion + "!=" + toolVersion)
}
@ -108,7 +108,6 @@ func (b *BaseRecognizer) SetState(v int) {
// Get a map from rule names to rule indexes.
//
// <p>Used for XPath and tree pattern compilation.</p>
//
func (b *BaseRecognizer) GetRuleIndexMap() map[string]int {
panic("Method not defined!")
@ -171,18 +170,18 @@ func (b *BaseRecognizer) GetErrorHeader(e RecognitionException) string {
}
// How should a token be displayed in an error message? The default
// is to display just the text, but during development you might
// want to have a lot of information spit out. Override in that case
// to use t.String() (which, for CommonToken, dumps everything about
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
//
// is to display just the text, but during development you might
// want to have a lot of information spit out. Override in that case
// to use t.String() (which, for CommonToken, dumps everything about
// the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
//
// @deprecated This method is not called by the ANTLR 4 Runtime. Specific
// implementations of {@link ANTLRErrorStrategy} may provide a similar
// feature when necessary. For example, see
// {@link DefaultErrorStrategy//GetTokenErrorDisplay}.
//
func (b *BaseRecognizer) GetTokenErrorDisplay(t Token) string {
if t == nil {
return "<no token>"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -18,12 +18,12 @@ import (
//
type SemanticContext interface {
comparable
Equals(other Collectable[SemanticContext]) bool
Hash() int
evaluate(parser Recognizer, outerContext RuleContext) bool
evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext
hash() int
String() string
}
@ -78,7 +78,7 @@ func NewPredicate(ruleIndex, predIndex int, isCtxDependent bool) *Predicate {
//The default {@link SemanticContext}, which is semantically equivalent to
//a predicate of the form {@code {true}?}.
var SemanticContextNone SemanticContext = NewPredicate(-1, -1, false)
var SemanticContextNone = NewPredicate(-1, -1, false)
func (p *Predicate) evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext {
return p
@ -95,7 +95,7 @@ func (p *Predicate) evaluate(parser Recognizer, outerContext RuleContext) bool {
return parser.Sempred(localctx, p.ruleIndex, p.predIndex)
}
func (p *Predicate) equals(other interface{}) bool {
func (p *Predicate) Equals(other Collectable[SemanticContext]) bool {
if p == other {
return true
} else if _, ok := other.(*Predicate); !ok {
@ -107,7 +107,7 @@ func (p *Predicate) equals(other interface{}) bool {
}
}
func (p *Predicate) hash() int {
func (p *Predicate) Hash() int {
h := murmurInit(0)
h = murmurUpdate(h, p.ruleIndex)
h = murmurUpdate(h, p.predIndex)
@ -151,17 +151,22 @@ func (p *PrecedencePredicate) compareTo(other *PrecedencePredicate) int {
return p.precedence - other.precedence
}
func (p *PrecedencePredicate) equals(other interface{}) bool {
if p == other {
return true
} else if _, ok := other.(*PrecedencePredicate); !ok {
func (p *PrecedencePredicate) Equals(other Collectable[SemanticContext]) bool {
var op *PrecedencePredicate
var ok bool
if op, ok = other.(*PrecedencePredicate); !ok {
return false
} else {
return p.precedence == other.(*PrecedencePredicate).precedence
}
if p == op {
return true
}
return p.precedence == other.(*PrecedencePredicate).precedence
}
func (p *PrecedencePredicate) hash() int {
func (p *PrecedencePredicate) Hash() int {
h := uint32(1)
h = 31*h + uint32(p.precedence)
return int(h)
@ -171,10 +176,10 @@ func (p *PrecedencePredicate) String() string {
return "{" + strconv.Itoa(p.precedence) + ">=prec}?"
}
func PrecedencePredicatefilterPrecedencePredicates(set Set) []*PrecedencePredicate {
func PrecedencePredicatefilterPrecedencePredicates(set *JStore[SemanticContext, Comparator[SemanticContext]]) []*PrecedencePredicate {
result := make([]*PrecedencePredicate, 0)
set.Each(func(v interface{}) bool {
set.Each(func(v SemanticContext) bool {
if c2, ok := v.(*PrecedencePredicate); ok {
result = append(result, c2)
}
@ -193,21 +198,21 @@ type AND struct {
func NewAND(a, b SemanticContext) *AND {
operands := newArray2DHashSet(nil, nil)
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
if aa, ok := a.(*AND); ok {
for _, o := range aa.opnds {
operands.Add(o)
operands.Put(o)
}
} else {
operands.Add(a)
operands.Put(a)
}
if ba, ok := b.(*AND); ok {
for _, o := range ba.opnds {
operands.Add(o)
operands.Put(o)
}
} else {
operands.Add(b)
operands.Put(b)
}
precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands)
if len(precedencePredicates) > 0 {
@ -220,7 +225,7 @@ func NewAND(a, b SemanticContext) *AND {
}
}
operands.Add(reduced)
operands.Put(reduced)
}
vs := operands.Values()
@ -235,14 +240,15 @@ func NewAND(a, b SemanticContext) *AND {
return and
}
func (a *AND) equals(other interface{}) bool {
func (a *AND) Equals(other Collectable[SemanticContext]) bool {
if a == other {
return true
} else if _, ok := other.(*AND); !ok {
}
if _, ok := other.(*AND); !ok {
return false
} else {
for i, v := range other.(*AND).opnds {
if !a.opnds[i].equals(v) {
if !a.opnds[i].Equals(v) {
return false
}
}
@ -250,13 +256,11 @@ func (a *AND) equals(other interface{}) bool {
}
}
//
// {@inheritDoc}
//
// <p>
// The evaluation of predicates by a context is short-circuiting, but
// unordered.</p>
//
func (a *AND) evaluate(parser Recognizer, outerContext RuleContext) bool {
for i := 0; i < len(a.opnds); i++ {
if !a.opnds[i].evaluate(parser, outerContext) {
@ -304,18 +308,18 @@ func (a *AND) evalPrecedence(parser Recognizer, outerContext RuleContext) Semant
return result
}
func (a *AND) hash() int {
func (a *AND) Hash() int {
h := murmurInit(37) // Init with a value different from OR
for _, op := range a.opnds {
h = murmurUpdate(h, op.hash())
h = murmurUpdate(h, op.Hash())
}
return murmurFinish(h, len(a.opnds))
}
func (a *OR) hash() int {
func (a *OR) Hash() int {
h := murmurInit(41) // Init with a value different from AND
for _, op := range a.opnds {
h = murmurUpdate(h, op.hash())
h = murmurUpdate(h, op.Hash())
}
return murmurFinish(h, len(a.opnds))
}
@ -345,21 +349,21 @@ type OR struct {
func NewOR(a, b SemanticContext) *OR {
operands := newArray2DHashSet(nil, nil)
operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
if aa, ok := a.(*OR); ok {
for _, o := range aa.opnds {
operands.Add(o)
operands.Put(o)
}
} else {
operands.Add(a)
operands.Put(a)
}
if ba, ok := b.(*OR); ok {
for _, o := range ba.opnds {
operands.Add(o)
operands.Put(o)
}
} else {
operands.Add(b)
operands.Put(b)
}
precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands)
if len(precedencePredicates) > 0 {
@ -372,7 +376,7 @@ func NewOR(a, b SemanticContext) *OR {
}
}
operands.Add(reduced)
operands.Put(reduced)
}
vs := operands.Values()
@ -388,14 +392,14 @@ func NewOR(a, b SemanticContext) *OR {
return o
}
func (o *OR) equals(other interface{}) bool {
func (o *OR) Equals(other Collectable[SemanticContext]) bool {
if o == other {
return true
} else if _, ok := other.(*OR); !ok {
return false
} else {
for i, v := range other.(*OR).opnds {
if !o.opnds[i].equals(v) {
if !o.opnds[i].Equals(v) {
return false
}
}
@ -406,7 +410,6 @@ func (o *OR) equals(other interface{}) bool {
// <p>
// The evaluation of predicates by o context is short-circuiting, but
// unordered.</p>
//
func (o *OR) evaluate(parser Recognizer, outerContext RuleContext) bool {
for i := 0; i < len(o.opnds); i++ {
if o.opnds[i].evaluate(parser, outerContext) {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -158,7 +158,6 @@ func NewCommonToken(source *TokenSourceCharStreamPair, tokenType, channel, start
// {@link Token//GetInputStream}.</p>
//
// @param oldToken The token to copy.
//
func (c *CommonToken) clone() *CommonToken {
t := NewCommonToken(c.source, c.tokenType, c.channel, c.start, c.stop)
t.tokenIndex = c.GetTokenIndex()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,14 +1,14 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
package antlr
import (
"bytes"
"fmt"
"bytes"
"fmt"
)
//
// Useful for rewriting out a buffered input token stream after doing some
// augmentation or other manipulations on it.
@ -85,12 +85,10 @@ import (
// If you don't use named rewrite streams, a "default" stream is used as the
// first example shows.</p>
const(
const (
Default_Program_Name = "default"
Program_Init_Size = 100
Min_Token_Index = 0
Program_Init_Size = 100
Min_Token_Index = 0
)
// Define the rewrite operation hierarchy
@ -98,13 +96,13 @@ const(
type RewriteOperation interface {
// Execute the rewrite operation by possibly adding to the buffer.
// Return the index of the next token to operate on.
Execute(buffer *bytes.Buffer) int
String() string
GetInstructionIndex() int
GetIndex() int
GetText() string
GetOpName() string
GetTokens() TokenStream
Execute(buffer *bytes.Buffer) int
String() string
GetInstructionIndex() int
GetIndex() int
GetText() string
GetOpName() string
GetTokens() TokenStream
SetInstructionIndex(val int)
SetIndex(int)
SetText(string)
@ -114,63 +112,62 @@ type RewriteOperation interface {
type BaseRewriteOperation struct {
//Current index of rewrites list
instruction_index int
instruction_index int
//Token buffer index
index int
index int
//Substitution text
text string
text string
//Actual operation name
op_name string
op_name string
//Pointer to token steam
tokens TokenStream
tokens TokenStream
}
func (op *BaseRewriteOperation)GetInstructionIndex() int{
func (op *BaseRewriteOperation) GetInstructionIndex() int {
return op.instruction_index
}
func (op *BaseRewriteOperation)GetIndex() int{
func (op *BaseRewriteOperation) GetIndex() int {
return op.index
}
func (op *BaseRewriteOperation)GetText() string{
func (op *BaseRewriteOperation) GetText() string {
return op.text
}
func (op *BaseRewriteOperation)GetOpName() string{
func (op *BaseRewriteOperation) GetOpName() string {
return op.op_name
}
func (op *BaseRewriteOperation)GetTokens() TokenStream{
func (op *BaseRewriteOperation) GetTokens() TokenStream {
return op.tokens
}
func (op *BaseRewriteOperation)SetInstructionIndex(val int){
func (op *BaseRewriteOperation) SetInstructionIndex(val int) {
op.instruction_index = val
}
func (op *BaseRewriteOperation)SetIndex(val int) {
func (op *BaseRewriteOperation) SetIndex(val int) {
op.index = val
}
func (op *BaseRewriteOperation)SetText(val string){
func (op *BaseRewriteOperation) SetText(val string) {
op.text = val
}
func (op *BaseRewriteOperation)SetOpName(val string){
func (op *BaseRewriteOperation) SetOpName(val string) {
op.op_name = val
}
func (op *BaseRewriteOperation)SetTokens(val TokenStream) {
func (op *BaseRewriteOperation) SetTokens(val TokenStream) {
op.tokens = val
}
func (op *BaseRewriteOperation) Execute(buffer *bytes.Buffer) int{
func (op *BaseRewriteOperation) Execute(buffer *bytes.Buffer) int {
return op.index
}
func (op *BaseRewriteOperation) String() string {
func (op *BaseRewriteOperation) String() string {
return fmt.Sprintf("<%s@%d:\"%s\">",
op.op_name,
op.tokens.Get(op.GetIndex()),
@ -179,26 +176,25 @@ func (op *BaseRewriteOperation) String() string {
}
type InsertBeforeOp struct {
BaseRewriteOperation
}
func NewInsertBeforeOp(index int, text string, stream TokenStream) *InsertBeforeOp{
return &InsertBeforeOp{BaseRewriteOperation:BaseRewriteOperation{
index:index,
text:text,
op_name:"InsertBeforeOp",
tokens:stream,
func NewInsertBeforeOp(index int, text string, stream TokenStream) *InsertBeforeOp {
return &InsertBeforeOp{BaseRewriteOperation: BaseRewriteOperation{
index: index,
text: text,
op_name: "InsertBeforeOp",
tokens: stream,
}}
}
func (op *InsertBeforeOp) Execute(buffer *bytes.Buffer) int{
func (op *InsertBeforeOp) Execute(buffer *bytes.Buffer) int {
buffer.WriteString(op.text)
if op.tokens.Get(op.index).GetTokenType() != TokenEOF{
if op.tokens.Get(op.index).GetTokenType() != TokenEOF {
buffer.WriteString(op.tokens.Get(op.index).GetText())
}
return op.index+1
return op.index + 1
}
func (op *InsertBeforeOp) String() string {
@ -213,20 +209,20 @@ type InsertAfterOp struct {
BaseRewriteOperation
}
func NewInsertAfterOp(index int, text string, stream TokenStream) *InsertAfterOp{
return &InsertAfterOp{BaseRewriteOperation:BaseRewriteOperation{
index:index+1,
text:text,
tokens:stream,
func NewInsertAfterOp(index int, text string, stream TokenStream) *InsertAfterOp {
return &InsertAfterOp{BaseRewriteOperation: BaseRewriteOperation{
index: index + 1,
text: text,
tokens: stream,
}}
}
func (op *InsertAfterOp) Execute(buffer *bytes.Buffer) int {
buffer.WriteString(op.text)
if op.tokens.Get(op.index).GetTokenType() != TokenEOF{
if op.tokens.Get(op.index).GetTokenType() != TokenEOF {
buffer.WriteString(op.tokens.Get(op.index).GetText())
}
return op.index+1
return op.index + 1
}
func (op *InsertAfterOp) String() string {
@ -235,28 +231,28 @@ func (op *InsertAfterOp) String() string {
// I'm going to try replacing range from x..y with (y-x)+1 ReplaceOp
// instructions.
type ReplaceOp struct{
type ReplaceOp struct {
BaseRewriteOperation
LastIndex int
}
func NewReplaceOp(from, to int, text string, stream TokenStream)*ReplaceOp {
func NewReplaceOp(from, to int, text string, stream TokenStream) *ReplaceOp {
return &ReplaceOp{
BaseRewriteOperation:BaseRewriteOperation{
index:from,
text:text,
op_name:"ReplaceOp",
tokens:stream,
BaseRewriteOperation: BaseRewriteOperation{
index: from,
text: text,
op_name: "ReplaceOp",
tokens: stream,
},
LastIndex:to,
LastIndex: to,
}
}
func (op *ReplaceOp)Execute(buffer *bytes.Buffer) int{
if op.text != ""{
func (op *ReplaceOp) Execute(buffer *bytes.Buffer) int {
if op.text != "" {
buffer.WriteString(op.text)
}
return op.LastIndex +1
return op.LastIndex + 1
}
func (op *ReplaceOp) String() string {
@ -268,54 +264,54 @@ func (op *ReplaceOp) String() string {
op.tokens.Get(op.index), op.tokens.Get(op.LastIndex), op.text)
}
type TokenStreamRewriter struct {
//Our source stream
tokens TokenStream
tokens TokenStream
// You may have multiple, named streams of rewrite operations.
// I'm calling these things "programs."
// Maps String (name) &rarr; rewrite (List)
programs map[string][]RewriteOperation
last_rewrite_token_indexes map[string]int
programs map[string][]RewriteOperation
last_rewrite_token_indexes map[string]int
}
func NewTokenStreamRewriter(tokens TokenStream) *TokenStreamRewriter{
func NewTokenStreamRewriter(tokens TokenStream) *TokenStreamRewriter {
return &TokenStreamRewriter{
tokens: tokens,
programs: map[string][]RewriteOperation{
Default_Program_Name:make([]RewriteOperation,0, Program_Init_Size),
tokens: tokens,
programs: map[string][]RewriteOperation{
Default_Program_Name: make([]RewriteOperation, 0, Program_Init_Size),
},
last_rewrite_token_indexes: map[string]int{},
last_rewrite_token_indexes: map[string]int{},
}
}
func (tsr *TokenStreamRewriter) GetTokenStream() TokenStream{
func (tsr *TokenStreamRewriter) GetTokenStream() TokenStream {
return tsr.tokens
}
// Rollback the instruction stream for a program so that
// the indicated instruction (via instructionIndex) is no
// longer in the stream. UNTESTED!
func (tsr *TokenStreamRewriter) Rollback(program_name string, instruction_index int){
is, ok := tsr.programs[program_name]
if ok{
// Rollback the instruction stream for a program so that
// the indicated instruction (via instructionIndex) is no
// longer in the stream. UNTESTED!
func (tsr *TokenStreamRewriter) Rollback(program_name string, instruction_index int) {
is, ok := tsr.programs[program_name]
if ok {
tsr.programs[program_name] = is[Min_Token_Index:instruction_index]
}
}
func (tsr *TokenStreamRewriter) RollbackDefault(instruction_index int){
func (tsr *TokenStreamRewriter) RollbackDefault(instruction_index int) {
tsr.Rollback(Default_Program_Name, instruction_index)
}
//Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(program_name string){
// Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(program_name string) {
tsr.Rollback(program_name, Min_Token_Index) //TODO: double test on that cause lower bound is not included
}
func (tsr *TokenStreamRewriter) DeleteProgramDefault(){
func (tsr *TokenStreamRewriter) DeleteProgramDefault() {
tsr.DeleteProgram(Default_Program_Name)
}
func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text string){
func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text string) {
// to insert after, just insert before next index (even if past end)
var op RewriteOperation = NewInsertAfterOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name)
@ -323,31 +319,31 @@ func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text
tsr.AddToProgram(program_name, op)
}
func (tsr *TokenStreamRewriter) InsertAfterDefault(index int, text string){
func (tsr *TokenStreamRewriter) InsertAfterDefault(index int, text string) {
tsr.InsertAfter(Default_Program_Name, index, text)
}
func (tsr *TokenStreamRewriter) InsertAfterToken(program_name string, token Token, text string){
func (tsr *TokenStreamRewriter) InsertAfterToken(program_name string, token Token, text string) {
tsr.InsertAfter(program_name, token.GetTokenIndex(), text)
}
func (tsr* TokenStreamRewriter) InsertBefore(program_name string, index int, text string){
func (tsr *TokenStreamRewriter) InsertBefore(program_name string, index int, text string) {
var op RewriteOperation = NewInsertBeforeOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name)
op.SetInstructionIndex(len(rewrites))
tsr.AddToProgram(program_name, op)
}
func (tsr *TokenStreamRewriter) InsertBeforeDefault(index int, text string){
func (tsr *TokenStreamRewriter) InsertBeforeDefault(index int, text string) {
tsr.InsertBefore(Default_Program_Name, index, text)
}
func (tsr *TokenStreamRewriter) InsertBeforeToken(program_name string,token Token, text string){
func (tsr *TokenStreamRewriter) InsertBeforeToken(program_name string, token Token, text string) {
tsr.InsertBefore(program_name, token.GetTokenIndex(), text)
}
func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text string){
if from > to || from < 0 || to < 0 || to >= tsr.tokens.Size(){
func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text string) {
if from > to || from < 0 || to < 0 || to >= tsr.tokens.Size() {
panic(fmt.Sprintf("replace: range invalid: %d..%d(size=%d)",
from, to, tsr.tokens.Size()))
}
@ -357,207 +353,216 @@ func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text
tsr.AddToProgram(program_name, op)
}
func (tsr *TokenStreamRewriter)ReplaceDefault(from, to int, text string) {
func (tsr *TokenStreamRewriter) ReplaceDefault(from, to int, text string) {
tsr.Replace(Default_Program_Name, from, to, text)
}
func (tsr *TokenStreamRewriter)ReplaceDefaultPos(index int, text string){
func (tsr *TokenStreamRewriter) ReplaceDefaultPos(index int, text string) {
tsr.ReplaceDefault(index, index, text)
}
func (tsr *TokenStreamRewriter)ReplaceToken(program_name string, from, to Token, text string){
func (tsr *TokenStreamRewriter) ReplaceToken(program_name string, from, to Token, text string) {
tsr.Replace(program_name, from.GetTokenIndex(), to.GetTokenIndex(), text)
}
func (tsr *TokenStreamRewriter)ReplaceTokenDefault(from, to Token, text string){
func (tsr *TokenStreamRewriter) ReplaceTokenDefault(from, to Token, text string) {
tsr.ReplaceToken(Default_Program_Name, from, to, text)
}
func (tsr *TokenStreamRewriter)ReplaceTokenDefaultPos(index Token, text string){
func (tsr *TokenStreamRewriter) ReplaceTokenDefaultPos(index Token, text string) {
tsr.ReplaceTokenDefault(index, index, text)
}
func (tsr *TokenStreamRewriter)Delete(program_name string, from, to int){
tsr.Replace(program_name, from, to, "" )
func (tsr *TokenStreamRewriter) Delete(program_name string, from, to int) {
tsr.Replace(program_name, from, to, "")
}
func (tsr *TokenStreamRewriter)DeleteDefault(from, to int){
func (tsr *TokenStreamRewriter) DeleteDefault(from, to int) {
tsr.Delete(Default_Program_Name, from, to)
}
func (tsr *TokenStreamRewriter)DeleteDefaultPos(index int){
tsr.DeleteDefault(index,index)
func (tsr *TokenStreamRewriter) DeleteDefaultPos(index int) {
tsr.DeleteDefault(index, index)
}
func (tsr *TokenStreamRewriter)DeleteToken(program_name string, from, to Token) {
func (tsr *TokenStreamRewriter) DeleteToken(program_name string, from, to Token) {
tsr.ReplaceToken(program_name, from, to, "")
}
func (tsr *TokenStreamRewriter)DeleteTokenDefault(from,to Token){
func (tsr *TokenStreamRewriter) DeleteTokenDefault(from, to Token) {
tsr.DeleteToken(Default_Program_Name, from, to)
}
func (tsr *TokenStreamRewriter)GetLastRewriteTokenIndex(program_name string)int {
func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndex(program_name string) int {
i, ok := tsr.last_rewrite_token_indexes[program_name]
if !ok{
if !ok {
return -1
}
return i
}
func (tsr *TokenStreamRewriter)GetLastRewriteTokenIndexDefault()int{
func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndexDefault() int {
return tsr.GetLastRewriteTokenIndex(Default_Program_Name)
}
func (tsr *TokenStreamRewriter)SetLastRewriteTokenIndex(program_name string, i int){
func (tsr *TokenStreamRewriter) SetLastRewriteTokenIndex(program_name string, i int) {
tsr.last_rewrite_token_indexes[program_name] = i
}
func (tsr *TokenStreamRewriter)InitializeProgram(name string)[]RewriteOperation{
func (tsr *TokenStreamRewriter) InitializeProgram(name string) []RewriteOperation {
is := make([]RewriteOperation, 0, Program_Init_Size)
tsr.programs[name] = is
return is
}
func (tsr *TokenStreamRewriter)AddToProgram(name string, op RewriteOperation){
func (tsr *TokenStreamRewriter) AddToProgram(name string, op RewriteOperation) {
is := tsr.GetProgram(name)
is = append(is, op)
tsr.programs[name] = is
}
func (tsr *TokenStreamRewriter)GetProgram(name string) []RewriteOperation {
func (tsr *TokenStreamRewriter) GetProgram(name string) []RewriteOperation {
is, ok := tsr.programs[name]
if !ok{
if !ok {
is = tsr.InitializeProgram(name)
}
return is
}
// Return the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter)GetTextDefault() string{
// Return the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetTextDefault() string {
return tsr.GetText(
Default_Program_Name,
NewInterval(0, tsr.tokens.Size()-1))
}
// Return the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter)GetText(program_name string, interval *Interval) string {
// Return the text from the original tokens altered per the
// instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetText(program_name string, interval *Interval) string {
rewrites := tsr.programs[program_name]
start := interval.Start
stop := interval.Stop
stop := interval.Stop
// ensure start/end are in range
stop = min(stop, tsr.tokens.Size()-1)
start = max(start,0)
if rewrites == nil || len(rewrites) == 0{
start = max(start, 0)
if rewrites == nil || len(rewrites) == 0 {
return tsr.tokens.GetTextFromInterval(interval) // no instructions to execute
}
buf := bytes.Buffer{}
// First, optimize instruction stream
indexToOp := reduceToSingleOperationPerIndex(rewrites)
// Walk buffer, executing instructions and emitting tokens
for i:=start; i<=stop && i<tsr.tokens.Size();{
for i := start; i <= stop && i < tsr.tokens.Size(); {
op := indexToOp[i]
delete(indexToOp, i)// remove so any left have index size-1
delete(indexToOp, i) // remove so any left have index size-1
t := tsr.tokens.Get(i)
if op == nil{
if op == nil {
// no operation at that index, just dump token
if t.GetTokenType() != TokenEOF {buf.WriteString(t.GetText())}
if t.GetTokenType() != TokenEOF {
buf.WriteString(t.GetText())
}
i++ // move to next token
}else {
i = op.Execute(&buf)// execute operation and skip
} else {
i = op.Execute(&buf) // execute operation and skip
}
}
// include stuff after end if it's last index in buffer
// So, if they did an insertAfter(lastValidIndex, "foo"), include
// foo if end==lastValidIndex.
if stop == tsr.tokens.Size()-1{
if stop == tsr.tokens.Size()-1 {
// Scan any remaining operations after last token
// should be included (they will be inserts).
for _, op := range indexToOp{
if op.GetIndex() >= tsr.tokens.Size()-1 {buf.WriteString(op.GetText())}
for _, op := range indexToOp {
if op.GetIndex() >= tsr.tokens.Size()-1 {
buf.WriteString(op.GetText())
}
}
}
return buf.String()
}
// We need to combine operations and report invalid operations (like
// overlapping replaces that are not completed nested). Inserts to
// same index need to be combined etc... Here are the cases:
// We need to combine operations and report invalid operations (like
// overlapping replaces that are not completed nested). Inserts to
// same index need to be combined etc... Here are the cases:
//
// I.i.u I.j.v leave alone, nonoverlapping
// I.i.u I.i.v combine: Iivu
// I.i.u I.j.v leave alone, nonoverlapping
// I.i.u I.i.v combine: Iivu
//
// R.i-j.u R.x-y.v | i-j in x-y delete first R
// R.i-j.u R.i-j.v delete first R
// R.i-j.u R.x-y.v | x-y in i-j ERROR
// R.i-j.u R.x-y.v | boundaries overlap ERROR
// R.i-j.u R.x-y.v | i-j in x-y delete first R
// R.i-j.u R.i-j.v delete first R
// R.i-j.u R.x-y.v | x-y in i-j ERROR
// R.i-j.u R.x-y.v | boundaries overlap ERROR
//
// Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
// Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
//
// I.i.u R.x-y.v | i in (x+1)-y delete I (since insert before
// we're not deleting i)
// I.i.u R.x-y.v | i not in (x+1)-y leave alone, nonoverlapping
// R.x-y.v I.i.u | i in x-y ERROR
// R.x-y.v I.x.u R.x-y.uv (combine, delete I)
// R.x-y.v I.i.u | i not in x-y leave alone, nonoverlapping
// I.i.u R.x-y.v | i in (x+1)-y delete I (since insert before
// we're not deleting i)
// I.i.u R.x-y.v | i not in (x+1)-y leave alone, nonoverlapping
// R.x-y.v I.i.u | i in x-y ERROR
// R.x-y.v I.x.u R.x-y.uv (combine, delete I)
// R.x-y.v I.i.u | i not in x-y leave alone, nonoverlapping
//
// I.i.u = insert u before op @ index i
// R.x-y.u = replace x-y indexed tokens with u
// I.i.u = insert u before op @ index i
// R.x-y.u = replace x-y indexed tokens with u
//
// First we need to examine replaces. For any replace op:
// First we need to examine replaces. For any replace op:
//
// 1. wipe out any insertions before op within that range.
// 2. Drop any replace op before that is contained completely within
// that range.
// 3. Throw exception upon boundary overlap with any previous replace.
// 1. wipe out any insertions before op within that range.
// 2. Drop any replace op before that is contained completely within
// that range.
// 3. Throw exception upon boundary overlap with any previous replace.
//
// Then we can deal with inserts:
// Then we can deal with inserts:
//
// 1. for any inserts to same index, combine even if not adjacent.
// 2. for any prior replace with same left boundary, combine this
// insert with replace and delete this replace.
// 3. throw exception if index in same range as previous replace
// 1. for any inserts to same index, combine even if not adjacent.
// 2. for any prior replace with same left boundary, combine this
// insert with replace and delete this replace.
// 3. throw exception if index in same range as previous replace
//
// Don't actually delete; make op null in list. Easier to walk list.
// Later we can throw as we add to index &rarr; op map.
// Don't actually delete; make op null in list. Easier to walk list.
// Later we can throw as we add to index &rarr; op map.
//
// Note that I.2 R.2-2 will wipe out I.2 even though, technically, the
// inserted stuff would be before the replace range. But, if you
// add tokens in front of a method body '{' and then delete the method
// body, I think the stuff before the '{' you added should disappear too.
// Note that I.2 R.2-2 will wipe out I.2 even though, technically, the
// inserted stuff would be before the replace range. But, if you
// add tokens in front of a method body '{' and then delete the method
// body, I think the stuff before the '{' you added should disappear too.
//
// Return a map from token index to operation.
//
func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]RewriteOperation{
// Return a map from token index to operation.
func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]RewriteOperation {
// WALK REPLACES
for i:=0; i < len(rewrites); i++{
for i := 0; i < len(rewrites); i++ {
op := rewrites[i]
if op == nil{continue}
if op == nil {
continue
}
rop, ok := op.(*ReplaceOp)
if !ok{continue}
if !ok {
continue
}
// Wipe prior inserts within range
for j:=0; j<i && j < len(rewrites); j++{
if iop, ok := rewrites[j].(*InsertBeforeOp);ok{
if iop.index == rop.index{
for j := 0; j < i && j < len(rewrites); j++ {
if iop, ok := rewrites[j].(*InsertBeforeOp); ok {
if iop.index == rop.index {
// E.g., insert before 2, delete 2..2; update replace
// text to include insert before, kill insert
rewrites[iop.instruction_index] = nil
if rop.text != ""{
if rop.text != "" {
rop.text = iop.text + rop.text
}else{
} else {
rop.text = iop.text
}
}else if iop.index > rop.index && iop.index <=rop.LastIndex{
} else if iop.index > rop.index && iop.index <= rop.LastIndex {
// delete insert as it's a no-op.
rewrites[iop.instruction_index] = nil
}
}
}
// Drop any prior replaces contained within
for j:=0; j<i && j < len(rewrites); j++{
if prevop, ok := rewrites[j].(*ReplaceOp);ok{
if prevop.index>=rop.index && prevop.LastIndex <= rop.LastIndex{
for j := 0; j < i && j < len(rewrites); j++ {
if prevop, ok := rewrites[j].(*ReplaceOp); ok {
if prevop.index >= rop.index && prevop.LastIndex <= rop.LastIndex {
// delete replace as it's a no-op.
rewrites[prevop.instruction_index] = nil
continue
@ -566,61 +571,67 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
disjoint := prevop.LastIndex < rop.index || prevop.index > rop.LastIndex
// Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
if prevop.text == "" && rop.text == "" && !disjoint{
if prevop.text == "" && rop.text == "" && !disjoint {
rewrites[prevop.instruction_index] = nil
rop.index = min(prevop.index, rop.index)
rop.LastIndex = max(prevop.LastIndex, rop.LastIndex)
println("new rop" + rop.String()) //TODO: remove console write, taken from Java version
}else if !disjoint{
} else if !disjoint {
panic("replace op boundaries of " + rop.String() + " overlap with previous " + prevop.String())
}
}
}
}
// WALK INSERTS
for i:=0; i < len(rewrites); i++ {
for i := 0; i < len(rewrites); i++ {
op := rewrites[i]
if op == nil{continue}
if op == nil {
continue
}
//hack to replicate inheritance in composition
_, iok := rewrites[i].(*InsertBeforeOp)
_, aok := rewrites[i].(*InsertAfterOp)
if !iok && !aok{continue}
if !iok && !aok {
continue
}
iop := rewrites[i]
// combine current insert with prior if any at same index
// deviating a bit from TokenStreamRewriter.java - hard to incorporate inheritance logic
for j:=0; j<i && j < len(rewrites); j++{
if nextIop, ok := rewrites[j].(*InsertAfterOp); ok{
if nextIop.index == iop.GetIndex(){
for j := 0; j < i && j < len(rewrites); j++ {
if nextIop, ok := rewrites[j].(*InsertAfterOp); ok {
if nextIop.index == iop.GetIndex() {
iop.SetText(nextIop.text + iop.GetText())
rewrites[j] = nil
}
}
if prevIop, ok := rewrites[j].(*InsertBeforeOp); ok{
if prevIop.index == iop.GetIndex(){
if prevIop, ok := rewrites[j].(*InsertBeforeOp); ok {
if prevIop.index == iop.GetIndex() {
iop.SetText(iop.GetText() + prevIop.text)
rewrites[prevIop.instruction_index] = nil
}
}
}
// look for replaces where iop.index is in range; error
for j:=0; j<i && j < len(rewrites); j++{
if rop,ok := rewrites[j].(*ReplaceOp); ok{
if iop.GetIndex() == rop.index{
for j := 0; j < i && j < len(rewrites); j++ {
if rop, ok := rewrites[j].(*ReplaceOp); ok {
if iop.GetIndex() == rop.index {
rop.text = iop.GetText() + rop.text
rewrites[i] = nil
continue
}
if iop.GetIndex() >= rop.index && iop.GetIndex() <= rop.LastIndex{
panic("insert op "+iop.String()+" within boundaries of previous "+rop.String())
if iop.GetIndex() >= rop.index && iop.GetIndex() <= rop.LastIndex {
panic("insert op " + iop.String() + " within boundaries of previous " + rop.String())
}
}
}
}
m := map[int]RewriteOperation{}
for i:=0; i < len(rewrites); i++{
for i := 0; i < len(rewrites); i++ {
op := rewrites[i]
if op == nil {continue}
if _, ok := m[op.GetIndex()]; ok{
if op == nil {
continue
}
if _, ok := m[op.GetIndex()]; ok {
panic("should only be one op per index")
}
m[op.GetIndex()] = op
@ -628,22 +639,21 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
return m
}
/*
Quick fixing Go lack of overloads
*/
*/
func max(a,b int)int{
if a>b{
func max(a, b int) int {
if a > b {
return a
}else {
} else {
return b
}
}
func min(a,b int)int{
if a<b{
func min(a, b int) int {
if a < b {
return a
}else {
} else {
return b
}
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -234,10 +234,8 @@ func (p *ParseTreeWalker) Walk(listener ParseTreeListener, t Tree) {
}
}
//
// Enters a grammar rule by first triggering the generic event {@link ParseTreeListener//EnterEveryRule}
// then by triggering the event specific to the given parse tree node
//
func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext)
listener.EnterEveryRule(ctx)
@ -246,7 +244,6 @@ func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
// Exits a grammar rule by first triggering the event specific to the given parse tree node
// then by triggering the generic event {@link ParseTreeListener//ExitEveryRule}
//
func (p *ParseTreeWalker) ExitRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext)
ctx.ExitRule(listener)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -9,8 +9,9 @@ import "fmt"
/** A set of utility routines useful for all kinds of ANTLR trees. */
// Print out a whole tree in LISP form. {@link //getNodeText} is used on the
// node payloads to get the text for the nodes. Detect
// parse trees and extract data appropriately.
//
// node payloads to get the text for the nodes. Detect
// parse trees and extract data appropriately.
func TreesStringTree(tree Tree, ruleNames []string, recog Recognizer) string {
if recog != nil {
@ -80,8 +81,8 @@ func TreesGetChildren(t Tree) []Tree {
}
// Return a list of all ancestors of this node. The first node of
// list is the root and the last is the parent of this node.
//
// list is the root and the last is the parent of this node.
func TreesgetAncestors(t Tree) []Tree {
ancestors := make([]Tree, 0)
t = t.GetParent()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
@ -47,28 +47,25 @@ func (s *IntStack) Push(e int) {
*s = append(*s, e)
}
func standardEqualsFunction(a interface{}, b interface{}) bool {
type comparable interface {
Equals(other Collectable[any]) bool
}
ac, oka := a.(comparable)
bc, okb := b.(comparable)
func standardEqualsFunction(a Collectable[any], b Collectable[any]) bool {
if !oka || !okb {
panic("Not Comparable")
}
return ac.equals(bc)
return a.Equals(b)
}
func standardHashFunction(a interface{}) int {
if h, ok := a.(hasher); ok {
return h.hash()
return h.Hash()
}
panic("Not Hasher")
}
type hasher interface {
hash() int
Hash() int
}
const bitsPerWord = 64
@ -171,7 +168,7 @@ func (b *BitSet) equals(other interface{}) bool {
// We only compare set bits, so we cannot rely on the two slices having the same size. Its
// possible for two BitSets to have different slice lengths but the same set bits. So we only
// compare the relavent words and ignore the trailing zeros.
// compare the relevant words and ignore the trailing zeros.
bLen := b.minLen()
otherLen := otherBitSet.minLen()

View File

@ -8,8 +8,6 @@ const (
_loadFactor = 0.75
)
var _ Set = (*array2DHashSet)(nil)
type Set interface {
Add(value interface{}) (added interface{})
Len() int
@ -20,9 +18,9 @@ type Set interface {
}
type array2DHashSet struct {
buckets [][]interface{}
buckets [][]Collectable[any]
hashcodeFunction func(interface{}) int
equalsFunction func(interface{}, interface{}) bool
equalsFunction func(Collectable[any], Collectable[any]) bool
n int // How many elements in set
threshold int // when to expand
@ -61,11 +59,11 @@ func (as *array2DHashSet) Values() []interface{} {
return values
}
func (as *array2DHashSet) Contains(value interface{}) bool {
func (as *array2DHashSet) Contains(value Collectable[any]) bool {
return as.Get(value) != nil
}
func (as *array2DHashSet) Add(value interface{}) interface{} {
func (as *array2DHashSet) Add(value Collectable[any]) interface{} {
if as.n > as.threshold {
as.expand()
}
@ -98,7 +96,7 @@ func (as *array2DHashSet) expand() {
b := as.getBuckets(o)
bucketLength := newBucketLengths[b]
var newBucket []interface{}
var newBucket []Collectable[any]
if bucketLength == 0 {
// new bucket
newBucket = as.createBucket(as.initialBucketCapacity)
@ -107,7 +105,7 @@ func (as *array2DHashSet) expand() {
newBucket = newTable[b]
if bucketLength == len(newBucket) {
// expand
newBucketCopy := make([]interface{}, len(newBucket)<<1)
newBucketCopy := make([]Collectable[any], len(newBucket)<<1)
copy(newBucketCopy[:bucketLength], newBucket)
newBucket = newBucketCopy
newTable[b] = newBucket
@ -124,7 +122,7 @@ func (as *array2DHashSet) Len() int {
return as.n
}
func (as *array2DHashSet) Get(o interface{}) interface{} {
func (as *array2DHashSet) Get(o Collectable[any]) interface{} {
if o == nil {
return nil
}
@ -147,7 +145,7 @@ func (as *array2DHashSet) Get(o interface{}) interface{} {
return nil
}
func (as *array2DHashSet) innerAdd(o interface{}) interface{} {
func (as *array2DHashSet) innerAdd(o Collectable[any]) interface{} {
b := as.getBuckets(o)
bucket := as.buckets[b]
@ -178,7 +176,7 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} {
// full bucket, expand and add to end
oldLength := len(bucket)
bucketCopy := make([]interface{}, oldLength<<1)
bucketCopy := make([]Collectable[any], oldLength<<1)
copy(bucketCopy[:oldLength], bucket)
bucket = bucketCopy
as.buckets[b] = bucket
@ -187,22 +185,22 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} {
return o
}
func (as *array2DHashSet) getBuckets(value interface{}) int {
func (as *array2DHashSet) getBuckets(value Collectable[any]) int {
hash := as.hashcodeFunction(value)
return hash & (len(as.buckets) - 1)
}
func (as *array2DHashSet) createBuckets(cap int) [][]interface{} {
return make([][]interface{}, cap)
func (as *array2DHashSet) createBuckets(cap int) [][]Collectable[any] {
return make([][]Collectable[any], cap)
}
func (as *array2DHashSet) createBucket(cap int) []interface{} {
return make([]interface{}, cap)
func (as *array2DHashSet) createBucket(cap int) []Collectable[any] {
return make([]Collectable[any], cap)
}
func newArray2DHashSetWithCap(
hashcodeFunction func(interface{}) int,
equalsFunction func(interface{}, interface{}) bool,
equalsFunction func(Collectable[any], Collectable[any]) bool,
initCap int,
initBucketCap int,
) *array2DHashSet {
@ -231,7 +229,7 @@ func newArray2DHashSetWithCap(
func newArray2DHashSet(
hashcodeFunction func(interface{}) int,
equalsFunction func(interface{}, interface{}) bool,
equalsFunction func(Collectable[any], Collectable[any]) bool,
) *array2DHashSet {
return newArray2DHashSetWithCap(hashcodeFunction, equalsFunction, _initalCapacity, _initalBucketCapacity)
}

View File

@ -1,10 +0,0 @@
language: go
go:
- 1.13
- 1.x
- tip
before_install:
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -service=travis-ci

View File

@ -5,10 +5,20 @@ import (
"time"
)
// An OperationWithData is executing by RetryWithData() or RetryNotifyWithData().
// The operation will be retried using a backoff policy if it returns an error.
type OperationWithData[T any] func() (T, error)
// An Operation is executing by Retry() or RetryNotify().
// The operation will be retried using a backoff policy if it returns an error.
type Operation func() error
func (o Operation) withEmptyData() OperationWithData[struct{}] {
return func() (struct{}, error) {
return struct{}{}, o()
}
}
// Notify is a notify-on-error function. It receives an operation error and
// backoff delay if the operation failed (with an error).
//
@ -28,18 +38,41 @@ func Retry(o Operation, b BackOff) error {
return RetryNotify(o, b, nil)
}
// RetryWithData is like Retry but returns data in the response too.
func RetryWithData[T any](o OperationWithData[T], b BackOff) (T, error) {
return RetryNotifyWithData(o, b, nil)
}
// RetryNotify calls notify function with the error and wait duration
// for each failed attempt before sleep.
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
return RetryNotifyWithTimer(operation, b, notify, nil)
}
// RetryNotifyWithData is like RetryNotify but returns data in the response too.
func RetryNotifyWithData[T any](operation OperationWithData[T], b BackOff, notify Notify) (T, error) {
return doRetryNotify(operation, b, notify, nil)
}
// RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer
// for each failed attempt before sleep.
// A default timer that uses system timer is used when nil is passed.
func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error {
var err error
var next time.Duration
_, err := doRetryNotify(operation.withEmptyData(), b, notify, t)
return err
}
// RetryNotifyWithTimerAndData is like RetryNotifyWithTimer but returns data in the response too.
func RetryNotifyWithTimerAndData[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
return doRetryNotify(operation, b, notify, t)
}
func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
var (
err error
next time.Duration
res T
)
if t == nil {
t = &defaultTimer{}
}
@ -52,21 +85,22 @@ func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer
b.Reset()
for {
if err = operation(); err == nil {
return nil
res, err = operation()
if err == nil {
return res, nil
}
var permanent *PermanentError
if errors.As(err, &permanent) {
return permanent.Err
return res, permanent.Err
}
if next = b.NextBackOff(); next == Stop {
if cerr := ctx.Err(); cerr != nil {
return cerr
return res, cerr
}
return err
return res, err
}
if notify != nil {
@ -77,7 +111,7 @@ func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer
select {
case <-ctx.Done():
return ctx.Err()
return res, ctx.Err()
case <-t.C():
}
}

View File

@ -85,7 +85,7 @@ func (v *Version) Set(version string) error {
return fmt.Errorf("failed to validate metadata: %v", err)
}
parsed := make([]int64, 3, 3)
parsed := make([]int64, 3)
for i, v := range dotParts[:3] {
val, err := strconv.ParseInt(v, 10, 64)

View File

@ -69,6 +69,58 @@ func Enabled() bool {
return true
}
// StderrIsJournalStream returns whether the process stderr is connected
// to the Journal's stream transport.
//
// This can be used for automatic protocol upgrading described in [Journal Native Protocol].
//
// Returns true if JOURNAL_STREAM environment variable is present,
// and stderr's device and inode numbers match it.
//
// Error is returned if unexpected error occurs: e.g. if JOURNAL_STREAM environment variable
// is present, but malformed, fstat syscall fails, etc.
//
// [Journal Native Protocol]: https://systemd.io/JOURNAL_NATIVE_PROTOCOL/#automatic-protocol-upgrading
func StderrIsJournalStream() (bool, error) {
return fdIsJournalStream(syscall.Stderr)
}
// StdoutIsJournalStream returns whether the process stdout is connected
// to the Journal's stream transport.
//
// Returns true if JOURNAL_STREAM environment variable is present,
// and stdout's device and inode numbers match it.
//
// Error is returned if unexpected error occurs: e.g. if JOURNAL_STREAM environment variable
// is present, but malformed, fstat syscall fails, etc.
//
// Most users should probably use [StderrIsJournalStream].
func StdoutIsJournalStream() (bool, error) {
return fdIsJournalStream(syscall.Stdout)
}
func fdIsJournalStream(fd int) (bool, error) {
journalStream := os.Getenv("JOURNAL_STREAM")
if journalStream == "" {
return false, nil
}
var expectedStat syscall.Stat_t
_, err := fmt.Sscanf(journalStream, "%d:%d", &expectedStat.Dev, &expectedStat.Ino)
if err != nil {
return false, fmt.Errorf("failed to parse JOURNAL_STREAM=%q: %v", journalStream, err)
}
var stat syscall.Stat_t
err = syscall.Fstat(fd, &stat)
if err != nil {
return false, err
}
match := stat.Dev == expectedStat.Dev && stat.Ino == expectedStat.Ino
return match, nil
}
// Send a message to the local systemd journal. vars is a map of journald
// fields to values. Fields must be composed of uppercase letters, numbers,
// and underscores, but must not start with an underscore. Within these

View File

@ -33,3 +33,11 @@ func Enabled() bool {
func Send(message string, priority Priority, vars map[string]string) error {
return errors.New("could not initialize socket to journald")
}
func StderrIsJournalStream() (bool, error) {
return false, nil
}
func StdoutIsJournalStream() (bool, error) {
return false, nil
}

View File

@ -568,29 +568,6 @@ func (p Patch) replace(doc *container, op Operation) error {
return errors.Wrapf(err, "replace operation failed to decode path")
}
if path == "" {
val := op.value()
if val.which == eRaw {
if !val.tryDoc() {
if !val.tryAry() {
return errors.Wrapf(err, "replace operation value must be object or array")
}
}
}
switch val.which {
case eAry:
*doc = &val.ary
case eDoc:
*doc = &val.doc
case eRaw:
return errors.Wrapf(err, "replace operation hit impossible case")
}
return nil
}
con, key := findObject(doc, path)
if con == nil {
@ -657,25 +634,6 @@ func (p Patch) test(doc *container, op Operation) error {
return errors.Wrapf(err, "test operation failed to decode path")
}
if path == "" {
var self lazyNode
switch sv := (*doc).(type) {
case *partialDoc:
self.doc = *sv
self.which = eDoc
case *partialArray:
self.ary = *sv
self.which = eAry
}
if self.equal(op.value()) {
return nil
}
return errors.Wrapf(ErrTestFailed, "testing value %s failed", path)
}
con, key := findObject(doc, path)
if con == nil {

View File

@ -26,11 +26,16 @@ var rxDupSlashes = regexp.MustCompile(`/{2,}`)
// - FlagLowercaseHost
// - FlagRemoveDefaultPort
// - FlagRemoveDuplicateSlashes (and this was mixed in with the |)
//
// This also normalizes the URL into its urlencoded form by removing RawPath and RawFragment.
func NormalizeURL(u *url.URL) {
lowercaseScheme(u)
lowercaseHost(u)
removeDefaultPort(u)
removeDuplicateSlashes(u)
u.RawPath = ""
u.RawFragment = ""
}
func lowercaseScheme(u *url.URL) {

View File

@ -23,6 +23,7 @@ go_library(
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
@ -31,7 +32,7 @@ go_library(
"//interpreter:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protodesc:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -69,7 +70,7 @@ go_test(
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],

View File

@ -139,7 +139,7 @@ var (
kind: TypeKind,
runtimeType: types.TypeType,
}
//UintType represents a uint type.
// UintType represents a uint type.
UintType = &Type{
kind: UintKind,
runtimeType: types.UintType,
@ -222,7 +222,8 @@ func (t *Type) equals(other *Type) bool {
// - The from types are the same instance
// - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type
// are IsAssignableType() from the parameters of the fromType.
//
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() {
return true
@ -312,6 +313,11 @@ func NullableType(wrapped *Type) *Type {
}
}
// OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func OptionalType(param *Type) *Type {
return OpaqueType("optional", param)
}
// OpaqueType creates an abstract parameterized type with a given name.
func OpaqueType(name string, params ...*Type) *Type {
return &Type{
@ -365,7 +371,9 @@ func Variable(name string, t *Type) EnvOption {
//
// - Overloads are searched in the order they are declared
// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents
// at runtime. Empty lists and maps will result in a 'default dispatch'
//
// at runtime. Empty lists and maps will result in a 'default dispatch'
//
// - In the event that a default dispatch occurs, the first overload provided is the one invoked
//
// If you intend to use overloads which differentiate based on the key or element type of a list or
@ -405,7 +413,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
// FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt func(*functionDecl) (*functionDecl, error)
// SingletonUnaryBinding creates a singleton function defintion to be used for all function overloads.
// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
@ -431,7 +439,17 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonBinaryBinding
func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
return SingletonBinaryBinding(fn, traits...)
}
// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
@ -453,7 +471,17 @@ func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonFunctionBinding
func SingletonFunctionImpl(fn functions.FunctionOp, traits ...int) FunctionOpt {
return SingletonFunctionBinding(fn, traits...)
}
// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
@ -720,9 +748,8 @@ func (f *functionDecl) addOverload(overload *overloadDecl) error {
// Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[index] = overload
return nil
} else {
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
}
}
f.overloads = append(f.overloads, overload)
@ -1177,3 +1204,43 @@ func collectParamNames(paramNames map[string]struct{}, arg *Type) {
collectParamNames(paramNames, param)
}
}
func typeValueToKind(tv *types.TypeValue) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@ -102,15 +102,18 @@ type Env struct {
provider ref.TypeProvider
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
// Internal parser representation
prsr *parser.Parser
prsr *parser.Parser
prsrOpts []parser.Option
// Internal checker representation
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
chkMutex sync.Mutex
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
// Program options tied to the environment
progOpts []ProgramOption
@ -159,6 +162,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
provider: registry,
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
progOpts: []ProgramOption{},
}).configure(opts)
}
@ -175,14 +179,14 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
err := e.initChecker()
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, e.chkErr.Error())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssues(errs)
}
res, errs := checker.Check(pe, ast.Source(), e.chk)
res, errs := checker.Check(pe, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssues(errs)
}
@ -236,10 +240,14 @@ func (e *Env) CompileSource(src Source) (*Ast, *Issues) {
// TypeProvider are immutable, or that their underlying implementations are based on the
// ref.TypeRegistry which provides a Copy method which will be invoked by this method.
func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
if e.chkErr != nil {
return nil, e.chkErr
chk, chkErr := e.getCheckerOrError()
if chkErr != nil {
return nil, chkErr
}
prsrOptsCopy := make([]parser.Option, len(e.prsrOpts))
copy(prsrOptsCopy, e.prsrOpts)
// The type-checker is configured with Declarations. The declarations may either be provided
// as options which have not yet been validated, or may come from a previous checker instance
// whose types have already been validated.
@ -248,10 +256,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// Copy the declarations if needed.
decsCopy := []*exprpb.Decl{}
if e.chk != nil {
if chk != nil {
// If the type-checker has already been instantiated, then the e.declarations have been
// valdiated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(e.chk))
// validated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk))
} else {
// If the type-checker has not been instantiated, ensure the unvalidated declarations are
// provided to the extended Env instance.
@ -304,8 +312,11 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.functions {
funcsCopy[k] = v
}
libsCopy := make(map[string]bool, len(e.libraries))
for k, v := range e.libraries {
libsCopy[k] = v
}
// TODO: functions copy needs to happen here.
ext := &Env{
Container: e.Container,
declarations: decsCopy,
@ -315,8 +326,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
adapter: adapter,
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
}
return ext.configure(opts)
}
@ -328,6 +341,12 @@ func (e *Env) HasFeature(flag int) bool {
return has && enabled
}
// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment.
func (e *Env) HasLibrary(libName string) bool {
configured, exists := e.libraries[libName]
return exists && configured
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@ -422,8 +441,8 @@ func (e *Env) UnknownVars() interpreter.PartialActivation {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), details.State())
expr, err := AstToString(ParsedExprToAst(&exprpb.ParsedExpr{Expr: pruned}))
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
if err != nil {
return nil, err
}
@ -443,12 +462,12 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator) (checker.CostEstimate, error) {
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast)
if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
}
return checker.Cost(checked, estimator), nil
return checker.Cost(checked, estimator, opts...)
}
// configure applies a series of EnvOptions to the current environment.
@ -464,17 +483,9 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
// If the default UTC timezone fix has been enabled, make sure the library is configured
if e.HasFeature(featureDefaultUTCTimeZone) {
if _, found := e.appliedFeatures[featureDefaultUTCTimeZone]; !found {
e, err = Lib(timeUTCLibrary{})(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[featureDefaultUTCTimeZone] = true
}
e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{}))
if err != nil {
return nil, err
}
// Initialize all of the functions configured within the environment.
@ -486,7 +497,10 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
// Configure the parser.
prsrOpts := []parser.Option{parser.Macros(e.macros...)}
prsrOpts := []parser.Option{}
prsrOpts = append(prsrOpts, e.prsrOpts...)
prsrOpts = append(prsrOpts, parser.Macros(e.macros...))
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
@ -497,7 +511,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
// Ensure that the checker init happens eagerly rather than lazily.
if e.HasFeature(featureEagerlyValidateDeclarations) {
err := e.initChecker()
_, err := e.initChecker()
if err != nil {
return nil, err
}
@ -506,7 +520,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return e, nil
}
func (e *Env) initChecker() error {
func (e *Env) initChecker() (*checker.Env, error) {
e.chkOnce.Do(func() {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
@ -518,32 +532,68 @@ func (e *Env) initChecker() error {
ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the statically configured declarations.
err = ce.Add(e.declarations...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the function declarations which are derived from the FunctionDecl instances.
for _, fn := range e.functions {
fnDecl, err := functionDeclToExprDecl(fn)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
err = ce.Add(fnDecl)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
}
// Add function declarations here separately.
e.chk = ce
e.setCheckerOrError(ce, nil)
})
return e.chkErr
return e.getCheckerOrError()
}
// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner
func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) {
e.chkMutex.Lock()
e.chk = chk
e.chkErr = chkErr
e.chkMutex.Unlock()
}
// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner
func (e *Env) getCheckerOrError() (*checker.Env, error) {
e.chkMutex.Lock()
defer e.chkMutex.Unlock()
return e.chk, e.chkErr
}
// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies
// the feature if it has not already been enabled.
func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
if !e.HasFeature(feature) {
return e, nil
}
_, applied := e.appliedFeatures[feature]
if applied {
return e, nil
}
e, err := option(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[feature] = true
return e, nil
}
// Issues defines methods for inspecting the error details of parse and check calls.

View File

@ -19,14 +19,14 @@ import (
"fmt"
"reflect"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
)

View File

@ -20,10 +20,27 @@ import (
"time"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
optMapMacro = "optMap"
hasValueFunc = "hasValue"
optionalNoneFunc = "optional.none"
optionalOfFunc = "optional.of"
optionalOfNonZeroValueFunc = "optional.ofNonZeroValue"
valueFunc = "value"
unusedIterVar = "#unused"
)
// Library provides a collection of EnvOption and ProgramOption values used to configure a CEL
@ -42,10 +59,27 @@ type Library interface {
ProgramOptions() []ProgramOption
}
// SingletonLibrary refines the Library interface to ensure that libraries in this format are only
// configured once within the environment.
type SingletonLibrary interface {
Library
// LibraryName provides a namespaced name which is used to check whether the library has already
// been configured in the environment.
LibraryName() string
}
// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args,
// and to be linked to each other.
func Lib(l Library) EnvOption {
singleton, isSingleton := l.(SingletonLibrary)
return func(e *Env) (*Env, error) {
if isSingleton {
if e.HasLibrary(singleton.LibraryName()) {
return e, nil
}
e.libraries[singleton.LibraryName()] = true
}
var err error
for _, opt := range l.CompileOptions() {
e, err = opt(e)
@ -67,6 +101,11 @@ func StdLib() EnvOption {
// features documented in the specification.
type stdLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// EnvOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption {
return []EnvOption{
@ -82,6 +121,225 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
}
}
type optionalLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (optionalLibrary) LibraryName() string {
return "cel.lib.optional"
}
// CompileOptions implements the Library interface method.
func (optionalLibrary) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
paramTypeV := TypeParamType("V")
optionalTypeV := OptionalType(paramTypeV)
listTypeV := ListType(paramTypeV)
mapTypeKV := MapType(paramTypeK, paramTypeV)
return []EnvOption{
// Enable the optional syntax in the parser.
enableOptionalSyntax(),
// Introduce the optional type.
Types(types.OptionalType),
// Configure the optMap macro.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
Overload("optional_of", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
return types.OptionalOf(value)
}))),
Function(optionalOfNonZeroValueFunc,
Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
v, isZeroer := value.(traits.Zeroer)
if !isZeroer || !v.IsZeroValue() {
return types.OptionalOf(value)
}
return types.OptionalNone
}))),
Function(optionalNoneFunc,
Overload("optional_none", []*Type{}, optionalTypeV,
FunctionBinding(func(values ...ref.Val) ref.Val {
return types.OptionalNone
}))),
Function(valueFunc,
MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return opt.GetValue()
}))),
Function(hasValueFunc,
MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return types.Bool(opt.HasValue())
}))),
// Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the
// evaluation chain.
Function("or",
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)),
Function("orValue",
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)),
// OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the
// optput type.
Function(operators.OptSelect,
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)),
// OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use
// these signatures to determine type-agreement without any special handling.
Function(operators.OptIndex,
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV),
Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV),
Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
// Index overloads to accommodate using an optional value as the operand.
Function(operators.Index,
Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("optional_map_index_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "optMap() variable name must be a simple identifier",
Location: meh.OffsetLocation(varIdent.GetId()),
}
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
// ProgramOptions implements the Library interface method.
func (optionalLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
}
}
func enableOptionalSyntax() EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableOptionalSyntax(true))
return e, nil
}
}
func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
args := call.Args()
if len(args) != 2 {
return i, nil
}
switch call.Function() {
case "or":
if call.OverloadID() != "" && call.OverloadID() != "optional_or_optional" {
return i, nil
}
return &evalOptionalOr{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
case "orValue":
if call.OverloadID() != "" && call.OverloadID() != "optional_orValue_value" {
return i, nil
}
return &evalOptionalOrValue{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
default:
return i, nil
}
}
// evalOptionalOr selects between two optional values, either the first if it has a value, or
// the second optional expression is evaluated and returned.
type evalOptionalOr struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOr) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
}
// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
// its value is returned, otherwise the alternative value expression is evaluated and returned.
type evalOptionalOrValue struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOrValue) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
}
type timeUTCLibrary struct{}
func (timeUTCLibrary) CompileOptions() []EnvOption {

View File

@ -17,6 +17,7 @@ package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -26,8 +27,11 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree, or an error
// if the input arguments are not suitable for the expansion requirements for the macro in question.
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
// are not well-formed, the result of the expansion will be an error.
//
// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call
// and produces as output an Expr ast node.
@ -81,8 +85,10 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// input to produce an output list.
//
// There are two call patterns supported by map:
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {

View File

@ -29,6 +29,7 @@ import (
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
@ -61,6 +62,10 @@ const (
// on a CEL timestamp operation. This fixes the scenario where the input time
// is not already in UTC.
featureDefaultUTCTimeZone
// Enable the use of optional types in the syntax, type-system, type-checking,
// and runtime.
featureOptionalTypes
)
// EnvOption is a functional interface for configuring the environment.
@ -163,19 +168,19 @@ func Container(name string) EnvOption {
// Abbreviations can be useful when working with variables, functions, and especially types from
// multiple namespaces:
//
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
//
// Only one the qualified names above may be used as the CEL container, so at least one of these
// references must be a long qualified name within an otherwise short CEL program. Using the
// following abbreviations, the program becomes much simpler:
//
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
//
// There are a few rules for the qualified names and the simple abbreviations generated from them:
// - Qualified names must be dot-delimited, e.g. `package.subpkg.name`.
@ -188,9 +193,12 @@ func Container(name string) EnvOption {
// - Expanded abbreviations do not participate in namespace resolution.
// - Abbreviation expansion is done instead of the container search for a matching identifier.
// - Containers follow C++ namespace resolution rules with searches from the most qualified name
// to the least qualified name.
//
// to the least qualified name.
//
// - Container references within the CEL program may be relative, and are resolved to fully
// qualified names at either type-check time or program plan time, whichever comes first.
//
// qualified names at either type-check time or program plan time, whichever comes first.
//
// If there is ever a case where an identifier could be in both the container and as an
// abbreviation, the abbreviation wins as this will ensure that the meaning of a program is
@ -216,7 +224,7 @@ func Abbrevs(qualifiedNames ...string) EnvOption {
// environment by default.
//
// Note: This option must be specified after the CustomTypeProvider option when used together.
func Types(addTypes ...interface{}) EnvOption {
func Types(addTypes ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg {
@ -253,7 +261,7 @@ func Types(addTypes ...interface{}) EnvOption {
//
// TypeDescs are hermetic to a single Env object, but may be copied to other Env values via
// extension or by re-using the same EnvOption with another NewEnv() call.
func TypeDescs(descs ...interface{}) EnvOption {
func TypeDescs(descs ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg {
@ -350,8 +358,8 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// variables with the same name provided to the Eval() call. If Globals is used in a Library with
// a Lib EnvOption, vars may shadow variables provided by previously added libraries.
//
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`.
func Globals(vars interface{}) ProgramOption {
// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`.
func Globals(vars any) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err := interpreter.NewActivation(vars)
if err != nil {
@ -404,6 +412,9 @@ const (
// OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails
// cost calculation is available via func ActualCost()
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
OptCheckStringFormat EvalOption = 1 << iota
)
// EvalOptions sets one or more evaluation options which may affect the evaluation or Result.
@ -534,6 +545,13 @@ func DefaultUTCTimeZone(enabled bool) EnvOption {
return features(featureDefaultUTCTimeZone, enabled)
}
// OptionalTypes enable support for optional syntax and types in CEL. The optional value type makes
// it possible to express whether variables have been provided, whether a result has been computed,
// and in the future whether an object field path, map key value, or list index has a value.
func OptionalTypes() EnvOption {
return Lib(optionalLibrary{})
}
// features sets the given feature flags. See list of Feature constants above.
func features(flag int, enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
@ -541,3 +559,21 @@ func features(flag int, enabled bool) EnvOption {
return e, nil
}
}
// ParserRecursionLimit adjusts the AST depth the parser will tolerate.
// Defaults defined in the parser package.
func ParserRecursionLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit))
return e, nil
}
}
// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse.
// Defaults defined in the parser package.
func ParserExpressionSizeLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit))
return e, nil
}
}

View File

@ -17,21 +17,20 @@ package cel
import (
"context"
"fmt"
"math"
"sync"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Program is an evaluable view of an Ast.
type Program interface {
// Eval returns the result of an evaluation of the Ast and environment against the input vars.
//
// The vars value may either be an `interpreter.Activation` or a `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or a `map[string]any`.
//
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be:
@ -43,16 +42,16 @@ type Program interface {
// An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption`
// or `ProgramOption` values used in the creation of the evaluation environment or executable
// program.
Eval(interface{}) (ref.Val, *EvalDetails, error)
Eval(any) (ref.Val, *EvalDetails, error)
// ContextEval evaluates the program with a set of input variables and a context object in order
// to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
//
// The vars value may either be an `interpreter.Activation` or `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or `map[string]any`.
//
// The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, interface{}) (ref.Val, *EvalDetails, error)
ContextEval(context.Context, any) (ref.Val, *EvalDetails, error)
}
// NoVars returns an empty Activation.
@ -65,7 +64,7 @@ func NoVars() interpreter.Activation {
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
func PartialVars(vars interface{},
func PartialVars(vars any,
unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) {
return interpreter.NewPartialActivation(vars, unknowns...)
}
@ -207,6 +206,37 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
t, err := ExprTypeToType(ast.typeMap[id])
if err != nil {
return false, err
}
if t.kind == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if k == t.kind {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@ -268,7 +298,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}
// Eval implements the Program interface method.
func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error) {
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
@ -287,11 +317,11 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
switch v := input.(type) {
case interpreter.Activation:
vars = v
case map[string]interface{}:
case map[string]any:
vars = activationPool.Setup(v)
defer activationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
@ -307,7 +337,7 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
}
// ContextEval implements the Program interface.
func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@ -318,22 +348,17 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev
case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]interface{}:
case map[string]any:
rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
}
// Cost implements the Coster interface method.
func (p *prog) Cost() (min, max int64) {
return estimateCost(p.interpretable)
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
@ -354,7 +379,7 @@ func newProgGen(factory progFactory) (Program, error) {
}
// Eval implements the Program interface method.
func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
@ -379,7 +404,7 @@ func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
}
// ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@ -406,29 +431,6 @@ func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val
return v, det, nil
}
// Cost implements the Coster interface method.
func (gen *progGen) Cost() (min, max int64) {
// Use an empty state value since no evaluation is performed.
p, err := gen.factory(emptyEvalState, nil)
if err != nil {
return 0, math.MaxInt64
}
return estimateCost(p)
}
// EstimateCost returns the heuristic cost interval for the program.
func EstimateCost(p Program) (min, max int64) {
return estimateCost(p)
}
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(interpreter.Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
}
type ctxEvalActivation struct {
parent interpreter.Activation
interrupt <-chan struct{}
@ -438,7 +440,7 @@ type ctxEvalActivation struct {
// ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) {
func (a *ctxEvalActivation) ResolveName(name string) (any, bool) {
if name == "#interrupted" {
a.interruptCheckCount++
if a.interruptCheckCount%a.interruptCheckFrequency == 0 {
@ -461,7 +463,7 @@ func (a *ctxEvalActivation) Parent() interpreter.Activation {
func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
New: func() any {
return &ctxEvalActivation{}
},
},
@ -483,21 +485,21 @@ func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan s
}
type evalActivation struct {
vars map[string]interface{}
lazyVars map[string]interface{}
vars map[string]any
lazyVars map[string]any
}
// ResolveName looks up the value of the input variable name, if found.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() any
// - func() ref.Val
//
// The lazy binding will only be invoked once per evaluation.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
func (a *evalActivation) ResolveName(name string) (any, bool) {
v, found := a.vars[name]
if !found {
return nil, false
@ -510,7 +512,7 @@ func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
lazy := obj()
a.lazyVars[name] = lazy
return lazy, true
case func() interface{}:
case func() any:
if resolved, found := a.lazyVars[name]; found {
return resolved, true
}
@ -530,8 +532,8 @@ func (a *evalActivation) Parent() interpreter.Activation {
func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &evalActivation{lazyVars: make(map[string]interface{})}
New: func() any {
return &evalActivation{lazyVars: make(map[string]any)}
},
},
}
@ -542,13 +544,13 @@ type evalActivationPool struct {
}
// Setup initializes a pooled Activation object with the map input.
func (p *evalActivationPool) Setup(vars map[string]interface{}) *evalActivation {
func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation {
a := p.Pool.Get().(*evalActivation)
a.vars = vars
return a
}
func (p *evalActivationPool) Put(value interface{}) {
func (p *evalActivationPool) Put(value any) {
a := value.(*evalActivation)
for k := range a.lazyVars {
delete(a.lazyVars, k)
@ -559,7 +561,7 @@ func (p *evalActivationPool) Put(value interface{}) {
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]interface{} inputs
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()
// ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable

View File

@ -30,7 +30,7 @@ go_library(
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@ -54,7 +54,7 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

View File

@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
@ -173,8 +174,8 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().Type)
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value))
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
identName := ident.GetName()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
@ -185,9 +186,37 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
}
}
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = decls.Bool
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *exprpb.Type {
// Interpret as field selection, first traversing down the operand.
c.check(sel.GetOperand())
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false)
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
// If the target type is 'optional', unwrap it for the sake of this check.
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
@ -199,7 +228,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found {
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found {
resultType = fieldType.Type
}
case kindTypeParam:
@ -212,16 +241,17 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if isDynOrError(targetType) {
resultType = decls.Dyn
} else {
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
}
resultType = decls.Dyn
}
if sel.TestOnly {
resultType = decls.Bool
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return decls.NewOptionalType(resultType)
}
c.setType(e, substitute(c.mappings, resultType, false))
return resultType
}
func (c *checker) checkCall(e *exprpb.Expr) {
@ -229,15 +259,19 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
fnName := call.GetFunction()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
target := call.GetTarget()
// Regular static call with simple name.
if target == nil {
// Check for the existence of the function.
@ -359,6 +393,9 @@ func (c *checker) resolveOverload(
}
if resultType == nil {
for i, arg := range argTypes {
argTypes[i] = substitute(c.mappings, arg, true)
}
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error
return nil
@ -369,16 +406,29 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemType *exprpb.Type
for _, e := range create.GetElements() {
var elemsType *exprpb.Type
optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e))
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(c.location(e), elemsType, elemType)
}
if elemType == nil {
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemType = c.newTypeVar()
elemsType = c.newTypeVar()
}
c.setType(e, decls.NewListType(elemType))
c.setType(e, decls.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@ -392,22 +442,31 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
var keyType *exprpb.Type
var valueType *exprpb.Type
var mapKeyType *exprpb.Type
var mapValueType *exprpb.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
c.check(key)
keyType = c.joinTypes(c.location(key), keyType, c.getType(key))
mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key))
c.check(ent.Value)
valueType = c.joinTypes(c.location(ent.Value), valueType, c.getType(ent.Value))
val := ent.GetValue()
c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(c.location(val), mapValueType, valType)
}
if keyType == nil {
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
keyType = c.newTypeVar()
valueType = c.newTypeVar()
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, decls.NewMapType(keyType, valueType))
c.setType(e, decls.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
@ -449,15 +508,21 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
c.check(value)
fieldType := decls.Error
if t, found := c.lookupFieldType(
c.locationByID(ent.GetId()),
messageType.GetMessageType(),
field); found {
fieldType = t.Type
ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field)
if found {
fieldType = ft.Type
}
if !c.isAssignable(fieldType, c.getType(value)) {
c.errors.fieldTypeMismatch(
c.locationByID(ent.Id), field, fieldType, c.getType(value))
valType := c.getType(value)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType)
}
}
}

View File

@ -92,7 +92,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
@ -258,6 +261,8 @@ type coster struct {
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
@ -280,16 +285,39 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
return 0, false
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate {
c := coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
// CostOption configures flags which affect cost computations.
type CostOption func(*coster) error
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
if hasCost {
c.presenceTestCost = selectAndIdentCost
return nil
}
c.presenceTestCost = CostEstimate{Min: 0, Max: 0}
return nil
}
return c.cost(checker.GetExpr())
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
if err != nil {
return CostEstimate{}, err
}
}
return c.cost(checker.GetExpr()), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@ -340,6 +368,12 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
var sum CostEstimate
if sel.GetTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
@ -503,7 +537,10 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
switch overloadID {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
if overloadID == overloads.ExtFormatString {
return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}

View File

@ -13,7 +13,7 @@ go_library(
],
importpath = "github.com/google/cel-go/checker/decls",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],

View File

@ -16,9 +16,9 @@
package decls
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
emptypb "google.golang.org/protobuf/types/known/emptypb"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
@ -64,6 +64,12 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
ParameterTypes: paramTypes}}}
}
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used
// by type-checking steps after overload resolution.
func NewFunctionType(resultType *exprpb.Type,

View File

@ -226,7 +226,7 @@ func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !proto.Equal(existing, overload) {
if !found || !overloadsEqual(existing, overload) {
newOverloads = append(newOverloads, overload)
}
}
@ -264,6 +264,31 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found
}
// overloadsEqual returns whether two overloads have identical signatures.
//
// type parameter names are ignored as they may be specified in any order and have no bearing on overload
// equivalence
func overloadsEqual(o1, o2 *exprpb.Decl_FunctionDecl_Overload) bool {
return o1.GetOverloadId() == o2.GetOverloadId() &&
o1.GetIsInstanceFunction() == o2.GetIsInstanceFunction() &&
paramsEqual(o1.GetParams(), o2.GetParams()) &&
proto.Equal(o1.GetResultType(), o2.GetResultType())
}
// paramsEqual returns whether two lists have equal length and all types are equal
func paramsEqual(p1, p2 []*exprpb.Type) bool {
if len(p1) != len(p2) {
return false
}
for i, a := range p1 {
b := p2[i]
if !proto.Equal(a, b) {
return false
}
}
return true
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {

View File

@ -26,7 +26,7 @@ type semanticAdorner struct {
var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem interface{}) string {
func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
if !isExpr {

View File

@ -287,6 +287,8 @@ func init() {
decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches,
decls.NewOverload(overloads.Matches,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith,

View File

@ -90,6 +90,14 @@ func FormatCheckedType(t *exprpb.Type) string {
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
@ -110,12 +118,39 @@ func isDyn(t *exprpb.Type) bool {
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool {
switch kindOf(t) {
case kindError:
return true
default:
return isDyn(t)
return isError(t) || isDyn(t)
}
func isError(t *exprpb.Type) bool {
return kindOf(t) == kindError
}
func isOptional(t *exprpb.Type) bool {
if kindOf(t) == kindAbstract {
at := t.GetAbstractType()
return at.GetName() == "optional"
}
return false
}
func maybeUnwrapOptional(t *exprpb.Type) (*exprpb.Type, bool) {
if isOptional(t) {
at := t.GetAbstractType()
return at.GetParameterTypes()[0], true
}
return t, false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
return "", false
}
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
@ -236,7 +271,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// substitution for t1, and whether t2 has a type substitution in mapping m.
//
// The type t2 is a valid substitution for t1 if any of the following statements is true
// - t2 has a type substitition (t2sub) equal to t1
// - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {

View File

@ -17,7 +17,7 @@ go_library(
importpath = "github.com/google/cel-go/common",
deps = [
"//common/runes:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_x_text//width:go_default_library",
],
)

View File

@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)
@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@ -13,6 +13,6 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@ -29,7 +29,7 @@ import (
// representation of an expression.
type Adorner interface {
// GetMetadata for the input context.
GetMetadata(ctx interface{}) string
GetMetadata(ctx any) string
}
// Writer manages writing expressions to an internal string.
@ -46,7 +46,7 @@ type emptyDebugAdorner struct {
var emptyAdorner Adorner = &emptyDebugAdorner{}
func (a *emptyDebugAdorner) GetMetadata(e interface{}) string {
func (a *emptyDebugAdorner) GetMetadata(e any) string {
return ""
}
@ -170,6 +170,9 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(":")
w.Buffer(entry.GetValue())
@ -191,6 +194,9 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.append(":")
w.Buffer(entry.GetValue())
@ -269,7 +275,7 @@ func (w *debugWriter) append(s string) {
w.buffer.WriteString(s)
}
func (w *debugWriter) appendFormat(f string, args ...interface{}) {
func (w *debugWriter) appendFormat(f string, args ...any) {
w.append(fmt.Sprintf(f, args...))
}
@ -280,7 +286,7 @@ func (w *debugWriter) doIndent() {
}
}
func (w *debugWriter) adorn(e interface{}) {
func (w *debugWriter) adorn(e any) {
w.append(w.adorner.GetMetadata(e))
}

View File

@ -38,7 +38,7 @@ func NewErrors(source Source) *Errors {
}
// ReportError records an error at a source location.
func (e *Errors) ReportError(l Location, format string, args ...interface{}) {
func (e *Errors) ReportError(l Location, format string, args ...any) {
e.numErrors++
if e.numErrors > e.maxErrorsToReport {
return

View File

@ -37,6 +37,8 @@ const (
Modulo = "_%_"
Negate = "-_"
Index = "_[_]"
OptIndex = "_[?_]"
OptSelect = "_?._"
// Macros, must have a valid identifier.
Has = "has"
@ -99,6 +101,8 @@ var (
LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2},
OptIndex: {displayName: "", precedence: 1, arity: 2},
OptSelect: {displayName: "", precedence: 1, arity: 2},
}
)

View File

@ -148,6 +148,11 @@ const (
StartsWith = "startsWith"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtQuoteString = "strings_quote"
)
// String function overload names.
const (
ContainsString = "contains_string"
@ -156,6 +161,11 @@ const (
StartsWithString = "starts_with_string"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtFormatString = "string_format"
)
// Time-based functions.
const (
TimeGetFullYear = "getFullYear"

View File

@ -22,6 +22,7 @@ go_library(
"map.go",
"null.go",
"object.go",
"optional.go",
"overflow.go",
"provider.go",
"string.go",
@ -38,10 +39,8 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -68,6 +67,7 @@ go_test(
"map_test.go",
"null_test.go",
"object_test.go",
"optional_test.go",
"provider_test.go",
"string_test.go",
"timestamp_test.go",
@ -80,7 +80,7 @@ go_test(
"//common/types/ref:go_default_library",
"//test:go_default_library",
"//test/proto3pb:test_all_types_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",

View File

@ -62,7 +62,7 @@ func (b Bool) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bool) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bool) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Bool:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -114,6 +114,11 @@ func (b Bool) Equal(other ref.Val) ref.Val {
return Bool(ok && b == otherBool)
}
// IsZeroValue returns true if the boolean value is false.
func (b Bool) IsZeroValue() bool {
return b == False
}
// Negate implements the traits.Negater interface method.
func (b Bool) Negate() ref.Val {
return !b
@ -125,7 +130,7 @@ func (b Bool) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bool) Value() interface{} {
func (b Bool) Value() any {
return bool(b)
}

View File

@ -63,7 +63,7 @@ func (b Bytes) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Array, reflect.Slice:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -116,6 +116,11 @@ func (b Bytes) Equal(other ref.Val) ref.Val {
return Bool(ok && bytes.Equal(b, otherBytes))
}
// IsZeroValue returns true if the byte array is empty.
func (b Bytes) IsZeroValue() bool {
return len(b) == 0
}
// Size implements the traits.Sizer interface method.
func (b Bytes) Size() ref.Val {
return Int(len(b))
@ -127,6 +132,6 @@ func (b Bytes) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bytes) Value() interface{} {
func (b Bytes) Value() any {
return []byte(b)
}

View File

@ -78,7 +78,7 @@ func (d Double) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Double) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Double) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Float32:
v := float32(d)
@ -134,13 +134,13 @@ func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
case IntType:
i, err := doubleToInt64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(i)
case UintType:
i, err := doubleToUint64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(i)
case DoubleType:
@ -182,6 +182,11 @@ func (d Double) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if double value is 0.0
func (d Double) IsZeroValue() bool {
return float64(d) == 0.0
}
// Multiply implements traits.Multiplier.Multiply.
func (d Double) Multiply(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
@ -211,6 +216,6 @@ func (d Double) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Double) Value() interface{} {
func (d Double) Value() any {
return float64(d)
}

View File

@ -57,14 +57,14 @@ func (d Duration) Add(other ref.Val) ref.Val {
dur2 := other.(Duration)
val, err := addDurationChecked(d.Duration, dur2.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
case TimestampType:
ts := other.(Timestamp).Time
val, err := addTimeDurationChecked(ts, d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return timestampOf(val)
}
@ -90,7 +90,7 @@ func (d Duration) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Duration) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Duration) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the duration is already assignable to the desired type return it.
if reflect.TypeOf(d.Duration).AssignableTo(typeDesc) {
return d.Duration, nil
@ -138,11 +138,16 @@ func (d Duration) Equal(other ref.Val) ref.Val {
return Bool(ok && d.Duration == otherDur.Duration)
}
// IsZeroValue returns true if the duration value is zero
func (d Duration) IsZeroValue() bool {
return d.Duration == 0
}
// Negate implements traits.Negater.Negate.
func (d Duration) Negate() ref.Val {
val, err := negateDurationChecked(d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@ -165,7 +170,7 @@ func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractDurationChecked(d.Duration, subtraDur.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@ -176,7 +181,7 @@ func (d Duration) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Duration) Value() interface{} {
func (d Duration) Value() any {
return d.Duration
}

View File

@ -22,6 +22,12 @@ import (
"github.com/google/cel-go/common/types/ref"
)
// Error interface which allows types types.Err values to be treated as error values.
type Error interface {
error
ref.Val
}
// Err type which extends the built-in go error and implements ref.Val.
type Err struct {
error
@ -51,7 +57,7 @@ var (
// NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...interface{}) ref.Val {
func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)}
}
@ -62,7 +68,7 @@ func NoSuchOverloadErr() ref.Val {
// UnsupportedRefValConversionErr returns a types.NewErr instance with a no such conversion
// message that indicates that the native value could not be converted to a CEL ref.Val.
func UnsupportedRefValConversionErr(val interface{}) ref.Val {
func UnsupportedRefValConversionErr(val any) ref.Val {
return NewErr("unsupported conversion to ref.Val: (%T)%v", val, val)
}
@ -74,20 +80,20 @@ func MaybeNoSuchOverloadErr(val ref.Val) ref.Val {
// ValOrErr either returns the existing error or creates a new one.
// TODO: Audit the use of this function and standardize the error messages and codes.
func ValOrErr(val ref.Val, format string, args ...interface{}) ref.Val {
func ValOrErr(val ref.Val, format string, args ...any) ref.Val {
if val == nil || !IsUnknownOrError(val) {
return NewErr(format, args...)
}
return val
}
// wrapErr wraps an existing Go error value into a CEL Err value.
func wrapErr(err error) ref.Val {
// WrapErr wraps an existing Go error value into a CEL Err value.
func WrapErr(err error) ref.Val {
return &Err{error: err}
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (e *Err) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (e *Err) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, e.error
}
@ -114,10 +120,15 @@ func (e *Err) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (e *Err) Value() interface{} {
func (e *Err) Value() any {
return e.error
}
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()
}
// IsError returns whether the input element ref.Type or ref.Val is equal to
// the ErrType singleton.
func IsError(val ref.Val) bool {

View File

@ -66,7 +66,7 @@ func (i Int) Add(other ref.Val) ref.Val {
}
val, err := addInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -89,7 +89,7 @@ func (i Int) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (i Int) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int, reflect.Int32:
// Enums are also mapped as int32 derivations.
@ -176,7 +176,7 @@ func (i Int) ConvertToType(typeVal ref.Type) ref.Val {
case UintType:
u, err := int64ToUint64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(u)
case DoubleType:
@ -204,7 +204,7 @@ func (i Int) Divide(other ref.Val) ref.Val {
}
val, err := divideInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -226,6 +226,11 @@ func (i Int) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if integer is equal to 0
func (i Int) IsZeroValue() bool {
return i == IntZero
}
// Modulo implements traits.Modder.Modulo.
func (i Int) Modulo(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
@ -234,7 +239,7 @@ func (i Int) Modulo(other ref.Val) ref.Val {
}
val, err := moduloInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -247,7 +252,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
}
val, err := multiplyInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -256,7 +261,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
func (i Int) Negate() ref.Val {
val, err := negateInt64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -269,7 +274,7 @@ func (i Int) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractInt64Checked(int64(i), int64(subtraInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@ -280,7 +285,7 @@ func (i Int) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (i Int) Value() interface{} {
func (i Int) Value() any {
return int64(i)
}

View File

@ -34,7 +34,7 @@ var (
// interpreter.
type baseIterator struct{}
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion on iterators not supported")
}
@ -50,6 +50,6 @@ func (*baseIterator) Type() ref.Type {
return IteratorType
}
func (*baseIterator) Value() interface{} {
func (*baseIterator) Value() any {
return nil
}

View File

@ -25,4 +25,5 @@ var (
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
)

View File

@ -17,11 +17,13 @@ package types
import (
"fmt"
"reflect"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -40,13 +42,13 @@ var (
// NewDynamicList returns a traits.Lister with heterogenous elements.
// value should be an array of "native" types, i.e. any type that
// NativeToValue() can convert to a ref.Val.
func NewDynamicList(adapter ref.TypeAdapter, value interface{}) traits.Lister {
func NewDynamicList(adapter ref.TypeAdapter, value any) traits.Lister {
refValue := reflect.ValueOf(value)
return &baseList{
TypeAdapter: adapter,
value: value,
size: refValue.Len(),
get: func(i int) interface{} {
get: func(i int) any {
return refValue.Index(i).Interface()
},
}
@ -58,7 +60,7 @@ func NewStringList(adapter ref.TypeAdapter, elems []string) traits.Lister {
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
get: func(i int) any { return elems[i] },
}
}
@ -70,7 +72,7 @@ func NewRefValList(adapter ref.TypeAdapter, elems []ref.Val) traits.Lister {
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
get: func(i int) any { return elems[i] },
}
}
@ -80,7 +82,7 @@ func NewProtoList(adapter ref.TypeAdapter, list protoreflect.List) traits.Lister
TypeAdapter: adapter,
value: list,
size: list.Len(),
get: func(i int) interface{} { return list.Get(i).Interface() },
get: func(i int) any { return list.Get(i).Interface() },
}
}
@ -91,22 +93,25 @@ func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
TypeAdapter: adapter,
value: l,
size: len(vals),
get: func(i int) interface{} { return vals[i] },
get: func(i int) any { return vals[i] },
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
var mutableValues []ref.Val
return &mutableList{
l := &mutableList{
baseList: &baseList{
TypeAdapter: adapter,
value: mutableValues,
size: 0,
get: func(i int) interface{} { return mutableValues[i] },
},
mutableValues: mutableValues,
}
l.get = func(i int) any {
return l.mutableValues[i]
}
return l
}
// baseList points to a list containing elements of any type.
@ -114,7 +119,7 @@ func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type baseList struct {
ref.TypeAdapter
value interface{}
value any
// size indicates the number of elements within the list.
// Since objects are immutable the size of a list is static.
@ -122,7 +127,7 @@ type baseList struct {
// get returns a value at the specified integer index.
// The index is guaranteed to be checked against the list index range.
get func(int) interface{}
get func(int) any
}
// Add implements the traits.Adder interface method.
@ -157,7 +162,7 @@ func (l *baseList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the underlying list value is assignable to the reflected type return it.
if reflect.TypeOf(l.value).AssignableTo(typeDesc) {
return l.value, nil
@ -240,7 +245,7 @@ func (l *baseList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *baseList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@ -250,6 +255,11 @@ func (l *baseList) Get(index ref.Val) ref.Val {
return l.NativeToValue(l.get(ind))
}
// IsZeroValue returns true if the list is empty.
func (l *baseList) IsZeroValue() bool {
return l.size == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l)
@ -266,10 +276,24 @@ func (l *baseList) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (l *baseList) Value() interface{} {
func (l *baseList) Value() any {
return l.value
}
// String converts the list to a human readable string form.
func (l *baseList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := 0; i < l.size; i++ {
sb.WriteString(fmt.Sprintf("%v", l.get(i)))
if i != l.size-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
*baseList
@ -305,7 +329,7 @@ func (l *mutableList) ToImmutableList() traits.Lister {
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type concatList struct {
ref.TypeAdapter
value interface{}
value any
prevList traits.Lister
nextList traits.Lister
}
@ -351,8 +375,8 @@ func (l *concatList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]interface{}))
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (any, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]any))
return combined.ConvertToNative(typeDesc)
}
@ -396,7 +420,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *concatList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@ -408,6 +432,11 @@ func (l *concatList) Get(index ref.Val) ref.Val {
return l.nextList.Get(offset)
}
// IsZeroValue returns true if the list is empty.
func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l)
@ -418,15 +447,29 @@ func (l *concatList) Size() ref.Val {
return l.prevList.Size().(Int).Add(l.nextList.Size())
}
// String converts the concatenated list to a human-readable string.
func (l *concatList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := Int(0); i < l.Size().(Int); i++ {
sb.WriteString(fmt.Sprintf("%v", l.Get(i)))
if i != l.Size().(Int)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// Type implements the ref.Val interface method.
func (l *concatList) Type() ref.Type {
return ListType
}
// Value implements the ref.Val interface method.
func (l *concatList) Value() interface{} {
func (l *concatList) Value() any {
if l.value == nil {
merged := make([]interface{}, l.Size().(Int))
merged := make([]any, l.Size().(Int))
prevLen := l.prevList.Size().(Int)
for i := Int(0); i < prevLen; i++ {
merged[i] = l.prevList.Get(i).Value()
@ -469,7 +512,8 @@ func (it *listIterator) Next() ref.Val {
return nil
}
func indexOrError(index ref.Val) (int, error) {
// IndexOrError converts an input index value into either a lossless integer index or an error.
func IndexOrError(index ref.Val) (int, error) {
switch iv := index.(type) {
case Int:
return int(iv), nil

View File

@ -17,20 +17,22 @@ package types
import (
"fmt"
"reflect"
"strings"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper {
func NewDynamicMap(adapter ref.TypeAdapter, value any) traits.Mapper {
refValue := reflect.ValueOf(value)
return &baseMap{
TypeAdapter: adapter,
@ -65,7 +67,7 @@ func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Map
}
// NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values.
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]interface{}) traits.Mapper {
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]any) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
mapAccessor: newStringIfaceMapAccessor(adapter, value),
@ -125,7 +127,7 @@ type baseMap struct {
mapAccessor
// value is the native Go value upon which the map type operators.
value interface{}
value any
// size is the number of entries in the map.
size int
@ -138,7 +140,7 @@ func (m *baseMap) Contains(index ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
if reflect.TypeOf(m.value).AssignableTo(typeDesc) {
@ -275,18 +277,42 @@ func (m *baseMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *baseMap) IsZeroValue() bool {
return m.size == 0
}
// Size implements the traits.Sizer interface method.
func (m *baseMap) Size() ref.Val {
return Int(m.size)
}
// String converts the map into a human-readable string.
func (m *baseMap) String() string {
var sb strings.Builder
sb.WriteString("{")
it := m.Iterator()
i := 0
for it.HasNext() == True {
k := it.Next()
v, _ := m.Find(k)
sb.WriteString(fmt.Sprintf("%v: %v", k, v))
if i != m.size-1 {
sb.WriteString(", ")
}
i++
}
sb.WriteString("}")
return sb.String()
}
// Type implements the ref.Val interface method.
func (m *baseMap) Type() ref.Type {
return MapType
}
// Value implements the ref.Val interface method.
func (m *baseMap) Value() interface{} {
func (m *baseMap) Value() any {
return m.value
}
@ -498,7 +524,7 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interface{}) mapAccessor {
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
TypeAdapter: adapter,
mapVal: mapVal,
@ -507,7 +533,7 @@ func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interf
type stringIfaceMapAccessor struct {
ref.TypeAdapter
mapVal map[string]interface{}
mapVal map[string]any
}
// Find uses native map accesses to find the key, returning (value, true) if present.
@ -556,7 +582,7 @@ func (m *protoMap) Contains(key ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method.
//
// Note, assignment to Golang struct types is not yet supported.
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
switch typeDesc {
@ -601,9 +627,9 @@ func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
m.value.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
ntvKey := key.Interface()
ntvVal := val.Interface()
switch ntvVal.(type) {
switch pv := ntvVal.(type) {
case protoreflect.Message:
ntvVal = ntvVal.(protoreflect.Message).Interface()
ntvVal = pv.Interface()
}
if keyType == otherKeyType && valType == otherValType {
mapVal.SetMapIndex(reflect.ValueOf(ntvKey), reflect.ValueOf(ntvVal))
@ -732,6 +758,11 @@ func (m *protoMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *protoMap) IsZeroValue() bool {
return m.value.Len() == 0
}
// Iterator implements the traits.Iterable interface method.
func (m *protoMap) Iterator() traits.Iterator {
// Copy the keys to make their order stable.
@ -758,7 +789,7 @@ func (m *protoMap) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (m *protoMap) Value() interface{} {
func (m *protoMap) Value() any {
return m.value
}

View File

@ -18,9 +18,10 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@ -34,14 +35,20 @@ var (
// NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE)
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
// golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int32:
return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil
switch typeDesc {
case jsonNullType:
return structpb.NullValue_NULL_VALUE, nil
case nullReflectType:
return n, nil
}
case reflect.Ptr:
switch typeDesc {
case anyValueType:
@ -54,6 +61,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return anypb.New(pb.(proto.Message))
case jsonValueType:
return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
return nil, nil
}
case reflect.Interface:
nv := n.Value()
@ -86,12 +97,17 @@ func (n Null) Equal(other ref.Val) ref.Val {
return Bool(NullType == other.Type())
}
// IsZeroValue returns true as null always represents an absent value.
func (n Null) IsZeroValue() bool {
return true
}
// Type implements ref.Val.Type.
func (n Null) Type() ref.Type {
return NullType
}
// Value implements ref.Val.Value.
func (n Null) Value() interface{} {
func (n Null) Value() any {
return structpb.NullValue_NULL_VALUE
}

View File

@ -18,11 +18,12 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@ -52,7 +53,7 @@ func NewObject(adapter ref.TypeAdapter,
typeValue: typeValue}
}
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
srcPB := o.value
if reflect.TypeOf(srcPB).AssignableTo(typeDesc) {
return srcPB, nil
@ -133,6 +134,11 @@ func (o *protoObj) IsSet(field ref.Val) ref.Val {
return False
}
// IsZeroValue returns true if the protobuf object is empty.
func (o *protoObj) IsZeroValue() bool {
return proto.Equal(o.value, o.typeDesc.Zero())
}
func (o *protoObj) Get(index ref.Val) ref.Val {
protoFieldName, ok := index.(String)
if !ok {
@ -154,6 +160,6 @@ func (o *protoObj) Type() ref.Type {
return o.typeValue
}
func (o *protoObj) Value() interface{} {
func (o *protoObj) Value() any {
return o.value
}

View File

@ -0,0 +1,108 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"errors"
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
)
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewTypeValue("optional")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}
)
// OptionalOf returns an optional value which wraps a concrete CEL value.
func OptionalOf(value ref.Val) *Optional {
return &Optional{value: value}
}
// Optional value which points to a value if non-empty.
type Optional struct {
value ref.Val
}
// HasValue returns true if the optional has a value.
func (o *Optional) HasValue() bool {
return o.value != nil
}
// GetValue returns the wrapped value contained in the optional.
func (o *Optional) GetValue() ref.Val {
if !o.HasValue() {
return NewErr("optional.none() dereference")
}
return o.value
}
// ConvertToNative implements the ref.Val interface method.
func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !o.HasValue() {
return nil, errors.New("optional.none() dereference")
}
return o.value.ConvertToNative(typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *Optional) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case OptionalType:
return o
case TypeType:
return OptionalType
}
return NewErr("type conversion error from '%s' to '%s'", OptionalType, typeVal)
}
// Equal determines whether the values contained by two optional values are equal.
func (o *Optional) Equal(other ref.Val) ref.Val {
otherOpt, isOpt := other.(*Optional)
if !isOpt {
return False
}
if !o.HasValue() {
return Bool(!otherOpt.HasValue())
}
if !otherOpt.HasValue() {
return False
}
return o.value.Equal(otherOpt.value)
}
func (o *Optional) String() string {
if o.HasValue() {
return fmt.Sprintf("optional(%v)", o.GetValue())
}
return "optional.none()"
}
// Type implements the ref.Val interface method.
func (o *Optional) Type() ref.Type {
return OptionalType
}
// Value returns the underlying 'Value()' of the wrapped value, if present.
func (o *Optional) Value() any {
if o.value == nil {
return nil
}
return o.value.Value()
}

View File

@ -17,7 +17,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/types/pb",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@ -18,9 +18,9 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)
// NewEnumValueDescription produces an enum value description with the fully qualified enum value
// newEnumValueDescription produces an enum value description with the fully qualified enum value
// name and the enum value descriptor.
func NewEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
func newEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
return &EnumValueDescription{
enumValueName: name,
desc: desc,

View File

@ -18,32 +18,66 @@ import (
"fmt"
"google.golang.org/protobuf/reflect/protoreflect"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
)
// NewFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values declared within any scope in the file.
func NewFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) *FileDescription {
// newFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values, as well as a map of extensions declared within any scope in the file.
func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDescription, extensionMap) {
metadata := collectFileMetadata(fileDesc)
enums := make(map[string]*EnumValueDescription)
for name, enumVal := range metadata.enumValues {
enums[name] = NewEnumValueDescription(name, enumVal)
enums[name] = newEnumValueDescription(name, enumVal)
}
types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes {
types[name] = NewTypeDescription(name, msgType)
types[name] = newTypeDescription(name, msgType, pbdb.extensions)
}
fileExtMap := make(extensionMap)
for typeName, extensions := range metadata.msgExtensionMap {
messageExtMap, found := fileExtMap[typeName]
if !found {
messageExtMap = make(map[string]*FieldDescription)
}
for _, ext := range extensions {
extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor()
messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc)
}
fileExtMap[typeName] = messageExtMap
}
return &FileDescription{
name: fileDesc.Path(),
types: types,
enums: enums,
}
}, fileExtMap
}
// FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct {
name string
types map[string]*TypeDescription
enums map[string]*EnumValueDescription
}
// Copy creates a copy of the FileDescription with updated Db references within its types.
func (fd *FileDescription) Copy(pbdb *Db) *FileDescription {
typesCopy := make(map[string]*TypeDescription, len(fd.types))
for k, v := range fd.types {
typesCopy[k] = v.Copy(pbdb)
}
return &FileDescription{
name: fd.name,
types: typesCopy,
enums: fd.enums,
}
}
// GetName returns the fully qualified file path for the file.
func (fd *FileDescription) GetName() string {
return fd.name
}
// GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) {
@ -94,6 +128,10 @@ type fileMetadata struct {
msgTypes map[string]protoreflect.MessageDescriptor
// enumValues maps from fully-qualified enum value to enum value descriptor.
enumValues map[string]protoreflect.EnumValueDescriptor
// msgExtensionMap maps from the protobuf message name being extended to a set of extensions
// for the type.
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor
// TODO: support enum type definitions for use in future type-check enhancements.
}
@ -102,28 +140,38 @@ type fileMetadata struct {
func collectFileMetadata(fileDesc protoreflect.FileDescriptor) *fileMetadata {
msgTypes := make(map[string]protoreflect.MessageDescriptor)
enumValues := make(map[string]protoreflect.EnumValueDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues)
msgExtensionMap := make(map[string][]protoreflect.ExtensionDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues, msgExtensionMap)
collectEnumValues(fileDesc.Enums(), enumValues)
collectExtensions(fileDesc.Extensions(), msgExtensionMap)
return &fileMetadata{
msgTypes: msgTypes,
enumValues: enumValues,
msgTypes: msgTypes,
enumValues: enumValues,
msgExtensionMap: msgExtensionMap,
}
}
// collectMsgTypes recursively collects messages, nested messages, and nested enums into a map of
// fully qualified protobuf names to descriptors.
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors, msgTypeMap map[string]protoreflect.MessageDescriptor, enumValueMap map[string]protoreflect.EnumValueDescriptor) {
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors,
msgTypeMap map[string]protoreflect.MessageDescriptor,
enumValueMap map[string]protoreflect.EnumValueDescriptor,
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < msgTypes.Len(); i++ {
msgType := msgTypes.Get(i)
msgTypeMap[string(msgType.FullName())] = msgType
nestedMsgTypes := msgType.Messages()
if nestedMsgTypes.Len() != 0 {
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap)
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap, msgExtensionMap)
}
nestedEnumTypes := msgType.Enums()
if nestedEnumTypes.Len() != 0 {
collectEnumValues(nestedEnumTypes, enumValueMap)
}
nestedExtensions := msgType.Extensions()
if nestedExtensions.Len() != 0 {
collectExtensions(nestedExtensions, msgExtensionMap)
}
}
}
@ -139,3 +187,16 @@ func collectEnumValues(enumTypes protoreflect.EnumDescriptors, enumValueMap map[
}
}
}
func collectExtensions(extensions protoreflect.ExtensionDescriptors, msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < extensions.Len(); i++ {
ext := extensions.Get(i)
extendsMsg := string(ext.ContainingMessage().FullName())
msgExts, found := msgExtensionMap[extendsMsg]
if !found {
msgExts = []protoreflect.ExtensionDescriptor{}
}
msgExts = append(msgExts, ext)
msgExtensionMap[extendsMsg] = msgExts
}
}

View File

@ -40,13 +40,19 @@ type Db struct {
revFileDescriptorMap map[string]*FileDescription
// files contains the deduped set of FileDescriptions whose types are contained in the pb.Db.
files []*FileDescription
// extensions contains the mapping between a given type name, extension name and its FieldDescription
extensions map[string]map[string]*FieldDescription
}
// extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription
type extensionMap = map[string]map[string]*FieldDescription
var (
// DefaultDb used at evaluation time or unless overridden at check time.
DefaultDb = &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
)
@ -80,6 +86,7 @@ func NewDb() *Db {
pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
// The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow
@ -96,19 +103,34 @@ func NewDb() *Db {
// Copy creates a copy of the current database with its own internal descriptor mapping.
func (pbdb *Db) Copy() *Db {
copy := NewDb()
for k, v := range pbdb.revFileDescriptorMap {
copy.revFileDescriptorMap[k] = v
}
for _, f := range pbdb.files {
for _, fd := range pbdb.files {
hasFile := false
for _, f2 := range copy.files {
if f2 == f {
for _, fd2 := range copy.files {
if fd2 == fd {
hasFile = true
}
}
if !hasFile {
copy.files = append(copy.files, f)
fd = fd.Copy(copy)
copy.files = append(copy.files, fd)
}
for _, enumValName := range fd.GetEnumNames() {
copy.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
copy.revFileDescriptorMap[msgTypeName] = fd
}
copy.revFileDescriptorMap[fd.GetName()] = fd
}
for typeName, extFieldMap := range pbdb.extensions {
copyExtFieldMap, found := copy.extensions[typeName]
if !found {
copyExtFieldMap = make(map[string]*FieldDescription, len(extFieldMap))
}
for extFieldName, fd := range extFieldMap {
copyExtFieldMap[extFieldName] = fd
}
copy.extensions[typeName] = copyExtFieldMap
}
return copy
}
@ -137,17 +159,30 @@ func (pbdb *Db) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) (*FileD
if err == nil {
fileDesc = globalFD
}
fd = NewFileDescription(fileDesc, pbdb)
var fileExtMap extensionMap
fd, fileExtMap = newFileDescription(fileDesc, pbdb)
for _, enumValName := range fd.GetEnumNames() {
pbdb.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
pbdb.revFileDescriptorMap[msgTypeName] = fd
}
pbdb.revFileDescriptorMap[fileDesc.Path()] = fd
pbdb.revFileDescriptorMap[fd.GetName()] = fd
// Return the specific file descriptor registered.
pbdb.files = append(pbdb.files, fd)
// Index the protobuf message extensions from the file into the pbdb
for typeName, extMap := range fileExtMap {
typeExtMap, found := pbdb.extensions[typeName]
if !found {
pbdb.extensions[typeName] = extMap
continue
}
for extName, field := range extMap {
typeExtMap[extName] = field
}
}
return fd, nil
}

Some files were not shown because too many files have changed in this diff Show More