diff --git a/RNS/Cryptography/AES.py b/RNS/Cryptography/AES.py index 2689509..f3130b4 100644 --- a/RNS/Cryptography/AES.py +++ b/RNS/Cryptography/AES.py @@ -36,15 +36,48 @@ if cp.PROVIDER == cp.PROVIDER_INTERNAL: elif cp.PROVIDER == cp.PROVIDER_PYCA: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - - if pu.cryptography_old_api(): - from cryptography.hazmat.backends import default_backend + if pu.cryptography_old_api(): from cryptography.hazmat.backends import default_backend class AES_128_CBC: - @staticmethod def encrypt(plaintext, key, iv): + if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}") + if cp.PROVIDER == cp.PROVIDER_INTERNAL: + cipher = AES(key) + return cipher.encrypt(plaintext, iv) + + elif cp.PROVIDER == cp.PROVIDER_PYCA: + if not pu.cryptography_old_api(): + cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + else: + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return ciphertext + + @staticmethod + def decrypt(ciphertext, key, iv): + if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}") + if cp.PROVIDER == cp.PROVIDER_INTERNAL: + cipher = AES(key) + return cipher.decrypt(ciphertext, iv) + + elif cp.PROVIDER == cp.PROVIDER_PYCA: + if not pu.cryptography_old_api(): + cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + else: + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + return plaintext + +class AES_256_CBC: + @staticmethod + def encrypt(plaintext, key, iv): + if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}") if cp.PROVIDER == cp.PROVIDER_INTERNAL: cipher = AES(key) return cipher.encrypt(plaintext, iv) @@ -61,6 +94,7 @@ class AES_128_CBC: @staticmethod def decrypt(ciphertext, key, iv): + if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}") if cp.PROVIDER == cp.PROVIDER_INTERNAL: cipher = AES(key) return cipher.decrypt(ciphertext, iv)