|
- package main
- import (
- "encoding/json"
- "errors"
- "flag"
- "fmt"
- "go/ast"
- "go/parser"
- "go/token"
- "os"
- "path"
- "path/filepath"
- "reflect"
- "runtime"
- "strings"
- )
- // gloabl var.
- var (
- ErrParams = errors.New("err params")
- _gopath = filepath.SplitList(os.Getenv("GOPATH"))
- )
- var (
- dir string
- pkgs = make(map[string]*ast.Package)
- rlpkgs = make(map[string]*ast.Package)
- definitions = make(map[string]*Schema)
- swagger = Swagger{
- Definitions: make(map[string]*Schema),
- Paths: make(map[string]*Item),
- SwaggerVersion: "2.0",
- Infos: Information{
- Title: "go-common api",
- Description: "api",
- Version: "1.0",
- Contact: Contact{
- EMail: "lintanghui@bilibili.com",
- },
- License: &License{
- Name: "Apache 2.0",
- URL: "http://www.apache.org/licenses/LICENSE-2.0.html",
- },
- },
- }
- stdlibObject = map[string]string{
- "&{time Time}": "time.Time",
- }
- )
- // refer to builtin.go
- var basicTypes = map[string]string{
- "bool": "boolean:",
- "uint": "integer:int32",
- "uint8": "integer:int32",
- "uint16": "integer:int32",
- "uint32": "integer:int32",
- "uint64": "integer:int64",
- "int": "integer:int64",
- "int8": "integer:int32",
- "int16": "integer:int32",
- "int32": "integer:int32",
- "int64": "integer:int64",
- "uintptr": "integer:int64",
- "float32": "number:float",
- "float64": "number:double",
- "string": "string:",
- "complex64": "number:float",
- "complex128": "number:double",
- "byte": "string:byte",
- "rune": "string:byte",
- // builtin golang objects
- "time.Time": "string:string",
- }
- func main() {
- flag.StringVar(&dir, "d", "./", "specific project dir")
- flag.Parse()
- err := ParseFromDir(dir)
- if err != nil {
- panic(err)
- }
- parseModel(pkgs)
- parseModel(rlpkgs)
- parseRouter()
- fd, err := os.Create(path.Join(dir, "swagger.json"))
- if err != nil {
- panic(err)
- }
- b, _ := json.MarshalIndent(swagger, "", " ")
- fd.Write(b)
- }
- // ParseFromDir parse ast pkg from dir.
- func ParseFromDir(dir string) (err error) {
- filepath.Walk(dir, func(fpath string, fileInfo os.FileInfo, err error) error {
- if err != nil {
- return nil
- }
- if !fileInfo.IsDir() {
- return nil
- }
- err = parseFromDir(fpath)
- return err
- })
- return
- }
- func parseFromDir(dir string) (err error) {
- fset := token.NewFileSet()
- pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
- name := info.Name()
- return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
- }, parser.ParseComments)
- if err != nil {
- return
- }
- for k, p := range pkgFolder {
- pkgs[k] = p
- }
- return
- }
- func parseImport(dir string) (err error) {
- fset := token.NewFileSet()
- pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
- name := info.Name()
- return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
- }, parser.ParseComments)
- if err != nil {
- return
- }
- for k, p := range pkgFolder {
- rlpkgs[k] = p
- }
- return
- }
- func parseModel(pkgs map[string]*ast.Package) {
- for _, p := range pkgs {
- for _, f := range p.Files {
- for _, im := range f.Imports {
- if !isSystemPackage(im.Path.Value) {
- for _, gp := range _gopath {
- path := gp + "/src/" + strings.Trim(im.Path.Value, "\"")
- if isExist(path) {
- parseImport(path)
- }
- }
- }
- }
- scom := parseStructComment(f)
- for _, obj := range f.Scope.Objects {
- if obj.Kind == ast.Typ {
- objName := obj.Name
- schema := &Schema{
- Title: objName,
- Type: "object",
- }
- ts, ok := obj.Decl.(*ast.TypeSpec)
- if !ok {
- fmt.Printf("obj type error %v ", obj.Kind)
- }
- st, ok := ts.Type.(*ast.StructType)
- if !ok {
- continue
- }
- properites := make(map[string]*Propertie)
- for _, fd := range st.Fields.List {
- if len(fd.Names) == 0 {
- continue
- }
- name, required, omit, desc := parseFieldTag(fd)
- if omit {
- continue
- }
- isSlice, realType, sType := typeAnalyser(fd)
- if (isSlice && isBasicType(realType)) || sType == "object" {
- if len(strings.Split(realType, " ")) > 1 {
- realType = strings.Replace(realType, " ", ".", -1)
- realType = strings.Replace(realType, "&", "", -1)
- realType = strings.Replace(realType, "{", "", -1)
- realType = strings.Replace(realType, "}", "", -1)
- }
- }
- mp := &Propertie{}
- if isSlice {
- mp.Type = "array"
- if isBasicType(strings.Replace(realType, "[]", "", -1)) {
- typeFormat := strings.Split(sType, ":")
- mp.Items = &Propertie{
- Type: typeFormat[0],
- Format: typeFormat[1],
- }
- } else {
- ss := strings.Split(realType, ".")
- mp.RefImport = ss[len(ss)-1]
- mp.Type = "array"
- mp.Items = &Propertie{
- Ref: "#/definitions/" + mp.RefImport,
- Type: sType,
- }
- }
- } else {
- if sType == "object" {
- ss := strings.Split(realType, ".")
- mp.RefImport = ss[len(ss)-1]
- mp.Type = sType
- mp.Ref = "#/definitions/" + mp.RefImport
- } else if isBasicType(realType) {
- typeFormat := strings.Split(sType, ":")
- mp.Type = typeFormat[0]
- mp.Format = typeFormat[1]
- } else if realType == "map" {
- typeFormat := strings.Split(sType, ":")
- mp.AdditionalProperties = &Propertie{
- Type: typeFormat[0],
- Format: typeFormat[1],
- }
- }
- }
- if name == "" {
- name = fd.Names[0].Name
- }
- if required {
- schema.Required = append(schema.Required, name)
- }
- mp.Description = desc
- if scm, ok := scom[obj.Name]; ok {
- if cm, ok := scm.field[fd.Names[0].Name]; ok {
- mp.Description = cm + desc
- }
- }
- properites[name] = mp
- }
- if scm, ok := scom[obj.Name]; ok {
- schema.Description = scm.comment
- }
- schema.Properties = properites
- definitions[schema.Title] = schema
- }
- }
- }
- }
- }
- func parseFieldTag(field *ast.Field) (name string, required, omit bool, tagDes string) {
- if field.Tag == nil {
- return
- }
- tag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
- param := tag.Get("form")
- if param != "" {
- params := strings.Split(param, ",")
- if len(params) > 0 {
- name = params[0]
- }
- if len(params) == 2 && params[1] == "split" {
- tagDes = "数组,按逗号分隔"
- }
- }
- if def := tag.Get("default"); def != "" {
- tagDes = fmt.Sprintf("%s 默认值 %s", tagDes, def)
- }
- validate := tag.Get("validate")
- if validate != "" {
- params := strings.Split(validate, ",")
- for _, param := range params {
- switch {
- case param == "required":
- required = true
- case strings.HasPrefix(param, "min"):
- tagDes = fmt.Sprintf("%s 最小值 %s", tagDes, strings.Split(param, "=")[1])
- case strings.HasPrefix(param, "max"):
- tagDes = fmt.Sprintf("%s 最大值 %s", tagDes, strings.Split(param, "=")[1])
- }
- }
- }
- // parse json response.
- json := tag.Get("json")
- if json != "" {
- jsons := strings.Split(json, ",")
- if len(jsons) > 0 {
- if jsons[0] == "-" {
- omit = true
- return
- }
- }
- }
- return
- }
- func parseRouter() {
- for _, p := range pkgs {
- if p.Name != "http" {
- continue
- }
- fmt.Printf("开始解析生成swagger文档\n")
- for _, f := range p.Files {
- for _, decl := range f.Decls {
- if fdecl, ok := decl.(*ast.FuncDecl); ok {
- if fdecl.Doc != nil {
- path, req, resp, item, err := parseFuncDoc(fdecl.Doc)
- if err != nil {
- fmt.Printf("解析失败 注解错误 %v\n", err)
- continue
- }
- if path != "" && err == nil {
- fmt.Printf("解析 %s 完成 请求参数为 %s 返回结构为 %s\n", path, req, resp)
- swagger.Paths[path] = item
- }
- }
- }
- }
- }
- }
- }
- func parseFuncDoc(f *ast.CommentGroup) (path, reqObj, respObj string, item *Item, err error) {
- item = new(Item)
- op := new(Operation)
- params := make([]*Parameter, 0)
- response := make(map[string]*Response)
- for _, d := range f.List {
- t := strings.TrimSpace(strings.TrimPrefix(d.Text, "//"))
- content := strings.Split(t, " ")
- switch content[0] {
- case "@params":
- if len(content) < 2 {
- err = fmt.Errorf("err params %s", content)
- return
- }
- reqObj = content[1]
- if model, ok := definitions[content[1]]; ok {
- for n, p := range model.Properties {
- param := &Parameter{
- In: "query",
- Name: n,
- Description: p.Description,
- Type: p.Type,
- Format: p.Format,
- }
- for _, p := range model.Required {
- if p == n {
- param.Required = true
- }
- }
- params = append(params, param)
- }
- } else {
- err = ErrParams
- return
- }
- case "@router":
- if len(content) != 3 {
- err = ErrParams
- return
- }
- switch content[1] {
- case "get":
- item.Get = op
- case "post":
- item.Post = op
- }
- path = content[2]
- op.OperationID = path
- case "@response":
- if len(content) < 2 {
- err = fmt.Errorf("err response %s", content)
- return
- }
- var (
- isarray bool
- ismap bool
- )
- if strings.HasPrefix(content[1], "[]") {
- isarray = true
- respObj = content[1][2:]
- } else if strings.HasPrefix(content[1], "map[]") {
- ismap = true
- respObj = content[1][5:]
- } else {
- respObj = content[1]
- }
- defini, ok := definitions[respObj]
- if !ok {
- err = ErrParams
- return
- }
- var resp *Propertie
- if isarray {
- resp = &Propertie{
- Type: "array",
- Items: &Propertie{
- Type: "object",
- Ref: "#/definitions/" + respObj,
- },
- }
- } else if ismap {
- resp = &Propertie{
- Type: "object",
- AdditionalProperties: &Propertie{
- Ref: "#/definitions/" + respObj,
- },
- }
- } else {
- resp = &Propertie{
- Type: "object",
- Ref: "#/definitions/" + respObj,
- }
- }
- response["200"] = &Response{
- Schema: &Schema{
- Type: "object",
- Properties: map[string]*Propertie{
- "code": &Propertie{
- Type: "integer",
- Description: "错误码描述",
- },
- "data": resp,
- "message": &Propertie{
- Type: "string",
- Description: "错误码文本描述",
- },
- "ttl": &Propertie{
- Type: "integer",
- Format: "int64",
- Description: "客户端限速时间",
- },
- },
- },
- Description: "服务成功响应内容",
- }
- op.Responses = response
- for _, rl := range defini.Properties {
- if rl.RefImport != "" {
- swagger.Definitions[rl.RefImport] = definitions[rl.RefImport]
- }
- }
- swagger.Definitions[respObj] = defini
- case "@description":
- op.Description = content[1]
- }
- }
- op.Parameters = params
- return
- }
- type structComment struct {
- comment string
- field map[string]string
- }
- func parseStructComment(f *ast.File) (scom map[string]structComment) {
- scom = make(map[string]structComment)
- for _, d := range f.Decls {
- switch specDecl := d.(type) {
- case *ast.GenDecl:
- if specDecl.Tok == token.TYPE {
- for _, s := range specDecl.Specs {
- switch tp := s.(*ast.TypeSpec).Type.(type) {
- case *ast.StructType:
- fcom := make(map[string]string)
- for _, fd := range tp.Fields.List {
- if len(fd.Names) == 0 {
- continue
- }
- if len(fd.Comment.Text()) > 0 {
- fcom[fd.Names[0].Name] = strings.TrimSuffix(fd.Comment.Text(), "\n")
- }
- }
- sspec := s.(*ast.TypeSpec)
- scom[sspec.Name.String()] = structComment{comment: strings.TrimSuffix(specDecl.Doc.Text(), "\n"), field: fcom}
- }
- }
- }
- }
- }
- return
- }
- func isBasicType(Type string) bool {
- if _, ok := basicTypes[Type]; ok {
- return true
- }
- return false
- }
- func typeAnalyser(f *ast.Field) (isSlice bool, realType, swaggerType string) {
- if arr, ok := f.Type.(*ast.ArrayType); ok {
- if isBasicType(fmt.Sprint(arr.Elt)) {
- return true, fmt.Sprintf("[]%v", arr.Elt), basicTypes[fmt.Sprint(arr.Elt)]
- }
- if mp, ok := arr.Elt.(*ast.MapType); ok {
- return false, fmt.Sprintf("map[%v][%v]", mp.Key, mp.Value), "object"
- }
- if star, ok := arr.Elt.(*ast.StarExpr); ok {
- return true, fmt.Sprint(star.X), "object"
- }
- basicType := fmt.Sprint(arr.Elt)
- if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
- basicType = object
- }
- if k, ok := basicTypes[basicType]; ok {
- return true, basicType, k
- }
- return true, fmt.Sprint(arr.Elt), "object"
- }
- switch t := f.Type.(type) {
- case *ast.StarExpr:
- basicType := fmt.Sprint(t.X)
- if k, ok := basicTypes[basicType]; ok {
- return false, basicType, k
- }
- return false, basicType, "object"
- case *ast.MapType:
- val := fmt.Sprintf("%v", t.Value)
- if isBasicType(val) {
- return false, "map", basicTypes[val]
- }
- return false, val, "object"
- }
- basicType := fmt.Sprint(f.Type)
- if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
- basicType = object
- }
- if k, ok := basicTypes[basicType]; ok {
- return false, basicType, k
- }
- return false, basicType, "object"
- }
- func isSystemPackage(pkgpath string) bool {
- goroot := os.Getenv("GOROOT")
- if goroot == "" {
- goroot = runtime.GOROOT()
- }
- wg, _ := filepath.EvalSymlinks(filepath.Join(goroot, "src", "pkg", pkgpath))
- if isExist(wg) {
- return true
- }
- wg, _ = filepath.EvalSymlinks(filepath.Join(goroot, "src", pkgpath))
- return isExist(wg)
- }
- func isExist(path string) bool {
- _, err := os.Stat(path)
- return err == nil || os.IsExist(err)
- }
|