548 lines
12 KiB
Go
548 lines
12 KiB
Go
package pattern
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
"reflect"
|
|
)
|
|
|
|
var tokensByString = map[string]Token{
|
|
"INT": Token(token.INT),
|
|
"FLOAT": Token(token.FLOAT),
|
|
"IMAG": Token(token.IMAG),
|
|
"CHAR": Token(token.CHAR),
|
|
"STRING": Token(token.STRING),
|
|
"+": Token(token.ADD),
|
|
"-": Token(token.SUB),
|
|
"*": Token(token.MUL),
|
|
"/": Token(token.QUO),
|
|
"%": Token(token.REM),
|
|
"&": Token(token.AND),
|
|
"|": Token(token.OR),
|
|
"^": Token(token.XOR),
|
|
"<<": Token(token.SHL),
|
|
">>": Token(token.SHR),
|
|
"&^": Token(token.AND_NOT),
|
|
"+=": Token(token.ADD_ASSIGN),
|
|
"-=": Token(token.SUB_ASSIGN),
|
|
"*=": Token(token.MUL_ASSIGN),
|
|
"/=": Token(token.QUO_ASSIGN),
|
|
"%=": Token(token.REM_ASSIGN),
|
|
"&=": Token(token.AND_ASSIGN),
|
|
"|=": Token(token.OR_ASSIGN),
|
|
"^=": Token(token.XOR_ASSIGN),
|
|
"<<=": Token(token.SHL_ASSIGN),
|
|
">>=": Token(token.SHR_ASSIGN),
|
|
"&^=": Token(token.AND_NOT_ASSIGN),
|
|
"&&": Token(token.LAND),
|
|
"||": Token(token.LOR),
|
|
"<-": Token(token.ARROW),
|
|
"++": Token(token.INC),
|
|
"--": Token(token.DEC),
|
|
"==": Token(token.EQL),
|
|
"<": Token(token.LSS),
|
|
">": Token(token.GTR),
|
|
"=": Token(token.ASSIGN),
|
|
"!": Token(token.NOT),
|
|
"!=": Token(token.NEQ),
|
|
"<=": Token(token.LEQ),
|
|
">=": Token(token.GEQ),
|
|
":=": Token(token.DEFINE),
|
|
"...": Token(token.ELLIPSIS),
|
|
"IMPORT": Token(token.IMPORT),
|
|
"VAR": Token(token.VAR),
|
|
"TYPE": Token(token.TYPE),
|
|
"CONST": Token(token.CONST),
|
|
"BREAK": Token(token.BREAK),
|
|
"CONTINUE": Token(token.CONTINUE),
|
|
"GOTO": Token(token.GOTO),
|
|
"FALLTHROUGH": Token(token.FALLTHROUGH),
|
|
}
|
|
|
|
func maybeToken(node Node) (Node, bool) {
|
|
if node, ok := node.(String); ok {
|
|
if tok, ok := tokensByString[string(node)]; ok {
|
|
return tok, true
|
|
}
|
|
return node, false
|
|
}
|
|
return node, false
|
|
}
|
|
|
|
func isNil(v interface{}) bool {
|
|
if v == nil {
|
|
return true
|
|
}
|
|
if _, ok := v.(Nil); ok {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
type matcher interface {
|
|
Match(*Matcher, interface{}) (interface{}, bool)
|
|
}
|
|
|
|
type State = map[string]interface{}
|
|
|
|
type Matcher struct {
|
|
TypesInfo *types.Info
|
|
State State
|
|
}
|
|
|
|
func (m *Matcher) fork() *Matcher {
|
|
state := make(State, len(m.State))
|
|
for k, v := range m.State {
|
|
state[k] = v
|
|
}
|
|
return &Matcher{
|
|
TypesInfo: m.TypesInfo,
|
|
State: state,
|
|
}
|
|
}
|
|
|
|
func (m *Matcher) merge(mc *Matcher) {
|
|
m.State = mc.State
|
|
}
|
|
|
|
func (m *Matcher) Match(a Node, b ast.Node) bool {
|
|
m.State = State{}
|
|
_, ok := match(m, a, b)
|
|
return ok
|
|
}
|
|
|
|
func Match(a Node, b ast.Node) (*Matcher, bool) {
|
|
m := &Matcher{}
|
|
ret := m.Match(a, b)
|
|
return m, ret
|
|
}
|
|
|
|
// Match two items, which may be (Node, AST) or (AST, AST)
|
|
func match(m *Matcher, l, r interface{}) (interface{}, bool) {
|
|
if _, ok := r.(Node); ok {
|
|
panic("Node mustn't be on right side of match")
|
|
}
|
|
|
|
switch l := l.(type) {
|
|
case *ast.ParenExpr:
|
|
return match(m, l.X, r)
|
|
case *ast.ExprStmt:
|
|
return match(m, l.X, r)
|
|
case *ast.DeclStmt:
|
|
return match(m, l.Decl, r)
|
|
case *ast.LabeledStmt:
|
|
return match(m, l.Stmt, r)
|
|
case *ast.BlockStmt:
|
|
return match(m, l.List, r)
|
|
case *ast.FieldList:
|
|
return match(m, l.List, r)
|
|
}
|
|
|
|
switch r := r.(type) {
|
|
case *ast.ParenExpr:
|
|
return match(m, l, r.X)
|
|
case *ast.ExprStmt:
|
|
return match(m, l, r.X)
|
|
case *ast.DeclStmt:
|
|
return match(m, l, r.Decl)
|
|
case *ast.LabeledStmt:
|
|
return match(m, l, r.Stmt)
|
|
case *ast.BlockStmt:
|
|
if r == nil {
|
|
return match(m, l, nil)
|
|
}
|
|
return match(m, l, r.List)
|
|
case *ast.FieldList:
|
|
if r == nil {
|
|
return match(m, l, nil)
|
|
}
|
|
return match(m, l, r.List)
|
|
case *ast.BasicLit:
|
|
if r == nil {
|
|
return match(m, l, nil)
|
|
}
|
|
}
|
|
|
|
if l, ok := l.(matcher); ok {
|
|
return l.Match(m, r)
|
|
}
|
|
|
|
if l, ok := l.(Node); ok {
|
|
// Matching of pattern with concrete value
|
|
return matchNodeAST(m, l, r)
|
|
}
|
|
|
|
if l == nil || r == nil {
|
|
return nil, l == r
|
|
}
|
|
|
|
{
|
|
ln, ok1 := l.(ast.Node)
|
|
rn, ok2 := r.(ast.Node)
|
|
if ok1 && ok2 {
|
|
return matchAST(m, ln, rn)
|
|
}
|
|
}
|
|
|
|
{
|
|
obj, ok := l.(types.Object)
|
|
if ok {
|
|
switch r := r.(type) {
|
|
case *ast.Ident:
|
|
return obj, obj == m.TypesInfo.ObjectOf(r)
|
|
case *ast.SelectorExpr:
|
|
return obj, obj == m.TypesInfo.ObjectOf(r.Sel)
|
|
default:
|
|
return obj, false
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
ln, ok1 := l.([]ast.Expr)
|
|
rn, ok2 := r.([]ast.Expr)
|
|
if ok1 || ok2 {
|
|
if ok1 && !ok2 {
|
|
rn = []ast.Expr{r.(ast.Expr)}
|
|
} else if !ok1 && ok2 {
|
|
ln = []ast.Expr{l.(ast.Expr)}
|
|
}
|
|
|
|
if len(ln) != len(rn) {
|
|
return nil, false
|
|
}
|
|
for i, ll := range ln {
|
|
if _, ok := match(m, ll, rn[i]); !ok {
|
|
return nil, false
|
|
}
|
|
}
|
|
return r, true
|
|
}
|
|
}
|
|
|
|
{
|
|
ln, ok1 := l.([]ast.Stmt)
|
|
rn, ok2 := r.([]ast.Stmt)
|
|
if ok1 || ok2 {
|
|
if ok1 && !ok2 {
|
|
rn = []ast.Stmt{r.(ast.Stmt)}
|
|
} else if !ok1 && ok2 {
|
|
ln = []ast.Stmt{l.(ast.Stmt)}
|
|
}
|
|
|
|
if len(ln) != len(rn) {
|
|
return nil, false
|
|
}
|
|
for i, ll := range ln {
|
|
if _, ok := match(m, ll, rn[i]); !ok {
|
|
return nil, false
|
|
}
|
|
}
|
|
return r, true
|
|
}
|
|
}
|
|
|
|
{
|
|
ln, ok1 := l.([]*ast.Field)
|
|
rn, ok2 := r.([]*ast.Field)
|
|
if ok1 || ok2 {
|
|
if ok1 && !ok2 {
|
|
rn = []*ast.Field{r.(*ast.Field)}
|
|
} else if !ok1 && ok2 {
|
|
ln = []*ast.Field{l.(*ast.Field)}
|
|
}
|
|
|
|
if len(ln) != len(rn) {
|
|
return nil, false
|
|
}
|
|
for i, ll := range ln {
|
|
if _, ok := match(m, ll, rn[i]); !ok {
|
|
return nil, false
|
|
}
|
|
}
|
|
return r, true
|
|
}
|
|
}
|
|
|
|
panic(fmt.Sprintf("unsupported comparison: %T and %T", l, r))
|
|
}
|
|
|
|
// Match a Node with an AST node
|
|
func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) {
|
|
switch b := b.(type) {
|
|
case []ast.Stmt:
|
|
// 'a' is not a List or we'd be using its Match
|
|
// implementation.
|
|
|
|
if len(b) != 1 {
|
|
return nil, false
|
|
}
|
|
return match(m, a, b[0])
|
|
case []ast.Expr:
|
|
// 'a' is not a List or we'd be using its Match
|
|
// implementation.
|
|
|
|
if len(b) != 1 {
|
|
return nil, false
|
|
}
|
|
return match(m, a, b[0])
|
|
case ast.Node:
|
|
ra := reflect.ValueOf(a)
|
|
rb := reflect.ValueOf(b).Elem()
|
|
|
|
if ra.Type().Name() != rb.Type().Name() {
|
|
return nil, false
|
|
}
|
|
|
|
for i := 0; i < ra.NumField(); i++ {
|
|
af := ra.Field(i)
|
|
fieldName := ra.Type().Field(i).Name
|
|
bf := rb.FieldByName(fieldName)
|
|
if (bf == reflect.Value{}) {
|
|
panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a))
|
|
}
|
|
ai := af.Interface()
|
|
bi := bf.Interface()
|
|
if ai == nil {
|
|
return b, bi == nil
|
|
}
|
|
if _, ok := match(m, ai.(Node), bi); !ok {
|
|
return b, false
|
|
}
|
|
}
|
|
return b, true
|
|
case nil:
|
|
return nil, a == Nil{}
|
|
default:
|
|
panic(fmt.Sprintf("unhandled type %T", b))
|
|
}
|
|
}
|
|
|
|
// Match two AST nodes
|
|
func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) {
|
|
ra := reflect.ValueOf(a)
|
|
rb := reflect.ValueOf(b)
|
|
|
|
if ra.Type() != rb.Type() {
|
|
return nil, false
|
|
}
|
|
if ra.IsNil() || rb.IsNil() {
|
|
return rb, ra.IsNil() == rb.IsNil()
|
|
}
|
|
|
|
ra = ra.Elem()
|
|
rb = rb.Elem()
|
|
for i := 0; i < ra.NumField(); i++ {
|
|
af := ra.Field(i)
|
|
bf := rb.Field(i)
|
|
if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup {
|
|
continue
|
|
}
|
|
|
|
switch af.Kind() {
|
|
case reflect.Slice:
|
|
if af.Len() != bf.Len() {
|
|
return nil, false
|
|
}
|
|
for j := 0; j < af.Len(); j++ {
|
|
if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok {
|
|
return nil, false
|
|
}
|
|
}
|
|
case reflect.String:
|
|
if af.String() != bf.String() {
|
|
return nil, false
|
|
}
|
|
case reflect.Int:
|
|
if af.Int() != bf.Int() {
|
|
return nil, false
|
|
}
|
|
case reflect.Bool:
|
|
if af.Bool() != bf.Bool() {
|
|
return nil, false
|
|
}
|
|
case reflect.Ptr, reflect.Interface:
|
|
if _, ok := match(m, af.Interface(), bf.Interface()); !ok {
|
|
return nil, false
|
|
}
|
|
default:
|
|
panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface()))
|
|
}
|
|
}
|
|
return b, true
|
|
}
|
|
|
|
func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
if isNil(b.Node) {
|
|
v, ok := m.State[b.Name]
|
|
if ok {
|
|
// Recall value
|
|
return match(m, v, node)
|
|
}
|
|
// Matching anything
|
|
b.Node = Any{}
|
|
}
|
|
|
|
// Store value
|
|
if _, ok := m.State[b.Name]; ok {
|
|
panic(fmt.Sprintf("binding already created: %s", b.Name))
|
|
}
|
|
new, ret := match(m, b.Node, node)
|
|
if ret {
|
|
m.State[b.Name] = new
|
|
}
|
|
return new, ret
|
|
}
|
|
|
|
func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
return node, true
|
|
}
|
|
|
|
func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
v := reflect.ValueOf(node)
|
|
if v.Kind() == reflect.Slice {
|
|
if isNil(l.Head) {
|
|
return node, v.Len() == 0
|
|
}
|
|
if v.Len() == 0 {
|
|
return nil, false
|
|
}
|
|
// OPT(dh): don't check the entire tail if head didn't match
|
|
_, ok1 := match(m, l.Head, v.Index(0).Interface())
|
|
_, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface())
|
|
return node, ok1 && ok2
|
|
}
|
|
// Our empty list does not equal an untyped Go nil. This way, we can
|
|
// tell apart an if with no else and an if with an empty else.
|
|
return nil, false
|
|
}
|
|
|
|
func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
switch o := node.(type) {
|
|
case token.Token:
|
|
if tok, ok := maybeToken(s); ok {
|
|
return match(m, tok, node)
|
|
}
|
|
return nil, false
|
|
case string:
|
|
return o, string(s) == o
|
|
default:
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
o, ok := node.(token.Token)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return o, token.Token(tok) == o
|
|
}
|
|
|
|
func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
return nil, isNil(node) || reflect.ValueOf(node).IsNil()
|
|
}
|
|
|
|
func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
r, ok := match(m, Ident(builtin), node)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
ident := r.(*ast.Ident)
|
|
obj := m.TypesInfo.ObjectOf(ident)
|
|
if obj != types.Universe.Lookup(ident.Name) {
|
|
return nil, false
|
|
}
|
|
return ident, true
|
|
}
|
|
|
|
func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
r, ok := match(m, Ident(obj), node)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
ident := r.(*ast.Ident)
|
|
|
|
id := m.TypesInfo.ObjectOf(ident)
|
|
_, ok = match(m, obj.Name, ident.Name)
|
|
return id, ok
|
|
}
|
|
|
|
func (fn Function) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
var name string
|
|
var obj types.Object
|
|
|
|
r, ok := match(m, Or{Nodes: []Node{Ident{Any{}}, SelectorExpr{Any{}, Any{}}}}, node)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
switch r := r.(type) {
|
|
case *ast.Ident:
|
|
obj = m.TypesInfo.ObjectOf(r)
|
|
switch obj := obj.(type) {
|
|
case *types.Func:
|
|
// OPT(dh): optimize this similar to code.FuncName
|
|
name = obj.FullName()
|
|
case *types.Builtin:
|
|
name = obj.Name()
|
|
default:
|
|
return nil, false
|
|
}
|
|
case *ast.SelectorExpr:
|
|
var ok bool
|
|
obj, ok = m.TypesInfo.ObjectOf(r.Sel).(*types.Func)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
// OPT(dh): optimize this similar to code.FuncName
|
|
name = obj.(*types.Func).FullName()
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
_, ok = match(m, fn.Name, name)
|
|
return obj, ok
|
|
}
|
|
|
|
func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
for _, opt := range or.Nodes {
|
|
mc := m.fork()
|
|
if ret, ok := match(mc, opt, node); ok {
|
|
m.merge(mc)
|
|
return ret, true
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) {
|
|
_, ok := match(m, not.Node, node)
|
|
if ok {
|
|
return nil, false
|
|
}
|
|
return node, true
|
|
}
|
|
|
|
var (
|
|
// Types of fields in go/ast structs that we want to skip
|
|
rtTokPos = reflect.TypeOf(token.Pos(0))
|
|
rtObject = reflect.TypeOf((*ast.Object)(nil))
|
|
rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil))
|
|
)
|
|
|
|
var (
|
|
_ matcher = Binding{}
|
|
_ matcher = Any{}
|
|
_ matcher = List{}
|
|
_ matcher = String("")
|
|
_ matcher = Token(0)
|
|
_ matcher = Nil{}
|
|
_ matcher = Builtin{}
|
|
_ matcher = Object{}
|
|
_ matcher = Function{}
|
|
_ matcher = Or{}
|
|
_ matcher = Not{}
|
|
)
|