Program Listing for File evaluationkey.cu

Return to documentation for file (src/lib/host/ckks/evaluationkey.cu)

// Copyright 2024-2026 Alişah Özcan
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
// Developer: Alişah Özcan

#include <heongpu/host/ckks/evaluationkey.cuh>

namespace heongpu
{

    __host__ Relinkey<Scheme::CKKS>::Relinkey(HEContext<Scheme::CKKS>& context)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                relinkey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;
            }
            break;
            case 2: // KEYSWITCHING_METHOD_II
            {
                d_ = context.d_leveled->operator[](0);
                relinkey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III
            {
                d_ = context.d_leveled->operator[](0);
                d_tilda_ = context.d_tilda_leveled->operator[](0);
                r_prime_ = context.r_prime_leveled;

                int max_depth = Q_size_ - 1;
                for (int i = 0; i < max_depth; i++)
                {
                    relinkey_size_leveled_.push_back(
                        2 * context.d_leveled->operator[](i) *
                        context.d_tilda_leveled->operator[](i) * r_prime_ *
                        ring_size);
                }
            }
            break;
            default:
                break;
        }
    }

    void Relinkey<Scheme::CKKS>::store_in_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            if ((key_type == keyswitching_type::KEYSWITCHING_METHOD_III) &&
                scheme_ == scheme_type::ckks)
            {
                int max_depth = Q_size_ - 1;
                for (int i = 0; i < max_depth; i++)
                {
                    device_location_leveled_.push_back(
                        std::move(DeviceVector<Data64>(
                            host_location_leveled_[i], stream)));
                    host_location_leveled_[i].resize(0);
                    host_location_leveled_[i].shrink_to_fit();
                }
                host_location_leveled_.resize(0);
                host_location_leveled_.shrink_to_fit();
            }
            else
            {
                device_location_ = DeviceVector<Data64>(host_location_, stream);
                host_location_.resize(0);
                host_location_.shrink_to_fit();
            }

            storage_type_ = storage_type::DEVICE;
        }
    }

    void Relinkey<Scheme::CKKS>::store_in_host(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            if ((key_type == keyswitching_type::KEYSWITCHING_METHOD_III) &&
                scheme_ == scheme_type::ckks)
            {
                int max_depth = Q_size_ - 1;
                for (int i = 0; i < max_depth; i++)
                {
                    host_location_leveled_.push_back(
                        HostVector<Data64>(relinkey_size_leveled_[i]));

                    cudaMemcpyAsync(host_location_leveled_[i].data(),
                                    device_location_leveled_[i].data(),
                                    relinkey_size_leveled_[i] * sizeof(Data64),
                                    cudaMemcpyDeviceToHost, stream);
                    HEONGPU_CUDA_CHECK(cudaGetLastError());

                    device_location_leveled_[i].resize(0, stream);
                }
                device_location_leveled_.resize(0);
                device_location_leveled_.shrink_to_fit();
            }
            else
            {
                host_location_ = HostVector<Data64>(relinkey_size_);
                cudaMemcpyAsync(host_location_.data(), device_location_.data(),
                                relinkey_size_ * sizeof(Data64),
                                cudaMemcpyDeviceToHost, stream);
                HEONGPU_CUDA_CHECK(cudaGetLastError());

                device_location_.resize(0, stream);
            }

            storage_type_ = storage_type::HOST;
        }
        else
        {
            // pass
        }
    }

    Data64* Relinkey<Scheme::CKKS>::data()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return device_location_.data();
        }
        else
        {
            return host_location_.data();
        }
    }
    Data64* Relinkey<Scheme::CKKS>::data(size_t i)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return device_location_leveled_[i].data();
        }
        else
        {
            return host_location_leveled_[i].data();
        }
    }

    void Relinkey<Scheme::CKKS>::save(std::ostream& os) const
    {
        if (key_type == keyswitching_type::KEYSWITCHING_METHOD_III)
        {
            throw std::runtime_error(
                "Relinkey has not serialization for KEYSWITCHING_METHOD_III!");
        }

        if (relin_key_generated_)
        {
            os.write((char*) &scheme_, sizeof(scheme_));

            os.write((char*) &key_type, sizeof(key_type));

            os.write((char*) &ring_size, sizeof(ring_size));

            os.write((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            os.write((char*) &Q_size_, sizeof(Q_size_));

            os.write((char*) &d_, sizeof(d_));

            os.write((char*) &d_tilda_, sizeof(d_tilda_));

            os.write((char*) &r_prime_, sizeof(r_prime_));

            os.write((char*) &storage_type_, sizeof(storage_type_));

            os.write((char*) &relin_key_generated_,
                     sizeof(relin_key_generated_));

            os.write((char*) &relinkey_size_, sizeof(relinkey_size_));

            if (storage_type_ == storage_type::DEVICE)
            {
                HostVector<Data64> host_locations_temp(relinkey_size_);
                cudaMemcpy(host_locations_temp.data(), device_location_.data(),
                           relinkey_size_ * sizeof(Data64),
                           cudaMemcpyDeviceToHost);
                HEONGPU_CUDA_CHECK(cudaGetLastError());
                cudaDeviceSynchronize();

                os.write((char*) host_locations_temp.data(),
                         sizeof(Data64) * relinkey_size_);
            }
            else
            {
                os.write((char*) host_location_.data(),
                         sizeof(Data64) * relinkey_size_);
            }
        }
        else
        {
            throw std::runtime_error(
                "Relinkey is not generated so can not be serialized!");
        }
    }

    void Relinkey<Scheme::CKKS>::load(std::istream& is)
    {
        if ((!relin_key_generated_))
        {
            is.read((char*) &scheme_, sizeof(scheme_));

            if (scheme_ != scheme_type::ckks)
            {
                throw std::runtime_error("Invalid scheme binary!");
            }

            is.read((char*) &key_type, sizeof(key_type));

            is.read((char*) &ring_size, sizeof(ring_size));

            is.read((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            is.read((char*) &Q_size_, sizeof(Q_size_));

            is.read((char*) &d_, sizeof(d_));

            is.read((char*) &d_tilda_, sizeof(d_tilda_));

            is.read((char*) &r_prime_, sizeof(r_prime_));

            is.read((char*) &storage_type_, sizeof(storage_type_));

            is.read((char*) &relin_key_generated_,
                    sizeof(relin_key_generated_));

            is.read((char*) &relinkey_size_, sizeof(relinkey_size_));

            storage_type_ = storage_type::DEVICE;
            relin_key_generated_ = true;

            HostVector<Data64> host_locations_temp(relinkey_size_);
            is.read((char*) host_locations_temp.data(),
                    sizeof(Data64) * relinkey_size_);

            device_location_.resize(relinkey_size_);
            cudaMemcpy(device_location_.data(), host_locations_temp.data(),
                       relinkey_size_ * sizeof(Data64), cudaMemcpyHostToDevice);
            HEONGPU_CUDA_CHECK(cudaGetLastError());
            cudaDeviceSynchronize();
        }
        else
        {
            throw std::runtime_error("Relinkey has been already exist!");
        }
    }

    int Relinkey<Scheme::CKKS>::memory_size()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            if (relinkey_size_leveled_.size() == 0)
            {
                return device_location_.size();
            }
            else
            {
                return device_location_leveled_[0].size();
            }
        }
        else
        {
            if (relinkey_size_leveled_.size() == 0)
            {
                return host_location_.size();
            }
            else
            {
                return host_location_leveled_[0].size();
            }
        }
    }

    void Relinkey<Scheme::CKKS>::memory_clear(cudaStream_t stream)
    {
        if (device_location_.size() > 0)
        {
            device_location_.resize(0, stream);
            device_location_.shrink_to_fit(stream);
        }

        for (int i = 0; i < device_location_leveled_.size(); i++)
        {
            device_location_leveled_[i].resize(0, stream);
            device_location_leveled_[i].shrink_to_fit(stream);
        }

        device_location_leveled_.resize(0);
        device_location_leveled_.shrink_to_fit();

        if (host_location_.size() > 0)
        {
            host_location_.resize(0);
            host_location_.shrink_to_fit();
        }

        for (int i = 0; i < host_location_leveled_.size(); i++)
        {
            host_location_leveled_[i].resize(0);
            host_location_leveled_[i].shrink_to_fit();
        }

        host_location_leveled_.resize(0);
        host_location_leveled_.shrink_to_fit();
    }

    void
    Relinkey<Scheme::CKKS>::memory_set(DeviceVector<Data64>&& new_device_vector)
    {
        storage_type_ = storage_type::DEVICE;
        device_location_ = std::move(new_device_vector);

        if (host_location_.size() > 0)
        {
            host_location_.resize(0);
            host_location_.shrink_to_fit();
        }
    }

    void
    Relinkey<Scheme::CKKS>::memory_set(DeviceVector<Data64>&& new_device_vector,
                                       int i)
    {
        storage_type_ = storage_type::DEVICE;
        device_location_leveled_[i] = std::move(new_device_vector);

        for (int i = 0; i < host_location_leveled_.size(); i++)
        {
            host_location_leveled_[i].resize(0);
            host_location_leveled_[i].shrink_to_fit();
        }

        host_location_leveled_.resize(0);
        host_location_leveled_.shrink_to_fit();
    }

    void Relinkey<Scheme::CKKS>::copy_to_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            if (memory_size() == 0)
            {
                // pass
            }
            else
            {
                if (relinkey_size_leveled_.size() == 0)
                {
                    device_location_ =
                        DeviceVector<Data64>(host_location_, stream);
                }
                else
                {
                    device_location_leveled_.resize(
                        relinkey_size_leveled_.size());

                    for (int i = 0; i < device_location_leveled_.size(); i++)
                    {
                        device_location_leveled_[i] = DeviceVector<Data64>(
                            host_location_leveled_[i], stream);
                    }
                }
            }

            storage_type_ = storage_type::DEVICE;
        }
    }

    void Relinkey<Scheme::CKKS>::remove_from_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            if (relinkey_size_leveled_.size() == 0)
            {
                device_location_.resize(0, stream);
                device_location_.shrink_to_fit(stream);
            }
            else
            {
                for (int i = 0; i < device_location_leveled_.size(); i++)
                {
                    device_location_leveled_[i].resize(0, stream);
                    device_location_leveled_[i].shrink_to_fit(stream);
                }

                device_location_leveled_.resize(0);
                device_location_leveled_.shrink_to_fit();
            }
        }
        else
        {
            // pass
        }
    }

    void Relinkey<Scheme::CKKS>::remove_from_host()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            if (relinkey_size_leveled_.size() == 0)
            {
                host_location_.resize(0);
                host_location_.shrink_to_fit();
            }
            else
            {
                for (int i = 0; i < host_location_leveled_.size(); i++)
                {
                    host_location_leveled_[i].resize(0);
                    host_location_leveled_[i].shrink_to_fit();
                }

                host_location_leveled_.resize(0);
                host_location_leveled_.shrink_to_fit();
            }
        }
    }

    __host__ MultipartyRelinkey<Scheme::CKKS>::MultipartyRelinkey(
        HEContext<Scheme::CKKS>& context, const RNGSeed seed)
        : Relinkey(context), seed_(seed)
    {
    }

    __host__
    Galoiskey<Scheme::CKKS>::Galoiskey(HEContext<Scheme::CKKS>& context)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        customized = false;

        group_order_ = 5;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                galoiskey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;

                for (int i = 0; i < MAX_SHIFT; i++)
                {
                    int power = pow(2, i);
                    galois_elt[power] =
                        steps_to_galois_elt(power, ring_size, group_order_);
                    galois_elt[(-power)] =
                        steps_to_galois_elt((-power), ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                max_shift_ = MAX_SHIFT - 1;
                max_log_slot_ = int(std::log2(ring_size >> 1));
            }
            break;
            case 2: // KEYSWITCHING_METHOD_II
            {
                for (int i = 0; i < MAX_SHIFT; i++)
                {
                    int power = pow(2, i);
                    galois_elt[power] =
                        steps_to_galois_elt(power, ring_size, group_order_);
                    galois_elt[(-power)] =
                        steps_to_galois_elt((-power), ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                max_shift_ = MAX_SHIFT - 1;
                max_log_slot_ = int(std::log2(ring_size >> 1));

                d_ = context.d_leveled->operator[](0);
                galoiskey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III
                throw std::invalid_argument(
                    "Galoiskey does not support KEYSWITCHING_METHOD_III");
                break;
            default:
                throw std::invalid_argument("Invalid Key Switching Type");
                break;
        }
    }

    __host__
    Galoiskey<Scheme::CKKS>::Galoiskey(HEContext<Scheme::CKKS>& context,
                                       int max_shift)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        customized = false;

        group_order_ = 5;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                galoiskey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;

                for (int i = 0; i < max_shift + 1; i++)
                {
                    int power = pow(2, i);
                    galois_elt[power] =
                        steps_to_galois_elt(power, ring_size, group_order_);
                    galois_elt[(-power)] =
                        steps_to_galois_elt((-power), ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                max_shift_ = max_shift;
                max_log_slot_ = int(std::log2(ring_size >> 1));
            }
            break;
            case 2: // KEYSWITCHING_METHOD_II
            {
                for (int i = 0; i < max_shift + 1; i++)
                {
                    int power = pow(2, i);
                    galois_elt[power] =
                        steps_to_galois_elt(power, ring_size, group_order_);
                    galois_elt[(-power)] =
                        steps_to_galois_elt((-power), ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                max_shift_ = max_shift;
                max_log_slot_ = int(std::log2(ring_size >> 1));

                d_ = context.d_leveled->operator[](0);
                galoiskey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III
                throw std::invalid_argument(
                    "Galoiskey does not support KEYSWITCHING_METHOD_III");
                break;
            default:
                throw std::invalid_argument("Invalid Key Switching Type");
                break;
        }
    }

    __host__
    Galoiskey<Scheme::CKKS>::Galoiskey(HEContext<Scheme::CKKS>& context,
                                       std::vector<int>& shift_vec)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        customized = false;

        group_order_ = 5;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                galoiskey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;

                for (int shift : shift_vec)
                {
                    galois_elt[shift] =
                        steps_to_galois_elt(shift, ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                break;
            }
            case 2: // KEYSWITCHING_METHOD_II
            {
                for (int shift : shift_vec)
                {
                    galois_elt[shift] =
                        steps_to_galois_elt(shift, ring_size, group_order_);
                }

                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);

                d_ = context.d_leveled->operator[](0);
                galoiskey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III
                throw std::invalid_argument(
                    "Galoiskey does not support KEYSWITCHING_METHOD_III");
                break;
            default:
                throw std::invalid_argument("Invalid Key Switching Type");
                break;
        }
    }

    __host__
    Galoiskey<Scheme::CKKS>::Galoiskey(HEContext<Scheme::CKKS>& context,
                                       std::vector<uint32_t>& galois_elts)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        customized = true;

        group_order_ = 5;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);
                galoiskey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;
                custom_galois_elt = galois_elts;
            }
            break;
            case 2: // KEYSWITCHING_METHOD_II
            {
                d_ = context.d_leveled->operator[](0);
                galois_elt_zero =
                    steps_to_galois_elt(0, ring_size, group_order_);
                galoiskey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
                custom_galois_elt = galois_elts;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III
                throw std::invalid_argument(
                    "Galoiskey does not support KEYSWITCHING_METHOD_III");
                break;
            default:
                throw std::invalid_argument("Invalid Key Switching Type");
                break;
        }
    }

    void Galoiskey<Scheme::CKKS>::store_in_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            for (const auto& galois_ : host_location_)
            {
                device_location_[galois_.first] =
                    DeviceVector<Data64>(galois_.second, stream);
            }

            zero_device_location_ =
                DeviceVector<Data64>(zero_host_location_, stream);

            host_location_.clear();
            zero_host_location_.resize(0);
            zero_host_location_.shrink_to_fit();

            storage_type_ = storage_type::DEVICE;
        }
    }

    void Galoiskey<Scheme::CKKS>::store_in_host(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            for (auto& galois_ : device_location_)
            {
                host_location_[galois_.first] =
                    HostVector<Data64>(galoiskey_size_);
                cudaMemcpyAsync(host_location_[galois_.first].data(),
                                galois_.second.data(),
                                galoiskey_size_ * sizeof(Data64),
                                cudaMemcpyDeviceToHost, stream);
                HEONGPU_CUDA_CHECK(cudaGetLastError());

                galois_.second.resize(0, stream);
            }

            zero_host_location_ = HostVector<Data64>(galoiskey_size_);
            cudaMemcpyAsync(zero_host_location_.data(),
                            zero_device_location_.data(),
                            galoiskey_size_ * sizeof(Data64),
                            cudaMemcpyDeviceToHost, stream);
            HEONGPU_CUDA_CHECK(cudaGetLastError());

            device_location_.clear();
            zero_device_location_.resize(0);

            storage_type_ = storage_type::HOST;
        }
        else
        {
            // pass
        }
    }

    Data64* Galoiskey<Scheme::CKKS>::data(size_t i)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return device_location_[i].data();
        }
        else
        {
            return host_location_[i].data();
        }
    }

    Data64* Galoiskey<Scheme::CKKS>::c_data()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return zero_device_location_.data();
        }
        else
        {
            return zero_host_location_.data();
        }
    }

    void Galoiskey<Scheme::CKKS>::save(std::ostream& os) const
    {
        if (galois_key_generated_)
        {
            os.write((char*) &scheme_, sizeof(scheme_));

            os.write((char*) &key_type, sizeof(key_type));

            os.write((char*) &ring_size, sizeof(ring_size));

            os.write((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            os.write((char*) &Q_size_, sizeof(Q_size_));

            os.write((char*) &d_, sizeof(d_));

            os.write((char*) &customized, sizeof(customized));

            os.write((char*) &group_order_, sizeof(group_order_));

            os.write((char*) &storage_type_, sizeof(storage_type_));

            os.write((char*) &galois_key_generated_,
                     sizeof(galois_key_generated_));

            if (customized)
            {
                uint32_t custom_galois_elt_size = custom_galois_elt.size();
                os.write((char*) &custom_galois_elt_size,
                         sizeof(custom_galois_elt_size));
                os.write((char*) custom_galois_elt.data(),
                         sizeof(u_int32_t) * custom_galois_elt_size);
            }
            else
            {
                uint32_t galois_elt_size = galois_elt.size();
                os.write((char*) &galois_elt_size, sizeof(galois_elt_size));
                for (auto& galois : galois_elt)
                {
                    os.write((char*) &galois.first, sizeof(galois.first));
                    os.write((char*) &galois.second, sizeof(galois.second));
                }
            }

            os.write((char*) &galois_elt_zero, sizeof(galois_elt_zero));

            os.write((char*) &galoiskey_size_, sizeof(galoiskey_size_));

            if (storage_type_ == storage_type::DEVICE)
            {
                uint32_t key_count = device_location_.size();
                os.write((char*) &key_count, sizeof(key_count));

                for (auto& galois_key_mem : device_location_)
                {
                    HostVector<Data64> host_locations_temp(galoiskey_size_);
                    cudaMemcpy(host_locations_temp.data(),
                               galois_key_mem.second.data(),
                               galoiskey_size_ * sizeof(Data64),
                               cudaMemcpyDeviceToHost);
                    HEONGPU_CUDA_CHECK(cudaGetLastError());
                    cudaDeviceSynchronize();

                    os.write((char*) &galois_key_mem.first,
                             sizeof(galois_key_mem.first));
                    os.write((char*) host_locations_temp.data(),
                             sizeof(Data64) * galoiskey_size_);
                }

                HostVector<Data64> host_locations_temp(galoiskey_size_);
                cudaMemcpy(
                    host_locations_temp.data(), zero_device_location_.data(),
                    galoiskey_size_ * sizeof(Data64), cudaMemcpyDeviceToHost);
                HEONGPU_CUDA_CHECK(cudaGetLastError());
                cudaDeviceSynchronize();

                os.write((char*) host_locations_temp.data(),
                         sizeof(Data64) * galoiskey_size_);
            }
            else
            {
                uint32_t key_count = host_location_.size();
                os.write((char*) &key_count, sizeof(key_count));

                for (auto& galois_key_mem : host_location_)
                {
                    os.write((char*) &galois_key_mem.first,
                             sizeof(galois_key_mem.first));
                    os.write((char*) galois_key_mem.second.data(),
                             sizeof(Data64) * galoiskey_size_);
                }

                os.write((char*) zero_host_location_.data(),
                         sizeof(Data64) * galoiskey_size_);
            }
        }
        else
        {
            throw std::runtime_error(
                "Galoiskey is not generated so can not be serialized!");
        }
    }

    void Galoiskey<Scheme::CKKS>::load(std::istream& is)
    {
        if ((!galois_key_generated_))
        {
            is.read((char*) &scheme_, sizeof(scheme_));

            if (scheme_ != scheme_type::ckks)
            {
                throw std::runtime_error("Invalid scheme binary!");
            }

            is.read((char*) &key_type, sizeof(key_type));

            is.read((char*) &ring_size, sizeof(ring_size));

            is.read((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            is.read((char*) &Q_size_, sizeof(Q_size_));

            is.read((char*) &d_, sizeof(d_));

            is.read((char*) &customized, sizeof(customized));

            is.read((char*) &group_order_, sizeof(group_order_));

            is.read((char*) &storage_type_, sizeof(storage_type_));

            is.read((char*) &galois_key_generated_,
                    sizeof(galois_key_generated_));

            storage_type_ = storage_type::DEVICE;
            galois_key_generated_ = true;

            if (customized)
            {
                uint32_t custom_galois_elt_size;
                is.read((char*) &custom_galois_elt_size,
                        custom_galois_elt_size);
                custom_galois_elt.resize(custom_galois_elt_size);
                is.read((char*) custom_galois_elt.data(),
                        sizeof(u_int32_t) * custom_galois_elt_size);
            }
            else
            {
                uint32_t galois_elt_size;
                is.read((char*) &galois_elt_size, sizeof(galois_elt_size));
                for (int i = 0; i < galois_elt_size; i++)
                {
                    int first;
                    int second;
                    is.read((char*) &first, sizeof(first));
                    is.read((char*) &second, sizeof(second));
                    galois_elt[first] = second;
                }
            }

            is.read((char*) &galois_elt_zero, sizeof(galois_elt_zero));

            is.read((char*) &galoiskey_size_, sizeof(galoiskey_size_));

            uint32_t key_count;
            is.read((char*) &key_count, sizeof(key_count));

            for (int i = 0; i < key_count; i++)
            {
                int first;
                is.read((char*) &first, sizeof(first));
                HostVector<Data64> host_locations_temp(galoiskey_size_);
                is.read((char*) host_locations_temp.data(),
                        sizeof(Data64) * galoiskey_size_);
                device_location_[first] =
                    DeviceVector<Data64>(host_locations_temp);
                cudaDeviceSynchronize();
            }

            HostVector<Data64> host_locations_temp(galoiskey_size_);
            is.read((char*) host_locations_temp.data(),
                    sizeof(Data64) * galoiskey_size_);

            zero_device_location_.resize(galoiskey_size_);
            cudaMemcpy(zero_device_location_.data(), host_locations_temp.data(),
                       galoiskey_size_ * sizeof(Data64),
                       cudaMemcpyHostToDevice);
            HEONGPU_CUDA_CHECK(cudaGetLastError());
            cudaDeviceSynchronize();
        }
        else
        {
            throw std::runtime_error("Galoiskey has been already exist!");
        }
    }

    __host__ MultipartyGaloiskey<Scheme::CKKS>::MultipartyGaloiskey(
        HEContext<Scheme::CKKS>& context, const RNGSeed seed)
        : Galoiskey(context), seed_(seed)
    {
    }

    __host__ MultipartyGaloiskey<Scheme::CKKS>::MultipartyGaloiskey(
        HEContext<Scheme::CKKS>& context, std::vector<int>& shift_vec,
        const RNGSeed seed)
        : Galoiskey(context, shift_vec), seed_(seed)
    {
    }

    __host__ MultipartyGaloiskey<Scheme::CKKS>::MultipartyGaloiskey(
        HEContext<Scheme::CKKS>& context, std::vector<uint32_t>& galois_elts,
        const RNGSeed seed)
        : Galoiskey(context, galois_elts), seed_(seed)
    {
    }

    __host__
    Switchkey<Scheme::CKKS>::Switchkey(HEContext<Scheme::CKKS>& context)
    {
        if (!context.context_generated_)
        {
            throw std::invalid_argument("HEContext is not generated!");
        }

        scheme_ = context.scheme_;
        key_type = context.keyswitching_type_;

        ring_size = context.n;
        Q_prime_size_ = context.Q_prime_size;
        Q_size_ = context.Q_size;

        switch (static_cast<int>(context.keyswitching_type_))
        {
            case 1: // KEYSWITCHING_METHOD_I
            {
                switchkey_size_ = 2 * Q_size_ * Q_prime_size_ * ring_size;
            }
            break;
            case 2: // KEYSWITCHING_METHOD_II
            {
                d_ = context.d_leveled->operator[](0);
                switchkey_size_ = 2 * d_ * Q_prime_size_ * ring_size;
            }
            break;
            case 3: // KEYSWITCHING_METHOD_III Galoiskey
                throw std::invalid_argument(
                    "Switchkey does not support KEYSWITCHING_METHOD_III");
                break;
            default:
                throw std::invalid_argument("Invalid Key Switching Type");
                break;
        }
    }

    void Switchkey<Scheme::CKKS>::store_in_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            device_location_ = DeviceVector<Data64>(host_location_, stream);
            host_location_.resize(0);
            host_location_.shrink_to_fit();

            storage_type_ = storage_type::DEVICE;
        }
    }

    void Switchkey<Scheme::CKKS>::store_in_host(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            host_location_ = HostVector<Data64>(switchkey_size_);
            cudaMemcpyAsync(host_location_.data(), device_location_.data(),
                            switchkey_size_ * sizeof(Data64),
                            cudaMemcpyDeviceToHost, stream);
            HEONGPU_CUDA_CHECK(cudaGetLastError());

            device_location_.resize(0, stream);

            storage_type_ = storage_type::HOST;
        }
        else
        {
            // pass
        }
    }

    Data64* Switchkey<Scheme::CKKS>::data()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return device_location_.data();
        }
        else
        {
            return host_location_.data();
        }
    }

    void Switchkey<Scheme::CKKS>::save(std::ostream& os) const
    {
        if (switch_key_generated_)
        {
            os.write((char*) &scheme_, sizeof(scheme_));

            os.write((char*) &key_type, sizeof(key_type));

            os.write((char*) &ring_size, sizeof(ring_size));

            os.write((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            os.write((char*) &Q_size_, sizeof(Q_size_));

            os.write((char*) &d_, sizeof(d_));

            os.write((char*) &storage_type_, sizeof(storage_type_));

            os.write((char*) &switch_key_generated_,
                     sizeof(switch_key_generated_));

            os.write((char*) &switchkey_size_, sizeof(switchkey_size_));

            if (storage_type_ == storage_type::DEVICE)
            {
                HostVector<Data64> host_locations_temp(switchkey_size_);
                cudaMemcpy(host_locations_temp.data(), device_location_.data(),
                           switchkey_size_ * sizeof(Data64),
                           cudaMemcpyDeviceToHost);
                HEONGPU_CUDA_CHECK(cudaGetLastError());
                cudaDeviceSynchronize();

                os.write((char*) host_locations_temp.data(),
                         sizeof(Data64) * switchkey_size_);
            }
            else
            {
                os.write((char*) host_location_.data(),
                         sizeof(Data64) * switchkey_size_);
            }
        }
        else
        {
            throw std::runtime_error(
                "Switchkey is not generated so can not be serialized!");
        }
    }

    void Switchkey<Scheme::CKKS>::load(std::istream& is)
    {
        if ((!switch_key_generated_))
        {
            is.read((char*) &scheme_, sizeof(scheme_));

            is.read((char*) &key_type, sizeof(key_type));

            is.read((char*) &ring_size, sizeof(ring_size));

            is.read((char*) &Q_prime_size_, sizeof(Q_prime_size_));

            is.read((char*) &Q_size_, sizeof(Q_size_));

            is.read((char*) &d_, sizeof(d_));

            is.read((char*) &storage_type_, sizeof(storage_type_));

            is.read((char*) &switch_key_generated_,
                    sizeof(switch_key_generated_));

            is.read((char*) &switchkey_size_, sizeof(switchkey_size_));

            storage_type_ = storage_type::DEVICE;
            switch_key_generated_ = true;

            HostVector<Data64> host_locations_temp(switchkey_size_);
            is.read((char*) host_locations_temp.data(),
                    sizeof(Data64) * switchkey_size_);

            cudaMemcpy(device_location_.data(), host_locations_temp.data(),
                       switchkey_size_ * sizeof(Data64),
                       cudaMemcpyHostToDevice);
            HEONGPU_CUDA_CHECK(cudaGetLastError());
            cudaDeviceSynchronize();
        }
        else
        {
            throw std::runtime_error("Switchkey has been already exist!");
        }
    }

    int Switchkey<Scheme::CKKS>::memory_size()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            return device_location_.size();
        }
        else
        {
            return host_location_.size();
        }
    }

    void Switchkey<Scheme::CKKS>::memory_clear(cudaStream_t stream)
    {
        if (device_location_.size() > 0)
        {
            device_location_.resize(0, stream);
            device_location_.shrink_to_fit(stream);
        }

        if (host_location_.size() > 0)
        {
            host_location_.resize(0);
            host_location_.shrink_to_fit();
        }
    }

    void Switchkey<Scheme::CKKS>::memory_set(
        DeviceVector<Data64>&& new_device_vector)
    {
        storage_type_ = storage_type::DEVICE;
        device_location_ = std::move(new_device_vector);

        if (host_location_.size() > 0)
        {
            host_location_.resize(0);
            host_location_.shrink_to_fit();
        }
    }

    void Switchkey<Scheme::CKKS>::copy_to_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            if (memory_size() == 0)
            {
                // pass
            }
            else
            {
                device_location_ = DeviceVector<Data64>(host_location_, stream);
            }

            storage_type_ = storage_type::DEVICE;
        }
    }

    void Switchkey<Scheme::CKKS>::remove_from_device(cudaStream_t stream)
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            device_location_.resize(0, stream);
            device_location_.shrink_to_fit(stream);

            storage_type_ = storage_type::HOST;
        }
        else
        {
            // pass
        }
    }

    void Switchkey<Scheme::CKKS>::remove_from_host()
    {
        if (storage_type_ == storage_type::DEVICE)
        {
            // pass
        }
        else
        {
            host_location_.resize(0);
            host_location_.shrink_to_fit();

            storage_type_ = storage_type::DEVICE;
        }
    }

} // namespace heongpu