You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
5.4 KiB
261 lines
5.4 KiB
package tcp |
|
|
|
import ( |
|
"bufio" |
|
"errors" |
|
"fmt" |
|
"io" |
|
"net" |
|
"strconv" |
|
"strings" |
|
"time" |
|
) |
|
|
|
type conn struct { |
|
// conn |
|
conn net.Conn |
|
// Read |
|
readTimeout time.Duration |
|
br *bufio.Reader |
|
// Write |
|
writeTimeout time.Duration |
|
bw *bufio.Writer |
|
// Scratch space for formatting argument length. |
|
// '*' or '$', length, "\r\n" |
|
lenScratch [32]byte |
|
// Scratch space for formatting integers and floats. |
|
numScratch [40]byte |
|
} |
|
|
|
// newConn returns a new connection for the given net connection. |
|
func newConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) *conn { |
|
return &conn{ |
|
conn: netConn, |
|
readTimeout: readTimeout, |
|
writeTimeout: writeTimeout, |
|
br: bufio.NewReaderSize(netConn, _readBufSize), |
|
bw: bufio.NewWriterSize(netConn, _writeBufSize), |
|
} |
|
} |
|
|
|
// Read read data from connection |
|
func (c *conn) Read() (cmd string, args [][]byte, err error) { |
|
var ( |
|
ln, cn int |
|
bs []byte |
|
) |
|
if c.readTimeout > 0 { |
|
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) |
|
} |
|
// start read |
|
if bs, err = c.readLine(); err != nil { |
|
return |
|
} |
|
if len(bs) < 2 { |
|
err = fmt.Errorf("read error data(%s) from connection", bs) |
|
return |
|
} |
|
// maybe a cmd that without any params is received,such as: QUIT |
|
if strings.ToLower(string(bs)) == _quit { |
|
cmd = _quit |
|
return |
|
} |
|
// get param number |
|
if ln, err = parseLen(bs[1:]); err != nil { |
|
return |
|
} |
|
args = make([][]byte, 0, ln-1) |
|
for i := 0; i < ln; i++ { |
|
if cn, err = c.readLen(_protoBulk); err != nil { |
|
return |
|
} |
|
if bs, err = c.readData(cn); err != nil { |
|
return |
|
} |
|
if i == 0 { |
|
cmd = strings.ToLower(string(bs)) |
|
continue |
|
} |
|
args = append(args, bs) |
|
} |
|
return |
|
} |
|
|
|
// WriteError write error to connection and close connection |
|
func (c *conn) WriteError(err error) { |
|
if c.writeTimeout > 0 { |
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
|
} |
|
if err = c.Write(proto{prefix: _protoErr, message: err.Error()}); err != nil { |
|
c.Close() |
|
return |
|
} |
|
c.Flush() |
|
c.Close() |
|
} |
|
|
|
// Write write data to connection |
|
func (c *conn) Write(p proto) (err error) { |
|
if c.writeTimeout > 0 { |
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
|
} |
|
// start write |
|
switch p.prefix { |
|
case _protoStr: |
|
err = c.writeStatus(p.message) |
|
case _protoErr: |
|
err = c.writeError(p.message) |
|
case _protoInt: |
|
err = c.writeInt64(int64(p.integer)) |
|
case _protoBulk: |
|
// c.writeString(p.message) |
|
err = c.writeBytes([]byte(p.message)) |
|
case _protoArray: |
|
err = c.writeLen(p.prefix, p.integer) |
|
} |
|
return |
|
} |
|
|
|
// Flush flush connection |
|
func (c *conn) Flush() error { |
|
if c.writeTimeout > 0 { |
|
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) |
|
} |
|
return c.bw.Flush() |
|
} |
|
|
|
// Close close connection |
|
func (c *conn) Close() error { |
|
return c.conn.Close() |
|
} |
|
|
|
// parseLen parses bulk string and array lengths. |
|
func parseLen(p []byte) (int, error) { |
|
if len(p) == 0 { |
|
return -1, errors.New("malformed length") |
|
} |
|
if p[0] == '-' && len(p) == 2 && p[1] == '1' { |
|
// handle $-1 and $-1 null replies. |
|
return -1, nil |
|
} |
|
var n int |
|
for _, b := range p { |
|
n *= 10 |
|
if b < '0' || b > '9' { |
|
return -1, errors.New("illegal bytes in length") |
|
} |
|
n += int(b - '0') |
|
} |
|
return n, nil |
|
} |
|
|
|
func (c *conn) readLine() ([]byte, error) { |
|
p, err := c.br.ReadBytes('\n') |
|
if err == bufio.ErrBufferFull { |
|
return nil, errors.New("long response line") |
|
} |
|
if err != nil { |
|
return nil, err |
|
} |
|
i := len(p) - 2 |
|
if i < 0 || p[i] != '\r' { |
|
return nil, errors.New("bad response line terminator") |
|
} |
|
return p[:i], nil |
|
} |
|
|
|
func (c *conn) readLen(prefix byte) (int, error) { |
|
ls, err := c.readLine() |
|
if err != nil { |
|
return 0, err |
|
} |
|
if len(ls) < 2 { |
|
return 0, errors.New("illegal bytes in length") |
|
} |
|
if ls[0] != prefix { |
|
return 0, errors.New("illegal bytes in length") |
|
} |
|
return parseLen(ls[1:]) |
|
} |
|
|
|
func (c *conn) readData(n int) ([]byte, error) { |
|
if n > _maxValueSize { |
|
return nil, errors.New("exceeding max value limit") |
|
} |
|
buf := make([]byte, n+2) |
|
r, err := io.ReadFull(c.br, buf) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if n != r-2 { |
|
return nil, errors.New("invalid bytes in len") |
|
} |
|
return buf[:n], err |
|
} |
|
|
|
func (c *conn) writeLen(prefix byte, n int) error { |
|
c.lenScratch[len(c.lenScratch)-1] = '\n' |
|
c.lenScratch[len(c.lenScratch)-2] = '\r' |
|
i := len(c.lenScratch) - 3 |
|
for { |
|
c.lenScratch[i] = byte('0' + n%10) |
|
i-- |
|
n = n / 10 |
|
if n == 0 { |
|
break |
|
} |
|
} |
|
c.lenScratch[i] = prefix |
|
_, err := c.bw.Write(c.lenScratch[i:]) |
|
return err |
|
} |
|
|
|
func (c *conn) writeStatus(s string) (err error) { |
|
c.bw.WriteByte(_protoStr) |
|
c.bw.WriteString(s) |
|
_, err = c.bw.WriteString("\r\n") |
|
return |
|
} |
|
|
|
func (c *conn) writeError(s string) (err error) { |
|
c.bw.WriteByte(_protoErr) |
|
c.bw.WriteString(s) |
|
_, err = c.bw.WriteString("\r\n") |
|
return |
|
} |
|
|
|
func (c *conn) writeInt64(n int64) (err error) { |
|
c.bw.WriteByte(_protoInt) |
|
c.bw.Write(strconv.AppendInt(c.numScratch[:0], n, 10)) |
|
_, err = c.bw.WriteString("\r\n") |
|
return |
|
} |
|
|
|
func (c *conn) writeString(s string) (err error) { |
|
c.writeLen(_protoBulk, len(s)) |
|
c.bw.WriteString(s) |
|
_, err = c.bw.WriteString("\r\n") |
|
return |
|
} |
|
|
|
func (c *conn) writeBytes(s []byte) (err error) { |
|
if len(s) == 0 { |
|
c.bw.WriteByte('$') |
|
c.bw.Write(_nullBulk) |
|
} else { |
|
c.writeLen(_protoBulk, len(s)) |
|
c.bw.Write(s) |
|
} |
|
_, err = c.bw.WriteString("\r\n") |
|
return |
|
} |
|
|
|
func (c *conn) writeStrings(ss []string) (err error) { |
|
c.writeLen(_protoArray, len(ss)) |
|
for _, s := range ss { |
|
if err = c.writeString(s); err != nil { |
|
return |
|
} |
|
} |
|
return |
|
}
|
|
|