dialect_mysql.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package gorm
  2. import (
  3. "crypto/sha1"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode/utf8"
  11. )
  12. type mysql struct {
  13. commonDialect
  14. }
  15. func init() {
  16. RegisterDialect("mysql", &mysql{})
  17. }
  18. func (mysql) GetName() string {
  19. return "mysql"
  20. }
  21. func (mysql) Quote(key string) string {
  22. return fmt.Sprintf("`%s`", key)
  23. }
  24. // Get Data Type for MySQL Dialect
  25. func (s *mysql) DataTypeOf(field *StructField) string {
  26. var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
  27. // MySQL allows only one auto increment column per table, and it must
  28. // be a KEY column.
  29. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
  30. if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
  31. delete(field.TagSettings, "AUTO_INCREMENT")
  32. }
  33. }
  34. if sqlType == "" {
  35. switch dataValue.Kind() {
  36. case reflect.Bool:
  37. sqlType = "boolean"
  38. case reflect.Int8:
  39. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  40. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  41. sqlType = "tinyint AUTO_INCREMENT"
  42. } else {
  43. sqlType = "tinyint"
  44. }
  45. case reflect.Int, reflect.Int16, reflect.Int32:
  46. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  47. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  48. sqlType = "int AUTO_INCREMENT"
  49. } else {
  50. sqlType = "int"
  51. }
  52. case reflect.Uint8:
  53. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  54. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  55. sqlType = "tinyint unsigned AUTO_INCREMENT"
  56. } else {
  57. sqlType = "tinyint unsigned"
  58. }
  59. case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  60. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  61. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  62. sqlType = "int unsigned AUTO_INCREMENT"
  63. } else {
  64. sqlType = "int unsigned"
  65. }
  66. case reflect.Int64:
  67. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  68. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  69. sqlType = "bigint AUTO_INCREMENT"
  70. } else {
  71. sqlType = "bigint"
  72. }
  73. case reflect.Uint64:
  74. if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
  75. field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
  76. sqlType = "bigint unsigned AUTO_INCREMENT"
  77. } else {
  78. sqlType = "bigint unsigned"
  79. }
  80. case reflect.Float32, reflect.Float64:
  81. sqlType = "double"
  82. case reflect.String:
  83. if size > 0 && size < 65532 {
  84. sqlType = fmt.Sprintf("varchar(%d)", size)
  85. } else {
  86. sqlType = "longtext"
  87. }
  88. case reflect.Struct:
  89. if _, ok := dataValue.Interface().(time.Time); ok {
  90. if _, ok := field.TagSettings["NOT NULL"]; ok {
  91. sqlType = "timestamp"
  92. } else {
  93. sqlType = "timestamp NULL"
  94. }
  95. }
  96. default:
  97. if IsByteArrayOrSlice(dataValue) {
  98. if size > 0 && size < 65532 {
  99. sqlType = fmt.Sprintf("varbinary(%d)", size)
  100. } else {
  101. sqlType = "longblob"
  102. }
  103. }
  104. }
  105. }
  106. if sqlType == "" {
  107. panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
  108. }
  109. if strings.TrimSpace(additionalType) == "" {
  110. return sqlType
  111. }
  112. return fmt.Sprintf("%v %v", sqlType, additionalType)
  113. }
  114. func (s mysql) RemoveIndex(tableName string, indexName string) error {
  115. _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
  116. return err
  117. }
  118. func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
  119. if limit != nil {
  120. if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
  121. sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
  122. if offset != nil {
  123. if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
  124. sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
  125. }
  126. }
  127. }
  128. }
  129. return
  130. }
  131. func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
  132. var count int
  133. s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
  134. return count > 0
  135. }
  136. func (s mysql) CurrentDatabase() (name string) {
  137. s.db.QueryRow("SELECT DATABASE()").Scan(&name)
  138. return
  139. }
  140. func (mysql) SelectFromDummyTable() string {
  141. return "FROM DUAL"
  142. }
  143. func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
  144. keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
  145. if utf8.RuneCountInString(keyName) <= 64 {
  146. return keyName
  147. }
  148. h := sha1.New()
  149. h.Write([]byte(keyName))
  150. bs := h.Sum(nil)
  151. // sha1 is 40 digits, keep first 24 characters of destination
  152. destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
  153. if len(destRunes) > 24 {
  154. destRunes = destRunes[:24]
  155. }
  156. return fmt.Sprintf("%s%x", string(destRunes), bs)
  157. }