/* * phase1.hpp * * Created on: May 26, 2021 * Author: mad */ #ifndef INCLUDE_CHIA_PHASE1_HPP_ #define INCLUDE_CHIA_PHASE1_HPP_ #include #include #include #include #include "b3/blake3.h" #include "chacha8.h" namespace phase1 { static uint16_t L_targets[2][kBC][kExtraBitsPow]; static void load_tables() { for (uint8_t parity = 0; parity < 2; parity++) { for (uint16_t i = 0; i < kBC; i++) { uint16_t indJ = i / kC; for (uint16_t m = 0; m < kExtraBitsPow; m++) { uint16_t yr = ((indJ + m) % kB) * kC + (((2 * m + parity) * (2 * m + parity) + i) % kC); L_targets[parity][i][m] = yr; } } } } static void initialize() { load_tables(); } class F1Calculator { public: F1Calculator(const uint8_t* orig_key) { uint8_t enc_key[32] = {}; // First byte is 1, the index of this table enc_key[0] = 1; memcpy(enc_key + 1, orig_key, 31); // Setup ChaCha8 context with zero-filled IV chacha8_keysetup(&enc_ctx_, enc_key, 256, NULL); } /* * x = [index * 16 .. index * 16 + 15] * block = entry_1[16] */ void compute_block(const uint64_t index, entry_1* block) { uint8_t buf[64]; chacha8_get_keystream(&enc_ctx_, index, 1, buf); for(uint64_t i = 0; i < 16; ++i) { const uint64_t x = index * 16 + i; const uint64_t y = Util::SliceInt64FromBytes(buf, i * 32, 32); block[i].x = x; block[i].y = (y << kExtraBits) | (x >> (32 - kExtraBits)); } } private: chacha8_ctx enc_ctx_ {}; }; // Class to evaluate F2 .. F7. template class FxCalculator { public: static constexpr uint8_t k_ = 32; FxCalculator(int table_index) { table_index_ = table_index; } // Disable copying FxCalculator(const FxCalculator&) = delete; // Performs one evaluation of the f function. void evaluate(const T& L, const T& R, S& entry) const { Bits C; Bits input; uint8_t input_bytes[64]; uint8_t hash_bytes[32]; uint8_t L_meta[16]; uint8_t R_meta[16]; size_t L_meta_bytes = 0; size_t R_meta_bytes = 0; get_meta{}(L, L_meta, &L_meta_bytes); get_meta{}(R, R_meta, &R_meta_bytes); const Bits Y_1(L.y, k_ + kExtraBits); const Bits L_c(L_meta, L_meta_bytes, L_meta_bytes * 8); const Bits R_c(R_meta, R_meta_bytes, R_meta_bytes * 8); if (table_index_ < 4) { C = L_c + R_c; input = Y_1 + C; } else { input = Y_1 + L_c + R_c; } input.ToBytes(input_bytes); blake3_hasher hasher; blake3_hasher_init(&hasher); blake3_hasher_update(&hasher, input_bytes, cdiv(input.GetSize(), 8)); blake3_hasher_finalize(&hasher, hash_bytes, sizeof(hash_bytes)); entry.y = Util::EightBytesToInt(hash_bytes) >> (64 - (k_ + (table_index_ < 7 ? kExtraBits : 0))); if (table_index_ < 4) { // c is already computed } else if (table_index_ < 7) { uint8_t len = kVectorLens[table_index_ + 1]; uint8_t start_byte = (k_ + kExtraBits) / 8; uint8_t end_bit = k_ + kExtraBits + k_ * len; uint8_t end_byte = cdiv(end_bit, 8); // TODO: proper support for partial bytes in Bits ctor C = Bits(hash_bytes + start_byte, end_byte - start_byte, (end_byte - start_byte) * 8); C = C.Slice((k_ + kExtraBits) % 8, end_bit - start_byte * 8); } uint8_t C_bytes[16]; C.ToBytes(C_bytes); set_meta{}(entry, C_bytes, C.GetSize() / 8); } private: int table_index_ = 0; }; template class FxMatcher { public: struct rmap_item { uint16_t pos; uint16_t count; }; FxMatcher() { rmap.resize(kBC); } // Disable copying FxMatcher(const FxMatcher&) = delete; // Given two buckets with entries (y values), computes which y values match, and returns a list // of the pairs of indices into bucket_L and bucket_R. Indices l and r match iff: // let yl = bucket_L[l].y, yr = bucket_R[r].y // // For any 0 <= m < kExtraBitsPow: // yl / kBC + 1 = yR / kBC AND // (yr % kBC) / kC - (yl % kBC) / kC = m (mod kB) AND // (yr % kBC) % kC - (yl % kBC) % kC = (2m + (yl/kBC) % 2)^2 (mod kC) // // Instead of doing the naive algorithm, which is an O(kExtraBitsPow * N^2) comparisons on // bucket length, we can store all the R values and lookup each of our 32 candidates to see if // any R value matches. This function can be further optimized by removing the inner loop, and // being more careful with memory allocation. int find_matches_ex( const std::vector& bucket_L, const std::vector& bucket_R, uint16_t* idx_L, uint16_t* idx_R) { if(bucket_L.empty() || bucket_R.empty()) { return 0; } const uint16_t parity = (bucket_L[0].y / kBC) % 2; for (auto yl : rmap_clean) { rmap[yl].count = 0; } rmap_clean.clear(); const uint64_t offset = (bucket_R[0].y / kBC) * kBC; for (size_t pos_R = 0; pos_R < bucket_R.size(); pos_R++) { const uint64_t r_y = bucket_R[pos_R].y - offset; if (!rmap[r_y].count) { rmap[r_y].pos = pos_R; } rmap[r_y].count++; rmap_clean.push_back(r_y); } int idx_count = 0; const uint64_t offset_y = offset - kBC; for (size_t pos_L = 0; pos_L < bucket_L.size(); pos_L++) { const uint64_t r = bucket_L[pos_L].y - offset_y; for (int i = 0; i < kExtraBitsPow; i++) { const uint16_t r_target = L_targets[parity][r][i]; for (size_t j = 0; j < rmap[r_target].count; j++) { idx_L[idx_count] = pos_L; idx_R[idx_count] = rmap[r_target].pos + j; idx_count++; } } } return idx_count; } int find_matches( const uint64_t& L_pos_begin, const std::vector& bucket_L, const std::vector& bucket_R, std::vector>& out) { uint16_t idx_L[kBC]; uint16_t idx_R[kBC]; const int count = find_matches_ex(bucket_L, bucket_R, idx_L, idx_R); for(int i = 0; i < count; ++i) { const auto pos = L_pos_begin + idx_L[i]; if(pos < (uint64_t(1) << 32)) { match_t match; match.left = bucket_L[idx_L[i]]; match.right = bucket_R[idx_R[i]]; match.pos = pos; match.off = idx_R[i] + (bucket_L.size() - idx_L[i]); out.push_back(match); } } return count; } private: std::vector rmap; std::vector rmap_clean; }; /* * id = 32 bytes */ template void compute_f1(const uint8_t* id, int num_threads, DS* T1_sort) { static constexpr size_t M = 4096; // F1 block size const auto begin = get_wall_time_micros(); typedef typename DS::WriteCache WriteCache; ThreadPool, size_t, std::shared_ptr> output( [T1_sort](std::vector& input, size_t&, std::shared_ptr& cache) { if(!cache) { cache = T1_sort->add_cache(); } for(auto& entry : input) { cache->add(entry); } }, nullptr, std::max(num_threads / 2, 1), "phase1/add"); ThreadPool> pool( [id](uint64_t& block, std::vector& out, size_t&) { out.resize(M * 16); F1Calculator F1(id); for(size_t i = 0; i < M; ++i) { F1.compute_block(block * M + i, &out[i * 16]); } }, &output, num_threads, "phase1/F1"); for(uint64_t k = 0; k < (uint64_t(1) << 28) / M; ++k) { pool.take_copy(k); } pool.close(); output.close(); T1_sort->finish(); std::cout << "[P1] Table 1 took " << (get_wall_time_micros() - begin) / 1e6 << " sec" << std::endl; } template uint64_t compute_matches( int R_index, int num_threads, DS_L* L_sort, DS_R* R_sort, Processor>* L_tmp_out, Processor>* R_tmp_out) { std::atomic num_found {}; std::atomic num_written {}; std::array L_index = {}; std::array L_offset = {}; std::array>, 2> L_bucket; double avg_bucket_size = 0; struct match_input_t { std::array L_offset = {}; std::array>, 2> L_bucket; }; typedef typename DS_R::WriteCache WriteCache; ThreadPool, size_t, std::shared_ptr> R_add( [R_sort](std::vector& input, size_t&, std::shared_ptr& cache) { if(!cache) { cache = R_sort->add_cache(); } for(auto& entry : input) { cache->add(entry); } }, nullptr, std::max(num_threads / 2, 1), "phase1/add"); Processor>* R_out = &R_add; if(R_tmp_out) { R_out = R_tmp_out; } ThreadPool>, std::vector> eval_pool( [R_index](std::vector>& matches, std::vector& out, size_t&) { out.reserve(matches.size()); FxCalculator Fx(R_index); for(const auto& match : matches) { S entry; entry.pos = match.pos; entry.off = match.off; Fx.evaluate(match.left, match.right, entry); out.push_back(entry); } }, R_out, num_threads, "phase1/eval"); ThreadPool, std::vector>, FxMatcher> match_pool( [&num_found, &num_written] (std::vector& input, std::vector>& out, FxMatcher& Fx) { out.reserve(64 * 1024); for(const auto& pair : input) { num_found += Fx.find_matches(pair.L_offset[1], *pair.L_bucket[1], *pair.L_bucket[0], out); } num_written += out.size(); }, &eval_pool, num_threads, "phase1/match"); Thread> read_thread( [&L_index, &L_offset, &L_bucket, &avg_bucket_size, &match_pool, L_tmp_out] (std::vector& input) { std::vector out; out.reserve(1024); for(const auto& entry : input) { const uint64_t index = entry.y / kBC; if(index < L_index[0]) { throw std::logic_error("input not sorted"); } if(index > L_index[0]) { if(L_index[1] + 1 == L_index[0]) { match_input_t pair; pair.L_offset = L_offset; pair.L_bucket = L_bucket; out.push_back(pair); } L_index[1] = L_index[0]; L_index[0] = index; L_offset[1] = L_offset[0]; if(auto bucket = L_bucket[0]) { L_offset[0] += bucket->size(); avg_bucket_size = avg_bucket_size * 0.99 + bucket->size() * 0.01; } L_bucket[1] = L_bucket[0]; L_bucket[0] = nullptr; } if(!L_bucket[0]) { L_bucket[0] = std::make_shared>(); L_bucket[0]->reserve(avg_bucket_size * 1.2); } L_bucket[0]->push_back(entry); } match_pool.take(out); if(L_tmp_out) { L_tmp_out->take(input); } }, "phase1/slice"); L_sort->read(&read_thread, std::max(num_threads / 2, 2), std::max(num_threads / 4, 2)); read_thread.close(); match_pool.close(); if(L_index[1] + 1 == L_index[0]) { FxMatcher Fx; std::vector> matches; num_found += Fx.find_matches(L_offset[1], *L_bucket[1], *L_bucket[0], matches); num_written += matches.size(); eval_pool.take(matches); } eval_pool.close(); R_add.close(); if(R_sort) { R_sort->finish(); } if(num_written < num_found) { std::cout << "[P1] Lost " << num_found - num_written << " matches due to 32-bit overflow." << std::endl; } return num_written; } template uint64_t compute_table( int R_index, int num_threads, DS_L* L_sort, DS_R* R_sort, DiskTable* L_tmp, DiskTable* R_tmp = nullptr) { Thread> L_write( [L_tmp](std::vector& input) { for(const auto& entry : input) { R tmp; tmp.assign(entry); L_tmp->write(tmp); } }, "phase1/write/L"); Thread> R_write( [R_tmp](std::vector& input) { for(const auto& entry : input) { R_tmp->write(entry); } }, "phase1/write/R"); const auto begin = get_wall_time_micros(); const auto num_matches = phase1::compute_matches( R_index, num_threads, L_sort, R_sort, L_tmp ? &L_write : nullptr, R_tmp ? &R_write : nullptr); L_write.close(); R_write.close(); if(L_tmp) { L_tmp->close(); } if(R_tmp) { R_tmp->close(); } std::cout << "[P1] Table " << R_index << " took " << (get_wall_time_micros() - begin) / 1e6 << " sec" << ", found " << num_matches << " matches" << std::endl; return num_matches; } inline void compute( const input_t& input, output_t& out, const int num_threads, const int log_num_buckets, const std::string plot_name, const std::string tmp_dir, const std::string tmp_dir_2) { const auto total_begin = get_wall_time_micros(); initialize(); const std::string prefix = tmp_dir + plot_name + ".p1."; const std::string prefix_2 = tmp_dir_2 + plot_name + ".p1."; DiskSort1 sort_1(32 + kExtraBits, log_num_buckets, prefix_2 + "t1"); compute_f1(input.id.data(), num_threads, &sort_1); DiskTable tmp_1(prefix + "table1.tmp"); DiskSort2 sort_2(32 + kExtraBits, log_num_buckets, prefix_2 + "t2"); compute_table( 2, num_threads, &sort_1, &sort_2, &tmp_1); DiskTable tmp_2(prefix + "table2.tmp"); DiskSort3 sort_3(32 + kExtraBits, log_num_buckets, prefix_2 + "t3"); compute_table( 3, num_threads, &sort_2, &sort_3, &tmp_2); DiskTable tmp_3(prefix + "table3.tmp"); DiskSort4 sort_4(32 + kExtraBits, log_num_buckets, prefix_2 + "t4"); compute_table( 4, num_threads, &sort_3, &sort_4, &tmp_3); DiskTable tmp_4(prefix + "table4.tmp"); DiskSort5 sort_5(32 + kExtraBits, log_num_buckets, prefix_2 + "t5"); compute_table( 5, num_threads, &sort_4, &sort_5, &tmp_4); DiskTable tmp_5(prefix + "table5.tmp"); DiskSort6 sort_6(32 + kExtraBits, log_num_buckets, prefix_2 + "t6"); compute_table( 6, num_threads, &sort_5, &sort_6, &tmp_5); DiskTable tmp_6(prefix + "table6.tmp"); DiskTable tmp_7(prefix_2 + "table7.tmp"); compute_table( 7, num_threads, &sort_6, nullptr, &tmp_6, &tmp_7); out.params = input; out.table[0] = tmp_1.get_info(); out.table[1] = tmp_2.get_info(); out.table[2] = tmp_3.get_info(); out.table[3] = tmp_4.get_info(); out.table[4] = tmp_5.get_info(); out.table[5] = tmp_6.get_info(); out.table[6] = tmp_7.get_info(); std::cout << "Phase 1 took " << (get_wall_time_micros() - total_begin) / 1e6 << " sec" << std::endl; } } // phase1 #endif /* INCLUDE_CHIA_PHASE1_HPP_ */