You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
6.2 KiB
354 lines
6.2 KiB
package mysql |
|
|
|
import ( |
|
"crypto/rand" |
|
"crypto/sha1" |
|
"encoding/binary" |
|
"fmt" |
|
"io" |
|
"runtime" |
|
"strings" |
|
|
|
"github.com/juju/errors" |
|
"github.com/siddontang/go/hack" |
|
) |
|
|
|
func Pstack() string { |
|
buf := make([]byte, 1024) |
|
n := runtime.Stack(buf, false) |
|
return string(buf[0:n]) |
|
} |
|
|
|
func CalcPassword(scramble, password []byte) []byte { |
|
if len(password) == 0 { |
|
return nil |
|
} |
|
|
|
// stage1Hash = SHA1(password) |
|
crypt := sha1.New() |
|
crypt.Write(password) |
|
stage1 := crypt.Sum(nil) |
|
|
|
// scrambleHash = SHA1(scramble + SHA1(stage1Hash)) |
|
// inner Hash |
|
crypt.Reset() |
|
crypt.Write(stage1) |
|
hash := crypt.Sum(nil) |
|
|
|
// outer Hash |
|
crypt.Reset() |
|
crypt.Write(scramble) |
|
crypt.Write(hash) |
|
scramble = crypt.Sum(nil) |
|
|
|
// token = scrambleHash XOR stage1Hash |
|
for i := range scramble { |
|
scramble[i] ^= stage1[i] |
|
} |
|
return scramble |
|
} |
|
|
|
func RandomBuf(size int) ([]byte, error) { |
|
buf := make([]byte, size) |
|
|
|
if _, err := io.ReadFull(rand.Reader, buf); err != nil { |
|
return nil, errors.Trace(err) |
|
} |
|
|
|
// avoid to generate '\0' |
|
for i, b := range buf { |
|
if uint8(b) == 0 { |
|
buf[i] = '0' |
|
} |
|
} |
|
|
|
return buf, nil |
|
} |
|
|
|
// little endian |
|
func FixedLengthInt(buf []byte) uint64 { |
|
var num uint64 = 0 |
|
for i, b := range buf { |
|
num |= uint64(b) << (uint(i) * 8) |
|
} |
|
return num |
|
} |
|
|
|
// big endian |
|
func BFixedLengthInt(buf []byte) uint64 { |
|
var num uint64 = 0 |
|
for i, b := range buf { |
|
num |= uint64(b) << (uint(len(buf)-i-1) * 8) |
|
} |
|
return num |
|
} |
|
|
|
func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { |
|
switch b[0] { |
|
|
|
// 251: NULL |
|
case 0xfb: |
|
n = 1 |
|
isNull = true |
|
return |
|
|
|
// 252: value of following 2 |
|
case 0xfc: |
|
num = uint64(b[1]) | uint64(b[2])<<8 |
|
n = 3 |
|
return |
|
|
|
// 253: value of following 3 |
|
case 0xfd: |
|
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
n = 4 |
|
return |
|
|
|
// 254: value of following 8 |
|
case 0xfe: |
|
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | |
|
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | |
|
uint64(b[7])<<48 | uint64(b[8])<<56 |
|
n = 9 |
|
return |
|
} |
|
|
|
// 0-250: value of first byte |
|
num = uint64(b[0]) |
|
n = 1 |
|
return |
|
} |
|
|
|
func PutLengthEncodedInt(n uint64) []byte { |
|
switch { |
|
case n <= 250: |
|
return []byte{byte(n)} |
|
|
|
case n <= 0xffff: |
|
return []byte{0xfc, byte(n), byte(n >> 8)} |
|
|
|
case n <= 0xffffff: |
|
return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} |
|
|
|
case n <= 0xffffffffffffffff: |
|
return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), |
|
byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)} |
|
} |
|
return nil |
|
} |
|
|
|
func LengthEnodedString(b []byte) ([]byte, bool, int, error) { |
|
// Get length |
|
num, isNull, n := LengthEncodedInt(b) |
|
if num < 1 { |
|
return nil, isNull, n, nil |
|
} |
|
|
|
n += int(num) |
|
|
|
// Check data length |
|
if len(b) >= n { |
|
return b[n-int(num) : n], false, n, nil |
|
} |
|
return nil, false, n, io.EOF |
|
} |
|
|
|
func SkipLengthEnodedString(b []byte) (int, error) { |
|
// Get length |
|
num, _, n := LengthEncodedInt(b) |
|
if num < 1 { |
|
return n, nil |
|
} |
|
|
|
n += int(num) |
|
|
|
// Check data length |
|
if len(b) >= n { |
|
return n, nil |
|
} |
|
return n, io.EOF |
|
} |
|
|
|
func PutLengthEncodedString(b []byte) []byte { |
|
data := make([]byte, 0, len(b)+9) |
|
data = append(data, PutLengthEncodedInt(uint64(len(b)))...) |
|
data = append(data, b...) |
|
return data |
|
} |
|
|
|
func Uint16ToBytes(n uint16) []byte { |
|
return []byte{ |
|
byte(n), |
|
byte(n >> 8), |
|
} |
|
} |
|
|
|
func Uint32ToBytes(n uint32) []byte { |
|
return []byte{ |
|
byte(n), |
|
byte(n >> 8), |
|
byte(n >> 16), |
|
byte(n >> 24), |
|
} |
|
} |
|
|
|
func Uint64ToBytes(n uint64) []byte { |
|
return []byte{ |
|
byte(n), |
|
byte(n >> 8), |
|
byte(n >> 16), |
|
byte(n >> 24), |
|
byte(n >> 32), |
|
byte(n >> 40), |
|
byte(n >> 48), |
|
byte(n >> 56), |
|
} |
|
} |
|
|
|
func FormatBinaryDate(n int, data []byte) ([]byte, error) { |
|
switch n { |
|
case 0: |
|
return []byte("0000-00-00"), nil |
|
case 4: |
|
return []byte(fmt.Sprintf("%04d-%02d-%02d", |
|
binary.LittleEndian.Uint16(data[:2]), |
|
data[2], |
|
data[3])), nil |
|
default: |
|
return nil, errors.Errorf("invalid date packet length %d", n) |
|
} |
|
} |
|
|
|
func FormatBinaryDateTime(n int, data []byte) ([]byte, error) { |
|
switch n { |
|
case 0: |
|
return []byte("0000-00-00 00:00:00"), nil |
|
case 4: |
|
return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00", |
|
binary.LittleEndian.Uint16(data[:2]), |
|
data[2], |
|
data[3])), nil |
|
case 7: |
|
return []byte(fmt.Sprintf( |
|
"%04d-%02d-%02d %02d:%02d:%02d", |
|
binary.LittleEndian.Uint16(data[:2]), |
|
data[2], |
|
data[3], |
|
data[4], |
|
data[5], |
|
data[6])), nil |
|
case 11: |
|
return []byte(fmt.Sprintf( |
|
"%04d-%02d-%02d %02d:%02d:%02d.%06d", |
|
binary.LittleEndian.Uint16(data[:2]), |
|
data[2], |
|
data[3], |
|
data[4], |
|
data[5], |
|
data[6], |
|
binary.LittleEndian.Uint32(data[7:11]))), nil |
|
default: |
|
return nil, errors.Errorf("invalid datetime packet length %d", n) |
|
} |
|
} |
|
|
|
func FormatBinaryTime(n int, data []byte) ([]byte, error) { |
|
if n == 0 { |
|
return []byte("0000-00-00"), nil |
|
} |
|
|
|
var sign byte |
|
if data[0] == 1 { |
|
sign = byte('-') |
|
} |
|
|
|
switch n { |
|
case 8: |
|
return []byte(fmt.Sprintf( |
|
"%c%02d:%02d:%02d", |
|
sign, |
|
uint16(data[1])*24+uint16(data[5]), |
|
data[6], |
|
data[7], |
|
)), nil |
|
case 12: |
|
return []byte(fmt.Sprintf( |
|
"%c%02d:%02d:%02d.%06d", |
|
sign, |
|
uint16(data[1])*24+uint16(data[5]), |
|
data[6], |
|
data[7], |
|
binary.LittleEndian.Uint32(data[8:12]), |
|
)), nil |
|
default: |
|
return nil, errors.Errorf("invalid time packet length %d", n) |
|
} |
|
} |
|
|
|
var ( |
|
DONTESCAPE = byte(255) |
|
|
|
EncodeMap [256]byte |
|
) |
|
|
|
// only support utf-8 |
|
func Escape(sql string) string { |
|
dest := make([]byte, 0, 2*len(sql)) |
|
|
|
for _, w := range hack.Slice(sql) { |
|
if c := EncodeMap[w]; c == DONTESCAPE { |
|
dest = append(dest, w) |
|
} else { |
|
dest = append(dest, '\\', c) |
|
} |
|
} |
|
|
|
return string(dest) |
|
} |
|
|
|
func GetNetProto(addr string) string { |
|
if strings.Contains(addr, "/") { |
|
return "unix" |
|
} else { |
|
return "tcp" |
|
} |
|
} |
|
|
|
// ErrorEqual returns a boolean indicating whether err1 is equal to err2. |
|
func ErrorEqual(err1, err2 error) bool { |
|
e1 := errors.Cause(err1) |
|
e2 := errors.Cause(err2) |
|
|
|
if e1 == e2 { |
|
return true |
|
} |
|
|
|
if e1 == nil || e2 == nil { |
|
return e1 == e2 |
|
} |
|
|
|
return e1.Error() == e2.Error() |
|
} |
|
|
|
var encodeRef = map[byte]byte{ |
|
'\x00': '0', |
|
'\'': '\'', |
|
'"': '"', |
|
'\b': 'b', |
|
'\n': 'n', |
|
'\r': 'r', |
|
'\t': 't', |
|
26: 'Z', // ctl-Z |
|
'\\': '\\', |
|
} |
|
|
|
func init() { |
|
for i := range EncodeMap { |
|
EncodeMap[i] = DONTESCAPE |
|
} |
|
for i := range EncodeMap { |
|
if to, ok := encodeRef[byte(i)]; ok { |
|
EncodeMap[byte(i)] = to |
|
} |
|
} |
|
}
|
|
|