diff --git a/random.c b/random.c index f04aef0..079a753 100644 --- a/random.c +++ b/random.c @@ -48,33 +48,24 @@ void random_state_init(struct random_state *state) { state->reseed = 0; } -void get_random_bytes(struct random_state *state, void *buf, size_t size) { - if (size > RANDOM_CACHE_SIZE / 2) { +static void refill(struct random_state *state) { + if (state->reseed < RANDOM_RESEED_SIZE) { chacha_keystream_bytes(&state->ctx, state->cache, RANDOM_CACHE_SIZE); - return; - } - - while (size) { - if (state->index == RANDOM_CACHE_SIZE) { - if (state->reseed >= RANDOM_RESEED_SIZE) { - random_state_init(state); - } else { - chacha_keystream_bytes(&state->ctx, state->cache, RANDOM_CACHE_SIZE); - state->index = 0; - state->reseed += RANDOM_CACHE_SIZE; - } - } - size_t remaining = RANDOM_CACHE_SIZE - state->index; - size_t copy_size = size < remaining ? size : remaining; - memcpy(buf, state->cache + state->index, copy_size); - state->index += copy_size; - size -= copy_size; + state->index = 0; + state->reseed += RANDOM_CACHE_SIZE; + } else { + random_state_init(state); } } uint16_t get_random_u16(struct random_state *state) { uint16_t value; - get_random_bytes(state, &value, sizeof(value)); + size_t remaining = RANDOM_CACHE_SIZE - state->index; + if (remaining < sizeof(value)) { + refill(state); + } + memcpy(&value, state->cache + state->index, sizeof(value)); + state->index += sizeof(value); return value; } @@ -89,14 +80,19 @@ uint16_t get_random_u16_uniform(struct random_state *state, uint16_t bound) { uint16_t r; do { r = get_random_u16(state); - } while (r < min); + } while (unlikely(r < min)); return r % bound; } uint64_t get_random_u64(struct random_state *state) { uint64_t value; - get_random_bytes(state, &value, sizeof(value)); + size_t remaining = RANDOM_CACHE_SIZE - state->index; + if (remaining < sizeof(value)) { + refill(state); + } + memcpy(&value, state->cache + state->index, sizeof(value)); + state->index += sizeof(value); return value; } @@ -111,7 +107,7 @@ uint64_t get_random_u64_uniform(struct random_state *state, uint64_t bound) { uint64_t r; do { r = get_random_u64(state); - } while (r < min); + } while (unlikely(r < min)); return r % bound; } diff --git a/random.h b/random.h index 2174758..5c2d326 100644 --- a/random.h +++ b/random.h @@ -16,7 +16,6 @@ struct random_state { }; void random_state_init(struct random_state *state); -void get_random_bytes(struct random_state *state, void *buf, size_t size); uint16_t get_random_u16(struct random_state *state); uint16_t get_random_u16_uniform(struct random_state *state, uint16_t bound); uint64_t get_random_u64(struct random_state *state);