This commit is contained in:
Mark Qvist 2025-05-06 17:48:38 +02:00
parent 62ecc0549d
commit 4ae0f28aa0

View file

@ -1,17 +1,17 @@
# MIT License
#
# Copyright (c) 2024 BoppreH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@ -58,25 +58,21 @@ inv_s_box = (
0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D,
)
def sub_bytes(s):
for i in range(4):
for j in range(4):
s[i][j] = s_box[s[i][j]]
def inv_sub_bytes(s):
for i in range(4):
for j in range(4):
s[i][j] = inv_s_box[s[i][j]]
def shift_rows(s):
s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1]
s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3]
def inv_shift_rows(s):
s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1]
s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
@ -87,11 +83,8 @@ def add_round_key(s, k):
for j in range(4):
s[i][j] ^= k[i][j]
# learned from https://web.archive.org/web/20100626212235/http://cs.ucsb.edu/~koc/cs178/projects/JT/aes.c
xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1)
def mix_single_column(a):
# see Sec 4.1.2 in The Design of Rijndael
t = a[0] ^ a[1] ^ a[2] ^ a[3]
@ -101,12 +94,10 @@ def mix_single_column(a):
a[2] ^= t ^ xtime(a[2] ^ a[3])
a[3] ^= t ^ xtime(a[3] ^ u)
def mix_columns(s):
for i in range(4):
mix_single_column(s[i])
def inv_mix_columns(s):
# see Sec 4.1.3 in The Design of Rijndael
for i in range(4):
@ -119,7 +110,6 @@ def inv_mix_columns(s):
mix_columns(s)
r_con = (
0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
@ -128,20 +118,11 @@ r_con = (
)
def bytes2matrix(text):
""" Converts a 16-byte array into a 4x4 matrix. """
return [list(text[i:i+4]) for i in range(0, len(text), 4)]
def matrix2bytes(matrix):
""" Converts a 4x4 matrix into a 16-byte array. """
return bytes(sum(matrix, []))
def xor_bytes(a, b):
""" Returns a new byte array with the elements xor'ed. """
return bytes(i^j for i, j in zip(a, b))
def bytes2matrix(text): return [list(text[i:i+4]) for i in range(0, len(text), 4)]
def matrix2bytes(matrix): return bytes(sum(matrix, []))
def xor_bytes(a, b): return bytes(i^j for i, j in zip(a, b))
def inc_bytes(a):
""" Returns a new byte array with the value increment by 1 """
out = list(a)
for i in reversed(range(len(out))):
if out[i] == 0xFF:
@ -151,52 +132,18 @@ def inc_bytes(a):
break
return bytes(out)
def pad(plaintext):
"""
Pads the given plaintext with PKCS#7 padding to a multiple of 16 bytes.
Note that if the plaintext size is a multiple of 16,
a whole block will be added.
"""
padding_len = 16 - (len(plaintext) % 16)
padding = bytes([padding_len] * padding_len)
return plaintext + padding
def unpad(plaintext):
"""
Removes a PKCS#7 padding, returning the unpadded text and ensuring the
padding was correct.
"""
padding_len = plaintext[-1]
assert padding_len > 0
message, padding = plaintext[:-padding_len], plaintext[-padding_len:]
assert all(p == padding_len for p in padding)
return message
def split_blocks(message, block_size=16, require_padding=True):
assert len(message) % block_size == 0 or not require_padding
return [message[i:i+16] for i in range(0, len(message), block_size)]
assert len(message) % block_size == 0 or not require_padding
return [message[i:i+16] for i in range(0, len(message), block_size)]
class AES256:
"""
Class for AES-128 encryption with CBC mode and PKCS#7.
This is a raw implementation of AES, without key stretching or IV
management. Unless you need that, please use `encrypt` and `decrypt`.
"""
rounds_by_key_size = {32: 14}
def __init__(self, master_key):
"""
Initializes the object with a given key.
"""
assert len(master_key) in AES256.rounds_by_key_size
self.n_rounds = AES256.rounds_by_key_size[len(master_key)]
self._key_matrices = self._expand_key(master_key)
def _expand_key(self, master_key):
"""
Expands and returns a list of key matrices for the given master_key.
"""
# Initialize round keys with raw key material.
key_columns = bytes2matrix(master_key)
iteration_size = len(master_key) // 4
@ -228,9 +175,6 @@ class AES256:
return [key_columns[4*i : 4*(i+1)] for i in range(len(key_columns) // 4)]
def encrypt_block(self, plaintext):
"""
Encrypts a single block of 16 byte long plaintext.
"""
assert len(plaintext) == 16
plain_state = bytes2matrix(plaintext)
@ -250,9 +194,6 @@ class AES256:
return matrix2bytes(plain_state)
def decrypt_block(self, ciphertext):
"""
Decrypts a single block of 16 byte long ciphertext.
"""
assert len(ciphertext) == 16
cipher_state = bytes2matrix(ciphertext)
@ -272,18 +213,10 @@ class AES256:
return matrix2bytes(cipher_state)
def encrypt_cbc(self, plaintext, iv):
"""
Encrypts `plaintext` using CBC mode and PKCS#7 padding, with the given
initialization vector (iv).
"""
assert len(iv) == 16
plaintext = pad(plaintext)
if len(iv) != 16: raise ValueError(f"Invalid IV length: {len(iv)}")
blocks = []
previous = iv
for plaintext_block in split_blocks(plaintext):
# CBC mode encrypt: encrypt(plaintext_block XOR previous)
block = self.encrypt_block(xor_bytes(plaintext_block, previous))
blocks.append(block)
previous = block
@ -291,157 +224,13 @@ class AES256:
return b''.join(blocks)
def decrypt_cbc(self, ciphertext, iv):
"""
Decrypts `ciphertext` using CBC mode and PKCS#7 padding, with the given
initialization vector (iv).
"""
assert len(iv) == 16
if len(iv) != 16: raise ValueError(f"Invalid IV length: {len(iv)}")
blocks = []
previous = iv
for ciphertext_block in split_blocks(ciphertext):
# CBC mode decrypt: previous XOR decrypt(ciphertext)
blocks.append(xor_bytes(previous, self.decrypt_block(ciphertext_block)))
previous = ciphertext_block
return unpad(b''.join(blocks))
def encrypt_pcbc(self, plaintext, iv):
"""
Encrypts `plaintext` using PCBC mode and PKCS#7 padding, with the given
initialization vector (iv).
"""
assert len(iv) == 16
plaintext = pad(plaintext)
blocks = []
prev_ciphertext = iv
prev_plaintext = bytes(16)
for plaintext_block in split_blocks(plaintext):
# PCBC mode encrypt: encrypt(plaintext_block XOR (prev_ciphertext XOR prev_plaintext))
ciphertext_block = self.encrypt_block(xor_bytes(plaintext_block, xor_bytes(prev_ciphertext, prev_plaintext)))
blocks.append(ciphertext_block)
prev_ciphertext = ciphertext_block
prev_plaintext = plaintext_block
return b''.join(blocks)
def decrypt_pcbc(self, ciphertext, iv):
"""
Decrypts `ciphertext` using PCBC mode and PKCS#7 padding, with the given
initialization vector (iv).
"""
assert len(iv) == 16
blocks = []
prev_ciphertext = iv
prev_plaintext = bytes(16)
for ciphertext_block in split_blocks(ciphertext):
# PCBC mode decrypt: (prev_plaintext XOR prev_ciphertext) XOR decrypt(ciphertext_block)
plaintext_block = xor_bytes(xor_bytes(prev_ciphertext, prev_plaintext), self.decrypt_block(ciphertext_block))
blocks.append(plaintext_block)
prev_ciphertext = ciphertext_block
prev_plaintext = plaintext_block
return unpad(b''.join(blocks))
def encrypt_cfb(self, plaintext, iv):
"""
Encrypts `plaintext` with the given initialization vector (iv).
"""
assert len(iv) == 16
blocks = []
prev_ciphertext = iv
for plaintext_block in split_blocks(plaintext, require_padding=False):
# CFB mode encrypt: plaintext_block XOR encrypt(prev_ciphertext)
ciphertext_block = xor_bytes(plaintext_block, self.encrypt_block(prev_ciphertext))
blocks.append(ciphertext_block)
prev_ciphertext = ciphertext_block
return b''.join(blocks)
def decrypt_cfb(self, ciphertext, iv):
"""
Decrypts `ciphertext` with the given initialization vector (iv).
"""
assert len(iv) == 16
blocks = []
prev_ciphertext = iv
for ciphertext_block in split_blocks(ciphertext, require_padding=False):
# CFB mode decrypt: ciphertext XOR decrypt(prev_ciphertext)
plaintext_block = xor_bytes(ciphertext_block, self.encrypt_block(prev_ciphertext))
blocks.append(plaintext_block)
prev_ciphertext = ciphertext_block
return b''.join(blocks)
def encrypt_ofb(self, plaintext, iv):
"""
Encrypts `plaintext` using OFB mode initialization vector (iv).
"""
assert len(iv) == 16
blocks = []
previous = iv
for plaintext_block in split_blocks(plaintext, require_padding=False):
# OFB mode encrypt: plaintext_block XOR encrypt(previous)
block = self.encrypt_block(previous)
ciphertext_block = xor_bytes(plaintext_block, block)
blocks.append(ciphertext_block)
previous = block
return b''.join(blocks)
def decrypt_ofb(self, ciphertext, iv):
"""
Decrypts `ciphertext` using OFB mode initialization vector (iv).
"""
assert len(iv) == 16
blocks = []
previous = iv
for ciphertext_block in split_blocks(ciphertext, require_padding=False):
# OFB mode decrypt: ciphertext XOR encrypt(previous)
block = self.encrypt_block(previous)
plaintext_block = xor_bytes(ciphertext_block, block)
blocks.append(plaintext_block)
previous = block
return b''.join(blocks)
def encrypt_ctr(self, plaintext, iv):
"""
Encrypts `plaintext` using CTR mode with the given nounce/IV.
"""
assert len(iv) == 16
blocks = []
nonce = iv
for plaintext_block in split_blocks(plaintext, require_padding=False):
# CTR mode encrypt: plaintext_block XOR encrypt(nonce)
block = xor_bytes(plaintext_block, self.encrypt_block(nonce))
blocks.append(block)
nonce = inc_bytes(nonce)
return b''.join(blocks)
def decrypt_ctr(self, ciphertext, iv):
"""
Decrypts `ciphertext` using CTR mode with the given nounce/IV.
"""
assert len(iv) == 16
blocks = []
nonce = iv
for ciphertext_block in split_blocks(ciphertext, require_padding=False):
# CTR mode decrypt: ciphertext XOR encrypt(nonce)
block = xor_bytes(ciphertext_block, self.encrypt_block(nonce))
blocks.append(block)
nonce = inc_bytes(nonce)
return b''.join(blocks)
__all__ = ["AES256"]