diff options
author | Christopher Speller <crspeller@gmail.com> | 2016-09-23 10:17:51 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-23 10:17:51 -0400 |
commit | 2ca0e8f9a0f9863555a26e984cde15efff9ef8f8 (patch) | |
tree | daae1ee67b14a3d0a84424f2a304885d9e75ce2b /vendor/github.com/go-sql-driver/mysql/packets.go | |
parent | 6d62d65b2dc85855aabea036cbd44f6059e19d13 (diff) | |
download | chat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.tar.gz chat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.tar.bz2 chat-2ca0e8f9a0f9863555a26e984cde15efff9ef8f8.zip |
Updating golang dependancies (#4075)
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
-rw-r--r-- | vendor/github.com/go-sql-driver/mysql/packets.go | 337 |
1 files changed, 125 insertions, 212 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index 8d9166578..618098146 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -13,7 +13,6 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" - "errors" "fmt" "io" "math" @@ -48,8 +47,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul + } else { + return nil, ErrPktSync } - return nil, ErrPktSync } mc.sequence++ @@ -100,12 +100,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - return err - } - } - n, err := mc.netConn.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ @@ -146,7 +140,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // protocol version [1 byte] if data[0] < minProtocolVersion { return nil, fmt.Errorf( - "unsupported protocol version %d. Version %d or higher is required", + "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", data[0], minProtocolVersion, ) @@ -202,11 +196,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // return //} //return ErrMalformPkt - - // make a memory safe copy of the cipher slice - var b [20]byte - copy(b[:], cipher) - return b[:], nil + return cipher, nil } // make a memory safe copy of the cipher slice @@ -224,11 +214,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientLongPassword | clientTransactions | clientLocalFiles | - clientPluginAuth | - clientMultiResults | mc.flags&clientLongFlag - if mc.cfg.ClientFoundRows { + if mc.cfg.clientFoundRows { clientFlags |= clientFoundRows } @@ -237,17 +225,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientSSL } - if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements - } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd)) - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { + if n := len(mc.cfg.dbname); n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } @@ -273,14 +257,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[11] = 0x00 // Charset [1 byte] - var found bool - data[12], found = collations[mc.cfg.Collation] - if !found { - // Note possibility for false negatives: - // could be triggered although the collation is valid if the - // collations map does not contain entries the server supports. - return errors.New("unknown collation") - } + data[12] = mc.cfg.collation // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -296,18 +273,15 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn + mc.buf.rd = tlsConn } // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 - } + pos := 13 + 23 // User [null terminated string] - if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + if len(mc.cfg.user) > 0 { + pos += copy(data[pos:], mc.cfg.user) } data[pos] = 0x00 pos++ @@ -317,16 +291,11 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos += 1 + copy(data[pos+1:], scrambleBuff) // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { - pos += copy(data[pos:], mc.cfg.DBName) + if len(mc.cfg.dbname) > 0 { + pos += copy(data[pos:], mc.cfg.dbname) data[pos] = 0x00 - pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") - data[pos] = 0x00 - // Send Auth packet return mc.writePacket(data) } @@ -335,9 +304,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) + scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) - // Calculate the packet length and add a tailing 0 + // Calculate the packet lenght and add a tailing 0 pktLen := len(scrambleBuff) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { @@ -353,25 +322,6 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { return mc.writePacket(data) } -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.Passwd) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - /****************************************************************************** * Command Packets * ******************************************************************************/ @@ -455,20 +405,8 @@ func (mc *mysqlConn) readResultOK() error { return mc.handleOkPacket(data) case iEOF: - if len(data) > 1 { - plugin := string(data[1:bytes.IndexByte(data, 0x00)]) - if plugin == "mysql_old_password" { - // using old_passwords - return ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return ErrCleartextPassword - } else { - return ErrUnknownPlugin - } - } else { - return ErrOldPassword - } + // someone is using old_passwords + return ErrOldPassword default: // Error otherwise return mc.handleErrorPacket(data) @@ -532,10 +470,6 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } } -func readStatus(b []byte) statusFlag { - return statusFlag(b[0]) | statusFlag(b[1])<<8 -} - // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { @@ -550,21 +484,17 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardResults(); err != nil { - return err - } // warning count [2 bytes] if !mc.strict { return nil + } else { + pos := 1 + n + m + 2 + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + return mc.getWarnings() + } + return nil } - - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } - return nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -583,7 +513,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if i == count { return columns, nil } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) + return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) } // Catalog @@ -600,20 +530,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { pos += n // Table [len coded string] - if mc.cfg.ColumnsWithAlias { - tableName, _, n, err := readLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n - columns[i].tableName = string(tableName) - } else { - n, err = skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err } + pos += n // Original table [len coded string] n, err = skipLengthEncodedString(data[pos:]) @@ -636,21 +557,20 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { return nil, err } - // Filler [uint8] - // Charset [charset, collation uint8] - // Length [uint32] + // Filler [1 byte] + // Charset [16 bit uint] + // Length [32 bit uint] pos += n + 1 + 2 + 4 - // Field type [uint8] + // Field type [byte] columns[i].fieldType = data[pos] pos++ - // Flags [uint16] + // Flags [16 bit uint] columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - pos += 2 + //pos += 2 - // Decimals [uint8] - columns[i].decimals = data[pos] + // Decimals [8 bit uint] //pos++ // Default value [len coded binary] @@ -672,18 +592,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err - } - rows.mc = nil return io.EOF } - if data[0] == iERR { - rows.mc = nil - return mc.handleErrorPacket(data) - } // RowSet Packet var n int @@ -704,7 +614,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), - mc.cfg.Loc, + mc.cfg.loc, ) if err == nil { continue @@ -734,10 +644,6 @@ func (mc *mysqlConn) readUntilEOF() error { if err == nil && data[0] != iEOF { continue } - if err == nil && data[0] == iEOF && len(data) == 5 { - mc.status = readStatus(data[3:]) - } - return err // Err or EOF } } @@ -770,13 +676,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Warning count [16 bit uint] if !stmt.mc.strict { return columnCount, nil + } else { + // Check for warnings count > 0, only available in MySQL > 4.1 + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { + return columnCount, stmt.mc.getWarnings() + } + return columnCount, nil } - - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } - return columnCount, nil } return 0, err } @@ -838,7 +744,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( - "argument count mismatch (got: %d; has: %d)", + "Arguments count mismatch (Got: %d Has: %d)", len(args), stmt.paramCount, ) @@ -1015,7 +921,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { val = []byte("0000-00-00") } else { - val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) } paramValues = appendLengthEncodedInteger(paramValues, @@ -1024,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramValues = append(paramValues, val...) default: - return fmt.Errorf("can not convert type: %T", arg) + return fmt.Errorf("Can't convert type: %T", arg) } } @@ -1042,28 +948,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } -func (mc *mysqlConn) discardResults() error { - for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() - if err != nil { - return err - } - if resLen > 0 { - // columns - if err := mc.readUntilEOF(); err != nil { - return err - } - // rows - if err := mc.readUntilEOF(); err != nil { - return err - } - } else { - mc.status &^= statusMoreResultsExists - } - } - return nil -} - // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() @@ -1075,14 +959,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { if data[0] != iOK { // EOF Packet if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err - } - rows.mc = nil return io.EOF } - rows.mc = nil // Error otherwise return rows.mc.handleErrorPacket(data) @@ -1149,7 +1027,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) pos += 4 continue @@ -1162,7 +1040,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, - fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + fieldTypeVarString, fieldTypeString, fieldTypeGeometry: var isNull bool var n int dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) @@ -1177,53 +1055,88 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } return err - case - fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD - fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] - fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] - + // Date YYYY-MM-DD + case fieldTypeDate, fieldTypeNewDate: num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n - switch { - case isNull: + if isNull { dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: - // database/sql does not support an equivalent to TIME, return a string - var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { - case 0x00, 0x1f: - dstlen = 8 - case 1, 2, 3, 4, 5, 6: - dstlen = 8 + 1 + decimals - default: - return fmt.Errorf( - "protocol error, illegal decimals value %d", - rows.columns[i].decimals, - ) - } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) - case rows.mc.parseTime: - dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) - default: - var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { - dstlen = 10 + } + + if rows.mc.parseTime { + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) + } else { + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false) + } + + if err == nil { + pos += int(num) + continue + } else { + return err + } + + // Time [-][H]HH:MM:SS[.fractal] + case fieldTypeTime: + num, isNull, n := readLengthEncodedInteger(data[pos:]) + pos += n + + if num == 0 { + if isNull { + dest[i] = nil + continue } else { - switch decimals := rows.columns[i].decimals; decimals { - case 0x00, 0x1f: - dstlen = 19 - case 1, 2, 3, 4, 5, 6: - dstlen = 19 + 1 + decimals - default: - return fmt.Errorf( - "protocol error, illegal decimals value %d", - rows.columns[i].decimals, - ) - } + dest[i] = []byte("00:00:00") + continue } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + } + + var sign string + if data[pos] == 1 { + sign = "-" + } + + switch num { + case 8: + dest[i] = []byte(fmt.Sprintf( + sign+"%02d:%02d:%02d", + uint16(data[pos+1])*24+uint16(data[pos+5]), + data[pos+6], + data[pos+7], + )) + pos += 8 + continue + case 12: + dest[i] = []byte(fmt.Sprintf( + sign+"%02d:%02d:%02d.%06d", + uint16(data[pos+1])*24+uint16(data[pos+5]), + data[pos+6], + data[pos+7], + binary.LittleEndian.Uint32(data[pos+8:pos+12]), + )) + pos += 12 + continue + default: + return fmt.Errorf("Invalid TIME-packet length %d", num) + } + + // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] + case fieldTypeTimestamp, fieldTypeDateTime: + num, isNull, n := readLengthEncodedInteger(data[pos:]) + + pos += n + + if isNull { + dest[i] = nil + continue + } + + if rows.mc.parseTime { + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) + } else { + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true) } if err == nil { @@ -1235,7 +1148,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType) } } |