main.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "fmt"
  6. "go/ast"
  7. "io/ioutil"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "text/template"
  15. "go-common/app/tool/cache/common"
  16. )
  17. var (
  18. // arguments
  19. singleFlight = flag.Bool("singleflight", false, "enable singleflight")
  20. nullCache = flag.String("nullcache", "", "null cache")
  21. checkNullCode = flag.String("check_null_code", "", "check null code")
  22. batchSize = flag.Int("batch", 0, "batch size")
  23. batchErr = flag.String("batch_err", "break", "batch err to contine or break")
  24. maxGroup = flag.Int("max_group", 0, "max group size")
  25. sync = flag.Bool("sync", false, "add cache in sync way.")
  26. paging = flag.Bool("paging", false, "use paging in single template")
  27. ignores = flag.String("ignores", "", "ignore params")
  28. numberTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64"}
  29. simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
  30. optionNames = []string{"singleflight", "nullcache", "check_null_code", "batch", "max_group", "sync", "paging", "ignores", "batch_err"}
  31. optionNamesMap = map[string]bool{}
  32. )
  33. const (
  34. _interfaceName = "_cache"
  35. _multiTpl = 1
  36. _singleTpl = 2
  37. _noneTpl = 3
  38. )
  39. func resetFlag() {
  40. *singleFlight = false
  41. *nullCache = ""
  42. *checkNullCode = ""
  43. *batchSize = 0
  44. *maxGroup = 0
  45. *sync = false
  46. *paging = false
  47. *batchErr = "break"
  48. *ignores = ""
  49. }
  50. // options options
  51. type options struct {
  52. name string
  53. keyType string
  54. valueType string
  55. cacheFunc string
  56. rawFunc string
  57. addCacheFunc string
  58. template int
  59. SimpleValue bool
  60. NumberValue bool
  61. GoValue bool
  62. ZeroValue string
  63. ImportPackage string
  64. importPackages []string
  65. Args string
  66. PkgName string
  67. EnableSingleFlight bool
  68. NullCache string
  69. EnableNullCache bool
  70. GroupSize int
  71. MaxGroup int
  72. EnableBatch bool
  73. BatchErrBreak bool
  74. Sync bool
  75. CheckNullCode string
  76. ExtraArgsType string
  77. ExtraArgs string
  78. ExtraCacheArgs string
  79. ExtraRawArgs string
  80. ExtraAddCacheArgs string
  81. EnablePaging bool
  82. Comment string
  83. }
  84. // parse parse options
  85. func parse(s *common.Source) (opts []*options) {
  86. f := s.F
  87. fset := s.Fset
  88. src := s.Src
  89. c := f.Scope.Lookup(_interfaceName)
  90. if (c == nil) || (c.Kind != ast.Typ) {
  91. log.Fatalln("无法找到缓存声明")
  92. }
  93. lines := strings.Split(src, "\n")
  94. lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
  95. for _, list := range lists {
  96. opt := options{Args: s.GetDef(_interfaceName), importPackages: s.Packages(list)}
  97. // get comment
  98. line := fset.Position(list.Pos()).Line - 3
  99. if len(lines)-1 >= line {
  100. comment := lines[line]
  101. opt.Comment = common.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
  102. opt.Comment = strings.TrimSpace(opt.Comment)
  103. }
  104. // get options
  105. line = fset.Position(list.Pos()).Line - 2
  106. comment := lines[line]
  107. os.Args = []string{os.Args[0]}
  108. if regexp.MustCompile(`\s+//\s*cache:.+`).Match([]byte(comment)) {
  109. args := strings.Split(common.RegexpReplace(`//\s*cache:(?P<arg>.+)`, comment, "$arg"), " ")
  110. for _, arg := range args {
  111. arg = strings.TrimSpace(arg)
  112. if arg != "" {
  113. // validate option name
  114. argName := common.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
  115. if !optionNamesMap[argName] {
  116. log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
  117. }
  118. os.Args = append(os.Args, arg)
  119. }
  120. }
  121. }
  122. resetFlag()
  123. flag.Parse()
  124. opt.EnableSingleFlight = *singleFlight
  125. opt.NullCache = *nullCache
  126. opt.EnablePaging = *paging
  127. opt.EnableNullCache = *nullCache != ""
  128. opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
  129. opt.BatchErrBreak = *batchErr == "break"
  130. opt.Sync = *sync
  131. opt.CheckNullCode = *checkNullCode
  132. opt.GroupSize = *batchSize
  133. opt.MaxGroup = *maxGroup
  134. // get func
  135. opt.name = list.Names[0].Name
  136. params := list.Type.(*ast.FuncType).Params.List
  137. if len(params) == 0 {
  138. log.Fatalln(opt.name + "参数不足")
  139. }
  140. if s.ExprString(params[0].Type) != "context.Context" {
  141. log.Fatalln("第一个参数必须为context")
  142. }
  143. if len(params) == 1 {
  144. opt.template = _noneTpl
  145. } else {
  146. if _, ok := params[1].Type.(*ast.ArrayType); ok {
  147. opt.template = _multiTpl
  148. } else {
  149. opt.template = _singleTpl
  150. // get key
  151. opt.keyType = s.ExprString(params[1].Type)
  152. }
  153. }
  154. if len(params) > 2 {
  155. var args []string
  156. var allArgs []string
  157. for _, pa := range params[2:] {
  158. paType := s.ExprString(pa.Type)
  159. if len(pa.Names) == 0 {
  160. args = append(args, paType)
  161. allArgs = append(allArgs, paType)
  162. continue
  163. }
  164. var names []string
  165. for _, name := range pa.Names {
  166. names = append(names, name.Name)
  167. }
  168. allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
  169. args = append(args, names...)
  170. }
  171. opt.ExtraArgs = strings.Join(args, ",")
  172. opt.ExtraArgsType = strings.Join(allArgs, ",")
  173. argsMap := make(map[string]bool)
  174. for _, arg := range args {
  175. argsMap[arg] = true
  176. }
  177. ignoreCache := make(map[string]bool)
  178. ignoreRaw := make(map[string]bool)
  179. ignoreAddCache := make(map[string]bool)
  180. ignoreArray := [3]map[string]bool{ignoreCache, ignoreRaw, ignoreAddCache}
  181. if *ignores != "" {
  182. is := strings.Split(*ignores, "|")
  183. if len(is) > 3 {
  184. log.Fatalln("ignores参数错误")
  185. }
  186. for i := range is {
  187. if len(is) > i {
  188. for _, s := range strings.Split(is[i], ",") {
  189. ignoreArray[i][s] = true
  190. }
  191. }
  192. }
  193. }
  194. var as []string
  195. for _, arg := range args {
  196. if !ignoreCache[arg] {
  197. as = append(as, arg)
  198. }
  199. }
  200. opt.ExtraCacheArgs = strings.Join(as, ",")
  201. as = []string{}
  202. for _, arg := range args {
  203. if !ignoreRaw[arg] {
  204. as = append(as, arg)
  205. }
  206. }
  207. opt.ExtraRawArgs = strings.Join(as, ",")
  208. as = []string{}
  209. for _, arg := range args {
  210. if !ignoreAddCache[arg] {
  211. as = append(as, arg)
  212. }
  213. }
  214. opt.ExtraAddCacheArgs = strings.Join(as, ",")
  215. if opt.ExtraAddCacheArgs != "" {
  216. opt.ExtraAddCacheArgs = "," + opt.ExtraAddCacheArgs
  217. }
  218. if opt.ExtraRawArgs != "" {
  219. opt.ExtraRawArgs = "," + opt.ExtraRawArgs
  220. }
  221. if opt.ExtraCacheArgs != "" {
  222. opt.ExtraCacheArgs = "," + opt.ExtraCacheArgs
  223. }
  224. if opt.ExtraArgs != "" {
  225. opt.ExtraArgs = "," + opt.ExtraArgs
  226. }
  227. if opt.ExtraArgsType != "" {
  228. opt.ExtraArgsType = "," + opt.ExtraArgsType
  229. }
  230. }
  231. // get k v from results
  232. results := list.Type.(*ast.FuncType).Results.List
  233. if len(results) != 2 {
  234. log.Fatalln(opt.name + ": 参数个数不对")
  235. }
  236. if s.ExprString(results[1].Type) != "error" {
  237. log.Fatalln(opt.name + ": 最后返回值参数需为error")
  238. }
  239. if opt.template == _multiTpl {
  240. p, ok := results[0].Type.(*ast.MapType)
  241. if !ok {
  242. log.Fatalln(opt.name + ": 批量获取方法 返回值类型需为map类型")
  243. }
  244. opt.keyType = s.ExprString(p.Key)
  245. opt.valueType = s.ExprString(p.Value)
  246. } else {
  247. opt.valueType = s.ExprString(results[0].Type)
  248. }
  249. for _, t := range numberTypes {
  250. if t == opt.valueType {
  251. opt.NumberValue = true
  252. break
  253. }
  254. }
  255. opt.ZeroValue = "nil"
  256. for _, t := range simpleTypes {
  257. if t == opt.valueType {
  258. opt.SimpleValue = true
  259. opt.ZeroValue = zeroValue(t)
  260. break
  261. }
  262. }
  263. if !opt.SimpleValue {
  264. for _, t := range []string{"[]", "map"} {
  265. if strings.HasPrefix(opt.valueType, t) {
  266. opt.GoValue = true
  267. break
  268. }
  269. }
  270. }
  271. upperName := strings.ToUpper(opt.name[0:1]) + opt.name[1:]
  272. opt.cacheFunc = fmt.Sprintf("d.Cache%s", upperName)
  273. opt.rawFunc = fmt.Sprintf("d.Raw%s", upperName)
  274. opt.addCacheFunc = fmt.Sprintf("d.AddCache%s", upperName)
  275. opt.Check()
  276. opts = append(opts, &opt)
  277. }
  278. return
  279. }
  280. func (option *options) Check() {
  281. if !option.SimpleValue && !strings.Contains(option.valueType, "*") && !strings.Contains(option.valueType, "[]") && !strings.Contains(option.valueType, "map") {
  282. log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
  283. }
  284. if option.EnableSingleFlight && option.EnableBatch {
  285. log.Fatalf("%s: 单飞和批量获取不能同时开启\n", option.name)
  286. }
  287. if option.template != _singleTpl && option.EnablePaging {
  288. log.Fatalf("%s: 分页只能用在单key模板中\n", option.name)
  289. }
  290. if option.SimpleValue && !option.EnableNullCache {
  291. if !((option.template == _multiTpl) && option.NumberValue) {
  292. log.Fatalf("%s: 值为基本类型时需开启空缓存 防止缓存零值穿透\n", option.name)
  293. }
  294. }
  295. if option.EnableNullCache {
  296. if !option.SimpleValue && option.CheckNullCode == "" {
  297. log.Fatalf("%s: 缺少-check_null_code参数\n", option.name)
  298. }
  299. if option.SimpleValue && option.NullCache == option.ZeroValue {
  300. log.Fatalf("%s: %s 不能作为空缓存值 \n", option.name, option.NullCache)
  301. }
  302. if strings.Contains(option.NullCache, "{}") {
  303. // -nullcache=[]*model.OrderMain{} 这种无效
  304. log.Fatalf("%s: %s 不能作为空缓存值 会导致空缓存无效 \n", option.name, option.NullCache)
  305. }
  306. if strings.Contains(option.CheckNullCode, "len") && strings.Contains(strings.Replace(option.CheckNullCode, " ", "", -1), "==0") {
  307. // -check_null_code=len($)==0 这种无效
  308. log.Fatalf("%s: -check_null_code=%s 错误 会有无意义的赋值\n", option.name, option.CheckNullCode)
  309. }
  310. }
  311. }
  312. func genHeader(opts []*options) (src string) {
  313. option := options{PkgName: os.Getenv("GOPACKAGE")}
  314. var sfCount int
  315. var packages, sfInit []string
  316. packagesMap := map[string]bool{`"context"`: true}
  317. for _, opt := range opts {
  318. if opt.EnableSingleFlight {
  319. option.EnableSingleFlight = true
  320. sfCount++
  321. }
  322. if opt.EnableBatch {
  323. option.EnableBatch = true
  324. }
  325. if len(opt.importPackages) > 0 {
  326. for _, pkg := range opt.importPackages {
  327. if !packagesMap[pkg] {
  328. packages = append(packages, pkg)
  329. packagesMap[pkg] = true
  330. }
  331. }
  332. }
  333. if opt.Args != "" {
  334. option.Args = opt.Args
  335. }
  336. }
  337. option.ImportPackage = strings.Join(packages, "\n")
  338. for i := 0; i < sfCount; i++ {
  339. sfInit = append(sfInit, "{}")
  340. }
  341. src = _headerTemplate
  342. src = strings.Replace(src, "SFCOUNT", strconv.Itoa(sfCount), -1)
  343. t := template.Must(template.New("header").Parse(src))
  344. var buffer bytes.Buffer
  345. err := t.Execute(&buffer, option)
  346. if err != nil {
  347. log.Fatalf("execute template: %s", err)
  348. }
  349. // Format the output.
  350. src = strings.Replace(buffer.String(), "\t", "", -1)
  351. src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
  352. src = strings.Replace(src, "NEWLINE", "", -1)
  353. src = strings.Replace(src, "ARGS", option.Args, -1)
  354. src = strings.Replace(src, "SFINIT", strings.Join(sfInit, ","), -1)
  355. return
  356. }
  357. func genBody(opts []*options) (res string) {
  358. sfnum := -1
  359. for _, option := range opts {
  360. var nullCodeVar, src string
  361. if option.template == _multiTpl {
  362. src = _multiTemplate
  363. nullCodeVar = "v"
  364. } else if option.template == _singleTpl {
  365. src = _singleTemplate
  366. nullCodeVar = "res"
  367. } else {
  368. src = _noneTemplate
  369. nullCodeVar = "res"
  370. }
  371. if option.template != _noneTpl {
  372. src = strings.Replace(src, "KEY", option.keyType, -1)
  373. }
  374. if option.CheckNullCode != "" {
  375. option.CheckNullCode = strings.Replace(option.CheckNullCode, "$", nullCodeVar, -1)
  376. }
  377. if option.EnableSingleFlight {
  378. sfnum++
  379. }
  380. src = strings.Replace(src, "NAME", option.name, -1)
  381. src = strings.Replace(src, "VALUE", option.valueType, -1)
  382. src = strings.Replace(src, "ADDCACHEFUNC", option.addCacheFunc, -1)
  383. src = strings.Replace(src, "CACHEFUNC", option.cacheFunc, -1)
  384. src = strings.Replace(src, "RAWFUNC", option.rawFunc, -1)
  385. src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
  386. src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
  387. src = strings.Replace(src, "SFNUM", strconv.Itoa(sfnum), -1)
  388. t := template.Must(template.New("cache").Parse(src))
  389. var buffer bytes.Buffer
  390. err := t.Execute(&buffer, option)
  391. if err != nil {
  392. log.Fatalf("execute template: %s", err)
  393. }
  394. // Format the output.
  395. src = strings.Replace(buffer.String(), "\t", "", -1)
  396. src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
  397. res = res + "\n" + src
  398. }
  399. return
  400. }
  401. func zeroValue(t string) string {
  402. switch t {
  403. case "bool":
  404. return "false"
  405. case "string":
  406. return "\"\""
  407. case "[]byte":
  408. return "nil"
  409. default:
  410. return "0"
  411. }
  412. }
  413. func init() {
  414. for _, name := range optionNames {
  415. optionNamesMap[name] = true
  416. }
  417. }
  418. func main() {
  419. log.SetFlags(0)
  420. defer func() {
  421. if err := recover(); err != nil {
  422. log.Fatalf("程序解析失败, err: %+v 请企业微信联系 @wangxu01", err)
  423. }
  424. }()
  425. options := parse(common.NewSource(common.SourceText()))
  426. header := genHeader(options)
  427. body := genBody(options)
  428. code := common.FormatCode(header + "\n" + body)
  429. // Write to file.
  430. dir := filepath.Dir(".")
  431. outputName := filepath.Join(dir, "dao.cache.go")
  432. err := ioutil.WriteFile(outputName, []byte(code), 0644)
  433. if err != nil {
  434. log.Fatalf("写入文件失败: %s", err)
  435. }
  436. log.Println("dao.cache.go: 生成成功")
  437. }