improved efficiency of AEAD

This commit is contained in:
mr-alice 2016-10-26 18:15:47 +02:00
parent 88298b997e
commit c87ca67120

View File

@ -34,6 +34,7 @@
#include "crypto/chacha20.h"
#include "util/rsprint.h"
#include "util/rsscopetimer.h"
#define rotl(x,n) { x = (x << n) | (x >> (-n & 31)) ;}
@ -95,16 +96,26 @@ struct uint256_32
b[6] += u.b[6] + (b[5]>>32);
b[7] += u.b[7] + (b[6]>>32);
b[0] &= 0xffffffff;
b[1] &= 0xffffffff;
b[2] &= 0xffffffff;
b[3] &= 0xffffffff;
b[4] &= 0xffffffff;
b[5] &= 0xffffffff;
b[6] &= 0xffffffff;
b[7] &= 0xffffffff;
b[0] = (uint32_t) b[0];
b[1] = (uint32_t) b[1];
b[2] = (uint32_t) b[2];
b[3] = (uint32_t) b[3];
b[4] = (uint32_t) b[4];
b[5] = (uint32_t) b[5];
b[6] = (uint32_t) b[6];
b[7] = (uint32_t) b[7];
}
void operator -=(const uint256_32& u)
{
*this += ~u ;
++(*this) ;
}
void operator++()
{
for(int i=0;i<8;++i)
if( (++b[i]) &= 0xffffffff)
break ;
}
void operator -=(const uint256_32& u) { *this += ~u ; *this += uint256_32(0,0,0,0,0,0,0,1); }
bool operator<(const uint256_32& u) const
{
@ -191,37 +202,61 @@ struct uint256_32
for(int c=7;c>=0;--c)
if(b[c] != 0)
{
if( (b[c] & 0xff000000) != 0) return c*32 + 3*8 + max_non_zero_of_height_bits(b[c] >> 24) ;
if( (b[c] & 0x00ff0000) != 0) return c*32 + 2*8 + max_non_zero_of_height_bits(b[c] >> 16) ;
if( (b[c] & 0x0000ff00) != 0) return c*32 + 1*8 + max_non_zero_of_height_bits(b[c] >> 8) ;
if( (b[c] & 0xff000000) != 0) return (c<<5) + 3*8 + max_non_zero_of_height_bits(b[c] >> 24) ;
if( (b[c] & 0x00ff0000) != 0) return (c<<5) + 2*8 + max_non_zero_of_height_bits(b[c] >> 16) ;
if( (b[c] & 0x0000ff00) != 0) return (c<<5) + 1*8 + max_non_zero_of_height_bits(b[c] >> 8) ;
return c*32 + 0*8 + max_non_zero_of_height_bits(b[c]) ;
}
return -1;
}
void lshift()
void lshift(uint32_t n)
{
int r = 0 ;
uint32_t p = n >> 5; // n/32
uint32_t u = n & 0x1f ; // n%32
if(p > 0)
for(int i=7;i>=0;--i)
b[i] = (i>=p)?b[i-p]:0 ;
uint32_t r = 0 ;
if(u>0)
for(int i=0;i<8;++i)
{
uint32_t r1 = (b[i] >> 31) ;
b[i] = (b[i] << 1) & 0xffffffff;
uint32_t r1 = (b[i] >> (31-u+1)) ;
b[i] = (b[i] << u) & 0xffffffff;
b[i] += r ;
r = r1 ;
}
}
void lshift()
{
uint32_t r ;
uint32_t r1 ;
r1 = (b[0] >> 31) ; b[0] = (b[0] << 1) & 0xffffffff; r = r1 ;
r1 = (b[1] >> 31) ; b[1] = (b[1] << 1) & 0xffffffff; b[1] += r ; r = r1 ;
r1 = (b[2] >> 31) ; b[2] = (b[2] << 1) & 0xffffffff; b[2] += r ; r = r1 ;
r1 = (b[3] >> 31) ; b[3] = (b[3] << 1) & 0xffffffff; b[3] += r ; r = r1 ;
r1 = (b[4] >> 31) ; b[4] = (b[4] << 1) & 0xffffffff; b[4] += r ; r = r1 ;
r1 = (b[5] >> 31) ; b[5] = (b[5] << 1) & 0xffffffff; b[5] += r ; r = r1 ;
r1 = (b[6] >> 31) ; b[6] = (b[6] << 1) & 0xffffffff; b[6] += r ; r = r1 ;
b[7] = (b[7] << 1) & 0xffffffff; b[7] += r ;
}
void rshift()
{
uint32_t r = 0 ;
uint32_t r ;
uint32_t r1 ;
for(int i=7;i>=0;--i)
{
uint32_t r1 = b[i] & 0x1;
b[i] >>= 1 ;
b[i] += r << 31;
r = r1 ;
}
r1 = b[7] & 0x1; b[7] >>= 1 ; r = r1 ;
r1 = b[6] & 0x1; b[6] >>= 1 ; if(r) b[6] += 0x80000000 ; r = r1 ;
r1 = b[5] & 0x1; b[5] >>= 1 ; if(r) b[5] += 0x80000000 ; r = r1 ;
r1 = b[4] & 0x1; b[4] >>= 1 ; if(r) b[4] += 0x80000000 ; r = r1 ;
r1 = b[3] & 0x1; b[3] >>= 1 ; if(r) b[3] += 0x80000000 ; r = r1 ;
r1 = b[2] & 0x1; b[2] >>= 1 ; if(r) b[2] += 0x80000000 ; r = r1 ;
r1 = b[1] & 0x1; b[1] >>= 1 ; if(r) b[1] += 0x80000000 ; r = r1 ;
b[0] >>= 1 ; if(r) b[0] += 0x80000000 ;
}
};
@ -236,11 +271,12 @@ static void quotient(const uint256_32& n,const uint256_32& p,uint256_32& q,uint2
int bmax = n.max_non_zero_bit() - p.max_non_zero_bit();
uint256_32 m(0,0,0,0,0,0,0,1) ;
uint256_32 m(0,0,0,0,0,0,0,0) ;
uint256_32 d = p ;
for(int i=0;i<bmax;++i)
m.lshift(), d.lshift() ;
m.b[bmax/32] = (1u << (bmax%32)) ; // set m to be 2^bmax
d.lshift(bmax);
for(int b=bmax;b>=0;--b,d.rshift(),m.rshift())
if(! (r < d))
@ -249,6 +285,20 @@ static void quotient(const uint256_32& n,const uint256_32& p,uint256_32& q,uint2
q += m ;
}
}
static void remainder(const uint256_32& n,const uint256_32& p,uint256_32& r)
{
// simple algorithm: add up multiples of u while keeping below *this. Once done, substract.
r = n ;
int bmax = n.max_non_zero_bit() - p.max_non_zero_bit();
uint256_32 d = p ;
d.lshift(bmax);
for(int b=bmax;b>=0;--b,d.rshift())
if(! (r < d))
r -= d ;
}
class chacha20_state
{
@ -308,6 +358,7 @@ static void apply_20_rounds(chacha20_state& s)
add(s,t) ;
}
#ifdef DEBUG_CHACHA20
static void print(const chacha20_state& s)
{
fprintf(stdout,"%08x %08x %08x %08x\n",s.c[0 ],s.c[1 ],s.c[2 ],s.c[3 ]) ;
@ -315,61 +366,7 @@ static void print(const chacha20_state& s)
fprintf(stdout,"%08x %08x %08x %08x\n",s.c[8 ],s.c[9 ],s.c[10],s.c[11]) ;
fprintf(stdout,"%08x %08x %08x %08x\n",s.c[12],s.c[13],s.c[14],s.c[15]) ;
}
// static uint8_t read16bits(char s)
// {
// if(s >= '0' && s <= '9')
// return s - '0' ;
// else if(s >= 'a' && s <= 'f')
// return s - 'a' + 10 ;
// else if(s >= 'A' && s <= 'F')
// return s - 'A' + 10 ;
// else
// throw std::runtime_error("Not an hex string!") ;
// }
//
// static uint256_32 create_256bit_int(const std::string& s)
// {
// uint256_32 r(0,0,0,0,0,0,0,0) ;
//
// fprintf(stdout,"Scanning %s\n",s.c_str()) ;
//
// for(int i=0;i<(int)s.length();++i)
// {
// uint32_t byte = (s.length() -1 - i)/2 ;
// uint32_t p = byte/4 ;
// uint32_t val;
//
// if(p >= 8)
// continue ;
//
// val = read16bits(s[i]) ;
//
// r.b[p] |= (( (val << (( (s.length()-i+1)%2)*4))) << (8*byte)) ;
// }
//
// return r;
// }
// static uint256_32 create_256bit_int_from_serialized(const std::string& s)
// {
// uint256_32 r(0,0,0,0,0,0,0,0) ;
//
// fprintf(stdout,"Scanning %s\n",s.c_str()) ;
//
// for(int i=0;i<(int)s.length();i+=3)
// {
// int byte = i/3 ;
// int p = byte/4 ;
// int sub_byte = byte - 4*p ;
//
// uint8_t b1 = read16bits(s[i+0]) ;
// uint8_t b2 = read16bits(s[i+1]) ;
// uint32_t b = (b1 << 4) + b2 ;
//
// r.b[p] |= ( b << (8*sub_byte)) ;
// }
// return r ;
// }
#endif
void chacha20_encrypt(uint8_t key[32], uint32_t block_counter, uint8_t nonce[12], uint8_t *data, uint32_t size)
{
@ -449,7 +446,7 @@ static void poly1305_add(poly1305_state& s,uint8_t *message,uint32_t size,bool p
s.a *= s.r ;
uint256_32 q,rst;
quotient(s.a,s.p,q,rst) ;
remainder(s.a,s.p,rst) ;
s.a = rst ;
}
}
@ -646,6 +643,13 @@ bool perform_tests()
std::cerr << " OK" << std::endl;
// operators
{ uint256_32 uu(0,0,0,0,0,0,0,0 ) ; ++uu ; if(!(uu == uint256_32(0,0,0,0,0,0,0,1))) return false ; }
{ uint256_32 uu(0,0,0,0,0,0,0,0xffffffff) ; ++uu ; if(!(uu == uint256_32(0,0,0,0,0,0,1,0))) return false ; }
std::cerr << " operator++ on 256bits numbers OK" << std::endl;
// sums/diffs of numbers
for(uint32_t i=0;i<100;++i)
@ -1144,6 +1148,37 @@ bool perform_tests()
}
std::cerr << " RFC7539 AEAD test vector #1 OK" << std::endl;
// bandwidth test
//
{
uint32_t SIZE = 1*1024*1024 ;
uint8_t *ten_megabyte_data = (uint8_t*)malloc(SIZE) ;
uint8_t key[32] = { 0x1c,0x92,0x40,0xa5,0xeb,0x55,0xd3,0x8a,0xf3,0x33,0x88,0x86,0x04,0xf6,0xb5,0xf0,
0x47,0x39,0x17,0xc1,0x40,0x2b,0x80,0x09,0x9d,0xca,0x5c,0xbc,0x20,0x70,0x75,0xc0 };
uint8_t nonce[12] = { 0x00,0x00,0x00,0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08 };
uint8_t aad[12] = { 0xf3,0x33,0x88,0x86,0x00,0x00,0x00,0x00,0x00,0x00,0x4e,0x91 };
uint8_t received_tag[16] ;
{
RsScopeTimer s("AEAD") ;
chacha20_encrypt(key, 1, nonce, ten_megabyte_data,SIZE) ;
std::cerr << " Chacha20 encryption speed: " << SIZE / (1024.0*1024.0) / s.duration() << " MB/s" << std::endl;
}
{
RsScopeTimer s("AEAD") ;
AEAD_chacha20_poly1305(key,nonce,ten_megabyte_data,SIZE,aad,12,received_tag,true) ;
std::cerr << " AEAD encryption speed: " << SIZE / (1024.0*1024.0) / s.duration() << " MB/s" << std::endl;
}
free(ten_megabyte_data) ;
}
return true;
}