123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- // Copyright 2010 Google Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // Package mockgen generates mock implementations of Go interfaces.
- package mockgen
- // TODO: This does not support recursive embedded interfaces.
- // TODO: This does not support embedding package-local interfaces in a separate file.
- import (
- "bytes"
- "fmt"
- "go/format"
- "go/token"
- "log"
- "path"
- "sort"
- "strconv"
- "strings"
- "unicode"
- "github.com/otokaze/mock/mockgen/model"
- )
- const (
- gomockImportPath = "github.com/otokaze/mock/gomock"
- )
- var (
- imports, auxFiles, buildFlags, execOnly string
- progOnly bool
- )
- // Generator a generator struct.
- type Generator struct {
- buf bytes.Buffer
- indent string
- MockNames map[string]string //may be empty
- Filename string // may be empty
- SrcPackage, SrcInterfaces string // may be empty
- packageMap map[string]string // map from import path to package name
- }
- func (g *Generator) p(format string, args ...interface{}) {
- fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
- }
- func (g *Generator) in() {
- g.indent += "\t"
- }
- func (g *Generator) out() {
- if len(g.indent) > 0 {
- g.indent = g.indent[0 : len(g.indent)-1]
- }
- }
- func removeDot(s string) string {
- if len(s) > 0 && s[len(s)-1] == '.' {
- return s[0 : len(s)-1]
- }
- return s
- }
- // sanitize cleans up a string to make a suitable package name.
- func sanitize(s string) string {
- t := ""
- for _, r := range s {
- if t == "" {
- if unicode.IsLetter(r) || r == '_' {
- t += string(r)
- continue
- }
- } else {
- if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
- t += string(r)
- continue
- }
- }
- t += "_"
- }
- if t == "_" {
- t = "x"
- }
- return t
- }
- // Generate gen mock by pkg.
- func (g *Generator) Generate(pkg *model.Package, pkgName string, outputPackagePath string) error {
- g.p("// Code generated by MockGen. DO NOT EDIT.")
- if g.Filename != "" {
- g.p("// Source: %v", g.Filename)
- } else {
- g.p("// Source: %v (interfaces: %v)", g.SrcPackage, g.SrcInterfaces)
- }
- g.p("")
- // Get all required imports, and generate unique names for them all.
- im := pkg.Imports()
- im[gomockImportPath] = true
- // Only import reflect if it's used. We only use reflect in mocked methods
- // so only import if any of the mocked interfaces have methods.
- for _, intf := range pkg.Interfaces {
- if len(intf.Methods) > 0 {
- im["reflect"] = true
- break
- }
- }
- // Sort keys to make import alias generation predictable
- sortedPaths := make([]string, len(im))
- x := 0
- for pth := range im {
- sortedPaths[x] = pth
- x++
- }
- sort.Strings(sortedPaths)
- g.packageMap = make(map[string]string, len(im))
- localNames := make(map[string]bool, len(im))
- for _, pth := range sortedPaths {
- base := sanitize(path.Base(pth))
- // Local names for an imported package can usually be the basename of the import path.
- // A couple of situations don't permit that, such as duplicate local names
- // (e.g. importing "html/template" and "text/template"), or where the basename is
- // a keyword (e.g. "foo/case").
- // try base0, base1, ...
- pkgName := base
- i := 0
- for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
- pkgName = base + strconv.Itoa(i)
- i++
- }
- g.packageMap[pth] = pkgName
- localNames[pkgName] = true
- }
- g.p("// Package %v is a generated GoMock package.", pkgName)
- g.p("package %v", pkgName)
- g.p("")
- g.p("import (")
- g.in()
- for path, pkg := range g.packageMap {
- if path == outputPackagePath {
- continue
- }
- g.p("%v %q", pkg, path)
- }
- for _, path := range pkg.DotImports {
- g.p(". %q", path)
- }
- g.out()
- g.p(")")
- for _, intf := range pkg.Interfaces {
- if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
- return err
- }
- }
- return nil
- }
- // The name of the mock type to use for the given interface identifier.
- func (g *Generator) mockName(typeName string) string {
- if mockName, ok := g.MockNames[typeName]; ok {
- return mockName
- }
- return "Mock" + typeName
- }
- // GenerateMockInterface gen mock intf.
- func (g *Generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
- mockType := g.mockName(intf.Name)
- g.p("")
- g.p("// %v is a mock of %v interface", mockType, intf.Name)
- g.p("type %v struct {", mockType)
- g.in()
- g.p("ctrl *gomock.Controller")
- g.p("recorder *%vMockRecorder", mockType)
- g.out()
- g.p("}")
- g.p("")
- g.p("// %vMockRecorder is the mock recorder for %v", mockType, mockType)
- g.p("type %vMockRecorder struct {", mockType)
- g.in()
- g.p("mock *%v", mockType)
- g.out()
- g.p("}")
- g.p("")
- // TODO: Re-enable this if we can import the interface reliably.
- //g.p("// Verify that the mock satisfies the interface at compile time.")
- //g.p("var _ %v = (*%v)(nil)", typeName, mockType)
- //g.p("")
- g.p("// New%v creates a new mock instance", mockType)
- g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
- g.in()
- g.p("mock := &%v{ctrl: ctrl}", mockType)
- g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
- g.p("return mock")
- g.out()
- g.p("}")
- g.p("")
- // XXX: possible name collision here if someone has EXPECT in their interface.
- g.p("// EXPECT returns an object that allows the caller to indicate expected use")
- g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
- g.in()
- g.p("return m.recorder")
- g.out()
- g.p("}")
- g.GenerateMockMethods(mockType, intf, outputPackagePath)
- return nil
- }
- // GenerateMockMethods gen mock methods.
- func (g *Generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
- for _, m := range intf.Methods {
- g.p("")
- g.GenerateMockMethod(mockType, m, pkgOverride)
- g.p("")
- g.GenerateMockRecorderMethod(mockType, m)
- }
- }
- func makeArgString(argNames, argTypes []string) string {
- args := make([]string, len(argNames))
- for i, name := range argNames {
- // specify the type only once for consecutive args of the same type
- if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
- args[i] = name
- } else {
- args[i] = name + " " + argTypes[i]
- }
- }
- return strings.Join(args, ", ")
- }
- // GenerateMockMethod generates a mock method implementation.
- // If non-empty, pkgOverride is the package in which unqualified types reside.
- func (g *Generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
- argNames := g.getArgNames(m)
- argTypes := g.getArgTypes(m, pkgOverride)
- argString := makeArgString(argNames, argTypes)
- rets := make([]string, len(m.Out))
- for i, p := range m.Out {
- rets[i] = p.Type.String(g.packageMap, pkgOverride)
- }
- retString := strings.Join(rets, ", ")
- if len(rets) > 1 {
- retString = "(" + retString + ")"
- }
- if retString != "" {
- retString = " " + retString
- }
- ia := newIdentifierAllocator(argNames)
- idRecv := ia.allocateIdentifier("m")
- g.p("// %v mocks base method", m.Name)
- g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
- g.in()
- var callArgs string
- if m.Variadic == nil {
- if len(argNames) > 0 {
- callArgs = ", " + strings.Join(argNames, ", ")
- }
- } else {
- // Non-trivial. The generated code must build a []interface{},
- // but the variadic argument may be any type.
- idVarArgs := ia.allocateIdentifier("varargs")
- idVArg := ia.allocateIdentifier("a")
- g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
- g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
- g.in()
- g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
- g.out()
- g.p("}")
- callArgs = ", " + idVarArgs + "..."
- }
- if len(m.Out) == 0 {
- g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
- } else {
- idRet := ia.allocateIdentifier("ret")
- g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
- // Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
- // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
- // Happily, this coincides with the semantics we want here.
- retNames := make([]string, len(rets))
- for i, t := range rets {
- retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
- g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
- }
- g.p("return " + strings.Join(retNames, ", "))
- }
- g.out()
- g.p("}")
- return nil
- }
- // GenerateMockRecorderMethod gen mock recorder method.
- func (g *Generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
- argNames := g.getArgNames(m)
- var argString string
- if m.Variadic == nil {
- argString = strings.Join(argNames, ", ")
- } else {
- argString = strings.Join(argNames[:len(argNames)-1], ", ")
- }
- if argString != "" {
- argString += " interface{}"
- }
- if m.Variadic != nil {
- if argString != "" {
- argString += ", "
- }
- argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
- }
- ia := newIdentifierAllocator(argNames)
- idRecv := ia.allocateIdentifier("mr")
- g.p("// %v indicates an expected call of %v", m.Name, m.Name)
- g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
- g.in()
- var callArgs string
- if m.Variadic == nil {
- if len(argNames) > 0 {
- callArgs = ", " + strings.Join(argNames, ", ")
- }
- } else {
- if len(argNames) == 1 {
- // Easy: just use ... to push the arguments through.
- callArgs = ", " + argNames[0] + "..."
- } else {
- // Hard: create a temporary slice.
- idVarArgs := ia.allocateIdentifier("varargs")
- g.p("%s := append([]interface{}{%s}, %s...)",
- idVarArgs,
- strings.Join(argNames[:len(argNames)-1], ", "),
- argNames[len(argNames)-1])
- callArgs = ", " + idVarArgs + "..."
- }
- }
- g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
- g.out()
- g.p("}")
- return nil
- }
- func (g *Generator) getArgNames(m *model.Method) []string {
- argNames := make([]string, len(m.In))
- for i, p := range m.In {
- name := p.Name
- if name == "" {
- name = fmt.Sprintf("arg%d", i)
- }
- argNames[i] = name
- }
- if m.Variadic != nil {
- name := m.Variadic.Name
- if name == "" {
- name = fmt.Sprintf("arg%d", len(m.In))
- }
- argNames = append(argNames, name)
- }
- return argNames
- }
- func (g *Generator) getArgTypes(m *model.Method, pkgOverride string) []string {
- argTypes := make([]string, len(m.In))
- for i, p := range m.In {
- argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
- }
- if m.Variadic != nil {
- argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
- }
- return argTypes
- }
- type identifierAllocator map[string]struct{}
- func newIdentifierAllocator(taken []string) identifierAllocator {
- a := make(identifierAllocator, len(taken))
- for _, s := range taken {
- a[s] = struct{}{}
- }
- return a
- }
- func (o identifierAllocator) allocateIdentifier(want string) string {
- id := want
- for i := 2; ; i++ {
- if _, ok := o[id]; !ok {
- o[id] = struct{}{}
- return id
- }
- id = want + "_" + strconv.Itoa(i)
- }
- }
- // Output returns the generator's output, formatted in the standard Go style.
- func (g *Generator) Output() []byte {
- src, err := format.Source(g.buf.Bytes())
- if err != nil {
- log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
- }
- return src
- }
|