mockgen.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. // Copyright 2010 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package mockgen generates mock implementations of Go interfaces.
  15. package mockgen
  16. // TODO: This does not support recursive embedded interfaces.
  17. // TODO: This does not support embedding package-local interfaces in a separate file.
  18. import (
  19. "bytes"
  20. "fmt"
  21. "go/format"
  22. "go/token"
  23. "log"
  24. "path"
  25. "sort"
  26. "strconv"
  27. "strings"
  28. "unicode"
  29. "github.com/otokaze/mock/mockgen/model"
  30. )
  31. const (
  32. gomockImportPath = "github.com/otokaze/mock/gomock"
  33. )
  34. var (
  35. imports, auxFiles, buildFlags, execOnly string
  36. progOnly bool
  37. )
  38. // Generator a generator struct.
  39. type Generator struct {
  40. buf bytes.Buffer
  41. indent string
  42. MockNames map[string]string //may be empty
  43. Filename string // may be empty
  44. SrcPackage, SrcInterfaces string // may be empty
  45. packageMap map[string]string // map from import path to package name
  46. }
  47. func (g *Generator) p(format string, args ...interface{}) {
  48. fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
  49. }
  50. func (g *Generator) in() {
  51. g.indent += "\t"
  52. }
  53. func (g *Generator) out() {
  54. if len(g.indent) > 0 {
  55. g.indent = g.indent[0 : len(g.indent)-1]
  56. }
  57. }
  58. func removeDot(s string) string {
  59. if len(s) > 0 && s[len(s)-1] == '.' {
  60. return s[0 : len(s)-1]
  61. }
  62. return s
  63. }
  64. // sanitize cleans up a string to make a suitable package name.
  65. func sanitize(s string) string {
  66. t := ""
  67. for _, r := range s {
  68. if t == "" {
  69. if unicode.IsLetter(r) || r == '_' {
  70. t += string(r)
  71. continue
  72. }
  73. } else {
  74. if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
  75. t += string(r)
  76. continue
  77. }
  78. }
  79. t += "_"
  80. }
  81. if t == "_" {
  82. t = "x"
  83. }
  84. return t
  85. }
  86. // Generate gen mock by pkg.
  87. func (g *Generator) Generate(pkg *model.Package, pkgName string, outputPackagePath string) error {
  88. g.p("// Code generated by MockGen. DO NOT EDIT.")
  89. if g.Filename != "" {
  90. g.p("// Source: %v", g.Filename)
  91. } else {
  92. g.p("// Source: %v (interfaces: %v)", g.SrcPackage, g.SrcInterfaces)
  93. }
  94. g.p("")
  95. // Get all required imports, and generate unique names for them all.
  96. im := pkg.Imports()
  97. im[gomockImportPath] = true
  98. // Only import reflect if it's used. We only use reflect in mocked methods
  99. // so only import if any of the mocked interfaces have methods.
  100. for _, intf := range pkg.Interfaces {
  101. if len(intf.Methods) > 0 {
  102. im["reflect"] = true
  103. break
  104. }
  105. }
  106. // Sort keys to make import alias generation predictable
  107. sortedPaths := make([]string, len(im))
  108. x := 0
  109. for pth := range im {
  110. sortedPaths[x] = pth
  111. x++
  112. }
  113. sort.Strings(sortedPaths)
  114. g.packageMap = make(map[string]string, len(im))
  115. localNames := make(map[string]bool, len(im))
  116. for _, pth := range sortedPaths {
  117. base := sanitize(path.Base(pth))
  118. // Local names for an imported package can usually be the basename of the import path.
  119. // A couple of situations don't permit that, such as duplicate local names
  120. // (e.g. importing "html/template" and "text/template"), or where the basename is
  121. // a keyword (e.g. "foo/case").
  122. // try base0, base1, ...
  123. pkgName := base
  124. i := 0
  125. for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
  126. pkgName = base + strconv.Itoa(i)
  127. i++
  128. }
  129. g.packageMap[pth] = pkgName
  130. localNames[pkgName] = true
  131. }
  132. g.p("// Package %v is a generated GoMock package.", pkgName)
  133. g.p("package %v", pkgName)
  134. g.p("")
  135. g.p("import (")
  136. g.in()
  137. for path, pkg := range g.packageMap {
  138. if path == outputPackagePath {
  139. continue
  140. }
  141. g.p("%v %q", pkg, path)
  142. }
  143. for _, path := range pkg.DotImports {
  144. g.p(". %q", path)
  145. }
  146. g.out()
  147. g.p(")")
  148. for _, intf := range pkg.Interfaces {
  149. if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
  150. return err
  151. }
  152. }
  153. return nil
  154. }
  155. // The name of the mock type to use for the given interface identifier.
  156. func (g *Generator) mockName(typeName string) string {
  157. if mockName, ok := g.MockNames[typeName]; ok {
  158. return mockName
  159. }
  160. return "Mock" + typeName
  161. }
  162. // GenerateMockInterface gen mock intf.
  163. func (g *Generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
  164. mockType := g.mockName(intf.Name)
  165. g.p("")
  166. g.p("// %v is a mock of %v interface", mockType, intf.Name)
  167. g.p("type %v struct {", mockType)
  168. g.in()
  169. g.p("ctrl *gomock.Controller")
  170. g.p("recorder *%vMockRecorder", mockType)
  171. g.out()
  172. g.p("}")
  173. g.p("")
  174. g.p("// %vMockRecorder is the mock recorder for %v", mockType, mockType)
  175. g.p("type %vMockRecorder struct {", mockType)
  176. g.in()
  177. g.p("mock *%v", mockType)
  178. g.out()
  179. g.p("}")
  180. g.p("")
  181. // TODO: Re-enable this if we can import the interface reliably.
  182. //g.p("// Verify that the mock satisfies the interface at compile time.")
  183. //g.p("var _ %v = (*%v)(nil)", typeName, mockType)
  184. //g.p("")
  185. g.p("// New%v creates a new mock instance", mockType)
  186. g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
  187. g.in()
  188. g.p("mock := &%v{ctrl: ctrl}", mockType)
  189. g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
  190. g.p("return mock")
  191. g.out()
  192. g.p("}")
  193. g.p("")
  194. // XXX: possible name collision here if someone has EXPECT in their interface.
  195. g.p("// EXPECT returns an object that allows the caller to indicate expected use")
  196. g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
  197. g.in()
  198. g.p("return m.recorder")
  199. g.out()
  200. g.p("}")
  201. g.GenerateMockMethods(mockType, intf, outputPackagePath)
  202. return nil
  203. }
  204. // GenerateMockMethods gen mock methods.
  205. func (g *Generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
  206. for _, m := range intf.Methods {
  207. g.p("")
  208. g.GenerateMockMethod(mockType, m, pkgOverride)
  209. g.p("")
  210. g.GenerateMockRecorderMethod(mockType, m)
  211. }
  212. }
  213. func makeArgString(argNames, argTypes []string) string {
  214. args := make([]string, len(argNames))
  215. for i, name := range argNames {
  216. // specify the type only once for consecutive args of the same type
  217. if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
  218. args[i] = name
  219. } else {
  220. args[i] = name + " " + argTypes[i]
  221. }
  222. }
  223. return strings.Join(args, ", ")
  224. }
  225. // GenerateMockMethod generates a mock method implementation.
  226. // If non-empty, pkgOverride is the package in which unqualified types reside.
  227. func (g *Generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
  228. argNames := g.getArgNames(m)
  229. argTypes := g.getArgTypes(m, pkgOverride)
  230. argString := makeArgString(argNames, argTypes)
  231. rets := make([]string, len(m.Out))
  232. for i, p := range m.Out {
  233. rets[i] = p.Type.String(g.packageMap, pkgOverride)
  234. }
  235. retString := strings.Join(rets, ", ")
  236. if len(rets) > 1 {
  237. retString = "(" + retString + ")"
  238. }
  239. if retString != "" {
  240. retString = " " + retString
  241. }
  242. ia := newIdentifierAllocator(argNames)
  243. idRecv := ia.allocateIdentifier("m")
  244. g.p("// %v mocks base method", m.Name)
  245. g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
  246. g.in()
  247. var callArgs string
  248. if m.Variadic == nil {
  249. if len(argNames) > 0 {
  250. callArgs = ", " + strings.Join(argNames, ", ")
  251. }
  252. } else {
  253. // Non-trivial. The generated code must build a []interface{},
  254. // but the variadic argument may be any type.
  255. idVarArgs := ia.allocateIdentifier("varargs")
  256. idVArg := ia.allocateIdentifier("a")
  257. g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
  258. g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
  259. g.in()
  260. g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
  261. g.out()
  262. g.p("}")
  263. callArgs = ", " + idVarArgs + "..."
  264. }
  265. if len(m.Out) == 0 {
  266. g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
  267. } else {
  268. idRet := ia.allocateIdentifier("ret")
  269. g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
  270. // Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
  271. // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
  272. // Happily, this coincides with the semantics we want here.
  273. retNames := make([]string, len(rets))
  274. for i, t := range rets {
  275. retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
  276. g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
  277. }
  278. g.p("return " + strings.Join(retNames, ", "))
  279. }
  280. g.out()
  281. g.p("}")
  282. return nil
  283. }
  284. // GenerateMockRecorderMethod gen mock recorder method.
  285. func (g *Generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
  286. argNames := g.getArgNames(m)
  287. var argString string
  288. if m.Variadic == nil {
  289. argString = strings.Join(argNames, ", ")
  290. } else {
  291. argString = strings.Join(argNames[:len(argNames)-1], ", ")
  292. }
  293. if argString != "" {
  294. argString += " interface{}"
  295. }
  296. if m.Variadic != nil {
  297. if argString != "" {
  298. argString += ", "
  299. }
  300. argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
  301. }
  302. ia := newIdentifierAllocator(argNames)
  303. idRecv := ia.allocateIdentifier("mr")
  304. g.p("// %v indicates an expected call of %v", m.Name, m.Name)
  305. g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
  306. g.in()
  307. var callArgs string
  308. if m.Variadic == nil {
  309. if len(argNames) > 0 {
  310. callArgs = ", " + strings.Join(argNames, ", ")
  311. }
  312. } else {
  313. if len(argNames) == 1 {
  314. // Easy: just use ... to push the arguments through.
  315. callArgs = ", " + argNames[0] + "..."
  316. } else {
  317. // Hard: create a temporary slice.
  318. idVarArgs := ia.allocateIdentifier("varargs")
  319. g.p("%s := append([]interface{}{%s}, %s...)",
  320. idVarArgs,
  321. strings.Join(argNames[:len(argNames)-1], ", "),
  322. argNames[len(argNames)-1])
  323. callArgs = ", " + idVarArgs + "..."
  324. }
  325. }
  326. g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
  327. g.out()
  328. g.p("}")
  329. return nil
  330. }
  331. func (g *Generator) getArgNames(m *model.Method) []string {
  332. argNames := make([]string, len(m.In))
  333. for i, p := range m.In {
  334. name := p.Name
  335. if name == "" {
  336. name = fmt.Sprintf("arg%d", i)
  337. }
  338. argNames[i] = name
  339. }
  340. if m.Variadic != nil {
  341. name := m.Variadic.Name
  342. if name == "" {
  343. name = fmt.Sprintf("arg%d", len(m.In))
  344. }
  345. argNames = append(argNames, name)
  346. }
  347. return argNames
  348. }
  349. func (g *Generator) getArgTypes(m *model.Method, pkgOverride string) []string {
  350. argTypes := make([]string, len(m.In))
  351. for i, p := range m.In {
  352. argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
  353. }
  354. if m.Variadic != nil {
  355. argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
  356. }
  357. return argTypes
  358. }
  359. type identifierAllocator map[string]struct{}
  360. func newIdentifierAllocator(taken []string) identifierAllocator {
  361. a := make(identifierAllocator, len(taken))
  362. for _, s := range taken {
  363. a[s] = struct{}{}
  364. }
  365. return a
  366. }
  367. func (o identifierAllocator) allocateIdentifier(want string) string {
  368. id := want
  369. for i := 2; ; i++ {
  370. if _, ok := o[id]; !ok {
  371. o[id] = struct{}{}
  372. return id
  373. }
  374. id = want + "_" + strconv.Itoa(i)
  375. }
  376. }
  377. // Output returns the generator's output, formatted in the standard Go style.
  378. func (g *Generator) Output() []byte {
  379. src, err := format.Source(g.buf.Bytes())
  380. if err != nil {
  381. log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
  382. }
  383. return src
  384. }