diff options
Diffstat (limited to 'vendor/google.golang.org/appengine/cmd/aefix')
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/ae.go | 185 | ||||
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/ae_test.go | 144 | ||||
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/fix.go | 848 | ||||
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/main.go | 258 | ||||
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/main_test.go | 129 | ||||
-rw-r--r-- | vendor/google.golang.org/appengine/cmd/aefix/typecheck.go | 673 |
6 files changed, 2237 insertions, 0 deletions
diff --git a/vendor/google.golang.org/appengine/cmd/aefix/ae.go b/vendor/google.golang.org/appengine/cmd/aefix/ae.go new file mode 100644 index 000000000..0fe2d4ae9 --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/ae.go @@ -0,0 +1,185 @@ +// Copyright 2016 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "path" + "strconv" + "strings" +) + +const ( + ctxPackage = "golang.org/x/net/context" + + newPackageBase = "google.golang.org/" + stutterPackage = false +) + +func init() { + register(fix{ + "ae", + "2016-04-15", + aeFn, + `Update old App Engine APIs to new App Engine APIs`, + }) +} + +// logMethod is the set of methods on appengine.Context used for logging. +var logMethod = map[string]bool{ + "Debugf": true, + "Infof": true, + "Warningf": true, + "Errorf": true, + "Criticalf": true, +} + +// mapPackage turns "appengine" into "google.golang.org/appengine", etc. +func mapPackage(s string) string { + if stutterPackage { + s += "/" + path.Base(s) + } + return newPackageBase + s +} + +func aeFn(f *ast.File) bool { + // During the walk, we track the last thing seen that looks like + // an appengine.Context, and reset it once the walk leaves a func. + var lastContext *ast.Ident + + fixed := false + + // Update imports. + mainImp := "appengine" + for _, imp := range f.Imports { + pth, _ := strconv.Unquote(imp.Path.Value) + if pth == "appengine" || strings.HasPrefix(pth, "appengine/") { + newPth := mapPackage(pth) + imp.Path.Value = strconv.Quote(newPth) + fixed = true + + if pth == "appengine" { + mainImp = newPth + } + } + } + + // Update any API changes. + walk(f, func(n interface{}) { + if ft, ok := n.(*ast.FuncType); ok && ft.Params != nil { + // See if this func has an `appengine.Context arg`. + // If so, remember its identifier. + for _, param := range ft.Params.List { + if !isPkgDot(param.Type, "appengine", "Context") { + continue + } + if len(param.Names) == 1 { + lastContext = param.Names[0] + break + } + } + return + } + + if as, ok := n.(*ast.AssignStmt); ok { + if len(as.Lhs) == 1 && len(as.Rhs) == 1 { + // If this node is an assignment from an appengine.NewContext invocation, + // remember the identifier on the LHS. + if isCall(as.Rhs[0], "appengine", "NewContext") { + if ident, ok := as.Lhs[0].(*ast.Ident); ok { + lastContext = ident + return + } + } + // x (=|:=) appengine.Timeout(y, z) + // should become + // x, _ (=|:=) context.WithTimeout(y, z) + if isCall(as.Rhs[0], "appengine", "Timeout") { + addImport(f, ctxPackage) + as.Lhs = append(as.Lhs, ast.NewIdent("_")) + // isCall already did the type checking. + sel := as.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr) + sel.X = ast.NewIdent("context") + sel.Sel = ast.NewIdent("WithTimeout") + fixed = true + return + } + } + return + } + + // If this node is a FuncDecl, we've finished the function, so reset lastContext. + if _, ok := n.(*ast.FuncDecl); ok { + lastContext = nil + return + } + + if call, ok := n.(*ast.CallExpr); ok { + if isPkgDot(call.Fun, "appengine", "Datacenter") && len(call.Args) == 0 { + insertContext(f, call, lastContext) + fixed = true + return + } + if isPkgDot(call.Fun, "taskqueue", "QueueStats") && len(call.Args) == 3 { + call.Args = call.Args[:2] // drop last arg + fixed = true + return + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + if lastContext != nil && refersTo(sel.X, lastContext) && logMethod[sel.Sel.Name] { + // c.Errorf(...) + // should become + // log.Errorf(c, ...) + addImport(f, mapPackage("appengine/log")) + sel.X = &ast.Ident{ // ast.NewIdent doesn't preserve the position. + NamePos: sel.X.Pos(), + Name: "log", + } + insertContext(f, call, lastContext) + fixed = true + return + } + } + }) + + // Change any `appengine.Context` to `context.Context`. + // Do this in a separate walk because the previous walk + // wants to identify "appengine.Context". + walk(f, func(n interface{}) { + expr, ok := n.(ast.Expr) + if ok && isPkgDot(expr, "appengine", "Context") { + addImport(f, ctxPackage) + // isPkgDot did the type checking. + n.(*ast.SelectorExpr).X.(*ast.Ident).Name = "context" + fixed = true + return + } + }) + + // The changes above might remove the need to import "appengine". + // Check if it's used, and drop it if it isn't. + if fixed && !usesImport(f, mainImp) { + deleteImport(f, mainImp) + } + + return fixed +} + +// ctx may be nil. +func insertContext(f *ast.File, call *ast.CallExpr, ctx *ast.Ident) { + if ctx == nil { + // context is unknown, so use a plain "ctx". + ctx = ast.NewIdent("ctx") + } else { + // Create a fresh *ast.Ident so we drop the position information. + ctx = ast.NewIdent(ctx.Name) + } + + call.Args = append([]ast.Expr{ctx}, call.Args...) +} diff --git a/vendor/google.golang.org/appengine/cmd/aefix/ae_test.go b/vendor/google.golang.org/appengine/cmd/aefix/ae_test.go new file mode 100644 index 000000000..21f5695b9 --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/ae_test.go @@ -0,0 +1,144 @@ +// Copyright 2016 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(aeTests, nil) +} + +var aeTests = []testCase{ + // Collection of fixes: + // - imports + // - appengine.Timeout -> context.WithTimeout + // - add ctx arg to appengine.Datacenter + // - logging API + { + Name: "ae.0", + In: `package foo + +import ( + "net/http" + "time" + + "appengine" + "appengine/datastore" +) + +func f(w http.ResponseWriter, r *http.Request) { + c := appengine.NewContext(r) + + c = appengine.Timeout(c, 5*time.Second) + err := datastore.ErrNoSuchEntity + c.Errorf("Something interesting happened: %v", err) + _ = appengine.Datacenter() +} +`, + Out: `package foo + +import ( + "net/http" + "time" + + "golang.org/x/net/context" + "google.golang.org/appengine" + "google.golang.org/appengine/datastore" + "google.golang.org/appengine/log" +) + +func f(w http.ResponseWriter, r *http.Request) { + c := appengine.NewContext(r) + + c, _ = context.WithTimeout(c, 5*time.Second) + err := datastore.ErrNoSuchEntity + log.Errorf(c, "Something interesting happened: %v", err) + _ = appengine.Datacenter(c) +} +`, + }, + + // Updating a function that takes an appengine.Context arg. + { + Name: "ae.1", + In: `package foo + +import ( + "appengine" +) + +func LogSomething(c2 appengine.Context) { + c2.Warningf("Stand back! I'm going to try science!") +} +`, + Out: `package foo + +import ( + "golang.org/x/net/context" + "google.golang.org/appengine/log" +) + +func LogSomething(c2 context.Context) { + log.Warningf(c2, "Stand back! I'm going to try science!") +} +`, + }, + + // Less widely used API changes: + // - drop maxTasks arg to taskqueue.QueueStats + { + Name: "ae.2", + In: `package foo + +import ( + "appengine" + "appengine/taskqueue" +) + +func f(ctx appengine.Context) { + stats, err := taskqueue.QueueStats(ctx, []string{"one", "two"}, 0) +} +`, + Out: `package foo + +import ( + "golang.org/x/net/context" + "google.golang.org/appengine/taskqueue" +) + +func f(ctx context.Context) { + stats, err := taskqueue.QueueStats(ctx, []string{"one", "two"}) +} +`, + }, + + // Check that the main "appengine" import will not be dropped + // if an appengine.Context -> context.Context change happens + // but the appengine package is still referenced. + { + Name: "ae.3", + In: `package foo + +import ( + "appengine" + "io" +) + +func f(ctx appengine.Context, w io.Writer) { + _ = appengine.IsDevAppServer() +} +`, + Out: `package foo + +import ( + "golang.org/x/net/context" + "google.golang.org/appengine" + "io" +) + +func f(ctx context.Context, w io.Writer) { + _ = appengine.IsDevAppServer() +} +`, + }, +} diff --git a/vendor/google.golang.org/appengine/cmd/aefix/fix.go b/vendor/google.golang.org/appengine/cmd/aefix/fix.go new file mode 100644 index 000000000..a100be794 --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/fix.go @@ -0,0 +1,848 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path" + "reflect" + "strconv" + "strings" +) + +type fix struct { + name string + date string // date that fix was introduced, in YYYY-MM-DD format + f func(*ast.File) bool + desc string +} + +// main runs sort.Sort(byName(fixes)) before printing list of fixes. +type byName []fix + +func (f byName) Len() int { return len(f) } +func (f byName) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f byName) Less(i, j int) bool { return f[i].name < f[j].name } + +// main runs sort.Sort(byDate(fixes)) before applying fixes. +type byDate []fix + +func (f byDate) Len() int { return len(f) } +func (f byDate) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date } + +var fixes []fix + +func register(f fix) { + fixes = append(fixes, f) +} + +// walk traverses the AST x, calling visit(y) for each node y in the tree but +// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt, +// in a bottom-up traversal. +func walk(x interface{}, visit func(interface{})) { + walkBeforeAfter(x, nop, visit) +} + +func nop(interface{}) {} + +// walkBeforeAfter is like walk but calls before(x) before traversing +// x's children and after(x) afterward. +func walkBeforeAfter(x interface{}, before, after func(interface{})) { + before(x) + + switch n := x.(type) { + default: + panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x)) + + case nil: + + // pointers to interfaces + case *ast.Decl: + walkBeforeAfter(*n, before, after) + case *ast.Expr: + walkBeforeAfter(*n, before, after) + case *ast.Spec: + walkBeforeAfter(*n, before, after) + case *ast.Stmt: + walkBeforeAfter(*n, before, after) + + // pointers to struct pointers + case **ast.BlockStmt: + walkBeforeAfter(*n, before, after) + case **ast.CallExpr: + walkBeforeAfter(*n, before, after) + case **ast.FieldList: + walkBeforeAfter(*n, before, after) + case **ast.FuncType: + walkBeforeAfter(*n, before, after) + case **ast.Ident: + walkBeforeAfter(*n, before, after) + case **ast.BasicLit: + walkBeforeAfter(*n, before, after) + + // pointers to slices + case *[]ast.Decl: + walkBeforeAfter(*n, before, after) + case *[]ast.Expr: + walkBeforeAfter(*n, before, after) + case *[]*ast.File: + walkBeforeAfter(*n, before, after) + case *[]*ast.Ident: + walkBeforeAfter(*n, before, after) + case *[]ast.Spec: + walkBeforeAfter(*n, before, after) + case *[]ast.Stmt: + walkBeforeAfter(*n, before, after) + + // These are ordered and grouped to match ../../pkg/go/ast/ast.go + case *ast.Field: + walkBeforeAfter(&n.Names, before, after) + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Tag, before, after) + case *ast.FieldList: + for _, field := range n.List { + walkBeforeAfter(field, before, after) + } + case *ast.BadExpr: + case *ast.Ident: + case *ast.Ellipsis: + walkBeforeAfter(&n.Elt, before, after) + case *ast.BasicLit: + case *ast.FuncLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CompositeLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Elts, before, after) + case *ast.ParenExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.SelectorExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.IndexExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Index, before, after) + case *ast.SliceExpr: + walkBeforeAfter(&n.X, before, after) + if n.Low != nil { + walkBeforeAfter(&n.Low, before, after) + } + if n.High != nil { + walkBeforeAfter(&n.High, before, after) + } + case *ast.TypeAssertExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Type, before, after) + case *ast.CallExpr: + walkBeforeAfter(&n.Fun, before, after) + walkBeforeAfter(&n.Args, before, after) + case *ast.StarExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.UnaryExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.BinaryExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Y, before, after) + case *ast.KeyValueExpr: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + + case *ast.ArrayType: + walkBeforeAfter(&n.Len, before, after) + walkBeforeAfter(&n.Elt, before, after) + case *ast.StructType: + walkBeforeAfter(&n.Fields, before, after) + case *ast.FuncType: + walkBeforeAfter(&n.Params, before, after) + if n.Results != nil { + walkBeforeAfter(&n.Results, before, after) + } + case *ast.InterfaceType: + walkBeforeAfter(&n.Methods, before, after) + case *ast.MapType: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.ChanType: + walkBeforeAfter(&n.Value, before, after) + + case *ast.BadStmt: + case *ast.DeclStmt: + walkBeforeAfter(&n.Decl, before, after) + case *ast.EmptyStmt: + case *ast.LabeledStmt: + walkBeforeAfter(&n.Stmt, before, after) + case *ast.ExprStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.SendStmt: + walkBeforeAfter(&n.Chan, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.IncDecStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.AssignStmt: + walkBeforeAfter(&n.Lhs, before, after) + walkBeforeAfter(&n.Rhs, before, after) + case *ast.GoStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.DeferStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.ReturnStmt: + walkBeforeAfter(&n.Results, before, after) + case *ast.BranchStmt: + case *ast.BlockStmt: + walkBeforeAfter(&n.List, before, after) + case *ast.IfStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Body, before, after) + walkBeforeAfter(&n.Else, before, after) + case *ast.CaseClause: + walkBeforeAfter(&n.List, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Tag, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.TypeSwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Assign, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CommClause: + walkBeforeAfter(&n.Comm, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SelectStmt: + walkBeforeAfter(&n.Body, before, after) + case *ast.ForStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Post, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.RangeStmt: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Body, before, after) + + case *ast.ImportSpec: + case *ast.ValueSpec: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Values, before, after) + walkBeforeAfter(&n.Names, before, after) + case *ast.TypeSpec: + walkBeforeAfter(&n.Type, before, after) + + case *ast.BadDecl: + case *ast.GenDecl: + walkBeforeAfter(&n.Specs, before, after) + case *ast.FuncDecl: + if n.Recv != nil { + walkBeforeAfter(&n.Recv, before, after) + } + walkBeforeAfter(&n.Type, before, after) + if n.Body != nil { + walkBeforeAfter(&n.Body, before, after) + } + + case *ast.File: + walkBeforeAfter(&n.Decls, before, after) + + case *ast.Package: + walkBeforeAfter(&n.Files, before, after) + + case []*ast.File: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Decl: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Expr: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []*ast.Ident: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Stmt: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Spec: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + } + after(x) +} + +// imports returns true if f imports path. +func imports(f *ast.File, path string) bool { + return importSpec(f, path) != nil +} + +// importSpec returns the import spec if f imports path, +// or nil otherwise. +func importSpec(f *ast.File, path string) *ast.ImportSpec { + for _, s := range f.Imports { + if importPath(s) == path { + return s + } + } + return nil +} + +// importPath returns the unquoted import path of s, +// or "" if the path is not properly quoted. +func importPath(s *ast.ImportSpec) string { + t, err := strconv.Unquote(s.Path.Value) + if err == nil { + return t + } + return "" +} + +// declImports reports whether gen contains an import of path. +func declImports(gen *ast.GenDecl, path string) bool { + if gen.Tok != token.IMPORT { + return false + } + for _, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + if importPath(impspec) == path { + return true + } + } + return false +} + +// isPkgDot returns true if t is the expression "pkg.name" +// where pkg is an imported identifier. +func isPkgDot(t ast.Expr, pkg, name string) bool { + sel, ok := t.(*ast.SelectorExpr) + return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name +} + +// isPtrPkgDot returns true if f is the expression "*pkg.name" +// where pkg is an imported identifier. +func isPtrPkgDot(t ast.Expr, pkg, name string) bool { + ptr, ok := t.(*ast.StarExpr) + return ok && isPkgDot(ptr.X, pkg, name) +} + +// isTopName returns true if n is a top-level unresolved identifier with the given name. +func isTopName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + return ok && id.Name == name && id.Obj == nil +} + +// isName returns true if n is an identifier with the given name. +func isName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + return ok && id.String() == name +} + +// isCall returns true if t is a call to pkg.name. +func isCall(t ast.Expr, pkg, name string) bool { + call, ok := t.(*ast.CallExpr) + return ok && isPkgDot(call.Fun, pkg, name) +} + +// If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil. +func isIdent(n interface{}) *ast.Ident { + id, _ := n.(*ast.Ident) + return id +} + +// refersTo returns true if n is a reference to the same object as x. +func refersTo(n ast.Node, x *ast.Ident) bool { + id, ok := n.(*ast.Ident) + // The test of id.Name == x.Name handles top-level unresolved + // identifiers, which all have Obj == nil. + return ok && id.Obj == x.Obj && id.Name == x.Name +} + +// isBlank returns true if n is the blank identifier. +func isBlank(n ast.Expr) bool { + return isName(n, "_") +} + +// isEmptyString returns true if n is an empty string literal. +func isEmptyString(n ast.Expr) bool { + lit, ok := n.(*ast.BasicLit) + return ok && lit.Kind == token.STRING && len(lit.Value) == 2 +} + +func warn(pos token.Pos, msg string, args ...interface{}) { + if pos.IsValid() { + msg = "%s: " + msg + arg1 := []interface{}{fset.Position(pos).String()} + args = append(arg1, args...) + } + fmt.Fprintf(os.Stderr, msg+"\n", args...) +} + +// countUses returns the number of uses of the identifier x in scope. +func countUses(x *ast.Ident, scope []ast.Stmt) int { + count := 0 + ff := func(n interface{}) { + if n, ok := n.(ast.Node); ok && refersTo(n, x) { + count++ + } + } + for _, n := range scope { + walk(n, ff) + } + return count +} + +// rewriteUses replaces all uses of the identifier x and !x in scope +// with f(x.Pos()) and fnot(x.Pos()). +func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) { + var lastF ast.Expr + ff := func(n interface{}) { + ptr, ok := n.(*ast.Expr) + if !ok { + return + } + nn := *ptr + + // The child node was just walked and possibly replaced. + // If it was replaced and this is a negation, replace with fnot(p). + not, ok := nn.(*ast.UnaryExpr) + if ok && not.Op == token.NOT && not.X == lastF { + *ptr = fnot(nn.Pos()) + return + } + if refersTo(nn, x) { + lastF = f(nn.Pos()) + *ptr = lastF + } + } + for _, n := range scope { + walk(n, ff) + } +} + +// assignsTo returns true if any of the code in scope assigns to or takes the address of x. +func assignsTo(x *ast.Ident, scope []ast.Stmt) bool { + assigned := false + ff := func(n interface{}) { + if assigned { + return + } + switch n := n.(type) { + case *ast.UnaryExpr: + // use of &x + if n.Op == token.AND && refersTo(n.X, x) { + assigned = true + return + } + case *ast.AssignStmt: + for _, l := range n.Lhs { + if refersTo(l, x) { + assigned = true + return + } + } + } + } + for _, n := range scope { + if assigned { + break + } + walk(n, ff) + } + return assigned +} + +// newPkgDot returns an ast.Expr referring to "pkg.name" at position pos. +func newPkgDot(pos token.Pos, pkg, name string) ast.Expr { + return &ast.SelectorExpr{ + X: &ast.Ident{ + NamePos: pos, + Name: pkg, + }, + Sel: &ast.Ident{ + NamePos: pos, + Name: name, + }, + } +} + +// renameTop renames all references to the top-level name old. +// It returns true if it makes any changes. +func renameTop(f *ast.File, old, new string) bool { + var fixed bool + + // Rename any conflicting imports + // (assuming package name is last element of path). + for _, s := range f.Imports { + if s.Name != nil { + if s.Name.Name == old { + s.Name.Name = new + fixed = true + } + } else { + _, thisName := path.Split(importPath(s)) + if thisName == old { + s.Name = ast.NewIdent(new) + fixed = true + } + } + } + + // Rename any top-level declarations. + for _, d := range f.Decls { + switch d := d.(type) { + case *ast.FuncDecl: + if d.Recv == nil && d.Name.Name == old { + d.Name.Name = new + d.Name.Obj.Name = new + fixed = true + } + case *ast.GenDecl: + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if s.Name.Name == old { + s.Name.Name = new + s.Name.Obj.Name = new + fixed = true + } + case *ast.ValueSpec: + for _, n := range s.Names { + if n.Name == old { + n.Name = new + n.Obj.Name = new + fixed = true + } + } + } + } + } + } + + // Rename top-level old to new, both unresolved names + // (probably defined in another file) and names that resolve + // to a declaration we renamed. + walk(f, func(n interface{}) { + id, ok := n.(*ast.Ident) + if ok && isTopName(id, old) { + id.Name = new + fixed = true + } + if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new { + id.Name = id.Obj.Name + fixed = true + } + }) + + return fixed +} + +// matchLen returns the length of the longest prefix shared by x and y. +func matchLen(x, y string) int { + i := 0 + for i < len(x) && i < len(y) && x[i] == y[i] { + i++ + } + return i +} + +// addImport adds the import path to the file f, if absent. +func addImport(f *ast.File, ipath string) (added bool) { + if imports(f, ipath) { + return false + } + + // Determine name of import. + // Assume added imports follow convention of using last element. + _, name := path.Split(ipath) + + // Rename any conflicting top-level references from name to name_. + renameTop(f, name, name+"_") + + newImport := &ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(ipath), + }, + } + + // Find an import decl to add to. + var ( + bestMatch = -1 + lastImport = -1 + impDecl *ast.GenDecl + impIndex = -1 + ) + for i, decl := range f.Decls { + gen, ok := decl.(*ast.GenDecl) + if ok && gen.Tok == token.IMPORT { + lastImport = i + // Do not add to import "C", to avoid disrupting the + // association with its doc comment, breaking cgo. + if declImports(gen, "C") { + continue + } + + // Compute longest shared prefix with imports in this block. + for j, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + n := matchLen(importPath(impspec), ipath) + if n > bestMatch { + bestMatch = n + impDecl = gen + impIndex = j + } + } + } + } + + // If no import decl found, add one after the last import. + if impDecl == nil { + impDecl = &ast.GenDecl{ + Tok: token.IMPORT, + } + f.Decls = append(f.Decls, nil) + copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:]) + f.Decls[lastImport+1] = impDecl + } + + // Ensure the import decl has parentheses, if needed. + if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() { + impDecl.Lparen = impDecl.Pos() + } + + insertAt := impIndex + 1 + if insertAt == 0 { + insertAt = len(impDecl.Specs) + } + impDecl.Specs = append(impDecl.Specs, nil) + copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:]) + impDecl.Specs[insertAt] = newImport + if insertAt > 0 { + // Assign same position as the previous import, + // so that the sorter sees it as being in the same block. + prev := impDecl.Specs[insertAt-1] + newImport.Path.ValuePos = prev.Pos() + newImport.EndPos = prev.Pos() + } + + f.Imports = append(f.Imports, newImport) + return true +} + +// deleteImport deletes the import path from the file f, if present. +func deleteImport(f *ast.File, path string) (deleted bool) { + oldImport := importSpec(f, path) + + // Find the import node that imports path, if any. + for i, decl := range f.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + continue + } + for j, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + if oldImport != impspec { + continue + } + + // We found an import spec that imports path. + // Delete it. + deleted = true + copy(gen.Specs[j:], gen.Specs[j+1:]) + gen.Specs = gen.Specs[:len(gen.Specs)-1] + + // If this was the last import spec in this decl, + // delete the decl, too. + if len(gen.Specs) == 0 { + copy(f.Decls[i:], f.Decls[i+1:]) + f.Decls = f.Decls[:len(f.Decls)-1] + } else if len(gen.Specs) == 1 { + gen.Lparen = token.NoPos // drop parens + } + if j > 0 { + // We deleted an entry but now there will be + // a blank line-sized hole where the import was. + // Close the hole by making the previous + // import appear to "end" where this one did. + gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End() + } + break + } + } + + // Delete it from f.Imports. + for i, imp := range f.Imports { + if imp == oldImport { + copy(f.Imports[i:], f.Imports[i+1:]) + f.Imports = f.Imports[:len(f.Imports)-1] + break + } + } + + return +} + +// rewriteImport rewrites any import of path oldPath to path newPath. +func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) { + for _, imp := range f.Imports { + if importPath(imp) == oldPath { + rewrote = true + // record old End, because the default is to compute + // it using the length of imp.Path.Value. + imp.EndPos = imp.End() + imp.Path.Value = strconv.Quote(newPath) + } + } + return +} + +func usesImport(f *ast.File, path string) (used bool) { + spec := importSpec(f, path) + if spec == nil { + return + } + + name := spec.Name.String() + switch name { + case "<nil>": + // If the package name is not explicitly specified, + // make an educated guess. This is not guaranteed to be correct. + lastSlash := strings.LastIndex(path, "/") + if lastSlash == -1 { + name = path + } else { + name = path[lastSlash+1:] + } + case "_", ".": + // Not sure if this import is used - err on the side of caution. + return true + } + + walk(f, func(n interface{}) { + sel, ok := n.(*ast.SelectorExpr) + if ok && isTopName(sel.X, name) { + used = true + } + }) + + return +} + +func expr(s string) ast.Expr { + x, err := parser.ParseExpr(s) + if err != nil { + panic("parsing " + s + ": " + err.Error()) + } + // Remove position information to avoid spurious newlines. + killPos(reflect.ValueOf(x)) + return x +} + +var posType = reflect.TypeOf(token.Pos(0)) + +func killPos(v reflect.Value) { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if !v.IsNil() { + killPos(v.Elem()) + } + case reflect.Slice: + n := v.Len() + for i := 0; i < n; i++ { + killPos(v.Index(i)) + } + case reflect.Struct: + n := v.NumField() + for i := 0; i < n; i++ { + f := v.Field(i) + if f.Type() == posType { + f.SetInt(0) + continue + } + killPos(f) + } + } +} + +// A Rename describes a single renaming. +type rename struct { + OldImport string // only apply rename if this import is present + NewImport string // add this import during rewrite + Old string // old name: p.T or *p.T + New string // new name: p.T or *p.T +} + +func renameFix(tab []rename) func(*ast.File) bool { + return func(f *ast.File) bool { + return renameFixTab(f, tab) + } +} + +func parseName(s string) (ptr bool, pkg, nam string) { + i := strings.Index(s, ".") + if i < 0 { + panic("parseName: invalid name " + s) + } + if strings.HasPrefix(s, "*") { + ptr = true + s = s[1:] + i-- + } + pkg = s[:i] + nam = s[i+1:] + return +} + +func renameFixTab(f *ast.File, tab []rename) bool { + fixed := false + added := map[string]bool{} + check := map[string]bool{} + for _, t := range tab { + if !imports(f, t.OldImport) { + continue + } + optr, opkg, onam := parseName(t.Old) + walk(f, func(n interface{}) { + np, ok := n.(*ast.Expr) + if !ok { + return + } + x := *np + if optr { + p, ok := x.(*ast.StarExpr) + if !ok { + return + } + x = p.X + } + if !isPkgDot(x, opkg, onam) { + return + } + if t.NewImport != "" && !added[t.NewImport] { + addImport(f, t.NewImport) + added[t.NewImport] = true + } + *np = expr(t.New) + check[t.OldImport] = true + fixed = true + }) + } + + for ipath := range check { + if !usesImport(f, ipath) { + deleteImport(f, ipath) + } + } + return fixed +} diff --git a/vendor/google.golang.org/appengine/cmd/aefix/main.go b/vendor/google.golang.org/appengine/cmd/aefix/main.go new file mode 100644 index 000000000..8e193a6ad --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/main.go @@ -0,0 +1,258 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/scanner" + "go/token" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" +) + +var ( + fset = token.NewFileSet() + exitCode = 0 +) + +var allowedRewrites = flag.String("r", "", + "restrict the rewrites to this comma-separated list") + +var forceRewrites = flag.String("force", "", + "force these fixes to run even if the code looks updated") + +var allowed, force map[string]bool + +var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") + +// enable for debugging fix failures +const debug = false // display incorrectly reformatted source and exit + +func usage() { + fmt.Fprintf(os.Stderr, "usage: aefix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") + sort.Sort(byName(fixes)) + for _, f := range fixes { + fmt.Fprintf(os.Stderr, "\n%s\n", f.name) + desc := strings.TrimSpace(f.desc) + desc = strings.Replace(desc, "\n", "\n\t", -1) + fmt.Fprintf(os.Stderr, "\t%s\n", desc) + } + os.Exit(2) +} + +func main() { + flag.Usage = usage + flag.Parse() + + sort.Sort(byDate(fixes)) + + if *allowedRewrites != "" { + allowed = make(map[string]bool) + for _, f := range strings.Split(*allowedRewrites, ",") { + allowed[f] = true + } + } + + if *forceRewrites != "" { + force = make(map[string]bool) + for _, f := range strings.Split(*forceRewrites, ",") { + force[f] = true + } + } + + if flag.NArg() == 0 { + if err := processFile("standard input", true); err != nil { + report(err) + } + os.Exit(exitCode) + } + + for i := 0; i < flag.NArg(); i++ { + path := flag.Arg(i) + switch dir, err := os.Stat(path); { + case err != nil: + report(err) + case dir.IsDir(): + walkDir(path) + default: + if err := processFile(path, false); err != nil { + report(err) + } + } + } + + os.Exit(exitCode) +} + +const parserMode = parser.ParseComments + +func gofmtFile(f *ast.File) ([]byte, error) { + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func processFile(filename string, useStdin bool) error { + var f *os.File + var err error + var fixlog bytes.Buffer + + if useStdin { + f = os.Stdin + } else { + f, err = os.Open(filename) + if err != nil { + return err + } + defer f.Close() + } + + src, err := ioutil.ReadAll(f) + if err != nil { + return err + } + + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err + } + + // Apply all fixes to file. + newFile := file + fixed := false + for _, fix := range fixes { + if allowed != nil && !allowed[fix.name] { + continue + } + if fix.f(newFile) { + fixed = true + fmt.Fprintf(&fixlog, " %s", fix.name) + + // AST changed. + // Print and parse, to update any missing scoping + // or position information for subsequent fixers. + newSrc, err := gofmtFile(newFile) + if err != nil { + return err + } + newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) + if err != nil { + if debug { + fmt.Printf("%s", newSrc) + report(err) + os.Exit(exitCode) + } + return err + } + } + } + if !fixed { + return nil + } + fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) + + // Print AST. We did that after each fix, so this appears + // redundant, but it is necessary to generate gofmt-compatible + // source code in a few cases. The official gofmt style is the + // output of the printer run on a standard AST generated by the parser, + // but the source we generated inside the loop above is the + // output of the printer run on a mangled AST generated by a fixer. + newSrc, err := gofmtFile(newFile) + if err != nil { + return err + } + + if *doDiff { + data, err := diff(src, newSrc) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s fixed/%s\n", filename, filename) + os.Stdout.Write(data) + return nil + } + + if useStdin { + os.Stdout.Write(newSrc) + return nil + } + + return ioutil.WriteFile(f.Name(), newSrc, 0) +} + +var gofmtBuf bytes.Buffer + +func gofmt(n interface{}) string { + gofmtBuf.Reset() + if err := format.Node(&gofmtBuf, fset, n); err != nil { + return "<" + err.Error() + ">" + } + return gofmtBuf.String() +} + +func report(err error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + +func walkDir(path string) { + filepath.Walk(path, visitFile) +} + +func visitFile(path string, f os.FileInfo, err error) error { + if err == nil && isGoFile(f) { + err = processFile(path, false) + } + if err != nil { + report(err) + } + return nil +} + +func isGoFile(f os.FileInfo) bool { + // ignore non-Go files + name := f.Name() + return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") +} + +func diff(b1, b2 []byte) (data []byte, err error) { + f1, err := ioutil.TempFile("", "go-fix") + if err != nil { + return nil, err + } + defer os.Remove(f1.Name()) + defer f1.Close() + + f2, err := ioutil.TempFile("", "go-fix") + if err != nil { + return nil, err + } + defer os.Remove(f2.Name()) + defer f2.Close() + + f1.Write(b1) + f2.Write(b2) + + data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput() + if len(data) > 0 { + // diff exits with a non-zero status when the files don't match. + // Ignore that failure as long as we get output. + err = nil + } + return +} diff --git a/vendor/google.golang.org/appengine/cmd/aefix/main_test.go b/vendor/google.golang.org/appengine/cmd/aefix/main_test.go new file mode 100644 index 000000000..2151bf29e --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/main_test.go @@ -0,0 +1,129 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/parser" + "strings" + "testing" +) + +type testCase struct { + Name string + Fn func(*ast.File) bool + In string + Out string +} + +var testCases []testCase + +func addTestCases(t []testCase, fn func(*ast.File) bool) { + // Fill in fn to avoid repetition in definitions. + if fn != nil { + for i := range t { + if t[i].Fn == nil { + t[i].Fn = fn + } + } + } + testCases = append(testCases, t...) +} + +func fnop(*ast.File) bool { return false } + +func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) { + file, err := parser.ParseFile(fset, desc, in, parserMode) + if err != nil { + t.Errorf("%s: parsing: %v", desc, err) + return + } + + outb, err := gofmtFile(file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + if s := string(outb); in != s && mustBeGofmt { + t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", + desc, desc, in, desc, s) + tdiff(t, in, s) + return + } + + if fn == nil { + for _, fix := range fixes { + if fix.f(file) { + fixed = true + } + } + } else { + fixed = fn(file) + } + + outb, err = gofmtFile(file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + + return string(outb), fixed, true +} + +func TestRewrite(t *testing.T) { + for _, tt := range testCases { + // Apply fix: should get tt.Out. + out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true) + if !ok { + continue + } + + // reformat to get printing right + out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false) + if !ok { + continue + } + + if out != tt.Out { + t.Errorf("%s: incorrect output.\n", tt.Name) + if !strings.HasPrefix(tt.Name, "testdata/") { + t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out) + } + tdiff(t, out, tt.Out) + continue + } + + if changed := out != tt.In; changed != fixed { + t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed) + continue + } + + // Should not change if run again. + out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true) + if !ok { + continue + } + + if fixed2 { + t.Errorf("%s: applied fixes during second round", tt.Name) + continue + } + + if out2 != out { + t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s", + tt.Name, out, out2) + tdiff(t, out, out2) + } + } +} + +func tdiff(t *testing.T, a, b string) { + data, err := diff([]byte(a), []byte(b)) + if err != nil { + t.Error(err) + return + } + t.Error(string(data)) +} diff --git a/vendor/google.golang.org/appengine/cmd/aefix/typecheck.go b/vendor/google.golang.org/appengine/cmd/aefix/typecheck.go new file mode 100644 index 000000000..d54d37547 --- /dev/null +++ b/vendor/google.golang.org/appengine/cmd/aefix/typecheck.go @@ -0,0 +1,673 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "reflect" + "strings" +) + +// Partial type checker. +// +// The fact that it is partial is very important: the input is +// an AST and a description of some type information to +// assume about one or more packages, but not all the +// packages that the program imports. The checker is +// expected to do as much as it can with what it has been +// given. There is not enough information supplied to do +// a full type check, but the type checker is expected to +// apply information that can be derived from variable +// declarations, function and method returns, and type switches +// as far as it can, so that the caller can still tell the types +// of expression relevant to a particular fix. +// +// TODO(rsc,gri): Replace with go/typechecker. +// Doing that could be an interesting test case for go/typechecker: +// the constraints about working with partial information will +// likely exercise it in interesting ways. The ideal interface would +// be to pass typecheck a map from importpath to package API text +// (Go source code), but for now we use data structures (TypeConfig, Type). +// +// The strings mostly use gofmt form. +// +// A Field or FieldList has as its type a comma-separated list +// of the types of the fields. For example, the field list +// x, y, z int +// has type "int, int, int". + +// The prefix "type " is the type of a type. +// For example, given +// var x int +// type T int +// x's type is "int" but T's type is "type int". +// mkType inserts the "type " prefix. +// getType removes it. +// isType tests for it. + +func mkType(t string) string { + return "type " + t +} + +func getType(t string) string { + if !isType(t) { + return "" + } + return t[len("type "):] +} + +func isType(t string) bool { + return strings.HasPrefix(t, "type ") +} + +// TypeConfig describes the universe of relevant types. +// For ease of creation, the types are all referred to by string +// name (e.g., "reflect.Value"). TypeByName is the only place +// where the strings are resolved. + +type TypeConfig struct { + Type map[string]*Type + Var map[string]string + Func map[string]string +} + +// typeof returns the type of the given name, which may be of +// the form "x" or "p.X". +func (cfg *TypeConfig) typeof(name string) string { + if cfg.Var != nil { + if t := cfg.Var[name]; t != "" { + return t + } + } + if cfg.Func != nil { + if t := cfg.Func[name]; t != "" { + return "func()" + t + } + } + return "" +} + +// Type describes the Fields and Methods of a type. +// If the field or method cannot be found there, it is next +// looked for in the Embed list. +type Type struct { + Field map[string]string // map field name to type + Method map[string]string // map method name to comma-separated return types (should start with "func ") + Embed []string // list of types this type embeds (for extra methods) + Def string // definition of named type +} + +// dot returns the type of "typ.name", making its decision +// using the type information in cfg. +func (typ *Type) dot(cfg *TypeConfig, name string) string { + if typ.Field != nil { + if t := typ.Field[name]; t != "" { + return t + } + } + if typ.Method != nil { + if t := typ.Method[name]; t != "" { + return t + } + } + + for _, e := range typ.Embed { + etyp := cfg.Type[e] + if etyp != nil { + if t := etyp.dot(cfg, name); t != "" { + return t + } + } + } + + return "" +} + +// typecheck type checks the AST f assuming the information in cfg. +// It returns two maps with type information: +// typeof maps AST nodes to type information in gofmt string form. +// assign maps type strings to lists of expressions that were assigned +// to values of another type that were assigned to that type. +func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) { + typeof = make(map[interface{}]string) + assign = make(map[string][]interface{}) + cfg1 := &TypeConfig{} + *cfg1 = *cfg // make copy so we can add locally + copied := false + + // gather function declarations + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + typecheck1(cfg, fn.Type, typeof, assign) + t := typeof[fn.Type] + if fn.Recv != nil { + // The receiver must be a type. + rcvr := typeof[fn.Recv] + if !isType(rcvr) { + if len(fn.Recv.List) != 1 { + continue + } + rcvr = mkType(gofmt(fn.Recv.List[0].Type)) + typeof[fn.Recv.List[0].Type] = rcvr + } + rcvr = getType(rcvr) + if rcvr != "" && rcvr[0] == '*' { + rcvr = rcvr[1:] + } + typeof[rcvr+"."+fn.Name.Name] = t + } else { + if isType(t) { + t = getType(t) + } else { + t = gofmt(fn.Type) + } + typeof[fn.Name] = t + + // Record typeof[fn.Name.Obj] for future references to fn.Name. + typeof[fn.Name.Obj] = t + } + } + + // gather struct declarations + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if ok { + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if cfg1.Type[s.Name.Name] != nil { + break + } + if !copied { + copied = true + // Copy map lazily: it's time. + cfg1.Type = make(map[string]*Type) + for k, v := range cfg.Type { + cfg1.Type[k] = v + } + } + t := &Type{Field: map[string]string{}} + cfg1.Type[s.Name.Name] = t + switch st := s.Type.(type) { + case *ast.StructType: + for _, f := range st.Fields.List { + for _, n := range f.Names { + t.Field[n.Name] = gofmt(f.Type) + } + } + case *ast.ArrayType, *ast.StarExpr, *ast.MapType: + t.Def = gofmt(st) + } + } + } + } + } + + typecheck1(cfg1, f, typeof, assign) + return typeof, assign +} + +func makeExprList(a []*ast.Ident) []ast.Expr { + var b []ast.Expr + for _, x := range a { + b = append(b, x) + } + return b +} + +// Typecheck1 is the recursive form of typecheck. +// It is like typecheck but adds to the information in typeof +// instead of allocating a new map. +func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) { + // set sets the type of n to typ. + // If isDecl is true, n is being declared. + set := func(n ast.Expr, typ string, isDecl bool) { + if typeof[n] != "" || typ == "" { + if typeof[n] != typ { + assign[typ] = append(assign[typ], n) + } + return + } + typeof[n] = typ + + // If we obtained typ from the declaration of x + // propagate the type to all the uses. + // The !isDecl case is a cheat here, but it makes + // up in some cases for not paying attention to + // struct fields. The real type checker will be + // more accurate so we won't need the cheat. + if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") { + typeof[id.Obj] = typ + } + } + + // Type-check an assignment lhs = rhs. + // If isDecl is true, this is := so we can update + // the types of the objects that lhs refers to. + typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) { + if len(lhs) > 1 && len(rhs) == 1 { + if _, ok := rhs[0].(*ast.CallExpr); ok { + t := split(typeof[rhs[0]]) + // Lists should have same length but may not; pair what can be paired. + for i := 0; i < len(lhs) && i < len(t); i++ { + set(lhs[i], t[i], isDecl) + } + return + } + } + if len(lhs) == 1 && len(rhs) == 2 { + // x = y, ok + rhs = rhs[:1] + } else if len(lhs) == 2 && len(rhs) == 1 { + // x, ok = y + lhs = lhs[:1] + } + + // Match as much as we can. + for i := 0; i < len(lhs) && i < len(rhs); i++ { + x, y := lhs[i], rhs[i] + if typeof[y] != "" { + set(x, typeof[y], isDecl) + } else { + set(y, typeof[x], false) + } + } + } + + expand := func(s string) string { + typ := cfg.Type[s] + if typ != nil && typ.Def != "" { + return typ.Def + } + return s + } + + // The main type check is a recursive algorithm implemented + // by walkBeforeAfter(n, before, after). + // Most of it is bottom-up, but in a few places we need + // to know the type of the function we are checking. + // The before function records that information on + // the curfn stack. + var curfn []*ast.FuncType + + before := func(n interface{}) { + // push function type on stack + switch n := n.(type) { + case *ast.FuncDecl: + curfn = append(curfn, n.Type) + case *ast.FuncLit: + curfn = append(curfn, n.Type) + } + } + + // After is the real type checker. + after := func(n interface{}) { + if n == nil { + return + } + if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace + defer func() { + if t := typeof[n]; t != "" { + pos := fset.Position(n.(ast.Node).Pos()) + fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t) + } + }() + } + + switch n := n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + // pop function type off stack + curfn = curfn[:len(curfn)-1] + + case *ast.FuncType: + typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results]))) + + case *ast.FieldList: + // Field list is concatenation of sub-lists. + t := "" + for _, field := range n.List { + if t != "" { + t += ", " + } + t += typeof[field] + } + typeof[n] = t + + case *ast.Field: + // Field is one instance of the type per name. + all := "" + t := typeof[n.Type] + if !isType(t) { + // Create a type, because it is typically *T or *p.T + // and we might care about that type. + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + if len(n.Names) == 0 { + all = t + } else { + for _, id := range n.Names { + if all != "" { + all += ", " + } + all += t + typeof[id.Obj] = t + typeof[id] = t + } + } + typeof[n] = all + + case *ast.ValueSpec: + // var declaration. Use type if present. + if n.Type != nil { + t := typeof[n.Type] + if !isType(t) { + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + for _, id := range n.Names { + set(id, t, true) + } + } + // Now treat same as assignment. + typecheckAssign(makeExprList(n.Names), n.Values, true) + + case *ast.AssignStmt: + typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE) + + case *ast.Ident: + // Identifier can take its type from underlying object. + if t := typeof[n.Obj]; t != "" { + typeof[n] = t + } + + case *ast.SelectorExpr: + // Field or method. + name := n.Sel.Name + if t := typeof[n.X]; t != "" { + if strings.HasPrefix(t, "*") { + t = t[1:] // implicit * + } + if typ := cfg.Type[t]; typ != nil { + if t := typ.dot(cfg, name); t != "" { + typeof[n] = t + return + } + } + tt := typeof[t+"."+name] + if isType(tt) { + typeof[n] = getType(tt) + return + } + } + // Package selector. + if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil { + str := x.Name + "." + name + if cfg.Type[str] != nil { + typeof[n] = mkType(str) + return + } + if t := cfg.typeof(x.Name + "." + name); t != "" { + typeof[n] = t + return + } + } + + case *ast.CallExpr: + // make(T) has type T. + if isTopName(n.Fun, "make") && len(n.Args) >= 1 { + typeof[n] = gofmt(n.Args[0]) + return + } + // new(T) has type *T + if isTopName(n.Fun, "new") && len(n.Args) == 1 { + typeof[n] = "*" + gofmt(n.Args[0]) + return + } + // Otherwise, use type of function to determine arguments. + t := typeof[n.Fun] + in, out := splitFunc(t) + if in == nil && out == nil { + return + } + typeof[n] = join(out) + for i, arg := range n.Args { + if i >= len(in) { + break + } + if typeof[arg] == "" { + typeof[arg] = in[i] + } + } + + case *ast.TypeAssertExpr: + // x.(type) has type of x. + if n.Type == nil { + typeof[n] = typeof[n.X] + return + } + // x.(T) has type T. + if t := typeof[n.Type]; isType(t) { + typeof[n] = getType(t) + } else { + typeof[n] = gofmt(n.Type) + } + + case *ast.SliceExpr: + // x[i:j] has type of x. + typeof[n] = typeof[n.X] + + case *ast.IndexExpr: + // x[i] has key type of x's type. + t := expand(typeof[n.X]) + if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { + // Lazy: assume there are no nested [] in the array + // length or map key type. + if i := strings.Index(t, "]"); i >= 0 { + typeof[n] = t[i+1:] + } + } + + case *ast.StarExpr: + // *x for x of type *T has type T when x is an expr. + // We don't use the result when *x is a type, but + // compute it anyway. + t := expand(typeof[n.X]) + if isType(t) { + typeof[n] = "type *" + getType(t) + } else if strings.HasPrefix(t, "*") { + typeof[n] = t[len("*"):] + } + + case *ast.UnaryExpr: + // &x for x of type T has type *T. + t := typeof[n.X] + if t != "" && n.Op == token.AND { + typeof[n] = "*" + t + } + + case *ast.CompositeLit: + // T{...} has type T. + typeof[n] = gofmt(n.Type) + + case *ast.ParenExpr: + // (x) has type of x. + typeof[n] = typeof[n.X] + + case *ast.RangeStmt: + t := expand(typeof[n.X]) + if t == "" { + return + } + var key, value string + if t == "string" { + key, value = "int", "rune" + } else if strings.HasPrefix(t, "[") { + key = "int" + if i := strings.Index(t, "]"); i >= 0 { + value = t[i+1:] + } + } else if strings.HasPrefix(t, "map[") { + if i := strings.Index(t, "]"); i >= 0 { + key, value = t[4:i], t[i+1:] + } + } + changed := false + if n.Key != nil && key != "" { + changed = true + set(n.Key, key, n.Tok == token.DEFINE) + } + if n.Value != nil && value != "" { + changed = true + set(n.Value, value, n.Tok == token.DEFINE) + } + // Ugly failure of vision: already type-checked body. + // Do it again now that we have that type info. + if changed { + typecheck1(cfg, n.Body, typeof, assign) + } + + case *ast.TypeSwitchStmt: + // Type of variable changes for each case in type switch, + // but go/parser generates just one variable. + // Repeat type check for each case with more precise + // type information. + as, ok := n.Assign.(*ast.AssignStmt) + if !ok { + return + } + varx, ok := as.Lhs[0].(*ast.Ident) + if !ok { + return + } + t := typeof[varx] + for _, cas := range n.Body.List { + cas := cas.(*ast.CaseClause) + if len(cas.List) == 1 { + // Variable has specific type only when there is + // exactly one type in the case list. + if tt := typeof[cas.List[0]]; isType(tt) { + tt = getType(tt) + typeof[varx] = tt + typeof[varx.Obj] = tt + typecheck1(cfg, cas.Body, typeof, assign) + } + } + } + // Restore t. + typeof[varx] = t + typeof[varx.Obj] = t + + case *ast.ReturnStmt: + if len(curfn) == 0 { + // Probably can't happen. + return + } + f := curfn[len(curfn)-1] + res := n.Results + if f.Results != nil { + t := split(typeof[f.Results]) + for i := 0; i < len(res) && i < len(t); i++ { + set(res[i], t[i], false) + } + } + } + } + walkBeforeAfter(f, before, after) +} + +// Convert between function type strings and lists of types. +// Using strings makes this a little harder, but it makes +// a lot of the rest of the code easier. This will all go away +// when we can use go/typechecker directly. + +// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"]. +func splitFunc(s string) (in, out []string) { + if !strings.HasPrefix(s, "func(") { + return nil, nil + } + + i := len("func(") // index of beginning of 'in' arguments + nparen := 0 + for j := i; j < len(s); j++ { + switch s[j] { + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // found end of parameter list + out := strings.TrimSpace(s[j+1:]) + if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' { + out = out[1 : len(out)-1] + } + return split(s[i:j]), split(out) + } + } + } + return nil, nil +} + +// joinFunc is the inverse of splitFunc. +func joinFunc(in, out []string) string { + outs := "" + if len(out) == 1 { + outs = " " + out[0] + } else if len(out) > 1 { + outs = " (" + join(out) + ")" + } + return "func(" + join(in) + ")" + outs +} + +// split splits "int, float" into ["int", "float"] and splits "" into []. +func split(s string) []string { + out := []string{} + i := 0 // current type being scanned is s[i:j]. + nparen := 0 + for j := 0; j < len(s); j++ { + switch s[j] { + case ' ': + if i == j { + i++ + } + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // probably can't happen + return nil + } + case ',': + if nparen == 0 { + if i < j { + out = append(out, s[i:j]) + } + i = j + 1 + } + } + } + if nparen != 0 { + // probably can't happen + return nil + } + if i < len(s) { + out = append(out, s[i:]) + } + return out +} + +// join is the inverse of split. +func join(x []string) string { + return strings.Join(x, ", ") +} |