rule.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. package dao
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "go-common/app/service/main/antispam/util"
  8. "go-common/library/database/sql"
  9. "go-common/library/log"
  10. )
  11. const (
  12. columnRules = "id, area, limit_type, limit_scope, dur_sec, allowed_counts, ctime, mtime"
  13. selectRuleCountsSQL = `SELECT COUNT(1) FROM rate_limit_rules %s`
  14. selectRulesByCondSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules %s`
  15. selectRuleByIDsSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE id IN(%s)`
  16. selectRulesByAreaSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s`
  17. selectRulesByAreaAndTypeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s`
  18. selectRulesByAreaAndTypeAndScopeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s AND limit_scope = %s`
  19. insertRuleSQL = `INSERT INTO rate_limit_rules(area, limit_type, limit_scope, dur_sec, allowed_counts) VALUES(?, ?, ?, ?, ?)`
  20. updateRuleSQL = `UPDATE rate_limit_rules SET dur_sec = ?, allowed_counts = ?, mtime = ? WHERE area = ? AND limit_type = ? AND limit_scope = ?`
  21. )
  22. // Rule .
  23. type Rule struct {
  24. ID int64 `db:"id"`
  25. Area int `db:"area"`
  26. LimitType int `db:"limit_type"`
  27. LimitScope int `db:"limit_scope"`
  28. DurationSec int64 `db:"dur_sec"`
  29. AllowedCounts int64 `db:"allowed_counts"`
  30. CTime time.Time `db:"ctime"`
  31. MTime time.Time `db:"mtime"`
  32. }
  33. // RuleDaoImpl .
  34. type RuleDaoImpl struct{}
  35. const (
  36. // LimitTypeDefaultLimit .
  37. LimitTypeDefaultLimit int = iota
  38. // LimitTypeRestrictLimit .
  39. LimitTypeRestrictLimit
  40. // LimitTypeWhite .
  41. LimitTypeWhite
  42. // LimitTypeBlack .
  43. LimitTypeBlack
  44. )
  45. const (
  46. // LimitScopeGlobal .
  47. LimitScopeGlobal int = iota
  48. // LimitScopeLocal .
  49. LimitScopeLocal
  50. )
  51. // NewRuleDao .
  52. func NewRuleDao() *RuleDaoImpl {
  53. return &RuleDaoImpl{}
  54. }
  55. func updateRule(ctx context.Context, executer Executer, r *Rule) error {
  56. _, err := executer.Exec(ctx,
  57. updateRuleSQL,
  58. r.DurationSec,
  59. r.AllowedCounts,
  60. time.Now(),
  61. r.Area,
  62. r.LimitType,
  63. r.LimitScope,
  64. )
  65. if err != nil {
  66. log.Error("%v", err)
  67. return err
  68. }
  69. return nil
  70. }
  71. func insertRule(ctx context.Context, executer Executer, r *Rule) error {
  72. res, err := executer.Exec(ctx,
  73. insertRuleSQL,
  74. r.Area,
  75. r.LimitType,
  76. r.LimitScope,
  77. r.DurationSec,
  78. r.AllowedCounts,
  79. )
  80. if err != nil {
  81. log.Error("%v", err)
  82. return err
  83. }
  84. lastID, err := res.LastInsertId()
  85. if err != nil {
  86. log.Error("%v", err)
  87. return err
  88. }
  89. r.ID = lastID
  90. return nil
  91. }
  92. // GetByCond .
  93. func (*RuleDaoImpl) GetByCond(ctx context.Context, cond *Condition) (rules []*Rule, totalCounts int64, err error) {
  94. sqlConds := make([]string, 0)
  95. if cond.Area != "" {
  96. sqlConds = append(sqlConds, fmt.Sprintf("area = %s", cond.Area))
  97. }
  98. if cond.State != "" {
  99. sqlConds = append(sqlConds, fmt.Sprintf("state = %s", cond.State))
  100. }
  101. var optionSQL string
  102. if len(sqlConds) > 0 {
  103. optionSQL = fmt.Sprintf("WHERE %s", strings.Join(sqlConds, " AND "))
  104. }
  105. var limitSQL string
  106. if cond.Pagination != nil {
  107. queryCountsSQL := fmt.Sprintf(selectRuleCountsSQL, optionSQL)
  108. totalCounts, err = GetTotalCounts(ctx, db, queryCountsSQL)
  109. if err != nil {
  110. return nil, 0, err
  111. }
  112. offset, limit := cond.OffsetLimit(totalCounts)
  113. if limit == 0 {
  114. return nil, 0, ErrResourceNotExist
  115. }
  116. limitSQL = fmt.Sprintf("LIMIT %d, %d", offset, limit)
  117. }
  118. if cond.OrderBy != "" {
  119. optionSQL = fmt.Sprintf("%s ORDER BY %s %s", optionSQL, cond.OrderBy, cond.Order)
  120. }
  121. if limitSQL != "" {
  122. optionSQL = fmt.Sprintf("%s %s", optionSQL, limitSQL)
  123. }
  124. querySQL := fmt.Sprintf(selectRulesByCondSQL, optionSQL)
  125. log.Info("OptionSQL(%s), GetByCondSQL(%s)", optionSQL, querySQL)
  126. rules, err = queryRules(ctx, db, querySQL)
  127. if err != nil {
  128. return nil, totalCounts, err
  129. }
  130. return rules, totalCounts, nil
  131. }
  132. // Update .
  133. func (rdi *RuleDaoImpl) Update(ctx context.Context, r *Rule) (*Rule, error) {
  134. if err := updateRule(ctx, db, r); err != nil {
  135. return nil, err
  136. }
  137. return rdi.GetByAreaAndTypeAndScope(ctx, &Condition{
  138. Area: fmt.Sprintf("%d", r.Area),
  139. LimitType: fmt.Sprintf("%d", r.LimitType),
  140. LimitScope: fmt.Sprintf("%d", r.LimitScope),
  141. })
  142. }
  143. // Insert .
  144. func (rdi *RuleDaoImpl) Insert(ctx context.Context, r *Rule) (*Rule, error) {
  145. if err := insertRule(ctx, db, r); err != nil {
  146. return nil, err
  147. }
  148. return rdi.GetByID(ctx, r.ID)
  149. }
  150. // GetByID .
  151. func (rdi *RuleDaoImpl) GetByID(ctx context.Context, id int64) (*Rule, error) {
  152. rs, err := rdi.GetByIDs(ctx, []int64{id})
  153. if err != nil {
  154. return nil, err
  155. }
  156. if rs[0] == nil {
  157. return nil, ErrResourceNotExist
  158. }
  159. return rs[0], nil
  160. }
  161. // GetByIDs .
  162. func (*RuleDaoImpl) GetByIDs(ctx context.Context, ids []int64) ([]*Rule, error) {
  163. rs, err := queryRules(ctx, db, fmt.Sprintf(selectRuleByIDsSQL, util.IntSliToSQLVarchars(ids)))
  164. if err != nil {
  165. return nil, err
  166. }
  167. res := make([]*Rule, len(ids))
  168. for i, id := range ids {
  169. for _, r := range rs {
  170. if r.ID == id {
  171. res[i] = r
  172. }
  173. }
  174. }
  175. return res, nil
  176. }
  177. // GetByAreaAndLimitType .
  178. func (*RuleDaoImpl) GetByAreaAndLimitType(ctx context.Context, cond *Condition) ([]*Rule, error) {
  179. return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeSQL, cond.Area, cond.LimitType))
  180. }
  181. // GetByAreaAndTypeAndScope .
  182. func (*RuleDaoImpl) GetByAreaAndTypeAndScope(ctx context.Context, cond *Condition) (*Rule, error) {
  183. rs, err := queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeAndScopeSQL,
  184. cond.Area,
  185. cond.LimitType,
  186. cond.LimitScope,
  187. ))
  188. if err != nil {
  189. return nil, err
  190. }
  191. return rs[0], nil
  192. }
  193. // GetByArea .
  194. func (*RuleDaoImpl) GetByArea(ctx context.Context, cond *Condition) ([]*Rule, error) {
  195. return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaSQL, cond.Area))
  196. }
  197. func queryRules(ctx context.Context, q Querier, rawSQL string) ([]*Rule, error) {
  198. log.Info("Query sql: %q", rawSQL)
  199. rows, err := q.Query(ctx, rawSQL)
  200. if err == sql.ErrNoRows {
  201. err = ErrResourceNotExist
  202. }
  203. if err != nil {
  204. log.Error("Error: %v, RawSQL: %s", err, rawSQL)
  205. return nil, err
  206. }
  207. defer rows.Close()
  208. rs, err := mapRowToRules(rows)
  209. if err != nil {
  210. return nil, err
  211. }
  212. if len(rs) == 0 {
  213. return nil, ErrResourceNotExist
  214. }
  215. return rs, nil
  216. }
  217. func mapRowToRules(rows *sql.Rows) (rs []*Rule, err error) {
  218. for rows.Next() {
  219. r := Rule{}
  220. err = rows.Scan(
  221. &r.ID,
  222. &r.Area,
  223. &r.LimitType,
  224. &r.LimitScope,
  225. &r.DurationSec,
  226. &r.AllowedCounts,
  227. &r.CTime,
  228. &r.MTime,
  229. )
  230. if err != nil {
  231. log.Error("%v", err)
  232. return nil, err
  233. }
  234. rs = append(rs, &r)
  235. }
  236. if err = rows.Err(); err != nil {
  237. log.Error("%v", err)
  238. return nil, err
  239. }
  240. return rs, nil
  241. }