Program Listing for File memorypool.cu

Return to documentation for file (src/lib/util/memorypool.cu)

// Copyright 2024-2026 Alişah Özcan
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
// Developer: Alişah Özcan

#include <heongpu/util/memorypool.cuh>

namespace heongpu
{
    std::shared_ptr<MemoryPool::HostResource> MemoryPool::host_base_ = nullptr;
    std::shared_ptr<MemoryPool::HostPoolResource> MemoryPool::host_pool_ =
        nullptr;
    std::shared_ptr<MemoryPool::HostStatsAdaptor>
        MemoryPool::host_stats_adaptor_ = nullptr;

    std::shared_ptr<MemoryPool::DeviceResource> MemoryPool::device_base_ =
        nullptr;
    std::shared_ptr<MemoryPool::DevicePoolResource> MemoryPool::device_pool_ =
        nullptr;
    std::shared_ptr<MemoryPool::DeviceStatsAdaptor>
        MemoryPool::device_stats_adaptor_ = nullptr;
    bool MemoryPool::initialized_ = false;
    std::mutex MemoryPool::mutex_;

    MemoryPoolConfig MemoryPoolConfig::Defaults()
    {
        MemoryPoolConfig config;
        config.initial_device_fraction = initial_device_memorypool_size;
        config.max_device_fraction = max_device_memorypool_size;
        config.max_host_fraction = max_host_memorypool_size;
        return config;
    }

    MemoryPool& MemoryPool::instance()
    {
        // Initialize CUDA runtime before the singleton is constructed so that
        // teardown order stays correct at process exit.
        cudaFree(nullptr);
        HEONGPU_CUDA_CHECK(cudaGetLastError());

        static MemoryPool instance;
        return instance;
    }

    void MemoryPool::ensure_base_resources()
    {
        if (!host_base_)
        {
            host_base_ = std::make_shared<HostResource>();
        }
        if (!device_base_)
        {
            device_base_ = std::make_shared<DeviceResource>();
        }
    }

    size_t MemoryPool::get_host_avaliable_memory() const
    {
        struct sysinfo memInfo;
        sysinfo(&memInfo);
        size_t free_memory = memInfo.freeram * memInfo.mem_unit;
        return free_memory;
    }

    size_t MemoryPool::get_decive_avaliable_memory() const
    {
        size_t free_mem = 0;
        size_t total_mem = 0;
        cudaMemGetInfo(&free_mem, &total_mem);
        HEONGPU_CUDA_CHECK(cudaGetLastError());
        return free_mem; // total_mem
    }

    size_t MemoryPool::roundup_256(size_t size) const
    {
        return ((size + 255) / 256) * 256;
    }

    void MemoryPool::initialize()
    {
        initialize(MemoryPoolConfig::Defaults());
    }

    void MemoryPool::initialize(const MemoryPoolConfig& config)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (!initialized_)
        {
            ensure_base_resources();
            size_t total_host_memory = get_host_avaliable_memory();
            size_t total_device_memory = get_decive_avaliable_memory();

            auto normalize_fraction = [](float value,
                                         const char* label) -> float
            {
                if (value <= 0.0f)
                {
                    throw std::invalid_argument(std::string(label) +
                                                " must be > 0");
                }
                if (value > 1.0f)
                {
                    if (value > 100.0f)
                    {
                        throw std::invalid_argument(
                            std::string(label) +
                            " must be in (0,1] or (0,100]");
                    }
                    return value / 100.0f;
                }
                return value;
            };

            auto resolve_pool_size =
                [&](const std::optional<size_t>& bytes,
                    const std::optional<float>& fraction, size_t total_memory,
                    size_t default_bytes, float default_fraction,
                    const char* label) -> size_t
            {
                size_t resolved = 0;
                if (bytes.has_value())
                {
                    if (*bytes == 0)
                    {
                        throw std::invalid_argument(std::string(label) +
                                                    " bytes must be > 0");
                    }
                    resolved = *bytes;
                }
                else if (fraction.has_value())
                {
                    float f = normalize_fraction(*fraction, label);
                    resolved = static_cast<size_t>(total_memory * f);
                }
                else if (default_bytes != 0)
                {
                    resolved = default_bytes;
                }
                else
                {
                    float f = normalize_fraction(default_fraction, label);
                    resolved = static_cast<size_t>(total_memory * f);
                }

                if (resolved == 0)
                {
                    throw std::invalid_argument(std::string(label) +
                                                " resolved to 0 bytes");
                }
                if (total_memory != 0 && resolved > total_memory)
                {
                    throw std::invalid_argument(
                        std::string(label) +
                        " exceeds available memory at initialization");
                }

                return roundup_256(resolved);
            };

            if (config.use_memory_pool)
            {
                size_t initial_host_pool_size = 0;
                if (!config.initial_host_bytes.has_value() &&
                    !config.initial_host_fraction.has_value())
                {
                    initial_host_pool_size =
                        roundup_256(static_cast<size_t>(104857600));
                }
                else
                {
                    initial_host_pool_size = resolve_pool_size(
                        config.initial_host_bytes, config.initial_host_fraction,
                        total_host_memory, 0, initial_host_memorypool_size,
                        "host initial pool size");
                }

                size_t max_host_pool_size = resolve_pool_size(
                    config.max_host_bytes, config.max_host_fraction,
                    total_host_memory, 0, max_host_memorypool_size,
                    "host max pool size");

                if (max_host_pool_size < initial_host_pool_size)
                {
                    throw std::invalid_argument(
                        "host max pool size must be >= host initial pool "
                        "size");
                }

                host_pool_ = std::make_shared<HostPoolResource>(
                    host_base_.get(), initial_host_pool_size,
                    max_host_pool_size);
                host_stats_adaptor_ =
                    std::make_shared<HostStatsAdaptor>(host_pool_.get());

                size_t initial_device_pool_size = resolve_pool_size(
                    config.initial_device_bytes, config.initial_device_fraction,
                    total_device_memory, 0, initial_device_memorypool_size,
                    "device initial pool size");
                size_t max_device_pool_size = resolve_pool_size(
                    config.max_device_bytes, config.max_device_fraction,
                    total_device_memory, 0, max_device_memorypool_size,
                    "device max pool size");

                if (max_device_pool_size < initial_device_pool_size)
                {
                    throw std::invalid_argument(
                        "device max pool size must be >= device initial pool "
                        "size");
                }

                device_pool_ = std::make_shared<DevicePoolResource>(
                    device_base_.get(), initial_device_pool_size,
                    max_device_pool_size);
                device_stats_adaptor_ =
                    std::make_shared<DeviceStatsAdaptor>(device_pool_.get());
            }

            initialized_ = true;
        }
    }

    void MemoryPool::use_memory_pool(bool use)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (use)
        {
            if (device_stats_adaptor_)
            {
                rmm::mr::set_current_device_resource(
                    device_stats_adaptor_.get());
            }
            else
            {
                ensure_base_resources();
                rmm::mr::set_current_device_resource(device_base_.get());
            }
        }
        else
        {
            ensure_base_resources();
            rmm::mr::set_current_device_resource(device_base_.get());
        }
    }

    void* MemoryPool::allocate(size_t size, cudaStream_t stream)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        return rmm::mr::get_current_device_resource()->allocate(size, stream);
    }

    void MemoryPool::deallocate(void* ptr, size_t size, cudaStream_t stream)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        rmm::mr::get_current_device_resource()->deallocate(ptr, size, stream);
    }

    rmm::mr::device_memory_resource* MemoryPool::get_device_resource() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (device_stats_adaptor_)
        {
            return device_stats_adaptor_.get();
        }
        const_cast<MemoryPool*>(this)->ensure_base_resources();
        return device_base_.get();
    }

    MemoryPool::HostStatsAdaptor* MemoryPool::get_host_resource() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        return host_stats_adaptor_.get();
    }

    void* MemoryPool::host_allocate(size_t size)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (host_pool_)
        {
            return host_pool_->allocate(size);
        }
        ensure_base_resources();
        return host_base_->allocate(size);
    }

    void MemoryPool::host_deallocate(void* ptr, size_t size)
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (host_pool_)
        {
            host_pool_->deallocate(ptr, size);
            return;
        }
        ensure_base_resources();
        host_base_->deallocate(ptr, size);
    }

    void MemoryPool::print_memory_pool_status() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (device_stats_adaptor_ && device_pool_ && host_stats_adaptor_ &&
            host_pool_)
        {
            auto device_status = device_stats_adaptor_->get_bytes_counter();
            const auto device_total = device_pool_->pool_size();
            const auto device_used = device_status.value;
            const auto device_free = device_total - device_used;

            auto host_status = host_stats_adaptor_->get_bytes_counter();
            const auto host_total = host_pool_->pool_size();
            const auto host_used = host_status.value;
            const auto host_free = host_total - host_used;

            std::cout << "[HEonGPU] Memory Pool Status" << std::endl;
            std::cout << "  Device Pool:" << std::endl;
            std::cout << "    Total     : " << device_total << " bytes"
                      << std::endl;
            std::cout << "    Used      : " << device_used << " bytes"
                      << std::endl;
            std::cout << "    Free      : " << device_free << " bytes"
                      << std::endl;
            std::cout << "  Host Pool:" << std::endl;
            std::cout << "    Total     : " << host_total << " bytes"
                      << std::endl;
            std::cout << "    Used      : " << host_used << " bytes"
                      << std::endl;
            std::cout << "    Free      : " << host_free << " bytes"
                      << std::endl;
        }
        else
        {
            std::cout
                << "[HEonGPU] Memory pool is not initialized or is disabled."
                << std::endl;
        }
    }

    size_t MemoryPool::get_current_device_pool_memory_usage() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (!device_stats_adaptor_)
        {
            return 0;
        }
        auto device_status = device_stats_adaptor_->get_bytes_counter();
        return device_status.value;
    }

    size_t MemoryPool::get_free_device_pool_memory() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (!device_stats_adaptor_ || !device_pool_)
        {
            return 0;
        }
        auto device_status = device_stats_adaptor_->get_bytes_counter();
        return device_pool_->pool_size() - device_status.value;
    }

    size_t MemoryPool::get_current_host_pool_memory_usage() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (!host_stats_adaptor_)
        {
            return 0;
        }
        auto host_status = host_stats_adaptor_->get_bytes_counter();
        return host_status.value;
    }

    size_t MemoryPool::get_free_host_pool_memory() const
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (!host_stats_adaptor_ || !host_pool_)
        {
            return 0;
        }
        auto host_status = host_stats_adaptor_->get_bytes_counter();
        return host_pool_->pool_size() - host_status.value;
    }

    MemoryPool::~MemoryPool()
    {
        clean_pool();
    }

    MemoryPool::MemoryPool() = default;

    void MemoryPool::clean_pool()
    {
        std::lock_guard<std::mutex> guard(mutex_);
        if (initialized_ || host_base_ || device_base_ || host_pool_ ||
            device_pool_)
        {
            rmm::mr::set_current_device_resource(nullptr);
            host_stats_adaptor_.reset();
            host_pool_.reset();
            host_base_.reset();
            device_stats_adaptor_.reset();
            device_pool_.reset();
            device_base_.reset();
            initialized_ = false;
        }
    }

} // namespace heongpu