diff --git a/RNS/Cryptography/aes/aes256.py b/RNS/Cryptography/aes/aes256.py index aaa4be9..1529a66 100644 --- a/RNS/Cryptography/aes/aes256.py +++ b/RNS/Cryptography/aes/aes256.py @@ -1,17 +1,17 @@ # MIT License - +# # Copyright (c) 2024 BoppreH - +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: - +# # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. - +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -58,25 +58,21 @@ inv_s_box = ( 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, ) - def sub_bytes(s): for i in range(4): for j in range(4): s[i][j] = s_box[s[i][j]] - def inv_sub_bytes(s): for i in range(4): for j in range(4): s[i][j] = inv_s_box[s[i][j]] - def shift_rows(s): s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1] s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3] - def inv_shift_rows(s): s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1] s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] @@ -87,11 +83,8 @@ def add_round_key(s, k): for j in range(4): s[i][j] ^= k[i][j] - -# learned from https://web.archive.org/web/20100626212235/http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) - def mix_single_column(a): # see Sec 4.1.2 in The Design of Rijndael t = a[0] ^ a[1] ^ a[2] ^ a[3] @@ -101,12 +94,10 @@ def mix_single_column(a): a[2] ^= t ^ xtime(a[2] ^ a[3]) a[3] ^= t ^ xtime(a[3] ^ u) - def mix_columns(s): for i in range(4): mix_single_column(s[i]) - def inv_mix_columns(s): # see Sec 4.1.3 in The Design of Rijndael for i in range(4): @@ -119,7 +110,6 @@ def inv_mix_columns(s): mix_columns(s) - r_con = ( 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, @@ -128,20 +118,11 @@ r_con = ( ) -def bytes2matrix(text): - """ Converts a 16-byte array into a 4x4 matrix. """ - return [list(text[i:i+4]) for i in range(0, len(text), 4)] - -def matrix2bytes(matrix): - """ Converts a 4x4 matrix into a 16-byte array. """ - return bytes(sum(matrix, [])) - -def xor_bytes(a, b): - """ Returns a new byte array with the elements xor'ed. """ - return bytes(i^j for i, j in zip(a, b)) +def bytes2matrix(text): return [list(text[i:i+4]) for i in range(0, len(text), 4)] +def matrix2bytes(matrix): return bytes(sum(matrix, [])) +def xor_bytes(a, b): return bytes(i^j for i, j in zip(a, b)) def inc_bytes(a): - """ Returns a new byte array with the value increment by 1 """ out = list(a) for i in reversed(range(len(out))): if out[i] == 0xFF: @@ -151,52 +132,18 @@ def inc_bytes(a): break return bytes(out) -def pad(plaintext): - """ - Pads the given plaintext with PKCS#7 padding to a multiple of 16 bytes. - Note that if the plaintext size is a multiple of 16, - a whole block will be added. - """ - padding_len = 16 - (len(plaintext) % 16) - padding = bytes([padding_len] * padding_len) - return plaintext + padding - -def unpad(plaintext): - """ - Removes a PKCS#7 padding, returning the unpadded text and ensuring the - padding was correct. - """ - padding_len = plaintext[-1] - assert padding_len > 0 - message, padding = plaintext[:-padding_len], plaintext[-padding_len:] - assert all(p == padding_len for p in padding) - return message - def split_blocks(message, block_size=16, require_padding=True): - assert len(message) % block_size == 0 or not require_padding - return [message[i:i+16] for i in range(0, len(message), block_size)] - + assert len(message) % block_size == 0 or not require_padding + return [message[i:i+16] for i in range(0, len(message), block_size)] class AES256: - """ - Class for AES-128 encryption with CBC mode and PKCS#7. - - This is a raw implementation of AES, without key stretching or IV - management. Unless you need that, please use `encrypt` and `decrypt`. - """ rounds_by_key_size = {32: 14} def __init__(self, master_key): - """ - Initializes the object with a given key. - """ assert len(master_key) in AES256.rounds_by_key_size self.n_rounds = AES256.rounds_by_key_size[len(master_key)] self._key_matrices = self._expand_key(master_key) def _expand_key(self, master_key): - """ - Expands and returns a list of key matrices for the given master_key. - """ # Initialize round keys with raw key material. key_columns = bytes2matrix(master_key) iteration_size = len(master_key) // 4 @@ -228,9 +175,6 @@ class AES256: return [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)] def encrypt_block(self, plaintext): - """ - Encrypts a single block of 16 byte long plaintext. - """ assert len(plaintext) == 16 plain_state = bytes2matrix(plaintext) @@ -250,9 +194,6 @@ class AES256: return matrix2bytes(plain_state) def decrypt_block(self, ciphertext): - """ - Decrypts a single block of 16 byte long ciphertext. - """ assert len(ciphertext) == 16 cipher_state = bytes2matrix(ciphertext) @@ -272,18 +213,10 @@ class AES256: return matrix2bytes(cipher_state) def encrypt_cbc(self, plaintext, iv): - """ - Encrypts `plaintext` using CBC mode and PKCS#7 padding, with the given - initialization vector (iv). - """ - assert len(iv) == 16 - - plaintext = pad(plaintext) - + if len(iv) != 16: raise ValueError(f"Invalid IV length: {len(iv)}") blocks = [] previous = iv for plaintext_block in split_blocks(plaintext): - # CBC mode encrypt: encrypt(plaintext_block XOR previous) block = self.encrypt_block(xor_bytes(plaintext_block, previous)) blocks.append(block) previous = block @@ -291,157 +224,13 @@ class AES256: return b''.join(blocks) def decrypt_cbc(self, ciphertext, iv): - """ - Decrypts `ciphertext` using CBC mode and PKCS#7 padding, with the given - initialization vector (iv). - """ - assert len(iv) == 16 - + if len(iv) != 16: raise ValueError(f"Invalid IV length: {len(iv)}") blocks = [] previous = iv for ciphertext_block in split_blocks(ciphertext): - # CBC mode decrypt: previous XOR decrypt(ciphertext) blocks.append(xor_bytes(previous, self.decrypt_block(ciphertext_block))) previous = ciphertext_block - return unpad(b''.join(blocks)) - - def encrypt_pcbc(self, plaintext, iv): - """ - Encrypts `plaintext` using PCBC mode and PKCS#7 padding, with the given - initialization vector (iv). - """ - assert len(iv) == 16 - - plaintext = pad(plaintext) - - blocks = [] - prev_ciphertext = iv - prev_plaintext = bytes(16) - for plaintext_block in split_blocks(plaintext): - # PCBC mode encrypt: encrypt(plaintext_block XOR (prev_ciphertext XOR prev_plaintext)) - ciphertext_block = self.encrypt_block(xor_bytes(plaintext_block, xor_bytes(prev_ciphertext, prev_plaintext))) - blocks.append(ciphertext_block) - prev_ciphertext = ciphertext_block - prev_plaintext = plaintext_block - - return b''.join(blocks) - - def decrypt_pcbc(self, ciphertext, iv): - """ - Decrypts `ciphertext` using PCBC mode and PKCS#7 padding, with the given - initialization vector (iv). - """ - assert len(iv) == 16 - - blocks = [] - prev_ciphertext = iv - prev_plaintext = bytes(16) - for ciphertext_block in split_blocks(ciphertext): - # PCBC mode decrypt: (prev_plaintext XOR prev_ciphertext) XOR decrypt(ciphertext_block) - plaintext_block = xor_bytes(xor_bytes(prev_ciphertext, prev_plaintext), self.decrypt_block(ciphertext_block)) - blocks.append(plaintext_block) - prev_ciphertext = ciphertext_block - prev_plaintext = plaintext_block - - return unpad(b''.join(blocks)) - - def encrypt_cfb(self, plaintext, iv): - """ - Encrypts `plaintext` with the given initialization vector (iv). - """ - assert len(iv) == 16 - - blocks = [] - prev_ciphertext = iv - for plaintext_block in split_blocks(plaintext, require_padding=False): - # CFB mode encrypt: plaintext_block XOR encrypt(prev_ciphertext) - ciphertext_block = xor_bytes(plaintext_block, self.encrypt_block(prev_ciphertext)) - blocks.append(ciphertext_block) - prev_ciphertext = ciphertext_block - - return b''.join(blocks) - - def decrypt_cfb(self, ciphertext, iv): - """ - Decrypts `ciphertext` with the given initialization vector (iv). - """ - assert len(iv) == 16 - - blocks = [] - prev_ciphertext = iv - for ciphertext_block in split_blocks(ciphertext, require_padding=False): - # CFB mode decrypt: ciphertext XOR decrypt(prev_ciphertext) - plaintext_block = xor_bytes(ciphertext_block, self.encrypt_block(prev_ciphertext)) - blocks.append(plaintext_block) - prev_ciphertext = ciphertext_block - - return b''.join(blocks) - - def encrypt_ofb(self, plaintext, iv): - """ - Encrypts `plaintext` using OFB mode initialization vector (iv). - """ - assert len(iv) == 16 - - blocks = [] - previous = iv - for plaintext_block in split_blocks(plaintext, require_padding=False): - # OFB mode encrypt: plaintext_block XOR encrypt(previous) - block = self.encrypt_block(previous) - ciphertext_block = xor_bytes(plaintext_block, block) - blocks.append(ciphertext_block) - previous = block - - return b''.join(blocks) - - def decrypt_ofb(self, ciphertext, iv): - """ - Decrypts `ciphertext` using OFB mode initialization vector (iv). - """ - assert len(iv) == 16 - - blocks = [] - previous = iv - for ciphertext_block in split_blocks(ciphertext, require_padding=False): - # OFB mode decrypt: ciphertext XOR encrypt(previous) - block = self.encrypt_block(previous) - plaintext_block = xor_bytes(ciphertext_block, block) - blocks.append(plaintext_block) - previous = block - - return b''.join(blocks) - - def encrypt_ctr(self, plaintext, iv): - """ - Encrypts `plaintext` using CTR mode with the given nounce/IV. - """ - assert len(iv) == 16 - - blocks = [] - nonce = iv - for plaintext_block in split_blocks(plaintext, require_padding=False): - # CTR mode encrypt: plaintext_block XOR encrypt(nonce) - block = xor_bytes(plaintext_block, self.encrypt_block(nonce)) - blocks.append(block) - nonce = inc_bytes(nonce) - - return b''.join(blocks) - - def decrypt_ctr(self, ciphertext, iv): - """ - Decrypts `ciphertext` using CTR mode with the given nounce/IV. - """ - assert len(iv) == 16 - - blocks = [] - nonce = iv - for ciphertext_block in split_blocks(ciphertext, require_padding=False): - # CTR mode decrypt: ciphertext XOR encrypt(nonce) - block = xor_bytes(ciphertext_block, self.encrypt_block(nonce)) - blocks.append(block) - nonce = inc_bytes(nonce) - return b''.join(blocks) __all__ = ["AES256"] \ No newline at end of file