Add support for stream cipher subclasses in SymmetricCipherStream

This commit is contained in:
angelsl 2017-11-12 21:02:23 +08:00 committed by Jonathan White
parent 23347b392f
commit 656836950e
No known key found for this signature in database
GPG Key ID: 440FC65F2E0C6E01
2 changed files with 47 additions and 39 deletions

View File

@ -24,7 +24,7 @@ SymmetricCipherStream::SymmetricCipherStream(QIODevice* baseDevice, SymmetricCip
, m_bufferPos(0) , m_bufferPos(0)
, m_bufferFilling(false) , m_bufferFilling(false)
, m_error(false) , m_error(false)
, m_isInitalized(false) , m_isInitialized(false)
, m_dataWritten(false) , m_dataWritten(false)
{ {
} }
@ -36,12 +36,12 @@ SymmetricCipherStream::~SymmetricCipherStream()
bool SymmetricCipherStream::init(const QByteArray& key, const QByteArray& iv) bool SymmetricCipherStream::init(const QByteArray& key, const QByteArray& iv)
{ {
m_isInitalized = m_cipher->init(key, iv); m_isInitialized = m_cipher->init(key, iv);
if (!m_isInitalized) { if (!m_isInitialized) {
setErrorString(m_cipher->errorString()); setErrorString(m_cipher->errorString());
} }
m_streamCipher = m_cipher->blockSize() == 1;
return m_isInitalized; return m_isInitialized;
} }
void SymmetricCipherStream::resetInternalState() void SymmetricCipherStream::resetInternalState()
@ -56,11 +56,8 @@ void SymmetricCipherStream::resetInternalState()
bool SymmetricCipherStream::open(QIODevice::OpenMode mode) bool SymmetricCipherStream::open(QIODevice::OpenMode mode)
{ {
if (!m_isInitalized) { return m_isInitialized && LayeredStream::open(mode);
return false;
}
return LayeredStream::open(mode);
} }
bool SymmetricCipherStream::reset() bool SymmetricCipherStream::reset()
@ -127,11 +124,11 @@ bool SymmetricCipherStream::readBlock()
QByteArray newData; QByteArray newData;
if (m_bufferFilling) { if (m_bufferFilling) {
newData.resize(m_cipher->blockSize() - m_buffer.size()); newData.resize(blockSize() - m_buffer.size());
} }
else { else {
m_buffer.clear(); m_buffer.clear();
newData.resize(m_cipher->blockSize()); newData.resize(blockSize());
} }
int readResult = m_baseDevice->read(newData.data(), newData.size()); int readResult = m_baseDevice->read(newData.data(), newData.size());
@ -140,12 +137,11 @@ bool SymmetricCipherStream::readBlock()
m_error = true; m_error = true;
setErrorString(m_baseDevice->errorString()); setErrorString(m_baseDevice->errorString());
return false; return false;
} } else {
else {
m_buffer.append(newData.left(readResult)); m_buffer.append(newData.left(readResult));
} }
if (m_buffer.size() != m_cipher->blockSize()) { if (!m_streamCipher && m_buffer.size() != blockSize()) {
m_bufferFilling = true; m_bufferFilling = true;
return false; return false;
} }
@ -159,27 +155,29 @@ bool SymmetricCipherStream::readBlock()
m_bufferFilling = false; m_bufferFilling = false;
if (m_baseDevice->atEnd()) { if (m_baseDevice->atEnd()) {
if (!m_streamCipher) {
// PKCS7 padding // PKCS7 padding
quint8 padLength = m_buffer.at(m_buffer.size() - 1); quint8 padLength = m_buffer.at(m_buffer.size() - 1);
if (padLength == m_cipher->blockSize()) { if (padLength == blockSize()) {
Q_ASSERT(m_buffer == QByteArray(m_cipher->blockSize(), m_cipher->blockSize())); Q_ASSERT(m_buffer == QByteArray(blockSize(), blockSize()));
// full block with just padding: discard // full block with just padding: discard
m_buffer.clear(); m_buffer.clear();
return false; return false;
} } else if (padLength > blockSize()) {
else if (padLength > m_cipher->blockSize()) {
// invalid padding // invalid padding
m_error = true; m_error = true;
setErrorString("Invalid padding."); setErrorString("Invalid padding.");
return false; return false;
} } else {
else {
Q_ASSERT(m_buffer.right(padLength) == QByteArray(padLength, padLength)); Q_ASSERT(m_buffer.right(padLength) == QByteArray(padLength, padLength));
// resize buffer to strip padding // resize buffer to strip padding
m_buffer.resize(m_cipher->blockSize() - padLength); m_buffer.resize(blockSize() - padLength);
return true; return true;
} }
} else {
return m_buffer.size() > 0;
}
} }
else { else {
return true; return true;
@ -200,14 +198,14 @@ qint64 SymmetricCipherStream::writeData(const char* data, qint64 maxSize)
qint64 offset = 0; qint64 offset = 0;
while (bytesRemaining > 0) { while (bytesRemaining > 0) {
int bytesToCopy = qMin(bytesRemaining, static_cast<qint64>(m_cipher->blockSize() - m_buffer.size())); int bytesToCopy = qMin(bytesRemaining, static_cast<qint64>(blockSize() - m_buffer.size()));
m_buffer.append(data + offset, bytesToCopy); m_buffer.append(data + offset, bytesToCopy);
offset += bytesToCopy; offset += bytesToCopy;
bytesRemaining -= bytesToCopy; bytesRemaining -= bytesToCopy;
if (m_buffer.size() == m_cipher->blockSize()) { if (m_buffer.size() == blockSize()) {
if (!writeBlock(false)) { if (!writeBlock(false)) {
if (m_error) { if (m_error) {
return -1; return -1;
@ -224,11 +222,11 @@ qint64 SymmetricCipherStream::writeData(const char* data, qint64 maxSize)
bool SymmetricCipherStream::writeBlock(bool lastBlock) bool SymmetricCipherStream::writeBlock(bool lastBlock)
{ {
Q_ASSERT(lastBlock || (m_buffer.size() == m_cipher->blockSize())); Q_ASSERT(m_streamCipher || lastBlock || (m_buffer.size() == blockSize()));
if (lastBlock) { if (lastBlock && !m_streamCipher) {
// PKCS7 padding // PKCS7 padding
int padLen = m_cipher->blockSize() - m_buffer.size(); int padLen = blockSize() - m_buffer.size();
for (int i = 0; i < padLen; i++) { for (int i = 0; i < padLen; i++) {
m_buffer.append(static_cast<char>(padLen)); m_buffer.append(static_cast<char>(padLen));
} }
@ -250,3 +248,11 @@ bool SymmetricCipherStream::writeBlock(bool lastBlock)
return true; return true;
} }
} }
int SymmetricCipherStream::blockSize() const {
if (m_streamCipher) {
return 1024;
} else {
return m_cipher->blockSize();
}
}

View File

@ -45,14 +45,16 @@ private:
void resetInternalState(); void resetInternalState();
bool readBlock(); bool readBlock();
bool writeBlock(bool lastBlock); bool writeBlock(bool lastBlock);
int blockSize() const;
const QScopedPointer<SymmetricCipher> m_cipher; const QScopedPointer<SymmetricCipher> m_cipher;
QByteArray m_buffer; QByteArray m_buffer;
int m_bufferPos; int m_bufferPos;
bool m_bufferFilling; bool m_bufferFilling;
bool m_error; bool m_error;
bool m_isInitalized; bool m_isInitialized;
bool m_dataWritten; bool m_dataWritten;
bool m_streamCipher;
}; };
#endif // KEEPASSX_SYMMETRICCIPHERSTREAM_H #endif // KEEPASSX_SYMMETRICCIPHERSTREAM_H