sql.go 17 KB


  1. package sql
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "sync/atomic"
  8. "time"
  9. "go-common/library/ecode"
  10. "go-common/library/net/netutil/breaker"
  11. "go-common/library/net/trace"
  12. "github.com/pkg/errors"
  13. )
  14. const (
  15. _family = "sql_client"
  16. )
  17. var (
  18. // ErrStmtNil prepared stmt error
  19. ErrStmtNil = errors.New("sql: prepare failed and stmt nil")
  20. // ErrNoMaster is returned by Master when call master multiple times.
  21. ErrNoMaster = errors.New("sql: no master instance")
  22. // ErrNoRows is returned by Scan when QueryRow doesn't return a row.
  23. // In such a case, QueryRow returns a placeholder *Row value that defers
  24. // this error until a Scan.
  25. ErrNoRows = sql.ErrNoRows
  26. // ErrTxDone transaction done.
  27. ErrTxDone = sql.ErrTxDone
  28. )
  29. // DB database.
  30. type DB struct {
  31. write *conn
  32. read []*conn
  33. idx int64
  34. master *DB
  35. }
  36. // conn database connection
  37. type conn struct {
  38. *sql.DB
  39. breaker breaker.Breaker
  40. conf *Config
  41. }
  42. // Tx transaction.
  43. type Tx struct {
  44. db *conn
  45. tx *sql.Tx
  46. t trace.Trace
  47. c context.Context
  48. cancel func()
  49. }
  50. // Row row.
  51. type Row struct {
  52. err error
  53. *sql.Row
  54. db *conn
  55. query string
  56. args []interface{}
  57. t trace.Trace
  58. cancel func()
  59. }
  60. // Scan copies the columns from the matched row into the values pointed at by dest.
  61. func (r *Row) Scan(dest ...interface{}) (err error) {
  62. if r.t != nil {
  63. defer r.t.Finish(&err)
  64. }
  65. if r.err != nil {
  66. err = r.err
  67. } else if r.Row == nil {
  68. err = ErrStmtNil
  69. }
  70. if err != nil {
  71. return
  72. }
  73. err = r.Row.Scan(dest...)
  74. if r.cancel != nil {
  75. r.cancel()
  76. }
  77. r.db.onBreaker(&err)
  78. if err != ErrNoRows {
  79. err = errors.Wrapf(err, "query %s args %+v", r.query, r.args)
  80. }
  81. return
  82. }
  83. // Rows rows.
  84. type Rows struct {
  85. *sql.Rows
  86. cancel func()
  87. }
  88. // Close closes the Rows, preventing further enumeration. If Next is called
  89. // and returns false and there are no further result sets,
  90. // the Rows are closed automatically and it will suffice to check the
  91. // result of Err. Close is idempotent and does not affect the result of Err.
  92. func (rs *Rows) Close() (err error) {
  93. err = errors.WithStack(rs.Rows.Close())
  94. if rs.cancel != nil {
  95. rs.cancel()
  96. }
  97. return
  98. }
  99. // Stmt prepared stmt.
  100. type Stmt struct {
  101. db *conn
  102. tx bool
  103. query string
  104. stmt atomic.Value
  105. t trace.Trace
  106. }
  107. // Open opens a database specified by its database driver name and a
  108. // driver-specific data source name, usually consisting of at least a database
  109. // name and connection information.
  110. func Open(c *Config) (*DB, error) {
  111. db := new(DB)
  112. d, err := connect(c, c.DSN)
  113. if err != nil {
  114. return nil, err
  115. }
  116. brkGroup := breaker.NewGroup(c.Breaker)
  117. brk := brkGroup.Get(c.Addr)
  118. w := &conn{DB: d, breaker: brk, conf: c}
  119. rs := make([]*conn, 0, len(c.ReadDSN))
  120. for _, rd := range c.ReadDSN {
  121. d, err := connect(c, rd)
  122. if err != nil {
  123. return nil, err
  124. }
  125. brk := brkGroup.Get(parseDSNAddr(rd))
  126. r := &conn{DB: d, breaker: brk, conf: c}
  127. rs = append(rs, r)
  128. }
  129. db.write = w
  130. db.read = rs
  131. db.master = &DB{write: db.write}
  132. return db, nil
  133. }
  134. func connect(c *Config, dataSourceName string) (*sql.DB, error) {
  135. d, err := sql.Open("mysql", dataSourceName)
  136. if err != nil {
  137. err = errors.WithStack(err)
  138. return nil, err
  139. }
  140. d.SetMaxOpenConns(c.Active)
  141. d.SetMaxIdleConns(c.Idle)
  142. d.SetConnMaxLifetime(time.Duration(c.IdleTimeout))
  143. return d, nil
  144. }
  145. // Begin starts a transaction. The isolation level is dependent on the driver.
  146. func (db *DB) Begin(c context.Context) (tx *Tx, err error) {
  147. return db.write.begin(c)
  148. }
  149. // Exec executes a query without returning any rows.
  150. // The args are for any placeholder parameters in the query.
  151. func (db *DB) Exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
  152. return db.write.exec(c, query, args...)
  153. }
  154. // Prepare creates a prepared statement for later queries or executions.
  155. // Multiple queries or executions may be run concurrently from the returned
  156. // statement. The caller must call the statement's Close method when the
  157. // statement is no longer needed.
  158. func (db *DB) Prepare(query string) (*Stmt, error) {
  159. return db.write.prepare(query)
  160. }
  161. // Prepared creates a prepared statement for later queries or executions.
  162. // Multiple queries or executions may be run concurrently from the returned
  163. // statement. The caller must call the statement's Close method when the
  164. // statement is no longer needed.
  165. func (db *DB) Prepared(query string) (stmt *Stmt) {
  166. return db.write.prepared(query)
  167. }
  168. // Query executes a query that returns rows, typically a SELECT. The args are
  169. // for any placeholder parameters in the query.
  170. func (db *DB) Query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
  171. idx := db.readIndex()
  172. for i := range db.read {
  173. if rows, err = db.read[(idx+i)%len(db.read)].query(c, query, args...); !ecode.ServiceUnavailable.Equal(err) {
  174. return
  175. }
  176. }
  177. return db.write.query(c, query, args...)
  178. }
  179. // QueryRow executes a query that is expected to return at most one row.
  180. // QueryRow always returns a non-nil value. Errors are deferred until Row's
  181. // Scan method is called.
  182. func (db *DB) QueryRow(c context.Context, query string, args ...interface{}) *Row {
  183. idx := db.readIndex()
  184. for i := range db.read {
  185. if row := db.read[(idx+i)%len(db.read)].queryRow(c, query, args...); !ecode.ServiceUnavailable.Equal(row.err) {
  186. return row
  187. }
  188. }
  189. return db.write.queryRow(c, query, args...)
  190. }
  191. func (db *DB) readIndex() int {
  192. if len(db.read) == 0 {
  193. return 0
  194. }
  195. v := atomic.AddInt64(&db.idx, 1)
  196. return int(v) % len(db.read)
  197. }
  198. // Close closes the write and read database, releasing any open resources.
  199. func (db *DB) Close() (err error) {
  200. if e := db.write.Close(); e != nil {
  201. err = errors.WithStack(e)
  202. }
  203. for _, rd := range db.read {
  204. if e := rd.Close(); e != nil {
  205. err = errors.WithStack(e)
  206. }
  207. }
  208. return
  209. }
  210. // Ping verifies a connection to the database is still alive, establishing a
  211. // connection if necessary.
  212. func (db *DB) Ping(c context.Context) (err error) {
  213. if err = db.write.ping(c); err != nil {
  214. return
  215. }
  216. for _, rd := range db.read {
  217. if err = rd.ping(c); err != nil {
  218. return
  219. }
  220. }
  221. return
  222. }
  223. // Master return *DB instance direct use master conn
  224. // use this *DB instance only when you have some reason need to get result without any delay.
  225. func (db *DB) Master() *DB {
  226. if db.master == nil {
  227. panic(ErrNoMaster)
  228. }
  229. return db.master
  230. }
  231. func (db *conn) onBreaker(err *error) {
  232. if err != nil && *err != nil && *err != sql.ErrNoRows && *err != sql.ErrTxDone {
  233. db.breaker.MarkFailed()
  234. } else {
  235. db.breaker.MarkSuccess()
  236. }
  237. }
  238. func (db *conn) begin(c context.Context) (tx *Tx, err error) {
  239. now := time.Now()
  240. t, ok := trace.FromContext(c)
  241. if ok {
  242. t = t.Fork(_family, "begin")
  243. t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
  244. defer func() {
  245. if err != nil {
  246. t.Finish(&err)
  247. }
  248. }()
  249. }
  250. if err = db.breaker.Allow(); err != nil {
  251. stats.Incr("mysql:begin", "breaker")
  252. return
  253. }
  254. _, c, cancel := db.conf.TranTimeout.Shrink(c)
  255. rtx, err := db.BeginTx(c, nil)
  256. stats.Timing("mysql:begin", int64(time.Since(now)/time.Millisecond))
  257. if err != nil {
  258. err = errors.WithStack(err)
  259. cancel()
  260. return
  261. }
  262. tx = &Tx{tx: rtx, t: t, db: db, c: c, cancel: cancel}
  263. return
  264. }
  265. func (db *conn) exec(c context.Context, query string, args ...interface{}) (res sql.Result, err error) {
  266. now := time.Now()
  267. if t, ok := trace.FromContext(c); ok {
  268. t = t.Fork(_family, "exec")
  269. t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
  270. defer t.Finish(&err)
  271. }
  272. if err = db.breaker.Allow(); err != nil {
  273. stats.Incr("mysql:exec", "breaker")
  274. return
  275. }
  276. _, c, cancel := db.conf.ExecTimeout.Shrink(c)
  277. res, err = db.ExecContext(c, query, args...)
  278. cancel()
  279. db.onBreaker(&err)
  280. stats.Timing("mysql:exec", int64(time.Since(now)/time.Millisecond))
  281. if err != nil {
  282. err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
  283. }
  284. return
  285. }
  286. func (db *conn) ping(c context.Context) (err error) {
  287. now := time.Now()
  288. if t, ok := trace.FromContext(c); ok {
  289. t = t.Fork(_family, "ping")
  290. t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, ""))
  291. defer t.Finish(&err)
  292. }
  293. if err = db.breaker.Allow(); err != nil {
  294. stats.Incr("mysql:ping", "breaker")
  295. return
  296. }
  297. _, c, cancel := db.conf.ExecTimeout.Shrink(c)
  298. err = db.PingContext(c)
  299. cancel()
  300. db.onBreaker(&err)
  301. stats.Timing("mysql:ping", int64(time.Since(now)/time.Millisecond))
  302. if err != nil {
  303. err = errors.WithStack(err)
  304. }
  305. return
  306. }
  307. func (db *conn) prepare(query string) (*Stmt, error) {
  308. stmt, err := db.Prepare(query)
  309. if err != nil {
  310. err = errors.Wrapf(err, "prepare %s", query)
  311. return nil, err
  312. }
  313. st := &Stmt{query: query, db: db}
  314. st.stmt.Store(stmt)
  315. return st, nil
  316. }
  317. func (db *conn) prepared(query string) (stmt *Stmt) {
  318. stmt = &Stmt{query: query, db: db}
  319. s, err := db.Prepare(query)
  320. if err == nil {
  321. stmt.stmt.Store(s)
  322. return
  323. }
  324. go func() {
  325. for {
  326. s, err := db.Prepare(query)
  327. if err != nil {
  328. time.Sleep(time.Second)
  329. continue
  330. }
  331. stmt.stmt.Store(s)
  332. return
  333. }
  334. }()
  335. return
  336. }
  337. func (db *conn) query(c context.Context, query string, args ...interface{}) (rows *Rows, err error) {
  338. now := time.Now()
  339. if t, ok := trace.FromContext(c); ok {
  340. t = t.Fork(_family, "query")
  341. t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
  342. defer t.Finish(&err)
  343. }
  344. if err = db.breaker.Allow(); err != nil {
  345. stats.Incr("mysql:query", "breaker")
  346. return
  347. }
  348. _, c, cancel := db.conf.QueryTimeout.Shrink(c)
  349. rs, err := db.DB.QueryContext(c, query, args...)
  350. db.onBreaker(&err)
  351. stats.Timing("mysql:query", int64(time.Since(now)/time.Millisecond))
  352. if err != nil {
  353. err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
  354. cancel()
  355. return
  356. }
  357. rows = &Rows{Rows: rs, cancel: cancel}
  358. return
  359. }
  360. func (db *conn) queryRow(c context.Context, query string, args ...interface{}) *Row {
  361. now := time.Now()
  362. t, ok := trace.FromContext(c)
  363. if ok {
  364. t = t.Fork(_family, "queryrow")
  365. t.SetTag(trace.String(trace.TagAddress, db.conf.Addr), trace.String(trace.TagComment, query))
  366. }
  367. if err := db.breaker.Allow(); err != nil {
  368. stats.Incr("mysql:queryrow", "breaker")
  369. return &Row{db: db, t: t, err: err}
  370. }
  371. _, c, cancel := db.conf.QueryTimeout.Shrink(c)
  372. r := db.DB.QueryRowContext(c, query, args...)
  373. stats.Timing("mysql:queryrow", int64(time.Since(now)/time.Millisecond))
  374. return &Row{db: db, Row: r, query: query, args: args, t: t, cancel: cancel}
  375. }
  376. // Close closes the statement.
  377. func (s *Stmt) Close() (err error) {
  378. if s == nil {
  379. err = ErrStmtNil
  380. return
  381. }
  382. stmt, ok := s.stmt.Load().(*sql.Stmt)
  383. if ok {
  384. err = errors.WithStack(stmt.Close())
  385. }
  386. return
  387. }
  388. // Exec executes a prepared statement with the given arguments and returns a
  389. // Result summarizing the effect of the statement.
  390. func (s *Stmt) Exec(c context.Context, args ...interface{}) (res sql.Result, err error) {
  391. if s == nil {
  392. err = ErrStmtNil
  393. return
  394. }
  395. now := time.Now()
  396. if s.tx {
  397. if s.t != nil {
  398. s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
  399. }
  400. } else if t, ok := trace.FromContext(c); ok {
  401. t = t.Fork(_family, "exec")
  402. t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
  403. defer t.Finish(&err)
  404. }
  405. if err = s.db.breaker.Allow(); err != nil {
  406. stats.Incr("mysql:stmt:exec", "breaker")
  407. return
  408. }
  409. stmt, ok := s.stmt.Load().(*sql.Stmt)
  410. if !ok {
  411. err = ErrStmtNil
  412. return
  413. }
  414. _, c, cancel := s.db.conf.ExecTimeout.Shrink(c)
  415. res, err = stmt.ExecContext(c, args...)
  416. cancel()
  417. s.db.onBreaker(&err)
  418. stats.Timing("mysql:stmt:exec", int64(time.Since(now)/time.Millisecond))
  419. if err != nil {
  420. err = errors.Wrapf(err, "exec:%s, args:%+v", s.query, args)
  421. }
  422. return
  423. }
  424. // Query executes a prepared query statement with the given arguments and
  425. // returns the query results as a *Rows.
  426. func (s *Stmt) Query(c context.Context, args ...interface{}) (rows *Rows, err error) {
  427. if s == nil {
  428. err = ErrStmtNil
  429. return
  430. }
  431. if s.tx {
  432. if s.t != nil {
  433. s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
  434. }
  435. } else if t, ok := trace.FromContext(c); ok {
  436. t = t.Fork(_family, "query")
  437. t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
  438. defer t.Finish(&err)
  439. }
  440. if err = s.db.breaker.Allow(); err != nil {
  441. stats.Incr("mysql:stmt:query", "breaker")
  442. return
  443. }
  444. stmt, ok := s.stmt.Load().(*sql.Stmt)
  445. if !ok {
  446. err = ErrStmtNil
  447. return
  448. }
  449. now := time.Now()
  450. _, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
  451. rs, err := stmt.QueryContext(c, args...)
  452. s.db.onBreaker(&err)
  453. stats.Timing("mysql:stmt:query", int64(time.Since(now)/time.Millisecond))
  454. if err != nil {
  455. err = errors.Wrapf(err, "query:%s, args:%+v", s.query, args)
  456. cancel()
  457. return
  458. }
  459. rows = &Rows{Rows: rs, cancel: cancel}
  460. return
  461. }
  462. // QueryRow executes a prepared query statement with the given arguments.
  463. // If an error occurs during the execution of the statement, that error will
  464. // be returned by a call to Scan on the returned *Row, which is always non-nil.
  465. // If the query selects no rows, the *Row's Scan will return ErrNoRows.
  466. // Otherwise, the *Row's Scan scans the first selected row and discards the rest.
  467. func (s *Stmt) QueryRow(c context.Context, args ...interface{}) (row *Row) {
  468. now := time.Now()
  469. row = &Row{db: s.db, query: s.query, args: args}
  470. if s == nil {
  471. row.err = ErrStmtNil
  472. return
  473. }
  474. if s.tx {
  475. if s.t != nil {
  476. s.t.SetTag(trace.String(trace.TagAnnotation, s.query))
  477. }
  478. } else if t, ok := trace.FromContext(c); ok {
  479. t = t.Fork(_family, "queryrow")
  480. t.SetTag(trace.String(trace.TagAddress, s.db.conf.Addr), trace.String(trace.TagComment, s.query))
  481. row.t = t
  482. }
  483. if row.err = s.db.breaker.Allow(); row.err != nil {
  484. stats.Incr("mysql:stmt:queryrow", "breaker")
  485. return
  486. }
  487. stmt, ok := s.stmt.Load().(*sql.Stmt)
  488. if !ok {
  489. return
  490. }
  491. _, c, cancel := s.db.conf.QueryTimeout.Shrink(c)
  492. row.Row = stmt.QueryRowContext(c, args...)
  493. row.cancel = cancel
  494. stats.Timing("mysql:stmt:queryrow", int64(time.Since(now)/time.Millisecond))
  495. return
  496. }
  497. // Commit commits the transaction.
  498. func (tx *Tx) Commit() (err error) {
  499. err = tx.tx.Commit()
  500. tx.cancel()
  501. tx.db.onBreaker(&err)
  502. if tx.t != nil {
  503. tx.t.Finish(&err)
  504. }
  505. if err != nil {
  506. err = errors.WithStack(err)
  507. }
  508. return
  509. }
  510. // Rollback aborts the transaction.
  511. func (tx *Tx) Rollback() (err error) {
  512. err = tx.tx.Rollback()
  513. tx.cancel()
  514. tx.db.onBreaker(&err)
  515. if tx.t != nil {
  516. tx.t.Finish(&err)
  517. }
  518. if err != nil {
  519. err = errors.WithStack(err)
  520. }
  521. return
  522. }
  523. // Exec executes a query that doesn't return rows. For example: an INSERT and
  524. // UPDATE.
  525. func (tx *Tx) Exec(query string, args ...interface{}) (res sql.Result, err error) {
  526. now := time.Now()
  527. if tx.t != nil {
  528. tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("exec %s", query)))
  529. }
  530. res, err = tx.tx.ExecContext(tx.c, query, args...)
  531. stats.Timing("mysql:tx:exec", int64(time.Since(now)/time.Millisecond))
  532. if err != nil {
  533. err = errors.Wrapf(err, "exec:%s, args:%+v", query, args)
  534. }
  535. return
  536. }
  537. // Query executes a query that returns rows, typically a SELECT.
  538. func (tx *Tx) Query(query string, args ...interface{}) (rows *Rows, err error) {
  539. if tx.t != nil {
  540. tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("query %s", query)))
  541. }
  542. now := time.Now()
  543. defer func() {
  544. stats.Timing("mysql:tx:query", int64(time.Since(now)/time.Millisecond))
  545. }()
  546. rs, err := tx.tx.QueryContext(tx.c, query, args...)
  547. if err == nil {
  548. rows = &Rows{Rows: rs}
  549. } else {
  550. err = errors.Wrapf(err, "query:%s, args:%+v", query, args)
  551. }
  552. return
  553. }
  554. // QueryRow executes a query that is expected to return at most one row.
  555. // QueryRow always returns a non-nil value. Errors are deferred until Row's
  556. // Scan method is called.
  557. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  558. if tx.t != nil {
  559. tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("queryrow %s", query)))
  560. }
  561. now := time.Now()
  562. defer func() {
  563. stats.Timing("mysql:tx:queryrow", int64(time.Since(now)/time.Millisecond))
  564. }()
  565. r := tx.tx.QueryRowContext(tx.c, query, args...)
  566. return &Row{Row: r, db: tx.db, query: query, args: args}
  567. }
  568. // Stmt returns a transaction-specific prepared statement from an existing statement.
  569. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  570. as, ok := stmt.stmt.Load().(*sql.Stmt)
  571. if !ok {
  572. return nil
  573. }
  574. ts := tx.tx.StmtContext(tx.c, as)
  575. st := &Stmt{query: stmt.query, tx: true, t: tx.t, db: tx.db}
  576. st.stmt.Store(ts)
  577. return st
  578. }
  579. // Prepare creates a prepared statement for use within a transaction.
  580. // The returned statement operates within the transaction and can no longer be
  581. // used once the transaction has been committed or rolled back.
  582. // To use an existing prepared statement on this transaction, see Tx.Stmt.
  583. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  584. if tx.t != nil {
  585. tx.t.SetTag(trace.String(trace.TagAnnotation, fmt.Sprintf("prepare %s", query)))
  586. }
  587. stmt, err := tx.tx.Prepare(query)
  588. if err != nil {
  589. err = errors.Wrapf(err, "prepare %s", query)
  590. return nil, err
  591. }
  592. st := &Stmt{query: query, tx: true, t: tx.t, db: tx.db}
  593. st.stmt.Store(stmt)
  594. return st, nil
  595. }
  596. // parseDSNAddr parse dsn name and return addr.
  597. func parseDSNAddr(dsn string) (addr string) {
  598. if dsn == "" {
  599. return
  600. }
  601. part0 := strings.Split(dsn, "@")
  602. if len(part0) > 1 {
  603. part1 := strings.Split(part0[1], "?")
  604. if len(part1) > 0 {
  605. addr = part1[0]
  606. }
  607. }
  608. return
  609. }