Program Listing for File bootstrapping.cu
↰ Return to documentation for file (src/lib/kernel/bootstrapping.cu)
// Copyright 2024-2026 Alişah Özcan
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
// Developer: Alişah Özcan
#include <heongpu/kernel/bootstrapping.cuh>
namespace heongpu
{
__device__ int exponent_calculation(int& index, int& n)
{
Data64 result = 1ULL;
Data64 five = 5ULL;
Data64 mod = (n << 2) - 1;
int bits = 32 - __clz(index);
for (int i = bits - 1; i > -1; i--)
{
result = (result * result) & mod;
if (((index >> i) & 1u))
{
result = (result * five) & mod;
}
}
return result;
}
__device__ int matrix_location(int& index)
{
if (index == 0)
{
return 0;
}
return (3 * index) - 1;
}
__device__ int matrix_reverse_location(int& index)
{
int total = (gridDim.y - 1) * 3;
if (index == 0)
{
return total;
}
return total - (3 * index);
}
__global__ void E_diagonal_generate_kernel(Complex64* output, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int block_y = blockIdx.y; // matrix index
int logk = block_y + 1;
int output_location = matrix_location(block_y);
int n = 1 << n_power;
int v_size = 1 << (n_power - logk);
int index1 = idx & ((v_size << 1) - 1);
int index2 = index1 >> (n_power - logk);
Complex64 W1(1.0, 0.0);
Complex64 W2(0.0, 0.0);
Complex64 W3(0.0, 0.0);
if (block_y == 0)
{
double angle = M_PI / (v_size << 2);
Complex64 omega_4n(cos(angle), sin(angle));
int expo = exponent_calculation(index1, n);
Complex64 W = omega_4n.exp(expo);
Complex64 W_neg = W; // W.negate();
if (index2 == 1)
{
W1 = W_neg;
W2 = Complex64(1.0, 0.0);
}
else
{
W2 = W;
}
output[(output_location << n_power) + idx] = W1;
output[((output_location + 1) << n_power) + idx] = W2;
}
else
{
double angle = M_PI / (v_size << 2);
Complex64 omega_4n(cos(angle), sin(angle));
int expo = exponent_calculation(index1, n);
Complex64 W = omega_4n.exp(expo);
Complex64 W_neg = W; // W.negate();
if (index2 == 1)
{
W1 = W_neg;
W3 = Complex64(1.0, 0.0);
}
else
{
W2 = W;
}
output[(output_location << n_power) + idx] = W1;
output[((output_location + 1) << n_power) + idx] = W2;
output[((output_location + 2) << n_power) + idx] = W3;
}
}
__global__ void E_diagonal_inverse_generate_kernel(Complex64* output,
int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int block_y = blockIdx.y; // matrix index
int logk = block_y + 1;
int output_location = matrix_reverse_location(block_y);
int n = 1 << n_power;
int v_size = 1 << (n_power - logk);
int index1 = idx & ((v_size << 1) - 1);
int index2 = index1 >> (n_power - logk);
Complex64 W1(0.5, 0.0);
Complex64 W2(0.5, 0.0);
Complex64 W3(0.0, 0.0);
if (block_y == 0)
{
if (index2 == 1)
{
double angle = M_PI / (v_size << 2);
Complex64 omega_4n(cos(angle), sin(angle));
int expo = exponent_calculation(index1, n);
W1 = omega_4n.inverse();
W1 = W1.exp(expo);
W1 = W1 / Complex64(2.0, 0.0);
W2 = W1.negate();
}
output[(output_location << n_power) + idx] = W1;
output[((output_location + 1) << n_power) + idx] = W2;
}
else
{
if (index2 == 1)
{
double angle = M_PI / (v_size << 2);
Complex64 omega_4n(cos(angle), sin(angle));
int expo = exponent_calculation(index1, n);
W1 = omega_4n.inverse();
W1 = W1.exp(expo);
W1 = W1 / Complex64(2.0, 0.0);
W2 = Complex64(0.0, 0.0);
W3 = W1.negate();
}
output[(output_location << n_power) + idx] = W1;
output[((output_location + 1) << n_power) + idx] = W2;
output[((output_location + 2) << n_power) + idx] = W3;
}
}
__global__ void E_diagonal_inverse_matrix_mult_single_kernel(
Complex64* input, Complex64* output, bool last, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (last)
{
for (int i = 0; i < 2; i++)
{
output[idx + (i << n_power)] = input[idx + (i << n_power)];
}
}
else
{
for (int i = 0; i < 3; i++)
{
output[idx + (i << n_power)] = input[idx + (i << n_power)];
}
}
}
__global__ void E_diagonal_matrix_mult_kernel(
Complex64* input, Complex64* output, Complex64* temp, int* diag_index,
int* input_index, int* output_index, int iteration_count,
int R_matrix_counter, int output_index_counter, int mul_index,
bool first1, bool first2, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int offset = first1 ? 2 : 3;
int L_matrix_loc_ = offset + (3 * mul_index);
int L_matrix_size = 3;
int R_matrix_counter_ = R_matrix_counter;
int output_index_counter_ = output_index_counter;
int iter_R_m = iteration_count;
if (first2)
{
for (int i = 0; i < iter_R_m; i++)
{
int diag_index_ = diag_index[R_matrix_counter_];
Complex64 R_m = input[idx + (i << n_power)];
for (int j = 0; j < L_matrix_size; j++)
{
Complex64 L_m =
rotated_access(input + ((L_matrix_loc_ + j) << n_power),
diag_index_, idx, n_power);
int output_location = output_index[output_index_counter_];
Complex64 res = output[(output_location << n_power) + idx];
res = res + (L_m * R_m);
output[(output_location << n_power) + idx] = res;
output_index_counter_++;
}
R_matrix_counter_++;
}
}
else
{
for (int i = 0; i < iter_R_m; i++)
{
int diag_index_ = diag_index[R_matrix_counter_];
Complex64 R_m =
temp[idx +
(input_index[R_matrix_counter_ - offset] << n_power)];
for (int j = 0; j < L_matrix_size; j++)
{
Complex64 L_m =
rotated_access(input + ((L_matrix_loc_ + j) << n_power),
diag_index_, idx, n_power);
int output_location = output_index[output_index_counter_];
Complex64 res = output[(output_location << n_power) + idx];
res = res + (L_m * R_m);
output[(output_location << n_power) + idx] = res;
output_index_counter_++;
}
R_matrix_counter_++;
}
}
}
__global__ void E_diagonal_inverse_matrix_mult_kernel(
Complex64* input, Complex64* output, Complex64* temp, int* diag_index,
int* input_index, int* output_index, int iteration_count,
int R_matrix_counter, int output_index_counter, int mul_index,
bool first, bool last, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int L_matrix_loc_ = 3 + (3 * mul_index);
int L_matrix_size = (last) ? 2 : 3;
int R_matrix_counter_ = R_matrix_counter;
int output_index_counter_ = output_index_counter;
int iter_R_m = iteration_count;
if (first)
{
for (int i = 0; i < iter_R_m; i++)
{
int diag_index_ = diag_index[R_matrix_counter_];
Complex64 R_m = input[idx + (i << n_power)];
for (int j = 0; j < L_matrix_size; j++)
{
Complex64 L_m =
rotated_access(input + ((L_matrix_loc_ + j) << n_power),
diag_index_, idx, n_power);
int output_location = output_index[output_index_counter_];
Complex64 res = output[(output_location << n_power) + idx];
res = res + (L_m * R_m);
output[(output_location << n_power) + idx] = res;
output_index_counter_++;
}
R_matrix_counter_++;
}
}
else
{
for (int i = 0; i < iter_R_m; i++)
{
int diag_index_ = diag_index[R_matrix_counter_];
Complex64 R_m =
temp[idx + (input_index[R_matrix_counter_ - 3] << n_power)];
for (int j = 0; j < L_matrix_size; j++)
{
Complex64 L_m =
rotated_access(input + ((L_matrix_loc_ + j) << n_power),
diag_index_, idx, n_power);
int output_location = output_index[output_index_counter_];
Complex64 res = output[(output_location << n_power) + idx];
res = res + (L_m * R_m);
output[(output_location << n_power) + idx] = res;
output_index_counter_++;
}
R_matrix_counter_++;
}
}
}
__global__ void complex_vector_scale_kernel(Complex64* data,
Complex64 scaling, int n_power)
{
int idx =
blockIdx.x * blockDim.x + threadIdx.x; // index within each vector
int idy = blockIdx.y; // matrix index
int location = idx + (idy << n_power);
data[location] = data[location] * scaling;
}
__global__ void vector_rotate_kernel(Complex64* input, Complex64* output,
int rotate_index, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Complex64 rotated = rotated_access(input, rotate_index, idx, n_power);
output[idx] = rotated;
}
// TODO: implement it for multiple RNS prime (currently it only works for
// single prime)
__global__ void mod_raise_kernel(Data64* input, Data64* output,
Modulus64* modulus, int n_power)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x; // ring size
int idy = blockIdx.y; // rns count
int idz = blockIdx.z; // cipher count
int location_input = idx + (idz << n_power);
int location_output =
idx + (idy << n_power) + ((gridDim.y * idz) << n_power);
Data64 q0 = modulus[0].value;
Data64 qi = modulus[idy].value;
// Get coefficient from level 0
Data64 coeff = input[location_input];
// Centered reduction around q0
// If coeff >= q0/2, use negative representation: coeff = q0 - coeff
Data64 pos = 1;
Data64 neg = 0;
if (coeff >= (q0 >> 1))
{
coeff = q0 - coeff;
pos = 0;
neg = 1;
}
if (idy == 0)
{
output[location_output] = input[location_input];
}
else
{
Data64 tmp = OPERATOR_GPU_64::reduce_forced(coeff, modulus[idy]);
output[location_output] = tmp * pos + (qi - tmp) * neg;
}
}
__global__ void
tfhe_nand_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a - input1_a_reg;
local_a = local_a - input2_a_reg;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b - input1_b_reg;
local_b = local_b - input2_b_reg;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_and_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a + input1_a_reg;
local_a = local_a + input2_a_reg;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b + input1_b_reg;
local_b = local_b + input2_b_reg;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_and_first_not_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a - input1_a_reg;
local_a = local_a + input2_a_reg;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b - input1_b_reg;
local_b = local_b + input2_b_reg;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_nor_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a - input1_a_reg;
local_a = local_a - input2_a_reg;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b - input1_b_reg;
local_b = local_b - input2_b_reg;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_or_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a + input1_a_reg;
local_a = local_a + input2_a_reg;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b + input1_b_reg;
local_b = local_b + input2_b_reg;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_xnor_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a - input1_a_reg;
local_a = local_a - input2_a_reg;
local_a = 2 * local_a;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b - (2 * input1_b_reg);
local_b = local_b - (2 * input2_b_reg);
local_b = local_b;
output_b[block_x] = local_b;
}
}
__global__ void
tfhe_xor_pre_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int32_t* input2_a, int32_t* input2_b,
int32_t encoded, int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t local_a = 0;
uint32_t input1_a_reg = input1_a[offset + i];
uint32_t input2_a_reg = input2_a[offset + i];
local_a = local_a + input1_a_reg;
local_a = local_a + input2_a_reg;
local_a = 2 * local_a;
output_a[offset + i] = local_a;
}
if (idx == 0)
{
uint32_t local_b = encoded;
uint32_t input1_b_reg = input1_b[block_x];
uint32_t input2_b_reg = input2_b[block_x];
local_b = local_b + (2 * input1_b_reg);
local_b = local_b + (2 * input2_b_reg);
local_b = local_b;
output_b[block_x] = local_b;
}
}
__global__ void tfhe_not_comp_kernel(int32_t* output_a, int32_t* output_b,
int32_t* input1_a, int32_t* input1_b,
int n)
{
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset = block_x * n;
for (int i = idx; i < n; i += blockDim.x)
{
uint32_t input1_a_reg = input1_a[offset + i];
input1_a_reg = -input1_a_reg;
output_a[offset + i] = input1_a_reg;
}
if (idx == 0)
{
uint32_t input1_b_reg = input1_b[block_x];
input1_b_reg = -input1_b_reg;
output_b[block_x] = input1_b_reg;
}
}
__device__ int32_t torus_modulus_switch_log(int32_t& input,
int& modulus_log)
{
uint64_t range_log = 63 - modulus_log;
uint64_t half_range = 1ULL << (range_log - 1);
uint64_t result64 =
(static_cast<uint64_t>(static_cast<uint32_t>(input)) << 32) +
half_range;
int32_t result = static_cast<int32_t>(result64 >> range_log);
return result;
}
__global__ void tfhe_bootstrapping_kernel(
const int32_t* input_a, const int32_t* input_b, int32_t* output,
const Data64* boot_key,
const Root64* __restrict__ forward_root_of_unity_table,
const Root64* __restrict__ inverse_root_of_unity_table,
const Ninverse64 n_inverse, const Modulus64 modulus,
const int32_t encoded, const int32_t bk_offset, const int32_t bk_mask,
const int32_t bk_half, int n, int N, int N_power, int k, int bk_bit,
int bk_length)
{
__shared__ uint32_t sdata32[1024];
__shared__ Data64 sdata64[1024];
int idx = threadIdx.x;
int block_x = blockIdx.x;
int offset_lwe = block_x * n;
const Modulus64 modulus_reg = modulus;
const Data64 threshold = modulus_reg.value >> 1;
const Ninverse64 n_inverse_reg = n_inverse;
int32_t input_b_i = input_b[block_x];
int32_t input_b_i_N = torus_modulus_switch_log(input_b_i, N_power);
input_b_i_N = static_cast<int32_t>(N << 1) - input_b_i_N;
int32_t2 accum[4];
int32_t2 accum2[4];
int32_t encoded_reg = encoded;
if (input_b_i_N < N)
{
accum[k].value[0] =
(idx < input_b_i_N) ? (-encoded_reg) : encoded_reg;
accum[k].value[1] = ((idx + blockDim.x) < input_b_i_N)
? (-encoded_reg)
: encoded_reg;
}
else
{
int32_t input_b_i_N_minus = input_b_i_N - N;
accum[k].value[0] =
(idx < input_b_i_N_minus) ? encoded_reg : (-encoded_reg);
accum[k].value[1] = ((idx + blockDim.x) < input_b_i_N_minus)
? (-encoded_reg)
: encoded_reg;
}
for (int i = 0; i < n; i++)
{
int32_t input_a_i = input_a[offset_lwe + i];
int32_t input_a_i_N = torus_modulus_switch_log(input_a_i, N_power);
Data64 offset_i =
i * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
uint64_t2 accum3[4];
for (int i2 = 0; i2 < (k + 1); i2++)
{
sdata32[idx] = accum[i2].value[0];
sdata32[idx + blockDim.x] = accum[i2].value[1];
__syncthreads();
if (input_a_i_N < N)
{
accum2[i2].value[0] =
(idx < input_a_i_N)
? (-static_cast<int32_t>(
sdata32[N - input_a_i_N + idx]))
: (static_cast<int32_t>(
sdata32[idx - input_a_i_N]));
accum2[i2].value[1] =
((idx + blockDim.x) < input_a_i_N)
? (-static_cast<int32_t>(
sdata32[N - input_a_i_N + idx]))
: (static_cast<int32_t>(
sdata32[idx - input_a_i_N]));
accum2[i2].value[0] =
accum2[i2].value[0] - accum[i2].value[0];
accum2[i2].value[1] =
accum2[i2].value[1] - accum[i2].value[1];
}
else
{
int32_t input_a_i_N_minus = input_a_i_N - N;
accum2[i2].value[0] =
(idx < input_a_i_N_minus)
? (static_cast<int32_t>(
sdata32[N - input_a_i_N_minus + idx]))
: (-static_cast<int32_t>(
sdata32[idx - input_a_i_N_minus]));
accum2[i2].value[1] =
((idx + blockDim.x) < input_a_i_N_minus)
? (static_cast<int32_t>(
sdata32[N - input_a_i_N_minus + idx]))
: (-static_cast<int32_t>(
sdata32[idx - input_a_i_N_minus]));
accum2[i2].value[0] =
accum2[i2].value[0] - accum[i2].value[0];
accum2[i2].value[1] =
accum2[i2].value[1] - accum[i2].value[1];
}
Data64 offset_i2 = i2 * ((Data64) bk_length * (k + 1) * N);
for (int i3 = 0; i3 < bk_length; i3++)
{
Data64 offset_i3 =
(offset_i + offset_i2) + i3 * ((Data64) (k + 1) * N);
int shift = 32 - (bk_bit * (i3 + 1));
int32_t temp0 =
(((accum2[i2].value[0] + bk_offset) >> shift) &
bk_mask) -
bk_half;
int32_t temp1 =
(((accum2[i2].value[1] + bk_offset) >> shift) &
bk_mask) -
bk_half;
// PRE PROCESS
sdata64[idx] =
(temp0 <= 0)
? static_cast<Data64>(modulus_reg.value + temp0)
: static_cast<Data64>(temp0);
sdata64[idx + blockDim.x] =
(temp1 <= 0)
? static_cast<Data64>(modulus_reg.value + temp1)
: static_cast<Data64>(temp1);
__syncthreads();
SmallForwardNTT(sdata64, forward_root_of_unity_table,
modulus_reg, false);
Data64 value0 = sdata64[idx];
Data64 value1 = sdata64[idx + blockDim.x];
for (int i4 = 0; i4 < (k + 1); i4++)
{
Data64 bk0 = boot_key[offset_i3 + (i4 * N) + idx];
Data64 bk1 =
boot_key[offset_i3 + (i4 * N) + idx + blockDim.x];
Data64 mul0 =
OPERATOR_GPU_64::mult(value0, bk0, modulus_reg);
Data64 mul1 =
OPERATOR_GPU_64::mult(value1, bk1, modulus_reg);
accum3[i4].value[0] = OPERATOR_GPU_64::add(
accum3[i4].value[0], mul0, modulus_reg);
accum3[i4].value[1] = OPERATOR_GPU_64::add(
accum3[i4].value[1], mul1, modulus_reg);
}
}
}
for (int i4 = 0; i4 < (k + 1); i4++)
{
sdata64[idx] = accum3[i4].value[0];
sdata64[idx + blockDim.x] = accum3[i4].value[1];
__syncthreads();
SmallInverseNTT(sdata64, inverse_root_of_unity_table,
modulus_reg, n_inverse_reg, false);
accum3[i4].value[0] = sdata64[idx];
accum3[i4].value[1] = sdata64[idx + blockDim.x];
// POST PROCESS
int32_t post_accum0 =
(accum3[i4].value[0] >= threshold)
? static_cast<int32_t>(static_cast<int64_t>(
accum3[i4].value[0] - modulus_reg.value))
: static_cast<int32_t>(
static_cast<int64_t>(accum3[i4].value[0]));
int32_t post_accum1 =
(accum3[i4].value[0] >= threshold)
? static_cast<int32_t>(static_cast<int64_t>(
accum3[i4].value[0] - modulus_reg.value))
: static_cast<int32_t>(
static_cast<int64_t>(accum3[i4].value[0]));
accum[i4].value[0] =
accum[i4].value[0] + static_cast<uint32_t>(post_accum0);
accum[i4].value[1] =
accum[i4].value[1] + static_cast<uint32_t>(post_accum1);
}
}
Data64 global_location =
(Data64) block_x * (Data64) (k + 1) * (Data64) N;
for (int i4 = 0; i4 < (k + 1); i4++)
{
output[global_location + (i4 * N) + idx] = accum[i4].value[0];
output[global_location + (i4 * N) + idx + blockDim.x] =
accum[i4].value[1];
}
}
__global__ void tfhe_bootstrapping_kernel_unique_step1(
const int32_t* input_a, const int32_t* input_b, Data64* output,
const Data64* boot_key,
const Root64* __restrict__ forward_root_of_unity_table,
const Modulus64 modulus, const int32_t encoded, const int32_t bk_offset,
const int32_t bk_mask, const int32_t bk_half, int n, int N, int N_power,
int k, int bk_bit, int bk_length)
{
__shared__ uint32_t shared_data32[1024];
__shared__ Data64 shared_data64[1024];
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int block_y = blockIdx.y; // k
int block_z = blockIdx.z; // l
int32_t encoded_reg = encoded;
const Modulus64 modulus_reg = modulus;
int offset_lwe = block_x * n;
Data64 offset_i = 0;
Data64 offset_i2 = block_y * ((Data64) bk_length * (k + 1) * N);
Data64 offset_i3 =
(offset_i + offset_i2) + (block_z * ((Data64) (k + 1) * N));
Data64 offset_o =
block_x * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
Data64 offset_o2 = block_y * ((Data64) bk_length * (k + 1) * N);
Data64 offset_o3 =
(offset_o + offset_o2) + (block_z * ((Data64) (k + 1) * N));
int32_t input_b_reg = input_b[block_x];
int32_t input_b_reg_N = torus_modulus_switch_log(input_b_reg, N_power);
input_b_reg_N = static_cast<int32_t>(N << 1) - input_b_reg_N;
int32_t2 temp;
int32_t2 temp2;
if (block_y == k)
{
if (input_b_reg_N < N)
{
temp.value[0] =
(idx_x < input_b_reg_N) ? (-encoded_reg) : encoded_reg;
temp.value[1] = ((idx_x + blockDim.x) < input_b_reg_N)
? (-encoded_reg)
: encoded_reg;
}
else
{
int32_t input_b_reg_N_minus = input_b_reg_N - N;
temp.value[0] = (idx_x < input_b_reg_N_minus) ? encoded_reg
: (-encoded_reg);
temp.value[1] = ((idx_x + blockDim.x) < input_b_reg_N_minus)
? encoded_reg
: (-encoded_reg);
}
}
int32_t input_a_reg = input_a[offset_lwe]; // + 0
uint32_t input_a_reg_N = torus_modulus_switch_log(input_a_reg, N_power);
shared_data32[idx_x] = temp.value[0];
shared_data32[idx_x + blockDim.x] = temp.value[1];
__syncthreads();
if (input_a_reg_N < N)
{
temp2.value[0] =
(idx_x < input_a_reg_N)
? (-static_cast<int32_t>(
shared_data32[N - input_a_reg_N + idx_x]))
: (static_cast<int32_t>(
shared_data32[idx_x - input_a_reg_N]));
temp2.value[1] =
((idx_x + blockDim.x) < input_a_reg_N)
? (-static_cast<int32_t>(
shared_data32[N - input_a_reg_N +
(idx_x + blockDim.x)]))
: (static_cast<int32_t>(
shared_data32[(idx_x + blockDim.x) - input_a_reg_N]));
temp2.value[0] = temp2.value[0] - temp.value[0];
temp2.value[1] = temp2.value[1] - temp.value[1];
}
else
{
int32_t input_a_reg_N_minus = input_a_reg_N - N;
temp2.value[0] =
(idx_x < input_a_reg_N_minus)
? (static_cast<int32_t>(
shared_data32[N - input_a_reg_N_minus + idx_x]))
: (-static_cast<int32_t>(
shared_data32[idx_x - input_a_reg_N_minus]));
temp2.value[1] = ((idx_x + blockDim.x) < input_a_reg_N_minus)
? (static_cast<int32_t>(
shared_data32[N - input_a_reg_N_minus +
(idx_x + blockDim.x)]))
: (-static_cast<int32_t>(
shared_data32[(idx_x + blockDim.x) -
input_a_reg_N_minus]));
temp2.value[0] = temp2.value[0] - temp.value[0];
temp2.value[1] = temp2.value[1] - temp.value[1];
}
int shift = 32 - (bk_bit * (block_z + 1));
temp2.value[0] =
(((temp2.value[0] + bk_offset) >> shift) & bk_mask) - bk_half;
temp2.value[1] =
(((temp2.value[1] + bk_offset) >> shift) & bk_mask) - bk_half;
// PRE PROCESS
shared_data64[idx_x] =
(temp2.value[0] < 0)
? static_cast<Data64>(modulus_reg.value + temp2.value[0])
: static_cast<Data64>(temp2.value[0]);
shared_data64[idx_x + blockDim.x] =
(temp2.value[1] < 0)
? static_cast<Data64>(modulus_reg.value + temp2.value[1])
: static_cast<Data64>(temp2.value[1]);
__syncthreads();
SmallForwardNTT(shared_data64, forward_root_of_unity_table, modulus_reg,
false);
Data64 ntt_value0 = shared_data64[idx_x];
Data64 ntt_value1 = shared_data64[idx_x + blockDim.x];
#pragma unroll
for (int i = 0; i < (k + 1); i++)
{
Data64 bk0 = boot_key[offset_i3 + (i * N) + idx_x];
Data64 bk1 = boot_key[offset_i3 + (i * N) + idx_x + blockDim.x];
Data64 mul0 = OPERATOR_GPU_64::mult(ntt_value0, bk0, modulus_reg);
Data64 mul1 = OPERATOR_GPU_64::mult(ntt_value1, bk1, modulus_reg);
output[offset_o3 + (i * N) + idx_x] = mul0;
output[offset_o3 + (i * N) + idx_x + blockDim.x] = mul1;
}
}
__global__ void tfhe_bootstrapping_kernel_regular_step1(
const int32_t* input_a, const int32_t* input_b, const int32_t* input_c,
Data64* output, const Data64* boot_key, int boot_index,
const Root64* __restrict__ forward_root_of_unity_table,
const Modulus64 modulus, const int32_t bk_offset, const int32_t bk_mask,
const int32_t bk_half, int n, int N, int N_power, int k, int bk_bit,
int bk_length)
{
__shared__ uint32_t shared_data32[1024];
__shared__ Data64 shared_data64[1024];
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int block_y = blockIdx.y; // k
int block_z = blockIdx.z; // l
const Modulus64 modulus_reg = modulus;
int offset_lwe = block_x * n;
Data64 offset_i =
boot_index * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
Data64 offset_i2 = block_y * ((Data64) bk_length * (k + 1) * N);
Data64 offset_i3 =
(offset_i + offset_i2) + (block_z * ((Data64) (k + 1) * N));
Data64 offset_o =
block_x * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
Data64 offset_o2 = block_y * ((Data64) bk_length * (k + 1) * N);
Data64 offset_o3 =
(offset_o + offset_o2) + (block_z * ((Data64) (k + 1) * N));
int32_t2 temp;
int32_t2 temp2;
int32_t input_a_reg = input_a[offset_lwe + boot_index];
int32_t input_a_reg_N = torus_modulus_switch_log(input_a_reg, N_power);
int offset_acc = block_x * (k + 1) * N;
int offset_acc2 = block_y * N;
offset_acc = offset_acc + offset_acc2 + idx_x;
//
temp.value[0] = input_c[offset_acc];
temp.value[1] = input_c[offset_acc + blockDim.x];
shared_data32[idx_x] = temp.value[0];
shared_data32[idx_x + blockDim.x] = temp.value[1];
__syncthreads();
if (input_a_reg_N < N)
{
temp2.value[0] =
(idx_x < input_a_reg_N)
? (-static_cast<int32_t>(
shared_data32[N - input_a_reg_N + idx_x]))
: (static_cast<int32_t>(
shared_data32[idx_x - input_a_reg_N]));
temp2.value[1] =
((idx_x + blockDim.x) < input_a_reg_N)
? (-static_cast<int32_t>(
shared_data32[N - input_a_reg_N +
(idx_x + blockDim.x)]))
: (static_cast<int32_t>(
shared_data32[(idx_x + blockDim.x) - input_a_reg_N]));
temp2.value[0] = temp2.value[0] - temp.value[0];
temp2.value[1] = temp2.value[1] - temp.value[1];
}
else
{
int32_t input_a_reg_N_minus = input_a_reg_N - N;
temp2.value[0] =
(idx_x < input_a_reg_N_minus)
? (static_cast<int32_t>(
shared_data32[N - input_a_reg_N_minus + idx_x]))
: (-static_cast<int32_t>(
shared_data32[idx_x - input_a_reg_N_minus]));
temp2.value[1] = ((idx_x + blockDim.x) < input_a_reg_N_minus)
? (static_cast<int32_t>(
shared_data32[N - input_a_reg_N_minus +
(idx_x + blockDim.x)]))
: (-static_cast<int32_t>(
shared_data32[(idx_x + blockDim.x) -
input_a_reg_N_minus]));
temp2.value[0] = temp2.value[0] - temp.value[0];
temp2.value[1] = temp2.value[1] - temp.value[1];
}
int shift = 32 - (bk_bit * (block_z + 1));
temp2.value[0] =
(((temp2.value[0] + bk_offset) >> shift) & bk_mask) - bk_half;
temp2.value[1] =
(((temp2.value[1] + bk_offset) >> shift) & bk_mask) - bk_half;
// PRE PROCESS
shared_data64[idx_x] =
(temp2.value[0] < 0)
? static_cast<Data64>(modulus_reg.value + temp2.value[0])
: static_cast<Data64>(temp2.value[0]);
shared_data64[idx_x + blockDim.x] =
(temp2.value[1] < 0)
? static_cast<Data64>(modulus_reg.value + temp2.value[1])
: static_cast<Data64>(temp2.value[1]);
__syncthreads();
SmallForwardNTT(shared_data64, forward_root_of_unity_table, modulus_reg,
false);
Data64 ntt_value0 = shared_data64[idx_x];
Data64 ntt_value1 = shared_data64[idx_x + blockDim.x];
#pragma unroll
for (int i = 0; i < (k + 1); i++)
{
Data64 bk0 = boot_key[offset_i3 + (i * N) + idx_x];
Data64 bk1 = boot_key[offset_i3 + (i * N) + idx_x + blockDim.x];
Data64 mul0 = OPERATOR_GPU_64::mult(ntt_value0, bk0, modulus_reg);
Data64 mul1 = OPERATOR_GPU_64::mult(ntt_value1, bk1, modulus_reg);
output[offset_o3 + (i * N) + idx_x] = mul0;
output[offset_o3 + (i * N) + idx_x + blockDim.x] = mul1;
}
}
__global__ void tfhe_bootstrapping_kernel_unique_step2(
const Data64* input, const int32_t* input_b, int32_t* output,
const Root64* __restrict__ inverse_root_of_unity_table,
const Ninverse64 n_inverse, const Modulus64 modulus,
const int32_t encoded, int n, int N, int N_power, int k, int bk_length)
{
__shared__ Data64 shared_data64[1024];
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int block_y = blockIdx.y; // k
Data64 offset_i =
block_x * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
int32_t encoded_reg = encoded;
const Modulus64 modulus_reg = modulus;
const Data64 threshold = modulus_reg.value >> 1;
const Ninverse64 n_inverse_reg = n_inverse;
Data64 accum0 = 0ULL;
Data64 accum1 = 0ULL;
for (int i = 0; i < (k + 1); i++)
{
Data64 offset_i2 = i * ((Data64) bk_length * (k + 1) * N);
#pragma unroll
for (int j = 0; j < bk_length; j++)
{
Data64 offset_i3 =
(offset_i + offset_i2) + (j * ((Data64) (k + 1) * N));
offset_i3 = offset_i3 + (block_y * N);
Data64 value0 = input[offset_i3 + idx_x];
Data64 value1 = input[offset_i3 + idx_x + blockDim.x];
accum0 = OPERATOR_GPU_64::add(accum0, value0, modulus_reg);
accum1 = OPERATOR_GPU_64::add(accum1, value1, modulus_reg);
}
}
shared_data64[idx_x] = accum0;
shared_data64[idx_x + blockDim.x] = accum1;
__syncthreads();
SmallInverseNTT(shared_data64, inverse_root_of_unity_table, modulus_reg,
n_inverse_reg, false);
accum0 = shared_data64[idx_x];
accum1 = shared_data64[idx_x + blockDim.x];
// POST PROCESS
int32_t post_accum0 =
(accum0 >= threshold)
? static_cast<int32_t>(
static_cast<int64_t>(accum0 - modulus_reg.value))
: static_cast<int32_t>(static_cast<int64_t>(accum0));
int32_t post_accum1 =
(accum1 >= threshold)
? static_cast<int32_t>(
static_cast<int64_t>(accum1 - modulus_reg.value))
: static_cast<int32_t>(static_cast<int64_t>(accum1));
Data64 offset_o = block_x * (Data64) (k + 1) * N;
offset_o = offset_o + (block_y * N);
int32_t input_b_reg = input_b[block_x];
int32_t input_b_reg_N = torus_modulus_switch_log(input_b_reg, N_power);
input_b_reg_N = static_cast<int32_t>(N << 1) - input_b_reg_N;
int32_t2 temp;
if (block_y == k)
{
if (input_b_reg_N < N)
{
temp.value[0] =
(idx_x < input_b_reg_N) ? (-encoded_reg) : encoded_reg;
temp.value[1] = ((idx_x + blockDim.x) < input_b_reg_N)
? (-encoded_reg)
: encoded_reg;
}
else
{
int32_t input_b_reg_N_minus = input_b_reg_N - N;
temp.value[0] = (idx_x < input_b_reg_N_minus) ? encoded_reg
: (-encoded_reg);
temp.value[1] = ((idx_x + blockDim.x) < input_b_reg_N_minus)
? encoded_reg
: (-encoded_reg);
}
post_accum0 = post_accum0 + temp.value[0];
post_accum1 = post_accum1 + temp.value[1];
}
output[offset_o + idx_x] = post_accum0;
output[offset_o + idx_x + blockDim.x] = post_accum1;
}
__global__ void tfhe_bootstrapping_kernel_regular_step2(
const Data64* input, int32_t* output,
const Root64* __restrict__ inverse_root_of_unity_table,
const Ninverse64 n_inverse, const Modulus64 modulus, int n, int N,
int k, int bk_length)
{
__shared__ Data64 shared_data64[1024];
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int block_y = blockIdx.y; // k
Data64 offset_i =
block_x * (Data64) (k + 1) * ((Data64) bk_length * (k + 1) * N);
const Modulus64 modulus_reg = modulus;
const Data64 threshold = modulus_reg.value >> 1;
const Ninverse64 n_inverse_reg = n_inverse;
Data64 accum0 = 0ULL;
Data64 accum1 = 0ULL;
for (int i = 0; i < (k + 1); i++)
{
Data64 offset_i2 = i * ((Data64) bk_length * (k + 1) * N);
#pragma unroll
for (int j = 0; j < bk_length; j++)
{
Data64 offset_i3 =
(offset_i + offset_i2) + (j * ((Data64) (k + 1) * N));
offset_i3 = offset_i3 + (block_y * N);
Data64 value0 = input[offset_i3 + idx_x];
Data64 value1 = input[offset_i3 + idx_x + blockDim.x];
accum0 = OPERATOR_GPU_64::add(accum0, value0, modulus_reg);
accum1 = OPERATOR_GPU_64::add(accum1, value1, modulus_reg);
}
}
shared_data64[idx_x] = accum0;
shared_data64[idx_x + blockDim.x] = accum1;
__syncthreads();
SmallInverseNTT(shared_data64, inverse_root_of_unity_table, modulus_reg,
n_inverse_reg, false);
accum0 = shared_data64[idx_x];
accum1 = shared_data64[idx_x + blockDim.x];
// POST PROCESS
int32_t post_accum0 =
(accum0 >= threshold)
? static_cast<int32_t>(
static_cast<int64_t>(accum0 - modulus_reg.value))
: static_cast<int32_t>(static_cast<int64_t>(accum0));
int32_t post_accum1 =
(accum1 >= threshold)
? static_cast<int32_t>(
static_cast<int64_t>(accum1 - modulus_reg.value))
: static_cast<int32_t>(static_cast<int64_t>(accum1));
Data64 offset_o = block_x * (Data64) (k + 1) * N;
offset_o = offset_o + (block_y * N);
output[offset_o + idx_x] = output[offset_o + idx_x] + post_accum0;
output[offset_o + idx_x + blockDim.x] =
output[offset_o + idx_x + blockDim.x] + post_accum1;
}
__global__ void tfhe_sample_extraction_kernel(const int32_t* input,
int32_t* output_a,
int32_t* output_b, int N,
int k, int index)
{
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int block_y = blockIdx.y; // k
Data64 offset_i = block_x * (Data64) (k + 1) * N;
Data64 offset_i2 = offset_i + (block_y * N);
int inner_index = index + 1;
int idx_x2 = idx_x + blockDim.x;
int32_t value0 = (idx_x < inner_index) ? input[offset_i2 + idx_x]
: -input[offset_i2 + N - idx_x];
int32_t value1 = (idx_x2 < inner_index)
? input[offset_i2 + idx_x2]
: -input[offset_i2 + N - idx_x2];
Data64 offset_o = block_x * (Data64) k * N;
Data64 offset_o2 = offset_o + (block_y * N);
output_a[offset_o2 + idx_x] = value0;
output_a[offset_o2 + idx_x2] = value1;
if ((idx_x == 0) && (block_y == 0))
{
Data64 offset_i3 = offset_i + (k * N);
int32_t b_reg = input[offset_i3];
output_b[block_x] = b_reg;
}
}
// It will be more efficient!
__global__ void tfhe_key_switching_kernel(
const int32_t* input_a, const int32_t* input_b, int32_t* output_a,
int32_t* output_b, const int32_t* ks_key_a, const int32_t* ks_key_b,
int ks_base_bit_, int ks_length_, int n, int N, int k)
{
int idx_x = threadIdx.x;
int block_x = blockIdx.x; // cipher size
int ks_base_bit_reg = ks_base_bit_;
int ks_length_reg = ks_length_;
int n_reg = n;
int N_reg = N;
int k_reg = k;
int base = 1 << ks_base_bit_reg;
int precision_offset = 1
<< (32 - (1 + ks_base_bit_reg * ks_length_reg));
int mask = base - 1;
int Nk_reg = N_reg * k_reg;
int offset_i = block_x * Nk_reg;
int32_t accum_a[2] = {0, 0};
int32_t accum_b = 0;
if (idx_x == 0)
{
accum_b = input_b[block_x];
}
for (int i = 0; i < Nk_reg; i++)
{
int32_t input_a_reg = input_a[offset_i + i];
int offset_key_b_i = i * ks_length_reg * mask;
int offset_key_a_i = offset_key_b_i * n_reg;
for (int i2 = 0; i2 < ks_length_reg; i2++)
{
int32_t input_a_decomp =
(((input_a_reg + precision_offset) >>
(32 - ((i2 + 1) * ks_base_bit_reg))) &
mask) +
1;
input_a_decomp = input_a_decomp - 1;
int offset_key_b_i2 = offset_key_b_i + (i2 * mask);
int offset_key_a_i3 = i2 * mask * n_reg;
if (input_a_decomp != 0)
{
int offset_key_a_i2 = (input_a_decomp - 1) * n_reg;
offset_key_a_i2 =
offset_key_a_i2 + offset_key_a_i + offset_key_a_i3;
int count = 0;
for (int i3 = idx_x; i3 < n; i3 += blockDim.x)
{
uint32_t ks_key_a_reg = ks_key_a[offset_key_a_i2 + i3];
accum_a[count] = accum_a[count] - ks_key_a_reg;
count++;
}
if (idx_x == 0)
{
uint32_t ks_key_b_reg =
ks_key_b[offset_key_b_i2 + (input_a_decomp - 1)];
accum_b = accum_b - ks_key_b_reg;
}
}
}
}
int offset_o = block_x * n_reg;
int count = 0;
for (int i = idx_x; i < n; i += blockDim.x)
{
output_a[offset_o + i] = accum_a[count];
count++;
}
if (idx_x == 0)
{
output_b[block_x] = accum_b;
}
}
} // namespace heongpu