Program Listing for File context.cu
↰ Return to documentation for file (src/lib/host/ckks/context.cu)
// Copyright 2024-2026 Alişah Özcan
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
// Developer: Alişah Özcan
#include <heongpu/host/ckks/context.cuh>
namespace heongpu
{
HEContextImpl<Scheme::CKKS>::HEContextImpl(const sec_level_type sec_level)
{
if (!coeff_modulus_specified_)
{
scheme_ = scheme_type::ckks;
sec_level_ = sec_level;
keyswitching_type_ = keyswitching_type::NONE;
}
else
{
throw std::logic_error("Parameters cannot be changed after the "
"coeff_modulus is specified!");
}
}
void HEContextImpl<Scheme::CKKS>::set_poly_modulus_degree(
size_t poly_modulus_degree)
{
if ((!coeff_modulus_specified_) && (!poly_modulus_degree_specified_))
{
if (!is_power_of_two(poly_modulus_degree))
{
throw std::logic_error(
"Poly modulus degree have to be power of two");
}
if ((poly_modulus_degree > MAX_POLY_DEGREE) ||
(poly_modulus_degree < MIN_POLY_DEGREE))
{
throw std::logic_error("Poly modulus degree is not supported");
}
n = poly_modulus_degree;
n_power = int(log2l(n));
poly_modulus_degree_specified_ = true;
}
else
{
throw std::logic_error("Poly modulus degree cannot be changed "
"after the coeff_modulus is specified!");
}
}
void HEContextImpl<Scheme::CKKS>::set_coeff_modulus_bit_sizes(
const std::vector<int>& log_Q_bases_bit_sizes,
const std::vector<int>& log_P_bases_bit_sizes)
{
if ((!coeff_modulus_specified_) && (!context_generated_) &&
(poly_modulus_degree_specified_))
{
if (log_P_bases_bit_sizes.empty())
{
throw std::logic_error(
"log_P_bases_bit_sizes cannot be empty!");
}
keyswitching_type_ =
(log_P_bases_bit_sizes.size() == 1)
? keyswitching_type::KEYSWITCHING_METHOD_I
: keyswitching_type::KEYSWITCHING_METHOD_II;
if (!coefficient_validator(log_Q_bases_bit_sizes,
log_P_bases_bit_sizes))
{
throw std::logic_error("P should be bigger than Q pairs!");
}
// Q' = Q x P
Qprime_mod_bit_sizes_ = log_Q_bases_bit_sizes;
Qprime_mod_bit_sizes_.insert(Qprime_mod_bit_sizes_.end(),
log_P_bases_bit_sizes.begin(),
log_P_bases_bit_sizes.end());
Q_mod_bit_sizes_ = log_Q_bases_bit_sizes;
P_mod_bit_sizes_ = log_P_bases_bit_sizes;
total_coeff_bit_count = 0; // TODO: calculate it with prod
for (int i = 0; i < Qprime_mod_bit_sizes_.size(); i++)
{
total_coeff_bit_count =
total_coeff_bit_count + Qprime_mod_bit_sizes_[i];
}
int max_coeff_bit_count = 0;
switch (sec_level_)
{
case sec_level_type::none:
break;
case sec_level_type::sec128:
max_coeff_bit_count = heongpu_128bit_std_parms(n);
break;
case sec_level_type::sec192:
max_coeff_bit_count = heongpu_192bit_std_parms(n);
break;
case sec_level_type::sec256:
max_coeff_bit_count = heongpu_256bit_std_parms(n);
break;
default:
throw std::runtime_error("Invalid security level");
break;
}
if ((max_coeff_bit_count < total_coeff_bit_count) &&
(sec_level_ != sec_level_type::none))
{
throw std::runtime_error(
"Parameters do not align with the security recommendations "
"provided by the lattice-estimator");
}
// Q' bases size
Q_prime_size = Qprime_mod_bit_sizes_.size();
coeff_modulus = Q_prime_size; // not required
// Q bases size
Q_size = log_Q_bases_bit_sizes.size();
// P bases size
P_size = Q_prime_size - Q_size;
prime_vector_ =
generate_primes(n,
Qprime_mod_bit_sizes_); // prime_vector_
for (int i = 0; i < prime_vector_.size(); i++)
{
base_q.push_back(prime_vector_[i].value);
}
coeff_modulus_specified_ = true;
}
else
{
throw std::logic_error("Coeff_modulus cannot be changed after the "
"context is generated!");
}
}
void HEContextImpl<Scheme::CKKS>::set_coeff_modulus_values(
const std::vector<Data64>& log_Q_bases,
const std::vector<Data64>& log_P_bases)
{
if ((!coeff_modulus_specified_) && (!context_generated_) &&
(poly_modulus_degree_specified_))
{
std::vector<int> log_Q_bases_bit_sizes;
for (int i = 0; i < log_Q_bases.size(); i++)
{
log_Q_bases_bit_sizes.push_back(
calculate_bit_size(log_Q_bases[i]));
}
std::vector<int> log_P_bases_bit_sizes;
for (int i = 0; i < log_P_bases.size(); i++)
{
log_P_bases_bit_sizes.push_back(
calculate_bit_size(log_P_bases[i]));
}
if (log_P_bases_bit_sizes.empty())
{
throw std::logic_error(
"log_P_bases_bit_sizes cannot be empty!");
}
keyswitching_type_ =
(log_P_bases_bit_sizes.size() == 1)
? keyswitching_type::KEYSWITCHING_METHOD_I
: keyswitching_type::KEYSWITCHING_METHOD_II;
if (!coefficient_validator(log_Q_bases_bit_sizes,
log_P_bases_bit_sizes))
{
throw std::logic_error(
"Invalid parameters, P should be bigger than Q pairs!");
}
// Q' = Q x P
Qprime_mod_bit_sizes_ = log_Q_bases_bit_sizes;
Qprime_mod_bit_sizes_.insert(Qprime_mod_bit_sizes_.end(),
log_P_bases_bit_sizes.begin(),
log_P_bases_bit_sizes.end());
Q_mod_bit_sizes_ = log_Q_bases_bit_sizes;
P_mod_bit_sizes_ = log_P_bases_bit_sizes;
total_coeff_bit_count = 0; // TODO: calculate it with prod
for (int i = 0; i < Qprime_mod_bit_sizes_.size(); i++)
{
total_coeff_bit_count =
total_coeff_bit_count + Qprime_mod_bit_sizes_[i];
}
int max_coeff_bit_count = 0;
switch (sec_level_)
{
case sec_level_type::none:
break;
case sec_level_type::sec128:
max_coeff_bit_count = heongpu_128bit_std_parms(n);
break;
case sec_level_type::sec192:
max_coeff_bit_count = heongpu_192bit_std_parms(n);
break;
case sec_level_type::sec256:
max_coeff_bit_count = heongpu_256bit_std_parms(n);
break;
default:
throw std::runtime_error("Invalid security level");
break;
}
if ((max_coeff_bit_count < total_coeff_bit_count) &&
(sec_level_ != sec_level_type::none))
{
throw std::runtime_error(
"Parameters do not align with the security recommendations "
"provided by the lattice-estimator");
}
// Q' bases size
Q_prime_size = Qprime_mod_bit_sizes_.size();
coeff_modulus = Q_prime_size; // not required
// Q bases size
Q_size = log_Q_bases_bit_sizes.size();
// P bases size
P_size = Q_prime_size - Q_size;
for (int i = 0; i < log_Q_bases.size(); i++)
{
Modulus64 mod_in(log_Q_bases[i]);
prime_vector_.push_back(mod_in);
}
for (int i = 0; i < log_P_bases.size(); i++)
{
Modulus64 mod_in(log_P_bases[i]);
prime_vector_.push_back(mod_in);
}
for (int i = 0; i < prime_vector_.size(); i++)
{
base_q.push_back(prime_vector_[i].value);
}
coeff_modulus_specified_ = true;
}
else
{
throw std::logic_error("coeff_modulus cannot be changed after the "
"context is generated!");
}
}
void HEContextImpl<Scheme::CKKS>::generate()
{
generate(MemoryPoolConfig::Defaults());
}
void
HEContextImpl<Scheme::CKKS>::generate(const MemoryPoolConfig& pool_config)
{
if ((!context_generated_) && (poly_modulus_degree_specified_) &&
(coeff_modulus_specified_))
{
// Memory pool initialization
MemoryPool::instance().initialize(pool_config);
MemoryPool::instance().use_memory_pool(pool_config.use_memory_pool);
cudaDeviceSynchronize();
// DRNG initialization
std::vector<unsigned char> generated_entropy(16); // for 128 bit
if (1 !=
RAND_bytes(generated_entropy.data(), generated_entropy.size()))
throw std::runtime_error("RAND_bytes failed");
std::vector<unsigned char> generated_nonce(8); // for 128 bit
if (1 !=
RAND_bytes(generated_entropy.data(), generated_entropy.size()))
throw std::runtime_error("RAND_bytes failed");
std::vector<unsigned char> personalization_string = {};
RandomNumberGenerator::instance().initialize(
generated_entropy, generated_nonce, personalization_string,
rngongpu::SecurityLevel::AES128, false);
cudaDeviceSynchronize();
// For kernel stack size
cudaDeviceSetLimit(cudaLimitStackSize, 2048);
modulus_ = std::make_shared<DeviceVector<Modulus64>>(prime_vector_);
std::vector<Data64> base_q_psi =
generate_primitive_root_of_unity(n, prime_vector_);
std::vector<Root64> Qprime_ntt_table =
generate_ntt_table(base_q_psi, prime_vector_, n_power);
std::vector<Root64> Qprime_intt_table =
generate_intt_table(base_q_psi, prime_vector_, n_power);
std::vector<Ninverse64> Qprime_n_inverse =
generate_n_inverse(n, prime_vector_);
ntt_table_ =
std::make_shared<DeviceVector<Root64>>(Qprime_ntt_table);
intt_table_ =
std::make_shared<DeviceVector<Root64>>(Qprime_intt_table);
n_inverse_ =
std::make_shared<DeviceVector<Ninverse64>>(Qprime_n_inverse);
std::vector<Data64> last_q_modinv =
calculate_last_q_modinv(prime_vector_, Q_prime_size, P_size);
std::vector<Data64> half = calculate_half(prime_vector_, P_size);
std::vector<Data64> half_mod =
calculate_half_mod(prime_vector_, half, Q_prime_size, P_size);
std::vector<Data64> factor =
calculate_factor(prime_vector_, Q_size, P_size);
last_q_modinv_ =
std::make_shared<DeviceVector<Data64>>(last_q_modinv);
half_p_ = std::make_shared<DeviceVector<Data64>>(half);
half_mod_ = std::make_shared<DeviceVector<Data64>>(half_mod);
factor_ = std::make_shared<DeviceVector<Data64>>(factor);
// For Rescale parameters for all depth
std::vector<Data64> rescale_last_q_modinv;
std::vector<Data64> rescaled_half_mod;
std::vector<Data64> rescaled_half;
for (int j = 0; j < (Q_size - 1); j++)
{
int inner = (Q_size - 1) - j;
rescaled_half.push_back(prime_vector_[inner].value >> 1);
for (int i = 0; i < inner; i++)
{
Data64 temp_ =
prime_vector_[inner].value % prime_vector_[i].value;
rescale_last_q_modinv.push_back(
OPERATOR64::modinv(temp_, prime_vector_[i]));
rescaled_half_mod.push_back(rescaled_half[j] %
prime_vector_[i].value);
}
}
rescaled_last_q_modinv_ =
std::make_shared<DeviceVector<Data64>>(rescale_last_q_modinv);
rescaled_half_mod_ =
std::make_shared<DeviceVector<Data64>>(rescaled_half_mod);
rescaled_half_ =
std::make_shared<DeviceVector<Data64>>(rescaled_half);
std::vector<Data64> Mi;
std::vector<Data64> Mi_inv;
std::vector<Data64> upper_half_threshold;
std::vector<Data64> decryption_modulus;
for (int i = 0; i < Q_size; i++)
{
int depth_Q_size = Q_size - i;
// Mi
std::vector<Data64> Mi_inner =
calculate_Mi(prime_vector_, depth_Q_size);
for (int j = 0; j < depth_Q_size * depth_Q_size; j++)
{
Mi.push_back(Mi_inner[j]);
}
// Mi_inv
std::vector<Data64> Mi_inv_inner =
calculate_Mi_inv(prime_vector_, depth_Q_size);
for (int j = 0; j < depth_Q_size; j++)
{
Mi_inv.push_back(Mi_inv_inner[j]);
}
// upper_half_threshold
std::vector<Data64> upper_half_threshold_inner =
calculate_upper_half_threshold(prime_vector_, depth_Q_size);
for (int j = 0; j < depth_Q_size; j++)
{
upper_half_threshold.push_back(
upper_half_threshold_inner[j]);
}
// decryption_modulus
std::vector<Data64> M_inner =
calculate_M(prime_vector_, depth_Q_size);
for (int j = 0; j < depth_Q_size; j++)
{
decryption_modulus.push_back(M_inner[j]);
}
}
Mi_ = std::make_shared<DeviceVector<Data64>>(Mi);
Mi_inv_ = std::make_shared<DeviceVector<Data64>>(Mi_inv);
upper_half_threshold_ =
std::make_shared<DeviceVector<Data64>>(upper_half_threshold);
decryption_modulus_ =
std::make_shared<DeviceVector<Data64>>(decryption_modulus);
// prime_location_leveled
std::vector<int> prime_loc;
int counter = Q_size;
for (int i = 0; i < Q_size - 1; i++)
{
for (int j = 0; j < counter; j++)
{
prime_loc.push_back(j);
}
counter--;
for (int j = 0; j < P_size; j++)
{
prime_loc.push_back(Q_size + j);
}
}
prime_location_leveled =
std::make_shared<DeviceVector<int>>(prime_loc);
switch (static_cast<int>(keyswitching_type_))
{
case 1: // KEYSWITCHING_METHOD_I
// Deafult
break;
case 2: // KEYSWITCHING_METHOD_II
{
KeySwitchParameterGenerator pool_ckks(
n, base_q, P_size, scheme_, keyswitching_type_);
m_leveled = pool_ckks.m;
l_leveled =
std::make_shared<std::vector<int>>(pool_ckks.level_Q_);
l_tilda_leveled = std::make_shared<std::vector<int>>(
pool_ckks.level_Qtilda_);
d_leveled =
std::make_shared<std::vector<int>>(pool_ckks.level_d_);
std::vector<std::vector<Data64>>
base_change_matrix_D_to_Qtilda_vec =
pool_ckks.level_base_change_matrix_D_to_Qtilda();
std::vector<std::vector<Data64>> Mi_inv_D_to_Qtilda_vec =
pool_ckks.level_Mi_inv_D_to_Qtilda();
std::vector<std::vector<Data64>> prod_D_to_Qtilda_vec =
pool_ckks.level_prod_D_to_Qtilda();
std::vector<std::vector<int>> I_j_vec =
pool_ckks.level_I_j();
std::vector<std::vector<int>> I_location_vec =
pool_ckks.level_I_location();
std::vector<std::vector<int>> Sk_pair_new_vec =
pool_ckks.level_sk_pair();
base_change_matrix_D_to_Qtilda_leveled =
std::make_shared<std::vector<DeviceVector<Data64>>>();
Mi_inv_D_to_Qtilda_leveled =
std::make_shared<std::vector<DeviceVector<Data64>>>();
prod_D_to_Qtilda_leveled =
std::make_shared<std::vector<DeviceVector<Data64>>>();
I_j_leveled =
std::make_shared<std::vector<DeviceVector<int>>>();
I_location_leveled =
std::make_shared<std::vector<DeviceVector<int>>>();
Sk_pair_leveled =
std::make_shared<std::vector<DeviceVector<int>>>();
for (int pool_lp = 0;
pool_lp < base_change_matrix_D_to_Qtilda_vec.size();
pool_lp++)
{
DeviceVector<Data64>
base_change_matrix_D_to_Qtilda_leveled_inner(
base_change_matrix_D_to_Qtilda_vec[pool_lp]);
DeviceVector<Data64> Mi_inv_D_to_Qtilda_leveled_inner(
Mi_inv_D_to_Qtilda_vec[pool_lp]);
DeviceVector<Data64> prod_D_to_Qtilda_leveled_inner(
prod_D_to_Qtilda_vec[pool_lp]);
base_change_matrix_D_to_Qtilda_leveled->push_back(
std::move(
base_change_matrix_D_to_Qtilda_leveled_inner));
Mi_inv_D_to_Qtilda_leveled->push_back(
std::move(Mi_inv_D_to_Qtilda_leveled_inner));
prod_D_to_Qtilda_leveled->push_back(
std::move(prod_D_to_Qtilda_leveled_inner));
DeviceVector<int> I_j_vec_inner(I_j_vec[pool_lp]);
DeviceVector<int> I_location_vec_inner(
I_location_vec[pool_lp]);
DeviceVector<int> Sk_pair_new_vec_inner(
Sk_pair_new_vec[pool_lp]);
I_j_leveled->push_back(std::move(I_j_vec_inner));
I_location_leveled->push_back(
std::move(I_location_vec_inner));
Sk_pair_leveled->push_back(
std::move(Sk_pair_new_vec_inner));
}
}
break;
default:
throw std::invalid_argument("Invalid Key Switching Type");
break;
}
context_generated_ = true;
}
else
{
throw std::runtime_error("Context is already generated!");
}
}
void HEContextImpl<Scheme::CKKS>::print_parameters()
{
if (context_generated_)
{
std::string scheme_string = "CKKS";
std::cout << "==== HEonGPU a GPU Based Homomorphic Encryption "
"Library ====\n"
<< std::endl;
std::cout << "Encryption parameters:" << std::endl;
std::cout << "--> scheme: " << scheme_string << std::endl;
std::cout << "--> poly_modulus_degree: " << n << std::endl;
std::cout << "--> Q_tilta size: Q( ";
for (std::size_t i = 0; i < Q_mod_bit_sizes_.size() - 1; i++)
{
std::cout << Q_mod_bit_sizes_[i] << " + ";
}
std::cout << Q_mod_bit_sizes_.back();
std::cout << " ) + P( ";
for (std::size_t i = 0; i < P_mod_bit_sizes_.size() - 1; i++)
{
std::cout << P_mod_bit_sizes_[i] << " + ";
}
std::cout << P_mod_bit_sizes_.back();
std::cout << " ) bits" << std::endl;
std::cout << std::endl;
}
else
{
std::cout << "Parameters is not generated yet!" << std::endl;
}
}
void HEContextImpl<Scheme::CKKS>::save(std::ostream& os) const
{
if ((poly_modulus_degree_specified_) && (coeff_modulus_specified_))
{
os.write((char*) &scheme_, sizeof(scheme_));
os.write((char*) &sec_level_, sizeof(sec_level_));
os.write((char*) &keyswitching_type_, sizeof(keyswitching_type_));
os.write((char*) &n, sizeof(n));
os.write((char*) &n_power, sizeof(n_power));
os.write((char*) &coeff_modulus, sizeof(coeff_modulus));
os.write((char*) &total_coeff_bit_count,
sizeof(total_coeff_bit_count));
os.write((char*) &Q_prime_size, sizeof(Q_prime_size));
os.write((char*) &Q_size, sizeof(Q_size));
os.write((char*) &P_size, sizeof(P_size));
uint32_t prime_vector_count = prime_vector_.size();
os.write((char*) &prime_vector_count, sizeof(prime_vector_count));
os.write((char*) prime_vector_.data(),
sizeof(Modulus64) * prime_vector_count);
uint32_t base_q_count = base_q.size();
os.write((char*) &base_q_count, sizeof(base_q_count));
os.write((char*) base_q.data(), sizeof(Data64) * base_q_count);
uint32_t Qprime_mod_bit_sizes_count = Qprime_mod_bit_sizes_.size();
os.write((char*) &Qprime_mod_bit_sizes_count,
sizeof(Qprime_mod_bit_sizes_count));
os.write((char*) Qprime_mod_bit_sizes_.data(),
sizeof(int) * Qprime_mod_bit_sizes_count);
uint32_t Q_mod_bit_sizes_count = Q_mod_bit_sizes_.size();
os.write((char*) &Q_mod_bit_sizes_count,
sizeof(Q_mod_bit_sizes_count));
os.write((char*) Q_mod_bit_sizes_.data(),
sizeof(int) * Q_mod_bit_sizes_count);
uint32_t P_mod_bit_sizes_count = P_mod_bit_sizes_.size();
os.write((char*) &P_mod_bit_sizes_count,
sizeof(P_mod_bit_sizes_count));
os.write((char*) P_mod_bit_sizes_.data(),
sizeof(int) * P_mod_bit_sizes_count);
}
else
{
throw std::runtime_error(
"Context has no enough parameters to serialize!");
}
}
void HEContextImpl<Scheme::CKKS>::load(std::istream& is)
{
if ((!context_generated_))
{
is.read((char*) &scheme_, sizeof(scheme_));
is.read((char*) &sec_level_, sizeof(sec_level_));
is.read((char*) &keyswitching_type_, sizeof(keyswitching_type_));
is.read((char*) &n, sizeof(n));
is.read((char*) &n_power, sizeof(n_power));
is.read((char*) &coeff_modulus, sizeof(coeff_modulus));
is.read((char*) &total_coeff_bit_count,
sizeof(total_coeff_bit_count));
is.read((char*) &Q_prime_size, sizeof(Q_prime_size));
is.read((char*) &Q_size, sizeof(Q_size));
is.read((char*) &P_size, sizeof(P_size));
uint32_t prime_vector_count;
is.read((char*) &prime_vector_count, sizeof(prime_vector_count));
prime_vector_.resize(prime_vector_count);
is.read((char*) prime_vector_.data(),
sizeof(Modulus64) * prime_vector_count);
uint32_t base_q_count;
is.read((char*) &base_q_count, sizeof(base_q_count));
base_q.resize(base_q_count);
is.read((char*) base_q.data(), sizeof(Data64) * base_q_count);
uint32_t Qprime_mod_bit_sizes_count;
is.read((char*) &Qprime_mod_bit_sizes_count,
sizeof(Qprime_mod_bit_sizes_count));
Qprime_mod_bit_sizes_.resize(Qprime_mod_bit_sizes_count);
is.read((char*) Qprime_mod_bit_sizes_.data(),
sizeof(int) * Qprime_mod_bit_sizes_count);
uint32_t Q_mod_bit_sizes_count;
is.read((char*) &Q_mod_bit_sizes_count,
sizeof(Q_mod_bit_sizes_count));
Q_mod_bit_sizes_.resize(Q_mod_bit_sizes_count);
is.read((char*) Q_mod_bit_sizes_.data(),
sizeof(int) * Q_mod_bit_sizes_count);
uint32_t P_mod_bit_sizes_count;
is.read((char*) &P_mod_bit_sizes_count,
sizeof(P_mod_bit_sizes_count));
P_mod_bit_sizes_.resize(P_mod_bit_sizes_count);
is.read((char*) P_mod_bit_sizes_.data(),
sizeof(int) * P_mod_bit_sizes_count);
poly_modulus_degree_specified_ = true;
coeff_modulus_specified_ = true;
context_generated_ = false;
this->generate();
}
else
{
throw std::runtime_error("Context has been already exist!");
}
}
template class HEContextImpl<Scheme::CKKS>;
} // namespace heongpu