Program Listing for File switchkey.cuh
↰ Return to documentation for file (src/include/heongpu/kernel/switchkey.cuh)
// Copyright 2024-2026 Alişah Özcan
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
// Developer: Alişah Özcan
#ifndef HEONGPU_SWITCHKEY_H
#define HEONGPU_SWITCHKEY_H
#include "cuda_runtime.h"
#include "gpuntt/common/modular_arith.cuh"
namespace heongpu
{
__global__ void cipher_broadcast_kernel(Data64* input, Data64* output,
Modulus64* modulus, int n_power,
int rns_mod_count);
__global__ void
cipher_broadcast_leveled_kernel(Data64* input, Data64* output,
Modulus64* modulus, int first_rns_mod_count,
int current_rns_mod_count, int n_power);
__global__ void keyswitch_multiply_accumulate_kernel(
Data64* input, const Data64* __restrict__ relinkey, Data64* output,
Modulus64* modulus, int n_power, int Q_tilda_size, int iteration_count1,
int iteration_count2);
__global__ void keyswitch_multiply_accumulate_leveled_kernel(
Data64* input, const Data64* __restrict__ relinkey, Data64* output,
Modulus64* modulus, int first_rns_mod_count,
int current_decomp_mod_count, int iteration_count1,
int iteration_count2, int n_power);
__global__ void keyswitch_multiply_accumulate_leveled_method_II_kernel(
Data64* input, const Data64* __restrict__ relinkey, Data64* output,
Modulus64* modulus, int first_rns_mod_count,
int current_decomp_mod_count, int current_rns_mod_count,
int iteration_count1, int iteration_count2, int level, int n_power);
__global__ void
divide_round_lastq_kernel(Data64* input, Data64* ct, Data64* output,
Modulus64* modulus, Data64* half,
Data64* half_mod, Data64* last_q_modinv,
int n_power, int decomp_mod_count);
__global__ void divide_round_lastq_switchkey_kernel(
Data64* input, Data64* ct, Data64* output, Modulus64* modulus,
Data64* half, Data64* half_mod, Data64* last_q_modinv, int n_power,
int decomp_mod_count);
__global__ void divide_round_lastq_leveled_stage_one_kernel(
Data64* input, Data64* output, Modulus64* modulus, Data64* half,
Data64* half_mod, int n_power, int first_decomp_count,
int current_decomp_count);
__global__ void divide_round_lastq_leveled_stage_two_kernel(
Data64* input_last, Data64* input, Data64* ct, Data64* output,
Modulus64* modulus, Data64* last_q_modinv, int n_power,
int current_decomp_count);
__global__ void divide_round_lastq_leveled_stage_two_switchkey_kernel(
Data64* input_last, Data64* input, Data64* ct, Data64* output,
Modulus64* modulus, Data64* last_q_modinv, int n_power,
int current_decomp_count);
__global__ void move_cipher_leveled_kernel(Data64* input, Data64* output,
int n_power,
int current_decomp_count);
__global__ void divide_round_lastq_rescale_kernel(
Data64* input_last, Data64* input, Data64* output, Modulus64* modulus,
Data64* last_q_modinv, int n_power, int current_decomp_count);
__global__ void base_conversion_DtoB_relin_kernel(
Data64* ciphertext, Data64* output, Modulus64* modulus,
Modulus64* B_base, Data64* base_change_matrix_D_to_B,
Data64* Mi_inv_D_to_B, Data64* prod_D_to_B, int* I_j_, int* I_location_,
int n_power, int l, int d_tilda, int d, int r_prime);
__global__ void base_conversion_DtoQtilde_relin_kernel(
Data64* ciphertext, Data64* output, Modulus64* modulus,
Data64* base_change_matrix_D_to_Qtilda, Data64* Mi_inv_D_to_Qtilda,
Data64* prod_D_to_Qtilda, int* I_j_, int* I_location_, int n_power,
int l, int Q_tilda, int d);
__global__ void base_conversion_DtoB_relin_leveled_kernel(
Data64* ciphertext, Data64* output, Modulus64* modulus,
Modulus64* B_base, Data64* base_change_matrix_D_to_B,
Data64* Mi_inv_D_to_B, Data64* prod_D_to_B, int* I_j_, int* I_location_,
int n_power, int d_tilda, int d, int r_prime, int* mod_index);
__global__ void base_conversion_DtoQtilde_relin_leveled_kernel(
Data64* ciphertext, Data64* output, Modulus64* modulus,
Data64* base_change_matrix_D_to_Qtilda, Data64* Mi_inv_D_to_Qtilda,
Data64* prod_D_to_Qtilda, int* I_j_, int* I_location_, int n_power,
int d, int current_Qtilda_size, int current_Q_size, int level,
int* mod_index);
__global__ void multiply_accumulate_extended_kernel(
Data64* input, Data64* relinkey, Data64* output, Modulus64* B_prime,
int n_power, int d_tilda, int d, int r_prime);
__global__ void base_conversion_BtoD_relin_kernel(
Data64* input, Data64* output, Modulus64* modulus, Modulus64* B_base,
Data64* base_change_matrix_B_to_D, Data64* Mi_inv_B_to_D,
Data64* prod_B_to_D, int* I_j_, int* I_location_, int n_power,
int l_tilda, int d_tilda, int d, int r_prime);
__global__ void base_conversion_BtoD_relin_leveled_kernel(
Data64* input, Data64* output, Modulus64* modulus, Modulus64* B_base,
Data64* base_change_matrix_B_to_D, Data64* Mi_inv_B_to_D,
Data64* prod_B_to_D, int* I_j_, int* I_location_, int n_power,
int l_tilda, int d_tilda, int d, int r_prime, int* mod_index);
__global__ void divide_round_lastq_extended_kernel(
Data64* input, Data64* ct, Data64* output, Modulus64* modulus,
Data64* half, Data64* half_mod, Data64* last_q_modinv, int n_power,
int Q_prime_size, int Q_size, int P_size);
__global__ void divide_round_lastq_extended_switchkey_kernel(
Data64* input, Data64* ct, Data64* output, Modulus64* modulus,
Data64* half, Data64* half_mod, Data64* last_q_modinv, int n_power,
int Q_prime_size, int Q_size, int P_size);
__global__ void divide_round_lastq_extended_leveled_kernel(
Data64* input, Data64* output, Modulus64* modulus, Data64* half,
Data64* half_mod, Data64* last_q_modinv, int n_power, int Q_prime_size,
int Q_size, int first_Q_prime_size, int first_Q_size, int P_size);
// TODO: Find efficient way!
__global__ void global_memory_replace_kernel(Data64* input, Data64* output,
int n_power);
__global__ void
global_memory_replace_offset_kernel(Data64* input, Data64* output,
int current_decomposition_count,
int n_power);
__global__ void
cipher_broadcast_switchkey_kernel(Data64* cipher, Data64* out0,
Data64* out1, Modulus64* modulus,
int n_power, int decomp_mod_count);
__global__ void cipher_broadcast_switchkey_method_II_kernel(
Data64* cipher, Data64* out0, Data64* out1, Modulus64* modulus,
int n_power, int decomp_mod_count);
__global__ void addition_switchkey(Data64* in1, Data64* in2, Data64* out,
Modulus64* modulus, int n_power);
__global__ void cipher_broadcast_switchkey_leveled_kernel(
Data64* cipher, Data64* out0, Data64* out1, Modulus64* modulus,
int n_power, int first_rns_mod_count, int current_rns_mod_count,
int current_decomp_mod_count);
__global__ void negacyclic_shift_poly_coeffmod_kernel(Data64* cipher_in,
Data64* cipher_out,
Modulus64* modulus,
int shift,
int n_power);
// Double Hoisting Kernels
// NTT-domain Galois automorphism: permutes NTT slots directly,
// avoiding the INTT->coeff-permute->NTT round trip.
// perm_table: precomputed gather indices of size N, where
// perm(j) = br((galois_elt * (2*br(j)+1) % 2N - 1) / 2)
// Grid: dim3((n >> 8), pql_count, 2), 256
__global__ void galois_permute_ntt_pql_kernel(Data64* input, Data64* output,
int galois_elt, int n_power,
int pql_count);
__global__ void broadcast_scale_P_kernel(Data64* c_ntt, Data64* output,
Data64* P_mod_q,
Modulus64* pq_modulus, int n_power,
int current_decomp_count,
int pql_count);
__global__ void addition_pql_kernel(Data64* in1, Data64* in2, Data64* out,
Modulus64* pq_modulus, int n_power,
int pql_count);
// Optimized Hoisting-Rotations
__global__ void ckks_duplicate_kernel(Data64* cipher, Data64* output,
Modulus64* modulus, int n_power,
int first_rns_mod_count,
int current_rns_mod_count,
int current_decomp_mod_count);
__global__ void bfv_duplicate_kernel(Data64* cipher, Data64* output1,
Data64* output2, Modulus64* modulus,
int n_power, int rns_mod_count);
__global__ void divide_round_lastq_permute_ckks_kernel(
Data64* input, Data64* input2, Data64* output, Modulus64* modulus,
Data64* half, Data64* half_mod, Data64* last_q_modinv, int galois_elt,
int n_power, int Q_prime_size, int Q_size, int first_Q_prime_size,
int first_Q_size, int P_size);
__global__ void divide_round_lastq_permute_bfv_kernel(
Data64* input, Data64* ct, Data64* output, Modulus64* modulus,
Data64* half, Data64* half_mod, Data64* last_q_modinv, int galois_elt,
int n_power, int Q_prime_size, int Q_size, int P_size);
} // namespace heongpu
#endif // HEONGPU_SWITCHKEY_H