You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
439 lines
12 KiB
439 lines
12 KiB
// Copyright 2010 Google Inc. |
|
// |
|
// Licensed under the Apache License, Version 2.0 (the "License"); |
|
// you may not use this file except in compliance with the License. |
|
// You may obtain a copy of the License at |
|
// |
|
// http://www.apache.org/licenses/LICENSE-2.0 |
|
// |
|
// Unless required by applicable law or agreed to in writing, software |
|
// distributed under the License is distributed on an "AS IS" BASIS, |
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
// See the License for the specific language governing permissions and |
|
// limitations under the License. |
|
|
|
// Package mockgen generates mock implementations of Go interfaces. |
|
package mockgen |
|
|
|
// TODO: This does not support recursive embedded interfaces. |
|
// TODO: This does not support embedding package-local interfaces in a separate file. |
|
|
|
import ( |
|
"bytes" |
|
"fmt" |
|
"go/format" |
|
"go/token" |
|
"log" |
|
"path" |
|
"sort" |
|
"strconv" |
|
"strings" |
|
"unicode" |
|
|
|
"github.com/otokaze/mock/mockgen/model" |
|
) |
|
|
|
const ( |
|
gomockImportPath = "github.com/otokaze/mock/gomock" |
|
) |
|
|
|
var ( |
|
imports, auxFiles, buildFlags, execOnly string |
|
progOnly bool |
|
) |
|
|
|
// Generator a generator struct. |
|
type Generator struct { |
|
buf bytes.Buffer |
|
indent string |
|
MockNames map[string]string //may be empty |
|
Filename string // may be empty |
|
SrcPackage, SrcInterfaces string // may be empty |
|
|
|
packageMap map[string]string // map from import path to package name |
|
} |
|
|
|
func (g *Generator) p(format string, args ...interface{}) { |
|
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) |
|
} |
|
|
|
func (g *Generator) in() { |
|
g.indent += "\t" |
|
} |
|
|
|
func (g *Generator) out() { |
|
if len(g.indent) > 0 { |
|
g.indent = g.indent[0 : len(g.indent)-1] |
|
} |
|
} |
|
|
|
func removeDot(s string) string { |
|
if len(s) > 0 && s[len(s)-1] == '.' { |
|
return s[0 : len(s)-1] |
|
} |
|
return s |
|
} |
|
|
|
// sanitize cleans up a string to make a suitable package name. |
|
func sanitize(s string) string { |
|
t := "" |
|
for _, r := range s { |
|
if t == "" { |
|
if unicode.IsLetter(r) || r == '_' { |
|
t += string(r) |
|
continue |
|
} |
|
} else { |
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { |
|
t += string(r) |
|
continue |
|
} |
|
} |
|
t += "_" |
|
} |
|
if t == "_" { |
|
t = "x" |
|
} |
|
return t |
|
} |
|
|
|
// Generate gen mock by pkg. |
|
func (g *Generator) Generate(pkg *model.Package, pkgName string, outputPackagePath string) error { |
|
g.p("// Code generated by MockGen. DO NOT EDIT.") |
|
if g.Filename != "" { |
|
g.p("// Source: %v", g.Filename) |
|
} else { |
|
g.p("// Source: %v (interfaces: %v)", g.SrcPackage, g.SrcInterfaces) |
|
} |
|
g.p("") |
|
|
|
// Get all required imports, and generate unique names for them all. |
|
im := pkg.Imports() |
|
im[gomockImportPath] = true |
|
|
|
// Only import reflect if it's used. We only use reflect in mocked methods |
|
// so only import if any of the mocked interfaces have methods. |
|
for _, intf := range pkg.Interfaces { |
|
if len(intf.Methods) > 0 { |
|
im["reflect"] = true |
|
break |
|
} |
|
} |
|
|
|
// Sort keys to make import alias generation predictable |
|
sortedPaths := make([]string, len(im)) |
|
x := 0 |
|
for pth := range im { |
|
sortedPaths[x] = pth |
|
x++ |
|
} |
|
sort.Strings(sortedPaths) |
|
|
|
g.packageMap = make(map[string]string, len(im)) |
|
localNames := make(map[string]bool, len(im)) |
|
for _, pth := range sortedPaths { |
|
base := sanitize(path.Base(pth)) |
|
|
|
// Local names for an imported package can usually be the basename of the import path. |
|
// A couple of situations don't permit that, such as duplicate local names |
|
// (e.g. importing "html/template" and "text/template"), or where the basename is |
|
// a keyword (e.g. "foo/case"). |
|
// try base0, base1, ... |
|
pkgName := base |
|
i := 0 |
|
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { |
|
pkgName = base + strconv.Itoa(i) |
|
i++ |
|
} |
|
|
|
g.packageMap[pth] = pkgName |
|
localNames[pkgName] = true |
|
} |
|
|
|
g.p("// Package %v is a generated GoMock package.", pkgName) |
|
g.p("package %v", pkgName) |
|
g.p("") |
|
g.p("import (") |
|
g.in() |
|
for path, pkg := range g.packageMap { |
|
if path == outputPackagePath { |
|
continue |
|
} |
|
g.p("%v %q", pkg, path) |
|
} |
|
for _, path := range pkg.DotImports { |
|
g.p(". %q", path) |
|
} |
|
g.out() |
|
g.p(")") |
|
|
|
for _, intf := range pkg.Interfaces { |
|
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// The name of the mock type to use for the given interface identifier. |
|
func (g *Generator) mockName(typeName string) string { |
|
if mockName, ok := g.MockNames[typeName]; ok { |
|
return mockName |
|
} |
|
|
|
return "Mock" + typeName |
|
} |
|
|
|
// GenerateMockInterface gen mock intf. |
|
func (g *Generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { |
|
mockType := g.mockName(intf.Name) |
|
|
|
g.p("") |
|
g.p("// %v is a mock of %v interface", mockType, intf.Name) |
|
g.p("type %v struct {", mockType) |
|
g.in() |
|
g.p("ctrl *gomock.Controller") |
|
g.p("recorder *%vMockRecorder", mockType) |
|
g.out() |
|
g.p("}") |
|
g.p("") |
|
|
|
g.p("// %vMockRecorder is the mock recorder for %v", mockType, mockType) |
|
g.p("type %vMockRecorder struct {", mockType) |
|
g.in() |
|
g.p("mock *%v", mockType) |
|
g.out() |
|
g.p("}") |
|
g.p("") |
|
|
|
// TODO: Re-enable this if we can import the interface reliably. |
|
//g.p("// Verify that the mock satisfies the interface at compile time.") |
|
//g.p("var _ %v = (*%v)(nil)", typeName, mockType) |
|
//g.p("") |
|
|
|
g.p("// New%v creates a new mock instance", mockType) |
|
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) |
|
g.in() |
|
g.p("mock := &%v{ctrl: ctrl}", mockType) |
|
g.p("mock.recorder = &%vMockRecorder{mock}", mockType) |
|
g.p("return mock") |
|
g.out() |
|
g.p("}") |
|
g.p("") |
|
|
|
// XXX: possible name collision here if someone has EXPECT in their interface. |
|
g.p("// EXPECT returns an object that allows the caller to indicate expected use") |
|
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) |
|
g.in() |
|
g.p("return m.recorder") |
|
g.out() |
|
g.p("}") |
|
|
|
g.GenerateMockMethods(mockType, intf, outputPackagePath) |
|
|
|
return nil |
|
} |
|
|
|
// GenerateMockMethods gen mock methods. |
|
func (g *Generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { |
|
for _, m := range intf.Methods { |
|
g.p("") |
|
g.GenerateMockMethod(mockType, m, pkgOverride) |
|
g.p("") |
|
g.GenerateMockRecorderMethod(mockType, m) |
|
} |
|
} |
|
|
|
func makeArgString(argNames, argTypes []string) string { |
|
args := make([]string, len(argNames)) |
|
for i, name := range argNames { |
|
// specify the type only once for consecutive args of the same type |
|
if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { |
|
args[i] = name |
|
} else { |
|
args[i] = name + " " + argTypes[i] |
|
} |
|
} |
|
return strings.Join(args, ", ") |
|
} |
|
|
|
// GenerateMockMethod generates a mock method implementation. |
|
// If non-empty, pkgOverride is the package in which unqualified types reside. |
|
func (g *Generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { |
|
argNames := g.getArgNames(m) |
|
argTypes := g.getArgTypes(m, pkgOverride) |
|
argString := makeArgString(argNames, argTypes) |
|
|
|
rets := make([]string, len(m.Out)) |
|
for i, p := range m.Out { |
|
rets[i] = p.Type.String(g.packageMap, pkgOverride) |
|
} |
|
retString := strings.Join(rets, ", ") |
|
if len(rets) > 1 { |
|
retString = "(" + retString + ")" |
|
} |
|
if retString != "" { |
|
retString = " " + retString |
|
} |
|
|
|
ia := newIdentifierAllocator(argNames) |
|
idRecv := ia.allocateIdentifier("m") |
|
|
|
g.p("// %v mocks base method", m.Name) |
|
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) |
|
g.in() |
|
|
|
var callArgs string |
|
if m.Variadic == nil { |
|
if len(argNames) > 0 { |
|
callArgs = ", " + strings.Join(argNames, ", ") |
|
} |
|
} else { |
|
// Non-trivial. The generated code must build a []interface{}, |
|
// but the variadic argument may be any type. |
|
idVarArgs := ia.allocateIdentifier("varargs") |
|
idVArg := ia.allocateIdentifier("a") |
|
g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) |
|
g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) |
|
g.in() |
|
g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) |
|
g.out() |
|
g.p("}") |
|
callArgs = ", " + idVarArgs + "..." |
|
} |
|
if len(m.Out) == 0 { |
|
g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) |
|
} else { |
|
idRet := ia.allocateIdentifier("ret") |
|
g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) |
|
|
|
// Go does not allow "naked" type assertions on nil values, so we use the two-value form here. |
|
// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. |
|
// Happily, this coincides with the semantics we want here. |
|
retNames := make([]string, len(rets)) |
|
for i, t := range rets { |
|
retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) |
|
g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) |
|
} |
|
g.p("return " + strings.Join(retNames, ", ")) |
|
} |
|
|
|
g.out() |
|
g.p("}") |
|
return nil |
|
} |
|
|
|
// GenerateMockRecorderMethod gen mock recorder method. |
|
func (g *Generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { |
|
argNames := g.getArgNames(m) |
|
|
|
var argString string |
|
if m.Variadic == nil { |
|
argString = strings.Join(argNames, ", ") |
|
} else { |
|
argString = strings.Join(argNames[:len(argNames)-1], ", ") |
|
} |
|
if argString != "" { |
|
argString += " interface{}" |
|
} |
|
|
|
if m.Variadic != nil { |
|
if argString != "" { |
|
argString += ", " |
|
} |
|
argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) |
|
} |
|
|
|
ia := newIdentifierAllocator(argNames) |
|
idRecv := ia.allocateIdentifier("mr") |
|
|
|
g.p("// %v indicates an expected call of %v", m.Name, m.Name) |
|
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) |
|
g.in() |
|
|
|
var callArgs string |
|
if m.Variadic == nil { |
|
if len(argNames) > 0 { |
|
callArgs = ", " + strings.Join(argNames, ", ") |
|
} |
|
} else { |
|
if len(argNames) == 1 { |
|
// Easy: just use ... to push the arguments through. |
|
callArgs = ", " + argNames[0] + "..." |
|
} else { |
|
// Hard: create a temporary slice. |
|
idVarArgs := ia.allocateIdentifier("varargs") |
|
g.p("%s := append([]interface{}{%s}, %s...)", |
|
idVarArgs, |
|
strings.Join(argNames[:len(argNames)-1], ", "), |
|
argNames[len(argNames)-1]) |
|
callArgs = ", " + idVarArgs + "..." |
|
} |
|
} |
|
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) |
|
|
|
g.out() |
|
g.p("}") |
|
return nil |
|
} |
|
|
|
func (g *Generator) getArgNames(m *model.Method) []string { |
|
argNames := make([]string, len(m.In)) |
|
for i, p := range m.In { |
|
name := p.Name |
|
if name == "" { |
|
name = fmt.Sprintf("arg%d", i) |
|
} |
|
argNames[i] = name |
|
} |
|
if m.Variadic != nil { |
|
name := m.Variadic.Name |
|
if name == "" { |
|
name = fmt.Sprintf("arg%d", len(m.In)) |
|
} |
|
argNames = append(argNames, name) |
|
} |
|
return argNames |
|
} |
|
|
|
func (g *Generator) getArgTypes(m *model.Method, pkgOverride string) []string { |
|
argTypes := make([]string, len(m.In)) |
|
for i, p := range m.In { |
|
argTypes[i] = p.Type.String(g.packageMap, pkgOverride) |
|
} |
|
if m.Variadic != nil { |
|
argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) |
|
} |
|
return argTypes |
|
} |
|
|
|
type identifierAllocator map[string]struct{} |
|
|
|
func newIdentifierAllocator(taken []string) identifierAllocator { |
|
a := make(identifierAllocator, len(taken)) |
|
for _, s := range taken { |
|
a[s] = struct{}{} |
|
} |
|
return a |
|
} |
|
|
|
func (o identifierAllocator) allocateIdentifier(want string) string { |
|
id := want |
|
for i := 2; ; i++ { |
|
if _, ok := o[id]; !ok { |
|
o[id] = struct{}{} |
|
return id |
|
} |
|
id = want + "_" + strconv.Itoa(i) |
|
} |
|
} |
|
|
|
// Output returns the generator's output, formatted in the standard Go style. |
|
func (g *Generator) Output() []byte { |
|
src, err := format.Source(g.buf.Bytes()) |
|
if err != nil { |
|
log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) |
|
} |
|
return src |
|
}
|
|
|