Implement and test trim_tree algo in memory

This commit is contained in:
j-berman 2024-07-08 20:01:14 -07:00
parent 36f1e1965f
commit 5ddca0ce11
9 changed files with 916 additions and 742 deletions

View file

@ -230,520 +230,311 @@ void CurveTreesGlobalTree::extend_tree(const CurveTreesV1::TreeExtension &tree_e
}
}
//----------------------------------------------------------------------------------------------------------------------
// If we reached the new root, then clear all remaining elements in the tree above the root. Otherwise continue
template <typename C>
static bool handle_root_after_trim(const std::size_t num_parents,
const std::size_t c1_expected_n_layers,
const std::size_t c2_expected_n_layers,
CurveTreesGlobalTree::Layer<C> &parents_inout,
std::vector<CurveTreesGlobalTree::Layer<Helios>> &c1_layers_inout,
std::vector<CurveTreesGlobalTree::Layer<Selene>> &c2_layers_inout)
void CurveTreesGlobalTree::reduce_tree(const CurveTreesV1::TreeReduction &tree_reduction)
{
// We're at the root if there should only be 1 element in the layer
if (num_parents > 1)
return false;
MDEBUG("We have encountered the root, clearing remaining elements in the tree");
// Clear all parents after root
while (parents_inout.size() > 1)
parents_inout.pop_back();
// Clear all remaining layers, if any
while (c1_layers_inout.size() > c1_expected_n_layers)
c1_layers_inout.pop_back();
while (c2_layers_inout.size() > c2_expected_n_layers)
c2_layers_inout.pop_back();
return true;
}
//----------------------------------------------------------------------------------------------------------------------
// Trims the child layer and caches values needed to update and trim the child's parent layer
// TODO: work on consolidating this function with the leaf layer logic and simplifying edge case handling
template <typename C_CHILD, typename C_PARENT>
static typename C_PARENT::Point trim_children(const C_CHILD &c_child,
const C_PARENT &c_parent,
const std::size_t parent_width,
const CurveTreesGlobalTree::Layer<C_PARENT> &parents,
const typename C_CHILD::Point &old_last_child_hash,
CurveTreesGlobalTree::Layer<C_CHILD> &children_inout,
std::size_t &last_parent_idx_inout,
typename C_PARENT::Point &old_last_parent_hash_out)
{
const std::size_t old_num_children = children_inout.size();
const std::size_t old_last_parent_idx = (old_num_children - 1) / parent_width;
const std::size_t old_last_offset = old_num_children % parent_width;
const std::size_t new_num_children = last_parent_idx_inout + 1;
const std::size_t new_last_parent_idx = (new_num_children - 1) / parent_width;
const std::size_t new_last_offset = new_num_children % parent_width;
CHECK_AND_ASSERT_THROW_MES(old_num_children >= new_num_children, "unexpected new_num_children");
last_parent_idx_inout = new_last_parent_idx;
old_last_parent_hash_out = parents[new_last_parent_idx];
MDEBUG("old_num_children: " << old_num_children <<
" , old_last_parent_idx: " << old_last_parent_idx <<
" , old_last_offset: " << old_last_offset <<
" , old_last_parent_hash_out: " << c_parent.to_string(old_last_parent_hash_out) <<
" , new_num_children: " << new_num_children <<
" , new_last_parent_idx: " << new_last_parent_idx <<
" , new_last_offset: " << new_last_offset);
// TODO: consolidate logic handling this function with the edge case at the end of this function
if (old_num_children == new_num_children)
{
// No new children means we only updated the last child, so use it to get the new last parent
const auto new_last_child = c_child.point_to_cycle_scalar(children_inout.back());
std::vector<typename C_PARENT::Scalar> new_child_v{new_last_child};
const auto &chunk = typename C_PARENT::Chunk{new_child_v.data(), new_child_v.size()};
const auto new_last_parent = c_parent.hash_grow(
/*existing_hash*/ old_last_parent_hash_out,
/*offset*/ (new_num_children - 1) % parent_width,
/*first_child_after_offset*/ c_child.point_to_cycle_scalar(old_last_child_hash),
/*children*/ chunk);
MDEBUG("New last parent using updated last child " << c_parent.to_string(new_last_parent));
return new_last_parent;
}
// Get the number of existing children in what will become the new last chunk after trimming
const std::size_t new_last_chunk_old_num_children = (old_last_parent_idx > new_last_parent_idx
|| old_last_offset == 0)
? parent_width
: old_last_offset;
CHECK_AND_ASSERT_THROW_MES(new_last_chunk_old_num_children > new_last_offset,
"unexpected new_last_chunk_old_num_children");
// Get the number of children we'll be trimming from the new last chunk
const std::size_t trim_n_children_from_new_last_chunk = new_last_offset == 0
? 0 // it wil remain full
: new_last_chunk_old_num_children - new_last_offset;
// We use hash trim if we're removing fewer elems in the last chunk than the number of elems remaining
const bool last_chunk_use_hash_trim = trim_n_children_from_new_last_chunk > 0
&& trim_n_children_from_new_last_chunk < new_last_offset;
MDEBUG("new_last_chunk_old_num_children: " << new_last_chunk_old_num_children <<
" , trim_n_children_from_new_last_chunk: " << trim_n_children_from_new_last_chunk <<
" , last_chunk_use_hash_trim: " << last_chunk_use_hash_trim);
// If we're using hash_trim for the last chunk, we'll need to collect the children we're removing
// TODO: use a separate function to handle last_chunk_use_hash_trim case
std::vector<typename C_CHILD::Point> new_last_chunk_children_to_trim;
if (last_chunk_use_hash_trim)
new_last_chunk_children_to_trim.reserve(trim_n_children_from_new_last_chunk);
// Trim the children starting at the back of the child layer
MDEBUG("Trimming " << (old_num_children - new_num_children) << " children");
while (children_inout.size() > new_num_children)
{
// If we're using hash_trim for the last chunk, collect children from the last chunk
if (last_chunk_use_hash_trim)
{
const std::size_t cur_last_parent_idx = (children_inout.size() - 1) / parent_width;
if (cur_last_parent_idx == new_last_parent_idx)
new_last_chunk_children_to_trim.emplace_back(std::move(children_inout.back()));
}
children_inout.pop_back();
}
CHECK_AND_ASSERT_THROW_MES(children_inout.size() == new_num_children, "unexpected new children");
// We're done trimming the children
// If we're not using hash_trim for the last chunk, and we will be trimming from the new last chunk, then
// we'll need to collect the new last chunk's remaining children for hash_grow
// TODO: use a separate function to handle last_chunk_remaining_children case
std::vector<typename C_CHILD::Point> last_chunk_remaining_children;
if (!last_chunk_use_hash_trim && new_last_offset > 0)
{
last_chunk_remaining_children.reserve(new_last_offset);
const std::size_t start_child_idx = new_last_parent_idx * parent_width;
CHECK_AND_ASSERT_THROW_MES((start_child_idx + new_last_offset) == children_inout.size(),
"unexpected start_child_idx");
for (std::size_t i = start_child_idx; i < children_inout.size(); ++i)
{
CHECK_AND_ASSERT_THROW_MES(i < children_inout.size(), "unexpected child idx");
last_chunk_remaining_children.push_back(children_inout[i]);
}
}
CHECK_AND_ASSERT_THROW_MES(!parents.empty(), "empty parent layer");
CHECK_AND_ASSERT_THROW_MES(new_last_parent_idx < parents.size(), "unexpected new_last_parent_idx");
// Set the new last chunk's parent hash
if (last_chunk_use_hash_trim)
{
CHECK_AND_ASSERT_THROW_MES(new_last_chunk_children_to_trim.size() == trim_n_children_from_new_last_chunk,
"unexpected size of last child chunk");
// We need to reverse the order in order to match the order the children were initially inserted into the tree
std::reverse(new_last_chunk_children_to_trim.begin(), new_last_chunk_children_to_trim.end());
// Check if the last child changed
const auto &old_last_child = old_last_child_hash;
const auto &new_last_child = children_inout.back();
if (c_child.to_bytes(old_last_child) == c_child.to_bytes(new_last_child))
{
// If the last child didn't change, then simply trim the collected children
std::vector<typename C_PARENT::Scalar> child_scalars;
fcmp::tower_cycle::extend_scalars_from_cycle_points<C_CHILD, C_PARENT>(c_child,
new_last_chunk_children_to_trim,
child_scalars);
for (std::size_t i = 0; i < child_scalars.size(); ++i)
MDEBUG("Trimming child " << c_parent.to_string(child_scalars[i]));
const auto &chunk = typename C_PARENT::Chunk{child_scalars.data(), child_scalars.size()};
const auto new_last_parent = c_parent.hash_trim(
old_last_parent_hash_out,
new_last_offset,
chunk);
MDEBUG("New last parent using simple hash_trim " << c_parent.to_string(new_last_parent));
return new_last_parent;
}
// The last child changed, so trim the old child, then grow the chunk by 1 with the new child
// TODO: implement prior_child_at_offset in hash_trim
new_last_chunk_children_to_trim.insert(new_last_chunk_children_to_trim.begin(), old_last_child);
std::vector<typename C_PARENT::Scalar> child_scalars;
fcmp::tower_cycle::extend_scalars_from_cycle_points<C_CHILD, C_PARENT>(c_child,
new_last_chunk_children_to_trim,
child_scalars);
for (std::size_t i = 0; i < child_scalars.size(); ++i)
MDEBUG("Trimming child " << c_parent.to_string(child_scalars[i]));
const auto &chunk = typename C_PARENT::Chunk{child_scalars.data(), child_scalars.size()};
CHECK_AND_ASSERT_THROW_MES(new_last_offset > 0, "new_last_offset must be >0");
auto new_last_parent = c_parent.hash_trim(
old_last_parent_hash_out,
new_last_offset - 1,
chunk);
std::vector<typename C_PARENT::Scalar> new_last_child_scalar{c_child.point_to_cycle_scalar(new_last_child)};
const auto &new_last_child_chunk = typename C_PARENT::Chunk{
new_last_child_scalar.data(),
new_last_child_scalar.size()};
MDEBUG("Growing with new child: " << c_parent.to_string(new_last_child_scalar[0]));
new_last_parent = c_parent.hash_grow(
new_last_parent,
new_last_offset - 1,
c_parent.zero_scalar(),
new_last_child_chunk);
MDEBUG("New last parent using hash_trim AND updated last child " << c_parent.to_string(new_last_parent));
return new_last_parent;
}
else if (!last_chunk_remaining_children.empty())
{
// If we have reamining children in the new last chunk, and some children were trimmed from the chunk, then
// use hash_grow to calculate the new hash
std::vector<typename C_PARENT::Scalar> child_scalars;
fcmp::tower_cycle::extend_scalars_from_cycle_points<C_CHILD, C_PARENT>(c_child,
last_chunk_remaining_children,
child_scalars);
const auto &chunk = typename C_PARENT::Chunk{child_scalars.data(), child_scalars.size()};
auto new_last_parent = c_parent.hash_grow(
/*existing_hash*/ c_parent.m_hash_init_point,
/*offset*/ 0,
/*first_child_after_offset*/ c_parent.zero_scalar(),
/*children*/ chunk);
MDEBUG("New last parent from re-growing last chunk " << c_parent.to_string(new_last_parent));
return new_last_parent;
}
// Check if the last child updated
const auto &old_last_child = old_last_child_hash;
const auto &new_last_child = children_inout.back();
const auto old_last_child_bytes = c_child.to_bytes(old_last_child);
const auto new_last_child_bytes = c_child.to_bytes(new_last_child);
if (old_last_child_bytes == new_last_child_bytes)
{
MDEBUG("The last child didn't update, nothing left to do");
return old_last_parent_hash_out;
}
// TODO: try to consolidate handling this edge case with the case of old_num_children == new_num_children
MDEBUG("The last child changed, updating last chunk parent hash");
CHECK_AND_ASSERT_THROW_MES(new_last_offset == 0, "unexpected new last offset");
const auto old_last_child_scalar = c_child.point_to_cycle_scalar(old_last_child);
auto new_last_child_scalar = c_child.point_to_cycle_scalar(new_last_child);
std::vector<typename C_PARENT::Scalar> child_scalars{std::move(new_last_child_scalar)};
const auto &chunk = typename C_PARENT::Chunk{child_scalars.data(), child_scalars.size()};
auto new_last_parent = c_parent.hash_grow(
/*existing_hash*/ old_last_parent_hash_out,
/*offset*/ parent_width - 1,
/*first_child_after_offset*/ old_last_child_scalar,
/*children*/ chunk);
MDEBUG("New last parent from updated last child " << c_parent.to_string(new_last_parent));
return new_last_parent;
}
//----------------------------------------------------------------------------------------------------------------------
void CurveTreesGlobalTree::trim_tree(const std::size_t new_num_leaves)
{
// TODO: consolidate below logic with trim_children above
CHECK_AND_ASSERT_THROW_MES(new_num_leaves >= CurveTreesV1::LEAF_TUPLE_SIZE,
"tree must have at least 1 leaf tuple in it");
CHECK_AND_ASSERT_THROW_MES(new_num_leaves % CurveTreesV1::LEAF_TUPLE_SIZE == 0,
"num leaves must be divisible by leaf tuple size");
auto &leaves_out = m_tree.leaves;
auto &c1_layers_out = m_tree.c1_layers;
auto &c2_layers_out = m_tree.c2_layers;
const std::size_t old_num_leaves = leaves_out.size() * CurveTreesV1::LEAF_TUPLE_SIZE;
CHECK_AND_ASSERT_THROW_MES(old_num_leaves > new_num_leaves, "unexpected new num leaves");
const std::size_t old_last_leaf_parent_idx = (old_num_leaves - CurveTreesV1::LEAF_TUPLE_SIZE)
/ m_curve_trees.m_leaf_layer_chunk_width;
const std::size_t old_last_leaf_offset = old_num_leaves % m_curve_trees.m_leaf_layer_chunk_width;
const std::size_t new_last_leaf_parent_idx = (new_num_leaves - CurveTreesV1::LEAF_TUPLE_SIZE)
/ m_curve_trees.m_leaf_layer_chunk_width;
const std::size_t new_last_leaf_offset = new_num_leaves % m_curve_trees.m_leaf_layer_chunk_width;
MDEBUG("old_num_leaves: " << old_num_leaves <<
", old_last_leaf_parent_idx: " << old_last_leaf_parent_idx <<
", old_last_leaf_offset: " << old_last_leaf_offset <<
", new_num_leaves: " << new_num_leaves <<
", new_last_leaf_parent_idx: " << new_last_leaf_parent_idx <<
", new_last_leaf_offset: " << new_last_leaf_offset);
// Get the number of existing leaves in what will become the new last chunk after trimming
const std::size_t new_last_chunk_old_num_leaves = (old_last_leaf_parent_idx > new_last_leaf_parent_idx
|| old_last_leaf_offset == 0)
? m_curve_trees.m_leaf_layer_chunk_width
: old_last_leaf_offset;
CHECK_AND_ASSERT_THROW_MES(new_last_chunk_old_num_leaves > new_last_leaf_offset,
"unexpected last_chunk_old_num_leaves");
// Get the number of leaves we'll be trimming from the new last chunk
const std::size_t n_leaves_trim_from_new_last_chunk = new_last_leaf_offset == 0
? 0 // the last chunk wil remain full
: new_last_chunk_old_num_leaves - new_last_leaf_offset;
// We use hash trim if we're removing fewer elems in the last chunk than the number of elems remaining
const bool last_chunk_use_hash_trim = n_leaves_trim_from_new_last_chunk > 0
&& n_leaves_trim_from_new_last_chunk < new_last_leaf_offset;
MDEBUG("new_last_chunk_old_num_leaves: " << new_last_chunk_old_num_leaves <<
", n_leaves_trim_from_new_last_chunk: " << n_leaves_trim_from_new_last_chunk <<
", last_chunk_use_hash_trim: " << last_chunk_use_hash_trim);
// If we're using hash_trim for the last chunk, we'll need to collect the leaves we're trimming from that chunk
std::vector<Selene::Scalar> new_last_chunk_leaves_to_trim;
if (last_chunk_use_hash_trim)
new_last_chunk_leaves_to_trim.reserve(n_leaves_trim_from_new_last_chunk);
// Trim the leaves starting at the back of the leaf layer
const std::size_t new_num_leaf_tuples = new_num_leaves / CurveTreesV1::LEAF_TUPLE_SIZE;
while (leaves_out.size() > new_num_leaf_tuples)
{
// If we're using hash_trim for the last chunk, collect leaves from the last chunk to use later
if (last_chunk_use_hash_trim)
{
// Check if we're now trimming leaves from what will be the new last chunk
const std::size_t num_leaves_remaining = (leaves_out.size() - 1) * CurveTreesV1::LEAF_TUPLE_SIZE;
const std::size_t cur_last_leaf_parent_idx = num_leaves_remaining / m_curve_trees.m_leaf_layer_chunk_width;
if (cur_last_leaf_parent_idx == new_last_leaf_parent_idx)
{
// Add leaves in reverse order, because we're going to reverse the entire vector later on to get the
// correct trim order
new_last_chunk_leaves_to_trim.emplace_back(std::move(leaves_out.back().C_x));
new_last_chunk_leaves_to_trim.emplace_back(std::move(leaves_out.back().I_x));
new_last_chunk_leaves_to_trim.emplace_back(std::move(leaves_out.back().O_x));
}
}
leaves_out.pop_back();
}
CHECK_AND_ASSERT_THROW_MES(leaves_out.size() == new_num_leaf_tuples, "unexpected size of new leaves");
const std::size_t cur_last_leaf_parent_idx = ((leaves_out.size() - 1) * CurveTreesV1::LEAF_TUPLE_SIZE)
/ m_curve_trees.m_leaf_layer_chunk_width;
CHECK_AND_ASSERT_THROW_MES(cur_last_leaf_parent_idx == new_last_leaf_parent_idx, "unexpected last leaf parent idx");
// If we're not using hash_trim for the last chunk, and the new last chunk is not full already, we'll need to
// collect the existing leaves to get the hash using hash_grow
std::vector<Selene::Scalar> last_chunk_remaining_leaves;
if (!last_chunk_use_hash_trim && new_last_leaf_offset > 0)
{
last_chunk_remaining_leaves.reserve(new_last_leaf_offset);
const std::size_t start_leaf_idx = new_last_leaf_parent_idx * m_curve_trees.m_leaf_layer_chunk_width;
MDEBUG("start_leaf_idx: " << start_leaf_idx << ", leaves_out.size(): " << leaves_out.size());
CHECK_AND_ASSERT_THROW_MES((start_leaf_idx + new_last_leaf_offset) == new_num_leaves,
"unexpected start_leaf_idx");
for (std::size_t i = (start_leaf_idx / CurveTreesV1::LEAF_TUPLE_SIZE); i < leaves_out.size(); ++i)
{
CHECK_AND_ASSERT_THROW_MES(i < leaves_out.size(), "unexpected leaf idx");
last_chunk_remaining_leaves.push_back(leaves_out[i].O_x);
last_chunk_remaining_leaves.push_back(leaves_out[i].I_x);
last_chunk_remaining_leaves.push_back(leaves_out[i].C_x);
}
}
CHECK_AND_ASSERT_THROW_MES(!c2_layers_out.empty(), "empty leaf parent layer");
CHECK_AND_ASSERT_THROW_MES(cur_last_leaf_parent_idx < c2_layers_out[0].size(),
"unexpected cur_last_leaf_parent_idx");
// Set the new last leaf parent
Selene::Point old_last_c2_hash = std::move(c2_layers_out[0][cur_last_leaf_parent_idx]);
if (last_chunk_use_hash_trim)
{
CHECK_AND_ASSERT_THROW_MES(new_last_chunk_leaves_to_trim.size() == n_leaves_trim_from_new_last_chunk,
"unexpected size of last leaf chunk");
// We need to reverse the order in order to match the order the leaves were initially inserted into the tree
std::reverse(new_last_chunk_leaves_to_trim.begin(), new_last_chunk_leaves_to_trim.end());
const Selene::Chunk trim_leaves{new_last_chunk_leaves_to_trim.data(), new_last_chunk_leaves_to_trim.size()};
for (std::size_t i = 0; i < new_last_chunk_leaves_to_trim.size(); ++i)
MDEBUG("Trimming leaf " << m_curve_trees.m_c2.to_string(new_last_chunk_leaves_to_trim[i]));
auto new_last_leaf_parent = m_curve_trees.m_c2.hash_trim(
old_last_c2_hash,
new_last_leaf_offset,
trim_leaves);
MDEBUG("New hash " << m_curve_trees.m_c2.to_string(new_last_leaf_parent));
c2_layers_out[0][cur_last_leaf_parent_idx] = std::move(new_last_leaf_parent);
}
else if (new_last_leaf_offset > 0)
{
for (std::size_t i = 0; i < last_chunk_remaining_leaves.size(); ++i)
MDEBUG("Hashing leaf " << m_curve_trees.m_c2.to_string(last_chunk_remaining_leaves[i]));
const auto &leaves = Selene::Chunk{last_chunk_remaining_leaves.data(), last_chunk_remaining_leaves.size()};
auto new_last_leaf_parent = m_curve_trees.m_c2.hash_grow(
/*existing_hash*/ m_curve_trees.m_c2.m_hash_init_point,
/*offset*/ 0,
/*first_child_after_offset*/ m_curve_trees.m_c2.zero_scalar(),
/*children*/ leaves);
MDEBUG("Result hash " << m_curve_trees.m_c2.to_string(new_last_leaf_parent));
c2_layers_out[0][cur_last_leaf_parent_idx] = std::move(new_last_leaf_parent);
}
if (handle_root_after_trim<Selene>(
/*num_parents*/ cur_last_leaf_parent_idx + 1,
/*c1_expected_n_layers*/ 0,
/*c2_expected_n_layers*/ 1,
/*parents_inout*/ c2_layers_out[0],
/*c1_layers_inout*/ c1_layers_out,
/*c2_layers_inout*/ c2_layers_out))
{
return;
}
// Go layer-by-layer starting by trimming the c2 layer we just set, and updating the parent layer hashes
bool trim_c1 = true;
std::size_t c1_idx = 0;
// Trim the leaves
const std::size_t init_num_leaves = m_tree.leaves.size() * m_curve_trees.LEAF_TUPLE_SIZE;
CHECK_AND_ASSERT_THROW_MES(init_num_leaves > tree_reduction.new_total_leaves, "expected fewer new total leaves");
CHECK_AND_ASSERT_THROW_MES((tree_reduction.new_total_leaves % m_curve_trees.LEAF_TUPLE_SIZE) == 0,
"unexpected new total leaves");
const std::size_t new_total_leaf_tuples = tree_reduction.new_total_leaves / m_curve_trees.LEAF_TUPLE_SIZE;
while (m_tree.leaves.size() > new_total_leaf_tuples)
m_tree.leaves.pop_back();
// Trim the layers
const auto &c2_layer_reductions = tree_reduction.c2_layer_reductions;
const auto &c1_layer_reductions = tree_reduction.c1_layer_reductions;
CHECK_AND_ASSERT_THROW_MES(!c2_layer_reductions.empty(), "empty c2 layer reductions");
bool use_c2 = true;
std::size_t c2_idx = 0;
std::size_t last_parent_idx = cur_last_leaf_parent_idx;
Helios::Point old_last_c1_hash;
for (std::size_t i = 0; i < (c1_layers_out.size() + c2_layers_out.size()); ++i)
std::size_t c1_idx = 0;
for (std::size_t i = 0; i < (c2_layer_reductions.size() + c1_layer_reductions.size()); ++i)
{
MDEBUG("Trimming layer " << i);
CHECK_AND_ASSERT_THROW_MES(c1_idx < c1_layers_out.size(), "unexpected c1 layer");
CHECK_AND_ASSERT_THROW_MES(c2_idx < c2_layers_out.size(), "unexpected c2 layer");
auto &c1_layer_out = c1_layers_out[c1_idx];
auto &c2_layer_out = c2_layers_out[c2_idx];
if (trim_c1)
// TODO: template below if statement
if (use_c2)
{
// TODO: fewer params
auto new_last_parent = trim_children(m_curve_trees.m_c2,
m_curve_trees.m_c1,
m_curve_trees.m_c1_width,
c1_layer_out,
old_last_c2_hash,
c2_layer_out,
last_parent_idx,
old_last_c1_hash);
CHECK_AND_ASSERT_THROW_MES(c2_idx < c2_layer_reductions.size(), "unexpected c2 layer reduction");
const auto &c2_reduction = c2_layer_reductions[c2_idx];
// Update the last parent
c1_layer_out[last_parent_idx] = std::move(new_last_parent);
CHECK_AND_ASSERT_THROW_MES(c2_idx < m_tree.c2_layers.size(), "missing c2 layer");
auto &c2_inout = m_tree.c2_layers[c2_idx];
if (handle_root_after_trim<Helios>(last_parent_idx + 1,
c1_idx + 1,
c2_idx + 1,
c1_layer_out,
c1_layers_out,
c2_layers_out))
CHECK_AND_ASSERT_THROW_MES(c2_reduction.new_total_parents <= c2_inout.size(),
"unexpected c2 new total parents");
c2_inout.resize(c2_reduction.new_total_parents);
c2_inout.shrink_to_fit();
// We updated the last hash
if (c2_reduction.update_existing_last_hash)
{
return;
c2_inout.back() = c2_reduction.new_last_hash;
}
++c2_idx;
}
else
{
// TODO: fewer params
auto new_last_parent = trim_children(m_curve_trees.m_c1,
m_curve_trees.m_c2,
m_curve_trees.m_c2_width,
c2_layer_out,
old_last_c1_hash,
c1_layer_out,
last_parent_idx,
old_last_c2_hash);
CHECK_AND_ASSERT_THROW_MES(c1_idx < c1_layer_reductions.size(), "unexpected c1 layer reduction");
const auto &c1_reduction = c1_layer_reductions[c1_idx];
// Update the last parent
c2_layer_out[last_parent_idx] = std::move(new_last_parent);
CHECK_AND_ASSERT_THROW_MES(c1_idx < m_tree.c1_layers.size(), "missing c1 layer");
auto &c1_inout = m_tree.c1_layers[c1_idx];
if (handle_root_after_trim<Selene>(last_parent_idx + 1,
c1_idx + 1,
c2_idx + 1,
c2_layer_out,
c1_layers_out,
c2_layers_out))
CHECK_AND_ASSERT_THROW_MES(c1_reduction.new_total_parents <= c1_inout.size(),
"unexpected c1 new total parents");
c1_inout.resize(c1_reduction.new_total_parents);
c1_inout.shrink_to_fit();
// We updated the last hash
if (c1_reduction.update_existing_last_hash)
{
return;
c1_inout.back() = c1_reduction.new_last_hash;
}
++c1_idx;
}
trim_c1 = !trim_c1;
use_c2 = !use_c2;
}
// Delete remaining layers
m_tree.c1_layers.resize(c1_layer_reductions.size());
m_tree.c2_layers.resize(c2_layer_reductions.size());
m_tree.c1_layers.shrink_to_fit();
m_tree.c2_layers.shrink_to_fit();
}
//----------------------------------------------------------------------------------------------------------------------
bool CurveTreesGlobalTree::audit_tree()
template<typename C_CHILD, typename C_PARENT>
static std::vector<typename C_PARENT::Scalar> get_last_chunk_children_to_trim(const C_CHILD &c_child,
const fcmp::curve_trees::TrimLayerInstructions &trim_instructions,
const CurveTreesGlobalTree::Layer<C_CHILD> &child_layer)
{
std::vector<typename C_PARENT::Scalar> children_to_trim_out;
const std::size_t new_total_children = trim_instructions.new_total_children;
const std::size_t old_total_children = trim_instructions.old_total_children;
const std::size_t new_total_parents = trim_instructions.new_total_parents;
const std::size_t parent_chunk_width = trim_instructions.parent_chunk_width;
const std::size_t new_offset = trim_instructions.new_offset;
CHECK_AND_ASSERT_THROW_MES(new_total_children > 0, "expected some new children");
CHECK_AND_ASSERT_THROW_MES(new_total_children >= new_offset, "expected more children than offset");
CHECK_AND_ASSERT_THROW_MES(new_total_parents > 0, "expected some new parents");
if (trim_instructions.need_last_chunk_children_to_trim)
{
std::size_t idx = ((new_total_parents - 1) * parent_chunk_width) + new_offset;
MDEBUG("Start trim from idx: " << idx);
do
{
// TODO: consolidate do while inner logic with below
CHECK_AND_ASSERT_THROW_MES(child_layer.size() > idx, "idx too high");
const auto &child_point = child_layer[idx];
auto child_scalar = c_child.point_to_cycle_scalar(child_point);
children_to_trim_out.push_back(std::move(child_scalar));
++idx;
}
while ((idx < old_total_children) && (idx % parent_chunk_width != 0));
}
else if (trim_instructions.need_last_chunk_remaining_children && new_offset > 0)
{
std::size_t idx = new_total_children - new_offset;
MDEBUG("Start grow remaining from idx: " << idx);
do
{
CHECK_AND_ASSERT_THROW_MES(child_layer.size() > idx, "idx too high");
const auto &child_point = child_layer[idx];
auto child_scalar = c_child.point_to_cycle_scalar(child_point);
children_to_trim_out.push_back(std::move(child_scalar));
++idx;
}
while ((idx < new_total_children) && (idx % parent_chunk_width != 0));
}
return children_to_trim_out;
}
//----------------------------------------------------------------------------------------------------------------------
// TODO: template
CurveTreesV1::LastChunkChildrenToTrim CurveTreesGlobalTree::get_all_last_chunk_children_to_trim(
const std::vector<fcmp::curve_trees::TrimLayerInstructions> &trim_instructions)
{
CurveTreesV1::LastChunkChildrenToTrim all_children_to_trim;
// Leaf layer
CHECK_AND_ASSERT_THROW_MES(!trim_instructions.empty(), "no instructions");
const auto &trim_leaf_layer_instructions = trim_instructions[0];
const std::size_t new_total_children = trim_leaf_layer_instructions.new_total_children;
const std::size_t old_total_children = trim_leaf_layer_instructions.old_total_children;
const std::size_t new_total_parents = trim_leaf_layer_instructions.new_total_parents;
const std::size_t parent_chunk_width = trim_leaf_layer_instructions.parent_chunk_width;
const std::size_t new_offset = trim_leaf_layer_instructions.new_offset;
CHECK_AND_ASSERT_THROW_MES(new_total_children >= CurveTreesV1::LEAF_TUPLE_SIZE, "expected some new leaves");
CHECK_AND_ASSERT_THROW_MES(new_total_children >= new_offset, "expected more children than offset");
CHECK_AND_ASSERT_THROW_MES(new_total_parents > 0, "expected some new parents");
std::vector<Selene::Scalar> leaves_to_trim;
// TODO: separate function
// TODO: calculate starting indexes in trim instructions, perhaps calculate end indexes also
if (trim_leaf_layer_instructions.need_last_chunk_children_to_trim)
{
std::size_t idx = ((new_total_parents - 1) * parent_chunk_width) + new_offset;
MDEBUG("Start trim from idx: " << idx);
do
{
CHECK_AND_ASSERT_THROW_MES(idx % CurveTreesV1::LEAF_TUPLE_SIZE == 0, "expected divisble by leaf tuple size");
const std::size_t leaf_tuple_idx = idx / CurveTreesV1::LEAF_TUPLE_SIZE;
CHECK_AND_ASSERT_THROW_MES(m_tree.leaves.size() > leaf_tuple_idx, "leaf_tuple_idx too high");
const auto &leaf_tuple = m_tree.leaves[leaf_tuple_idx];
leaves_to_trim.push_back(leaf_tuple.O_x);
leaves_to_trim.push_back(leaf_tuple.I_x);
leaves_to_trim.push_back(leaf_tuple.C_x);
idx += CurveTreesV1::LEAF_TUPLE_SIZE;
}
while ((idx < old_total_children) && (idx % parent_chunk_width != 0));
}
else if (trim_leaf_layer_instructions.need_last_chunk_remaining_children && new_offset > 0)
{
std::size_t idx = new_total_children - new_offset;
do
{
CHECK_AND_ASSERT_THROW_MES(idx % CurveTreesV1::LEAF_TUPLE_SIZE == 0, "expected divisble by leaf tuple size");
const std::size_t leaf_tuple_idx = idx / CurveTreesV1::LEAF_TUPLE_SIZE;
CHECK_AND_ASSERT_THROW_MES(m_tree.leaves.size() > leaf_tuple_idx, "leaf_tuple_idx too high");
const auto &leaf_tuple = m_tree.leaves[leaf_tuple_idx];
leaves_to_trim.push_back(leaf_tuple.O_x);
leaves_to_trim.push_back(leaf_tuple.I_x);
leaves_to_trim.push_back(leaf_tuple.C_x);
idx += CurveTreesV1::LEAF_TUPLE_SIZE;
}
while ((idx < new_total_children) && (idx % parent_chunk_width != 0));
}
all_children_to_trim.c2_children.emplace_back(std::move(leaves_to_trim));
bool parent_is_c2 = false;
std::size_t c1_idx = 0;
std::size_t c2_idx = 0;
for (std::size_t i = 1; i < trim_instructions.size(); ++i)
{
const auto &trim_layer_instructions = trim_instructions[i];
if (parent_is_c2)
{
CHECK_AND_ASSERT_THROW_MES(m_tree.c1_layers.size() > c1_idx, "c1_idx too high");
auto children_to_trim = get_last_chunk_children_to_trim<Helios, Selene>(
m_curve_trees.m_c1,
trim_layer_instructions,
m_tree.c1_layers[c1_idx]);
all_children_to_trim.c2_children.emplace_back(std::move(children_to_trim));
++c1_idx;
}
else
{
CHECK_AND_ASSERT_THROW_MES(m_tree.c2_layers.size() > c2_idx, "c2_idx too high");
auto children_to_trim = get_last_chunk_children_to_trim<Selene, Helios>(
m_curve_trees.m_c2,
trim_layer_instructions,
m_tree.c2_layers[c2_idx]);
all_children_to_trim.c1_children.emplace_back(std::move(children_to_trim));
++c2_idx;
}
parent_is_c2 = !parent_is_c2;
}
return all_children_to_trim;
}
//----------------------------------------------------------------------------------------------------------------------
CurveTreesV1::LastHashes CurveTreesGlobalTree::get_last_hashes_to_trim(
const std::vector<fcmp::curve_trees::TrimLayerInstructions> &trim_instructions) const
{
CurveTreesV1::LastHashes last_hashes;
CHECK_AND_ASSERT_THROW_MES(!trim_instructions.empty(), "no instructions");
bool parent_is_c2 = true;
std::size_t c1_idx = 0;
std::size_t c2_idx = 0;
for (const auto &trim_layer_instructions : trim_instructions)
{
const std::size_t new_total_parents = trim_layer_instructions.new_total_parents;
CHECK_AND_ASSERT_THROW_MES(new_total_parents > 0, "no new parents");
if (parent_is_c2)
{
CHECK_AND_ASSERT_THROW_MES(m_tree.c2_layers.size() > c2_idx, "c2_idx too high");
const auto &c2_layer = m_tree.c2_layers[c2_idx];
CHECK_AND_ASSERT_THROW_MES(c2_layer.size() >= new_total_parents, "not enough c2 parents");
last_hashes.c2_last_hashes.push_back(c2_layer[new_total_parents - 1]);
++c2_idx;
}
else
{
CHECK_AND_ASSERT_THROW_MES(m_tree.c1_layers.size() > c1_idx, "c1_idx too high");
const auto &c1_layer = m_tree.c1_layers[c1_idx];
CHECK_AND_ASSERT_THROW_MES(c1_layer.size() >= new_total_parents, "not enough c1 parents");
last_hashes.c1_last_hashes.push_back(c1_layer[new_total_parents - 1]);
++c1_idx;
}
parent_is_c2 = !parent_is_c2;
}
return last_hashes;
}
//----------------------------------------------------------------------------------------------------------------------
void CurveTreesGlobalTree::trim_tree(const std::size_t trim_n_leaf_tuples)
{
const std::size_t old_n_leaf_tuples = this->get_num_leaf_tuples();
MDEBUG(old_n_leaf_tuples << " leaves in the tree, trimming " << trim_n_leaf_tuples);
// Get trim instructions
const auto trim_instructions = m_curve_trees.get_trim_instructions(old_n_leaf_tuples, trim_n_leaf_tuples);
MDEBUG("Acquired trim instructions for " << trim_instructions.size() << " layers");
// Do initial tree reads
const auto last_chunk_children_to_trim = this->get_all_last_chunk_children_to_trim(trim_instructions);
const auto last_hashes_to_trim = this->get_last_hashes_to_trim(trim_instructions);
// Get the new hashes, wrapped in a simple struct we can use to trim the tree
const auto tree_reduction = m_curve_trees.get_tree_reduction(
trim_instructions,
last_chunk_children_to_trim,
last_hashes_to_trim);
// Use tree reduction to trim tree
this->reduce_tree(tree_reduction);
const std::size_t new_n_leaf_tuples = this->get_num_leaf_tuples();
CHECK_AND_ASSERT_THROW_MES((new_n_leaf_tuples + trim_n_leaf_tuples) == old_n_leaf_tuples,
"unexpected num leaves after trim");
}
//----------------------------------------------------------------------------------------------------------------------
bool CurveTreesGlobalTree::audit_tree(const std::size_t expected_n_leaf_tuples)
{
MDEBUG("Auditing global tree");
@ -752,6 +543,8 @@ bool CurveTreesGlobalTree::audit_tree()
const auto &c2_layers = m_tree.c2_layers;
CHECK_AND_ASSERT_MES(!leaves.empty(), false, "must have at least 1 leaf in tree");
CHECK_AND_ASSERT_MES(leaves.size() == expected_n_leaf_tuples, false, "unexpected num leaves");
CHECK_AND_ASSERT_MES(!c2_layers.empty(), false, "must have at least 1 c2 layer in tree");
CHECK_AND_ASSERT_MES(c2_layers.size() == c1_layers.size() || c2_layers.size() == (c1_layers.size() + 1),
false, "unexpected mismatch of c2 and c1 layers");
@ -983,7 +776,7 @@ void CurveTreesGlobalTree::log_tree()
//----------------------------------------------------------------------------------------------------------------------
// Test helpers
//----------------------------------------------------------------------------------------------------------------------
const std::vector<CurveTreesV1::LeafTuple> generate_random_leaves(const CurveTreesV1 &curve_trees,
static const std::vector<CurveTreesV1::LeafTuple> generate_random_leaves(const CurveTreesV1 &curve_trees,
const std::size_t num_leaves)
{
std::vector<CurveTreesV1::LeafTuple> tuples;
@ -1005,9 +798,18 @@ const std::vector<CurveTreesV1::LeafTuple> generate_random_leaves(const CurveTre
return tuples;
}
//----------------------------------------------------------------------------------------------------------------------
static const Selene::Scalar generate_random_selene_scalar()
{
crypto::secret_key s;
crypto::public_key S;
crypto::generate_keys(S, s, s, false);
return fcmp::tower_cycle::ed_25519_point_to_scalar(S);
}
//----------------------------------------------------------------------------------------------------------------------
static bool grow_tree(CurveTreesV1 &curve_trees,
CurveTreesGlobalTree &global_tree,
const std::size_t num_leaves)
const std::size_t new_n_leaf_tuples)
{
// Do initial tree reads
const std::size_t old_n_leaf_tuples = global_tree.get_num_leaf_tuples();
@ -1019,7 +821,7 @@ static bool grow_tree(CurveTreesV1 &curve_trees,
// - The tree extension includes all elements we'll need to add to the existing tree when adding the new leaves
const auto tree_extension = curve_trees.get_tree_extension(old_n_leaf_tuples,
last_hashes,
generate_random_leaves(curve_trees, num_leaves));
generate_random_leaves(curve_trees, new_n_leaf_tuples));
global_tree.log_tree_extension(tree_extension);
@ -1029,7 +831,8 @@ static bool grow_tree(CurveTreesV1 &curve_trees,
global_tree.log_tree();
// Validate tree structure and all hashes
return global_tree.audit_tree();
const std::size_t expected_n_leaf_tuples = old_n_leaf_tuples + new_n_leaf_tuples;
return global_tree.audit_tree(expected_n_leaf_tuples);
}
//----------------------------------------------------------------------------------------------------------------------
static bool grow_tree_in_memory(const std::size_t init_leaves,
@ -1059,25 +862,27 @@ static bool grow_tree_in_memory(const std::size_t init_leaves,
return true;
}
//----------------------------------------------------------------------------------------------------------------------
static bool trim_tree_in_memory(const std::size_t init_leaves,
const std::size_t trim_leaves,
static bool trim_tree_in_memory(const std::size_t trim_n_leaf_tuples,
CurveTreesGlobalTree &&global_tree)
{
// Trim the global tree by `trim_leaves`
LOG_PRINT_L1("Trimming " << trim_leaves << " leaves from tree");
const std::size_t old_n_leaf_tuples = global_tree.get_num_leaf_tuples();
CHECK_AND_ASSERT_THROW_MES(old_n_leaf_tuples > trim_n_leaf_tuples, "cannot trim more leaves than exist");
CHECK_AND_ASSERT_THROW_MES(trim_n_leaf_tuples > 0, "must be trimming some leaves");
CHECK_AND_ASSERT_MES(init_leaves > trim_leaves, false, "trimming too many leaves");
const std::size_t new_num_leaves = init_leaves - trim_leaves;
global_tree.trim_tree(new_num_leaves * CurveTreesV1::LEAF_TUPLE_SIZE);
// Trim the global tree by `trim_n_leaf_tuples`
LOG_PRINT_L1("Trimming " << trim_n_leaf_tuples << " leaf tuples from tree");
MDEBUG("Finished trimming " << trim_leaves << " leaves from tree");
global_tree.trim_tree(trim_n_leaf_tuples);
MDEBUG("Finished trimming " << trim_n_leaf_tuples << " leaf tuples from tree");
global_tree.log_tree();
bool res = global_tree.audit_tree();
const std::size_t expected_n_leaf_tuples = old_n_leaf_tuples - trim_n_leaf_tuples;
bool res = global_tree.audit_tree(expected_n_leaf_tuples);
CHECK_AND_ASSERT_MES(res, false, "failed to trim tree in memory");
MDEBUG("Successfully trimmed " << trim_leaves << " leaves in memory");
MDEBUG("Successfully trimmed " << trim_n_leaf_tuples << " leaves in memory");
return true;
}
//----------------------------------------------------------------------------------------------------------------------
@ -1116,12 +921,9 @@ TEST(curve_trees, grow_tree)
Helios helios;
Selene selene;
// Constant for how deep we want the tree
const std::size_t TEST_N_LAYERS = 4;
// Use lower values for chunk width than prod so that we can quickly test a many-layer deep tree
const std::size_t helios_chunk_width = 3;
const std::size_t selene_chunk_width = 2;
static const std::size_t helios_chunk_width = 3;
static const std::size_t selene_chunk_width = 2;
static_assert(helios_chunk_width > 1, "helios width must be > 1");
static_assert(selene_chunk_width > 1, "selene width must be > 1");
@ -1129,6 +931,9 @@ TEST(curve_trees, grow_tree)
LOG_PRINT_L1("Test grow tree with helios chunk width " << helios_chunk_width
<< ", selene chunk width " << selene_chunk_width);
// Constant for how deep we want the tree
static const std::size_t TEST_N_LAYERS = 4;
// Number of leaves for which x number of layers is required
std::size_t leaves_needed_for_n_layers = selene_chunk_width;
for (std::size_t i = 1; i < TEST_N_LAYERS; ++i)
@ -1153,7 +958,7 @@ TEST(curve_trees, grow_tree)
{
// TODO: init tree once, then extend a copy of that tree
// Then extend the tree with ext_leaves
for (std::size_t ext_leaves = 1; (init_leaves + ext_leaves) < leaves_needed_for_n_layers; ++ext_leaves)
for (std::size_t ext_leaves = 1; (init_leaves + ext_leaves) <= leaves_needed_for_n_layers; ++ext_leaves)
{
ASSERT_TRUE(grow_tree_in_memory(init_leaves, ext_leaves, curve_trees));
ASSERT_TRUE(grow_tree_db(init_leaves, ext_leaves, curve_trees, test_db));