You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
494 lines
13 KiB
494 lines
13 KiB
// Copyright 2012 Google Inc. |
|
// |
|
// Licensed under the Apache License, Version 2.0 (the "License"); |
|
// you may not use this file except in compliance with the License. |
|
// You may obtain a copy of the License at |
|
// |
|
// http://www.apache.org/licenses/LICENSE-2.0 |
|
// |
|
// Unless required by applicable law or agreed to in writing, software |
|
// distributed under the License is distributed on an "AS IS" BASIS, |
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
// See the License for the specific language governing permissions and |
|
// limitations under the License. |
|
|
|
package mockgen |
|
|
|
// This file contains the model construction by parsing source files. |
|
|
|
import ( |
|
"fmt" |
|
"go/ast" |
|
"go/build" |
|
"go/parser" |
|
"go/token" |
|
"log" |
|
"path" |
|
"path/filepath" |
|
"strconv" |
|
"strings" |
|
|
|
"github.com/otokaze/mock/mockgen/model" |
|
) |
|
|
|
// ParseFile parse by a file |
|
// TODO: simplify error reporting |
|
func ParseFile(source string) (*model.Package, error) { |
|
srcDir, err := filepath.Abs(filepath.Dir(source)) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed getting source directory: %v", err) |
|
} |
|
|
|
var packageImport string |
|
if p, err := build.ImportDir(srcDir, 0); err == nil { |
|
packageImport = p.ImportPath |
|
} // TODO: should we fail if this returns an error? |
|
|
|
fs := token.NewFileSet() |
|
file, err := parser.ParseFile(fs, source, nil, 0) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) |
|
} |
|
|
|
p := &fileParser{ |
|
fileSet: fs, |
|
imports: make(map[string]string), |
|
importedInterfaces: make(map[string]map[string]*ast.InterfaceType), |
|
auxInterfaces: make(map[string]map[string]*ast.InterfaceType), |
|
srcDir: srcDir, |
|
} |
|
|
|
// Handle -imports. |
|
dotImports := make(map[string]bool) |
|
if imports != "" { |
|
for _, kv := range strings.Split(imports, ",") { |
|
eq := strings.Index(kv, "=") |
|
k, v := kv[:eq], kv[eq+1:] |
|
if k == "." { |
|
// TODO: Catch dupes? |
|
dotImports[v] = true |
|
} else { |
|
// TODO: Catch dupes? |
|
p.imports[k] = v |
|
} |
|
} |
|
} |
|
|
|
// Handle -aux_files. |
|
if err := p.parseAuxFiles(auxFiles); err != nil { |
|
return nil, err |
|
} |
|
p.addAuxInterfacesFromFile(packageImport, file) // this file |
|
|
|
pkg, err := p.parseFile(packageImport, file) |
|
if err != nil { |
|
return nil, err |
|
} |
|
pkg.DotImports = make([]string, 0, len(dotImports)) |
|
for path := range dotImports { |
|
pkg.DotImports = append(pkg.DotImports, path) |
|
} |
|
return pkg, nil |
|
} |
|
|
|
type fileParser struct { |
|
fileSet *token.FileSet |
|
imports map[string]string // package name => import path |
|
importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface |
|
|
|
auxFiles []*ast.File |
|
auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface |
|
|
|
srcDir string |
|
} |
|
|
|
func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { |
|
ps := p.fileSet.Position(pos) |
|
format = "%s:%d:%d: " + format |
|
args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) |
|
return fmt.Errorf(format, args...) |
|
} |
|
|
|
func (p *fileParser) parseAuxFiles(auxFiles string) error { |
|
auxFiles = strings.TrimSpace(auxFiles) |
|
if auxFiles == "" { |
|
return nil |
|
} |
|
for _, kv := range strings.Split(auxFiles, ",") { |
|
parts := strings.SplitN(kv, "=", 2) |
|
if len(parts) != 2 { |
|
return fmt.Errorf("bad aux file spec: %v", kv) |
|
} |
|
pkg, fpath := parts[0], parts[1] |
|
|
|
file, err := parser.ParseFile(p.fileSet, fpath, nil, 0) |
|
if err != nil { |
|
return err |
|
} |
|
p.auxFiles = append(p.auxFiles, file) |
|
p.addAuxInterfacesFromFile(pkg, file) |
|
} |
|
return nil |
|
} |
|
|
|
func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { |
|
if _, ok := p.auxInterfaces[pkg]; !ok { |
|
p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) |
|
} |
|
for ni := range iterInterfaces(file) { |
|
p.auxInterfaces[pkg][ni.name.Name] = ni.it |
|
} |
|
} |
|
|
|
// parseFile loads all file imports and auxiliary files import into the |
|
// fileParser, parses all file interfaces and returns package model. |
|
func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) { |
|
allImports := importsOfFile(file) |
|
// Don't stomp imports provided by -imports. Those should take precedence. |
|
for pkg, path := range allImports { |
|
if _, ok := p.imports[pkg]; !ok { |
|
p.imports[pkg] = path |
|
} |
|
} |
|
// Add imports from auxiliary files, which might be needed for embedded interfaces. |
|
// Don't stomp any other imports. |
|
for _, f := range p.auxFiles { |
|
for pkg, path := range importsOfFile(f) { |
|
if _, ok := p.imports[pkg]; !ok { |
|
p.imports[pkg] = path |
|
} |
|
} |
|
} |
|
|
|
var is []*model.Interface |
|
for ni := range iterInterfaces(file) { |
|
i, err := p.parseInterface(ni.name.String(), importPath, ni.it) |
|
if err != nil { |
|
return nil, err |
|
} |
|
is = append(is, i) |
|
} |
|
return &model.Package{ |
|
Name: file.Name.String(), |
|
SrcDir: p.srcDir, |
|
Interfaces: is, |
|
}, nil |
|
} |
|
|
|
// parsePackage loads package specified by path, parses it and populates |
|
// corresponding imports and importedInterfaces into the fileParser. |
|
func (p *fileParser) parsePackage(path string) error { |
|
var pkgs map[string]*ast.Package |
|
if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil { |
|
return err |
|
} else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil { |
|
return err |
|
} |
|
for _, pkg := range pkgs { |
|
file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) |
|
if _, ok := p.importedInterfaces[path]; !ok { |
|
p.importedInterfaces[path] = make(map[string]*ast.InterfaceType) |
|
} |
|
for ni := range iterInterfaces(file) { |
|
p.importedInterfaces[path][ni.name.Name] = ni.it |
|
} |
|
for pkgName, pkgPath := range importsOfFile(file) { |
|
if _, ok := p.imports[pkgName]; !ok { |
|
p.imports[pkgName] = pkgPath |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { |
|
intf := &model.Interface{Name: name} |
|
for _, field := range it.Methods.List { |
|
switch v := field.Type.(type) { |
|
case *ast.FuncType: |
|
if nn := len(field.Names); nn != 1 { |
|
return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn) |
|
} |
|
m := &model.Method{ |
|
Name: field.Names[0].String(), |
|
} |
|
var err error |
|
m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) |
|
if err != nil { |
|
return nil, err |
|
} |
|
intf.Methods = append(intf.Methods, m) |
|
case *ast.Ident: |
|
// Embedded interface in this package. |
|
ei := p.auxInterfaces[pkg][v.String()] |
|
if ei == nil { |
|
if ei = p.importedInterfaces[pkg][v.String()]; ei == nil { |
|
return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) |
|
} |
|
} |
|
eintf, err := p.parseInterface(v.String(), pkg, ei) |
|
if err != nil { |
|
return nil, err |
|
} |
|
// Copy the methods. |
|
// TODO: apply shadowing rules. |
|
intf.Methods = append(intf.Methods, eintf.Methods...) |
|
case *ast.SelectorExpr: |
|
// Embedded interface in another package. |
|
fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() |
|
epkg, ok := p.imports[fpkg] |
|
if !ok { |
|
return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg) |
|
} |
|
ei := p.auxInterfaces[fpkg][sel] |
|
if ei == nil { |
|
fpkg = epkg |
|
if _, ok = p.importedInterfaces[epkg]; !ok { |
|
if err := p.parsePackage(epkg); err != nil { |
|
return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err) |
|
} |
|
} |
|
if ei = p.importedInterfaces[epkg][sel]; ei == nil { |
|
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel) |
|
} |
|
} |
|
eintf, err := p.parseInterface(sel, fpkg, ei) |
|
if err != nil { |
|
return nil, err |
|
} |
|
// Copy the methods. |
|
// TODO: apply shadowing rules. |
|
intf.Methods = append(intf.Methods, eintf.Methods...) |
|
default: |
|
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) |
|
} |
|
} |
|
return intf, nil |
|
} |
|
|
|
func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) { |
|
if f.Params != nil { |
|
regParams := f.Params.List |
|
if isVariadic(f) { |
|
n := len(regParams) |
|
varParams := regParams[n-1:] |
|
regParams = regParams[:n-1] |
|
vp, err := p.parseFieldList(pkg, varParams) |
|
if err != nil { |
|
return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) |
|
} |
|
variadic = vp[0] |
|
} |
|
in, err = p.parseFieldList(pkg, regParams) |
|
if err != nil { |
|
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) |
|
} |
|
} |
|
if f.Results != nil { |
|
out, err = p.parseFieldList(pkg, f.Results.List) |
|
if err != nil { |
|
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) |
|
} |
|
} |
|
return |
|
} |
|
|
|
func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { |
|
nf := 0 |
|
for _, f := range fields { |
|
nn := len(f.Names) |
|
if nn == 0 { |
|
nn = 1 // anonymous parameter |
|
} |
|
nf += nn |
|
} |
|
if nf == 0 { |
|
return nil, nil |
|
} |
|
ps := make([]*model.Parameter, nf) |
|
i := 0 // destination index |
|
for _, f := range fields { |
|
t, err := p.parseType(pkg, f.Type) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if len(f.Names) == 0 { |
|
// anonymous arg |
|
ps[i] = &model.Parameter{Type: t} |
|
i++ |
|
continue |
|
} |
|
for _, name := range f.Names { |
|
ps[i] = &model.Parameter{Name: name.Name, Type: t} |
|
i++ |
|
} |
|
} |
|
return ps, nil |
|
} |
|
|
|
func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { |
|
switch v := typ.(type) { |
|
case *ast.ArrayType: |
|
ln := -1 |
|
if v.Len != nil { |
|
x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value) |
|
if err != nil { |
|
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) |
|
} |
|
ln = x |
|
} |
|
t, err := p.parseType(pkg, v.Elt) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &model.ArrayType{Len: ln, Type: t}, nil |
|
case *ast.ChanType: |
|
t, err := p.parseType(pkg, v.Value) |
|
if err != nil { |
|
return nil, err |
|
} |
|
var dir model.ChanDir |
|
if v.Dir == ast.SEND { |
|
dir = model.SendDir |
|
} |
|
if v.Dir == ast.RECV { |
|
dir = model.RecvDir |
|
} |
|
return &model.ChanType{Dir: dir, Type: t}, nil |
|
case *ast.Ellipsis: |
|
// assume we're parsing a variadic argument |
|
return p.parseType(pkg, v.Elt) |
|
case *ast.FuncType: |
|
in, variadic, out, err := p.parseFunc(pkg, v) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil |
|
case *ast.Ident: |
|
if v.IsExported() { |
|
// `pkg` may be an aliased imported pkg |
|
// if so, patch the import w/ the fully qualified import |
|
maybeImportedPkg, ok := p.imports[pkg] |
|
if ok { |
|
pkg = maybeImportedPkg |
|
} |
|
// assume type in this package |
|
return &model.NamedType{Package: pkg, Type: v.Name}, nil |
|
} |
|
return model.PredeclaredType(v.Name), nil |
|
|
|
case *ast.InterfaceType: |
|
if v.Methods != nil && len(v.Methods.List) > 0 { |
|
return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") |
|
} |
|
return model.PredeclaredType("interface{}"), nil |
|
case *ast.MapType: |
|
key, err := p.parseType(pkg, v.Key) |
|
if err != nil { |
|
return nil, err |
|
} |
|
value, err := p.parseType(pkg, v.Value) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &model.MapType{Key: key, Value: value}, nil |
|
case *ast.SelectorExpr: |
|
pkgName := v.X.(*ast.Ident).String() |
|
pkg, ok := p.imports[pkgName] |
|
if !ok { |
|
return nil, p.errorf(v.Pos(), "unknown package %q", pkgName) |
|
} |
|
return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil |
|
case *ast.StarExpr: |
|
t, err := p.parseType(pkg, v.X) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &model.PointerType{Type: t}, nil |
|
case *ast.StructType: |
|
if v.Fields != nil && len(v.Fields.List) > 0 { |
|
return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types") |
|
} |
|
return model.PredeclaredType("struct{}"), nil |
|
} |
|
|
|
return nil, fmt.Errorf("don't know how to parse type %T", typ) |
|
} |
|
|
|
// importsOfFile returns a map of package name to import path |
|
// of the imports in file. |
|
func importsOfFile(file *ast.File) map[string]string { |
|
m := make(map[string]string) |
|
for _, is := range file.Imports { |
|
var pkgName string |
|
importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes |
|
|
|
if is.Name != nil { |
|
// Named imports are always certain. |
|
if is.Name.Name == "_" { |
|
continue |
|
} |
|
pkgName = removeDot(is.Name.Name) |
|
} else { |
|
pkg, err := build.Import(importPath, "", 0) |
|
if err != nil { |
|
// Fallback to import path suffix. Note that this is uncertain. |
|
_, last := path.Split(importPath) |
|
// If the last path component has dots, the first dot-delimited |
|
// field is used as the name. |
|
pkgName = strings.SplitN(last, ".", 2)[0] |
|
} else { |
|
pkgName = pkg.Name |
|
} |
|
} |
|
|
|
if _, ok := m[pkgName]; ok { |
|
log.Fatalf("imported package collision: %q imported twice", pkgName) |
|
} |
|
m[pkgName] = importPath |
|
} |
|
return m |
|
} |
|
|
|
type namedInterface struct { |
|
name *ast.Ident |
|
it *ast.InterfaceType |
|
} |
|
|
|
// Create an iterator over all interfaces in file. |
|
func iterInterfaces(file *ast.File) <-chan namedInterface { |
|
ch := make(chan namedInterface) |
|
go func() { |
|
for _, decl := range file.Decls { |
|
gd, ok := decl.(*ast.GenDecl) |
|
if !ok || gd.Tok != token.TYPE { |
|
continue |
|
} |
|
for _, spec := range gd.Specs { |
|
ts, ok := spec.(*ast.TypeSpec) |
|
if !ok { |
|
continue |
|
} |
|
it, ok := ts.Type.(*ast.InterfaceType) |
|
if !ok { |
|
continue |
|
} |
|
|
|
ch <- namedInterface{ts.Name, it} |
|
} |
|
} |
|
close(ch) |
|
}() |
|
return ch |
|
} |
|
|
|
// isVariadic returns whether the function is variadic. |
|
func isVariadic(f *ast.FuncType) bool { |
|
nargs := len(f.Params.List) |
|
if nargs == 0 { |
|
return false |
|
} |
|
_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) |
|
return ok |
|
}
|
|
|