diff options
Diffstat (limited to 'src/lib/tlslite/TLSRecordLayer.py')
-rwxr-xr-x | src/lib/tlslite/TLSRecordLayer.py | 1131 |
1 files changed, 1131 insertions, 0 deletions
diff --git a/src/lib/tlslite/TLSRecordLayer.py b/src/lib/tlslite/TLSRecordLayer.py new file mode 100755 index 000000000..002a56862 --- /dev/null +++ b/src/lib/tlslite/TLSRecordLayer.py @@ -0,0 +1,1131 @@ +"""Helper class for TLSConnection.""" +from __future__ import generators + +from utils.compat import * +from utils.cryptomath import * +from utils.cipherfactory import createAES, createRC4, createTripleDES +from utils.codec import * +from errors import * +from messages import * +from mathtls import * +from constants import * +from utils.cryptomath import getRandomBytes +from utils import hmac +from FileObject import FileObject +import sha +import md5 +import socket +import errno +import traceback + +try: + GeneratorExit +except NameError: + class GeneratorExit(Exception): + pass + +class _ConnectionState: + def __init__(self): + self.macContext = None + self.encContext = None + self.seqnum = 0 + + def getSeqNumStr(self): + w = Writer(8) + w.add(self.seqnum, 8) + seqnumStr = bytesToString(w.bytes) + self.seqnum += 1 + return seqnumStr + + +class TLSRecordLayer: + """ + This class handles data transmission for a TLS connection. + + Its only subclass is L{tlslite.TLSConnection.TLSConnection}. We've + separated the code in this class from TLSConnection to make things + more readable. + + + @type sock: socket.socket + @ivar sock: The underlying socket object. + + @type session: L{tlslite.Session.Session} + @ivar session: The session corresponding to this connection. + + Due to TLS session resumption, multiple connections can correspond + to the same underlying session. + + @type version: tuple + @ivar version: The TLS version being used for this connection. + + (3,0) means SSL 3.0, and (3,1) means TLS 1.0. + + @type closed: bool + @ivar closed: If this connection is closed. + + @type resumed: bool + @ivar resumed: If this connection is based on a resumed session. + + @type allegedSharedKeyUsername: str or None + @ivar allegedSharedKeyUsername: This is set to the shared-key + username asserted by the client, whether the handshake succeeded or + not. If the handshake fails, this can be inspected to + determine if a guessing attack is in progress against a particular + user account. + + @type allegedSrpUsername: str or None + @ivar allegedSrpUsername: This is set to the SRP username + asserted by the client, whether the handshake succeeded or not. + If the handshake fails, this can be inspected to determine + if a guessing attack is in progress against a particular user + account. + + @type closeSocket: bool + @ivar closeSocket: If the socket should be closed when the + connection is closed (writable). + + If you set this to True, TLS Lite will assume the responsibility of + closing the socket when the TLS Connection is shutdown (either + through an error or through the user calling close()). The default + is False. + + @type ignoreAbruptClose: bool + @ivar ignoreAbruptClose: If an abrupt close of the socket should + raise an error (writable). + + If you set this to True, TLS Lite will not raise a + L{tlslite.errors.TLSAbruptCloseError} exception if the underlying + socket is unexpectedly closed. Such an unexpected closure could be + caused by an attacker. However, it also occurs with some incorrect + TLS implementations. + + You should set this to True only if you're not worried about an + attacker truncating the connection, and only if necessary to avoid + spurious errors. The default is False. + + @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync, + getCipherImplementation, getCipherName + """ + + def __init__(self, sock): + self.sock = sock + + #My session object (Session instance; read-only) + self.session = None + + #Am I a client or server? + self._client = None + + #Buffers for processing messages + self._handshakeBuffer = [] + self._readBuffer = "" + + #Handshake digests + self._handshake_md5 = md5.md5() + self._handshake_sha = sha.sha() + + #TLS Protocol Version + self.version = (0,0) #read-only + self._versionCheck = False #Once we choose a version, this is True + + #Current and Pending connection states + self._writeState = _ConnectionState() + self._readState = _ConnectionState() + self._pendingWriteState = _ConnectionState() + self._pendingReadState = _ConnectionState() + + #Is the connection open? + self.closed = True #read-only + self._refCount = 0 #Used to trigger closure + + #Is this a resumed (or shared-key) session? + self.resumed = False #read-only + + #What username did the client claim in his handshake? + self.allegedSharedKeyUsername = None + self.allegedSrpUsername = None + + #On a call to close(), do we close the socket? (writeable) + self.closeSocket = False + + #If the socket is abruptly closed, do we ignore it + #and pretend the connection was shut down properly? (writeable) + self.ignoreAbruptClose = False + + #Fault we will induce, for testing purposes + self.fault = None + + #********************************************************* + # Public Functions START + #********************************************************* + + def read(self, max=None, min=1): + """Read some data from the TLS connection. + + This function will block until at least 'min' bytes are + available (or the connection is closed). + + If an exception is raised, the connection will have been + automatically closed. + + @type max: int + @param max: The maximum number of bytes to return. + + @type min: int + @param min: The minimum number of bytes to return + + @rtype: str + @return: A string of no more than 'max' bytes, and no fewer + than 'min' (unless the connection has been closed, in which + case fewer than 'min' bytes may be returned). + + @raise socket.error: If a socket error occurs. + @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed + without a preceding alert. + @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. + """ + for result in self.readAsync(max, min): + pass + return result + + def readAsync(self, max=None, min=1): + """Start a read operation on the TLS connection. + + This function returns a generator which behaves similarly to + read(). Successive invocations of the generator will return 0 + if it is waiting to read from the socket, 1 if it is waiting + to write to the socket, or a string if the read operation has + completed. + + @rtype: iterable + @return: A generator; see above for details. + """ + try: + while len(self._readBuffer)<min and not self.closed: + try: + for result in self._getMsg(ContentType.application_data): + if result in (0,1): + yield result + applicationData = result + self._readBuffer += bytesToString(applicationData.write()) + except TLSRemoteAlert, alert: + if alert.description != AlertDescription.close_notify: + raise + except TLSAbruptCloseError: + if not self.ignoreAbruptClose: + raise + else: + self._shutdown(True) + + if max == None: + max = len(self._readBuffer) + + returnStr = self._readBuffer[:max] + self._readBuffer = self._readBuffer[max:] + yield returnStr + except GeneratorExit: + pass + except: + self._shutdown(False) + raise + + def write(self, s): + """Write some data to the TLS connection. + + This function will block until all the data has been sent. + + If an exception is raised, the connection will have been + automatically closed. + + @type s: str + @param s: The data to transmit to the other party. + + @raise socket.error: If a socket error occurs. + """ + for result in self.writeAsync(s): + pass + + def writeAsync(self, s): + """Start a write operation on the TLS connection. + + This function returns a generator which behaves similarly to + write(). Successive invocations of the generator will return + 1 if it is waiting to write to the socket, or will raise + StopIteration if the write operation has completed. + + @rtype: iterable + @return: A generator; see above for details. + """ + try: + if self.closed: + raise ValueError() + + index = 0 + blockSize = 16384 + skipEmptyFrag = False + while 1: + startIndex = index * blockSize + endIndex = startIndex + blockSize + if startIndex >= len(s): + break + if endIndex > len(s): + endIndex = len(s) + block = stringToBytes(s[startIndex : endIndex]) + applicationData = ApplicationData().create(block) + for result in self._sendMsg(applicationData, skipEmptyFrag): + yield result + skipEmptyFrag = True #only send an empy fragment on 1st message + index += 1 + except: + self._shutdown(False) + raise + + def close(self): + """Close the TLS connection. + + This function will block until it has exchanged close_notify + alerts with the other party. After doing so, it will shut down the + TLS connection. Further attempts to read through this connection + will return "". Further attempts to write through this connection + will raise ValueError. + + If makefile() has been called on this connection, the connection + will be not be closed until the connection object and all file + objects have been closed. + + Even if an exception is raised, the connection will have been + closed. + + @raise socket.error: If a socket error occurs. + @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed + without a preceding alert. + @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. + """ + if not self.closed: + for result in self._decrefAsync(): + pass + + def closeAsync(self): + """Start a close operation on the TLS connection. + + This function returns a generator which behaves similarly to + close(). Successive invocations of the generator will return 0 + if it is waiting to read from the socket, 1 if it is waiting + to write to the socket, or will raise StopIteration if the + close operation has completed. + + @rtype: iterable + @return: A generator; see above for details. + """ + if not self.closed: + for result in self._decrefAsync(): + yield result + + def _decrefAsync(self): + self._refCount -= 1 + if self._refCount == 0 and not self.closed: + try: + for result in self._sendMsg(Alert().create(\ + AlertDescription.close_notify, AlertLevel.warning)): + yield result + alert = None + while not alert: + for result in self._getMsg((ContentType.alert, \ + ContentType.application_data)): + if result in (0,1): + yield result + if result.contentType == ContentType.alert: + alert = result + if alert.description == AlertDescription.close_notify: + self._shutdown(True) + else: + raise TLSRemoteAlert(alert) + except (socket.error, TLSAbruptCloseError): + #If the other side closes the socket, that's okay + self._shutdown(True) + except: + self._shutdown(False) + raise + + def getCipherName(self): + """Get the name of the cipher used with this connection. + + @rtype: str + @return: The name of the cipher used with this connection. + Either 'aes128', 'aes256', 'rc4', or '3des'. + """ + if not self._writeState.encContext: + return None + return self._writeState.encContext.name + + def getCipherImplementation(self): + """Get the name of the cipher implementation used with + this connection. + + @rtype: str + @return: The name of the cipher implementation used with + this connection. Either 'python', 'cryptlib', 'openssl', + or 'pycrypto'. + """ + if not self._writeState.encContext: + return None + return self._writeState.encContext.implementation + + + + #Emulate a socket, somewhat - + def send(self, s): + """Send data to the TLS connection (socket emulation). + + @raise socket.error: If a socket error occurs. + """ + self.write(s) + return len(s) + + def sendall(self, s): + """Send data to the TLS connection (socket emulation). + + @raise socket.error: If a socket error occurs. + """ + self.write(s) + + def recv(self, bufsize): + """Get some data from the TLS connection (socket emulation). + + @raise socket.error: If a socket error occurs. + @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed + without a preceding alert. + @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. + """ + return self.read(bufsize) + + def makefile(self, mode='r', bufsize=-1): + """Create a file object for the TLS connection (socket emulation). + + @rtype: L{tlslite.FileObject.FileObject} + """ + self._refCount += 1 + return FileObject(self, mode, bufsize) + + def getsockname(self): + """Return the socket's own address (socket emulation).""" + return self.sock.getsockname() + + def getpeername(self): + """Return the remote address to which the socket is connected + (socket emulation).""" + return self.sock.getpeername() + + def settimeout(self, value): + """Set a timeout on blocking socket operations (socket emulation).""" + return self.sock.settimeout(value) + + def gettimeout(self): + """Return the timeout associated with socket operations (socket + emulation).""" + return self.sock.gettimeout() + + def setsockopt(self, level, optname, value): + """Set the value of the given socket option (socket emulation).""" + return self.sock.setsockopt(level, optname, value) + + + #********************************************************* + # Public Functions END + #********************************************************* + + def _shutdown(self, resumable): + self._writeState = _ConnectionState() + self._readState = _ConnectionState() + #Don't do this: self._readBuffer = "" + self.version = (0,0) + self._versionCheck = False + self.closed = True + if self.closeSocket: + self.sock.close() + + #Even if resumable is False, we'll never toggle this on + if not resumable and self.session: + self.session.resumable = False + + + def _sendError(self, alertDescription, errorStr=None): + alert = Alert().create(alertDescription, AlertLevel.fatal) + for result in self._sendMsg(alert): + yield result + self._shutdown(False) + raise TLSLocalAlert(alert, errorStr) + + def _sendMsgs(self, msgs): + skipEmptyFrag = False + for msg in msgs: + for result in self._sendMsg(msg, skipEmptyFrag): + yield result + skipEmptyFrag = True + + def _sendMsg(self, msg, skipEmptyFrag=False): + bytes = msg.write() + contentType = msg.contentType + + #Whenever we're connected and asked to send a message, + #we first send an empty Application Data message. This prevents + #an attacker from launching a chosen-plaintext attack based on + #knowing the next IV. + if not self.closed and not skipEmptyFrag and self.version == (3,1): + if self._writeState.encContext: + if self._writeState.encContext.isBlockCipher: + for result in self._sendMsg(ApplicationData(), + skipEmptyFrag=True): + yield result + + #Update handshake hashes + if contentType == ContentType.handshake: + bytesStr = bytesToString(bytes) + self._handshake_md5.update(bytesStr) + self._handshake_sha.update(bytesStr) + + #Calculate MAC + if self._writeState.macContext: + seqnumStr = self._writeState.getSeqNumStr() + bytesStr = bytesToString(bytes) + mac = self._writeState.macContext.copy() + mac.update(seqnumStr) + mac.update(chr(contentType)) + if self.version == (3,0): + mac.update( chr( int(len(bytes)/256) ) ) + mac.update( chr( int(len(bytes)%256) ) ) + elif self.version in ((3,1), (3,2)): + mac.update(chr(self.version[0])) + mac.update(chr(self.version[1])) + mac.update( chr( int(len(bytes)/256) ) ) + mac.update( chr( int(len(bytes)%256) ) ) + else: + raise AssertionError() + mac.update(bytesStr) + macString = mac.digest() + macBytes = stringToBytes(macString) + if self.fault == Fault.badMAC: + macBytes[0] = (macBytes[0]+1) % 256 + + #Encrypt for Block or Stream Cipher + if self._writeState.encContext: + #Add padding and encrypt (for Block Cipher): + if self._writeState.encContext.isBlockCipher: + + #Add TLS 1.1 fixed block + if self.version == (3,2): + bytes = self.fixedIVBlock + bytes + + #Add padding: bytes = bytes + (macBytes + paddingBytes) + currentLength = len(bytes) + len(macBytes) + 1 + blockLength = self._writeState.encContext.block_size + paddingLength = blockLength-(currentLength % blockLength) + + paddingBytes = createByteArraySequence([paddingLength] * \ + (paddingLength+1)) + if self.fault == Fault.badPadding: + paddingBytes[0] = (paddingBytes[0]+1) % 256 + endBytes = concatArrays(macBytes, paddingBytes) + bytes = concatArrays(bytes, endBytes) + #Encrypt + plaintext = stringToBytes(bytes) + ciphertext = self._writeState.encContext.encrypt(plaintext) + bytes = stringToBytes(ciphertext) + + #Encrypt (for Stream Cipher) + else: + bytes = concatArrays(bytes, macBytes) + plaintext = bytesToString(bytes) + ciphertext = self._writeState.encContext.encrypt(plaintext) + bytes = stringToBytes(ciphertext) + + #Add record header and send + r = RecordHeader3().create(self.version, contentType, len(bytes)) + s = bytesToString(concatArrays(r.write(), bytes)) + while 1: + try: + bytesSent = self.sock.send(s) #Might raise socket.error + except socket.error, why: + if why[0] == errno.EWOULDBLOCK: + yield 1 + continue + else: + raise + if bytesSent == len(s): + return + s = s[bytesSent:] + yield 1 + + + def _getMsg(self, expectedType, secondaryType=None, constructorType=None): + try: + if not isinstance(expectedType, tuple): + expectedType = (expectedType,) + + #Spin in a loop, until we've got a non-empty record of a type we + #expect. The loop will be repeated if: + # - we receive a renegotiation attempt; we send no_renegotiation, + # then try again + # - we receive an empty application-data fragment; we try again + while 1: + for result in self._getNextRecord(): + if result in (0,1): + yield result + recordHeader, p = result + + #If this is an empty application-data fragment, try again + if recordHeader.type == ContentType.application_data: + if p.index == len(p.bytes): + continue + + #If we received an unexpected record type... + if recordHeader.type not in expectedType: + + #If we received an alert... + if recordHeader.type == ContentType.alert: + alert = Alert().parse(p) + + #We either received a fatal error, a warning, or a + #close_notify. In any case, we're going to close the + #connection. In the latter two cases we respond with + #a close_notify, but ignore any socket errors, since + #the other side might have already closed the socket. + if alert.level == AlertLevel.warning or \ + alert.description == AlertDescription.close_notify: + + #If the sendMsg() call fails because the socket has + #already been closed, we will be forgiving and not + #report the error nor invalidate the "resumability" + #of the session. + try: + alertMsg = Alert() + alertMsg.create(AlertDescription.close_notify, + AlertLevel.warning) + for result in self._sendMsg(alertMsg): + yield result + except socket.error: + pass + + if alert.description == \ + AlertDescription.close_notify: + self._shutdown(True) + elif alert.level == AlertLevel.warning: + self._shutdown(False) + + else: #Fatal alert: + self._shutdown(False) + + #Raise the alert as an exception + raise TLSRemoteAlert(alert) + + #If we received a renegotiation attempt... + if recordHeader.type == ContentType.handshake: + subType = p.get(1) + reneg = False + if self._client: + if subType == HandshakeType.hello_request: + reneg = True + else: + if subType == HandshakeType.client_hello: + reneg = True + #Send no_renegotiation, then try again + if reneg: + alertMsg = Alert() + alertMsg.create(AlertDescription.no_renegotiation, + AlertLevel.warning) + for result in self._sendMsg(alertMsg): + yield result + continue + + #Otherwise: this is an unexpected record, but neither an + #alert nor renegotiation + for result in self._sendError(\ + AlertDescription.unexpected_message, + "received type=%d" % recordHeader.type): + yield result + + break + + #Parse based on content_type + if recordHeader.type == ContentType.change_cipher_spec: + yield ChangeCipherSpec().parse(p) + elif recordHeader.type == ContentType.alert: + yield Alert().parse(p) + elif recordHeader.type == ContentType.application_data: + yield ApplicationData().parse(p) + elif recordHeader.type == ContentType.handshake: + #Convert secondaryType to tuple, if it isn't already + if not isinstance(secondaryType, tuple): + secondaryType = (secondaryType,) + + #If it's a handshake message, check handshake header + if recordHeader.ssl2: + subType = p.get(1) + if subType != HandshakeType.client_hello: + for result in self._sendError(\ + AlertDescription.unexpected_message, + "Can only handle SSLv2 ClientHello messages"): + yield result + if HandshakeType.client_hello not in secondaryType: + for result in self._sendError(\ + AlertDescription.unexpected_message): + yield result + subType = HandshakeType.client_hello + else: + subType = p.get(1) + if subType not in secondaryType: + for result in self._sendError(\ + AlertDescription.unexpected_message, + "Expecting %s, got %s" % (str(secondaryType), subType)): + yield result + + #Update handshake hashes + sToHash = bytesToString(p.bytes) + self._handshake_md5.update(sToHash) + self._handshake_sha.update(sToHash) + + #Parse based on handshake type + if subType == HandshakeType.client_hello: + yield ClientHello(recordHeader.ssl2).parse(p) + elif subType == HandshakeType.server_hello: + yield ServerHello().parse(p) + elif subType == HandshakeType.certificate: + yield Certificate(constructorType).parse(p) + elif subType == HandshakeType.certificate_request: + yield CertificateRequest().parse(p) + elif subType == HandshakeType.certificate_verify: + yield CertificateVerify().parse(p) + elif subType == HandshakeType.server_key_exchange: + yield ServerKeyExchange(constructorType).parse(p) + elif subType == HandshakeType.server_hello_done: + yield ServerHelloDone().parse(p) + elif subType == HandshakeType.client_key_exchange: + yield ClientKeyExchange(constructorType, \ + self.version).parse(p) + elif subType == HandshakeType.finished: + yield Finished(self.version).parse(p) + else: + raise AssertionError() + + #If an exception was raised by a Parser or Message instance: + except SyntaxError, e: + for result in self._sendError(AlertDescription.decode_error, + formatExceptionTrace(e)): + yield result + + + #Returns next record or next handshake message + def _getNextRecord(self): + + #If there's a handshake message waiting, return it + if self._handshakeBuffer: + recordHeader, bytes = self._handshakeBuffer[0] + self._handshakeBuffer = self._handshakeBuffer[1:] + yield (recordHeader, Parser(bytes)) + return + + #Otherwise... + #Read the next record header + bytes = createByteArraySequence([]) + recordHeaderLength = 1 + ssl2 = False + while 1: + try: + s = self.sock.recv(recordHeaderLength-len(bytes)) + except socket.error, why: + if why[0] == errno.EWOULDBLOCK: + yield 0 + continue + else: + raise + + #If the connection was abruptly closed, raise an error + if len(s)==0: + raise TLSAbruptCloseError() + + bytes += stringToBytes(s) + if len(bytes)==1: + if bytes[0] in ContentType.all: + ssl2 = False + recordHeaderLength = 5 + elif bytes[0] == 128: + ssl2 = True + recordHeaderLength = 2 + else: + raise SyntaxError() + if len(bytes) == recordHeaderLength: + break + + #Parse the record header + if ssl2: + r = RecordHeader2().parse(Parser(bytes)) + else: + r = RecordHeader3().parse(Parser(bytes)) + + #Check the record header fields + if r.length > 18432: + for result in self._sendError(AlertDescription.record_overflow): + yield result + + #Read the record contents + bytes = createByteArraySequence([]) + while 1: + try: + s = self.sock.recv(r.length - len(bytes)) + except socket.error, why: + if why[0] == errno.EWOULDBLOCK: + yield 0 + continue + else: + raise + + #If the connection is closed, raise a socket error + if len(s)==0: + raise TLSAbruptCloseError() + + bytes += stringToBytes(s) + if len(bytes) == r.length: + break + + #Check the record header fields (2) + #We do this after reading the contents from the socket, so that + #if there's an error, we at least don't leave extra bytes in the + #socket.. + # + # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP. + # SO WE LEAVE IT OUT FOR NOW. + # + #if self._versionCheck and r.version != self.version: + # for result in self._sendError(AlertDescription.protocol_version, + # "Version in header field: %s, should be %s" % (str(r.version), + # str(self.version))): + # yield result + + #Decrypt the record + for result in self._decryptRecord(r.type, bytes): + if result in (0,1): + yield result + else: + break + bytes = result + p = Parser(bytes) + + #If it doesn't contain handshake messages, we can just return it + if r.type != ContentType.handshake: + yield (r, p) + #If it's an SSLv2 ClientHello, we can return it as well + elif r.ssl2: + yield (r, p) + else: + #Otherwise, we loop through and add the handshake messages to the + #handshake buffer + while 1: + if p.index == len(bytes): #If we're at the end + if not self._handshakeBuffer: + for result in self._sendError(\ + AlertDescription.decode_error, \ + "Received empty handshake record"): + yield result + break + #There needs to be at least 4 bytes to get a header + if p.index+4 > len(bytes): + for result in self._sendError(\ + AlertDescription.decode_error, + "A record has a partial handshake message (1)"): + yield result + p.get(1) # skip handshake type + msgLength = p.get(3) + if p.index+msgLength > len(bytes): + for result in self._sendError(\ + AlertDescription.decode_error, + "A record has a partial handshake message (2)"): + yield result + + handshakePair = (r, bytes[p.index-4 : p.index+msgLength]) + self._handshakeBuffer.append(handshakePair) + p.index += msgLength + + #We've moved at least one handshake message into the + #handshakeBuffer, return the first one + recordHeader, bytes = self._handshakeBuffer[0] + self._handshakeBuffer = self._handshakeBuffer[1:] + yield (recordHeader, Parser(bytes)) + + + def _decryptRecord(self, recordType, bytes): + if self._readState.encContext: + + #Decrypt if it's a block cipher + if self._readState.encContext.isBlockCipher: + blockLength = self._readState.encContext.block_size + if len(bytes) % blockLength != 0: + for result in self._sendError(\ + AlertDescription.decryption_failed, + "Encrypted data not a multiple of blocksize"): + yield result + ciphertext = bytesToString(bytes) + plaintext = self._readState.encContext.decrypt(ciphertext) + if self.version == (3,2): #For TLS 1.1, remove explicit IV + plaintext = plaintext[self._readState.encContext.block_size : ] + bytes = stringToBytes(plaintext) + + #Check padding + paddingGood = True + paddingLength = bytes[-1] + if (paddingLength+1) > len(bytes): + paddingGood=False + totalPaddingLength = 0 + else: + if self.version == (3,0): + totalPaddingLength = paddingLength+1 + elif self.version in ((3,1), (3,2)): + totalPaddingLength = paddingLength+1 + paddingBytes = bytes[-totalPaddingLength:-1] + for byte in paddingBytes: + if byte != paddingLength: + paddingGood = False + totalPaddingLength = 0 + else: + raise AssertionError() + + #Decrypt if it's a stream cipher + else: + paddingGood = True + ciphertext = bytesToString(bytes) + plaintext = self._readState.encContext.decrypt(ciphertext) + bytes = stringToBytes(plaintext) + totalPaddingLength = 0 + + #Check MAC + macGood = True + macLength = self._readState.macContext.digest_size + endLength = macLength + totalPaddingLength + if endLength > len(bytes): + macGood = False + else: + #Read MAC + startIndex = len(bytes) - endLength + endIndex = startIndex + macLength + checkBytes = bytes[startIndex : endIndex] + + #Calculate MAC + seqnumStr = self._readState.getSeqNumStr() + bytes = bytes[:-endLength] + bytesStr = bytesToString(bytes) + mac = self._readState.macContext.copy() + mac.update(seqnumStr) + mac.update(chr(recordType)) + if self.version == (3,0): + mac.update( chr( int(len(bytes)/256) ) ) + mac.update( chr( int(len(bytes)%256) ) ) + elif self.version in ((3,1), (3,2)): + mac.update(chr(self.version[0])) + mac.update(chr(self.version[1])) + mac.update( chr( int(len(bytes)/256) ) ) + mac.update( chr( int(len(bytes)%256) ) ) + else: + raise AssertionError() + mac.update(bytesStr) + macString = mac.digest() + macBytes = stringToBytes(macString) + + #Compare MACs + if macBytes != checkBytes: + macGood = False + + if not (paddingGood and macGood): + for result in self._sendError(AlertDescription.bad_record_mac, + "MAC failure (or padding failure)"): + yield result + + yield bytes + + def _handshakeStart(self, client): + self._client = client + self._handshake_md5 = md5.md5() + self._handshake_sha = sha.sha() + self._handshakeBuffer = [] + self.allegedSharedKeyUsername = None + self.allegedSrpUsername = None + self._refCount = 1 + + def _handshakeDone(self, resumed): + self.resumed = resumed + self.closed = False + + def _calcPendingStates(self, clientRandom, serverRandom, implementations): + if self.session.cipherSuite in CipherSuite.aes128Suites: + macLength = 20 + keyLength = 16 + ivLength = 16 + createCipherFunc = createAES + elif self.session.cipherSuite in CipherSuite.aes256Suites: + macLength = 20 + keyLength = 32 + ivLength = 16 + createCipherFunc = createAES + elif self.session.cipherSuite in CipherSuite.rc4Suites: + macLength = 20 + keyLength = 16 + ivLength = 0 + createCipherFunc = createRC4 + elif self.session.cipherSuite in CipherSuite.tripleDESSuites: + macLength = 20 + keyLength = 24 + ivLength = 8 + createCipherFunc = createTripleDES + else: + raise AssertionError() + + if self.version == (3,0): + createMACFunc = MAC_SSL + elif self.version in ((3,1), (3,2)): + createMACFunc = hmac.HMAC + + outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) + + #Calculate Keying Material from Master Secret + if self.version == (3,0): + keyBlock = PRF_SSL(self.session.masterSecret, + concatArrays(serverRandom, clientRandom), + outputLength) + elif self.version in ((3,1), (3,2)): + keyBlock = PRF(self.session.masterSecret, + "key expansion", + concatArrays(serverRandom,clientRandom), + outputLength) + else: + raise AssertionError() + + #Slice up Keying Material + clientPendingState = _ConnectionState() + serverPendingState = _ConnectionState() + p = Parser(keyBlock) + clientMACBlock = bytesToString(p.getFixBytes(macLength)) + serverMACBlock = bytesToString(p.getFixBytes(macLength)) + clientKeyBlock = bytesToString(p.getFixBytes(keyLength)) + serverKeyBlock = bytesToString(p.getFixBytes(keyLength)) + clientIVBlock = bytesToString(p.getFixBytes(ivLength)) + serverIVBlock = bytesToString(p.getFixBytes(ivLength)) + clientPendingState.macContext = createMACFunc(clientMACBlock, + digestmod=sha) + serverPendingState.macContext = createMACFunc(serverMACBlock, + digestmod=sha) + clientPendingState.encContext = createCipherFunc(clientKeyBlock, + clientIVBlock, + implementations) + serverPendingState.encContext = createCipherFunc(serverKeyBlock, + serverIVBlock, + implementations) + + #Assign new connection states to pending states + if self._client: + self._pendingWriteState = clientPendingState + self._pendingReadState = serverPendingState + else: + self._pendingWriteState = serverPendingState + self._pendingReadState = clientPendingState + + if self.version == (3,2) and ivLength: + #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC + #residue to create the IV for each sent block) + self.fixedIVBlock = getRandomBytes(ivLength) + + def _changeWriteState(self): + self._writeState = self._pendingWriteState + self._pendingWriteState = _ConnectionState() + + def _changeReadState(self): + self._readState = self._pendingReadState + self._pendingReadState = _ConnectionState() + + def _sendFinished(self): + #Send ChangeCipherSpec + for result in self._sendMsg(ChangeCipherSpec()): + yield result + + #Switch to pending write state + self._changeWriteState() + + #Calculate verification data + verifyData = self._calcFinished(True) + if self.fault == Fault.badFinished: + verifyData[0] = (verifyData[0]+1)%256 + + #Send Finished message under new state + finished = Finished(self.version).create(verifyData) + for result in self._sendMsg(finished): + yield result + + def _getFinished(self): + #Get and check ChangeCipherSpec + for result in self._getMsg(ContentType.change_cipher_spec): + if result in (0,1): + yield result + changeCipherSpec = result + + if changeCipherSpec.type != 1: + for result in self._sendError(AlertDescription.illegal_parameter, + "ChangeCipherSpec type incorrect"): + yield result + + #Switch to pending read state + self._changeReadState() + + #Calculate verification data + verifyData = self._calcFinished(False) + + #Get and check Finished message under new state + for result in self._getMsg(ContentType.handshake, + HandshakeType.finished): + if result in (0,1): + yield result + finished = result + if finished.verify_data != verifyData: + for result in self._sendError(AlertDescription.decrypt_error, + "Finished message is incorrect"): + yield result + + def _calcFinished(self, send=True): + if self.version == (3,0): + if (self._client and send) or (not self._client and not send): + senderStr = "\x43\x4C\x4E\x54" + else: + senderStr = "\x53\x52\x56\x52" + + verifyData = self._calcSSLHandshakeHash(self.session.masterSecret, + senderStr) + return verifyData + + elif self.version in ((3,1), (3,2)): + if (self._client and send) or (not self._client and not send): + label = "client finished" + else: + label = "server finished" + + handshakeHashes = stringToBytes(self._handshake_md5.digest() + \ + self._handshake_sha.digest()) + verifyData = PRF(self.session.masterSecret, label, handshakeHashes, + 12) + return verifyData + else: + raise AssertionError() + + #Used for Finished messages and CertificateVerify messages in SSL v3 + def _calcSSLHandshakeHash(self, masterSecret, label): + masterSecretStr = bytesToString(masterSecret) + + imac_md5 = self._handshake_md5.copy() + imac_sha = self._handshake_sha.copy() + + imac_md5.update(label + masterSecretStr + '\x36'*48) + imac_sha.update(label + masterSecretStr + '\x36'*40) + + md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \ + imac_md5.digest()).digest() + shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \ + imac_sha.digest()).digest() + + return stringToBytes(md5Str + shaStr) + |