zlimit.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package zlimit
  2. import (
  3. "context"
  4. "time"
  5. "go-common/app/service/main/location/model"
  6. "go-common/library/cache/redis"
  7. "go-common/library/database/sql"
  8. "go-common/library/ecode"
  9. "go-common/library/log"
  10. xip "go-common/library/net/ip"
  11. xtime "go-common/library/time"
  12. "go-common/library/xstr"
  13. )
  14. const (
  15. _prefixBlackList = "zl_"
  16. )
  17. // Config default struct
  18. type Config struct {
  19. DB *sql.Config
  20. Redis *Redis
  21. IPFile string
  22. FlushTime xtime.Duration
  23. }
  24. // Redis redis struct
  25. type Redis struct {
  26. *redis.Config
  27. Expire xtime.Duration
  28. }
  29. // Service zlimit service struct
  30. type Service struct {
  31. // mysql
  32. db *sql.DB
  33. getPolicyStmt *sql.Stmt
  34. getRelationStmt *sql.Stmt
  35. getGroupPolicyStmt *sql.Stmt
  36. // redis
  37. redis *redis.Pool
  38. expire int32
  39. flushTime time.Duration
  40. // cache
  41. policy map[int64]map[int64]int64
  42. groupPolicy map[int64][]int64
  43. missch chan interface{}
  44. // xip
  45. list *xip.List
  46. }
  47. // New new zlimit service
  48. func New(c *Config) (s *Service) {
  49. var err error
  50. s = &Service{
  51. db: sql.NewMySQL(c.DB),
  52. redis: redis.NewPool(c.Redis.Config),
  53. expire: int32(time.Duration(c.Redis.Expire) / time.Second),
  54. missch: make(chan interface{}, 1024),
  55. policy: make(map[int64]map[int64]int64),
  56. groupPolicy: make(map[int64][]int64),
  57. flushTime: time.Duration(c.FlushTime),
  58. }
  59. s.getPolicyStmt = s.db.Prepared(_getPolicySQL)
  60. s.getRelationStmt = s.db.Prepared(_getRelationSQL)
  61. s.getGroupPolicyStmt = s.db.Prepared(_getGolbalPolicySQL)
  62. s.load()
  63. s.list, err = xip.New(c.IPFile)
  64. if err != nil {
  65. log.Error("xip.New(%s) error(%v)", c.IPFile, err)
  66. panic(err)
  67. }
  68. go s.reloadproc()
  69. go s.cacheproc()
  70. return
  71. }
  72. func (s *Service) load() {
  73. var (
  74. tmpPolicy map[int64]map[int64]int64
  75. tmpGroupPolicy map[int64][]int64
  76. err error
  77. )
  78. if tmpPolicy, err = s.policies(context.TODO()); err != nil {
  79. log.Error("s.policies error(%v)", err)
  80. } else if len(tmpPolicy) > 0 {
  81. s.policy = tmpPolicy
  82. }
  83. if tmpGroupPolicy, err = s.groupPolicies(context.TODO()); err != nil {
  84. log.Error("s.groupPolicies error(%v)", err)
  85. } else if len(tmpGroupPolicy) > 0 {
  86. s.groupPolicy = tmpGroupPolicy
  87. }
  88. }
  89. // reloadproc reload data from db
  90. func (s *Service) reloadproc() {
  91. for {
  92. s.load()
  93. time.Sleep(s.flushTime)
  94. }
  95. }
  96. // Find redio rule by aid and ipaddr
  97. func (s *Service) Find(c context.Context, aid int64, ipaddr, cdnip string) (ret, retdown int64, err error) {
  98. var (
  99. ok bool
  100. auth, pid, zid, gid int64
  101. rules, pids []int64
  102. zids map[int64]int64
  103. ipInfo *xip.Zone
  104. )
  105. ipInfo = s.list.Zone(ipaddr)
  106. if (ipInfo != nil) && (ipInfo.Province == "共享地址" || ipInfo.City == "共享地址") && cdnip != "" {
  107. ipInfo = s.list.Zone(cdnip)
  108. }
  109. if ipInfo == nil {
  110. ret = model.Allow
  111. retdown = model.AllowDown
  112. return
  113. }
  114. uz := s.zoneids(ipInfo) // country, state, city
  115. if ok, err = s.existsRule(c, aid); err != nil {
  116. log.Error("s.existsRule error(%v)", err)
  117. err = nil
  118. } else if ok {
  119. if rules, err = s.rule(c, aid, uz); err != nil {
  120. log.Error("s.rule(%d) error(%v) ", aid, err)
  121. err = nil
  122. } else {
  123. for _, auth = range rules {
  124. retdown = 0xff & auth
  125. ret = auth >> 8
  126. if ret != 0 {
  127. break
  128. }
  129. }
  130. if ret == 0 {
  131. ret = model.Allow
  132. retdown = model.AllowDown
  133. }
  134. return
  135. }
  136. }
  137. if gid, err = s.groupid(c, aid); err != nil {
  138. return
  139. } else if gid != 0 {
  140. if pids, ok = s.groupPolicy[gid]; ok {
  141. for _, pid = range pids {
  142. if zids, ok = s.policy[pid]; !ok {
  143. continue
  144. }
  145. if ret == 0 {
  146. // ret already set skip check
  147. for _, zid = range uz {
  148. if auth, ok = zids[zid]; ok {
  149. if ret == 0 {
  150. retdown = 0xff & auth
  151. ret = auth >> 8 // ret must not be zero
  152. break
  153. }
  154. }
  155. }
  156. }
  157. tmpZids := map[int64]map[int64]int64{
  158. aid: zids,
  159. }
  160. s.addCache(tmpZids)
  161. }
  162. if ret == 0 {
  163. ret = model.Allow
  164. retdown = model.AllowDown
  165. }
  166. return
  167. }
  168. }
  169. ret = model.Allow
  170. retdown = model.AllowDown
  171. zids = make(map[int64]int64)
  172. zids[0] = ret<<8 | retdown
  173. tmpZids := map[int64]map[int64]int64{
  174. aid: zids,
  175. }
  176. s.addCache(tmpZids)
  177. return
  178. }
  179. // Forbid check ip is forbid or not.
  180. func (s *Service) Forbid(c context.Context, pstr string, ipaddr string) (err error) {
  181. if pstr == "" {
  182. return
  183. }
  184. var (
  185. ret int64
  186. pids []int64
  187. )
  188. if pids, err = xstr.SplitInts(pstr); err != nil {
  189. log.Error("xstr.SplitInts(%s) error(%v)", pstr, err)
  190. return
  191. }
  192. if ret, _ = s.FindByPid(c, pids, ipaddr); ret == model.Forbidden {
  193. err = ecode.ZlimitForbidden
  194. }
  195. return
  196. }
  197. // FindByPid redio rule by policy_id and ipaddr
  198. func (s *Service) FindByPid(c context.Context, pids []int64, ipaddr string) (ret, retdown int64) {
  199. var (
  200. ok bool
  201. auth int64
  202. zoneids []int64
  203. )
  204. ret = model.Allow
  205. retdown = model.AllowDown
  206. ipInfo := s.list.Zone(ipaddr)
  207. if ipInfo == nil {
  208. return
  209. }
  210. zoneids = s.zoneids(ipInfo)
  211. for _, pid := range pids {
  212. if _, ok = s.policy[pid]; !ok {
  213. continue
  214. }
  215. for _, zoneid := range zoneids {
  216. if auth, ok = s.policy[pid][zoneid]; ok {
  217. retdown = 0xff & auth
  218. ret = auth >> 8
  219. break
  220. }
  221. }
  222. }
  223. return
  224. }
  225. // FindByGid redio rule by group_id and ipaddr(or cdnip)
  226. func (s *Service) FindByGid(c context.Context, gid int64, ipaddr, cdnip string) (ret, retdown int64) {
  227. var ipInfo *xip.Zone
  228. ret = model.Allow
  229. retdown = model.AllowDown
  230. ipInfo = s.list.Zone(ipaddr)
  231. if (ipInfo != nil) && (ipInfo.Province == "共享地址" || ipInfo.City == "共享地址") && cdnip != "" {
  232. ipInfo = s.list.Zone(cdnip)
  233. }
  234. if ipInfo == nil {
  235. return
  236. }
  237. zoneids := s.zoneids(ipInfo)
  238. if pids, ok := s.groupPolicy[gid]; ok {
  239. for _, pid := range pids {
  240. if _, ok := s.policy[pid]; !ok {
  241. continue
  242. }
  243. for _, zoneid := range zoneids {
  244. if auth, ok := s.policy[pid][zoneid]; ok {
  245. retdown = 0xff & auth
  246. ret = auth >> 8
  247. break
  248. }
  249. }
  250. }
  251. }
  252. return
  253. }
  254. // zoneids make zoneids
  255. func (s *Service) zoneids(ipinfos *xip.Zone) []int64 {
  256. cZid := xip.ZoneID(ipinfos.Country, "", "")
  257. cpZid := xip.ZoneID(ipinfos.Country, ipinfos.Province, "")
  258. cpcZid := xip.ZoneID(ipinfos.Country, ipinfos.Province, ipinfos.City)
  259. zoneids := []int64{0, cZid, cpZid, cpcZid}
  260. return zoneids
  261. }
  262. func (s *Service) addCache(d interface{}) {
  263. // asynchronous add rules to redis
  264. select {
  265. case s.missch <- d:
  266. default:
  267. log.Warn("cacheproc chan full")
  268. }
  269. }
  270. // cacheproc is a routine for add rules into redis.
  271. func (s *Service) cacheproc() {
  272. for {
  273. d := <-s.missch
  274. switch d.(type) {
  275. case map[int64]map[int64]int64:
  276. v := d.(map[int64]map[int64]int64)
  277. if err := s.addRule(context.TODO(), v); err != nil {
  278. log.Error("s.addRule error(%v) error(%v)", v, err)
  279. }
  280. default:
  281. log.Warn("cacheproc can't process the type")
  282. }
  283. }
  284. }