util.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package mysql
  2. import (
  3. "crypto/rand"
  4. "crypto/sha1"
  5. "encoding/binary"
  6. "fmt"
  7. "io"
  8. "runtime"
  9. "strings"
  10. "github.com/juju/errors"
  11. "github.com/siddontang/go/hack"
  12. )
  13. func Pstack() string {
  14. buf := make([]byte, 1024)
  15. n := runtime.Stack(buf, false)
  16. return string(buf[0:n])
  17. }
  18. func CalcPassword(scramble, password []byte) []byte {
  19. if len(password) == 0 {
  20. return nil
  21. }
  22. // stage1Hash = SHA1(password)
  23. crypt := sha1.New()
  24. crypt.Write(password)
  25. stage1 := crypt.Sum(nil)
  26. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  27. // inner Hash
  28. crypt.Reset()
  29. crypt.Write(stage1)
  30. hash := crypt.Sum(nil)
  31. // outer Hash
  32. crypt.Reset()
  33. crypt.Write(scramble)
  34. crypt.Write(hash)
  35. scramble = crypt.Sum(nil)
  36. // token = scrambleHash XOR stage1Hash
  37. for i := range scramble {
  38. scramble[i] ^= stage1[i]
  39. }
  40. return scramble
  41. }
  42. func RandomBuf(size int) ([]byte, error) {
  43. buf := make([]byte, size)
  44. if _, err := io.ReadFull(rand.Reader, buf); err != nil {
  45. return nil, errors.Trace(err)
  46. }
  47. // avoid to generate '\0'
  48. for i, b := range buf {
  49. if uint8(b) == 0 {
  50. buf[i] = '0'
  51. }
  52. }
  53. return buf, nil
  54. }
  55. // little endian
  56. func FixedLengthInt(buf []byte) uint64 {
  57. var num uint64 = 0
  58. for i, b := range buf {
  59. num |= uint64(b) << (uint(i) * 8)
  60. }
  61. return num
  62. }
  63. // big endian
  64. func BFixedLengthInt(buf []byte) uint64 {
  65. var num uint64 = 0
  66. for i, b := range buf {
  67. num |= uint64(b) << (uint(len(buf)-i-1) * 8)
  68. }
  69. return num
  70. }
  71. func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
  72. switch b[0] {
  73. // 251: NULL
  74. case 0xfb:
  75. n = 1
  76. isNull = true
  77. return
  78. // 252: value of following 2
  79. case 0xfc:
  80. num = uint64(b[1]) | uint64(b[2])<<8
  81. n = 3
  82. return
  83. // 253: value of following 3
  84. case 0xfd:
  85. num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
  86. n = 4
  87. return
  88. // 254: value of following 8
  89. case 0xfe:
  90. num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
  91. uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
  92. uint64(b[7])<<48 | uint64(b[8])<<56
  93. n = 9
  94. return
  95. }
  96. // 0-250: value of first byte
  97. num = uint64(b[0])
  98. n = 1
  99. return
  100. }
  101. func PutLengthEncodedInt(n uint64) []byte {
  102. switch {
  103. case n <= 250:
  104. return []byte{byte(n)}
  105. case n <= 0xffff:
  106. return []byte{0xfc, byte(n), byte(n >> 8)}
  107. case n <= 0xffffff:
  108. return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  109. case n <= 0xffffffffffffffff:
  110. return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24),
  111. byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}
  112. }
  113. return nil
  114. }
  115. func LengthEnodedString(b []byte) ([]byte, bool, int, error) {
  116. // Get length
  117. num, isNull, n := LengthEncodedInt(b)
  118. if num < 1 {
  119. return nil, isNull, n, nil
  120. }
  121. n += int(num)
  122. // Check data length
  123. if len(b) >= n {
  124. return b[n-int(num) : n], false, n, nil
  125. }
  126. return nil, false, n, io.EOF
  127. }
  128. func SkipLengthEnodedString(b []byte) (int, error) {
  129. // Get length
  130. num, _, n := LengthEncodedInt(b)
  131. if num < 1 {
  132. return n, nil
  133. }
  134. n += int(num)
  135. // Check data length
  136. if len(b) >= n {
  137. return n, nil
  138. }
  139. return n, io.EOF
  140. }
  141. func PutLengthEncodedString(b []byte) []byte {
  142. data := make([]byte, 0, len(b)+9)
  143. data = append(data, PutLengthEncodedInt(uint64(len(b)))...)
  144. data = append(data, b...)
  145. return data
  146. }
  147. func Uint16ToBytes(n uint16) []byte {
  148. return []byte{
  149. byte(n),
  150. byte(n >> 8),
  151. }
  152. }
  153. func Uint32ToBytes(n uint32) []byte {
  154. return []byte{
  155. byte(n),
  156. byte(n >> 8),
  157. byte(n >> 16),
  158. byte(n >> 24),
  159. }
  160. }
  161. func Uint64ToBytes(n uint64) []byte {
  162. return []byte{
  163. byte(n),
  164. byte(n >> 8),
  165. byte(n >> 16),
  166. byte(n >> 24),
  167. byte(n >> 32),
  168. byte(n >> 40),
  169. byte(n >> 48),
  170. byte(n >> 56),
  171. }
  172. }
  173. func FormatBinaryDate(n int, data []byte) ([]byte, error) {
  174. switch n {
  175. case 0:
  176. return []byte("0000-00-00"), nil
  177. case 4:
  178. return []byte(fmt.Sprintf("%04d-%02d-%02d",
  179. binary.LittleEndian.Uint16(data[:2]),
  180. data[2],
  181. data[3])), nil
  182. default:
  183. return nil, errors.Errorf("invalid date packet length %d", n)
  184. }
  185. }
  186. func FormatBinaryDateTime(n int, data []byte) ([]byte, error) {
  187. switch n {
  188. case 0:
  189. return []byte("0000-00-00 00:00:00"), nil
  190. case 4:
  191. return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
  192. binary.LittleEndian.Uint16(data[:2]),
  193. data[2],
  194. data[3])), nil
  195. case 7:
  196. return []byte(fmt.Sprintf(
  197. "%04d-%02d-%02d %02d:%02d:%02d",
  198. binary.LittleEndian.Uint16(data[:2]),
  199. data[2],
  200. data[3],
  201. data[4],
  202. data[5],
  203. data[6])), nil
  204. case 11:
  205. return []byte(fmt.Sprintf(
  206. "%04d-%02d-%02d %02d:%02d:%02d.%06d",
  207. binary.LittleEndian.Uint16(data[:2]),
  208. data[2],
  209. data[3],
  210. data[4],
  211. data[5],
  212. data[6],
  213. binary.LittleEndian.Uint32(data[7:11]))), nil
  214. default:
  215. return nil, errors.Errorf("invalid datetime packet length %d", n)
  216. }
  217. }
  218. func FormatBinaryTime(n int, data []byte) ([]byte, error) {
  219. if n == 0 {
  220. return []byte("0000-00-00"), nil
  221. }
  222. var sign byte
  223. if data[0] == 1 {
  224. sign = byte('-')
  225. }
  226. switch n {
  227. case 8:
  228. return []byte(fmt.Sprintf(
  229. "%c%02d:%02d:%02d",
  230. sign,
  231. uint16(data[1])*24+uint16(data[5]),
  232. data[6],
  233. data[7],
  234. )), nil
  235. case 12:
  236. return []byte(fmt.Sprintf(
  237. "%c%02d:%02d:%02d.%06d",
  238. sign,
  239. uint16(data[1])*24+uint16(data[5]),
  240. data[6],
  241. data[7],
  242. binary.LittleEndian.Uint32(data[8:12]),
  243. )), nil
  244. default:
  245. return nil, errors.Errorf("invalid time packet length %d", n)
  246. }
  247. }
  248. var (
  249. DONTESCAPE = byte(255)
  250. EncodeMap [256]byte
  251. )
  252. // only support utf-8
  253. func Escape(sql string) string {
  254. dest := make([]byte, 0, 2*len(sql))
  255. for _, w := range hack.Slice(sql) {
  256. if c := EncodeMap[w]; c == DONTESCAPE {
  257. dest = append(dest, w)
  258. } else {
  259. dest = append(dest, '\\', c)
  260. }
  261. }
  262. return string(dest)
  263. }
  264. func GetNetProto(addr string) string {
  265. if strings.Contains(addr, "/") {
  266. return "unix"
  267. } else {
  268. return "tcp"
  269. }
  270. }
  271. // ErrorEqual returns a boolean indicating whether err1 is equal to err2.
  272. func ErrorEqual(err1, err2 error) bool {
  273. e1 := errors.Cause(err1)
  274. e2 := errors.Cause(err2)
  275. if e1 == e2 {
  276. return true
  277. }
  278. if e1 == nil || e2 == nil {
  279. return e1 == e2
  280. }
  281. return e1.Error() == e2.Error()
  282. }
  283. var encodeRef = map[byte]byte{
  284. '\x00': '0',
  285. '\'': '\'',
  286. '"': '"',
  287. '\b': 'b',
  288. '\n': 'n',
  289. '\r': 'r',
  290. '\t': 't',
  291. 26: 'Z', // ctl-Z
  292. '\\': '\\',
  293. }
  294. func init() {
  295. for i := range EncodeMap {
  296. EncodeMap[i] = DONTESCAPE
  297. }
  298. for i := range EncodeMap {
  299. if to, ok := encodeRef[byte(i)]; ok {
  300. EncodeMap[byte(i)] = to
  301. }
  302. }
  303. }