rebase: update kubernetes to 1.28.0 in main

updating kubernetes to 1.28.0
in the main repo.

Signed-off-by: Madhu Rajanna <madhupr007@gmail.com>
This commit is contained in:
Madhu Rajanna 2023-08-17 07:15:28 +02:00 committed by mergify[bot]
parent b2fdc269c3
commit ff3e84ad67
706 changed files with 45252 additions and 16346 deletions

114
go.mod
View File

@ -35,15 +35,15 @@ require (
// //
// when updating k8s.io/kubernetes, make sure to update the replace section too // when updating k8s.io/kubernetes, make sure to update the replace section too
// //
k8s.io/api v0.27.4 k8s.io/api v0.28.0
k8s.io/apimachinery v0.27.4 k8s.io/apimachinery v0.28.0
k8s.io/client-go v12.0.0+incompatible k8s.io/client-go v12.0.0+incompatible
k8s.io/cloud-provider v0.27.4 k8s.io/cloud-provider v0.28.0
k8s.io/klog/v2 v2.100.1 k8s.io/klog/v2 v2.100.1
k8s.io/kubernetes v1.27.4 k8s.io/kubernetes v1.28.0
k8s.io/mount-utils v0.27.4 k8s.io/mount-utils v0.28.0
k8s.io/pod-security-admission v0.0.0 k8s.io/pod-security-admission v0.0.0
k8s.io/utils v0.0.0-20230209194617-a36077c30491 k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
sigs.k8s.io/controller-runtime v0.15.1 sigs.k8s.io/controller-runtime v0.15.1
) )
@ -51,7 +51,7 @@ require (
github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect
github.com/ansel1/merry v1.6.2 // indirect github.com/ansel1/merry v1.6.2 // indirect
github.com/ansel1/merry/v2 v2.0.1 // indirect github.com/ansel1/merry/v2 v2.0.1 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
github.com/armon/go-metrics v0.3.10 // indirect github.com/armon/go-metrics v0.3.10 // indirect
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect
github.com/aws/aws-sdk-go-v2 v1.20.1 // indirect github.com/aws/aws-sdk-go-v2 v1.20.1 // indirect
@ -62,14 +62,14 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v3 v3.2.2 // indirect github.com/cenkalti/backoff/v3 v3.2.2 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.4.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/emicklei/go-restful/v3 v3.9.0 // indirect github.com/emicklei/go-restful/v3 v3.9.0 // indirect
github.com/evanphx/json-patch v4.12.0+incompatible // indirect github.com/evanphx/json-patch v5.6.0+incompatible // indirect
github.com/evanphx/json-patch/v5 v5.6.0 // indirect github.com/evanphx/json-patch/v5 v5.6.0 // indirect
github.com/fatih/color v1.13.0 // indirect github.com/fatih/color v1.13.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect
@ -81,14 +81,14 @@ require (
github.com/go-logr/logr v1.2.4 // indirect github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.1 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.22.3 // indirect github.com/go-openapi/swag v0.22.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/google/cel-go v0.12.6 // indirect github.com/google/cel-go v0.16.0 // indirect
github.com/google/gnostic v0.6.9 // indirect github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gofuzz v1.2.0 // indirect github.com/google/gofuzz v1.2.0 // indirect
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
@ -108,7 +108,7 @@ require (
github.com/hashicorp/vault v1.11.11 // indirect github.com/hashicorp/vault v1.11.11 // indirect
github.com/hashicorp/vault/sdk v0.7.0 // indirect github.com/hashicorp/vault/sdk v0.7.0 // indirect
github.com/imdario/mergo v0.3.13 // indirect github.com/imdario/mergo v0.3.13 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
@ -131,15 +131,15 @@ require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.1 // indirect github.com/prometheus/procfs v0.10.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/spf13/cobra v1.6.0 // indirect github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect
go.etcd.io/etcd/api/v3 v3.5.7 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.7 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect
go.etcd.io/etcd/client/v3 v3.5.7 // indirect go.etcd.io/etcd/client/v3 v3.5.9 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.35.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.35.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.35.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.35.1 // indirect
go.opentelemetry.io/otel v1.10.0 // indirect go.opentelemetry.io/otel v1.10.0 // indirect
@ -151,31 +151,31 @@ require (
go.opentelemetry.io/otel/trace v1.10.0 // indirect go.opentelemetry.io/otel/trace v1.10.0 // indirect
go.opentelemetry.io/proto/otlp v0.19.0 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect
go.uber.org/atomic v1.10.0 // indirect go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.8.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.24.0 // indirect go.uber.org/zap v1.24.0 // indirect
golang.org/x/oauth2 v0.7.0 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/sync v0.2.0 // indirect golang.org/x/sync v0.2.0 // indirect
golang.org/x/term v0.11.0 // indirect golang.org/x/term v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect golang.org/x/text v0.12.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.9.3 // indirect golang.org/x/tools v0.9.3 // indirect
gomodules.xyz/jsonpatch/v2 v2.3.0 // indirect gomodules.xyz/jsonpatch/v2 v2.3.0 // indirect
google.golang.org/api v0.110.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apiextensions-apiserver v0.27.4 // indirect k8s.io/apiextensions-apiserver v0.28.0 // indirect
k8s.io/apiserver v0.27.4 // indirect k8s.io/apiserver v0.28.0 // indirect
k8s.io/component-base v0.27.4 // indirect k8s.io/component-base v0.28.0 // indirect
k8s.io/component-helpers v0.27.4 // indirect k8s.io/component-helpers v0.28.0 // indirect
k8s.io/controller-manager v0.27.4 // indirect k8s.io/controller-manager v0.28.0 // indirect
k8s.io/kms v0.27.4 // indirect k8s.io/kms v0.28.0 // indirect
k8s.io/kube-openapi v0.0.0-20230501164219-8b0f38b5fd1f // indirect k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 // indirect
k8s.io/kubectl v0.0.0 // indirect k8s.io/kubectl v0.0.0 // indirect
k8s.io/kubelet v0.0.0 // indirect k8s.io/kubelet v0.0.0 // indirect
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.1.2 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.1.2 // indirect
@ -194,32 +194,32 @@ replace (
// //
// k8s.io/kubernetes depends on these k8s.io packages, but unversioned // k8s.io/kubernetes depends on these k8s.io packages, but unversioned
// //
k8s.io/api => k8s.io/api v0.27.4 k8s.io/api => k8s.io/api v0.28.0
k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.27.4 k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.28.0
k8s.io/apimachinery => k8s.io/apimachinery v0.27.4 k8s.io/apimachinery => k8s.io/apimachinery v0.28.0
k8s.io/apiserver => k8s.io/apiserver v0.27.4 k8s.io/apiserver => k8s.io/apiserver v0.28.0
k8s.io/cli-runtime => k8s.io/cli-runtime v0.27.4 k8s.io/cli-runtime => k8s.io/cli-runtime v0.28.0
k8s.io/client-go => k8s.io/client-go v0.27.4 k8s.io/client-go => k8s.io/client-go v0.28.0
k8s.io/cloud-provider => k8s.io/cloud-provider v0.27.4 k8s.io/cloud-provider => k8s.io/cloud-provider v0.28.0
k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.27.4 k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.28.0
k8s.io/code-generator => k8s.io/code-generator v0.27.4 k8s.io/code-generator => k8s.io/code-generator v0.28.0
k8s.io/component-base => k8s.io/component-base v0.27.4 k8s.io/component-base => k8s.io/component-base v0.28.0
k8s.io/component-helpers => k8s.io/component-helpers v0.27.4 k8s.io/component-helpers => k8s.io/component-helpers v0.28.0
k8s.io/controller-manager => k8s.io/controller-manager v0.27.4 k8s.io/controller-manager => k8s.io/controller-manager v0.28.0
k8s.io/cri-api => k8s.io/cri-api v0.27.4 k8s.io/cri-api => k8s.io/cri-api v0.28.0
k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.27.4 k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.28.0
k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.27.4 k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.28.0
k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.27.4 k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.28.0
k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.27.4 k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.28.0
k8s.io/kube-proxy => k8s.io/kube-proxy v0.27.4 k8s.io/kube-proxy => k8s.io/kube-proxy v0.28.0
k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.27.4 k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.28.0
k8s.io/kubectl => k8s.io/kubectl v0.27.4 k8s.io/kubectl => k8s.io/kubectl v0.28.0
k8s.io/kubelet => k8s.io/kubelet v0.27.4 k8s.io/kubelet => k8s.io/kubelet v0.28.0
k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.27.4 k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.28.0
k8s.io/metrics => k8s.io/metrics v0.27.4 k8s.io/metrics => k8s.io/metrics v0.28.0
k8s.io/mount-utils => k8s.io/mount-utils v0.27.4 k8s.io/mount-utils => k8s.io/mount-utils v0.28.0
k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.27.4 k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.28.0
k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.27.4 k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.28.0
// layeh.com seems to be misbehaving // layeh.com seems to be misbehaving
layeh.com/radius => github.com/layeh/radius v0.0.0-20190322222518-890bc1058917 layeh.com/radius => github.com/layeh/radius v0.0.0-20190322222518-890bc1058917
) )

1095
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,68 @@
/*
Package antlr implements the Go version of the ANTLR 4 runtime.
# The ANTLR Tool
ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing,
or translating structured text or binary files. It's widely used to build languages, tools, and frameworks.
From a grammar, ANTLR generates a parser that can build parse trees and also generates a listener interface
(or visitor) that makes it easy to respond to the recognition of phrases of interest.
# Code Generation
ANTLR supports the generation of code in a number of [target languages], and the generated code is supported by a
runtime library, written specifically to support the generated code in the target language. This library is the
runtime for the Go target.
To generate code for the go target, it is generally recommended to place the source grammar files in a package of
their own, and use the `.sh` script method of generating code, using the go generate directive. In that same directory
it is usual, though not required, to place the antlr tool that should be used to generate the code. That does mean
that the antlr tool JAR file will be checked in to your source code control though, so you are free to use any other
way of specifying the version of the ANTLR tool to use, such as aliasing in `.zshrc` or equivalent, or a profile in
your IDE, or configuration in your CI system.
Here is a general template for an ANTLR based recognizer in Go:
.
myproject
parser
mygrammar.g4
antlr-4.12.0-complete.jar
error_listeners.go
generate.go
generate.sh
go.mod
go.sum
main.go
main_test.go
Make sure that the package statement in your grammar file(s) reflects the go package they exist in.
The generate.go file then looks like this:
package parser
//go:generate ./generate.sh
And the generate.sh file will look similar to this:
#!/bin/sh
alias antlr4='java -Xmx500M -cp "./antlr4-4.12.0-complete.jar:$CLASSPATH" org.antlr.v4.Tool'
antlr4 -Dlanguage=Go -no-visitor -package parser *.g4
depending on whether you want visitors or listeners or any other ANTLR options.
From the command line at the root of your package myproject you can then simply issue the command:
go generate ./...
# Copyright Notice
Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
Use of this file is governed by the BSD 3-clause license, which can be found in the [LICENSE.txt] file in the project root.
[target languages]: https://github.com/antlr/antlr4/tree/master/runtime
[LICENSE.txt]: https://github.com/antlr/antlr4/blob/master/LICENSE.txt
*/
package antlr

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -6,11 +6,24 @@ package antlr
import "sync" import "sync"
// ATNInvalidAltNumber is used to represent an ALT number that has yet to be calculated or
// which is invalid for a particular struct such as [*antlr.BaseRuleContext]
var ATNInvalidAltNumber int var ATNInvalidAltNumber int
// ATN represents an “[Augmented Transition Network]”, though general in ANTLR the term
// “Augmented Recursive Transition Network” though there are some descriptions of “[Recursive Transition Network]”
// in existence.
//
// ATNs represent the main networks in the system and are serialized by the code generator and support [ALL(*)].
//
// [Augmented Transition Network]: https://en.wikipedia.org/wiki/Augmented_transition_network
// [ALL(*)]: https://www.antlr.org/papers/allstar-techreport.pdf
// [Recursive Transition Network]: https://en.wikipedia.org/wiki/Recursive_transition_network
type ATN struct { type ATN struct {
// DecisionToState is the decision points for all rules, subrules, optional // DecisionToState is the decision points for all rules, subrules, optional
// blocks, ()+, ()*, etc. Used to build DFA predictors for them. // blocks, ()+, ()*, etc. Each subrule/rule is a decision point, and we must track them so we
// can go back later and build DFA predictors for them. This includes
// all the rules, subrules, optional blocks, ()+, ()* etc...
DecisionToState []DecisionState DecisionToState []DecisionState
// grammarType is the ATN type and is used for deserializing ATNs from strings. // grammarType is the ATN type and is used for deserializing ATNs from strings.
@ -45,6 +58,8 @@ type ATN struct {
edgeMu sync.RWMutex edgeMu sync.RWMutex
} }
// NewATN returns a new ATN struct representing the given grammarType and is used
// for runtime deserialization of ATNs from the code generated by the ANTLR tool
func NewATN(grammarType int, maxTokenType int) *ATN { func NewATN(grammarType int, maxTokenType int) *ATN {
return &ATN{ return &ATN{
grammarType: grammarType, grammarType: grammarType,
@ -53,7 +68,7 @@ func NewATN(grammarType int, maxTokenType int) *ATN {
} }
} }
// NextTokensInContext computes the set of valid tokens that can occur starting // NextTokensInContext computes and returns the set of valid tokens that can occur starting
// in state s. If ctx is nil, the set of tokens will not include what can follow // in state s. If ctx is nil, the set of tokens will not include what can follow
// the rule surrounding s. In other words, the set will be restricted to tokens // the rule surrounding s. In other words, the set will be restricted to tokens
// reachable staying within the rule of s. // reachable staying within the rule of s.
@ -61,8 +76,8 @@ func (a *ATN) NextTokensInContext(s ATNState, ctx RuleContext) *IntervalSet {
return NewLL1Analyzer(a).Look(s, nil, ctx) return NewLL1Analyzer(a).Look(s, nil, ctx)
} }
// NextTokensNoContext computes the set of valid tokens that can occur starting // NextTokensNoContext computes and returns the set of valid tokens that can occur starting
// in s and staying in same rule. Token.EPSILON is in set if we reach end of // in state s and staying in same rule. [antlr.Token.EPSILON] is in set if we reach end of
// rule. // rule.
func (a *ATN) NextTokensNoContext(s ATNState) *IntervalSet { func (a *ATN) NextTokensNoContext(s ATNState) *IntervalSet {
a.mu.Lock() a.mu.Lock()
@ -76,6 +91,8 @@ func (a *ATN) NextTokensNoContext(s ATNState) *IntervalSet {
return iset return iset
} }
// NextTokens computes and returns the set of valid tokens starting in state s, by
// calling either [NextTokensNoContext] (ctx == nil) or [NextTokensInContext] (ctx != nil).
func (a *ATN) NextTokens(s ATNState, ctx RuleContext) *IntervalSet { func (a *ATN) NextTokens(s ATNState, ctx RuleContext) *IntervalSet {
if ctx == nil { if ctx == nil {
return a.NextTokensNoContext(s) return a.NextTokensNoContext(s)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -8,19 +8,14 @@ import (
"fmt" "fmt"
) )
type comparable interface {
equals(other interface{}) bool
}
// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic // ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
// context). The syntactic context is a graph-structured stack node whose // context). The syntactic context is a graph-structured stack node whose
// path(s) to the root is the rule invocation(s) chain used to arrive at the // path(s) to the root is the rule invocation(s) chain used to arrive at the
// state. The semantic context is the tree of semantic predicates encountered // state. The semantic context is the tree of semantic predicates encountered
// before reaching an ATN state. // before reaching an ATN state.
type ATNConfig interface { type ATNConfig interface {
comparable Equals(o Collectable[ATNConfig]) bool
Hash() int
hash() int
GetState() ATNState GetState() ATNState
GetAlt() int GetAlt() int
@ -47,7 +42,7 @@ type BaseATNConfig struct {
reachesIntoOuterContext int reachesIntoOuterContext int
} }
func NewBaseATNConfig7(old *BaseATNConfig) *BaseATNConfig { // TODO: Dup func NewBaseATNConfig7(old *BaseATNConfig) ATNConfig { // TODO: Dup
return &BaseATNConfig{ return &BaseATNConfig{
state: old.state, state: old.state,
alt: old.alt, alt: old.alt,
@ -135,11 +130,16 @@ func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) {
b.reachesIntoOuterContext = v b.reachesIntoOuterContext = v
} }
// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
// for a collection.
//
// An ATN configuration is equal to another if both have the same state, they // An ATN configuration is equal to another if both have the same state, they
// predict the same alternative, and syntactic/semantic contexts are the same. // predict the same alternative, and syntactic/semantic contexts are the same.
func (b *BaseATNConfig) equals(o interface{}) bool { func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool {
if b == o { if b == o {
return true return true
} else if o == nil {
return false
} }
var other, ok = o.(*BaseATNConfig) var other, ok = o.(*BaseATNConfig)
@ -153,30 +153,32 @@ func (b *BaseATNConfig) equals(o interface{}) bool {
if b.context == nil { if b.context == nil {
equal = other.context == nil equal = other.context == nil
} else { } else {
equal = b.context.equals(other.context) equal = b.context.Equals(other.context)
} }
var ( var (
nums = b.state.GetStateNumber() == other.state.GetStateNumber() nums = b.state.GetStateNumber() == other.state.GetStateNumber()
alts = b.alt == other.alt alts = b.alt == other.alt
cons = b.semanticContext.equals(other.semanticContext) cons = b.semanticContext.Equals(other.semanticContext)
sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed
) )
return nums && alts && cons && sups && equal return nums && alts && cons && sups && equal
} }
func (b *BaseATNConfig) hash() int { // Hash is the default hash function for BaseATNConfig, when no specialist hash function
// is required for a collection
func (b *BaseATNConfig) Hash() int {
var c int var c int
if b.context != nil { if b.context != nil {
c = b.context.hash() c = b.context.Hash()
} }
h := murmurInit(7) h := murmurInit(7)
h = murmurUpdate(h, b.state.GetStateNumber()) h = murmurUpdate(h, b.state.GetStateNumber())
h = murmurUpdate(h, b.alt) h = murmurUpdate(h, b.alt)
h = murmurUpdate(h, c) h = murmurUpdate(h, c)
h = murmurUpdate(h, b.semanticContext.hash()) h = murmurUpdate(h, b.semanticContext.Hash())
return murmurFinish(h, 4) return murmurFinish(h, 4)
} }
@ -243,7 +245,9 @@ func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *Lex
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)} return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
} }
func (l *LexerATNConfig) hash() int { // Hash is the default hash function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Hash() int {
var f int var f int
if l.passedThroughNonGreedyDecision { if l.passedThroughNonGreedyDecision {
f = 1 f = 1
@ -253,15 +257,20 @@ func (l *LexerATNConfig) hash() int {
h := murmurInit(7) h := murmurInit(7)
h = murmurUpdate(h, l.state.GetStateNumber()) h = murmurUpdate(h, l.state.GetStateNumber())
h = murmurUpdate(h, l.alt) h = murmurUpdate(h, l.alt)
h = murmurUpdate(h, l.context.hash()) h = murmurUpdate(h, l.context.Hash())
h = murmurUpdate(h, l.semanticContext.hash()) h = murmurUpdate(h, l.semanticContext.Hash())
h = murmurUpdate(h, f) h = murmurUpdate(h, f)
h = murmurUpdate(h, l.lexerActionExecutor.hash()) h = murmurUpdate(h, l.lexerActionExecutor.Hash())
h = murmurFinish(h, 6) h = murmurFinish(h, 6)
return h return h
} }
func (l *LexerATNConfig) equals(other interface{}) bool { // Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via
// the default comparator [ObjEqComparator].
func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool {
if l == other {
return true
}
var othert, ok = other.(*LexerATNConfig) var othert, ok = other.(*LexerATNConfig)
if l == other { if l == other {
@ -275,7 +284,7 @@ func (l *LexerATNConfig) equals(other interface{}) bool {
var b bool var b bool
if l.lexerActionExecutor != nil { if l.lexerActionExecutor != nil {
b = !l.lexerActionExecutor.equals(othert.lexerActionExecutor) b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor)
} else { } else {
b = othert.lexerActionExecutor != nil b = othert.lexerActionExecutor != nil
} }
@ -284,10 +293,9 @@ func (l *LexerATNConfig) equals(other interface{}) bool {
return false return false
} }
return l.BaseATNConfig.equals(othert.BaseATNConfig) return l.BaseATNConfig.Equals(othert.BaseATNConfig)
} }
func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool { func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool {
var ds, ok = target.(DecisionState) var ds, ok = target.(DecisionState)

View File

@ -1,24 +1,25 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
package antlr package antlr
import "fmt" import (
"fmt"
)
type ATNConfigSet interface { type ATNConfigSet interface {
hash() int Hash() int
Equals(o Collectable[ATNConfig]) bool
Add(ATNConfig, *DoubleDict) bool Add(ATNConfig, *DoubleDict) bool
AddAll([]ATNConfig) bool AddAll([]ATNConfig) bool
GetStates() Set GetStates() *JStore[ATNState, Comparator[ATNState]]
GetPredicates() []SemanticContext GetPredicates() []SemanticContext
GetItems() []ATNConfig GetItems() []ATNConfig
OptimizeConfigs(interpreter *BaseATNSimulator) OptimizeConfigs(interpreter *BaseATNSimulator)
Equals(other interface{}) bool
Length() int Length() int
IsEmpty() bool IsEmpty() bool
Contains(ATNConfig) bool Contains(ATNConfig) bool
@ -57,7 +58,7 @@ type BaseATNConfigSet struct {
// effectively doubles the number of objects associated with ATNConfigs. All // effectively doubles the number of objects associated with ATNConfigs. All
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when // keys are hashed by (s, i, _, pi), not including the context. Wiped out when
// read-only because a set becomes a DFA state. // read-only because a set becomes a DFA state.
configLookup Set configLookup *JStore[ATNConfig, Comparator[ATNConfig]]
// configs is the added elements. // configs is the added elements.
configs []ATNConfig configs []ATNConfig
@ -83,7 +84,7 @@ type BaseATNConfigSet struct {
// readOnly is whether it is read-only. Do not // readOnly is whether it is read-only. Do not
// allow any code to manipulate the set if true because DFA states will point at // allow any code to manipulate the set if true because DFA states will point at
// sets and those must not change. It not protect other fields; conflictingAlts // sets and those must not change. It not, protect other fields; conflictingAlts
// in particular, which is assigned after readOnly. // in particular, which is assigned after readOnly.
readOnly bool readOnly bool
@ -104,7 +105,7 @@ func (b *BaseATNConfigSet) Alts() *BitSet {
func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet { func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet {
return &BaseATNConfigSet{ return &BaseATNConfigSet{
cachedHash: -1, cachedHash: -1,
configLookup: newArray2DHashSetWithCap(hashATNConfig, equalATNConfigs, 16, 2), configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](aConfCompInst),
fullCtx: fullCtx, fullCtx: fullCtx,
} }
} }
@ -126,9 +127,11 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
b.dipsIntoOuterContext = true b.dipsIntoOuterContext = true
} }
existing := b.configLookup.Add(config).(ATNConfig) existing, present := b.configLookup.Put(config)
if existing == config { // The config was not already in the set
//
if !present {
b.cachedHash = -1 b.cachedHash = -1
b.configs = append(b.configs, config) // Track order here b.configs = append(b.configs, config) // Track order here
return true return true
@ -154,11 +157,14 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
return true return true
} }
func (b *BaseATNConfigSet) GetStates() Set { func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {
states := newArray2DHashSet(nil, nil)
// states uses the standard comparator provided by the ATNState instance
//
states := NewJStore[ATNState, Comparator[ATNState]](aStateEqInst)
for i := 0; i < len(b.configs); i++ { for i := 0; i < len(b.configs); i++ {
states.Add(b.configs[i].GetState()) states.Put(b.configs[i].GetState())
} }
return states return states
@ -214,7 +220,34 @@ func (b *BaseATNConfigSet) AddAll(coll []ATNConfig) bool {
return false return false
} }
func (b *BaseATNConfigSet) Equals(other interface{}) bool { // Compare is a hack function just to verify that adding DFAstares to the known
// set works, so long as comparison of ATNConfigSet s works. For that to work, we
// need to make sure that the set of ATNConfigs in two sets are equivalent. We can't
// know the order, so we do this inefficient hack. If this proves the point, then
// we can change the config set to a better structure.
func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
if len(b.configs) != len(bs.configs) {
return false
}
for _, c := range b.configs {
found := false
for _, c2 := range bs.configs {
if c.Equals(c2) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
if b == other { if b == other {
return true return true
} else if _, ok := other.(*BaseATNConfigSet); !ok { } else if _, ok := other.(*BaseATNConfigSet); !ok {
@ -224,15 +257,15 @@ func (b *BaseATNConfigSet) Equals(other interface{}) bool {
other2 := other.(*BaseATNConfigSet) other2 := other.(*BaseATNConfigSet)
return b.configs != nil && return b.configs != nil &&
// TODO: b.configs.equals(other2.configs) && // TODO: Is b necessary?
b.fullCtx == other2.fullCtx && b.fullCtx == other2.fullCtx &&
b.uniqueAlt == other2.uniqueAlt && b.uniqueAlt == other2.uniqueAlt &&
b.conflictingAlts == other2.conflictingAlts && b.conflictingAlts == other2.conflictingAlts &&
b.hasSemanticContext == other2.hasSemanticContext && b.hasSemanticContext == other2.hasSemanticContext &&
b.dipsIntoOuterContext == other2.dipsIntoOuterContext b.dipsIntoOuterContext == other2.dipsIntoOuterContext &&
b.Compare(other2)
} }
func (b *BaseATNConfigSet) hash() int { func (b *BaseATNConfigSet) Hash() int {
if b.readOnly { if b.readOnly {
if b.cachedHash == -1 { if b.cachedHash == -1 {
b.cachedHash = b.hashCodeConfigs() b.cachedHash = b.hashCodeConfigs()
@ -247,7 +280,7 @@ func (b *BaseATNConfigSet) hash() int {
func (b *BaseATNConfigSet) hashCodeConfigs() int { func (b *BaseATNConfigSet) hashCodeConfigs() int {
h := 1 h := 1
for _, config := range b.configs { for _, config := range b.configs {
h = 31*h + config.hash() h = 31*h + config.Hash()
} }
return h return h
} }
@ -283,7 +316,7 @@ func (b *BaseATNConfigSet) Clear() {
b.configs = make([]ATNConfig, 0) b.configs = make([]ATNConfig, 0)
b.cachedHash = -1 b.cachedHash = -1
b.configLookup = newArray2DHashSet(nil, equalATNConfigs) b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](atnConfCompInst)
} }
func (b *BaseATNConfigSet) FullContext() bool { func (b *BaseATNConfigSet) FullContext() bool {
@ -365,7 +398,8 @@ type OrderedATNConfigSet struct {
func NewOrderedATNConfigSet() *OrderedATNConfigSet { func NewOrderedATNConfigSet() *OrderedATNConfigSet {
b := NewBaseATNConfigSet(false) b := NewBaseATNConfigSet(false)
b.configLookup = newArray2DHashSet(nil, nil) // This set uses the standard Hash() and Equals() from ATNConfig
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
return &OrderedATNConfigSet{BaseATNConfigSet: b} return &OrderedATNConfigSet{BaseATNConfigSet: b}
} }
@ -375,7 +409,7 @@ func hashATNConfig(i interface{}) int {
hash := 7 hash := 7
hash = 31*hash + o.GetState().GetStateNumber() hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt() hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().hash() hash = 31*hash + o.GetSemanticContext().Hash()
return hash return hash
} }
@ -403,5 +437,5 @@ func equalATNConfigs(a, b interface{}) bool {
return false return false
} }
return ai.GetSemanticContext().equals(bi.GetSemanticContext()) return ai.GetSemanticContext().Equals(bi.GetSemanticContext())
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -49,7 +49,8 @@ type ATNState interface {
AddTransition(Transition, int) AddTransition(Transition, int)
String() string String() string
hash() int Hash() int
Equals(Collectable[ATNState]) bool
} }
type BaseATNState struct { type BaseATNState struct {
@ -123,7 +124,7 @@ func (as *BaseATNState) SetNextTokenWithinRule(v *IntervalSet) {
as.NextTokenWithinRule = v as.NextTokenWithinRule = v
} }
func (as *BaseATNState) hash() int { func (as *BaseATNState) Hash() int {
return as.stateNumber return as.stateNumber
} }
@ -131,7 +132,7 @@ func (as *BaseATNState) String() string {
return strconv.Itoa(as.stateNumber) return strconv.Itoa(as.stateNumber)
} }
func (as *BaseATNState) equals(other interface{}) bool { func (as *BaseATNState) Equals(other Collectable[ATNState]) bool {
if ot, ok := other.(ATNState); ok { if ot, ok := other.(ATNState); ok {
return as.stateNumber == ot.GetStateNumber() return as.stateNumber == ot.GetStateNumber()
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -331,10 +331,12 @@ func (c *CommonTokenStream) GetTextFromRuleContext(interval RuleContext) string
func (c *CommonTokenStream) GetTextFromInterval(interval *Interval) string { func (c *CommonTokenStream) GetTextFromInterval(interval *Interval) string {
c.lazyInit() c.lazyInit()
c.Fill()
if interval == nil { if interval == nil {
c.Fill()
interval = NewInterval(0, len(c.tokens)-1) interval = NewInterval(0, len(c.tokens)-1)
} else {
c.Sync(interval.Stop)
} }
start := interval.Start start := interval.Start

View File

@ -0,0 +1,147 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
// This file contains all the implementations of custom comparators used for generic collections when the
// Hash() and Equals() funcs supplied by the struct objects themselves need to be overridden. Normally, we would
// put the comparators in the source file for the struct themselves, but given the organization of this code is
// sorta kinda based upon the Java code, I found it confusing trying to find out which comparator was where and used by
// which instantiation of a collection. For instance, an Array2DHashSet in the Java source, when used with ATNConfig
// collections requires three different comparators depending on what the collection is being used for. Collecting - pun intended -
// all the comparators here, makes it much easier to see which implementation of hash and equals is used by which collection.
// It also makes it easy to verify that the Hash() and Equals() functions marry up with the Java implementations.
// ObjEqComparator is the equivalent of the Java ObjectEqualityComparator, which is the default instance of
// Equality comparator. We do not have inheritance in Go, only interfaces, so we use generics to enforce some
// type safety and avoid having to implement this for every type that we want to perform comparison on.
//
// This comparator works by using the standard Hash() and Equals() methods of the type T that is being compared. Which
// allows us to use it in any collection instance that does nto require a special hash or equals implementation.
type ObjEqComparator[T Collectable[T]] struct{}
var (
aStateEqInst = &ObjEqComparator[ATNState]{}
aConfEqInst = &ObjEqComparator[ATNConfig]{}
aConfCompInst = &ATNConfigComparator[ATNConfig]{}
atnConfCompInst = &BaseATNConfigComparator[ATNConfig]{}
dfaStateEqInst = &ObjEqComparator[*DFAState]{}
semctxEqInst = &ObjEqComparator[SemanticContext]{}
atnAltCfgEqInst = &ATNAltConfigComparator[ATNConfig]{}
)
// Equals2 delegates to the Equals() method of type T
func (c *ObjEqComparator[T]) Equals2(o1, o2 T) bool {
return o1.Equals(o2)
}
// Hash1 delegates to the Hash() method of type T
func (c *ObjEqComparator[T]) Hash1(o T) int {
return o.Hash()
}
type SemCComparator[T Collectable[T]] struct{}
// ATNConfigComparator is used as the compartor for the configLookup field of an ATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type ATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetAlt() == o2.GetAlt() &&
o1.GetSemanticContext().Equals(o2.GetSemanticContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNConfigComparator[T]) Hash1(o ATNConfig) int {
hash := 7
hash = 31*hash + o.GetState().GetStateNumber()
hash = 31*hash + o.GetAlt()
hash = 31*hash + o.GetSemanticContext().Hash()
return hash
}
// ATNAltConfigComparator is used as the comparator for mapping configs to Alt Bitsets
type ATNAltConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetContext().Equals(o2.GetContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup
func (c *ATNAltConfigComparator[T]) Hash1(o ATNConfig) int {
h := murmurInit(7)
h = murmurUpdate(h, o.GetState().GetStateNumber())
h = murmurUpdate(h, o.GetContext().Hash())
return murmurFinish(h, 2)
}
// BaseATNConfigComparator is used as the comparator for the configLookup field of a BaseATNConfigSet
// and has a custom Equals() and Hash() implementation, because equality is not based on the
// standard Hash() and Equals() methods of the ATNConfig type.
type BaseATNConfigComparator[T Collectable[T]] struct {
}
// Equals2 is a custom comparator for ATNConfigs specifically for baseATNConfigSet
func (c *BaseATNConfigComparator[T]) Equals2(o1, o2 ATNConfig) bool {
// Same pointer, must be equal, even if both nil
//
if o1 == o2 {
return true
}
// If either are nil, but not both, then the result is false
//
if o1 == nil || o2 == nil {
return false
}
return o1.GetState().GetStateNumber() == o2.GetState().GetStateNumber() &&
o1.GetAlt() == o2.GetAlt() &&
o1.GetSemanticContext().Equals(o2.GetSemanticContext())
}
// Hash1 is custom hash implementation for ATNConfigs specifically for configLookup, but in fact just
// delegates to the standard Hash() method of the ATNConfig type.
func (c *BaseATNConfigComparator[T]) Hash1(o ATNConfig) int {
return o.Hash()
}

View File

@ -1,13 +1,9 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
package antlr package antlr
import (
"sort"
)
type DFA struct { type DFA struct {
// atnStartState is the ATN state in which this was created // atnStartState is the ATN state in which this was created
atnStartState DecisionState atnStartState DecisionState
@ -15,8 +11,15 @@ type DFA struct {
decision int decision int
// states is all the DFA states. Use Map to get the old state back; Set can only // states is all the DFA states. Use Map to get the old state back; Set can only
// indicate whether it is there. // indicate whether it is there. Go maps implement key hash collisions and so on and are very
states map[int]*DFAState // good, but the DFAState is an object and can't be used directly as the key as it can in say JAva
// amd C#, whereby if the hashcode is the same for two objects, then Equals() is called against them
// to see if they really are the same object.
//
//
states *JStore[*DFAState, *ObjEqComparator[*DFAState]]
numstates int
s0 *DFAState s0 *DFAState
@ -29,7 +32,7 @@ func NewDFA(atnStartState DecisionState, decision int) *DFA {
dfa := &DFA{ dfa := &DFA{
atnStartState: atnStartState, atnStartState: atnStartState,
decision: decision, decision: decision,
states: make(map[int]*DFAState), states: NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst),
} }
if s, ok := atnStartState.(*StarLoopEntryState); ok && s.precedenceRuleDecision { if s, ok := atnStartState.(*StarLoopEntryState); ok && s.precedenceRuleDecision {
dfa.precedenceDfa = true dfa.precedenceDfa = true
@ -92,7 +95,8 @@ func (d *DFA) getPrecedenceDfa() bool {
// true or nil otherwise, and d.precedenceDfa is updated. // true or nil otherwise, and d.precedenceDfa is updated.
func (d *DFA) setPrecedenceDfa(precedenceDfa bool) { func (d *DFA) setPrecedenceDfa(precedenceDfa bool) {
if d.getPrecedenceDfa() != precedenceDfa { if d.getPrecedenceDfa() != precedenceDfa {
d.setStates(make(map[int]*DFAState)) d.states = NewJStore[*DFAState, *ObjEqComparator[*DFAState]](dfaStateEqInst)
d.numstates = 0
if precedenceDfa { if precedenceDfa {
precedenceState := NewDFAState(-1, NewBaseATNConfigSet(false)) precedenceState := NewDFAState(-1, NewBaseATNConfigSet(false))
@ -117,38 +121,12 @@ func (d *DFA) setS0(s *DFAState) {
d.s0 = s d.s0 = s
} }
func (d *DFA) getState(hash int) (*DFAState, bool) {
s, ok := d.states[hash]
return s, ok
}
func (d *DFA) setStates(states map[int]*DFAState) {
d.states = states
}
func (d *DFA) setState(hash int, state *DFAState) {
d.states[hash] = state
}
func (d *DFA) numStates() int {
return len(d.states)
}
type dfaStateList []*DFAState
func (d dfaStateList) Len() int { return len(d) }
func (d dfaStateList) Less(i, j int) bool { return d[i].stateNumber < d[j].stateNumber }
func (d dfaStateList) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
// sortedStates returns the states in d sorted by their state number. // sortedStates returns the states in d sorted by their state number.
func (d *DFA) sortedStates() []*DFAState { func (d *DFA) sortedStates() []*DFAState {
vs := make([]*DFAState, 0, len(d.states))
for _, v := range d.states { vs := d.states.SortedSlice(func(i, j *DFAState) bool {
vs = append(vs, v) return i.stateNumber < j.stateNumber
} })
sort.Sort(dfaStateList(vs))
return vs return vs
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -90,16 +90,16 @@ func NewDFAState(stateNumber int, configs ATNConfigSet) *DFAState {
} }
// GetAltSet gets the set of all alts mentioned by all ATN configurations in d. // GetAltSet gets the set of all alts mentioned by all ATN configurations in d.
func (d *DFAState) GetAltSet() Set { func (d *DFAState) GetAltSet() []int {
alts := newArray2DHashSet(nil, nil) var alts []int
if d.configs != nil { if d.configs != nil {
for _, c := range d.configs.GetItems() { for _, c := range d.configs.GetItems() {
alts.Add(c.GetAlt()) alts = append(alts, c.GetAlt())
} }
} }
if alts.Len() == 0 { if len(alts) == 0 {
return nil return nil
} }
@ -130,27 +130,6 @@ func (d *DFAState) setPrediction(v int) {
d.prediction = v d.prediction = v
} }
// equals returns whether d equals other. Two DFAStates are equal if their ATN
// configuration sets are the same. This method is used to see if a state
// already exists.
//
// Because the number of alternatives and number of ATN configurations are
// finite, there is a finite number of DFA states that can be processed. This is
// necessary to show that the algorithm terminates.
//
// Cannot test the DFA state numbers here because in
// ParserATNSimulator.addDFAState we need to know if any other state exists that
// has d exact set of ATN configurations. The stateNumber is irrelevant.
func (d *DFAState) equals(other interface{}) bool {
if d == other {
return true
} else if _, ok := other.(*DFAState); !ok {
return false
}
return d.configs.Equals(other.(*DFAState).configs)
}
func (d *DFAState) String() string { func (d *DFAState) String() string {
var s string var s string
if d.isAcceptState { if d.isAcceptState {
@ -164,8 +143,27 @@ func (d *DFAState) String() string {
return fmt.Sprintf("%d:%s%s", d.stateNumber, fmt.Sprint(d.configs), s) return fmt.Sprintf("%d:%s%s", d.stateNumber, fmt.Sprint(d.configs), s)
} }
func (d *DFAState) hash() int { func (d *DFAState) Hash() int {
h := murmurInit(7) h := murmurInit(7)
h = murmurUpdate(h, d.configs.hash()) h = murmurUpdate(h, d.configs.Hash())
return murmurFinish(h, 1) return murmurFinish(h, 1)
} }
// Equals returns whether d equals other. Two DFAStates are equal if their ATN
// configuration sets are the same. This method is used to see if a state
// already exists.
//
// Because the number of alternatives and number of ATN configurations are
// finite, there is a finite number of DFA states that can be processed. This is
// necessary to show that the algorithm terminates.
//
// Cannot test the DFA state numbers here because in
// ParserATNSimulator.addDFAState we need to know if any other state exists that
// has d exact set of ATN configurations. The stateNumber is irrelevant.
func (d *DFAState) Equals(o Collectable[*DFAState]) bool {
if d == o {
return true
}
return d.configs.Equals(o.(*DFAState).configs)
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -87,7 +87,6 @@ func (d *DiagnosticErrorListener) getDecisionDescription(recognizer Parser, dfa
return strconv.Itoa(decision) + " (" + ruleName + ")" return strconv.Itoa(decision) + " (" + ruleName + ")"
} }
//
// Computes the set of conflicting or ambiguous alternatives from a // Computes the set of conflicting or ambiguous alternatives from a
// configuration set, if that information was not already provided by the // configuration set, if that information was not already provided by the
// parser. // parser.
@ -97,7 +96,6 @@ func (d *DiagnosticErrorListener) getDecisionDescription(recognizer Parser, dfa
// @param configs The conflicting or ambiguous configuration set. // @param configs The conflicting or ambiguous configuration set.
// @return Returns {@code ReportedAlts} if it is not {@code nil}, otherwise // @return Returns {@code ReportedAlts} if it is not {@code nil}, otherwise
// returns the set of alternatives represented in {@code configs}. // returns the set of alternatives represented in {@code configs}.
//
func (d *DiagnosticErrorListener) getConflictingAlts(ReportedAlts *BitSet, set ATNConfigSet) *BitSet { func (d *DiagnosticErrorListener) getConflictingAlts(ReportedAlts *BitSet, set ATNConfigSet) *BitSet {
if ReportedAlts != nil { if ReportedAlts != nil {
return ReportedAlts return ReportedAlts

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -48,12 +48,9 @@ func NewConsoleErrorListener() *ConsoleErrorListener {
return new(ConsoleErrorListener) return new(ConsoleErrorListener)
} }
//
// Provides a default instance of {@link ConsoleErrorListener}. // Provides a default instance of {@link ConsoleErrorListener}.
//
var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener() var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener()
//
// {@inheritDoc} // {@inheritDoc}
// //
// <p> // <p>
@ -64,7 +61,6 @@ var ConsoleErrorListenerINSTANCE = NewConsoleErrorListener()
// <pre> // <pre>
// line <em>line</em>:<em>charPositionInLine</em> <em>msg</em> // line <em>line</em>:<em>charPositionInLine</em> <em>msg</em>
// </pre> // </pre>
//
func (c *ConsoleErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException) { func (c *ConsoleErrorListener) SyntaxError(recognizer Recognizer, offendingSymbol interface{}, line, column int, msg string, e RecognitionException) {
fmt.Fprintln(os.Stderr, "line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg) fmt.Fprintln(os.Stderr, "line "+strconv.Itoa(line)+":"+strconv.Itoa(column)+" "+msg)
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -23,7 +23,6 @@ type ErrorStrategy interface {
// This is the default implementation of {@link ANTLRErrorStrategy} used for // This is the default implementation of {@link ANTLRErrorStrategy} used for
// error Reporting and recovery in ANTLR parsers. // error Reporting and recovery in ANTLR parsers.
//
type DefaultErrorStrategy struct { type DefaultErrorStrategy struct {
errorRecoveryMode bool errorRecoveryMode bool
lastErrorIndex int lastErrorIndex int
@ -61,12 +60,10 @@ func (d *DefaultErrorStrategy) reset(recognizer Parser) {
d.endErrorCondition(recognizer) d.endErrorCondition(recognizer)
} }
//
// This method is called to enter error recovery mode when a recognition // This method is called to enter error recovery mode when a recognition
// exception is Reported. // exception is Reported.
// //
// @param recognizer the parser instance // @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) beginErrorCondition(recognizer Parser) { func (d *DefaultErrorStrategy) beginErrorCondition(recognizer Parser) {
d.errorRecoveryMode = true d.errorRecoveryMode = true
} }
@ -75,28 +72,23 @@ func (d *DefaultErrorStrategy) InErrorRecoveryMode(recognizer Parser) bool {
return d.errorRecoveryMode return d.errorRecoveryMode
} }
//
// This method is called to leave error recovery mode after recovering from // This method is called to leave error recovery mode after recovering from
// a recognition exception. // a recognition exception.
// //
// @param recognizer // @param recognizer
//
func (d *DefaultErrorStrategy) endErrorCondition(recognizer Parser) { func (d *DefaultErrorStrategy) endErrorCondition(recognizer Parser) {
d.errorRecoveryMode = false d.errorRecoveryMode = false
d.lastErrorStates = nil d.lastErrorStates = nil
d.lastErrorIndex = -1 d.lastErrorIndex = -1
} }
//
// {@inheritDoc} // {@inheritDoc}
// //
// <p>The default implementation simply calls {@link //endErrorCondition}.</p> // <p>The default implementation simply calls {@link //endErrorCondition}.</p>
//
func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) { func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) {
d.endErrorCondition(recognizer) d.endErrorCondition(recognizer)
} }
//
// {@inheritDoc} // {@inheritDoc}
// //
// <p>The default implementation returns immediately if the handler is already // <p>The default implementation returns immediately if the handler is already
@ -114,7 +106,6 @@ func (d *DefaultErrorStrategy) ReportMatch(recognizer Parser) {
// <li>All other types: calls {@link Parser//NotifyErrorListeners} to Report // <li>All other types: calls {@link Parser//NotifyErrorListeners} to Report
// the exception</li> // the exception</li>
// </ul> // </ul>
//
func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionException) { func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionException) {
// if we've already Reported an error and have not Matched a token // if we've already Reported an error and have not Matched a token
// yet successfully, don't Report any errors. // yet successfully, don't Report any errors.
@ -142,7 +133,6 @@ func (d *DefaultErrorStrategy) ReportError(recognizer Parser, e RecognitionExcep
// <p>The default implementation reSynchronizes the parser by consuming tokens // <p>The default implementation reSynchronizes the parser by consuming tokens
// until we find one in the reSynchronization set--loosely the set of tokens // until we find one in the reSynchronization set--loosely the set of tokens
// that can follow the current rule.</p> // that can follow the current rule.</p>
//
func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException) { func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
if d.lastErrorIndex == recognizer.GetInputStream().Index() && if d.lastErrorIndex == recognizer.GetInputStream().Index() &&
@ -206,7 +196,6 @@ func (d *DefaultErrorStrategy) Recover(recognizer Parser, e RecognitionException
// compare token set at the start of the loop and at each iteration. If for // compare token set at the start of the loop and at each iteration. If for
// some reason speed is suffering for you, you can turn off d // some reason speed is suffering for you, you can turn off d
// functionality by simply overriding d method as a blank { }.</p> // functionality by simply overriding d method as a blank { }.</p>
//
func (d *DefaultErrorStrategy) Sync(recognizer Parser) { func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
// If already recovering, don't try to Sync // If already recovering, don't try to Sync
if d.InErrorRecoveryMode(recognizer) { if d.InErrorRecoveryMode(recognizer) {
@ -247,7 +236,6 @@ func (d *DefaultErrorStrategy) Sync(recognizer Parser) {
// //
// @param recognizer the parser instance // @param recognizer the parser instance
// @param e the recognition exception // @param e the recognition exception
//
func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *NoViableAltException) { func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *NoViableAltException) {
tokens := recognizer.GetTokenStream() tokens := recognizer.GetTokenStream()
var input string var input string
@ -264,7 +252,6 @@ func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *N
recognizer.NotifyErrorListeners(msg, e.offendingToken, e) recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
} }
//
// This is called by {@link //ReportError} when the exception is an // This is called by {@link //ReportError} when the exception is an
// {@link InputMisMatchException}. // {@link InputMisMatchException}.
// //
@ -272,14 +259,12 @@ func (d *DefaultErrorStrategy) ReportNoViableAlternative(recognizer Parser, e *N
// //
// @param recognizer the parser instance // @param recognizer the parser instance
// @param e the recognition exception // @param e the recognition exception
//
func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *InputMisMatchException) { func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *InputMisMatchException) {
msg := "mismatched input " + this.GetTokenErrorDisplay(e.offendingToken) + msg := "mismatched input " + this.GetTokenErrorDisplay(e.offendingToken) +
" expecting " + e.getExpectedTokens().StringVerbose(recognizer.GetLiteralNames(), recognizer.GetSymbolicNames(), false) " expecting " + e.getExpectedTokens().StringVerbose(recognizer.GetLiteralNames(), recognizer.GetSymbolicNames(), false)
recognizer.NotifyErrorListeners(msg, e.offendingToken, e) recognizer.NotifyErrorListeners(msg, e.offendingToken, e)
} }
//
// This is called by {@link //ReportError} when the exception is a // This is called by {@link //ReportError} when the exception is a
// {@link FailedPredicateException}. // {@link FailedPredicateException}.
// //
@ -287,7 +272,6 @@ func (this *DefaultErrorStrategy) ReportInputMisMatch(recognizer Parser, e *Inpu
// //
// @param recognizer the parser instance // @param recognizer the parser instance
// @param e the recognition exception // @param e the recognition exception
//
func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *FailedPredicateException) { func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *FailedPredicateException) {
ruleName := recognizer.GetRuleNames()[recognizer.GetParserRuleContext().GetRuleIndex()] ruleName := recognizer.GetRuleNames()[recognizer.GetParserRuleContext().GetRuleIndex()]
msg := "rule " + ruleName + " " + e.message msg := "rule " + ruleName + " " + e.message
@ -310,7 +294,6 @@ func (d *DefaultErrorStrategy) ReportFailedPredicate(recognizer Parser, e *Faile
// {@link Parser//NotifyErrorListeners}.</p> // {@link Parser//NotifyErrorListeners}.</p>
// //
// @param recognizer the parser instance // @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) { func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) { if d.InErrorRecoveryMode(recognizer) {
return return
@ -339,7 +322,6 @@ func (d *DefaultErrorStrategy) ReportUnwantedToken(recognizer Parser) {
// {@link Parser//NotifyErrorListeners}.</p> // {@link Parser//NotifyErrorListeners}.</p>
// //
// @param recognizer the parser instance // @param recognizer the parser instance
//
func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) { func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
if d.InErrorRecoveryMode(recognizer) { if d.InErrorRecoveryMode(recognizer) {
return return
@ -392,15 +374,14 @@ func (d *DefaultErrorStrategy) ReportMissingToken(recognizer Parser) {
// derivation: // derivation:
// //
// <pre> // <pre>
// =&gt ID '=' '(' INT ')' ('+' atom)* '' // =&gt ID '=' '(' INT ')' ('+' atom)*
// ^ // ^
// </pre> // </pre>
// //
// The attempt to Match {@code ')'} will fail when it sees {@code ''} and // The attempt to Match {@code ')'} will fail when it sees {@code } and
// call {@link //recoverInline}. To recover, it sees that {@code LA(1)==''} // call {@link //recoverInline}. To recover, it sees that {@code LA(1)==}
// is in the set of tokens that can follow the {@code ')'} token reference // is in the set of tokens that can follow the {@code ')'} token reference
// in rule {@code atom}. It can assume that you forgot the {@code ')'}. // in rule {@code atom}. It can assume that you forgot the {@code ')'}.
//
func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token { func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
// SINGLE TOKEN DELETION // SINGLE TOKEN DELETION
MatchedSymbol := d.SingleTokenDeletion(recognizer) MatchedSymbol := d.SingleTokenDeletion(recognizer)
@ -418,7 +399,6 @@ func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
panic(NewInputMisMatchException(recognizer)) panic(NewInputMisMatchException(recognizer))
} }
//
// This method implements the single-token insertion inline error recovery // This method implements the single-token insertion inline error recovery
// strategy. It is called by {@link //recoverInline} if the single-token // strategy. It is called by {@link //recoverInline} if the single-token
// deletion strategy fails to recover from the mismatched input. If this // deletion strategy fails to recover from the mismatched input. If this
@ -434,7 +414,6 @@ func (d *DefaultErrorStrategy) RecoverInline(recognizer Parser) Token {
// @param recognizer the parser instance // @param recognizer the parser instance
// @return {@code true} if single-token insertion is a viable recovery // @return {@code true} if single-token insertion is a viable recovery
// strategy for the current mismatched input, otherwise {@code false} // strategy for the current mismatched input, otherwise {@code false}
//
func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool { func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
currentSymbolType := recognizer.GetTokenStream().LA(1) currentSymbolType := recognizer.GetTokenStream().LA(1)
// if current token is consistent with what could come after current // if current token is consistent with what could come after current
@ -469,7 +448,6 @@ func (d *DefaultErrorStrategy) SingleTokenInsertion(recognizer Parser) bool {
// @return the successfully Matched {@link Token} instance if single-token // @return the successfully Matched {@link Token} instance if single-token
// deletion successfully recovers from the mismatched input, otherwise // deletion successfully recovers from the mismatched input, otherwise
// {@code nil} // {@code nil}
//
func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token { func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
NextTokenType := recognizer.GetTokenStream().LA(2) NextTokenType := recognizer.GetTokenStream().LA(2)
expecting := d.GetExpectedTokens(recognizer) expecting := d.GetExpectedTokens(recognizer)
@ -507,7 +485,6 @@ func (d *DefaultErrorStrategy) SingleTokenDeletion(recognizer Parser) Token {
// a CommonToken of the appropriate type. The text will be the token. // a CommonToken of the appropriate type. The text will be the token.
// If you change what tokens must be created by the lexer, // If you change what tokens must be created by the lexer,
// override d method to create the appropriate tokens. // override d method to create the appropriate tokens.
//
func (d *DefaultErrorStrategy) GetMissingSymbol(recognizer Parser) Token { func (d *DefaultErrorStrategy) GetMissingSymbol(recognizer Parser) Token {
currentSymbol := recognizer.GetCurrentToken() currentSymbol := recognizer.GetCurrentToken()
expecting := d.GetExpectedTokens(recognizer) expecting := d.GetExpectedTokens(recognizer)
@ -546,7 +523,6 @@ func (d *DefaultErrorStrategy) GetExpectedTokens(recognizer Parser) *IntervalSet
// the token). This is better than forcing you to override a method in // the token). This is better than forcing you to override a method in
// your token objects because you don't have to go modify your lexer // your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type. // so that it creates a NewJava type.
//
func (d *DefaultErrorStrategy) GetTokenErrorDisplay(t Token) string { func (d *DefaultErrorStrategy) GetTokenErrorDisplay(t Token) string {
if t == nil { if t == nil {
return "<no token>" return "<no token>"
@ -578,7 +554,7 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
// from within the rule i.e., the FIRST computation done by // from within the rule i.e., the FIRST computation done by
// ANTLR stops at the end of a rule. // ANTLR stops at the end of a rule.
// //
// EXAMPLE // # EXAMPLE
// //
// When you find a "no viable alt exception", the input is not // When you find a "no viable alt exception", the input is not
// consistent with any of the alternatives for rule r. The best // consistent with any of the alternatives for rule r. The best
@ -597,7 +573,6 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
// c : ID // c : ID
// | INT // | INT
// //
//
// At each rule invocation, the set of tokens that could follow // At each rule invocation, the set of tokens that could follow
// that rule is pushed on a stack. Here are the various // that rule is pushed on a stack. Here are the various
// context-sensitive follow sets: // context-sensitive follow sets:
@ -660,7 +635,6 @@ func (d *DefaultErrorStrategy) escapeWSAndQuote(s string) string {
// //
// Like Grosch I implement context-sensitive FOLLOW sets that are combined // Like Grosch I implement context-sensitive FOLLOW sets that are combined
// at run-time upon error to avoid overhead during parsing. // at run-time upon error to avoid overhead during parsing.
//
func (d *DefaultErrorStrategy) getErrorRecoverySet(recognizer Parser) *IntervalSet { func (d *DefaultErrorStrategy) getErrorRecoverySet(recognizer Parser) *IntervalSet {
atn := recognizer.GetInterpreter().atn atn := recognizer.GetInterpreter().atn
ctx := recognizer.GetParserRuleContext() ctx := recognizer.GetParserRuleContext()
@ -733,7 +707,6 @@ func NewBailErrorStrategy() *BailErrorStrategy {
// in a {@link ParseCancellationException} so it is not caught by the // in a {@link ParseCancellationException} so it is not caught by the
// rule func catches. Use {@link Exception//getCause()} to get the // rule func catches. Use {@link Exception//getCause()} to get the
// original {@link RecognitionException}. // original {@link RecognitionException}.
//
func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) { func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
context := recognizer.GetParserRuleContext() context := recognizer.GetParserRuleContext()
for context != nil { for context != nil {
@ -749,7 +722,6 @@ func (b *BailErrorStrategy) Recover(recognizer Parser, e RecognitionException) {
// Make sure we don't attempt to recover inline if the parser // Make sure we don't attempt to recover inline if the parser
// successfully recovers, it won't panic an exception. // successfully recovers, it won't panic an exception.
//
func (b *BailErrorStrategy) RecoverInline(recognizer Parser) Token { func (b *BailErrorStrategy) RecoverInline(recognizer Parser) Token {
b.Recover(recognizer, NewInputMisMatchException(recognizer)) b.Recover(recognizer, NewInputMisMatchException(recognizer))

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -74,7 +74,6 @@ func (b *BaseRecognitionException) GetInputStream() IntStream {
// <p>If the state number is not known, b method returns -1.</p> // <p>If the state number is not known, b method returns -1.</p>
//
// Gets the set of input symbols which could potentially follow the // Gets the set of input symbols which could potentially follow the
// previously Matched symbol at the time b exception was panicn. // previously Matched symbol at the time b exception was panicn.
// //
@ -136,7 +135,6 @@ type NoViableAltException struct {
// to take based upon the remaining input. It tracks the starting token // to take based upon the remaining input. It tracks the starting token
// of the offending input and also knows where the parser was // of the offending input and also knows where the parser was
// in the various paths when the error. Reported by ReportNoViableAlternative() // in the various paths when the error. Reported by ReportNoViableAlternative()
//
func NewNoViableAltException(recognizer Parser, input TokenStream, startToken Token, offendingToken Token, deadEndConfigs ATNConfigSet, ctx ParserRuleContext) *NoViableAltException { func NewNoViableAltException(recognizer Parser, input TokenStream, startToken Token, offendingToken Token, deadEndConfigs ATNConfigSet, ctx ParserRuleContext) *NoViableAltException {
if ctx == nil { if ctx == nil {
@ -177,7 +175,6 @@ type InputMisMatchException struct {
// This signifies any kind of mismatched input exceptions such as // This signifies any kind of mismatched input exceptions such as
// when the current input does not Match the expected token. // when the current input does not Match the expected token.
//
func NewInputMisMatchException(recognizer Parser) *InputMisMatchException { func NewInputMisMatchException(recognizer Parser) *InputMisMatchException {
i := new(InputMisMatchException) i := new(InputMisMatchException)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -223,6 +223,10 @@ func (i *IntervalSet) StringVerbose(literalNames []string, symbolicNames []strin
return i.toIndexString() return i.toIndexString()
} }
func (i *IntervalSet) GetIntervals() []*Interval {
return i.intervals
}
func (i *IntervalSet) toCharString() string { func (i *IntervalSet) toCharString() string {
names := make([]string, len(i.intervals)) names := make([]string, len(i.intervals))

View File

@ -0,0 +1,198 @@
package antlr
// Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root.
import (
"sort"
)
// Collectable is an interface that a struct should implement if it is to be
// usable as a key in these collections.
type Collectable[T any] interface {
Hash() int
Equals(other Collectable[T]) bool
}
type Comparator[T any] interface {
Hash1(o T) int
Equals2(T, T) bool
}
// JStore implements a container that allows the use of a struct to calculate the key
// for a collection of values akin to map. This is not meant to be a full-blown HashMap but just
// serve the needs of the ANTLR Go runtime.
//
// For ease of porting the logic of the runtime from the master target (Java), this collection
// operates in a similar way to Java, in that it can use any struct that supplies a Hash() and Equals()
// function as the key. The values are stored in a standard go map which internally is a form of hashmap
// itself, the key for the go map is the hash supplied by the key object. The collection is able to deal with
// hash conflicts by using a simple slice of values associated with the hash code indexed bucket. That isn't
// particularly efficient, but it is simple, and it works. As this is specifically for the ANTLR runtime, and
// we understand the requirements, then this is fine - this is not a general purpose collection.
type JStore[T any, C Comparator[T]] struct {
store map[int][]T
len int
comparator Comparator[T]
}
func NewJStore[T any, C Comparator[T]](comparator Comparator[T]) *JStore[T, C] {
if comparator == nil {
panic("comparator cannot be nil")
}
s := &JStore[T, C]{
store: make(map[int][]T, 1),
comparator: comparator,
}
return s
}
// Put will store given value in the collection. Note that the key for storage is generated from
// the value itself - this is specifically because that is what ANTLR needs - this would not be useful
// as any kind of general collection.
//
// If the key has a hash conflict, then the value will be added to the slice of values associated with the
// hash, unless the value is already in the slice, in which case the existing value is returned. Value equivalence is
// tested by calling the equals() method on the key.
//
// # If the given value is already present in the store, then the existing value is returned as v and exists is set to true
//
// If the given value is not present in the store, then the value is added to the store and returned as v and exists is set to false.
func (s *JStore[T, C]) Put(value T) (v T, exists bool) { //nolint:ireturn
kh := s.comparator.Hash1(value)
for _, v1 := range s.store[kh] {
if s.comparator.Equals2(value, v1) {
return v1, true
}
}
s.store[kh] = append(s.store[kh], value)
s.len++
return value, false
}
// Get will return the value associated with the key - the type of the key is the same type as the value
// which would not generally be useful, but this is a specific thing for ANTLR where the key is
// generated using the object we are going to store.
func (s *JStore[T, C]) Get(key T) (T, bool) { //nolint:ireturn
kh := s.comparator.Hash1(key)
for _, v := range s.store[kh] {
if s.comparator.Equals2(key, v) {
return v, true
}
}
return key, false
}
// Contains returns true if the given key is present in the store
func (s *JStore[T, C]) Contains(key T) bool { //nolint:ireturn
_, present := s.Get(key)
return present
}
func (s *JStore[T, C]) SortedSlice(less func(i, j T) bool) []T {
vs := make([]T, 0, len(s.store))
for _, v := range s.store {
vs = append(vs, v...)
}
sort.Slice(vs, func(i, j int) bool {
return less(vs[i], vs[j])
})
return vs
}
func (s *JStore[T, C]) Each(f func(T) bool) {
for _, e := range s.store {
for _, v := range e {
f(v)
}
}
}
func (s *JStore[T, C]) Len() int {
return s.len
}
func (s *JStore[T, C]) Values() []T {
vs := make([]T, 0, len(s.store))
for _, e := range s.store {
for _, v := range e {
vs = append(vs, v)
}
}
return vs
}
type entry[K, V any] struct {
key K
val V
}
type JMap[K, V any, C Comparator[K]] struct {
store map[int][]*entry[K, V]
len int
comparator Comparator[K]
}
func NewJMap[K, V any, C Comparator[K]](comparator Comparator[K]) *JMap[K, V, C] {
return &JMap[K, V, C]{
store: make(map[int][]*entry[K, V], 1),
comparator: comparator,
}
}
func (m *JMap[K, V, C]) Put(key K, val V) {
kh := m.comparator.Hash1(key)
m.store[kh] = append(m.store[kh], &entry[K, V]{key, val})
m.len++
}
func (m *JMap[K, V, C]) Values() []V {
vs := make([]V, 0, len(m.store))
for _, e := range m.store {
for _, v := range e {
vs = append(vs, v.val)
}
}
return vs
}
func (m *JMap[K, V, C]) Get(key K) (V, bool) {
var none V
kh := m.comparator.Hash1(key)
for _, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
return e.val, true
}
}
return none, false
}
func (m *JMap[K, V, C]) Len() int {
return len(m.store)
}
func (m *JMap[K, V, C]) Delete(key K) {
kh := m.comparator.Hash1(key)
for i, e := range m.store[kh] {
if m.comparator.Equals2(e.key, key) {
m.store[kh] = append(m.store[kh][:i], m.store[kh][i+1:]...)
m.len--
return
}
}
}
func (m *JMap[K, V, C]) Clear() {
m.store = make(map[int][]*entry[K, V])
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -232,8 +232,6 @@ func (b *BaseLexer) NextToken() Token {
} }
return b.token return b.token
} }
return nil
} }
// Instruct the lexer to Skip creating a token for current lexer rule // Instruct the lexer to Skip creating a token for current lexer rule
@ -342,7 +340,7 @@ func (b *BaseLexer) GetCharIndex() int {
} }
// Return the text Matched so far for the current token or any text override. // Return the text Matched so far for the current token or any text override.
//Set the complete text of l token it wipes any previous changes to the text. // Set the complete text of l token it wipes any previous changes to the text.
func (b *BaseLexer) GetText() string { func (b *BaseLexer) GetText() string {
if b.text != "" { if b.text != "" {
return b.text return b.text

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -21,8 +21,8 @@ type LexerAction interface {
getActionType() int getActionType() int
getIsPositionDependent() bool getIsPositionDependent() bool
execute(lexer Lexer) execute(lexer Lexer)
hash() int Hash() int
equals(other LexerAction) bool Equals(other LexerAction) bool
} }
type BaseLexerAction struct { type BaseLexerAction struct {
@ -51,15 +51,14 @@ func (b *BaseLexerAction) getIsPositionDependent() bool {
return b.isPositionDependent return b.isPositionDependent
} }
func (b *BaseLexerAction) hash() int { func (b *BaseLexerAction) Hash() int {
return b.actionType return b.actionType
} }
func (b *BaseLexerAction) equals(other LexerAction) bool { func (b *BaseLexerAction) Equals(other LexerAction) bool {
return b == other return b == other
} }
//
// Implements the {@code Skip} lexer action by calling {@link Lexer//Skip}. // Implements the {@code Skip} lexer action by calling {@link Lexer//Skip}.
// //
// <p>The {@code Skip} command does not have any parameters, so l action is // <p>The {@code Skip} command does not have any parameters, so l action is
@ -85,7 +84,8 @@ func (l *LexerSkipAction) String() string {
return "skip" return "skip"
} }
// Implements the {@code type} lexer action by calling {@link Lexer//setType} // Implements the {@code type} lexer action by calling {@link Lexer//setType}
//
// with the assigned type. // with the assigned type.
type LexerTypeAction struct { type LexerTypeAction struct {
*BaseLexerAction *BaseLexerAction
@ -104,14 +104,14 @@ func (l *LexerTypeAction) execute(lexer Lexer) {
lexer.SetType(l.thetype) lexer.SetType(l.thetype)
} }
func (l *LexerTypeAction) hash() int { func (l *LexerTypeAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.thetype) h = murmurUpdate(h, l.thetype)
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
func (l *LexerTypeAction) equals(other LexerAction) bool { func (l *LexerTypeAction) Equals(other LexerAction) bool {
if l == other { if l == other {
return true return true
} else if _, ok := other.(*LexerTypeAction); !ok { } else if _, ok := other.(*LexerTypeAction); !ok {
@ -148,14 +148,14 @@ func (l *LexerPushModeAction) execute(lexer Lexer) {
lexer.PushMode(l.mode) lexer.PushMode(l.mode)
} }
func (l *LexerPushModeAction) hash() int { func (l *LexerPushModeAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.mode) h = murmurUpdate(h, l.mode)
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
func (l *LexerPushModeAction) equals(other LexerAction) bool { func (l *LexerPushModeAction) Equals(other LexerAction) bool {
if l == other { if l == other {
return true return true
} else if _, ok := other.(*LexerPushModeAction); !ok { } else if _, ok := other.(*LexerPushModeAction); !ok {
@ -245,14 +245,14 @@ func (l *LexerModeAction) execute(lexer Lexer) {
lexer.SetMode(l.mode) lexer.SetMode(l.mode)
} }
func (l *LexerModeAction) hash() int { func (l *LexerModeAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.mode) h = murmurUpdate(h, l.mode)
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
func (l *LexerModeAction) equals(other LexerAction) bool { func (l *LexerModeAction) Equals(other LexerAction) bool {
if l == other { if l == other {
return true return true
} else if _, ok := other.(*LexerModeAction); !ok { } else if _, ok := other.(*LexerModeAction); !ok {
@ -303,7 +303,7 @@ func (l *LexerCustomAction) execute(lexer Lexer) {
lexer.Action(nil, l.ruleIndex, l.actionIndex) lexer.Action(nil, l.ruleIndex, l.actionIndex)
} }
func (l *LexerCustomAction) hash() int { func (l *LexerCustomAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.ruleIndex) h = murmurUpdate(h, l.ruleIndex)
@ -311,13 +311,14 @@ func (l *LexerCustomAction) hash() int {
return murmurFinish(h, 3) return murmurFinish(h, 3)
} }
func (l *LexerCustomAction) equals(other LexerAction) bool { func (l *LexerCustomAction) Equals(other LexerAction) bool {
if l == other { if l == other {
return true return true
} else if _, ok := other.(*LexerCustomAction); !ok { } else if _, ok := other.(*LexerCustomAction); !ok {
return false return false
} else { } else {
return l.ruleIndex == other.(*LexerCustomAction).ruleIndex && l.actionIndex == other.(*LexerCustomAction).actionIndex return l.ruleIndex == other.(*LexerCustomAction).ruleIndex &&
l.actionIndex == other.(*LexerCustomAction).actionIndex
} }
} }
@ -344,14 +345,14 @@ func (l *LexerChannelAction) execute(lexer Lexer) {
lexer.SetChannel(l.channel) lexer.SetChannel(l.channel)
} }
func (l *LexerChannelAction) hash() int { func (l *LexerChannelAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.actionType) h = murmurUpdate(h, l.actionType)
h = murmurUpdate(h, l.channel) h = murmurUpdate(h, l.channel)
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
func (l *LexerChannelAction) equals(other LexerAction) bool { func (l *LexerChannelAction) Equals(other LexerAction) bool {
if l == other { if l == other {
return true return true
} else if _, ok := other.(*LexerChannelAction); !ok { } else if _, ok := other.(*LexerChannelAction); !ok {
@ -412,10 +413,10 @@ func (l *LexerIndexedCustomAction) execute(lexer Lexer) {
l.lexerAction.execute(lexer) l.lexerAction.execute(lexer)
} }
func (l *LexerIndexedCustomAction) hash() int { func (l *LexerIndexedCustomAction) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, l.offset) h = murmurUpdate(h, l.offset)
h = murmurUpdate(h, l.lexerAction.hash()) h = murmurUpdate(h, l.lexerAction.Hash())
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
@ -425,6 +426,7 @@ func (l *LexerIndexedCustomAction) equals(other LexerAction) bool {
} else if _, ok := other.(*LexerIndexedCustomAction); !ok { } else if _, ok := other.(*LexerIndexedCustomAction); !ok {
return false return false
} else { } else {
return l.offset == other.(*LexerIndexedCustomAction).offset && l.lexerAction == other.(*LexerIndexedCustomAction).lexerAction return l.offset == other.(*LexerIndexedCustomAction).offset &&
l.lexerAction.Equals(other.(*LexerIndexedCustomAction).lexerAction)
} }
} }

View File

@ -1,9 +1,11 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
package antlr package antlr
import "golang.org/x/exp/slices"
// Represents an executor for a sequence of lexer actions which traversed during // Represents an executor for a sequence of lexer actions which traversed during
// the Matching operation of a lexer rule (token). // the Matching operation of a lexer rule (token).
// //
@ -12,8 +14,8 @@ package antlr
// not cause bloating of the {@link DFA} created for the lexer.</p> // not cause bloating of the {@link DFA} created for the lexer.</p>
type LexerActionExecutor struct { type LexerActionExecutor struct {
lexerActions []LexerAction lexerActions []LexerAction
cachedHash int cachedHash int
} }
func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor { func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor {
@ -30,7 +32,7 @@ func NewLexerActionExecutor(lexerActions []LexerAction) *LexerActionExecutor {
// of the performance-critical {@link LexerATNConfig//hashCode} operation. // of the performance-critical {@link LexerATNConfig//hashCode} operation.
l.cachedHash = murmurInit(57) l.cachedHash = murmurInit(57)
for _, a := range lexerActions { for _, a := range lexerActions {
l.cachedHash = murmurUpdate(l.cachedHash, a.hash()) l.cachedHash = murmurUpdate(l.cachedHash, a.Hash())
} }
return l return l
@ -151,14 +153,17 @@ func (l *LexerActionExecutor) execute(lexer Lexer, input CharStream, startIndex
} }
} }
func (l *LexerActionExecutor) hash() int { func (l *LexerActionExecutor) Hash() int {
if l == nil { if l == nil {
// TODO: Why is this here? l should not be nil
return 61 return 61
} }
// TODO: This is created from the action itself when the struct is created - will this be an issue at some point? Java uses the runtime assign hashcode
return l.cachedHash return l.cachedHash
} }
func (l *LexerActionExecutor) equals(other interface{}) bool { func (l *LexerActionExecutor) Equals(other interface{}) bool {
if l == other { if l == other {
return true return true
} }
@ -169,5 +174,13 @@ func (l *LexerActionExecutor) equals(other interface{}) bool {
if othert == nil { if othert == nil {
return false return false
} }
return l.cachedHash == othert.cachedHash && &l.lexerActions == &othert.lexerActions if l.cachedHash != othert.cachedHash {
return false
}
if len(l.lexerActions) != len(othert.lexerActions) {
return false
}
return slices.EqualFunc(l.lexerActions, othert.lexerActions, func(i, j LexerAction) bool {
return i.Equals(j)
})
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -591,19 +591,24 @@ func (l *LexerATNSimulator) addDFAState(configs ATNConfigSet, suppressEdge bool)
proposed.lexerActionExecutor = firstConfigWithRuleStopState.(*LexerATNConfig).lexerActionExecutor proposed.lexerActionExecutor = firstConfigWithRuleStopState.(*LexerATNConfig).lexerActionExecutor
proposed.setPrediction(l.atn.ruleToTokenType[firstConfigWithRuleStopState.GetState().GetRuleIndex()]) proposed.setPrediction(l.atn.ruleToTokenType[firstConfigWithRuleStopState.GetState().GetRuleIndex()])
} }
hash := proposed.hash()
dfa := l.decisionToDFA[l.mode] dfa := l.decisionToDFA[l.mode]
l.atn.stateMu.Lock() l.atn.stateMu.Lock()
defer l.atn.stateMu.Unlock() defer l.atn.stateMu.Unlock()
existing, ok := dfa.getState(hash) existing, present := dfa.states.Get(proposed)
if ok { if present {
// This state was already present, so just return it.
//
proposed = existing proposed = existing
} else { } else {
proposed.stateNumber = dfa.numStates()
// We need to add the new state
//
proposed.stateNumber = dfa.states.Len()
configs.SetReadOnly(true) configs.SetReadOnly(true)
proposed.configs = configs proposed.configs = configs
dfa.setState(hash, proposed) dfa.states.Put(proposed)
} }
if !suppressEdge { if !suppressEdge {
dfa.setS0(proposed) dfa.setS0(proposed)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -14,14 +14,15 @@ func NewLL1Analyzer(atn *ATN) *LL1Analyzer {
return la return la
} }
//* Special value added to the lookahead sets to indicate that we hit // - Special value added to the lookahead sets to indicate that we hit
// a predicate during analysis if {@code seeThruPreds==false}. // a predicate during analysis if {@code seeThruPreds==false}.
/// //
// /
const ( const (
LL1AnalyzerHitPred = TokenInvalidType LL1AnalyzerHitPred = TokenInvalidType
) )
//* // *
// Calculates the SLL(1) expected lookahead set for each outgoing transition // Calculates the SLL(1) expected lookahead set for each outgoing transition
// of an {@link ATNState}. The returned array has one element for each // of an {@link ATNState}. The returned array has one element for each
// outgoing transition in {@code s}. If the closure from transition // outgoing transition in {@code s}. If the closure from transition
@ -38,7 +39,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
look := make([]*IntervalSet, count) look := make([]*IntervalSet, count)
for alt := 0; alt < count; alt++ { for alt := 0; alt < count; alt++ {
look[alt] = NewIntervalSet() look[alt] = NewIntervalSet()
lookBusy := newArray2DHashSet(nil, nil) lookBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
seeThruPreds := false // fail to get lookahead upon pred seeThruPreds := false // fail to get lookahead upon pred
la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), seeThruPreds, false) la.look1(s.GetTransitions()[alt].getTarget(), nil, BasePredictionContextEMPTY, look[alt], lookBusy, NewBitSet(), seeThruPreds, false)
// Wipe out lookahead for la alternative if we found nothing // Wipe out lookahead for la alternative if we found nothing
@ -50,7 +51,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
return look return look
} }
//* // *
// Compute set of tokens that can follow {@code s} in the ATN in the // Compute set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}. // specified {@code ctx}.
// //
@ -67,7 +68,7 @@ func (la *LL1Analyzer) getDecisionLookahead(s ATNState) []*IntervalSet {
// //
// @return The set of tokens that can follow {@code s} in the ATN in the // @return The set of tokens that can follow {@code s} in the ATN in the
// specified {@code ctx}. // specified {@code ctx}.
/// // /
func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet { func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet {
r := NewIntervalSet() r := NewIntervalSet()
seeThruPreds := true // ignore preds get all lookahead seeThruPreds := true // ignore preds get all lookahead
@ -75,7 +76,7 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet
if ctx != nil { if ctx != nil {
lookContext = predictionContextFromRuleContext(s.GetATN(), ctx) lookContext = predictionContextFromRuleContext(s.GetATN(), ctx)
} }
la.look1(s, stopState, lookContext, r, newArray2DHashSet(nil, nil), NewBitSet(), seeThruPreds, true) la.look1(s, stopState, lookContext, r, NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst), NewBitSet(), seeThruPreds, true)
return r return r
} }
@ -109,14 +110,14 @@ func (la *LL1Analyzer) Look(s, stopState ATNState, ctx RuleContext) *IntervalSet
// outermost context is reached. This parameter has no effect if {@code ctx} // outermost context is reached. This parameter has no effect if {@code ctx}
// is {@code nil}. // is {@code nil}.
func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) { func (la *LL1Analyzer) look2(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, i int) {
returnState := la.atn.states[ctx.getReturnState(i)] returnState := la.atn.states[ctx.getReturnState(i)]
la.look1(returnState, stopState, ctx.GetParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF) la.look1(returnState, stopState, ctx.GetParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF)
} }
func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool) { func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool) {
c := NewBaseATNConfig6(s, 0, ctx) c := NewBaseATNConfig6(s, 0, ctx)
@ -124,8 +125,11 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
return return
} }
lookBusy.Add(c) _, present := lookBusy.Put(c)
if present {
return
}
if s == stopState { if s == stopState {
if ctx == nil { if ctx == nil {
look.addOne(TokenEpsilon) look.addOne(TokenEpsilon)
@ -198,7 +202,7 @@ func (la *LL1Analyzer) look1(s, stopState ATNState, ctx PredictionContext, look
} }
} }
func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy Set, calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) { func (la *LL1Analyzer) look3(stopState ATNState, ctx PredictionContext, look *IntervalSet, lookBusy *JStore[ATNConfig, Comparator[ATNConfig]], calledRuleStack *BitSet, seeThruPreds, addEOF bool, t1 *RuleTransition) {
newContext := SingletonBasePredictionContextCreate(ctx, t1.followState.GetStateNumber()) newContext := SingletonBasePredictionContextCreate(ctx, t1.followState.GetStateNumber())

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -91,7 +91,6 @@ func NewBaseParser(input TokenStream) *BaseParser {
// bypass alternatives. // bypass alternatives.
// //
// @see ATNDeserializationOptions//isGenerateRuleBypassTransitions() // @see ATNDeserializationOptions//isGenerateRuleBypassTransitions()
//
var bypassAltsAtnCache = make(map[string]int) var bypassAltsAtnCache = make(map[string]int)
// reset the parser's state// // reset the parser's state//
@ -230,7 +229,6 @@ func (p *BaseParser) GetParseListeners() []ParseTreeListener {
// @param listener the listener to add // @param listener the listener to add
// //
// @panics nilPointerException if {@code} listener is {@code nil} // @panics nilPointerException if {@code} listener is {@code nil}
//
func (p *BaseParser) AddParseListener(listener ParseTreeListener) { func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
if listener == nil { if listener == nil {
panic("listener") panic("listener")
@ -241,13 +239,11 @@ func (p *BaseParser) AddParseListener(listener ParseTreeListener) {
p.parseListeners = append(p.parseListeners, listener) p.parseListeners = append(p.parseListeners, listener)
} }
//
// Remove {@code listener} from the list of parse listeners. // Remove {@code listener} from the list of parse listeners.
// //
// <p>If {@code listener} is {@code nil} or has not been added as a parse // <p>If {@code listener} is {@code nil} or has not been added as a parse
// listener, p.method does nothing.</p> // listener, p.method does nothing.</p>
// @param listener the listener to remove // @param listener the listener to remove
//
func (p *BaseParser) RemoveParseListener(listener ParseTreeListener) { func (p *BaseParser) RemoveParseListener(listener ParseTreeListener) {
if p.parseListeners != nil { if p.parseListeners != nil {
@ -289,11 +285,9 @@ func (p *BaseParser) TriggerEnterRuleEvent() {
} }
} }
//
// Notify any parse listeners of an exit rule event. // Notify any parse listeners of an exit rule event.
// //
// @see //addParseListener // @see //addParseListener
//
func (p *BaseParser) TriggerExitRuleEvent() { func (p *BaseParser) TriggerExitRuleEvent() {
if p.parseListeners != nil { if p.parseListeners != nil {
// reverse order walk of listeners // reverse order walk of listeners
@ -330,7 +324,6 @@ func (p *BaseParser) setTokenFactory(factory TokenFactory) {
// //
// @panics UnsupportedOperationException if the current parser does not // @panics UnsupportedOperationException if the current parser does not
// implement the {@link //getSerializedATN()} method. // implement the {@link //getSerializedATN()} method.
//
func (p *BaseParser) GetATNWithBypassAlts() { func (p *BaseParser) GetATNWithBypassAlts() {
// TODO // TODO
@ -402,7 +395,6 @@ func (p *BaseParser) SetTokenStream(input TokenStream) {
// Match needs to return the current input symbol, which gets put // Match needs to return the current input symbol, which gets put
// into the label for the associated token ref e.g., x=ID. // into the label for the associated token ref e.g., x=ID.
//
func (p *BaseParser) GetCurrentToken() Token { func (p *BaseParser) GetCurrentToken() Token {
return p.input.LT(1) return p.input.LT(1)
} }
@ -624,7 +616,6 @@ func (p *BaseParser) IsExpectedToken(symbol int) bool {
// respectively. // respectively.
// //
// @see ATN//getExpectedTokens(int, RuleContext) // @see ATN//getExpectedTokens(int, RuleContext)
//
func (p *BaseParser) GetExpectedTokens() *IntervalSet { func (p *BaseParser) GetExpectedTokens() *IntervalSet {
return p.Interpreter.atn.getExpectedTokens(p.state, p.ctx) return p.Interpreter.atn.getExpectedTokens(p.state, p.ctx)
} }
@ -686,7 +677,7 @@ func (p *BaseParser) GetDFAStrings() string {
func (p *BaseParser) DumpDFA() { func (p *BaseParser) DumpDFA() {
seenOne := false seenOne := false
for _, dfa := range p.Interpreter.decisionToDFA { for _, dfa := range p.Interpreter.decisionToDFA {
if dfa.numStates() > 0 { if dfa.states.Len() > 0 {
if seenOne { if seenOne {
fmt.Println() fmt.Println()
} }
@ -703,7 +694,6 @@ func (p *BaseParser) GetSourceName() string {
// During a parse is sometimes useful to listen in on the rule entry and exit // During a parse is sometimes useful to listen in on the rule entry and exit
// events as well as token Matches. p.is for quick and dirty debugging. // events as well as token Matches. p.is for quick and dirty debugging.
//
func (p *BaseParser) SetTrace(trace *TraceListener) { func (p *BaseParser) SetTrace(trace *TraceListener) {
if trace == nil { if trace == nil {
p.RemoveParseListener(p.tracer) p.RemoveParseListener(p.tracer)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -11,11 +11,11 @@ import (
) )
var ( var (
ParserATNSimulatorDebug = false ParserATNSimulatorDebug = false
ParserATNSimulatorListATNDecisions = false ParserATNSimulatorTraceATNSim = false
ParserATNSimulatorDFADebug = false ParserATNSimulatorDFADebug = false
ParserATNSimulatorRetryDebug = false ParserATNSimulatorRetryDebug = false
TurnOffLRLoopEntryBranchOpt = false TurnOffLRLoopEntryBranchOpt = false
) )
type ParserATNSimulator struct { type ParserATNSimulator struct {
@ -70,8 +70,8 @@ func (p *ParserATNSimulator) reset() {
} }
func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, outerContext ParserRuleContext) int { func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions { if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("AdaptivePredict decision " + strconv.Itoa(decision) + fmt.Println("adaptivePredict decision " + strconv.Itoa(decision) +
" exec LA(1)==" + p.getLookaheadName(input) + " exec LA(1)==" + p.getLookaheadName(input) +
" line " + strconv.Itoa(input.LT(1).GetLine()) + ":" + " line " + strconv.Itoa(input.LT(1).GetLine()) + ":" +
strconv.Itoa(input.LT(1).GetColumn())) strconv.Itoa(input.LT(1).GetColumn()))
@ -111,15 +111,15 @@ func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, ou
if s0 == nil { if s0 == nil {
if outerContext == nil { if outerContext == nil {
outerContext = RuleContextEmpty outerContext = ParserRuleContextEmpty
} }
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions { if ParserATNSimulatorDebug {
fmt.Println("predictATN decision " + strconv.Itoa(dfa.decision) + fmt.Println("predictATN decision " + strconv.Itoa(dfa.decision) +
" exec LA(1)==" + p.getLookaheadName(input) + " exec LA(1)==" + p.getLookaheadName(input) +
", outerContext=" + outerContext.String(p.parser.GetRuleNames(), nil)) ", outerContext=" + outerContext.String(p.parser.GetRuleNames(), nil))
} }
fullCtx := false fullCtx := false
s0Closure := p.computeStartState(dfa.atnStartState, RuleContextEmpty, fullCtx) s0Closure := p.computeStartState(dfa.atnStartState, ParserRuleContextEmpty, fullCtx)
p.atn.stateMu.Lock() p.atn.stateMu.Lock()
if dfa.getPrecedenceDfa() { if dfa.getPrecedenceDfa() {
@ -174,17 +174,18 @@ func (p *ParserATNSimulator) AdaptivePredict(input TokenStream, decision int, ou
// Reporting insufficient predicates // Reporting insufficient predicates
// cover these cases: // cover these cases:
// dead end
// single alt
// single alt + preds
// conflict
// conflict + preds
// //
// dead end
// single alt
// single alt + preds
// conflict
// conflict + preds
func (p *ParserATNSimulator) execATN(dfa *DFA, s0 *DFAState, input TokenStream, startIndex int, outerContext ParserRuleContext) int { func (p *ParserATNSimulator) execATN(dfa *DFA, s0 *DFAState, input TokenStream, startIndex int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions { if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("execATN decision " + strconv.Itoa(dfa.decision) + fmt.Println("execATN decision " + strconv.Itoa(dfa.decision) +
" exec LA(1)==" + p.getLookaheadName(input) + ", DFA state " + s0.String() +
", LA(1)==" + p.getLookaheadName(input) +
" line " + strconv.Itoa(input.LT(1).GetLine()) + ":" + strconv.Itoa(input.LT(1).GetColumn())) " line " + strconv.Itoa(input.LT(1).GetLine()) + ":" + strconv.Itoa(input.LT(1).GetColumn()))
} }
@ -277,8 +278,6 @@ func (p *ParserATNSimulator) execATN(dfa *DFA, s0 *DFAState, input TokenStream,
t = input.LA(1) t = input.LA(1)
} }
} }
panic("Should not have reached p state")
} }
// Get an existing target state for an edge in the DFA. If the target state // Get an existing target state for an edge in the DFA. If the target state
@ -384,7 +383,7 @@ func (p *ParserATNSimulator) predicateDFAState(dfaState *DFAState, decisionState
// comes back with reach.uniqueAlt set to a valid alt // comes back with reach.uniqueAlt set to a valid alt
func (p *ParserATNSimulator) execATNWithFullContext(dfa *DFA, D *DFAState, s0 ATNConfigSet, input TokenStream, startIndex int, outerContext ParserRuleContext) int { func (p *ParserATNSimulator) execATNWithFullContext(dfa *DFA, D *DFAState, s0 ATNConfigSet, input TokenStream, startIndex int, outerContext ParserRuleContext) int {
if ParserATNSimulatorDebug || ParserATNSimulatorListATNDecisions { if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("execATNWithFullContext " + s0.String()) fmt.Println("execATNWithFullContext " + s0.String())
} }
@ -492,9 +491,6 @@ func (p *ParserATNSimulator) execATNWithFullContext(dfa *DFA, D *DFAState, s0 AT
} }
func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCtx bool) ATNConfigSet { func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCtx bool) ATNConfigSet {
if ParserATNSimulatorDebug {
fmt.Println("in computeReachSet, starting closure: " + closure.String())
}
if p.mergeCache == nil { if p.mergeCache == nil {
p.mergeCache = NewDoubleDict() p.mergeCache = NewDoubleDict()
} }
@ -570,7 +566,7 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
// //
if reach == nil { if reach == nil {
reach = NewBaseATNConfigSet(fullCtx) reach = NewBaseATNConfigSet(fullCtx)
closureBusy := newArray2DHashSet(nil, nil) closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](aConfEqInst)
treatEOFAsEpsilon := t == TokenEOF treatEOFAsEpsilon := t == TokenEOF
amount := len(intermediate.configs) amount := len(intermediate.configs)
for k := 0; k < amount; k++ { for k := 0; k < amount; k++ {
@ -610,6 +606,11 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
reach.Add(skippedStopStates[l], p.mergeCache) reach.Add(skippedStopStates[l], p.mergeCache)
} }
} }
if ParserATNSimulatorTraceATNSim {
fmt.Println("computeReachSet " + closure.String() + " -> " + reach.String())
}
if len(reach.GetItems()) == 0 { if len(reach.GetItems()) == 0 {
return nil return nil
} }
@ -617,7 +618,6 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
return reach return reach
} }
//
// Return a configuration set containing only the configurations from // Return a configuration set containing only the configurations from
// {@code configs} which are in a {@link RuleStopState}. If all // {@code configs} which are in a {@link RuleStopState}. If all
// configurations in {@code configs} are already in a rule stop state, p // configurations in {@code configs} are already in a rule stop state, p
@ -636,7 +636,6 @@ func (p *ParserATNSimulator) computeReachSet(closure ATNConfigSet, t int, fullCt
// @return {@code configs} if all configurations in {@code configs} are in a // @return {@code configs} if all configurations in {@code configs} are in a
// rule stop state, otherwise return a Newconfiguration set containing only // rule stop state, otherwise return a Newconfiguration set containing only
// the configurations from {@code configs} which are in a rule stop state // the configurations from {@code configs} which are in a rule stop state
//
func (p *ParserATNSimulator) removeAllConfigsNotInRuleStopState(configs ATNConfigSet, lookToEndOfRule bool) ATNConfigSet { func (p *ParserATNSimulator) removeAllConfigsNotInRuleStopState(configs ATNConfigSet, lookToEndOfRule bool) ATNConfigSet {
if PredictionModeallConfigsInRuleStopStates(configs) { if PredictionModeallConfigsInRuleStopStates(configs) {
return configs return configs
@ -662,16 +661,20 @@ func (p *ParserATNSimulator) computeStartState(a ATNState, ctx RuleContext, full
// always at least the implicit call to start rule // always at least the implicit call to start rule
initialContext := predictionContextFromRuleContext(p.atn, ctx) initialContext := predictionContextFromRuleContext(p.atn, ctx)
configs := NewBaseATNConfigSet(fullCtx) configs := NewBaseATNConfigSet(fullCtx)
if ParserATNSimulatorDebug || ParserATNSimulatorTraceATNSim {
fmt.Println("computeStartState from ATN state " + a.String() +
" initialContext=" + initialContext.String())
}
for i := 0; i < len(a.GetTransitions()); i++ { for i := 0; i < len(a.GetTransitions()); i++ {
target := a.GetTransitions()[i].getTarget() target := a.GetTransitions()[i].getTarget()
c := NewBaseATNConfig6(target, i+1, initialContext) c := NewBaseATNConfig6(target, i+1, initialContext)
closureBusy := newArray2DHashSet(nil, nil) closureBusy := NewJStore[ATNConfig, Comparator[ATNConfig]](atnConfCompInst)
p.closure(c, configs, closureBusy, true, fullCtx, false) p.closure(c, configs, closureBusy, true, fullCtx, false)
} }
return configs return configs
} }
//
// This method transforms the start state computed by // This method transforms the start state computed by
// {@link //computeStartState} to the special start state used by a // {@link //computeStartState} to the special start state used by a
// precedence DFA for a particular precedence value. The transformation // precedence DFA for a particular precedence value. The transformation
@ -726,7 +729,6 @@ func (p *ParserATNSimulator) computeStartState(a ATNState, ctx RuleContext, full
// @return The transformed configuration set representing the start state // @return The transformed configuration set representing the start state
// for a precedence DFA at a particular precedence level (determined by // for a precedence DFA at a particular precedence level (determined by
// calling {@link Parser//getPrecedence}). // calling {@link Parser//getPrecedence}).
//
func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConfigSet { func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConfigSet {
statesFromAlt1 := make(map[int]PredictionContext) statesFromAlt1 := make(map[int]PredictionContext)
@ -760,7 +762,7 @@ func (p *ParserATNSimulator) applyPrecedenceFilter(configs ATNConfigSet) ATNConf
// (basically a graph subtraction algorithm). // (basically a graph subtraction algorithm).
if !config.getPrecedenceFilterSuppressed() { if !config.getPrecedenceFilterSuppressed() {
context := statesFromAlt1[config.GetState().GetStateNumber()] context := statesFromAlt1[config.GetState().GetStateNumber()]
if context != nil && context.equals(config.GetContext()) { if context != nil && context.Equals(config.GetContext()) {
// eliminated // eliminated
continue continue
} }
@ -824,7 +826,6 @@ func (p *ParserATNSimulator) getPredicatePredictions(ambigAlts *BitSet, altToPre
return pairs return pairs
} }
//
// This method is used to improve the localization of error messages by // This method is used to improve the localization of error messages by
// choosing an alternative rather than panicing a // choosing an alternative rather than panicing a
// {@link NoViableAltException} in particular prediction scenarios where the // {@link NoViableAltException} in particular prediction scenarios where the
@ -869,7 +870,6 @@ func (p *ParserATNSimulator) getPredicatePredictions(ambigAlts *BitSet, altToPre
// @return The value to return from {@link //AdaptivePredict}, or // @return The value to return from {@link //AdaptivePredict}, or
// {@link ATN//INVALID_ALT_NUMBER} if a suitable alternative was not // {@link ATN//INVALID_ALT_NUMBER} if a suitable alternative was not
// identified and {@link //AdaptivePredict} should Report an error instead. // identified and {@link //AdaptivePredict} should Report an error instead.
//
func (p *ParserATNSimulator) getSynValidOrSemInvalidAltThatFinishedDecisionEntryRule(configs ATNConfigSet, outerContext ParserRuleContext) int { func (p *ParserATNSimulator) getSynValidOrSemInvalidAltThatFinishedDecisionEntryRule(configs ATNConfigSet, outerContext ParserRuleContext) int {
cfgs := p.splitAccordingToSemanticValidity(configs, outerContext) cfgs := p.splitAccordingToSemanticValidity(configs, outerContext)
semValidConfigs := cfgs[0] semValidConfigs := cfgs[0]
@ -938,11 +938,11 @@ func (p *ParserATNSimulator) splitAccordingToSemanticValidity(configs ATNConfigS
} }
// Look through a list of predicate/alt pairs, returning alts for the // Look through a list of predicate/alt pairs, returning alts for the
// pairs that win. A {@code NONE} predicate indicates an alt containing an
// unpredicated config which behaves as "always true." If !complete
// then we stop at the first predicate that evaluates to true. This
// includes pairs with nil predicates.
// //
// pairs that win. A {@code NONE} predicate indicates an alt containing an
// unpredicated config which behaves as "always true." If !complete
// then we stop at the first predicate that evaluates to true. This
// includes pairs with nil predicates.
func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPrediction, outerContext ParserRuleContext, complete bool) *BitSet { func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPrediction, outerContext ParserRuleContext, complete bool) *BitSet {
predictions := NewBitSet() predictions := NewBitSet()
for i := 0; i < len(predPredictions); i++ { for i := 0; i < len(predPredictions); i++ {
@ -972,16 +972,16 @@ func (p *ParserATNSimulator) evalSemanticContext(predPredictions []*PredPredicti
return predictions return predictions
} }
func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx, treatEOFAsEpsilon bool) { func (p *ParserATNSimulator) closure(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx, treatEOFAsEpsilon bool) {
initialDepth := 0 initialDepth := 0
p.closureCheckingStopState(config, configs, closureBusy, collectPredicates, p.closureCheckingStopState(config, configs, closureBusy, collectPredicates,
fullCtx, initialDepth, treatEOFAsEpsilon) fullCtx, initialDepth, treatEOFAsEpsilon)
} }
func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
if ParserATNSimulatorDebug { if ParserATNSimulatorTraceATNSim {
fmt.Println("closure(" + config.String() + ")") fmt.Println("closure(" + config.String() + ")")
fmt.Println("configs(" + configs.String() + ")") //fmt.Println("configs(" + configs.String() + ")")
if config.GetReachesIntoOuterContext() > 50 { if config.GetReachesIntoOuterContext() > 50 {
panic("problem") panic("problem")
} }
@ -1031,7 +1031,7 @@ func (p *ParserATNSimulator) closureCheckingStopState(config ATNConfig, configs
} }
// Do the actual work of walking epsilon edges// // Do the actual work of walking epsilon edges//
func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy Set, collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) { func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet, closureBusy *JStore[ATNConfig, Comparator[ATNConfig]], collectPredicates, fullCtx bool, depth int, treatEOFAsEpsilon bool) {
state := config.GetState() state := config.GetState()
// optimization // optimization
if !state.GetEpsilonOnlyTransitions() { if !state.GetEpsilonOnlyTransitions() {
@ -1066,7 +1066,8 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet,
c.SetReachesIntoOuterContext(c.GetReachesIntoOuterContext() + 1) c.SetReachesIntoOuterContext(c.GetReachesIntoOuterContext() + 1)
if closureBusy.Add(c) != c { _, present := closureBusy.Put(c)
if present {
// avoid infinite recursion for right-recursive rules // avoid infinite recursion for right-recursive rules
continue continue
} }
@ -1077,9 +1078,13 @@ func (p *ParserATNSimulator) closureWork(config ATNConfig, configs ATNConfigSet,
fmt.Println("dips into outer ctx: " + c.String()) fmt.Println("dips into outer ctx: " + c.String())
} }
} else { } else {
if !t.getIsEpsilon() && closureBusy.Add(c) != c {
// avoid infinite recursion for EOF* and EOF+ if !t.getIsEpsilon() {
continue _, present := closureBusy.Put(c)
if present {
// avoid infinite recursion for EOF* and EOF+
continue
}
} }
if _, ok := t.(*RuleTransition); ok { if _, ok := t.(*RuleTransition); ok {
// latch when newDepth goes negative - once we step out of the entry context we can't return // latch when newDepth goes negative - once we step out of the entry context we can't return
@ -1104,7 +1109,16 @@ func (p *ParserATNSimulator) canDropLoopEntryEdgeInLeftRecursiveRule(config ATNC
// left-recursion elimination. For efficiency, also check if // left-recursion elimination. For efficiency, also check if
// the context has an empty stack case. If so, it would mean // the context has an empty stack case. If so, it would mean
// global FOLLOW so we can't perform optimization // global FOLLOW so we can't perform optimization
if startLoop, ok := _p.(StarLoopEntryState); !ok || !startLoop.precedenceRuleDecision || config.GetContext().isEmpty() || config.GetContext().hasEmptyPath() { if _p.GetStateType() != ATNStateStarLoopEntry {
return false
}
startLoop, ok := _p.(*StarLoopEntryState)
if !ok {
return false
}
if !startLoop.precedenceRuleDecision ||
config.GetContext().isEmpty() ||
config.GetContext().hasEmptyPath() {
return false return false
} }
@ -1117,8 +1131,8 @@ func (p *ParserATNSimulator) canDropLoopEntryEdgeInLeftRecursiveRule(config ATNC
return false return false
} }
} }
x := _p.GetTransitions()[0].getTarget()
decisionStartState := _p.(BlockStartState).GetTransitions()[0].getTarget().(BlockStartState) decisionStartState := x.(BlockStartState)
blockEndStateNum := decisionStartState.getEndState().stateNumber blockEndStateNum := decisionStartState.getEndState().stateNumber
blockEndState := p.atn.states[blockEndStateNum].(*BlockEndState) blockEndState := p.atn.states[blockEndStateNum].(*BlockEndState)
@ -1355,13 +1369,12 @@ func (p *ParserATNSimulator) GetTokenName(t int) string {
return "EOF" return "EOF"
} }
if p.parser != nil && p.parser.GetLiteralNames() != nil { if p.parser != nil && p.parser.GetLiteralNames() != nil && t < len(p.parser.GetLiteralNames()) {
if t >= len(p.parser.GetLiteralNames()) { return p.parser.GetLiteralNames()[t] + "<" + strconv.Itoa(t) + ">"
fmt.Println(strconv.Itoa(t) + " ttype out of range: " + strings.Join(p.parser.GetLiteralNames(), ",")) }
// fmt.Println(p.parser.GetInputStream().(TokenStream).GetAllText()) // p seems incorrect
} else { if p.parser != nil && p.parser.GetLiteralNames() != nil && t < len(p.parser.GetSymbolicNames()) {
return p.parser.GetLiteralNames()[t] + "<" + strconv.Itoa(t) + ">" return p.parser.GetSymbolicNames()[t] + "<" + strconv.Itoa(t) + ">"
}
} }
return strconv.Itoa(t) return strconv.Itoa(t)
@ -1372,9 +1385,9 @@ func (p *ParserATNSimulator) getLookaheadName(input TokenStream) string {
} }
// Used for debugging in AdaptivePredict around execATN but I cut // Used for debugging in AdaptivePredict around execATN but I cut
// it out for clarity now that alg. works well. We can leave p
// "dead" code for a bit.
// //
// it out for clarity now that alg. works well. We can leave p
// "dead" code for a bit.
func (p *ParserATNSimulator) dumpDeadEndConfigs(nvae *NoViableAltException) { func (p *ParserATNSimulator) dumpDeadEndConfigs(nvae *NoViableAltException) {
panic("Not implemented") panic("Not implemented")
@ -1421,7 +1434,6 @@ func (p *ParserATNSimulator) getUniqueAlt(configs ATNConfigSet) int {
return alt return alt
} }
//
// Add an edge to the DFA, if possible. This method calls // Add an edge to the DFA, if possible. This method calls
// {@link //addDFAState} to ensure the {@code to} state is present in the // {@link //addDFAState} to ensure the {@code to} state is present in the
// DFA. If {@code from} is {@code nil}, or if {@code t} is outside the // DFA. If {@code from} is {@code nil}, or if {@code t} is outside the
@ -1440,7 +1452,6 @@ func (p *ParserATNSimulator) getUniqueAlt(configs ATNConfigSet) int {
// @return If {@code to} is {@code nil}, p method returns {@code nil} // @return If {@code to} is {@code nil}, p method returns {@code nil}
// otherwise p method returns the result of calling {@link //addDFAState} // otherwise p method returns the result of calling {@link //addDFAState}
// on {@code to} // on {@code to}
//
func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFAState) *DFAState { func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFAState) *DFAState {
if ParserATNSimulatorDebug { if ParserATNSimulatorDebug {
fmt.Println("EDGE " + from.String() + " -> " + to.String() + " upon " + p.GetTokenName(t)) fmt.Println("EDGE " + from.String() + " -> " + to.String() + " upon " + p.GetTokenName(t))
@ -1472,7 +1483,6 @@ func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFA
return to return to
} }
//
// Add state {@code D} to the DFA if it is not already present, and return // Add state {@code D} to the DFA if it is not already present, and return
// the actual instance stored in the DFA. If a state equivalent to {@code D} // the actual instance stored in the DFA. If a state equivalent to {@code D}
// is already in the DFA, the existing state is returned. Otherwise p // is already in the DFA, the existing state is returned. Otherwise p
@ -1486,25 +1496,30 @@ func (p *ParserATNSimulator) addDFAEdge(dfa *DFA, from *DFAState, t int, to *DFA
// @return The state stored in the DFA. This will be either the existing // @return The state stored in the DFA. This will be either the existing
// state if {@code D} is already in the DFA, or {@code D} itself if the // state if {@code D} is already in the DFA, or {@code D} itself if the
// state was not already present. // state was not already present.
//
func (p *ParserATNSimulator) addDFAState(dfa *DFA, d *DFAState) *DFAState { func (p *ParserATNSimulator) addDFAState(dfa *DFA, d *DFAState) *DFAState {
if d == ATNSimulatorError { if d == ATNSimulatorError {
return d return d
} }
hash := d.hash() existing, present := dfa.states.Get(d)
existing, ok := dfa.getState(hash) if present {
if ok { if ParserATNSimulatorTraceATNSim {
fmt.Print("addDFAState " + d.String() + " exists")
}
return existing return existing
} }
d.stateNumber = dfa.numStates()
// The state was not present, so update it with configs
//
d.stateNumber = dfa.states.Len()
if !d.configs.ReadOnly() { if !d.configs.ReadOnly() {
d.configs.OptimizeConfigs(p.BaseATNSimulator) d.configs.OptimizeConfigs(p.BaseATNSimulator)
d.configs.SetReadOnly(true) d.configs.SetReadOnly(true)
} }
dfa.setState(hash, d) dfa.states.Put(d)
if ParserATNSimulatorDebug { if ParserATNSimulatorTraceATNSim {
fmt.Println("adding NewDFA state: " + d.String()) fmt.Println("addDFAState new " + d.String())
} }
return d return d
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -340,7 +340,7 @@ func (prc *BaseParserRuleContext) String(ruleNames []string, stop RuleContext) s
return s return s
} }
var RuleContextEmpty = NewBaseParserRuleContext(nil, -1) var ParserRuleContextEmpty = NewBaseParserRuleContext(nil, -1)
type InterpreterRuleContext interface { type InterpreterRuleContext interface {
ParserRuleContext ParserRuleContext

View File

@ -1,10 +1,12 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
package antlr package antlr
import ( import (
"fmt"
"golang.org/x/exp/slices"
"strconv" "strconv"
) )
@ -26,10 +28,10 @@ var (
) )
type PredictionContext interface { type PredictionContext interface {
hash() int Hash() int
Equals(interface{}) bool
GetParent(int) PredictionContext GetParent(int) PredictionContext
getReturnState(int) int getReturnState(int) int
equals(PredictionContext) bool
length() int length() int
isEmpty() bool isEmpty() bool
hasEmptyPath() bool hasEmptyPath() bool
@ -53,7 +55,7 @@ func (b *BasePredictionContext) isEmpty() bool {
func calculateHash(parent PredictionContext, returnState int) int { func calculateHash(parent PredictionContext, returnState int) int {
h := murmurInit(1) h := murmurInit(1)
h = murmurUpdate(h, parent.hash()) h = murmurUpdate(h, parent.Hash())
h = murmurUpdate(h, returnState) h = murmurUpdate(h, returnState)
return murmurFinish(h, 2) return murmurFinish(h, 2)
} }
@ -86,7 +88,6 @@ func NewPredictionContextCache() *PredictionContextCache {
// Add a context to the cache and return it. If the context already exists, // Add a context to the cache and return it. If the context already exists,
// return that one instead and do not add a Newcontext to the cache. // return that one instead and do not add a Newcontext to the cache.
// Protect shared cache from unsafe thread access. // Protect shared cache from unsafe thread access.
//
func (p *PredictionContextCache) add(ctx PredictionContext) PredictionContext { func (p *PredictionContextCache) add(ctx PredictionContext) PredictionContext {
if ctx == BasePredictionContextEMPTY { if ctx == BasePredictionContextEMPTY {
return BasePredictionContextEMPTY return BasePredictionContextEMPTY
@ -160,28 +161,28 @@ func (b *BaseSingletonPredictionContext) hasEmptyPath() bool {
return b.returnState == BasePredictionContextEmptyReturnState return b.returnState == BasePredictionContextEmptyReturnState
} }
func (b *BaseSingletonPredictionContext) equals(other PredictionContext) bool { func (b *BaseSingletonPredictionContext) Hash() int {
return b.cachedHash
}
func (b *BaseSingletonPredictionContext) Equals(other interface{}) bool {
if b == other { if b == other {
return true return true
} else if _, ok := other.(*BaseSingletonPredictionContext); !ok { }
if _, ok := other.(*BaseSingletonPredictionContext); !ok {
return false return false
} else if b.hash() != other.hash() {
return false // can't be same if hash is different
} }
otherP := other.(*BaseSingletonPredictionContext) otherP := other.(*BaseSingletonPredictionContext)
if b.returnState != other.getReturnState(0) { if b.returnState != otherP.getReturnState(0) {
return false return false
} else if b.parentCtx == nil { }
if b.parentCtx == nil {
return otherP.parentCtx == nil return otherP.parentCtx == nil
} }
return b.parentCtx.equals(otherP.parentCtx) return b.parentCtx.Equals(otherP.parentCtx)
}
func (b *BaseSingletonPredictionContext) hash() int {
return b.cachedHash
} }
func (b *BaseSingletonPredictionContext) String() string { func (b *BaseSingletonPredictionContext) String() string {
@ -215,7 +216,7 @@ func NewEmptyPredictionContext() *EmptyPredictionContext {
p := new(EmptyPredictionContext) p := new(EmptyPredictionContext)
p.BaseSingletonPredictionContext = NewBaseSingletonPredictionContext(nil, BasePredictionContextEmptyReturnState) p.BaseSingletonPredictionContext = NewBaseSingletonPredictionContext(nil, BasePredictionContextEmptyReturnState)
p.cachedHash = calculateEmptyHash()
return p return p
} }
@ -231,7 +232,11 @@ func (e *EmptyPredictionContext) getReturnState(index int) int {
return e.returnState return e.returnState
} }
func (e *EmptyPredictionContext) equals(other PredictionContext) bool { func (e *EmptyPredictionContext) Hash() int {
return e.cachedHash
}
func (e *EmptyPredictionContext) Equals(other interface{}) bool {
return e == other return e == other
} }
@ -254,7 +259,7 @@ func NewArrayPredictionContext(parents []PredictionContext, returnStates []int)
hash := murmurInit(1) hash := murmurInit(1)
for _, parent := range parents { for _, parent := range parents {
hash = murmurUpdate(hash, parent.hash()) hash = murmurUpdate(hash, parent.Hash())
} }
for _, returnState := range returnStates { for _, returnState := range returnStates {
@ -298,18 +303,31 @@ func (a *ArrayPredictionContext) getReturnState(index int) int {
return a.returnStates[index] return a.returnStates[index]
} }
func (a *ArrayPredictionContext) equals(other PredictionContext) bool { // Equals is the default comparison function for ArrayPredictionContext when no specialized
if _, ok := other.(*ArrayPredictionContext); !ok { // implementation is needed for a collection
return false func (a *ArrayPredictionContext) Equals(o interface{}) bool {
} else if a.cachedHash != other.hash() { if a == o {
return false // can't be same if hash is different return true
} else {
otherP := other.(*ArrayPredictionContext)
return &a.returnStates == &otherP.returnStates && &a.parents == &otherP.parents
} }
other, ok := o.(*ArrayPredictionContext)
if !ok {
return false
}
if a.cachedHash != other.Hash() {
return false // can't be same if hash is different
}
// Must compare the actual array elements and not just the array address
//
return slices.Equal(a.returnStates, other.returnStates) &&
slices.EqualFunc(a.parents, other.parents, func(x, y PredictionContext) bool {
return x.Equals(y)
})
} }
func (a *ArrayPredictionContext) hash() int { // Hash is the default hash function for ArrayPredictionContext when no specialized
// implementation is needed for a collection
func (a *ArrayPredictionContext) Hash() int {
return a.BasePredictionContext.cachedHash return a.BasePredictionContext.cachedHash
} }
@ -343,11 +361,11 @@ func (a *ArrayPredictionContext) String() string {
// / // /
func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) PredictionContext { func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) PredictionContext {
if outerContext == nil { if outerContext == nil {
outerContext = RuleContextEmpty outerContext = ParserRuleContextEmpty
} }
// if we are in RuleContext of start rule, s, then BasePredictionContext // if we are in RuleContext of start rule, s, then BasePredictionContext
// is EMPTY. Nobody called us. (if we are empty, return empty) // is EMPTY. Nobody called us. (if we are empty, return empty)
if outerContext.GetParent() == nil || outerContext == RuleContextEmpty { if outerContext.GetParent() == nil || outerContext == ParserRuleContextEmpty {
return BasePredictionContextEMPTY return BasePredictionContextEMPTY
} }
// If we have a parent, convert it to a BasePredictionContext graph // If we have a parent, convert it to a BasePredictionContext graph
@ -359,11 +377,20 @@ func predictionContextFromRuleContext(a *ATN, outerContext RuleContext) Predicti
} }
func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext { func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
// share same graph if both same
if a == b { // Share same graph if both same
//
if a == b || a.Equals(b) {
return a return a
} }
// In Java, EmptyPredictionContext inherits from SingletonPredictionContext, and so the test
// in java for SingletonPredictionContext will succeed and a new ArrayPredictionContext will be created
// from it.
// In go, EmptyPredictionContext does not equate to SingletonPredictionContext and so that conversion
// will fail. We need to test for both Empty and Singleton and create an ArrayPredictionContext from
// either of them.
ac, ok1 := a.(*BaseSingletonPredictionContext) ac, ok1 := a.(*BaseSingletonPredictionContext)
bc, ok2 := b.(*BaseSingletonPredictionContext) bc, ok2 := b.(*BaseSingletonPredictionContext)
@ -380,17 +407,32 @@ func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict)
return b return b
} }
} }
// convert singleton so both are arrays to normalize
if _, ok := a.(*BaseSingletonPredictionContext); ok { // Convert Singleton or Empty so both are arrays to normalize - We should not use the existing parameters
a = NewArrayPredictionContext([]PredictionContext{a.GetParent(0)}, []int{a.getReturnState(0)}) // here.
//
// TODO: I think that maybe the Prediction Context structs should be redone as there is a chance we will see this mess again - maybe redo the logic here
var arp, arb *ArrayPredictionContext
var ok bool
if arp, ok = a.(*ArrayPredictionContext); ok {
} else if _, ok = a.(*BaseSingletonPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{a.GetParent(0)}, []int{a.getReturnState(0)})
} else if _, ok = a.(*EmptyPredictionContext); ok {
arp = NewArrayPredictionContext([]PredictionContext{}, []int{})
} }
if _, ok := b.(*BaseSingletonPredictionContext); ok {
b = NewArrayPredictionContext([]PredictionContext{b.GetParent(0)}, []int{b.getReturnState(0)}) if arb, ok = b.(*ArrayPredictionContext); ok {
} else if _, ok = b.(*BaseSingletonPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{b.GetParent(0)}, []int{b.getReturnState(0)})
} else if _, ok = b.(*EmptyPredictionContext); ok {
arb = NewArrayPredictionContext([]PredictionContext{}, []int{})
} }
return mergeArrays(a.(*ArrayPredictionContext), b.(*ArrayPredictionContext), rootIsWildcard, mergeCache)
// Both arp and arb
return mergeArrays(arp, arb, rootIsWildcard, mergeCache)
} }
//
// Merge two {@link SingletonBasePredictionContext} instances. // Merge two {@link SingletonBasePredictionContext} instances.
// //
// <p>Stack tops equal, parents merge is same return left graph.<br> // <p>Stack tops equal, parents merge is same return left graph.<br>
@ -423,11 +465,11 @@ func merge(a, b PredictionContext, rootIsWildcard bool, mergeCache *DoubleDict)
// / // /
func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext { func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil { if mergeCache != nil {
previous := mergeCache.Get(a.hash(), b.hash()) previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil { if previous != nil {
return previous.(PredictionContext) return previous.(PredictionContext)
} }
previous = mergeCache.Get(b.hash(), a.hash()) previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil { if previous != nil {
return previous.(PredictionContext) return previous.(PredictionContext)
} }
@ -436,7 +478,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
rootMerge := mergeRoot(a, b, rootIsWildcard) rootMerge := mergeRoot(a, b, rootIsWildcard)
if rootMerge != nil { if rootMerge != nil {
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), rootMerge) mergeCache.set(a.Hash(), b.Hash(), rootMerge)
} }
return rootMerge return rootMerge
} }
@ -456,7 +498,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
// Newjoined parent so create Newsingleton pointing to it, a' // Newjoined parent so create Newsingleton pointing to it, a'
spc := SingletonBasePredictionContextCreate(parent, a.returnState) spc := SingletonBasePredictionContextCreate(parent, a.returnState)
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), spc) mergeCache.set(a.Hash(), b.Hash(), spc)
} }
return spc return spc
} }
@ -478,7 +520,7 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
parents := []PredictionContext{singleParent, singleParent} parents := []PredictionContext{singleParent, singleParent}
apc := NewArrayPredictionContext(parents, payloads) apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), apc) mergeCache.set(a.Hash(), b.Hash(), apc)
} }
return apc return apc
} }
@ -494,12 +536,11 @@ func mergeSingletons(a, b *BaseSingletonPredictionContext, rootIsWildcard bool,
} }
apc := NewArrayPredictionContext(parents, payloads) apc := NewArrayPredictionContext(parents, payloads)
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), apc) mergeCache.set(a.Hash(), b.Hash(), apc)
} }
return apc return apc
} }
//
// Handle case where at least one of {@code a} or {@code b} is // Handle case where at least one of {@code a} or {@code b} is
// {@link //EMPTY}. In the following diagrams, the symbol {@code $} is used // {@link //EMPTY}. In the following diagrams, the symbol {@code $} is used
// to represent {@link //EMPTY}. // to represent {@link //EMPTY}.
@ -561,7 +602,6 @@ func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionC
return nil return nil
} }
//
// Merge two {@link ArrayBasePredictionContext} instances. // Merge two {@link ArrayBasePredictionContext} instances.
// //
// <p>Different tops, different parents.<br> // <p>Different tops, different parents.<br>
@ -583,12 +623,18 @@ func mergeRoot(a, b SingletonPredictionContext, rootIsWildcard bool) PredictionC
// / // /
func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext { func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *DoubleDict) PredictionContext {
if mergeCache != nil { if mergeCache != nil {
previous := mergeCache.Get(a.hash(), b.hash()) previous := mergeCache.Get(a.Hash(), b.Hash())
if previous != nil { if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext) return previous.(PredictionContext)
} }
previous = mergeCache.Get(b.hash(), a.hash()) previous = mergeCache.Get(b.Hash(), a.Hash())
if previous != nil { if previous != nil {
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> previous")
}
return previous.(PredictionContext) return previous.(PredictionContext)
} }
} }
@ -608,7 +654,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
payload := a.returnStates[i] payload := a.returnStates[i]
// $+$ = $ // $+$ = $
bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil bothDollars := payload == BasePredictionContextEmptyReturnState && aParent == nil && bParent == nil
axAX := (aParent != nil && bParent != nil && aParent == bParent) // ax+ax axAX := aParent != nil && bParent != nil && aParent == bParent // ax+ax
// -> // ->
// ax // ax
if bothDollars || axAX { if bothDollars || axAX {
@ -651,7 +697,7 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
if k == 1 { // for just one merged element, return singleton top if k == 1 { // for just one merged element, return singleton top
pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0]) pc := SingletonBasePredictionContextCreate(mergedParents[0], mergedReturnStates[0])
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), pc) mergeCache.set(a.Hash(), b.Hash(), pc)
} }
return pc return pc
} }
@ -663,27 +709,36 @@ func mergeArrays(a, b *ArrayPredictionContext, rootIsWildcard bool, mergeCache *
// if we created same array as a or b, return that instead // if we created same array as a or b, return that instead
// TODO: track whether this is possible above during merge sort for speed // TODO: track whether this is possible above during merge sort for speed
// TODO: In go, I do not think we can just do M == xx as M is a brand new allocation. This could be causing allocation problems
if M == a { if M == a {
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), a) mergeCache.set(a.Hash(), b.Hash(), a)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> a")
} }
return a return a
} }
if M == b { if M == b {
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), b) mergeCache.set(a.Hash(), b.Hash(), b)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> b")
} }
return b return b
} }
combineCommonParents(mergedParents) combineCommonParents(mergedParents)
if mergeCache != nil { if mergeCache != nil {
mergeCache.set(a.hash(), b.hash(), M) mergeCache.set(a.Hash(), b.Hash(), M)
}
if ParserATNSimulatorTraceATNSim {
fmt.Println("mergeArrays a=" + a.String() + ",b=" + b.String() + " -> " + M.String())
} }
return M return M
} }
//
// Make pass over all <em>M</em> {@code parents} merge any {@code equals()} // Make pass over all <em>M</em> {@code parents} merge any {@code equals()}
// ones. // ones.
// / // /

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -70,7 +70,6 @@ const (
PredictionModeLLExactAmbigDetection = 2 PredictionModeLLExactAmbigDetection = 2
) )
//
// Computes the SLL prediction termination condition. // Computes the SLL prediction termination condition.
// //
// <p> // <p>
@ -108,9 +107,9 @@ const (
// The single-alt-state thing lets prediction continue upon rules like // The single-alt-state thing lets prediction continue upon rules like
// (otherwise, it would admit defeat too soon):</p> // (otherwise, it would admit defeat too soon):</p>
// //
// <p>{@code [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) '' }</p> // <p>{@code [12|1|[], 6|2|[], 12|2|[]]. s : (ID | ID ID?) }</p>
// //
// <p>When the ATN simulation reaches the state before {@code ''}, it has a // <p>When the ATN simulation reaches the state before {@code }, it has a
// DFA state that looks like: {@code [12|1|[], 6|2|[], 12|2|[]]}. Naturally // DFA state that looks like: {@code [12|1|[], 6|2|[], 12|2|[]]}. Naturally
// {@code 12|1|[]} and {@code 12|2|[]} conflict, but we cannot stop // {@code 12|1|[]} and {@code 12|2|[]} conflict, but we cannot stop
// processing this node because alternative to has another way to continue, // processing this node because alternative to has another way to continue,
@ -152,16 +151,15 @@ const (
// //
// <p>Before testing these configurations against others, we have to merge // <p>Before testing these configurations against others, we have to merge
// {@code x} and {@code x'} (without modifying the existing configurations). // {@code x} and {@code x'} (without modifying the existing configurations).
// For example, we test {@code (x+x')==x''} when looking for conflicts in // For example, we test {@code (x+x')==x} when looking for conflicts in
// the following configurations.</p> // the following configurations.</p>
// //
// <p>{@code (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x'', {})}</p> // <p>{@code (s, 1, x, {}), (s, 1, x', {p}), (s, 2, x, {})}</p>
// //
// <p>If the configuration set has predicates (as indicated by // <p>If the configuration set has predicates (as indicated by
// {@link ATNConfigSet//hasSemanticContext}), this algorithm makes a copy of // {@link ATNConfigSet//hasSemanticContext}), this algorithm makes a copy of
// the configurations to strip out all of the predicates so that a standard // the configurations to strip out all of the predicates so that a standard
// {@link ATNConfigSet} will merge everything ignoring predicates.</p> // {@link ATNConfigSet} will merge everything ignoring predicates.</p>
//
func PredictionModehasSLLConflictTerminatingPrediction(mode int, configs ATNConfigSet) bool { func PredictionModehasSLLConflictTerminatingPrediction(mode int, configs ATNConfigSet) bool {
// Configs in rule stop states indicate reaching the end of the decision // Configs in rule stop states indicate reaching the end of the decision
// rule (local context) or end of start rule (full context). If all // rule (local context) or end of start rule (full context). If all
@ -229,7 +227,6 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
return true return true
} }
//
// Full LL prediction termination. // Full LL prediction termination.
// //
// <p>Can we stop looking ahead during ATN simulation or is there some // <p>Can we stop looking ahead during ATN simulation or is there some
@ -334,7 +331,7 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
// </li> // </li>
// //
// <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 1, y)}, // <li>{@code (s, 1, x)}, {@code (s, 2, x)}, {@code (s', 1, y)},
// {@code (s', 2, y)}, {@code (s'', 1, z)} yields non-conflicting set // {@code (s', 2, y)}, {@code (s, 1, z)} yields non-conflicting set
// {@code {1}} U conflicting sets {@code min({1,2})} U {@code min({1,2})} = // {@code {1}} U conflicting sets {@code min({1,2})} U {@code min({1,2})} =
// {@code {1}} =&gt stop and predict 1</li> // {@code {1}} =&gt stop and predict 1</li>
// //
@ -369,31 +366,26 @@ func PredictionModeallConfigsInRuleStopStates(configs ATNConfigSet) bool {
// two or one and three so we keep going. We can only stop prediction when // two or one and three so we keep going. We can only stop prediction when
// we need exact ambiguity detection when the sets look like // we need exact ambiguity detection when the sets look like
// {@code A={{1,2}}} or {@code {{1,2},{1,2}}}, etc...</p> // {@code A={{1,2}}} or {@code {{1,2},{1,2}}}, etc...</p>
//
func PredictionModeresolvesToJustOneViableAlt(altsets []*BitSet) int { func PredictionModeresolvesToJustOneViableAlt(altsets []*BitSet) int {
return PredictionModegetSingleViableAlt(altsets) return PredictionModegetSingleViableAlt(altsets)
} }
//
// Determines if every alternative subset in {@code altsets} contains more // Determines if every alternative subset in {@code altsets} contains more
// than one alternative. // than one alternative.
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
// @return {@code true} if every {@link BitSet} in {@code altsets} has // @return {@code true} if every {@link BitSet} in {@code altsets} has
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false} // {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
//
func PredictionModeallSubsetsConflict(altsets []*BitSet) bool { func PredictionModeallSubsetsConflict(altsets []*BitSet) bool {
return !PredictionModehasNonConflictingAltSet(altsets) return !PredictionModehasNonConflictingAltSet(altsets)
} }
//
// Determines if any single alternative subset in {@code altsets} contains // Determines if any single alternative subset in {@code altsets} contains
// exactly one alternative. // exactly one alternative.
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with // @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} 1, otherwise {@code false} // {@link BitSet//cardinality cardinality} 1, otherwise {@code false}
//
func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool { func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ { for i := 0; i < len(altsets); i++ {
alts := altsets[i] alts := altsets[i]
@ -404,14 +396,12 @@ func PredictionModehasNonConflictingAltSet(altsets []*BitSet) bool {
return false return false
} }
//
// Determines if any single alternative subset in {@code altsets} contains // Determines if any single alternative subset in {@code altsets} contains
// more than one alternative. // more than one alternative.
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
// @return {@code true} if {@code altsets} contains a {@link BitSet} with // @return {@code true} if {@code altsets} contains a {@link BitSet} with
// {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false} // {@link BitSet//cardinality cardinality} &gt 1, otherwise {@code false}
//
func PredictionModehasConflictingAltSet(altsets []*BitSet) bool { func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
for i := 0; i < len(altsets); i++ { for i := 0; i < len(altsets); i++ {
alts := altsets[i] alts := altsets[i]
@ -422,13 +412,11 @@ func PredictionModehasConflictingAltSet(altsets []*BitSet) bool {
return false return false
} }
//
// Determines if every alternative subset in {@code altsets} is equivalent. // Determines if every alternative subset in {@code altsets} is equivalent.
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
// @return {@code true} if every member of {@code altsets} is equal to the // @return {@code true} if every member of {@code altsets} is equal to the
// others, otherwise {@code false} // others, otherwise {@code false}
//
func PredictionModeallSubsetsEqual(altsets []*BitSet) bool { func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
var first *BitSet var first *BitSet
@ -444,13 +432,11 @@ func PredictionModeallSubsetsEqual(altsets []*BitSet) bool {
return true return true
} }
//
// Returns the unique alternative predicted by all alternative subsets in // Returns the unique alternative predicted by all alternative subsets in
// {@code altsets}. If no such alternative exists, this method returns // {@code altsets}. If no such alternative exists, this method returns
// {@link ATN//INVALID_ALT_NUMBER}. // {@link ATN//INVALID_ALT_NUMBER}.
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
//
func PredictionModegetUniqueAlt(altsets []*BitSet) int { func PredictionModegetUniqueAlt(altsets []*BitSet) int {
all := PredictionModeGetAlts(altsets) all := PredictionModeGetAlts(altsets)
if all.length() == 1 { if all.length() == 1 {
@ -466,7 +452,6 @@ func PredictionModegetUniqueAlt(altsets []*BitSet) int {
// //
// @param altsets a collection of alternative subsets // @param altsets a collection of alternative subsets
// @return the set of represented alternatives in {@code altsets} // @return the set of represented alternatives in {@code altsets}
//
func PredictionModeGetAlts(altsets []*BitSet) *BitSet { func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
all := NewBitSet() all := NewBitSet()
for _, alts := range altsets { for _, alts := range altsets {
@ -475,44 +460,35 @@ func PredictionModeGetAlts(altsets []*BitSet) *BitSet {
return all return all
} }
// // PredictionModegetConflictingAltSubsets gets the conflicting alt subsets from a configuration set.
// This func gets the conflicting alt subsets from a configuration set.
// For each configuration {@code c} in {@code configs}: // For each configuration {@code c} in {@code configs}:
// //
// <pre> // <pre>
// map[c] U= c.{@link ATNConfig//alt alt} // map hash/equals uses s and x, not // map[c] U= c.{@link ATNConfig//alt alt} // map hash/equals uses s and x, not
// alt and not pred // alt and not pred
// </pre> // </pre>
//
func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet { func PredictionModegetConflictingAltSubsets(configs ATNConfigSet) []*BitSet {
configToAlts := make(map[int]*BitSet) configToAlts := NewJMap[ATNConfig, *BitSet, *ATNAltConfigComparator[ATNConfig]](atnAltCfgEqInst)
for _, c := range configs.GetItems() { for _, c := range configs.GetItems() {
key := 31 * c.GetState().GetStateNumber() + c.GetContext().hash()
alts, ok := configToAlts[key] alts, ok := configToAlts.Get(c)
if !ok { if !ok {
alts = NewBitSet() alts = NewBitSet()
configToAlts[key] = alts configToAlts.Put(c, alts)
} }
alts.add(c.GetAlt()) alts.add(c.GetAlt())
} }
values := make([]*BitSet, 0, 10) return configToAlts.Values()
for _, v := range configToAlts {
values = append(values, v)
}
return values
} }
// // PredictionModeGetStateToAltMap gets a map from state to alt subset from a configuration set. For each
// Get a map from state to alt subset from a configuration set. For each
// configuration {@code c} in {@code configs}: // configuration {@code c} in {@code configs}:
// //
// <pre> // <pre>
// map[c.{@link ATNConfig//state state}] U= c.{@link ATNConfig//alt alt} // map[c.{@link ATNConfig//state state}] U= c.{@link ATNConfig//alt alt}
// </pre> // </pre>
//
func PredictionModeGetStateToAltMap(configs ATNConfigSet) *AltDict { func PredictionModeGetStateToAltMap(configs ATNConfigSet) *AltDict {
m := NewAltDict() m := NewAltDict()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -49,7 +49,7 @@ var tokenTypeMapCache = make(map[string]int)
var ruleIndexMapCache = make(map[string]int) var ruleIndexMapCache = make(map[string]int)
func (b *BaseRecognizer) checkVersion(toolVersion string) { func (b *BaseRecognizer) checkVersion(toolVersion string) {
runtimeVersion := "4.10.1" runtimeVersion := "4.12.0"
if runtimeVersion != toolVersion { if runtimeVersion != toolVersion {
fmt.Println("ANTLR runtime and generated code versions disagree: " + runtimeVersion + "!=" + toolVersion) fmt.Println("ANTLR runtime and generated code versions disagree: " + runtimeVersion + "!=" + toolVersion)
} }
@ -108,7 +108,6 @@ func (b *BaseRecognizer) SetState(v int) {
// Get a map from rule names to rule indexes. // Get a map from rule names to rule indexes.
// //
// <p>Used for XPath and tree pattern compilation.</p> // <p>Used for XPath and tree pattern compilation.</p>
//
func (b *BaseRecognizer) GetRuleIndexMap() map[string]int { func (b *BaseRecognizer) GetRuleIndexMap() map[string]int {
panic("Method not defined!") panic("Method not defined!")
@ -171,18 +170,18 @@ func (b *BaseRecognizer) GetErrorHeader(e RecognitionException) string {
} }
// How should a token be displayed in an error message? The default // How should a token be displayed in an error message? The default
// is to display just the text, but during development you might //
// want to have a lot of information spit out. Override in that case // is to display just the text, but during development you might
// to use t.String() (which, for CommonToken, dumps everything about // want to have a lot of information spit out. Override in that case
// the token). This is better than forcing you to override a method in // to use t.String() (which, for CommonToken, dumps everything about
// your token objects because you don't have to go modify your lexer // the token). This is better than forcing you to override a method in
// so that it creates a NewJava type. // your token objects because you don't have to go modify your lexer
// so that it creates a NewJava type.
// //
// @deprecated This method is not called by the ANTLR 4 Runtime. Specific // @deprecated This method is not called by the ANTLR 4 Runtime. Specific
// implementations of {@link ANTLRErrorStrategy} may provide a similar // implementations of {@link ANTLRErrorStrategy} may provide a similar
// feature when necessary. For example, see // feature when necessary. For example, see
// {@link DefaultErrorStrategy//GetTokenErrorDisplay}. // {@link DefaultErrorStrategy//GetTokenErrorDisplay}.
//
func (b *BaseRecognizer) GetTokenErrorDisplay(t Token) string { func (b *BaseRecognizer) GetTokenErrorDisplay(t Token) string {
if t == nil { if t == nil {
return "<no token>" return "<no token>"

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -18,12 +18,12 @@ import (
// //
type SemanticContext interface { type SemanticContext interface {
comparable Equals(other Collectable[SemanticContext]) bool
Hash() int
evaluate(parser Recognizer, outerContext RuleContext) bool evaluate(parser Recognizer, outerContext RuleContext) bool
evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext
hash() int
String() string String() string
} }
@ -78,7 +78,7 @@ func NewPredicate(ruleIndex, predIndex int, isCtxDependent bool) *Predicate {
//The default {@link SemanticContext}, which is semantically equivalent to //The default {@link SemanticContext}, which is semantically equivalent to
//a predicate of the form {@code {true}?}. //a predicate of the form {@code {true}?}.
var SemanticContextNone SemanticContext = NewPredicate(-1, -1, false) var SemanticContextNone = NewPredicate(-1, -1, false)
func (p *Predicate) evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext { func (p *Predicate) evalPrecedence(parser Recognizer, outerContext RuleContext) SemanticContext {
return p return p
@ -95,7 +95,7 @@ func (p *Predicate) evaluate(parser Recognizer, outerContext RuleContext) bool {
return parser.Sempred(localctx, p.ruleIndex, p.predIndex) return parser.Sempred(localctx, p.ruleIndex, p.predIndex)
} }
func (p *Predicate) equals(other interface{}) bool { func (p *Predicate) Equals(other Collectable[SemanticContext]) bool {
if p == other { if p == other {
return true return true
} else if _, ok := other.(*Predicate); !ok { } else if _, ok := other.(*Predicate); !ok {
@ -107,7 +107,7 @@ func (p *Predicate) equals(other interface{}) bool {
} }
} }
func (p *Predicate) hash() int { func (p *Predicate) Hash() int {
h := murmurInit(0) h := murmurInit(0)
h = murmurUpdate(h, p.ruleIndex) h = murmurUpdate(h, p.ruleIndex)
h = murmurUpdate(h, p.predIndex) h = murmurUpdate(h, p.predIndex)
@ -151,17 +151,22 @@ func (p *PrecedencePredicate) compareTo(other *PrecedencePredicate) int {
return p.precedence - other.precedence return p.precedence - other.precedence
} }
func (p *PrecedencePredicate) equals(other interface{}) bool { func (p *PrecedencePredicate) Equals(other Collectable[SemanticContext]) bool {
if p == other {
return true var op *PrecedencePredicate
} else if _, ok := other.(*PrecedencePredicate); !ok { var ok bool
if op, ok = other.(*PrecedencePredicate); !ok {
return false return false
} else {
return p.precedence == other.(*PrecedencePredicate).precedence
} }
if p == op {
return true
}
return p.precedence == other.(*PrecedencePredicate).precedence
} }
func (p *PrecedencePredicate) hash() int { func (p *PrecedencePredicate) Hash() int {
h := uint32(1) h := uint32(1)
h = 31*h + uint32(p.precedence) h = 31*h + uint32(p.precedence)
return int(h) return int(h)
@ -171,10 +176,10 @@ func (p *PrecedencePredicate) String() string {
return "{" + strconv.Itoa(p.precedence) + ">=prec}?" return "{" + strconv.Itoa(p.precedence) + ">=prec}?"
} }
func PrecedencePredicatefilterPrecedencePredicates(set Set) []*PrecedencePredicate { func PrecedencePredicatefilterPrecedencePredicates(set *JStore[SemanticContext, Comparator[SemanticContext]]) []*PrecedencePredicate {
result := make([]*PrecedencePredicate, 0) result := make([]*PrecedencePredicate, 0)
set.Each(func(v interface{}) bool { set.Each(func(v SemanticContext) bool {
if c2, ok := v.(*PrecedencePredicate); ok { if c2, ok := v.(*PrecedencePredicate); ok {
result = append(result, c2) result = append(result, c2)
} }
@ -193,21 +198,21 @@ type AND struct {
func NewAND(a, b SemanticContext) *AND { func NewAND(a, b SemanticContext) *AND {
operands := newArray2DHashSet(nil, nil) operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
if aa, ok := a.(*AND); ok { if aa, ok := a.(*AND); ok {
for _, o := range aa.opnds { for _, o := range aa.opnds {
operands.Add(o) operands.Put(o)
} }
} else { } else {
operands.Add(a) operands.Put(a)
} }
if ba, ok := b.(*AND); ok { if ba, ok := b.(*AND); ok {
for _, o := range ba.opnds { for _, o := range ba.opnds {
operands.Add(o) operands.Put(o)
} }
} else { } else {
operands.Add(b) operands.Put(b)
} }
precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands) precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands)
if len(precedencePredicates) > 0 { if len(precedencePredicates) > 0 {
@ -220,7 +225,7 @@ func NewAND(a, b SemanticContext) *AND {
} }
} }
operands.Add(reduced) operands.Put(reduced)
} }
vs := operands.Values() vs := operands.Values()
@ -235,14 +240,15 @@ func NewAND(a, b SemanticContext) *AND {
return and return and
} }
func (a *AND) equals(other interface{}) bool { func (a *AND) Equals(other Collectable[SemanticContext]) bool {
if a == other { if a == other {
return true return true
} else if _, ok := other.(*AND); !ok { }
if _, ok := other.(*AND); !ok {
return false return false
} else { } else {
for i, v := range other.(*AND).opnds { for i, v := range other.(*AND).opnds {
if !a.opnds[i].equals(v) { if !a.opnds[i].Equals(v) {
return false return false
} }
} }
@ -250,13 +256,11 @@ func (a *AND) equals(other interface{}) bool {
} }
} }
//
// {@inheritDoc} // {@inheritDoc}
// //
// <p> // <p>
// The evaluation of predicates by a context is short-circuiting, but // The evaluation of predicates by a context is short-circuiting, but
// unordered.</p> // unordered.</p>
//
func (a *AND) evaluate(parser Recognizer, outerContext RuleContext) bool { func (a *AND) evaluate(parser Recognizer, outerContext RuleContext) bool {
for i := 0; i < len(a.opnds); i++ { for i := 0; i < len(a.opnds); i++ {
if !a.opnds[i].evaluate(parser, outerContext) { if !a.opnds[i].evaluate(parser, outerContext) {
@ -304,18 +308,18 @@ func (a *AND) evalPrecedence(parser Recognizer, outerContext RuleContext) Semant
return result return result
} }
func (a *AND) hash() int { func (a *AND) Hash() int {
h := murmurInit(37) // Init with a value different from OR h := murmurInit(37) // Init with a value different from OR
for _, op := range a.opnds { for _, op := range a.opnds {
h = murmurUpdate(h, op.hash()) h = murmurUpdate(h, op.Hash())
} }
return murmurFinish(h, len(a.opnds)) return murmurFinish(h, len(a.opnds))
} }
func (a *OR) hash() int { func (a *OR) Hash() int {
h := murmurInit(41) // Init with a value different from AND h := murmurInit(41) // Init with a value different from AND
for _, op := range a.opnds { for _, op := range a.opnds {
h = murmurUpdate(h, op.hash()) h = murmurUpdate(h, op.Hash())
} }
return murmurFinish(h, len(a.opnds)) return murmurFinish(h, len(a.opnds))
} }
@ -345,21 +349,21 @@ type OR struct {
func NewOR(a, b SemanticContext) *OR { func NewOR(a, b SemanticContext) *OR {
operands := newArray2DHashSet(nil, nil) operands := NewJStore[SemanticContext, Comparator[SemanticContext]](semctxEqInst)
if aa, ok := a.(*OR); ok { if aa, ok := a.(*OR); ok {
for _, o := range aa.opnds { for _, o := range aa.opnds {
operands.Add(o) operands.Put(o)
} }
} else { } else {
operands.Add(a) operands.Put(a)
} }
if ba, ok := b.(*OR); ok { if ba, ok := b.(*OR); ok {
for _, o := range ba.opnds { for _, o := range ba.opnds {
operands.Add(o) operands.Put(o)
} }
} else { } else {
operands.Add(b) operands.Put(b)
} }
precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands) precedencePredicates := PrecedencePredicatefilterPrecedencePredicates(operands)
if len(precedencePredicates) > 0 { if len(precedencePredicates) > 0 {
@ -372,7 +376,7 @@ func NewOR(a, b SemanticContext) *OR {
} }
} }
operands.Add(reduced) operands.Put(reduced)
} }
vs := operands.Values() vs := operands.Values()
@ -388,14 +392,14 @@ func NewOR(a, b SemanticContext) *OR {
return o return o
} }
func (o *OR) equals(other interface{}) bool { func (o *OR) Equals(other Collectable[SemanticContext]) bool {
if o == other { if o == other {
return true return true
} else if _, ok := other.(*OR); !ok { } else if _, ok := other.(*OR); !ok {
return false return false
} else { } else {
for i, v := range other.(*OR).opnds { for i, v := range other.(*OR).opnds {
if !o.opnds[i].equals(v) { if !o.opnds[i].Equals(v) {
return false return false
} }
} }
@ -406,7 +410,6 @@ func (o *OR) equals(other interface{}) bool {
// <p> // <p>
// The evaluation of predicates by o context is short-circuiting, but // The evaluation of predicates by o context is short-circuiting, but
// unordered.</p> // unordered.</p>
//
func (o *OR) evaluate(parser Recognizer, outerContext RuleContext) bool { func (o *OR) evaluate(parser Recognizer, outerContext RuleContext) bool {
for i := 0; i < len(o.opnds); i++ { for i := 0; i < len(o.opnds); i++ {
if o.opnds[i].evaluate(parser, outerContext) { if o.opnds[i].evaluate(parser, outerContext) {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -158,7 +158,6 @@ func NewCommonToken(source *TokenSourceCharStreamPair, tokenType, channel, start
// {@link Token//GetInputStream}.</p> // {@link Token//GetInputStream}.</p>
// //
// @param oldToken The token to copy. // @param oldToken The token to copy.
//
func (c *CommonToken) clone() *CommonToken { func (c *CommonToken) clone() *CommonToken {
t := NewCommonToken(c.source, c.tokenType, c.channel, c.start, c.stop) t := NewCommonToken(c.source, c.tokenType, c.channel, c.start, c.stop)
t.tokenIndex = c.GetTokenIndex() t.tokenIndex = c.GetTokenIndex()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,15 +1,15 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
package antlr package antlr
import ( import (
"bytes" "bytes"
"fmt" "fmt"
) )
//
//
// Useful for rewriting out a buffered input token stream after doing some // Useful for rewriting out a buffered input token stream after doing some
// augmentation or other manipulations on it. // augmentation or other manipulations on it.
@ -85,12 +85,10 @@ import (
// If you don't use named rewrite streams, a "default" stream is used as the // If you don't use named rewrite streams, a "default" stream is used as the
// first example shows.</p> // first example shows.</p>
const (
const(
Default_Program_Name = "default" Default_Program_Name = "default"
Program_Init_Size = 100 Program_Init_Size = 100
Min_Token_Index = 0 Min_Token_Index = 0
) )
// Define the rewrite operation hierarchy // Define the rewrite operation hierarchy
@ -98,13 +96,13 @@ const(
type RewriteOperation interface { type RewriteOperation interface {
// Execute the rewrite operation by possibly adding to the buffer. // Execute the rewrite operation by possibly adding to the buffer.
// Return the index of the next token to operate on. // Return the index of the next token to operate on.
Execute(buffer *bytes.Buffer) int Execute(buffer *bytes.Buffer) int
String() string String() string
GetInstructionIndex() int GetInstructionIndex() int
GetIndex() int GetIndex() int
GetText() string GetText() string
GetOpName() string GetOpName() string
GetTokens() TokenStream GetTokens() TokenStream
SetInstructionIndex(val int) SetInstructionIndex(val int)
SetIndex(int) SetIndex(int)
SetText(string) SetText(string)
@ -114,63 +112,62 @@ type RewriteOperation interface {
type BaseRewriteOperation struct { type BaseRewriteOperation struct {
//Current index of rewrites list //Current index of rewrites list
instruction_index int instruction_index int
//Token buffer index //Token buffer index
index int index int
//Substitution text //Substitution text
text string text string
//Actual operation name //Actual operation name
op_name string op_name string
//Pointer to token steam //Pointer to token steam
tokens TokenStream tokens TokenStream
} }
func (op *BaseRewriteOperation)GetInstructionIndex() int{ func (op *BaseRewriteOperation) GetInstructionIndex() int {
return op.instruction_index return op.instruction_index
} }
func (op *BaseRewriteOperation)GetIndex() int{ func (op *BaseRewriteOperation) GetIndex() int {
return op.index return op.index
} }
func (op *BaseRewriteOperation)GetText() string{ func (op *BaseRewriteOperation) GetText() string {
return op.text return op.text
} }
func (op *BaseRewriteOperation)GetOpName() string{ func (op *BaseRewriteOperation) GetOpName() string {
return op.op_name return op.op_name
} }
func (op *BaseRewriteOperation)GetTokens() TokenStream{ func (op *BaseRewriteOperation) GetTokens() TokenStream {
return op.tokens return op.tokens
} }
func (op *BaseRewriteOperation)SetInstructionIndex(val int){ func (op *BaseRewriteOperation) SetInstructionIndex(val int) {
op.instruction_index = val op.instruction_index = val
} }
func (op *BaseRewriteOperation)SetIndex(val int) { func (op *BaseRewriteOperation) SetIndex(val int) {
op.index = val op.index = val
} }
func (op *BaseRewriteOperation)SetText(val string){ func (op *BaseRewriteOperation) SetText(val string) {
op.text = val op.text = val
} }
func (op *BaseRewriteOperation)SetOpName(val string){ func (op *BaseRewriteOperation) SetOpName(val string) {
op.op_name = val op.op_name = val
} }
func (op *BaseRewriteOperation)SetTokens(val TokenStream) { func (op *BaseRewriteOperation) SetTokens(val TokenStream) {
op.tokens = val op.tokens = val
} }
func (op *BaseRewriteOperation) Execute(buffer *bytes.Buffer) int {
func (op *BaseRewriteOperation) Execute(buffer *bytes.Buffer) int{
return op.index return op.index
} }
func (op *BaseRewriteOperation) String() string { func (op *BaseRewriteOperation) String() string {
return fmt.Sprintf("<%s@%d:\"%s\">", return fmt.Sprintf("<%s@%d:\"%s\">",
op.op_name, op.op_name,
op.tokens.Get(op.GetIndex()), op.tokens.Get(op.GetIndex()),
@ -179,26 +176,25 @@ func (op *BaseRewriteOperation) String() string {
} }
type InsertBeforeOp struct { type InsertBeforeOp struct {
BaseRewriteOperation BaseRewriteOperation
} }
func NewInsertBeforeOp(index int, text string, stream TokenStream) *InsertBeforeOp{ func NewInsertBeforeOp(index int, text string, stream TokenStream) *InsertBeforeOp {
return &InsertBeforeOp{BaseRewriteOperation:BaseRewriteOperation{ return &InsertBeforeOp{BaseRewriteOperation: BaseRewriteOperation{
index:index, index: index,
text:text, text: text,
op_name:"InsertBeforeOp", op_name: "InsertBeforeOp",
tokens:stream, tokens: stream,
}} }}
} }
func (op *InsertBeforeOp) Execute(buffer *bytes.Buffer) int{ func (op *InsertBeforeOp) Execute(buffer *bytes.Buffer) int {
buffer.WriteString(op.text) buffer.WriteString(op.text)
if op.tokens.Get(op.index).GetTokenType() != TokenEOF{ if op.tokens.Get(op.index).GetTokenType() != TokenEOF {
buffer.WriteString(op.tokens.Get(op.index).GetText()) buffer.WriteString(op.tokens.Get(op.index).GetText())
} }
return op.index+1 return op.index + 1
} }
func (op *InsertBeforeOp) String() string { func (op *InsertBeforeOp) String() string {
@ -213,20 +209,20 @@ type InsertAfterOp struct {
BaseRewriteOperation BaseRewriteOperation
} }
func NewInsertAfterOp(index int, text string, stream TokenStream) *InsertAfterOp{ func NewInsertAfterOp(index int, text string, stream TokenStream) *InsertAfterOp {
return &InsertAfterOp{BaseRewriteOperation:BaseRewriteOperation{ return &InsertAfterOp{BaseRewriteOperation: BaseRewriteOperation{
index:index+1, index: index + 1,
text:text, text: text,
tokens:stream, tokens: stream,
}} }}
} }
func (op *InsertAfterOp) Execute(buffer *bytes.Buffer) int { func (op *InsertAfterOp) Execute(buffer *bytes.Buffer) int {
buffer.WriteString(op.text) buffer.WriteString(op.text)
if op.tokens.Get(op.index).GetTokenType() != TokenEOF{ if op.tokens.Get(op.index).GetTokenType() != TokenEOF {
buffer.WriteString(op.tokens.Get(op.index).GetText()) buffer.WriteString(op.tokens.Get(op.index).GetText())
} }
return op.index+1 return op.index + 1
} }
func (op *InsertAfterOp) String() string { func (op *InsertAfterOp) String() string {
@ -235,28 +231,28 @@ func (op *InsertAfterOp) String() string {
// I'm going to try replacing range from x..y with (y-x)+1 ReplaceOp // I'm going to try replacing range from x..y with (y-x)+1 ReplaceOp
// instructions. // instructions.
type ReplaceOp struct{ type ReplaceOp struct {
BaseRewriteOperation BaseRewriteOperation
LastIndex int LastIndex int
} }
func NewReplaceOp(from, to int, text string, stream TokenStream)*ReplaceOp { func NewReplaceOp(from, to int, text string, stream TokenStream) *ReplaceOp {
return &ReplaceOp{ return &ReplaceOp{
BaseRewriteOperation:BaseRewriteOperation{ BaseRewriteOperation: BaseRewriteOperation{
index:from, index: from,
text:text, text: text,
op_name:"ReplaceOp", op_name: "ReplaceOp",
tokens:stream, tokens: stream,
}, },
LastIndex:to, LastIndex: to,
} }
} }
func (op *ReplaceOp)Execute(buffer *bytes.Buffer) int{ func (op *ReplaceOp) Execute(buffer *bytes.Buffer) int {
if op.text != ""{ if op.text != "" {
buffer.WriteString(op.text) buffer.WriteString(op.text)
} }
return op.LastIndex +1 return op.LastIndex + 1
} }
func (op *ReplaceOp) String() string { func (op *ReplaceOp) String() string {
@ -268,54 +264,54 @@ func (op *ReplaceOp) String() string {
op.tokens.Get(op.index), op.tokens.Get(op.LastIndex), op.text) op.tokens.Get(op.index), op.tokens.Get(op.LastIndex), op.text)
} }
type TokenStreamRewriter struct { type TokenStreamRewriter struct {
//Our source stream //Our source stream
tokens TokenStream tokens TokenStream
// You may have multiple, named streams of rewrite operations. // You may have multiple, named streams of rewrite operations.
// I'm calling these things "programs." // I'm calling these things "programs."
// Maps String (name) &rarr; rewrite (List) // Maps String (name) &rarr; rewrite (List)
programs map[string][]RewriteOperation programs map[string][]RewriteOperation
last_rewrite_token_indexes map[string]int last_rewrite_token_indexes map[string]int
} }
func NewTokenStreamRewriter(tokens TokenStream) *TokenStreamRewriter{ func NewTokenStreamRewriter(tokens TokenStream) *TokenStreamRewriter {
return &TokenStreamRewriter{ return &TokenStreamRewriter{
tokens: tokens, tokens: tokens,
programs: map[string][]RewriteOperation{ programs: map[string][]RewriteOperation{
Default_Program_Name:make([]RewriteOperation,0, Program_Init_Size), Default_Program_Name: make([]RewriteOperation, 0, Program_Init_Size),
}, },
last_rewrite_token_indexes: map[string]int{}, last_rewrite_token_indexes: map[string]int{},
} }
} }
func (tsr *TokenStreamRewriter) GetTokenStream() TokenStream{ func (tsr *TokenStreamRewriter) GetTokenStream() TokenStream {
return tsr.tokens return tsr.tokens
} }
// Rollback the instruction stream for a program so that // Rollback the instruction stream for a program so that
// the indicated instruction (via instructionIndex) is no // the indicated instruction (via instructionIndex) is no
// longer in the stream. UNTESTED! // longer in the stream. UNTESTED!
func (tsr *TokenStreamRewriter) Rollback(program_name string, instruction_index int){ func (tsr *TokenStreamRewriter) Rollback(program_name string, instruction_index int) {
is, ok := tsr.programs[program_name] is, ok := tsr.programs[program_name]
if ok{ if ok {
tsr.programs[program_name] = is[Min_Token_Index:instruction_index] tsr.programs[program_name] = is[Min_Token_Index:instruction_index]
} }
} }
func (tsr *TokenStreamRewriter) RollbackDefault(instruction_index int){ func (tsr *TokenStreamRewriter) RollbackDefault(instruction_index int) {
tsr.Rollback(Default_Program_Name, instruction_index) tsr.Rollback(Default_Program_Name, instruction_index)
} }
//Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(program_name string){ // Reset the program so that no instructions exist
func (tsr *TokenStreamRewriter) DeleteProgram(program_name string) {
tsr.Rollback(program_name, Min_Token_Index) //TODO: double test on that cause lower bound is not included tsr.Rollback(program_name, Min_Token_Index) //TODO: double test on that cause lower bound is not included
} }
func (tsr *TokenStreamRewriter) DeleteProgramDefault(){ func (tsr *TokenStreamRewriter) DeleteProgramDefault() {
tsr.DeleteProgram(Default_Program_Name) tsr.DeleteProgram(Default_Program_Name)
} }
func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text string){ func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text string) {
// to insert after, just insert before next index (even if past end) // to insert after, just insert before next index (even if past end)
var op RewriteOperation = NewInsertAfterOp(index, text, tsr.tokens) var op RewriteOperation = NewInsertAfterOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name) rewrites := tsr.GetProgram(program_name)
@ -323,31 +319,31 @@ func (tsr *TokenStreamRewriter) InsertAfter(program_name string, index int, text
tsr.AddToProgram(program_name, op) tsr.AddToProgram(program_name, op)
} }
func (tsr *TokenStreamRewriter) InsertAfterDefault(index int, text string){ func (tsr *TokenStreamRewriter) InsertAfterDefault(index int, text string) {
tsr.InsertAfter(Default_Program_Name, index, text) tsr.InsertAfter(Default_Program_Name, index, text)
} }
func (tsr *TokenStreamRewriter) InsertAfterToken(program_name string, token Token, text string){ func (tsr *TokenStreamRewriter) InsertAfterToken(program_name string, token Token, text string) {
tsr.InsertAfter(program_name, token.GetTokenIndex(), text) tsr.InsertAfter(program_name, token.GetTokenIndex(), text)
} }
func (tsr* TokenStreamRewriter) InsertBefore(program_name string, index int, text string){ func (tsr *TokenStreamRewriter) InsertBefore(program_name string, index int, text string) {
var op RewriteOperation = NewInsertBeforeOp(index, text, tsr.tokens) var op RewriteOperation = NewInsertBeforeOp(index, text, tsr.tokens)
rewrites := tsr.GetProgram(program_name) rewrites := tsr.GetProgram(program_name)
op.SetInstructionIndex(len(rewrites)) op.SetInstructionIndex(len(rewrites))
tsr.AddToProgram(program_name, op) tsr.AddToProgram(program_name, op)
} }
func (tsr *TokenStreamRewriter) InsertBeforeDefault(index int, text string){ func (tsr *TokenStreamRewriter) InsertBeforeDefault(index int, text string) {
tsr.InsertBefore(Default_Program_Name, index, text) tsr.InsertBefore(Default_Program_Name, index, text)
} }
func (tsr *TokenStreamRewriter) InsertBeforeToken(program_name string,token Token, text string){ func (tsr *TokenStreamRewriter) InsertBeforeToken(program_name string, token Token, text string) {
tsr.InsertBefore(program_name, token.GetTokenIndex(), text) tsr.InsertBefore(program_name, token.GetTokenIndex(), text)
} }
func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text string){ func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text string) {
if from > to || from < 0 || to < 0 || to >= tsr.tokens.Size(){ if from > to || from < 0 || to < 0 || to >= tsr.tokens.Size() {
panic(fmt.Sprintf("replace: range invalid: %d..%d(size=%d)", panic(fmt.Sprintf("replace: range invalid: %d..%d(size=%d)",
from, to, tsr.tokens.Size())) from, to, tsr.tokens.Size()))
} }
@ -357,207 +353,216 @@ func (tsr *TokenStreamRewriter) Replace(program_name string, from, to int, text
tsr.AddToProgram(program_name, op) tsr.AddToProgram(program_name, op)
} }
func (tsr *TokenStreamRewriter)ReplaceDefault(from, to int, text string) { func (tsr *TokenStreamRewriter) ReplaceDefault(from, to int, text string) {
tsr.Replace(Default_Program_Name, from, to, text) tsr.Replace(Default_Program_Name, from, to, text)
} }
func (tsr *TokenStreamRewriter)ReplaceDefaultPos(index int, text string){ func (tsr *TokenStreamRewriter) ReplaceDefaultPos(index int, text string) {
tsr.ReplaceDefault(index, index, text) tsr.ReplaceDefault(index, index, text)
} }
func (tsr *TokenStreamRewriter)ReplaceToken(program_name string, from, to Token, text string){ func (tsr *TokenStreamRewriter) ReplaceToken(program_name string, from, to Token, text string) {
tsr.Replace(program_name, from.GetTokenIndex(), to.GetTokenIndex(), text) tsr.Replace(program_name, from.GetTokenIndex(), to.GetTokenIndex(), text)
} }
func (tsr *TokenStreamRewriter)ReplaceTokenDefault(from, to Token, text string){ func (tsr *TokenStreamRewriter) ReplaceTokenDefault(from, to Token, text string) {
tsr.ReplaceToken(Default_Program_Name, from, to, text) tsr.ReplaceToken(Default_Program_Name, from, to, text)
} }
func (tsr *TokenStreamRewriter)ReplaceTokenDefaultPos(index Token, text string){ func (tsr *TokenStreamRewriter) ReplaceTokenDefaultPos(index Token, text string) {
tsr.ReplaceTokenDefault(index, index, text) tsr.ReplaceTokenDefault(index, index, text)
} }
func (tsr *TokenStreamRewriter)Delete(program_name string, from, to int){ func (tsr *TokenStreamRewriter) Delete(program_name string, from, to int) {
tsr.Replace(program_name, from, to, "" ) tsr.Replace(program_name, from, to, "")
} }
func (tsr *TokenStreamRewriter)DeleteDefault(from, to int){ func (tsr *TokenStreamRewriter) DeleteDefault(from, to int) {
tsr.Delete(Default_Program_Name, from, to) tsr.Delete(Default_Program_Name, from, to)
} }
func (tsr *TokenStreamRewriter)DeleteDefaultPos(index int){ func (tsr *TokenStreamRewriter) DeleteDefaultPos(index int) {
tsr.DeleteDefault(index,index) tsr.DeleteDefault(index, index)
} }
func (tsr *TokenStreamRewriter)DeleteToken(program_name string, from, to Token) { func (tsr *TokenStreamRewriter) DeleteToken(program_name string, from, to Token) {
tsr.ReplaceToken(program_name, from, to, "") tsr.ReplaceToken(program_name, from, to, "")
} }
func (tsr *TokenStreamRewriter)DeleteTokenDefault(from,to Token){ func (tsr *TokenStreamRewriter) DeleteTokenDefault(from, to Token) {
tsr.DeleteToken(Default_Program_Name, from, to) tsr.DeleteToken(Default_Program_Name, from, to)
} }
func (tsr *TokenStreamRewriter)GetLastRewriteTokenIndex(program_name string)int { func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndex(program_name string) int {
i, ok := tsr.last_rewrite_token_indexes[program_name] i, ok := tsr.last_rewrite_token_indexes[program_name]
if !ok{ if !ok {
return -1 return -1
} }
return i return i
} }
func (tsr *TokenStreamRewriter)GetLastRewriteTokenIndexDefault()int{ func (tsr *TokenStreamRewriter) GetLastRewriteTokenIndexDefault() int {
return tsr.GetLastRewriteTokenIndex(Default_Program_Name) return tsr.GetLastRewriteTokenIndex(Default_Program_Name)
} }
func (tsr *TokenStreamRewriter)SetLastRewriteTokenIndex(program_name string, i int){ func (tsr *TokenStreamRewriter) SetLastRewriteTokenIndex(program_name string, i int) {
tsr.last_rewrite_token_indexes[program_name] = i tsr.last_rewrite_token_indexes[program_name] = i
} }
func (tsr *TokenStreamRewriter)InitializeProgram(name string)[]RewriteOperation{ func (tsr *TokenStreamRewriter) InitializeProgram(name string) []RewriteOperation {
is := make([]RewriteOperation, 0, Program_Init_Size) is := make([]RewriteOperation, 0, Program_Init_Size)
tsr.programs[name] = is tsr.programs[name] = is
return is return is
} }
func (tsr *TokenStreamRewriter)AddToProgram(name string, op RewriteOperation){ func (tsr *TokenStreamRewriter) AddToProgram(name string, op RewriteOperation) {
is := tsr.GetProgram(name) is := tsr.GetProgram(name)
is = append(is, op) is = append(is, op)
tsr.programs[name] = is tsr.programs[name] = is
} }
func (tsr *TokenStreamRewriter)GetProgram(name string) []RewriteOperation { func (tsr *TokenStreamRewriter) GetProgram(name string) []RewriteOperation {
is, ok := tsr.programs[name] is, ok := tsr.programs[name]
if !ok{ if !ok {
is = tsr.InitializeProgram(name) is = tsr.InitializeProgram(name)
} }
return is return is
} }
// Return the text from the original tokens altered per the
// instructions given to this rewriter. // Return the text from the original tokens altered per the
func (tsr *TokenStreamRewriter)GetTextDefault() string{ // instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetTextDefault() string {
return tsr.GetText( return tsr.GetText(
Default_Program_Name, Default_Program_Name,
NewInterval(0, tsr.tokens.Size()-1)) NewInterval(0, tsr.tokens.Size()-1))
} }
// Return the text from the original tokens altered per the
// instructions given to this rewriter. // Return the text from the original tokens altered per the
func (tsr *TokenStreamRewriter)GetText(program_name string, interval *Interval) string { // instructions given to this rewriter.
func (tsr *TokenStreamRewriter) GetText(program_name string, interval *Interval) string {
rewrites := tsr.programs[program_name] rewrites := tsr.programs[program_name]
start := interval.Start start := interval.Start
stop := interval.Stop stop := interval.Stop
// ensure start/end are in range // ensure start/end are in range
stop = min(stop, tsr.tokens.Size()-1) stop = min(stop, tsr.tokens.Size()-1)
start = max(start,0) start = max(start, 0)
if rewrites == nil || len(rewrites) == 0{ if rewrites == nil || len(rewrites) == 0 {
return tsr.tokens.GetTextFromInterval(interval) // no instructions to execute return tsr.tokens.GetTextFromInterval(interval) // no instructions to execute
} }
buf := bytes.Buffer{} buf := bytes.Buffer{}
// First, optimize instruction stream // First, optimize instruction stream
indexToOp := reduceToSingleOperationPerIndex(rewrites) indexToOp := reduceToSingleOperationPerIndex(rewrites)
// Walk buffer, executing instructions and emitting tokens // Walk buffer, executing instructions and emitting tokens
for i:=start; i<=stop && i<tsr.tokens.Size();{ for i := start; i <= stop && i < tsr.tokens.Size(); {
op := indexToOp[i] op := indexToOp[i]
delete(indexToOp, i)// remove so any left have index size-1 delete(indexToOp, i) // remove so any left have index size-1
t := tsr.tokens.Get(i) t := tsr.tokens.Get(i)
if op == nil{ if op == nil {
// no operation at that index, just dump token // no operation at that index, just dump token
if t.GetTokenType() != TokenEOF {buf.WriteString(t.GetText())} if t.GetTokenType() != TokenEOF {
buf.WriteString(t.GetText())
}
i++ // move to next token i++ // move to next token
}else { } else {
i = op.Execute(&buf)// execute operation and skip i = op.Execute(&buf) // execute operation and skip
} }
} }
// include stuff after end if it's last index in buffer // include stuff after end if it's last index in buffer
// So, if they did an insertAfter(lastValidIndex, "foo"), include // So, if they did an insertAfter(lastValidIndex, "foo"), include
// foo if end==lastValidIndex. // foo if end==lastValidIndex.
if stop == tsr.tokens.Size()-1{ if stop == tsr.tokens.Size()-1 {
// Scan any remaining operations after last token // Scan any remaining operations after last token
// should be included (they will be inserts). // should be included (they will be inserts).
for _, op := range indexToOp{ for _, op := range indexToOp {
if op.GetIndex() >= tsr.tokens.Size()-1 {buf.WriteString(op.GetText())} if op.GetIndex() >= tsr.tokens.Size()-1 {
buf.WriteString(op.GetText())
}
} }
} }
return buf.String() return buf.String()
} }
// We need to combine operations and report invalid operations (like // We need to combine operations and report invalid operations (like
// overlapping replaces that are not completed nested). Inserts to // overlapping replaces that are not completed nested). Inserts to
// same index need to be combined etc... Here are the cases: // same index need to be combined etc... Here are the cases:
// //
// I.i.u I.j.v leave alone, nonoverlapping // I.i.u I.j.v leave alone, nonoverlapping
// I.i.u I.i.v combine: Iivu // I.i.u I.i.v combine: Iivu
// //
// R.i-j.u R.x-y.v | i-j in x-y delete first R // R.i-j.u R.x-y.v | i-j in x-y delete first R
// R.i-j.u R.i-j.v delete first R // R.i-j.u R.i-j.v delete first R
// R.i-j.u R.x-y.v | x-y in i-j ERROR // R.i-j.u R.x-y.v | x-y in i-j ERROR
// R.i-j.u R.x-y.v | boundaries overlap ERROR // R.i-j.u R.x-y.v | boundaries overlap ERROR
// //
// Delete special case of replace (text==null): // Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right) // D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
// //
// I.i.u R.x-y.v | i in (x+1)-y delete I (since insert before // I.i.u R.x-y.v | i in (x+1)-y delete I (since insert before
// we're not deleting i) // we're not deleting i)
// I.i.u R.x-y.v | i not in (x+1)-y leave alone, nonoverlapping // I.i.u R.x-y.v | i not in (x+1)-y leave alone, nonoverlapping
// R.x-y.v I.i.u | i in x-y ERROR // R.x-y.v I.i.u | i in x-y ERROR
// R.x-y.v I.x.u R.x-y.uv (combine, delete I) // R.x-y.v I.x.u R.x-y.uv (combine, delete I)
// R.x-y.v I.i.u | i not in x-y leave alone, nonoverlapping // R.x-y.v I.i.u | i not in x-y leave alone, nonoverlapping
// //
// I.i.u = insert u before op @ index i // I.i.u = insert u before op @ index i
// R.x-y.u = replace x-y indexed tokens with u // R.x-y.u = replace x-y indexed tokens with u
// //
// First we need to examine replaces. For any replace op: // First we need to examine replaces. For any replace op:
// //
// 1. wipe out any insertions before op within that range. // 1. wipe out any insertions before op within that range.
// 2. Drop any replace op before that is contained completely within // 2. Drop any replace op before that is contained completely within
// that range. // that range.
// 3. Throw exception upon boundary overlap with any previous replace. // 3. Throw exception upon boundary overlap with any previous replace.
// //
// Then we can deal with inserts: // Then we can deal with inserts:
// //
// 1. for any inserts to same index, combine even if not adjacent. // 1. for any inserts to same index, combine even if not adjacent.
// 2. for any prior replace with same left boundary, combine this // 2. for any prior replace with same left boundary, combine this
// insert with replace and delete this replace. // insert with replace and delete this replace.
// 3. throw exception if index in same range as previous replace // 3. throw exception if index in same range as previous replace
// //
// Don't actually delete; make op null in list. Easier to walk list. // Don't actually delete; make op null in list. Easier to walk list.
// Later we can throw as we add to index &rarr; op map. // Later we can throw as we add to index &rarr; op map.
// //
// Note that I.2 R.2-2 will wipe out I.2 even though, technically, the // Note that I.2 R.2-2 will wipe out I.2 even though, technically, the
// inserted stuff would be before the replace range. But, if you // inserted stuff would be before the replace range. But, if you
// add tokens in front of a method body '{' and then delete the method // add tokens in front of a method body '{' and then delete the method
// body, I think the stuff before the '{' you added should disappear too. // body, I think the stuff before the '{' you added should disappear too.
// //
// Return a map from token index to operation. // Return a map from token index to operation.
// func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]RewriteOperation {
func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]RewriteOperation{
// WALK REPLACES // WALK REPLACES
for i:=0; i < len(rewrites); i++{ for i := 0; i < len(rewrites); i++ {
op := rewrites[i] op := rewrites[i]
if op == nil{continue} if op == nil {
continue
}
rop, ok := op.(*ReplaceOp) rop, ok := op.(*ReplaceOp)
if !ok{continue} if !ok {
continue
}
// Wipe prior inserts within range // Wipe prior inserts within range
for j:=0; j<i && j < len(rewrites); j++{ for j := 0; j < i && j < len(rewrites); j++ {
if iop, ok := rewrites[j].(*InsertBeforeOp);ok{ if iop, ok := rewrites[j].(*InsertBeforeOp); ok {
if iop.index == rop.index{ if iop.index == rop.index {
// E.g., insert before 2, delete 2..2; update replace // E.g., insert before 2, delete 2..2; update replace
// text to include insert before, kill insert // text to include insert before, kill insert
rewrites[iop.instruction_index] = nil rewrites[iop.instruction_index] = nil
if rop.text != ""{ if rop.text != "" {
rop.text = iop.text + rop.text rop.text = iop.text + rop.text
}else{ } else {
rop.text = iop.text rop.text = iop.text
} }
}else if iop.index > rop.index && iop.index <=rop.LastIndex{ } else if iop.index > rop.index && iop.index <= rop.LastIndex {
// delete insert as it's a no-op. // delete insert as it's a no-op.
rewrites[iop.instruction_index] = nil rewrites[iop.instruction_index] = nil
} }
} }
} }
// Drop any prior replaces contained within // Drop any prior replaces contained within
for j:=0; j<i && j < len(rewrites); j++{ for j := 0; j < i && j < len(rewrites); j++ {
if prevop, ok := rewrites[j].(*ReplaceOp);ok{ if prevop, ok := rewrites[j].(*ReplaceOp); ok {
if prevop.index>=rop.index && prevop.LastIndex <= rop.LastIndex{ if prevop.index >= rop.index && prevop.LastIndex <= rop.LastIndex {
// delete replace as it's a no-op. // delete replace as it's a no-op.
rewrites[prevop.instruction_index] = nil rewrites[prevop.instruction_index] = nil
continue continue
@ -566,61 +571,67 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
disjoint := prevop.LastIndex < rop.index || prevop.index > rop.LastIndex disjoint := prevop.LastIndex < rop.index || prevop.index > rop.LastIndex
// Delete special case of replace (text==null): // Delete special case of replace (text==null):
// D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right) // D.i-j.u D.x-y.v | boundaries overlap combine to max(min)..max(right)
if prevop.text == "" && rop.text == "" && !disjoint{ if prevop.text == "" && rop.text == "" && !disjoint {
rewrites[prevop.instruction_index] = nil rewrites[prevop.instruction_index] = nil
rop.index = min(prevop.index, rop.index) rop.index = min(prevop.index, rop.index)
rop.LastIndex = max(prevop.LastIndex, rop.LastIndex) rop.LastIndex = max(prevop.LastIndex, rop.LastIndex)
println("new rop" + rop.String()) //TODO: remove console write, taken from Java version println("new rop" + rop.String()) //TODO: remove console write, taken from Java version
}else if !disjoint{ } else if !disjoint {
panic("replace op boundaries of " + rop.String() + " overlap with previous " + prevop.String()) panic("replace op boundaries of " + rop.String() + " overlap with previous " + prevop.String())
} }
} }
} }
} }
// WALK INSERTS // WALK INSERTS
for i:=0; i < len(rewrites); i++ { for i := 0; i < len(rewrites); i++ {
op := rewrites[i] op := rewrites[i]
if op == nil{continue} if op == nil {
continue
}
//hack to replicate inheritance in composition //hack to replicate inheritance in composition
_, iok := rewrites[i].(*InsertBeforeOp) _, iok := rewrites[i].(*InsertBeforeOp)
_, aok := rewrites[i].(*InsertAfterOp) _, aok := rewrites[i].(*InsertAfterOp)
if !iok && !aok{continue} if !iok && !aok {
continue
}
iop := rewrites[i] iop := rewrites[i]
// combine current insert with prior if any at same index // combine current insert with prior if any at same index
// deviating a bit from TokenStreamRewriter.java - hard to incorporate inheritance logic // deviating a bit from TokenStreamRewriter.java - hard to incorporate inheritance logic
for j:=0; j<i && j < len(rewrites); j++{ for j := 0; j < i && j < len(rewrites); j++ {
if nextIop, ok := rewrites[j].(*InsertAfterOp); ok{ if nextIop, ok := rewrites[j].(*InsertAfterOp); ok {
if nextIop.index == iop.GetIndex(){ if nextIop.index == iop.GetIndex() {
iop.SetText(nextIop.text + iop.GetText()) iop.SetText(nextIop.text + iop.GetText())
rewrites[j] = nil rewrites[j] = nil
} }
} }
if prevIop, ok := rewrites[j].(*InsertBeforeOp); ok{ if prevIop, ok := rewrites[j].(*InsertBeforeOp); ok {
if prevIop.index == iop.GetIndex(){ if prevIop.index == iop.GetIndex() {
iop.SetText(iop.GetText() + prevIop.text) iop.SetText(iop.GetText() + prevIop.text)
rewrites[prevIop.instruction_index] = nil rewrites[prevIop.instruction_index] = nil
} }
} }
} }
// look for replaces where iop.index is in range; error // look for replaces where iop.index is in range; error
for j:=0; j<i && j < len(rewrites); j++{ for j := 0; j < i && j < len(rewrites); j++ {
if rop,ok := rewrites[j].(*ReplaceOp); ok{ if rop, ok := rewrites[j].(*ReplaceOp); ok {
if iop.GetIndex() == rop.index{ if iop.GetIndex() == rop.index {
rop.text = iop.GetText() + rop.text rop.text = iop.GetText() + rop.text
rewrites[i] = nil rewrites[i] = nil
continue continue
} }
if iop.GetIndex() >= rop.index && iop.GetIndex() <= rop.LastIndex{ if iop.GetIndex() >= rop.index && iop.GetIndex() <= rop.LastIndex {
panic("insert op "+iop.String()+" within boundaries of previous "+rop.String()) panic("insert op " + iop.String() + " within boundaries of previous " + rop.String())
} }
} }
} }
} }
m := map[int]RewriteOperation{} m := map[int]RewriteOperation{}
for i:=0; i < len(rewrites); i++{ for i := 0; i < len(rewrites); i++ {
op := rewrites[i] op := rewrites[i]
if op == nil {continue} if op == nil {
if _, ok := m[op.GetIndex()]; ok{ continue
}
if _, ok := m[op.GetIndex()]; ok {
panic("should only be one op per index") panic("should only be one op per index")
} }
m[op.GetIndex()] = op m[op.GetIndex()] = op
@ -628,22 +639,21 @@ func reduceToSingleOperationPerIndex(rewrites []RewriteOperation) map[int]Rewrit
return m return m
} }
/* /*
Quick fixing Go lack of overloads Quick fixing Go lack of overloads
*/ */
func max(a,b int)int{ func max(a, b int) int {
if a>b{ if a > b {
return a return a
}else { } else {
return b return b
} }
} }
func min(a,b int)int{ func min(a, b int) int {
if a<b{ if a < b {
return a return a
}else { } else {
return b return b
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -234,10 +234,8 @@ func (p *ParseTreeWalker) Walk(listener ParseTreeListener, t Tree) {
} }
} }
//
// Enters a grammar rule by first triggering the generic event {@link ParseTreeListener//EnterEveryRule} // Enters a grammar rule by first triggering the generic event {@link ParseTreeListener//EnterEveryRule}
// then by triggering the event specific to the given parse tree node // then by triggering the event specific to the given parse tree node
//
func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) { func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext) ctx := r.GetRuleContext().(ParserRuleContext)
listener.EnterEveryRule(ctx) listener.EnterEveryRule(ctx)
@ -246,7 +244,6 @@ func (p *ParseTreeWalker) EnterRule(listener ParseTreeListener, r RuleNode) {
// Exits a grammar rule by first triggering the event specific to the given parse tree node // Exits a grammar rule by first triggering the event specific to the given parse tree node
// then by triggering the generic event {@link ParseTreeListener//ExitEveryRule} // then by triggering the generic event {@link ParseTreeListener//ExitEveryRule}
//
func (p *ParseTreeWalker) ExitRule(listener ParseTreeListener, r RuleNode) { func (p *ParseTreeWalker) ExitRule(listener ParseTreeListener, r RuleNode) {
ctx := r.GetRuleContext().(ParserRuleContext) ctx := r.GetRuleContext().(ParserRuleContext)
ctx.ExitRule(listener) ctx.ExitRule(listener)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -9,8 +9,9 @@ import "fmt"
/** A set of utility routines useful for all kinds of ANTLR trees. */ /** A set of utility routines useful for all kinds of ANTLR trees. */
// Print out a whole tree in LISP form. {@link //getNodeText} is used on the // Print out a whole tree in LISP form. {@link //getNodeText} is used on the
// node payloads to get the text for the nodes. Detect //
// parse trees and extract data appropriately. // node payloads to get the text for the nodes. Detect
// parse trees and extract data appropriately.
func TreesStringTree(tree Tree, ruleNames []string, recog Recognizer) string { func TreesStringTree(tree Tree, ruleNames []string, recog Recognizer) string {
if recog != nil { if recog != nil {
@ -80,8 +81,8 @@ func TreesGetChildren(t Tree) []Tree {
} }
// Return a list of all ancestors of this node. The first node of // Return a list of all ancestors of this node. The first node of
// list is the root and the last is the parent of this node.
// //
// list is the root and the last is the parent of this node.
func TreesgetAncestors(t Tree) []Tree { func TreesgetAncestors(t Tree) []Tree {
ancestors := make([]Tree, 0) ancestors := make([]Tree, 0)
t = t.GetParent() t = t.GetParent()

View File

@ -1,4 +1,4 @@
// Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. // Copyright (c) 2012-2022 The ANTLR Project. All rights reserved.
// Use of this file is governed by the BSD 3-clause license that // Use of this file is governed by the BSD 3-clause license that
// can be found in the LICENSE.txt file in the project root. // can be found in the LICENSE.txt file in the project root.
@ -47,28 +47,25 @@ func (s *IntStack) Push(e int) {
*s = append(*s, e) *s = append(*s, e)
} }
func standardEqualsFunction(a interface{}, b interface{}) bool { type comparable interface {
Equals(other Collectable[any]) bool
}
ac, oka := a.(comparable) func standardEqualsFunction(a Collectable[any], b Collectable[any]) bool {
bc, okb := b.(comparable)
if !oka || !okb { return a.Equals(b)
panic("Not Comparable")
}
return ac.equals(bc)
} }
func standardHashFunction(a interface{}) int { func standardHashFunction(a interface{}) int {
if h, ok := a.(hasher); ok { if h, ok := a.(hasher); ok {
return h.hash() return h.Hash()
} }
panic("Not Hasher") panic("Not Hasher")
} }
type hasher interface { type hasher interface {
hash() int Hash() int
} }
const bitsPerWord = 64 const bitsPerWord = 64
@ -171,7 +168,7 @@ func (b *BitSet) equals(other interface{}) bool {
// We only compare set bits, so we cannot rely on the two slices having the same size. Its // We only compare set bits, so we cannot rely on the two slices having the same size. Its
// possible for two BitSets to have different slice lengths but the same set bits. So we only // possible for two BitSets to have different slice lengths but the same set bits. So we only
// compare the relavent words and ignore the trailing zeros. // compare the relevant words and ignore the trailing zeros.
bLen := b.minLen() bLen := b.minLen()
otherLen := otherBitSet.minLen() otherLen := otherBitSet.minLen()

View File

@ -8,8 +8,6 @@ const (
_loadFactor = 0.75 _loadFactor = 0.75
) )
var _ Set = (*array2DHashSet)(nil)
type Set interface { type Set interface {
Add(value interface{}) (added interface{}) Add(value interface{}) (added interface{})
Len() int Len() int
@ -20,9 +18,9 @@ type Set interface {
} }
type array2DHashSet struct { type array2DHashSet struct {
buckets [][]interface{} buckets [][]Collectable[any]
hashcodeFunction func(interface{}) int hashcodeFunction func(interface{}) int
equalsFunction func(interface{}, interface{}) bool equalsFunction func(Collectable[any], Collectable[any]) bool
n int // How many elements in set n int // How many elements in set
threshold int // when to expand threshold int // when to expand
@ -61,11 +59,11 @@ func (as *array2DHashSet) Values() []interface{} {
return values return values
} }
func (as *array2DHashSet) Contains(value interface{}) bool { func (as *array2DHashSet) Contains(value Collectable[any]) bool {
return as.Get(value) != nil return as.Get(value) != nil
} }
func (as *array2DHashSet) Add(value interface{}) interface{} { func (as *array2DHashSet) Add(value Collectable[any]) interface{} {
if as.n > as.threshold { if as.n > as.threshold {
as.expand() as.expand()
} }
@ -98,7 +96,7 @@ func (as *array2DHashSet) expand() {
b := as.getBuckets(o) b := as.getBuckets(o)
bucketLength := newBucketLengths[b] bucketLength := newBucketLengths[b]
var newBucket []interface{} var newBucket []Collectable[any]
if bucketLength == 0 { if bucketLength == 0 {
// new bucket // new bucket
newBucket = as.createBucket(as.initialBucketCapacity) newBucket = as.createBucket(as.initialBucketCapacity)
@ -107,7 +105,7 @@ func (as *array2DHashSet) expand() {
newBucket = newTable[b] newBucket = newTable[b]
if bucketLength == len(newBucket) { if bucketLength == len(newBucket) {
// expand // expand
newBucketCopy := make([]interface{}, len(newBucket)<<1) newBucketCopy := make([]Collectable[any], len(newBucket)<<1)
copy(newBucketCopy[:bucketLength], newBucket) copy(newBucketCopy[:bucketLength], newBucket)
newBucket = newBucketCopy newBucket = newBucketCopy
newTable[b] = newBucket newTable[b] = newBucket
@ -124,7 +122,7 @@ func (as *array2DHashSet) Len() int {
return as.n return as.n
} }
func (as *array2DHashSet) Get(o interface{}) interface{} { func (as *array2DHashSet) Get(o Collectable[any]) interface{} {
if o == nil { if o == nil {
return nil return nil
} }
@ -147,7 +145,7 @@ func (as *array2DHashSet) Get(o interface{}) interface{} {
return nil return nil
} }
func (as *array2DHashSet) innerAdd(o interface{}) interface{} { func (as *array2DHashSet) innerAdd(o Collectable[any]) interface{} {
b := as.getBuckets(o) b := as.getBuckets(o)
bucket := as.buckets[b] bucket := as.buckets[b]
@ -178,7 +176,7 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} {
// full bucket, expand and add to end // full bucket, expand and add to end
oldLength := len(bucket) oldLength := len(bucket)
bucketCopy := make([]interface{}, oldLength<<1) bucketCopy := make([]Collectable[any], oldLength<<1)
copy(bucketCopy[:oldLength], bucket) copy(bucketCopy[:oldLength], bucket)
bucket = bucketCopy bucket = bucketCopy
as.buckets[b] = bucket as.buckets[b] = bucket
@ -187,22 +185,22 @@ func (as *array2DHashSet) innerAdd(o interface{}) interface{} {
return o return o
} }
func (as *array2DHashSet) getBuckets(value interface{}) int { func (as *array2DHashSet) getBuckets(value Collectable[any]) int {
hash := as.hashcodeFunction(value) hash := as.hashcodeFunction(value)
return hash & (len(as.buckets) - 1) return hash & (len(as.buckets) - 1)
} }
func (as *array2DHashSet) createBuckets(cap int) [][]interface{} { func (as *array2DHashSet) createBuckets(cap int) [][]Collectable[any] {
return make([][]interface{}, cap) return make([][]Collectable[any], cap)
} }
func (as *array2DHashSet) createBucket(cap int) []interface{} { func (as *array2DHashSet) createBucket(cap int) []Collectable[any] {
return make([]interface{}, cap) return make([]Collectable[any], cap)
} }
func newArray2DHashSetWithCap( func newArray2DHashSetWithCap(
hashcodeFunction func(interface{}) int, hashcodeFunction func(interface{}) int,
equalsFunction func(interface{}, interface{}) bool, equalsFunction func(Collectable[any], Collectable[any]) bool,
initCap int, initCap int,
initBucketCap int, initBucketCap int,
) *array2DHashSet { ) *array2DHashSet {
@ -231,7 +229,7 @@ func newArray2DHashSetWithCap(
func newArray2DHashSet( func newArray2DHashSet(
hashcodeFunction func(interface{}) int, hashcodeFunction func(interface{}) int,
equalsFunction func(interface{}, interface{}) bool, equalsFunction func(Collectable[any], Collectable[any]) bool,
) *array2DHashSet { ) *array2DHashSet {
return newArray2DHashSetWithCap(hashcodeFunction, equalsFunction, _initalCapacity, _initalBucketCapacity) return newArray2DHashSetWithCap(hashcodeFunction, equalsFunction, _initalCapacity, _initalBucketCapacity)
} }

View File

@ -1,10 +0,0 @@
language: go
go:
- 1.13
- 1.x
- tip
before_install:
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -service=travis-ci

View File

@ -5,10 +5,20 @@ import (
"time" "time"
) )
// An OperationWithData is executing by RetryWithData() or RetryNotifyWithData().
// The operation will be retried using a backoff policy if it returns an error.
type OperationWithData[T any] func() (T, error)
// An Operation is executing by Retry() or RetryNotify(). // An Operation is executing by Retry() or RetryNotify().
// The operation will be retried using a backoff policy if it returns an error. // The operation will be retried using a backoff policy if it returns an error.
type Operation func() error type Operation func() error
func (o Operation) withEmptyData() OperationWithData[struct{}] {
return func() (struct{}, error) {
return struct{}{}, o()
}
}
// Notify is a notify-on-error function. It receives an operation error and // Notify is a notify-on-error function. It receives an operation error and
// backoff delay if the operation failed (with an error). // backoff delay if the operation failed (with an error).
// //
@ -28,18 +38,41 @@ func Retry(o Operation, b BackOff) error {
return RetryNotify(o, b, nil) return RetryNotify(o, b, nil)
} }
// RetryWithData is like Retry but returns data in the response too.
func RetryWithData[T any](o OperationWithData[T], b BackOff) (T, error) {
return RetryNotifyWithData(o, b, nil)
}
// RetryNotify calls notify function with the error and wait duration // RetryNotify calls notify function with the error and wait duration
// for each failed attempt before sleep. // for each failed attempt before sleep.
func RetryNotify(operation Operation, b BackOff, notify Notify) error { func RetryNotify(operation Operation, b BackOff, notify Notify) error {
return RetryNotifyWithTimer(operation, b, notify, nil) return RetryNotifyWithTimer(operation, b, notify, nil)
} }
// RetryNotifyWithData is like RetryNotify but returns data in the response too.
func RetryNotifyWithData[T any](operation OperationWithData[T], b BackOff, notify Notify) (T, error) {
return doRetryNotify(operation, b, notify, nil)
}
// RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer // RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer
// for each failed attempt before sleep. // for each failed attempt before sleep.
// A default timer that uses system timer is used when nil is passed. // A default timer that uses system timer is used when nil is passed.
func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error { func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error {
var err error _, err := doRetryNotify(operation.withEmptyData(), b, notify, t)
var next time.Duration return err
}
// RetryNotifyWithTimerAndData is like RetryNotifyWithTimer but returns data in the response too.
func RetryNotifyWithTimerAndData[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
return doRetryNotify(operation, b, notify, t)
}
func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
var (
err error
next time.Duration
res T
)
if t == nil { if t == nil {
t = &defaultTimer{} t = &defaultTimer{}
} }
@ -52,21 +85,22 @@ func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer
b.Reset() b.Reset()
for { for {
if err = operation(); err == nil { res, err = operation()
return nil if err == nil {
return res, nil
} }
var permanent *PermanentError var permanent *PermanentError
if errors.As(err, &permanent) { if errors.As(err, &permanent) {
return permanent.Err return res, permanent.Err
} }
if next = b.NextBackOff(); next == Stop { if next = b.NextBackOff(); next == Stop {
if cerr := ctx.Err(); cerr != nil { if cerr := ctx.Err(); cerr != nil {
return cerr return res, cerr
} }
return err return res, err
} }
if notify != nil { if notify != nil {
@ -77,7 +111,7 @@ func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return res, ctx.Err()
case <-t.C(): case <-t.C():
} }
} }

View File

@ -85,7 +85,7 @@ func (v *Version) Set(version string) error {
return fmt.Errorf("failed to validate metadata: %v", err) return fmt.Errorf("failed to validate metadata: %v", err)
} }
parsed := make([]int64, 3, 3) parsed := make([]int64, 3)
for i, v := range dotParts[:3] { for i, v := range dotParts[:3] {
val, err := strconv.ParseInt(v, 10, 64) val, err := strconv.ParseInt(v, 10, 64)

View File

@ -69,6 +69,58 @@ func Enabled() bool {
return true return true
} }
// StderrIsJournalStream returns whether the process stderr is connected
// to the Journal's stream transport.
//
// This can be used for automatic protocol upgrading described in [Journal Native Protocol].
//
// Returns true if JOURNAL_STREAM environment variable is present,
// and stderr's device and inode numbers match it.
//
// Error is returned if unexpected error occurs: e.g. if JOURNAL_STREAM environment variable
// is present, but malformed, fstat syscall fails, etc.
//
// [Journal Native Protocol]: https://systemd.io/JOURNAL_NATIVE_PROTOCOL/#automatic-protocol-upgrading
func StderrIsJournalStream() (bool, error) {
return fdIsJournalStream(syscall.Stderr)
}
// StdoutIsJournalStream returns whether the process stdout is connected
// to the Journal's stream transport.
//
// Returns true if JOURNAL_STREAM environment variable is present,
// and stdout's device and inode numbers match it.
//
// Error is returned if unexpected error occurs: e.g. if JOURNAL_STREAM environment variable
// is present, but malformed, fstat syscall fails, etc.
//
// Most users should probably use [StderrIsJournalStream].
func StdoutIsJournalStream() (bool, error) {
return fdIsJournalStream(syscall.Stdout)
}
func fdIsJournalStream(fd int) (bool, error) {
journalStream := os.Getenv("JOURNAL_STREAM")
if journalStream == "" {
return false, nil
}
var expectedStat syscall.Stat_t
_, err := fmt.Sscanf(journalStream, "%d:%d", &expectedStat.Dev, &expectedStat.Ino)
if err != nil {
return false, fmt.Errorf("failed to parse JOURNAL_STREAM=%q: %v", journalStream, err)
}
var stat syscall.Stat_t
err = syscall.Fstat(fd, &stat)
if err != nil {
return false, err
}
match := stat.Dev == expectedStat.Dev && stat.Ino == expectedStat.Ino
return match, nil
}
// Send a message to the local systemd journal. vars is a map of journald // Send a message to the local systemd journal. vars is a map of journald
// fields to values. Fields must be composed of uppercase letters, numbers, // fields to values. Fields must be composed of uppercase letters, numbers,
// and underscores, but must not start with an underscore. Within these // and underscores, but must not start with an underscore. Within these

View File

@ -33,3 +33,11 @@ func Enabled() bool {
func Send(message string, priority Priority, vars map[string]string) error { func Send(message string, priority Priority, vars map[string]string) error {
return errors.New("could not initialize socket to journald") return errors.New("could not initialize socket to journald")
} }
func StderrIsJournalStream() (bool, error) {
return false, nil
}
func StdoutIsJournalStream() (bool, error) {
return false, nil
}

View File

@ -568,29 +568,6 @@ func (p Patch) replace(doc *container, op Operation) error {
return errors.Wrapf(err, "replace operation failed to decode path") return errors.Wrapf(err, "replace operation failed to decode path")
} }
if path == "" {
val := op.value()
if val.which == eRaw {
if !val.tryDoc() {
if !val.tryAry() {
return errors.Wrapf(err, "replace operation value must be object or array")
}
}
}
switch val.which {
case eAry:
*doc = &val.ary
case eDoc:
*doc = &val.doc
case eRaw:
return errors.Wrapf(err, "replace operation hit impossible case")
}
return nil
}
con, key := findObject(doc, path) con, key := findObject(doc, path)
if con == nil { if con == nil {
@ -657,25 +634,6 @@ func (p Patch) test(doc *container, op Operation) error {
return errors.Wrapf(err, "test operation failed to decode path") return errors.Wrapf(err, "test operation failed to decode path")
} }
if path == "" {
var self lazyNode
switch sv := (*doc).(type) {
case *partialDoc:
self.doc = *sv
self.which = eDoc
case *partialArray:
self.ary = *sv
self.which = eAry
}
if self.equal(op.value()) {
return nil
}
return errors.Wrapf(ErrTestFailed, "testing value %s failed", path)
}
con, key := findObject(doc, path) con, key := findObject(doc, path)
if con == nil { if con == nil {

View File

@ -26,11 +26,16 @@ var rxDupSlashes = regexp.MustCompile(`/{2,}`)
// - FlagLowercaseHost // - FlagLowercaseHost
// - FlagRemoveDefaultPort // - FlagRemoveDefaultPort
// - FlagRemoveDuplicateSlashes (and this was mixed in with the |) // - FlagRemoveDuplicateSlashes (and this was mixed in with the |)
//
// This also normalizes the URL into its urlencoded form by removing RawPath and RawFragment.
func NormalizeURL(u *url.URL) { func NormalizeURL(u *url.URL) {
lowercaseScheme(u) lowercaseScheme(u)
lowercaseHost(u) lowercaseHost(u)
removeDefaultPort(u) removeDefaultPort(u)
removeDuplicateSlashes(u) removeDuplicateSlashes(u)
u.RawPath = ""
u.RawFragment = ""
} }
func lowercaseScheme(u *url.URL) { func lowercaseScheme(u *url.URL) {

View File

@ -23,6 +23,7 @@ go_library(
"//checker/decls:go_default_library", "//checker/decls:go_default_library",
"//common:go_default_library", "//common:go_default_library",
"//common/containers:go_default_library", "//common/containers:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library", "//common/overloads:go_default_library",
"//common/types:go_default_library", "//common/types:go_default_library",
"//common/types/pb:go_default_library", "//common/types/pb:go_default_library",
@ -31,7 +32,7 @@ go_library(
"//interpreter:go_default_library", "//interpreter:go_default_library",
"//interpreter/functions:go_default_library", "//interpreter/functions:go_default_library",
"//parser:go_default_library", "//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protodesc:go_default_library", "@org_golang_google_protobuf//reflect/protodesc:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -69,7 +70,7 @@ go_test(
"//test/proto2pb:go_default_library", "//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library", "//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto", "@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library",
], ],

View File

@ -139,7 +139,7 @@ var (
kind: TypeKind, kind: TypeKind,
runtimeType: types.TypeType, runtimeType: types.TypeType,
} }
//UintType represents a uint type. // UintType represents a uint type.
UintType = &Type{ UintType = &Type{
kind: UintKind, kind: UintKind,
runtimeType: types.UintType, runtimeType: types.UintType,
@ -222,7 +222,8 @@ func (t *Type) equals(other *Type) bool {
// - The from types are the same instance // - The from types are the same instance
// - The target type is dynamic // - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type // - The fromType has the same kind and type name as the target type, and all parameters of the target type
// are IsAssignableType() from the parameters of the fromType. //
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool { func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() { if t == fromType || t.isDyn() {
return true return true
@ -312,6 +313,11 @@ func NullableType(wrapped *Type) *Type {
} }
} }
// OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func OptionalType(param *Type) *Type {
return OpaqueType("optional", param)
}
// OpaqueType creates an abstract parameterized type with a given name. // OpaqueType creates an abstract parameterized type with a given name.
func OpaqueType(name string, params ...*Type) *Type { func OpaqueType(name string, params ...*Type) *Type {
return &Type{ return &Type{
@ -365,7 +371,9 @@ func Variable(name string, t *Type) EnvOption {
// //
// - Overloads are searched in the order they are declared // - Overloads are searched in the order they are declared
// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents // - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents
// at runtime. Empty lists and maps will result in a 'default dispatch' //
// at runtime. Empty lists and maps will result in a 'default dispatch'
//
// - In the event that a default dispatch occurs, the first overload provided is the one invoked // - In the event that a default dispatch occurs, the first overload provided is the one invoked
// //
// If you intend to use overloads which differentiate based on the key or element type of a list or // If you intend to use overloads which differentiate based on the key or element type of a list or
@ -405,7 +413,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
// FunctionOpt defines a functional option for configuring a function declaration. // FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt func(*functionDecl) (*functionDecl, error) type FunctionOpt func(*functionDecl) (*functionDecl, error)
// SingletonUnaryBinding creates a singleton function defintion to be used for all function overloads. // SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
// //
// Note, this approach works well if operand is expected to have a specific trait which it implements, // Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. // e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
@ -431,7 +439,17 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
// //
// Note, this approach works well if operand is expected to have a specific trait which it implements, // Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. // e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonBinaryBinding
func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt { func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
return SingletonBinaryBinding(fn, traits...)
}
// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
trait := 0 trait := 0
for _, t := range traits { for _, t := range traits {
trait = trait | t trait = trait | t
@ -453,7 +471,17 @@ func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt {
// //
// Note, this approach works well if operand is expected to have a specific trait which it implements, // Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. // e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
//
// Deprecated: use SingletonFunctionBinding
func SingletonFunctionImpl(fn functions.FunctionOp, traits ...int) FunctionOpt { func SingletonFunctionImpl(fn functions.FunctionOp, traits ...int) FunctionOpt {
return SingletonFunctionBinding(fn, traits...)
}
// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt {
trait := 0 trait := 0
for _, t := range traits { for _, t := range traits {
trait = trait | t trait = trait | t
@ -720,9 +748,8 @@ func (f *functionDecl) addOverload(overload *overloadDecl) error {
// Allow redefinition of an overload implementation so long as the signatures match. // Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[index] = overload f.overloads[index] = overload
return nil return nil
} else {
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
} }
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
} }
} }
f.overloads = append(f.overloads, overload) f.overloads = append(f.overloads, overload)
@ -1177,3 +1204,43 @@ func collectParamNames(paramNames map[string]struct{}, arg *Type) {
collectParamNames(paramNames, param) collectParamNames(paramNames, param)
} }
} }
func typeValueToKind(tv *types.TypeValue) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}

View File

@ -102,15 +102,18 @@ type Env struct {
provider ref.TypeProvider provider ref.TypeProvider
features map[int]bool features map[int]bool
appliedFeatures map[int]bool appliedFeatures map[int]bool
libraries map[string]bool
// Internal parser representation // Internal parser representation
prsr *parser.Parser prsr *parser.Parser
prsrOpts []parser.Option
// Internal checker representation // Internal checker representation
chk *checker.Env chkMutex sync.Mutex
chkErr error chk *checker.Env
chkOnce sync.Once chkErr error
chkOpts []checker.Option chkOnce sync.Once
chkOpts []checker.Option
// Program options tied to the environment // Program options tied to the environment
progOpts []ProgramOption progOpts []ProgramOption
@ -159,6 +162,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
provider: registry, provider: registry,
features: map[int]bool{}, features: map[int]bool{},
appliedFeatures: map[int]bool{}, appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
progOpts: []ProgramOption{}, progOpts: []ProgramOption{},
}).configure(opts) }).configure(opts)
} }
@ -175,14 +179,14 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast) pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations. // Construct the internal checker env, erroring if there is an issue adding the declarations.
err := e.initChecker() chk, err := e.initChecker()
if err != nil { if err != nil {
errs := common.NewErrors(ast.Source()) errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, e.chkErr.Error()) errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssues(errs) return nil, NewIssues(errs)
} }
res, errs := checker.Check(pe, ast.Source(), e.chk) res, errs := checker.Check(pe, ast.Source(), chk)
if len(errs.GetErrors()) > 0 { if len(errs.GetErrors()) > 0 {
return nil, NewIssues(errs) return nil, NewIssues(errs)
} }
@ -236,10 +240,14 @@ func (e *Env) CompileSource(src Source) (*Ast, *Issues) {
// TypeProvider are immutable, or that their underlying implementations are based on the // TypeProvider are immutable, or that their underlying implementations are based on the
// ref.TypeRegistry which provides a Copy method which will be invoked by this method. // ref.TypeRegistry which provides a Copy method which will be invoked by this method.
func (e *Env) Extend(opts ...EnvOption) (*Env, error) { func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
if e.chkErr != nil { chk, chkErr := e.getCheckerOrError()
return nil, e.chkErr if chkErr != nil {
return nil, chkErr
} }
prsrOptsCopy := make([]parser.Option, len(e.prsrOpts))
copy(prsrOptsCopy, e.prsrOpts)
// The type-checker is configured with Declarations. The declarations may either be provided // The type-checker is configured with Declarations. The declarations may either be provided
// as options which have not yet been validated, or may come from a previous checker instance // as options which have not yet been validated, or may come from a previous checker instance
// whose types have already been validated. // whose types have already been validated.
@ -248,10 +256,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// Copy the declarations if needed. // Copy the declarations if needed.
decsCopy := []*exprpb.Decl{} decsCopy := []*exprpb.Decl{}
if e.chk != nil { if chk != nil {
// If the type-checker has already been instantiated, then the e.declarations have been // If the type-checker has already been instantiated, then the e.declarations have been
// valdiated within the chk instance. // validated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(e.chk)) chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk))
} else { } else {
// If the type-checker has not been instantiated, ensure the unvalidated declarations are // If the type-checker has not been instantiated, ensure the unvalidated declarations are
// provided to the extended Env instance. // provided to the extended Env instance.
@ -304,8 +312,11 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.functions { for k, v := range e.functions {
funcsCopy[k] = v funcsCopy[k] = v
} }
libsCopy := make(map[string]bool, len(e.libraries))
for k, v := range e.libraries {
libsCopy[k] = v
}
// TODO: functions copy needs to happen here.
ext := &Env{ ext := &Env{
Container: e.Container, Container: e.Container,
declarations: decsCopy, declarations: decsCopy,
@ -315,8 +326,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
adapter: adapter, adapter: adapter,
features: featuresCopy, features: featuresCopy,
appliedFeatures: appliedFeaturesCopy, appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
provider: provider, provider: provider,
chkOpts: chkOptsCopy, chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
} }
return ext.configure(opts) return ext.configure(opts)
} }
@ -328,6 +341,12 @@ func (e *Env) HasFeature(flag int) bool {
return has && enabled return has && enabled
} }
// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment.
func (e *Env) HasLibrary(libName string) bool {
configured, exists := e.libraries[libName]
return exists && configured
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues. // Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
// //
// This form of Parse creates a Source value for the input `txt` and forwards to the // This form of Parse creates a Source value for the input `txt` and forwards to the
@ -422,8 +441,8 @@ func (e *Env) UnknownVars() interpreter.PartialActivation {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an // TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again. // Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), details.State()) pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(&exprpb.ParsedExpr{Expr: pruned})) expr, err := AstToString(ParsedExprToAst(pruned))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -443,12 +462,12 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and // EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator. // extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator) (checker.CostEstimate, error) { func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast) checked, err := AstToCheckedExpr(ast)
if err != nil { if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err) return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
} }
return checker.Cost(checked, estimator), nil return checker.Cost(checked, estimator, opts...)
} }
// configure applies a series of EnvOptions to the current environment. // configure applies a series of EnvOptions to the current environment.
@ -464,17 +483,9 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
} }
// If the default UTC timezone fix has been enabled, make sure the library is configured // If the default UTC timezone fix has been enabled, make sure the library is configured
if e.HasFeature(featureDefaultUTCTimeZone) { e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{}))
if _, found := e.appliedFeatures[featureDefaultUTCTimeZone]; !found { if err != nil {
e, err = Lib(timeUTCLibrary{})(e) return nil, err
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[featureDefaultUTCTimeZone] = true
}
} }
// Initialize all of the functions configured within the environment. // Initialize all of the functions configured within the environment.
@ -486,7 +497,10 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
} }
// Configure the parser. // Configure the parser.
prsrOpts := []parser.Option{parser.Macros(e.macros...)} prsrOpts := []parser.Option{}
prsrOpts = append(prsrOpts, e.prsrOpts...)
prsrOpts = append(prsrOpts, parser.Macros(e.macros...))
if e.HasFeature(featureEnableMacroCallTracking) { if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true)) prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
} }
@ -497,7 +511,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
// Ensure that the checker init happens eagerly rather than lazily. // Ensure that the checker init happens eagerly rather than lazily.
if e.HasFeature(featureEagerlyValidateDeclarations) { if e.HasFeature(featureEagerlyValidateDeclarations) {
err := e.initChecker() _, err := e.initChecker()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -506,7 +520,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return e, nil return e, nil
} }
func (e *Env) initChecker() error { func (e *Env) initChecker() (*checker.Env, error) {
e.chkOnce.Do(func() { e.chkOnce.Do(func() {
chkOpts := []checker.Option{} chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...) chkOpts = append(chkOpts, e.chkOpts...)
@ -518,32 +532,68 @@ func (e *Env) initChecker() error {
ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...) ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
if err != nil { if err != nil {
e.chkErr = err e.setCheckerOrError(nil, err)
return return
} }
// Add the statically configured declarations. // Add the statically configured declarations.
err = ce.Add(e.declarations...) err = ce.Add(e.declarations...)
if err != nil { if err != nil {
e.chkErr = err e.setCheckerOrError(nil, err)
return return
} }
// Add the function declarations which are derived from the FunctionDecl instances. // Add the function declarations which are derived from the FunctionDecl instances.
for _, fn := range e.functions { for _, fn := range e.functions {
fnDecl, err := functionDeclToExprDecl(fn) fnDecl, err := functionDeclToExprDecl(fn)
if err != nil { if err != nil {
e.chkErr = err e.setCheckerOrError(nil, err)
return return
} }
err = ce.Add(fnDecl) err = ce.Add(fnDecl)
if err != nil { if err != nil {
e.chkErr = err e.setCheckerOrError(nil, err)
return return
} }
} }
// Add function declarations here separately. // Add function declarations here separately.
e.chk = ce e.setCheckerOrError(ce, nil)
}) })
return e.chkErr return e.getCheckerOrError()
}
// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner
func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) {
e.chkMutex.Lock()
e.chk = chk
e.chkErr = chkErr
e.chkMutex.Unlock()
}
// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner
func (e *Env) getCheckerOrError() (*checker.Env, error) {
e.chkMutex.Lock()
defer e.chkMutex.Unlock()
return e.chk, e.chkErr
}
// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies
// the feature if it has not already been enabled.
func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
if !e.HasFeature(feature) {
return e, nil
}
_, applied := e.appliedFeatures[feature]
if applied {
return e, nil
}
e, err := option(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[feature] = true
return e, nil
} }
// Issues defines methods for inspecting the error details of parse and check calls. // Issues defines methods for inspecting the error details of parse and check calls.

View File

@ -19,14 +19,14 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common" "github.com/google/cel-go/common"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser" "github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
) )

View File

@ -20,10 +20,27 @@ import (
"time" "time"
"github.com/google/cel-go/checker" "github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
optMapMacro = "optMap"
hasValueFunc = "hasValue"
optionalNoneFunc = "optional.none"
optionalOfFunc = "optional.of"
optionalOfNonZeroValueFunc = "optional.ofNonZeroValue"
valueFunc = "value"
unusedIterVar = "#unused"
) )
// Library provides a collection of EnvOption and ProgramOption values used to configure a CEL // Library provides a collection of EnvOption and ProgramOption values used to configure a CEL
@ -42,10 +59,27 @@ type Library interface {
ProgramOptions() []ProgramOption ProgramOptions() []ProgramOption
} }
// SingletonLibrary refines the Library interface to ensure that libraries in this format are only
// configured once within the environment.
type SingletonLibrary interface {
Library
// LibraryName provides a namespaced name which is used to check whether the library has already
// been configured in the environment.
LibraryName() string
}
// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args, // Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args,
// and to be linked to each other. // and to be linked to each other.
func Lib(l Library) EnvOption { func Lib(l Library) EnvOption {
singleton, isSingleton := l.(SingletonLibrary)
return func(e *Env) (*Env, error) { return func(e *Env) (*Env, error) {
if isSingleton {
if e.HasLibrary(singleton.LibraryName()) {
return e, nil
}
e.libraries[singleton.LibraryName()] = true
}
var err error var err error
for _, opt := range l.CompileOptions() { for _, opt := range l.CompileOptions() {
e, err = opt(e) e, err = opt(e)
@ -67,6 +101,11 @@ func StdLib() EnvOption {
// features documented in the specification. // features documented in the specification.
type stdLibrary struct{} type stdLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// EnvOptions returns options for the standard CEL function declarations and macros. // EnvOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption { func (stdLibrary) CompileOptions() []EnvOption {
return []EnvOption{ return []EnvOption{
@ -82,6 +121,225 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
} }
} }
type optionalLibrary struct{}
// LibraryName implements the SingletonLibrary interface method.
func (optionalLibrary) LibraryName() string {
return "cel.lib.optional"
}
// CompileOptions implements the Library interface method.
func (optionalLibrary) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
paramTypeV := TypeParamType("V")
optionalTypeV := OptionalType(paramTypeV)
listTypeV := ListType(paramTypeV)
mapTypeKV := MapType(paramTypeK, paramTypeV)
return []EnvOption{
// Enable the optional syntax in the parser.
enableOptionalSyntax(),
// Introduce the optional type.
Types(types.OptionalType),
// Configure the optMap macro.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
Overload("optional_of", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
return types.OptionalOf(value)
}))),
Function(optionalOfNonZeroValueFunc,
Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
v, isZeroer := value.(traits.Zeroer)
if !isZeroer || !v.IsZeroValue() {
return types.OptionalOf(value)
}
return types.OptionalNone
}))),
Function(optionalNoneFunc,
Overload("optional_none", []*Type{}, optionalTypeV,
FunctionBinding(func(values ...ref.Val) ref.Val {
return types.OptionalNone
}))),
Function(valueFunc,
MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return opt.GetValue()
}))),
Function(hasValueFunc,
MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return types.Bool(opt.HasValue())
}))),
// Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the
// evaluation chain.
Function("or",
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)),
Function("orValue",
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)),
// OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the
// optput type.
Function(operators.OptSelect,
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)),
// OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use
// these signatures to determine type-agreement without any special handling.
Function(operators.OptIndex,
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV),
Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV),
Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
// Index overloads to accommodate using an optional value as the operand.
Function(operators.Index,
Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("optional_map_index_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, &common.Error{
Message: "optMap() variable name must be a simple identifier",
Location: meh.OffsetLocation(varIdent.GetId()),
}
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
// ProgramOptions implements the Library interface method.
func (optionalLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
}
}
func enableOptionalSyntax() EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableOptionalSyntax(true))
return e, nil
}
}
func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
args := call.Args()
if len(args) != 2 {
return i, nil
}
switch call.Function() {
case "or":
if call.OverloadID() != "" && call.OverloadID() != "optional_or_optional" {
return i, nil
}
return &evalOptionalOr{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
case "orValue":
if call.OverloadID() != "" && call.OverloadID() != "optional_orValue_value" {
return i, nil
}
return &evalOptionalOrValue{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
default:
return i, nil
}
}
// evalOptionalOr selects between two optional values, either the first if it has a value, or
// the second optional expression is evaluated and returned.
type evalOptionalOr struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOr) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
}
// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
// its value is returned, otherwise the alternative value expression is evaluated and returned.
type evalOptionalOrValue struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOrValue) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
}
type timeUTCLibrary struct{} type timeUTCLibrary struct{}
func (timeUTCLibrary) CompileOptions() []EnvOption { func (timeUTCLibrary) CompileOptions() []EnvOption {

View File

@ -17,6 +17,7 @@ package cel
import ( import (
"github.com/google/cel-go/common" "github.com/google/cel-go/common"
"github.com/google/cel-go/parser" "github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
) )
@ -26,8 +27,11 @@ import (
// a Macro should be created per arg-count or as a var arg macro. // a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree, or an error // MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
// if the input arguments are not suitable for the expansion requirements for the macro in question. //
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
// are not well-formed, the result of the expansion will be an error.
// //
// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call // The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call
// and produces as output an Expr ast node. // and produces as output an Expr ast node.
@ -81,8 +85,10 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// input to produce an output list. // input to produce an output list.
// //
// There are two call patterns supported by map: // There are two call patterns supported by map:
// <iterRange>.map(<iterVar>, <transform>) //
// <iterRange>.map(<iterVar>, <predicate>, <transform>) // <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// In the second form only iterVar values which return true when provided to the predicate expression // In the second form only iterVar values which return true when provided to the predicate expression
// are transformed. // are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {

View File

@ -29,6 +29,7 @@ import (
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter" "github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb" descpb "google.golang.org/protobuf/types/descriptorpb"
@ -61,6 +62,10 @@ const (
// on a CEL timestamp operation. This fixes the scenario where the input time // on a CEL timestamp operation. This fixes the scenario where the input time
// is not already in UTC. // is not already in UTC.
featureDefaultUTCTimeZone featureDefaultUTCTimeZone
// Enable the use of optional types in the syntax, type-system, type-checking,
// and runtime.
featureOptionalTypes
) )
// EnvOption is a functional interface for configuring the environment. // EnvOption is a functional interface for configuring the environment.
@ -163,19 +168,19 @@ func Container(name string) EnvOption {
// Abbreviations can be useful when working with variables, functions, and especially types from // Abbreviations can be useful when working with variables, functions, and especially types from
// multiple namespaces: // multiple namespaces:
// //
// // CEL object construction // // CEL object construction
// qual.pkg.version.ObjTypeName{ // qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...} // field: alt.container.ver.FieldTypeName{value: ...}
// } // }
// //
// Only one the qualified names above may be used as the CEL container, so at least one of these // Only one the qualified names above may be used as the CEL container, so at least one of these
// references must be a long qualified name within an otherwise short CEL program. Using the // references must be a long qualified name within an otherwise short CEL program. Using the
// following abbreviations, the program becomes much simpler: // following abbreviations, the program becomes much simpler:
// //
// // CEL Go option // // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName") // Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction // // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}} // ObjTypeName{field: FieldTypeName{value: ...}}
// //
// There are a few rules for the qualified names and the simple abbreviations generated from them: // There are a few rules for the qualified names and the simple abbreviations generated from them:
// - Qualified names must be dot-delimited, e.g. `package.subpkg.name`. // - Qualified names must be dot-delimited, e.g. `package.subpkg.name`.
@ -188,9 +193,12 @@ func Container(name string) EnvOption {
// - Expanded abbreviations do not participate in namespace resolution. // - Expanded abbreviations do not participate in namespace resolution.
// - Abbreviation expansion is done instead of the container search for a matching identifier. // - Abbreviation expansion is done instead of the container search for a matching identifier.
// - Containers follow C++ namespace resolution rules with searches from the most qualified name // - Containers follow C++ namespace resolution rules with searches from the most qualified name
// to the least qualified name. //
// to the least qualified name.
//
// - Container references within the CEL program may be relative, and are resolved to fully // - Container references within the CEL program may be relative, and are resolved to fully
// qualified names at either type-check time or program plan time, whichever comes first. //
// qualified names at either type-check time or program plan time, whichever comes first.
// //
// If there is ever a case where an identifier could be in both the container and as an // If there is ever a case where an identifier could be in both the container and as an
// abbreviation, the abbreviation wins as this will ensure that the meaning of a program is // abbreviation, the abbreviation wins as this will ensure that the meaning of a program is
@ -216,7 +224,7 @@ func Abbrevs(qualifiedNames ...string) EnvOption {
// environment by default. // environment by default.
// //
// Note: This option must be specified after the CustomTypeProvider option when used together. // Note: This option must be specified after the CustomTypeProvider option when used together.
func Types(addTypes ...interface{}) EnvOption { func Types(addTypes ...any) EnvOption {
return func(e *Env) (*Env, error) { return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry) reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg { if !isReg {
@ -253,7 +261,7 @@ func Types(addTypes ...interface{}) EnvOption {
// //
// TypeDescs are hermetic to a single Env object, but may be copied to other Env values via // TypeDescs are hermetic to a single Env object, but may be copied to other Env values via
// extension or by re-using the same EnvOption with another NewEnv() call. // extension or by re-using the same EnvOption with another NewEnv() call.
func TypeDescs(descs ...interface{}) EnvOption { func TypeDescs(descs ...any) EnvOption {
return func(e *Env) (*Env, error) { return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry) reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg { if !isReg {
@ -350,8 +358,8 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// variables with the same name provided to the Eval() call. If Globals is used in a Library with // variables with the same name provided to the Eval() call. If Globals is used in a Library with
// a Lib EnvOption, vars may shadow variables provided by previously added libraries. // a Lib EnvOption, vars may shadow variables provided by previously added libraries.
// //
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`. // The vars value may either be an `interpreter.Activation` instance or a `map[string]any`.
func Globals(vars interface{}) ProgramOption { func Globals(vars any) ProgramOption {
return func(p *prog) (*prog, error) { return func(p *prog) (*prog, error) {
defaultVars, err := interpreter.NewActivation(vars) defaultVars, err := interpreter.NewActivation(vars)
if err != nil { if err != nil {
@ -404,6 +412,9 @@ const (
// OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails // OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails
// cost calculation is available via func ActualCost() // cost calculation is available via func ActualCost()
OptTrackCost EvalOption = 1 << iota OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
OptCheckStringFormat EvalOption = 1 << iota
) )
// EvalOptions sets one or more evaluation options which may affect the evaluation or Result. // EvalOptions sets one or more evaluation options which may affect the evaluation or Result.
@ -534,6 +545,13 @@ func DefaultUTCTimeZone(enabled bool) EnvOption {
return features(featureDefaultUTCTimeZone, enabled) return features(featureDefaultUTCTimeZone, enabled)
} }
// OptionalTypes enable support for optional syntax and types in CEL. The optional value type makes
// it possible to express whether variables have been provided, whether a result has been computed,
// and in the future whether an object field path, map key value, or list index has a value.
func OptionalTypes() EnvOption {
return Lib(optionalLibrary{})
}
// features sets the given feature flags. See list of Feature constants above. // features sets the given feature flags. See list of Feature constants above.
func features(flag int, enabled bool) EnvOption { func features(flag int, enabled bool) EnvOption {
return func(e *Env) (*Env, error) { return func(e *Env) (*Env, error) {
@ -541,3 +559,21 @@ func features(flag int, enabled bool) EnvOption {
return e, nil return e, nil
} }
} }
// ParserRecursionLimit adjusts the AST depth the parser will tolerate.
// Defaults defined in the parser package.
func ParserRecursionLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit))
return e, nil
}
}
// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse.
// Defaults defined in the parser package.
func ParserExpressionSizeLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit))
return e, nil
}
}

View File

@ -17,21 +17,20 @@ package cel
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"sync" "sync"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter" "github.com/google/cel-go/interpreter"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
) )
// Program is an evaluable view of an Ast. // Program is an evaluable view of an Ast.
type Program interface { type Program interface {
// Eval returns the result of an evaluation of the Ast and environment against the input vars. // Eval returns the result of an evaluation of the Ast and environment against the input vars.
// //
// The vars value may either be an `interpreter.Activation` or a `map[string]interface{}`. // The vars value may either be an `interpreter.Activation` or a `map[string]any`.
// //
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will // If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be: // be non-nil. Given this caveat on `details`, the return state from evaluation will be:
@ -43,16 +42,16 @@ type Program interface {
// An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption` // An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption`
// or `ProgramOption` values used in the creation of the evaluation environment or executable // or `ProgramOption` values used in the creation of the evaluation environment or executable
// program. // program.
Eval(interface{}) (ref.Val, *EvalDetails, error) Eval(any) (ref.Val, *EvalDetails, error)
// ContextEval evaluates the program with a set of input variables and a context object in order // ContextEval evaluates the program with a set of input variables and a context object in order
// to support cancellation and timeouts. This method must be used in conjunction with the // to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation. // InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
// //
// The vars value may either be an `interpreter.Activation` or `map[string]interface{}`. // The vars value may either be an `interpreter.Activation` or `map[string]any`.
// //
// The output contract for `ContextEval` is otherwise identical to the `Eval` method. // The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, interface{}) (ref.Val, *EvalDetails, error) ContextEval(context.Context, any) (ref.Val, *EvalDetails, error)
} }
// NoVars returns an empty Activation. // NoVars returns an empty Activation.
@ -65,7 +64,7 @@ func NoVars() interpreter.Activation {
// //
// The `vars` value may either be an interpreter.Activation or any valid input to the // The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call. // interpreter.NewActivation call.
func PartialVars(vars interface{}, func PartialVars(vars any,
unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) { unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) {
return interpreter.NewPartialActivation(vars, unknowns...) return interpreter.NewPartialActivation(vars, unknowns...)
} }
@ -207,6 +206,37 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 { if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...)) decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
} }
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
t, err := ExprTypeToType(ast.typeMap[id])
if err != nil {
return false, err
}
if t.kind == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if k == t.kind {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory. // Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@ -268,7 +298,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
} }
// Eval implements the Program interface method. // Eval implements the Program interface method.
func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error) { func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named // Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery // return values makes it possible to modify the error response during the recovery
// function. // function.
@ -287,11 +317,11 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
switch v := input.(type) { switch v := input.(type) {
case interpreter.Activation: case interpreter.Activation:
vars = v vars = v
case map[string]interface{}: case map[string]any:
vars = activationPool.Setup(v) vars = activationPool.Setup(v)
defer activationPool.Put(vars) defer activationPool.Put(vars)
default: default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input) return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
} }
if p.defaultVars != nil { if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
@ -307,7 +337,7 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
} }
// ContextEval implements the Program interface. // ContextEval implements the Program interface.
func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) { func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil { if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil") return nil, nil, fmt.Errorf("context can not be nil")
} }
@ -318,22 +348,17 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev
case interpreter.Activation: case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency) vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars) defer ctxActivationPool.Put(vars)
case map[string]interface{}: case map[string]any:
rawVars := activationPool.Setup(v) rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars) defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency) vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars) defer ctxActivationPool.Put(vars)
default: default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input) return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
} }
return p.Eval(vars) return p.Eval(vars)
} }
// Cost implements the Coster interface method.
func (p *prog) Cost() (min, max int64) {
return estimateCost(p.interpretable)
}
// progFactory is a helper alias for marking a program creation factory function. // progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error) type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
@ -354,7 +379,7 @@ func newProgGen(factory progFactory) (Program, error) {
} }
// Eval implements the Program interface method. // Eval implements the Program interface method.
func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) { func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a // The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful // new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results. // results.
@ -379,7 +404,7 @@ func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
} }
// ContextEval implements the Program interface method. // ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) { func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil { if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil") return nil, nil, fmt.Errorf("context can not be nil")
} }
@ -406,29 +431,6 @@ func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val
return v, det, nil return v, det, nil
} }
// Cost implements the Coster interface method.
func (gen *progGen) Cost() (min, max int64) {
// Use an empty state value since no evaluation is performed.
p, err := gen.factory(emptyEvalState, nil)
if err != nil {
return 0, math.MaxInt64
}
return estimateCost(p)
}
// EstimateCost returns the heuristic cost interval for the program.
func EstimateCost(p Program) (min, max int64) {
return estimateCost(p)
}
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(interpreter.Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
}
type ctxEvalActivation struct { type ctxEvalActivation struct {
parent interpreter.Activation parent interpreter.Activation
interrupt <-chan struct{} interrupt <-chan struct{}
@ -438,7 +440,7 @@ type ctxEvalActivation struct {
// ResolveName implements the Activation interface method, but adds a special #interrupted variable // ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel. // which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) { func (a *ctxEvalActivation) ResolveName(name string) (any, bool) {
if name == "#interrupted" { if name == "#interrupted" {
a.interruptCheckCount++ a.interruptCheckCount++
if a.interruptCheckCount%a.interruptCheckFrequency == 0 { if a.interruptCheckCount%a.interruptCheckFrequency == 0 {
@ -461,7 +463,7 @@ func (a *ctxEvalActivation) Parent() interpreter.Activation {
func newCtxEvalActivationPool() *ctxEvalActivationPool { func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{ return &ctxEvalActivationPool{
Pool: sync.Pool{ Pool: sync.Pool{
New: func() interface{} { New: func() any {
return &ctxEvalActivation{} return &ctxEvalActivation{}
}, },
}, },
@ -483,21 +485,21 @@ func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan s
} }
type evalActivation struct { type evalActivation struct {
vars map[string]interface{} vars map[string]any
lazyVars map[string]interface{} lazyVars map[string]any
} }
// ResolveName looks up the value of the input variable name, if found. // ResolveName looks up the value of the input variable name, if found.
// //
// Lazy bindings may be supplied within the map-based input in either of the following forms: // Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{} // - func() any
// - func() ref.Val // - func() ref.Val
// //
// The lazy binding will only be invoked once per evaluation. // The lazy binding will only be invoked once per evaluation.
// //
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using // Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment. // the ref.TypeAdapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (interface{}, bool) { func (a *evalActivation) ResolveName(name string) (any, bool) {
v, found := a.vars[name] v, found := a.vars[name]
if !found { if !found {
return nil, false return nil, false
@ -510,7 +512,7 @@ func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
lazy := obj() lazy := obj()
a.lazyVars[name] = lazy a.lazyVars[name] = lazy
return lazy, true return lazy, true
case func() interface{}: case func() any:
if resolved, found := a.lazyVars[name]; found { if resolved, found := a.lazyVars[name]; found {
return resolved, true return resolved, true
} }
@ -530,8 +532,8 @@ func (a *evalActivation) Parent() interpreter.Activation {
func newEvalActivationPool() *evalActivationPool { func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{ return &evalActivationPool{
Pool: sync.Pool{ Pool: sync.Pool{
New: func() interface{} { New: func() any {
return &evalActivation{lazyVars: make(map[string]interface{})} return &evalActivation{lazyVars: make(map[string]any)}
}, },
}, },
} }
@ -542,13 +544,13 @@ type evalActivationPool struct {
} }
// Setup initializes a pooled Activation object with the map input. // Setup initializes a pooled Activation object with the map input.
func (p *evalActivationPool) Setup(vars map[string]interface{}) *evalActivation { func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation {
a := p.Pool.Get().(*evalActivation) a := p.Pool.Get().(*evalActivation)
a.vars = vars a.vars = vars
return a return a
} }
func (p *evalActivationPool) Put(value interface{}) { func (p *evalActivationPool) Put(value any) {
a := value.(*evalActivation) a := value.(*evalActivation)
for k := range a.lazyVars { for k := range a.lazyVars {
delete(a.lazyVars, k) delete(a.lazyVars, k)
@ -559,7 +561,7 @@ func (p *evalActivationPool) Put(value interface{}) {
var ( var (
emptyEvalState = interpreter.NewEvalState() emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]interface{} inputs // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool() activationPool = newEvalActivationPool()
// ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable // ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable

View File

@ -30,7 +30,7 @@ go_library(
"//common/types/pb:go_default_library", "//common/types/pb:go_default_library",
"//common/types/ref:go_default_library", "//common/types/ref:go_default_library",
"//parser:go_default_library", "//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library", "@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library",
@ -54,7 +54,7 @@ go_test(
"//test:go_default_library", "//test:go_default_library",
"//test/proto2pb:go_default_library", "//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library", "//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library", "@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
], ],
) )

View File

@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/checker/decls" "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common" "github.com/google/cel-go/common"
"github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -173,8 +174,8 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified // Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name. // variable name.
c.setType(e, ident.GetIdent().Type) c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value)) c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
identName := ident.GetName() identName := ident.GetName()
e.ExprKind = &exprpb.Expr_IdentExpr{ e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{ IdentExpr: &exprpb.Expr_Ident{
@ -185,9 +186,37 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
} }
} }
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = decls.Bool
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *exprpb.Type {
// Interpret as field selection, first traversing down the operand. // Interpret as field selection, first traversing down the operand.
c.check(sel.GetOperand()) c.check(operand)
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false) operandType := substitute(c.mappings, c.getType(operand), false)
// If the target type is 'optional', unwrap it for the sake of this check.
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection. // Assume error type by default as most types do not support field selection.
resultType := decls.Error resultType := decls.Error
switch kindOf(targetType) { switch kindOf(targetType) {
@ -199,7 +228,7 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Objects yield their field type declaration as the selection result type, but only if // Objects yield their field type declaration as the selection result type, but only if
// the field is defined. // the field is defined.
messageType := targetType messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found { if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found {
resultType = fieldType.Type resultType = fieldType.Type
} }
case kindTypeParam: case kindTypeParam:
@ -212,16 +241,17 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
default: default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well // Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check. // in order to allow forward progress on the check.
if isDynOrError(targetType) { if !isDynOrError(targetType) {
resultType = decls.Dyn
} else {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType) c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
} }
resultType = decls.Dyn
} }
if sel.TestOnly {
resultType = decls.Bool // If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return decls.NewOptionalType(resultType)
} }
c.setType(e, substitute(c.mappings, resultType, false)) return resultType
} }
func (c *checker) checkCall(e *exprpb.Expr) { func (c *checker) checkCall(e *exprpb.Expr) {
@ -229,15 +259,19 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// please consider the impact on planner.go and consolidate implementations or mirror code // please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate. // as appropriate.
call := e.GetCallExpr() call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
fnName := call.GetFunction() fnName := call.GetFunction()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
// Traverse arguments. // Traverse arguments.
for _, arg := range args { for _, arg := range args {
c.check(arg) c.check(arg)
} }
target := call.GetTarget()
// Regular static call with simple name. // Regular static call with simple name.
if target == nil { if target == nil {
// Check for the existence of the function. // Check for the existence of the function.
@ -359,6 +393,9 @@ func (c *checker) resolveOverload(
} }
if resultType == nil { if resultType == nil {
for i, arg := range argTypes {
argTypes[i] = substitute(c.mappings, arg, true)
}
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil) c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error resultType = decls.Error
return nil return nil
@ -369,16 +406,29 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) { func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr() create := e.GetListExpr()
var elemType *exprpb.Type var elemsType *exprpb.Type
for _, e := range create.GetElements() { optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
c.check(e) c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e)) elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(c.location(e), elemsType, elemType)
} }
if elemType == nil { if elemsType == nil {
// If the list is empty, assign free type var to elem type. // If the list is empty, assign free type var to elem type.
elemType = c.newTypeVar() elemsType = c.newTypeVar()
} }
c.setType(e, decls.NewListType(elemType)) c.setType(e, decls.NewListType(elemsType))
} }
func (c *checker) checkCreateStruct(e *exprpb.Expr) { func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@ -392,22 +442,31 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) { func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr() mapVal := e.GetStructExpr()
var keyType *exprpb.Type var mapKeyType *exprpb.Type
var valueType *exprpb.Type var mapValueType *exprpb.Type
for _, ent := range mapVal.GetEntries() { for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey() key := ent.GetMapKey()
c.check(key) c.check(key)
keyType = c.joinTypes(c.location(key), keyType, c.getType(key)) mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key))
c.check(ent.Value) val := ent.GetValue()
valueType = c.joinTypes(c.location(ent.Value), valueType, c.getType(ent.Value)) c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(c.location(val), mapValueType, valType)
} }
if keyType == nil { if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type. // If the map is empty, assign free type variables to typeKey and value type.
keyType = c.newTypeVar() mapKeyType = c.newTypeVar()
valueType = c.newTypeVar() mapValueType = c.newTypeVar()
} }
c.setType(e, decls.NewMapType(keyType, valueType)) c.setType(e, decls.NewMapType(mapKeyType, mapValueType))
} }
func (c *checker) checkCreateMessage(e *exprpb.Expr) { func (c *checker) checkCreateMessage(e *exprpb.Expr) {
@ -449,15 +508,21 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
c.check(value) c.check(value)
fieldType := decls.Error fieldType := decls.Error
if t, found := c.lookupFieldType( ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field)
c.locationByID(ent.GetId()), if found {
messageType.GetMessageType(), fieldType = ft.Type
field); found {
fieldType = t.Type
} }
if !c.isAssignable(fieldType, c.getType(value)) {
c.errors.fieldTypeMismatch( valType := c.getType(value)
c.locationByID(ent.Id), field, fieldType, c.getType(value)) if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType)
} }
} }
} }

View File

@ -92,7 +92,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
case *exprpb.Expr_ConstExpr: case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) { switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue: case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue)) // converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue: case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue)) v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue, case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
@ -258,6 +261,8 @@ type coster struct {
computedSizes map[int64]SizeEstimate computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr checkedExpr *exprpb.CheckedExpr
estimator CostEstimator estimator CostEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
} }
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names. // Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
@ -280,16 +285,39 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
return 0, false return 0, false
} }
// Cost estimates the cost of the parsed and type checked CEL expression. // CostOption configures flags which affect cost computations.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate { type CostOption func(*coster) error
c := coster{
checkedExpr: checker, // PresenceTestHasCost determines whether presence testing has a cost of one or zero.
estimator: estimator, // Defaults to presence test has a cost of one.
exprPath: map[int64][]string{}, func PresenceTestHasCost(hasCost bool) CostOption {
iterRanges: map[string][]int64{}, return func(c *coster) error {
computedSizes: map[int64]SizeEstimate{}, if hasCost {
c.presenceTestCost = selectAndIdentCost
return nil
}
c.presenceTestCost = CostEstimate{Min: 0, Max: 0}
return nil
} }
return c.cost(checker.GetExpr()) }
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
if err != nil {
return CostEstimate{}, err
}
}
return c.cost(checker.GetExpr()), nil
} }
func (c *coster) cost(e *exprpb.Expr) CostEstimate { func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@ -340,6 +368,12 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr() sel := e.GetSelectExpr()
var sum CostEstimate var sum CostEstimate
if sel.GetTestOnly() { if sel.GetTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
return sum return sum
} }
sum = sum.Add(c.cost(sel.GetOperand())) sum = sum.Add(c.cost(sel.GetOperand()))
@ -503,7 +537,10 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
} }
switch overloadID { switch overloadID {
// O(n) functions // O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString: case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
if overloadID == overloads.ExtFormatString {
return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
if len(args) == 1 { if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
} }

View File

@ -13,7 +13,7 @@ go_library(
], ],
importpath = "github.com/google/cel-go/checker/decls", importpath = "github.com/google/cel-go/checker/decls",
deps = [ deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library", "@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library",
], ],

View File

@ -16,9 +16,9 @@
package decls package decls
import ( import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
emptypb "google.golang.org/protobuf/types/known/emptypb" emptypb "google.golang.org/protobuf/types/known/emptypb"
structpb "google.golang.org/protobuf/types/known/structpb" structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
) )
var ( var (
@ -64,6 +64,12 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
ParameterTypes: paramTypes}}} ParameterTypes: paramTypes}}}
} }
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used // NewFunctionType creates a function invocation contract, typically only used
// by type-checking steps after overload resolution. // by type-checking steps after overload resolution.
func NewFunctionType(resultType *exprpb.Type, func NewFunctionType(resultType *exprpb.Type,

View File

@ -226,7 +226,7 @@ func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{} newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads { for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()] existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !proto.Equal(existing, overload) { if !found || !overloadsEqual(existing, overload) {
newOverloads = append(newOverloads, overload) newOverloads = append(newOverloads, overload)
} }
} }
@ -264,6 +264,31 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found return found
} }
// overloadsEqual returns whether two overloads have identical signatures.
//
// type parameter names are ignored as they may be specified in any order and have no bearing on overload
// equivalence
func overloadsEqual(o1, o2 *exprpb.Decl_FunctionDecl_Overload) bool {
return o1.GetOverloadId() == o2.GetOverloadId() &&
o1.GetIsInstanceFunction() == o2.GetIsInstanceFunction() &&
paramsEqual(o1.GetParams(), o2.GetParams()) &&
proto.Equal(o1.GetResultType(), o2.GetResultType())
}
// paramsEqual returns whether two lists have equal length and all types are equal
func paramsEqual(p1, p2 []*exprpb.Type) bool {
if len(p1) != len(p2) {
return false
}
for i, a := range p1 {
b := p2[i]
if !proto.Equal(a, b) {
return false
}
}
return true
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent // sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances. // CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl { func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {

View File

@ -26,7 +26,7 @@ type semanticAdorner struct {
var _ debug.Adorner = &semanticAdorner{} var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem interface{}) string { func (a *semanticAdorner) GetMetadata(elem any) string {
result := "" result := ""
e, isExpr := elem.(*exprpb.Expr) e, isExpr := elem.(*exprpb.Expr)
if !isExpr { if !isExpr {

View File

@ -287,6 +287,8 @@ func init() {
decls.NewInstanceOverload(overloads.EndsWithString, decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)), []*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches, decls.NewFunction(overloads.Matches,
decls.NewOverload(overloads.Matches,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewInstanceOverload(overloads.MatchesString, decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)), []*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith, decls.NewFunction(overloads.StartsWith,

View File

@ -90,6 +90,14 @@ func FormatCheckedType(t *exprpb.Type) string {
return "!error!" return "!error!"
case kindTypeParam: case kindTypeParam:
return t.GetTypeParam() return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
} }
return t.String() return t.String()
} }
@ -110,12 +118,39 @@ func isDyn(t *exprpb.Type) bool {
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message. // isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool { func isDynOrError(t *exprpb.Type) bool {
switch kindOf(t) { return isError(t) || isDyn(t)
case kindError: }
return true
default: func isError(t *exprpb.Type) bool {
return isDyn(t) return kindOf(t) == kindError
}
func isOptional(t *exprpb.Type) bool {
if kindOf(t) == kindAbstract {
at := t.GetAbstractType()
return at.GetName() == "optional"
} }
return false
}
func maybeUnwrapOptional(t *exprpb.Type) (*exprpb.Type, bool) {
if isOptional(t) {
at := t.GetAbstractType()
return at.GetParameterTypes()[0], true
}
return t, false
}
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
return "", false
} }
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one. // isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
@ -236,7 +271,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// substitution for t1, and whether t2 has a type substitution in mapping m. // substitution for t1, and whether t2 has a type substitution in mapping m.
// //
// The type t2 is a valid substitution for t1 if any of the following statements is true // The type t2 is a valid substitution for t1 if any of the following statements is true
// - t2 has a type substitition (t2sub) equal to t1 // - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1 // - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1. // - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) { func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {

View File

@ -17,7 +17,7 @@ go_library(
importpath = "github.com/google/cel-go/common", importpath = "github.com/google/cel-go/common",
deps = [ deps = [
"//common/runes:go_default_library", "//common/runes:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_x_text//width:go_default_library", "@org_golang_x_text//width:go_default_library",
], ],
) )

View File

@ -12,7 +12,7 @@ go_library(
], ],
importpath = "github.com/google/cel-go/common/containers", importpath = "github.com/google/cel-go/common/containers",
deps = [ deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
], ],
) )
@ -26,6 +26,6 @@ go_test(
":go_default_library", ":go_default_library",
], ],
deps = [ deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
], ],
) )

View File

@ -13,6 +13,6 @@ go_library(
importpath = "github.com/google/cel-go/common/debug", importpath = "github.com/google/cel-go/common/debug",
deps = [ deps = [
"//common:go_default_library", "//common:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
], ],
) )

View File

@ -29,7 +29,7 @@ import (
// representation of an expression. // representation of an expression.
type Adorner interface { type Adorner interface {
// GetMetadata for the input context. // GetMetadata for the input context.
GetMetadata(ctx interface{}) string GetMetadata(ctx any) string
} }
// Writer manages writing expressions to an internal string. // Writer manages writing expressions to an internal string.
@ -46,7 +46,7 @@ type emptyDebugAdorner struct {
var emptyAdorner Adorner = &emptyDebugAdorner{} var emptyAdorner Adorner = &emptyDebugAdorner{}
func (a *emptyDebugAdorner) GetMetadata(e interface{}) string { func (a *emptyDebugAdorner) GetMetadata(e any) string {
return "" return ""
} }
@ -170,6 +170,9 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(",") w.append(",")
w.appendLine() w.appendLine()
} }
if entry.GetOptionalEntry() {
w.append("?")
}
w.append(entry.GetFieldKey()) w.append(entry.GetFieldKey())
w.append(":") w.append(":")
w.Buffer(entry.GetValue()) w.Buffer(entry.GetValue())
@ -191,6 +194,9 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append(",") w.append(",")
w.appendLine() w.appendLine()
} }
if entry.GetOptionalEntry() {
w.append("?")
}
w.Buffer(entry.GetMapKey()) w.Buffer(entry.GetMapKey())
w.append(":") w.append(":")
w.Buffer(entry.GetValue()) w.Buffer(entry.GetValue())
@ -269,7 +275,7 @@ func (w *debugWriter) append(s string) {
w.buffer.WriteString(s) w.buffer.WriteString(s)
} }
func (w *debugWriter) appendFormat(f string, args ...interface{}) { func (w *debugWriter) appendFormat(f string, args ...any) {
w.append(fmt.Sprintf(f, args...)) w.append(fmt.Sprintf(f, args...))
} }
@ -280,7 +286,7 @@ func (w *debugWriter) doIndent() {
} }
} }
func (w *debugWriter) adorn(e interface{}) { func (w *debugWriter) adorn(e any) {
w.append(w.adorner.GetMetadata(e)) w.append(w.adorner.GetMetadata(e))
} }

View File

@ -38,7 +38,7 @@ func NewErrors(source Source) *Errors {
} }
// ReportError records an error at a source location. // ReportError records an error at a source location.
func (e *Errors) ReportError(l Location, format string, args ...interface{}) { func (e *Errors) ReportError(l Location, format string, args ...any) {
e.numErrors++ e.numErrors++
if e.numErrors > e.maxErrorsToReport { if e.numErrors > e.maxErrorsToReport {
return return

View File

@ -37,6 +37,8 @@ const (
Modulo = "_%_" Modulo = "_%_"
Negate = "-_" Negate = "-_"
Index = "_[_]" Index = "_[_]"
OptIndex = "_[?_]"
OptSelect = "_?._"
// Macros, must have a valid identifier. // Macros, must have a valid identifier.
Has = "has" Has = "has"
@ -99,6 +101,8 @@ var (
LogicalNot: {displayName: "!", precedence: 2, arity: 1}, LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1}, Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2}, Index: {displayName: "", precedence: 1, arity: 2},
OptIndex: {displayName: "", precedence: 1, arity: 2},
OptSelect: {displayName: "", precedence: 1, arity: 2},
} }
) )

View File

@ -148,6 +148,11 @@ const (
StartsWith = "startsWith" StartsWith = "startsWith"
) )
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtQuoteString = "strings_quote"
)
// String function overload names. // String function overload names.
const ( const (
ContainsString = "contains_string" ContainsString = "contains_string"
@ -156,6 +161,11 @@ const (
StartsWithString = "starts_with_string" StartsWithString = "starts_with_string"
) )
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtFormatString = "string_format"
)
// Time-based functions. // Time-based functions.
const ( const (
TimeGetFullYear = "getFullYear" TimeGetFullYear = "getFullYear"

View File

@ -22,6 +22,7 @@ go_library(
"map.go", "map.go",
"null.go", "null.go",
"object.go", "object.go",
"optional.go",
"overflow.go", "overflow.go",
"provider.go", "provider.go",
"string.go", "string.go",
@ -38,10 +39,8 @@ go_library(
"//common/types/ref:go_default_library", "//common/types/ref:go_default_library",
"//common/types/traits:go_default_library", "//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library", "@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library", "@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@ -68,6 +67,7 @@ go_test(
"map_test.go", "map_test.go",
"null_test.go", "null_test.go",
"object_test.go", "object_test.go",
"optional_test.go",
"provider_test.go", "provider_test.go",
"string_test.go", "string_test.go",
"timestamp_test.go", "timestamp_test.go",
@ -80,7 +80,7 @@ go_test(
"//common/types/ref:go_default_library", "//common/types/ref:go_default_library",
"//test:go_default_library", "//test:go_default_library",
"//test/proto3pb:test_all_types_go_proto", "//test/proto3pb:test_all_types_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library", "@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library", "@org_golang_google_protobuf//types/known/durationpb:go_default_library",

View File

@ -62,7 +62,7 @@ func (b Bool) Compare(other ref.Val) ref.Val {
} }
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
func (b Bool) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (b Bool) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() { switch typeDesc.Kind() {
case reflect.Bool: case reflect.Bool:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -114,6 +114,11 @@ func (b Bool) Equal(other ref.Val) ref.Val {
return Bool(ok && b == otherBool) return Bool(ok && b == otherBool)
} }
// IsZeroValue returns true if the boolean value is false.
func (b Bool) IsZeroValue() bool {
return b == False
}
// Negate implements the traits.Negater interface method. // Negate implements the traits.Negater interface method.
func (b Bool) Negate() ref.Val { func (b Bool) Negate() ref.Val {
return !b return !b
@ -125,7 +130,7 @@ func (b Bool) Type() ref.Type {
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (b Bool) Value() interface{} { func (b Bool) Value() any {
return bool(b) return bool(b)
} }

View File

@ -63,7 +63,7 @@ func (b Bytes) Compare(other ref.Val) ref.Val {
} }
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() { switch typeDesc.Kind() {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@ -116,6 +116,11 @@ func (b Bytes) Equal(other ref.Val) ref.Val {
return Bool(ok && bytes.Equal(b, otherBytes)) return Bool(ok && bytes.Equal(b, otherBytes))
} }
// IsZeroValue returns true if the byte array is empty.
func (b Bytes) IsZeroValue() bool {
return len(b) == 0
}
// Size implements the traits.Sizer interface method. // Size implements the traits.Sizer interface method.
func (b Bytes) Size() ref.Val { func (b Bytes) Size() ref.Val {
return Int(len(b)) return Int(len(b))
@ -127,6 +132,6 @@ func (b Bytes) Type() ref.Type {
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (b Bytes) Value() interface{} { func (b Bytes) Value() any {
return []byte(b) return []byte(b)
} }

View File

@ -78,7 +78,7 @@ func (d Double) Compare(other ref.Val) ref.Val {
} }
// ConvertToNative implements ref.Val.ConvertToNative. // ConvertToNative implements ref.Val.ConvertToNative.
func (d Double) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (d Double) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() { switch typeDesc.Kind() {
case reflect.Float32: case reflect.Float32:
v := float32(d) v := float32(d)
@ -134,13 +134,13 @@ func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
case IntType: case IntType:
i, err := doubleToInt64Checked(float64(d)) i, err := doubleToInt64Checked(float64(d))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(i) return Int(i)
case UintType: case UintType:
i, err := doubleToUint64Checked(float64(d)) i, err := doubleToUint64Checked(float64(d))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Uint(i) return Uint(i)
case DoubleType: case DoubleType:
@ -182,6 +182,11 @@ func (d Double) Equal(other ref.Val) ref.Val {
} }
} }
// IsZeroValue returns true if double value is 0.0
func (d Double) IsZeroValue() bool {
return float64(d) == 0.0
}
// Multiply implements traits.Multiplier.Multiply. // Multiply implements traits.Multiplier.Multiply.
func (d Double) Multiply(other ref.Val) ref.Val { func (d Double) Multiply(other ref.Val) ref.Val {
otherDouble, ok := other.(Double) otherDouble, ok := other.(Double)
@ -211,6 +216,6 @@ func (d Double) Type() ref.Type {
} }
// Value implements ref.Val.Value. // Value implements ref.Val.Value.
func (d Double) Value() interface{} { func (d Double) Value() any {
return float64(d) return float64(d)
} }

View File

@ -57,14 +57,14 @@ func (d Duration) Add(other ref.Val) ref.Val {
dur2 := other.(Duration) dur2 := other.(Duration)
val, err := addDurationChecked(d.Duration, dur2.Duration) val, err := addDurationChecked(d.Duration, dur2.Duration)
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return durationOf(val) return durationOf(val)
case TimestampType: case TimestampType:
ts := other.(Timestamp).Time ts := other.(Timestamp).Time
val, err := addTimeDurationChecked(ts, d.Duration) val, err := addTimeDurationChecked(ts, d.Duration)
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return timestampOf(val) return timestampOf(val)
} }
@ -90,7 +90,7 @@ func (d Duration) Compare(other ref.Val) ref.Val {
} }
// ConvertToNative implements ref.Val.ConvertToNative. // ConvertToNative implements ref.Val.ConvertToNative.
func (d Duration) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (d Duration) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the duration is already assignable to the desired type return it. // If the duration is already assignable to the desired type return it.
if reflect.TypeOf(d.Duration).AssignableTo(typeDesc) { if reflect.TypeOf(d.Duration).AssignableTo(typeDesc) {
return d.Duration, nil return d.Duration, nil
@ -138,11 +138,16 @@ func (d Duration) Equal(other ref.Val) ref.Val {
return Bool(ok && d.Duration == otherDur.Duration) return Bool(ok && d.Duration == otherDur.Duration)
} }
// IsZeroValue returns true if the duration value is zero
func (d Duration) IsZeroValue() bool {
return d.Duration == 0
}
// Negate implements traits.Negater.Negate. // Negate implements traits.Negater.Negate.
func (d Duration) Negate() ref.Val { func (d Duration) Negate() ref.Val {
val, err := negateDurationChecked(d.Duration) val, err := negateDurationChecked(d.Duration)
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return durationOf(val) return durationOf(val)
} }
@ -165,7 +170,7 @@ func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
} }
val, err := subtractDurationChecked(d.Duration, subtraDur.Duration) val, err := subtractDurationChecked(d.Duration, subtraDur.Duration)
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return durationOf(val) return durationOf(val)
} }
@ -176,7 +181,7 @@ func (d Duration) Type() ref.Type {
} }
// Value implements ref.Val.Value. // Value implements ref.Val.Value.
func (d Duration) Value() interface{} { func (d Duration) Value() any {
return d.Duration return d.Duration
} }

View File

@ -22,6 +22,12 @@ import (
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
) )
// Error interface which allows types types.Err values to be treated as error values.
type Error interface {
error
ref.Val
}
// Err type which extends the built-in go error and implements ref.Val. // Err type which extends the built-in go error and implements ref.Val.
type Err struct { type Err struct {
error error
@ -51,7 +57,7 @@ var (
// NewErr creates a new Err described by the format string and args. // NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes. // TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...interface{}) ref.Val { func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)} return &Err{fmt.Errorf(format, args...)}
} }
@ -62,7 +68,7 @@ func NoSuchOverloadErr() ref.Val {
// UnsupportedRefValConversionErr returns a types.NewErr instance with a no such conversion // UnsupportedRefValConversionErr returns a types.NewErr instance with a no such conversion
// message that indicates that the native value could not be converted to a CEL ref.Val. // message that indicates that the native value could not be converted to a CEL ref.Val.
func UnsupportedRefValConversionErr(val interface{}) ref.Val { func UnsupportedRefValConversionErr(val any) ref.Val {
return NewErr("unsupported conversion to ref.Val: (%T)%v", val, val) return NewErr("unsupported conversion to ref.Val: (%T)%v", val, val)
} }
@ -74,20 +80,20 @@ func MaybeNoSuchOverloadErr(val ref.Val) ref.Val {
// ValOrErr either returns the existing error or creates a new one. // ValOrErr either returns the existing error or creates a new one.
// TODO: Audit the use of this function and standardize the error messages and codes. // TODO: Audit the use of this function and standardize the error messages and codes.
func ValOrErr(val ref.Val, format string, args ...interface{}) ref.Val { func ValOrErr(val ref.Val, format string, args ...any) ref.Val {
if val == nil || !IsUnknownOrError(val) { if val == nil || !IsUnknownOrError(val) {
return NewErr(format, args...) return NewErr(format, args...)
} }
return val return val
} }
// wrapErr wraps an existing Go error value into a CEL Err value. // WrapErr wraps an existing Go error value into a CEL Err value.
func wrapErr(err error) ref.Val { func WrapErr(err error) ref.Val {
return &Err{error: err} return &Err{error: err}
} }
// ConvertToNative implements ref.Val.ConvertToNative. // ConvertToNative implements ref.Val.ConvertToNative.
func (e *Err) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (e *Err) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, e.error return nil, e.error
} }
@ -114,10 +120,15 @@ func (e *Err) Type() ref.Type {
} }
// Value implements ref.Val.Value. // Value implements ref.Val.Value.
func (e *Err) Value() interface{} { func (e *Err) Value() any {
return e.error return e.error
} }
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()
}
// IsError returns whether the input element ref.Type or ref.Val is equal to // IsError returns whether the input element ref.Type or ref.Val is equal to
// the ErrType singleton. // the ErrType singleton.
func IsError(val ref.Val) bool { func IsError(val ref.Val) bool {

View File

@ -66,7 +66,7 @@ func (i Int) Add(other ref.Val) ref.Val {
} }
val, err := addInt64Checked(int64(i), int64(otherInt)) val, err := addInt64Checked(int64(i), int64(otherInt))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -89,7 +89,7 @@ func (i Int) Compare(other ref.Val) ref.Val {
} }
// ConvertToNative implements ref.Val.ConvertToNative. // ConvertToNative implements ref.Val.ConvertToNative.
func (i Int) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() { switch typeDesc.Kind() {
case reflect.Int, reflect.Int32: case reflect.Int, reflect.Int32:
// Enums are also mapped as int32 derivations. // Enums are also mapped as int32 derivations.
@ -176,7 +176,7 @@ func (i Int) ConvertToType(typeVal ref.Type) ref.Val {
case UintType: case UintType:
u, err := int64ToUint64Checked(int64(i)) u, err := int64ToUint64Checked(int64(i))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Uint(u) return Uint(u)
case DoubleType: case DoubleType:
@ -204,7 +204,7 @@ func (i Int) Divide(other ref.Val) ref.Val {
} }
val, err := divideInt64Checked(int64(i), int64(otherInt)) val, err := divideInt64Checked(int64(i), int64(otherInt))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -226,6 +226,11 @@ func (i Int) Equal(other ref.Val) ref.Val {
} }
} }
// IsZeroValue returns true if integer is equal to 0
func (i Int) IsZeroValue() bool {
return i == IntZero
}
// Modulo implements traits.Modder.Modulo. // Modulo implements traits.Modder.Modulo.
func (i Int) Modulo(other ref.Val) ref.Val { func (i Int) Modulo(other ref.Val) ref.Val {
otherInt, ok := other.(Int) otherInt, ok := other.(Int)
@ -234,7 +239,7 @@ func (i Int) Modulo(other ref.Val) ref.Val {
} }
val, err := moduloInt64Checked(int64(i), int64(otherInt)) val, err := moduloInt64Checked(int64(i), int64(otherInt))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -247,7 +252,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
} }
val, err := multiplyInt64Checked(int64(i), int64(otherInt)) val, err := multiplyInt64Checked(int64(i), int64(otherInt))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -256,7 +261,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
func (i Int) Negate() ref.Val { func (i Int) Negate() ref.Val {
val, err := negateInt64Checked(int64(i)) val, err := negateInt64Checked(int64(i))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -269,7 +274,7 @@ func (i Int) Subtract(subtrahend ref.Val) ref.Val {
} }
val, err := subtractInt64Checked(int64(i), int64(subtraInt)) val, err := subtractInt64Checked(int64(i), int64(subtraInt))
if err != nil { if err != nil {
return wrapErr(err) return WrapErr(err)
} }
return Int(val) return Int(val)
} }
@ -280,7 +285,7 @@ func (i Int) Type() ref.Type {
} }
// Value implements ref.Val.Value. // Value implements ref.Val.Value.
func (i Int) Value() interface{} { func (i Int) Value() any {
return int64(i) return int64(i)
} }

View File

@ -34,7 +34,7 @@ var (
// interpreter. // interpreter.
type baseIterator struct{} type baseIterator struct{}
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion on iterators not supported") return nil, fmt.Errorf("type conversion on iterators not supported")
} }
@ -50,6 +50,6 @@ func (*baseIterator) Type() ref.Type {
return IteratorType return IteratorType
} }
func (*baseIterator) Value() interface{} { func (*baseIterator) Value() any {
return nil return nil
} }

View File

@ -25,4 +25,5 @@ var (
jsonValueType = reflect.TypeOf(&structpb.Value{}) jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{}) jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{}) jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
) )

View File

@ -17,11 +17,13 @@ package types
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb" structpb "google.golang.org/protobuf/types/known/structpb"
@ -40,13 +42,13 @@ var (
// NewDynamicList returns a traits.Lister with heterogenous elements. // NewDynamicList returns a traits.Lister with heterogenous elements.
// value should be an array of "native" types, i.e. any type that // value should be an array of "native" types, i.e. any type that
// NativeToValue() can convert to a ref.Val. // NativeToValue() can convert to a ref.Val.
func NewDynamicList(adapter ref.TypeAdapter, value interface{}) traits.Lister { func NewDynamicList(adapter ref.TypeAdapter, value any) traits.Lister {
refValue := reflect.ValueOf(value) refValue := reflect.ValueOf(value)
return &baseList{ return &baseList{
TypeAdapter: adapter, TypeAdapter: adapter,
value: value, value: value,
size: refValue.Len(), size: refValue.Len(),
get: func(i int) interface{} { get: func(i int) any {
return refValue.Index(i).Interface() return refValue.Index(i).Interface()
}, },
} }
@ -58,7 +60,7 @@ func NewStringList(adapter ref.TypeAdapter, elems []string) traits.Lister {
TypeAdapter: adapter, TypeAdapter: adapter,
value: elems, value: elems,
size: len(elems), size: len(elems),
get: func(i int) interface{} { return elems[i] }, get: func(i int) any { return elems[i] },
} }
} }
@ -70,7 +72,7 @@ func NewRefValList(adapter ref.TypeAdapter, elems []ref.Val) traits.Lister {
TypeAdapter: adapter, TypeAdapter: adapter,
value: elems, value: elems,
size: len(elems), size: len(elems),
get: func(i int) interface{} { return elems[i] }, get: func(i int) any { return elems[i] },
} }
} }
@ -80,7 +82,7 @@ func NewProtoList(adapter ref.TypeAdapter, list protoreflect.List) traits.Lister
TypeAdapter: adapter, TypeAdapter: adapter,
value: list, value: list,
size: list.Len(), size: list.Len(),
get: func(i int) interface{} { return list.Get(i).Interface() }, get: func(i int) any { return list.Get(i).Interface() },
} }
} }
@ -91,22 +93,25 @@ func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
TypeAdapter: adapter, TypeAdapter: adapter,
value: l, value: l,
size: len(vals), size: len(vals),
get: func(i int) interface{} { return vals[i] }, get: func(i int) any { return vals[i] },
} }
} }
// NewMutableList creates a new mutable list whose internal state can be modified. // NewMutableList creates a new mutable list whose internal state can be modified.
func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister { func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
var mutableValues []ref.Val var mutableValues []ref.Val
return &mutableList{ l := &mutableList{
baseList: &baseList{ baseList: &baseList{
TypeAdapter: adapter, TypeAdapter: adapter,
value: mutableValues, value: mutableValues,
size: 0, size: 0,
get: func(i int) interface{} { return mutableValues[i] },
}, },
mutableValues: mutableValues, mutableValues: mutableValues,
} }
l.get = func(i int) any {
return l.mutableValues[i]
}
return l
} }
// baseList points to a list containing elements of any type. // baseList points to a list containing elements of any type.
@ -114,7 +119,7 @@ func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
// The `ref.TypeAdapter` enables native type to CEL type conversions. // The `ref.TypeAdapter` enables native type to CEL type conversions.
type baseList struct { type baseList struct {
ref.TypeAdapter ref.TypeAdapter
value interface{} value any
// size indicates the number of elements within the list. // size indicates the number of elements within the list.
// Since objects are immutable the size of a list is static. // Since objects are immutable the size of a list is static.
@ -122,7 +127,7 @@ type baseList struct {
// get returns a value at the specified integer index. // get returns a value at the specified integer index.
// The index is guaranteed to be checked against the list index range. // The index is guaranteed to be checked against the list index range.
get func(int) interface{} get func(int) any
} }
// Add implements the traits.Adder interface method. // Add implements the traits.Adder interface method.
@ -157,7 +162,7 @@ func (l *baseList) Contains(elem ref.Val) ref.Val {
} }
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the underlying list value is assignable to the reflected type return it. // If the underlying list value is assignable to the reflected type return it.
if reflect.TypeOf(l.value).AssignableTo(typeDesc) { if reflect.TypeOf(l.value).AssignableTo(typeDesc) {
return l.value, nil return l.value, nil
@ -240,7 +245,7 @@ func (l *baseList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method. // Get implements the traits.Indexer interface method.
func (l *baseList) Get(index ref.Val) ref.Val { func (l *baseList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index) ind, err := IndexOrError(index)
if err != nil { if err != nil {
return ValOrErr(index, err.Error()) return ValOrErr(index, err.Error())
} }
@ -250,6 +255,11 @@ func (l *baseList) Get(index ref.Val) ref.Val {
return l.NativeToValue(l.get(ind)) return l.NativeToValue(l.get(ind))
} }
// IsZeroValue returns true if the list is empty.
func (l *baseList) IsZeroValue() bool {
return l.size == 0
}
// Iterator implements the traits.Iterable interface method. // Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator { func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l) return newListIterator(l)
@ -266,10 +276,24 @@ func (l *baseList) Type() ref.Type {
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (l *baseList) Value() interface{} { func (l *baseList) Value() any {
return l.value return l.value
} }
// String converts the list to a human readable string form.
func (l *baseList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := 0; i < l.size; i++ {
sb.WriteString(fmt.Sprintf("%v", l.get(i)))
if i != l.size-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only. // mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct { type mutableList struct {
*baseList *baseList
@ -305,7 +329,7 @@ func (l *mutableList) ToImmutableList() traits.Lister {
// The `ref.TypeAdapter` enables native type to CEL type conversions. // The `ref.TypeAdapter` enables native type to CEL type conversions.
type concatList struct { type concatList struct {
ref.TypeAdapter ref.TypeAdapter
value interface{} value any
prevList traits.Lister prevList traits.Lister
nextList traits.Lister nextList traits.Lister
} }
@ -351,8 +375,8 @@ func (l *concatList) Contains(elem ref.Val) ref.Val {
} }
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (l *concatList) ConvertToNative(typeDesc reflect.Type) (any, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]interface{})) combined := NewDynamicList(l.TypeAdapter, l.Value().([]any))
return combined.ConvertToNative(typeDesc) return combined.ConvertToNative(typeDesc)
} }
@ -396,7 +420,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method. // Get implements the traits.Indexer interface method.
func (l *concatList) Get(index ref.Val) ref.Val { func (l *concatList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index) ind, err := IndexOrError(index)
if err != nil { if err != nil {
return ValOrErr(index, err.Error()) return ValOrErr(index, err.Error())
} }
@ -408,6 +432,11 @@ func (l *concatList) Get(index ref.Val) ref.Val {
return l.nextList.Get(offset) return l.nextList.Get(offset)
} }
// IsZeroValue returns true if the list is empty.
func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}
// Iterator implements the traits.Iterable interface method. // Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator { func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l) return newListIterator(l)
@ -418,15 +447,29 @@ func (l *concatList) Size() ref.Val {
return l.prevList.Size().(Int).Add(l.nextList.Size()) return l.prevList.Size().(Int).Add(l.nextList.Size())
} }
// String converts the concatenated list to a human-readable string.
func (l *concatList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := Int(0); i < l.Size().(Int); i++ {
sb.WriteString(fmt.Sprintf("%v", l.Get(i)))
if i != l.Size().(Int)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// Type implements the ref.Val interface method. // Type implements the ref.Val interface method.
func (l *concatList) Type() ref.Type { func (l *concatList) Type() ref.Type {
return ListType return ListType
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (l *concatList) Value() interface{} { func (l *concatList) Value() any {
if l.value == nil { if l.value == nil {
merged := make([]interface{}, l.Size().(Int)) merged := make([]any, l.Size().(Int))
prevLen := l.prevList.Size().(Int) prevLen := l.prevList.Size().(Int)
for i := Int(0); i < prevLen; i++ { for i := Int(0); i < prevLen; i++ {
merged[i] = l.prevList.Get(i).Value() merged[i] = l.prevList.Get(i).Value()
@ -469,7 +512,8 @@ func (it *listIterator) Next() ref.Val {
return nil return nil
} }
func indexOrError(index ref.Val) (int, error) { // IndexOrError converts an input index value into either a lossless integer index or an error.
func IndexOrError(index ref.Val) (int, error) {
switch iv := index.(type) { switch iv := index.(type) {
case Int: case Int:
return int(iv), nil return int(iv), nil

View File

@ -17,20 +17,22 @@ package types
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb" "github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/common/types/traits"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb" structpb "google.golang.org/protobuf/types/known/structpb"
) )
// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs. // NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper { func NewDynamicMap(adapter ref.TypeAdapter, value any) traits.Mapper {
refValue := reflect.ValueOf(value) refValue := reflect.ValueOf(value)
return &baseMap{ return &baseMap{
TypeAdapter: adapter, TypeAdapter: adapter,
@ -65,7 +67,7 @@ func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Map
} }
// NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values. // NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values.
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]interface{}) traits.Mapper { func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]any) traits.Mapper {
return &baseMap{ return &baseMap{
TypeAdapter: adapter, TypeAdapter: adapter,
mapAccessor: newStringIfaceMapAccessor(adapter, value), mapAccessor: newStringIfaceMapAccessor(adapter, value),
@ -125,7 +127,7 @@ type baseMap struct {
mapAccessor mapAccessor
// value is the native Go value upon which the map type operators. // value is the native Go value upon which the map type operators.
value interface{} value any
// size is the number of entries in the map. // size is the number of entries in the map.
size int size int
@ -138,7 +140,7 @@ func (m *baseMap) Contains(index ref.Val) ref.Val {
} }
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and // If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types. // maps with the same key value types.
if reflect.TypeOf(m.value).AssignableTo(typeDesc) { if reflect.TypeOf(m.value).AssignableTo(typeDesc) {
@ -275,18 +277,42 @@ func (m *baseMap) Get(key ref.Val) ref.Val {
return v return v
} }
// IsZeroValue returns true if the map is empty.
func (m *baseMap) IsZeroValue() bool {
return m.size == 0
}
// Size implements the traits.Sizer interface method. // Size implements the traits.Sizer interface method.
func (m *baseMap) Size() ref.Val { func (m *baseMap) Size() ref.Val {
return Int(m.size) return Int(m.size)
} }
// String converts the map into a human-readable string.
func (m *baseMap) String() string {
var sb strings.Builder
sb.WriteString("{")
it := m.Iterator()
i := 0
for it.HasNext() == True {
k := it.Next()
v, _ := m.Find(k)
sb.WriteString(fmt.Sprintf("%v: %v", k, v))
if i != m.size-1 {
sb.WriteString(", ")
}
i++
}
sb.WriteString("}")
return sb.String()
}
// Type implements the ref.Val interface method. // Type implements the ref.Val interface method.
func (m *baseMap) Type() ref.Type { func (m *baseMap) Type() ref.Type {
return MapType return MapType
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (m *baseMap) Value() interface{} { func (m *baseMap) Value() any {
return m.value return m.value
} }
@ -498,7 +524,7 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
} }
} }
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interface{}) mapAccessor { func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{ return &stringIfaceMapAccessor{
TypeAdapter: adapter, TypeAdapter: adapter,
mapVal: mapVal, mapVal: mapVal,
@ -507,7 +533,7 @@ func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interf
type stringIfaceMapAccessor struct { type stringIfaceMapAccessor struct {
ref.TypeAdapter ref.TypeAdapter
mapVal map[string]interface{} mapVal map[string]any
} }
// Find uses native map accesses to find the key, returning (value, true) if present. // Find uses native map accesses to find the key, returning (value, true) if present.
@ -556,7 +582,7 @@ func (m *protoMap) Contains(key ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method. // ConvertToNative implements the ref.Val interface method.
// //
// Note, assignment to Golang struct types is not yet supported. // Note, assignment to Golang struct types is not yet supported.
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and // If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types. // maps with the same key value types.
switch typeDesc { switch typeDesc {
@ -601,9 +627,9 @@ func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
m.value.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool { m.value.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
ntvKey := key.Interface() ntvKey := key.Interface()
ntvVal := val.Interface() ntvVal := val.Interface()
switch ntvVal.(type) { switch pv := ntvVal.(type) {
case protoreflect.Message: case protoreflect.Message:
ntvVal = ntvVal.(protoreflect.Message).Interface() ntvVal = pv.Interface()
} }
if keyType == otherKeyType && valType == otherValType { if keyType == otherKeyType && valType == otherValType {
mapVal.SetMapIndex(reflect.ValueOf(ntvKey), reflect.ValueOf(ntvVal)) mapVal.SetMapIndex(reflect.ValueOf(ntvKey), reflect.ValueOf(ntvVal))
@ -732,6 +758,11 @@ func (m *protoMap) Get(key ref.Val) ref.Val {
return v return v
} }
// IsZeroValue returns true if the map is empty.
func (m *protoMap) IsZeroValue() bool {
return m.value.Len() == 0
}
// Iterator implements the traits.Iterable interface method. // Iterator implements the traits.Iterable interface method.
func (m *protoMap) Iterator() traits.Iterator { func (m *protoMap) Iterator() traits.Iterator {
// Copy the keys to make their order stable. // Copy the keys to make their order stable.
@ -758,7 +789,7 @@ func (m *protoMap) Type() ref.Type {
} }
// Value implements the ref.Val interface method. // Value implements the ref.Val interface method.
func (m *protoMap) Value() interface{} { func (m *protoMap) Value() any {
return m.value return m.value
} }

View File

@ -18,9 +18,10 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb" structpb "google.golang.org/protobuf/types/known/structpb"
) )
@ -34,14 +35,20 @@ var (
// NullValue singleton. // NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE) NullValue = Null(structpb.NullValue_NULL_VALUE)
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE) // golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
) )
// ConvertToNative implements ref.Val.ConvertToNative. // ConvertToNative implements ref.Val.ConvertToNative.
func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() { switch typeDesc.Kind() {
case reflect.Int32: case reflect.Int32:
return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil switch typeDesc {
case jsonNullType:
return structpb.NullValue_NULL_VALUE, nil
case nullReflectType:
return n, nil
}
case reflect.Ptr: case reflect.Ptr:
switch typeDesc { switch typeDesc {
case anyValueType: case anyValueType:
@ -54,6 +61,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return anypb.New(pb.(proto.Message)) return anypb.New(pb.(proto.Message))
case jsonValueType: case jsonValueType:
return structpb.NewNullValue(), nil return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
return nil, nil
} }
case reflect.Interface: case reflect.Interface:
nv := n.Value() nv := n.Value()
@ -86,12 +97,17 @@ func (n Null) Equal(other ref.Val) ref.Val {
return Bool(NullType == other.Type()) return Bool(NullType == other.Type())
} }
// IsZeroValue returns true as null always represents an absent value.
func (n Null) IsZeroValue() bool {
return true
}
// Type implements ref.Val.Type. // Type implements ref.Val.Type.
func (n Null) Type() ref.Type { func (n Null) Type() ref.Type {
return NullType return NullType
} }
// Value implements ref.Val.Value. // Value implements ref.Val.Value.
func (n Null) Value() interface{} { func (n Null) Value() any {
return structpb.NullValue_NULL_VALUE return structpb.NullValue_NULL_VALUE
} }

View File

@ -18,11 +18,12 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb" structpb "google.golang.org/protobuf/types/known/structpb"
) )
@ -52,7 +53,7 @@ func NewObject(adapter ref.TypeAdapter,
typeValue: typeValue} typeValue: typeValue}
} }
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
srcPB := o.value srcPB := o.value
if reflect.TypeOf(srcPB).AssignableTo(typeDesc) { if reflect.TypeOf(srcPB).AssignableTo(typeDesc) {
return srcPB, nil return srcPB, nil
@ -133,6 +134,11 @@ func (o *protoObj) IsSet(field ref.Val) ref.Val {
return False return False
} }
// IsZeroValue returns true if the protobuf object is empty.
func (o *protoObj) IsZeroValue() bool {
return proto.Equal(o.value, o.typeDesc.Zero())
}
func (o *protoObj) Get(index ref.Val) ref.Val { func (o *protoObj) Get(index ref.Val) ref.Val {
protoFieldName, ok := index.(String) protoFieldName, ok := index.(String)
if !ok { if !ok {
@ -154,6 +160,6 @@ func (o *protoObj) Type() ref.Type {
return o.typeValue return o.typeValue
} }
func (o *protoObj) Value() interface{} { func (o *protoObj) Value() any {
return o.value return o.value
} }

View File

@ -0,0 +1,108 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"errors"
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
)
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewTypeValue("optional")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}
)
// OptionalOf returns an optional value which wraps a concrete CEL value.
func OptionalOf(value ref.Val) *Optional {
return &Optional{value: value}
}
// Optional value which points to a value if non-empty.
type Optional struct {
value ref.Val
}
// HasValue returns true if the optional has a value.
func (o *Optional) HasValue() bool {
return o.value != nil
}
// GetValue returns the wrapped value contained in the optional.
func (o *Optional) GetValue() ref.Val {
if !o.HasValue() {
return NewErr("optional.none() dereference")
}
return o.value
}
// ConvertToNative implements the ref.Val interface method.
func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !o.HasValue() {
return nil, errors.New("optional.none() dereference")
}
return o.value.ConvertToNative(typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *Optional) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case OptionalType:
return o
case TypeType:
return OptionalType
}
return NewErr("type conversion error from '%s' to '%s'", OptionalType, typeVal)
}
// Equal determines whether the values contained by two optional values are equal.
func (o *Optional) Equal(other ref.Val) ref.Val {
otherOpt, isOpt := other.(*Optional)
if !isOpt {
return False
}
if !o.HasValue() {
return Bool(!otherOpt.HasValue())
}
if !otherOpt.HasValue() {
return False
}
return o.value.Equal(otherOpt.value)
}
func (o *Optional) String() string {
if o.HasValue() {
return fmt.Sprintf("optional(%v)", o.GetValue())
}
return "optional.none()"
}
// Type implements the ref.Val interface method.
func (o *Optional) Type() ref.Type {
return OptionalType
}
// Value returns the underlying 'Value()' of the wrapped value, if present.
func (o *Optional) Value() any {
if o.value == nil {
return nil
}
return o.value.Value()
}

View File

@ -17,7 +17,7 @@ go_library(
], ],
importpath = "github.com/google/cel-go/common/types/pb", importpath = "github.com/google/cel-go/common/types/pb",
deps = [ deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library", "@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@ -18,9 +18,9 @@ import (
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
) )
// NewEnumValueDescription produces an enum value description with the fully qualified enum value // newEnumValueDescription produces an enum value description with the fully qualified enum value
// name and the enum value descriptor. // name and the enum value descriptor.
func NewEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription { func newEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
return &EnumValueDescription{ return &EnumValueDescription{
enumValueName: name, enumValueName: name,
desc: desc, desc: desc,

View File

@ -18,32 +18,66 @@ import (
"fmt" "fmt"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
) )
// NewFileDescription returns a FileDescription instance with a complete listing of all the message // newFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values declared within any scope in the file. // types and enum values, as well as a map of extensions declared within any scope in the file.
func NewFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) *FileDescription { func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDescription, extensionMap) {
metadata := collectFileMetadata(fileDesc) metadata := collectFileMetadata(fileDesc)
enums := make(map[string]*EnumValueDescription) enums := make(map[string]*EnumValueDescription)
for name, enumVal := range metadata.enumValues { for name, enumVal := range metadata.enumValues {
enums[name] = NewEnumValueDescription(name, enumVal) enums[name] = newEnumValueDescription(name, enumVal)
} }
types := make(map[string]*TypeDescription) types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes { for name, msgType := range metadata.msgTypes {
types[name] = NewTypeDescription(name, msgType) types[name] = newTypeDescription(name, msgType, pbdb.extensions)
}
fileExtMap := make(extensionMap)
for typeName, extensions := range metadata.msgExtensionMap {
messageExtMap, found := fileExtMap[typeName]
if !found {
messageExtMap = make(map[string]*FieldDescription)
}
for _, ext := range extensions {
extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor()
messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc)
}
fileExtMap[typeName] = messageExtMap
} }
return &FileDescription{ return &FileDescription{
name: fileDesc.Path(),
types: types, types: types,
enums: enums, enums: enums,
} }, fileExtMap
} }
// FileDescription holds a map of all types and enum values declared within a proto file. // FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct { type FileDescription struct {
name string
types map[string]*TypeDescription types map[string]*TypeDescription
enums map[string]*EnumValueDescription enums map[string]*EnumValueDescription
} }
// Copy creates a copy of the FileDescription with updated Db references within its types.
func (fd *FileDescription) Copy(pbdb *Db) *FileDescription {
typesCopy := make(map[string]*TypeDescription, len(fd.types))
for k, v := range fd.types {
typesCopy[k] = v.Copy(pbdb)
}
return &FileDescription{
name: fd.name,
types: typesCopy,
enums: fd.enums,
}
}
// GetName returns the fully qualified file path for the file.
func (fd *FileDescription) GetName() string {
return fd.name
}
// GetEnumDescription returns an EnumDescription for a qualified enum value // GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file. // name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) { func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) {
@ -94,6 +128,10 @@ type fileMetadata struct {
msgTypes map[string]protoreflect.MessageDescriptor msgTypes map[string]protoreflect.MessageDescriptor
// enumValues maps from fully-qualified enum value to enum value descriptor. // enumValues maps from fully-qualified enum value to enum value descriptor.
enumValues map[string]protoreflect.EnumValueDescriptor enumValues map[string]protoreflect.EnumValueDescriptor
// msgExtensionMap maps from the protobuf message name being extended to a set of extensions
// for the type.
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor
// TODO: support enum type definitions for use in future type-check enhancements. // TODO: support enum type definitions for use in future type-check enhancements.
} }
@ -102,28 +140,38 @@ type fileMetadata struct {
func collectFileMetadata(fileDesc protoreflect.FileDescriptor) *fileMetadata { func collectFileMetadata(fileDesc protoreflect.FileDescriptor) *fileMetadata {
msgTypes := make(map[string]protoreflect.MessageDescriptor) msgTypes := make(map[string]protoreflect.MessageDescriptor)
enumValues := make(map[string]protoreflect.EnumValueDescriptor) enumValues := make(map[string]protoreflect.EnumValueDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues) msgExtensionMap := make(map[string][]protoreflect.ExtensionDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues, msgExtensionMap)
collectEnumValues(fileDesc.Enums(), enumValues) collectEnumValues(fileDesc.Enums(), enumValues)
collectExtensions(fileDesc.Extensions(), msgExtensionMap)
return &fileMetadata{ return &fileMetadata{
msgTypes: msgTypes, msgTypes: msgTypes,
enumValues: enumValues, enumValues: enumValues,
msgExtensionMap: msgExtensionMap,
} }
} }
// collectMsgTypes recursively collects messages, nested messages, and nested enums into a map of // collectMsgTypes recursively collects messages, nested messages, and nested enums into a map of
// fully qualified protobuf names to descriptors. // fully qualified protobuf names to descriptors.
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors, msgTypeMap map[string]protoreflect.MessageDescriptor, enumValueMap map[string]protoreflect.EnumValueDescriptor) { func collectMsgTypes(msgTypes protoreflect.MessageDescriptors,
msgTypeMap map[string]protoreflect.MessageDescriptor,
enumValueMap map[string]protoreflect.EnumValueDescriptor,
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < msgTypes.Len(); i++ { for i := 0; i < msgTypes.Len(); i++ {
msgType := msgTypes.Get(i) msgType := msgTypes.Get(i)
msgTypeMap[string(msgType.FullName())] = msgType msgTypeMap[string(msgType.FullName())] = msgType
nestedMsgTypes := msgType.Messages() nestedMsgTypes := msgType.Messages()
if nestedMsgTypes.Len() != 0 { if nestedMsgTypes.Len() != 0 {
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap) collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap, msgExtensionMap)
} }
nestedEnumTypes := msgType.Enums() nestedEnumTypes := msgType.Enums()
if nestedEnumTypes.Len() != 0 { if nestedEnumTypes.Len() != 0 {
collectEnumValues(nestedEnumTypes, enumValueMap) collectEnumValues(nestedEnumTypes, enumValueMap)
} }
nestedExtensions := msgType.Extensions()
if nestedExtensions.Len() != 0 {
collectExtensions(nestedExtensions, msgExtensionMap)
}
} }
} }
@ -139,3 +187,16 @@ func collectEnumValues(enumTypes protoreflect.EnumDescriptors, enumValueMap map[
} }
} }
} }
func collectExtensions(extensions protoreflect.ExtensionDescriptors, msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < extensions.Len(); i++ {
ext := extensions.Get(i)
extendsMsg := string(ext.ContainingMessage().FullName())
msgExts, found := msgExtensionMap[extendsMsg]
if !found {
msgExts = []protoreflect.ExtensionDescriptor{}
}
msgExts = append(msgExts, ext)
msgExtensionMap[extendsMsg] = msgExts
}
}

View File

@ -40,13 +40,19 @@ type Db struct {
revFileDescriptorMap map[string]*FileDescription revFileDescriptorMap map[string]*FileDescription
// files contains the deduped set of FileDescriptions whose types are contained in the pb.Db. // files contains the deduped set of FileDescriptions whose types are contained in the pb.Db.
files []*FileDescription files []*FileDescription
// extensions contains the mapping between a given type name, extension name and its FieldDescription
extensions map[string]map[string]*FieldDescription
} }
// extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription
type extensionMap = map[string]map[string]*FieldDescription
var ( var (
// DefaultDb used at evaluation time or unless overridden at check time. // DefaultDb used at evaluation time or unless overridden at check time.
DefaultDb = &Db{ DefaultDb = &Db{
revFileDescriptorMap: make(map[string]*FileDescription), revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{}, files: []*FileDescription{},
extensions: make(extensionMap),
} }
) )
@ -80,6 +86,7 @@ func NewDb() *Db {
pbdb := &Db{ pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription), revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{}, files: []*FileDescription{},
extensions: make(extensionMap),
} }
// The FileDescription objects in the default db contain lazily initialized TypeDescription // The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow // values which may point to the state contained in the DefaultDb irrespective of this shallow
@ -96,19 +103,34 @@ func NewDb() *Db {
// Copy creates a copy of the current database with its own internal descriptor mapping. // Copy creates a copy of the current database with its own internal descriptor mapping.
func (pbdb *Db) Copy() *Db { func (pbdb *Db) Copy() *Db {
copy := NewDb() copy := NewDb()
for k, v := range pbdb.revFileDescriptorMap { for _, fd := range pbdb.files {
copy.revFileDescriptorMap[k] = v
}
for _, f := range pbdb.files {
hasFile := false hasFile := false
for _, f2 := range copy.files { for _, fd2 := range copy.files {
if f2 == f { if fd2 == fd {
hasFile = true hasFile = true
} }
} }
if !hasFile { if !hasFile {
copy.files = append(copy.files, f) fd = fd.Copy(copy)
copy.files = append(copy.files, fd)
} }
for _, enumValName := range fd.GetEnumNames() {
copy.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
copy.revFileDescriptorMap[msgTypeName] = fd
}
copy.revFileDescriptorMap[fd.GetName()] = fd
}
for typeName, extFieldMap := range pbdb.extensions {
copyExtFieldMap, found := copy.extensions[typeName]
if !found {
copyExtFieldMap = make(map[string]*FieldDescription, len(extFieldMap))
}
for extFieldName, fd := range extFieldMap {
copyExtFieldMap[extFieldName] = fd
}
copy.extensions[typeName] = copyExtFieldMap
} }
return copy return copy
} }
@ -137,17 +159,30 @@ func (pbdb *Db) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) (*FileD
if err == nil { if err == nil {
fileDesc = globalFD fileDesc = globalFD
} }
fd = NewFileDescription(fileDesc, pbdb) var fileExtMap extensionMap
fd, fileExtMap = newFileDescription(fileDesc, pbdb)
for _, enumValName := range fd.GetEnumNames() { for _, enumValName := range fd.GetEnumNames() {
pbdb.revFileDescriptorMap[enumValName] = fd pbdb.revFileDescriptorMap[enumValName] = fd
} }
for _, msgTypeName := range fd.GetTypeNames() { for _, msgTypeName := range fd.GetTypeNames() {
pbdb.revFileDescriptorMap[msgTypeName] = fd pbdb.revFileDescriptorMap[msgTypeName] = fd
} }
pbdb.revFileDescriptorMap[fileDesc.Path()] = fd pbdb.revFileDescriptorMap[fd.GetName()] = fd
// Return the specific file descriptor registered. // Return the specific file descriptor registered.
pbdb.files = append(pbdb.files, fd) pbdb.files = append(pbdb.files, fd)
// Index the protobuf message extensions from the file into the pbdb
for typeName, extMap := range fileExtMap {
typeExtMap, found := pbdb.extensions[typeName]
if !found {
pbdb.extensions[typeName] = extMap
continue
}
for extName, field := range extMap {
typeExtMap[extName] = field
}
}
return fd, nil return fd, nil
} }

Some files were not shown because too many files have changed in this diff Show More