Skip to content

Code Sample: PQC — ML-KEM & ML-DSA

Complete ML-KEM and ML-DSA implementations using cuPQC-PK for GPU-accelerated post-quantum cryptography. ML-KEM (Module-Lattice Key Encapsulation Mechanism) is standardized in NIST FIPS 203 and provides secure key exchange resistant to both classical and quantum attacks. ML-DSA (Module-Lattice Digital Signature Algorithm) is standardized in NIST FIPS 204 and enables post-quantum secure digital signatures for authentication and non-repudiation.

ML-KEM (Key Encapsulation Mechanism)

#include <vector>
#include <iostream>
#include <pk.hpp>
using namespace cupqc;

using MLKEM512Key = decltype(ML_KEM_512()
                            + Function<function::Keygen>()
                            + Block()
                            + BlockDim<128>());

using MLKEM512Encaps = decltype(ML_KEM_512()
                               + Function<function::Encaps>()
                               + Block()
                               + BlockDim<128>());

using MLKEM512Decaps = decltype(ML_KEM_512()
                               + Function<function::Decaps>()
                               + Block()
                               + BlockDim<128>());

__global__ void keygen_kernel(uint8_t* public_keys, uint8_t* secret_keys,
                               uint8_t* workspace, uint8_t* randombytes)
{
    __shared__ uint8_t smem_ptr[MLKEM512Key::shared_memory_size];
    int block = blockIdx.x;
    auto public_key = public_keys + block * MLKEM512Key::public_key_size;
    auto secret_key = secret_keys + block * MLKEM512Key::secret_key_size;
    auto entropy = randombytes + block * MLKEM512Key::entropy_size;
    auto work = workspace + block * MLKEM512Key::workspace_size;

    MLKEM512Key().execute(public_key, secret_key, entropy, work, smem_ptr);
}

__global__ void encaps_kernel(uint8_t* ciphertexts, uint8_t* shared_secrets,
                               const uint8_t* public_keys, uint8_t* workspace,
                               uint8_t* randombytes)
{
    __shared__ uint8_t smem_ptr[MLKEM512Encaps::shared_memory_size];
    int block = blockIdx.x;
    auto shared_secret = shared_secrets + block * MLKEM512Encaps::shared_secret_size;
    auto ciphertext = ciphertexts + block * MLKEM512Encaps::ciphertext_size;
    auto public_key = public_keys + block * MLKEM512Encaps::public_key_size;
    auto entropy = randombytes + block * MLKEM512Encaps::entropy_size;
    auto work = workspace + block * MLKEM512Encaps::workspace_size;

    MLKEM512Encaps().execute(ciphertext, shared_secret, public_key,
                             entropy, work, smem_ptr);
}

__global__ void decaps_kernel(uint8_t* shared_secrets,
                               const uint8_t* ciphertexts,
                               const uint8_t* secret_keys, uint8_t* workspace)
{
    __shared__ uint8_t smem_ptr[MLKEM512Decaps::shared_memory_size];
    int block = blockIdx.x;
    auto shared_secret = shared_secrets + block * MLKEM512Decaps::shared_secret_size;
    auto ciphertext = ciphertexts + block * MLKEM512Decaps::ciphertext_size;
    auto secret_key = secret_keys + block * MLKEM512Decaps::secret_key_size;
    auto work = workspace + block * MLKEM512Decaps::workspace_size;

    MLKEM512Decaps().execute(shared_secret, ciphertext, secret_key,
                            work, smem_ptr);
}

void ml_kem_keygen(std::vector<uint8_t>& public_keys,
                   std::vector<uint8_t>& secret_keys,
                   const unsigned int batch)
{
    auto workspace = make_workspace<MLKEM512Key>(batch);
    auto randombytes = get_entropy<MLKEM512Key>(batch);

    uint8_t* d_public_key = nullptr;
    uint8_t* d_secret_key = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_public_key),
               MLKEM512Key::public_key_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_secret_key),
               MLKEM512Key::secret_key_size * batch);

    keygen_kernel<<<batch, MLKEM512Key::BlockDim>>>(
        d_public_key, d_secret_key, workspace, randombytes);

    cudaMemcpy(public_keys.data(), d_public_key,
               MLKEM512Key::public_key_size * batch,
               cudaMemcpyDeviceToHost);
    cudaMemcpy(secret_keys.data(), d_secret_key,
               MLKEM512Key::secret_key_size * batch,
               cudaMemcpyDeviceToHost);

    cudaFree(d_public_key);
    cudaFree(d_secret_key);
    destroy_workspace(workspace);
    release_entropy(randombytes);
}

void ml_kem_encaps(std::vector<uint8_t>& ciphertexts,
                   std::vector<uint8_t>& sharedsecrets,
                   const std::vector<uint8_t>& public_keys,
                   const unsigned int batch)
{
    auto workspace = make_workspace<MLKEM512Encaps>(batch);
    auto randombytes = get_entropy<MLKEM512Encaps>(batch);

    uint8_t* d_ciphertext = nullptr;
    uint8_t* d_public_key = nullptr;
    uint8_t* d_sharedsecret = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_ciphertext),
               MLKEM512Encaps::ciphertext_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_public_key),
               MLKEM512Encaps::public_key_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_sharedsecret),
               MLKEM512Encaps::shared_secret_size * batch);

    cudaMemcpy(d_public_key, public_keys.data(),
               MLKEM512Encaps::public_key_size * batch,
               cudaMemcpyHostToDevice);

    encaps_kernel<<<batch, MLKEM512Encaps::BlockDim>>>(
        d_ciphertext, d_sharedsecret, d_public_key, workspace, randombytes);

    cudaMemcpy(ciphertexts.data(), d_ciphertext,
               MLKEM512Encaps::ciphertext_size * batch,
               cudaMemcpyDeviceToHost);
    cudaMemcpy(sharedsecrets.data(), d_sharedsecret,
               MLKEM512Encaps::shared_secret_size * batch,
               cudaMemcpyDeviceToHost);

    cudaFree(d_public_key);
    cudaFree(d_ciphertext);
    cudaFree(d_sharedsecret);
    destroy_workspace(workspace);
    release_entropy(randombytes);
}

void ml_kem_decaps(std::vector<uint8_t>& sharedsecrets,
                   const std::vector<uint8_t>& ciphertexts,
                   const std::vector<uint8_t>& secret_keys,
                   const unsigned int batch)
{
    auto workspace = make_workspace<MLKEM512Decaps>(batch);

    uint8_t* d_sharedsecret = nullptr;
    uint8_t* d_ciphertext = nullptr;
    uint8_t* d_secret_key = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_sharedsecret),
               MLKEM512Decaps::shared_secret_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_ciphertext),
               MLKEM512Decaps::ciphertext_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_secret_key),
               MLKEM512Decaps::secret_key_size * batch);

    cudaMemcpy(d_ciphertext, ciphertexts.data(),
               MLKEM512Decaps::ciphertext_size * batch,
               cudaMemcpyHostToDevice);
    cudaMemcpy(d_secret_key, secret_keys.data(),
               MLKEM512Decaps::secret_key_size * batch,
               cudaMemcpyHostToDevice);

    decaps_kernel<<<batch, MLKEM512Decaps::BlockDim>>>(
        d_sharedsecret, d_ciphertext, d_secret_key, workspace);

    cudaMemcpy(sharedsecrets.data(), d_sharedsecret,
               MLKEM512Decaps::shared_secret_size * batch,
               cudaMemcpyDeviceToHost);

    cudaFree(d_sharedsecret);
    cudaFree(d_ciphertext);
    cudaFree(d_secret_key);
    destroy_workspace(workspace);
}

int main() {
    unsigned int batch = 10;
    std::vector<uint8_t> public_keys(MLKEM512Key::public_key_size * batch);
    std::vector<uint8_t> secret_keys(MLKEM512Key::secret_key_size * batch);

    ml_kem_keygen(public_keys, secret_keys, batch);

    std::vector<uint8_t> ciphertexts(MLKEM512Encaps::ciphertext_size * batch);
    std::vector<uint8_t> sharedsecrets_encap(
        MLKEM512Encaps::shared_secret_size * batch);
    ml_kem_encaps(ciphertexts, sharedsecrets_encap, public_keys, batch);

    std::vector<uint8_t> sharedsecrets_decap(
        MLKEM512Encaps::shared_secret_size * batch);
    ml_kem_decaps(sharedsecrets_decap, ciphertexts, secret_keys, batch);

    return 0;
}

ML-DSA (Digital Signature Algorithm)

#include <vector>
#include <iostream>
#include <pk.hpp>
using namespace cupqc;

using MLDSA44Key = decltype(ML_DSA_44()
                           + Function<function::Keygen>()
                           + Block()
                           + BlockDim<128>());

using MLDSA44Sign = decltype(ML_DSA_44()
                            + Function<function::Sign>()
                            + Block()
                            + BlockDim<128>());

using MLDSA44Verify = decltype(ML_DSA_44()
                              + Function<function::Verify>()
                              + Block()
                              + BlockDim<128>());

__global__ void keygen_kernel(uint8_t* public_keys, uint8_t* secret_keys,
                               uint8_t* randombytes, uint8_t* workspace)
{
    __shared__ uint8_t smem_ptr[MLDSA44Key::shared_memory_size];
    int block = blockIdx.x;
    auto public_key = public_keys + block * MLDSA44Key::public_key_size;
    auto secret_key = secret_keys + block * MLDSA44Key::secret_key_size;
    auto entropy = randombytes + block * MLDSA44Key::entropy_size;
    auto work = workspace + block * MLDSA44Key::workspace_size;

    MLDSA44Key().execute(public_key, secret_key, entropy, work, smem_ptr);
}

__global__ void sign_kernel(uint8_t* signatures, const uint8_t* messages,
                            const size_t message_size, const uint8_t* secret_keys,
                            uint8_t* randombytes, uint8_t* workspace)
{
    __shared__ uint8_t smem_ptr[MLDSA44Sign::shared_memory_size];
    int block = blockIdx.x;
    auto signature = signatures + block * ((MLDSA44Sign::signature_size + 7) / 8 * 8);
    auto message = messages + block * message_size;
    auto secret_key = secret_keys + block * MLDSA44Sign::secret_key_size;
    auto entropy = randombytes + block * MLDSA44Sign::entropy_size;
    auto work = workspace + block * MLDSA44Sign::workspace_size;

    MLDSA44Sign().execute(signature, message, message_size, secret_key,
                         entropy, work, smem_ptr);
}

__global__ void verify_kernel(bool* valids, const uint8_t* signatures,
                              const uint8_t* messages, const size_t message_size,
                              const uint8_t* public_keys, uint8_t* workspace)
{
    __shared__ uint8_t smem_ptr[MLDSA44Verify::shared_memory_size];
    int block = blockIdx.x;
    auto signature = signatures + block * ((MLDSA44Sign::signature_size + 7) / 8 * 8);
    auto message = messages + block * message_size;
    auto public_key = public_keys + block * MLDSA44Verify::public_key_size;
    auto work = workspace + block * MLDSA44Verify::workspace_size;

    valids[block] = MLDSA44Verify().execute(message, message_size, signature,
                                           public_key, work, smem_ptr);
}

void ml_dsa_keygen(std::vector<uint8_t>& public_keys,
                   std::vector<uint8_t>& secret_keys,
                   const unsigned int batch)
{
    auto workspace = make_workspace<MLDSA44Key>(batch);
    auto randombytes = get_entropy<MLDSA44Key>(batch);

    uint8_t* d_public_key = nullptr;
    uint8_t* d_secret_key = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_public_key),
               MLDSA44Key::public_key_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_secret_key),
               MLDSA44Key::secret_key_size * batch);

    keygen_kernel<<<batch, MLDSA44Key::BlockDim>>>(
        d_public_key, d_secret_key, randombytes, workspace);

    cudaMemcpy(public_keys.data(), d_public_key,
               MLDSA44Key::public_key_size * batch,
               cudaMemcpyDeviceToHost);
    cudaMemcpy(secret_keys.data(), d_secret_key,
               MLDSA44Key::secret_key_size * batch,
               cudaMemcpyDeviceToHost);

    cudaFree(d_public_key);
    cudaFree(d_secret_key);
    destroy_workspace(workspace);
    release_entropy(randombytes);
}

void ml_dsa_sign(std::vector<uint8_t>& signatures,
                std::vector<uint8_t>& messages, size_t message_size,
                const std::vector<uint8_t>& secret_keys,
                const unsigned int batch)
{
    auto workspace = make_workspace<MLDSA44Sign>(batch);
    auto randombytes = get_entropy<MLDSA44Sign>(batch);

    uint8_t* d_signatures = nullptr;
    uint8_t* d_messages = nullptr;
    uint8_t* d_secret_key = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_signatures),
               ((MLDSA44Sign::signature_size + 7) / 8 * 8) * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_messages), message_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_secret_key),
               MLDSA44Sign::secret_key_size * batch);

    cudaMemcpy(d_messages, messages.data(), message_size * batch,
               cudaMemcpyHostToDevice);
    cudaMemcpy(d_secret_key, secret_keys.data(),
               MLDSA44Sign::secret_key_size * batch,
               cudaMemcpyHostToDevice);

    sign_kernel<<<batch, MLDSA44Sign::BlockDim>>>(
        d_signatures, d_messages, message_size, d_secret_key,
        randombytes, workspace);

    cudaMemcpy(signatures.data(), d_signatures,
               ((MLDSA44Sign::signature_size + 7) / 8 * 8) * batch,
               cudaMemcpyDeviceToHost);

    cudaFree(d_signatures);
    cudaFree(d_messages);
    cudaFree(d_secret_key);
    destroy_workspace(workspace);
    release_entropy(randombytes);
}

void ml_dsa_verify(bool* valids, const std::vector<uint8_t>& signatures,
                  const std::vector<uint8_t>& messages, size_t message_size,
                  const std::vector<uint8_t>& public_keys,
                  const unsigned int batch)
{
    auto workspace = make_workspace<MLDSA44Verify>(batch);

    bool* d_valids = nullptr;
    uint8_t* d_signatures = nullptr;
    uint8_t* d_messages = nullptr;
    uint8_t* d_public_key = nullptr;

    cudaMalloc(reinterpret_cast<void**>(&d_valids), batch * sizeof(bool));
    cudaMalloc(reinterpret_cast<void**>(&d_signatures),
               ((MLDSA44Sign::signature_size + 7) / 8 * 8) * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_messages), message_size * batch);
    cudaMalloc(reinterpret_cast<void**>(&d_public_key),
               MLDSA44Verify::public_key_size * batch);

    cudaMemcpy(d_signatures, signatures.data(),
               ((MLDSA44Sign::signature_size + 7) / 8 * 8) * batch,
               cudaMemcpyHostToDevice);
    cudaMemcpy(d_messages, messages.data(), message_size * batch,
               cudaMemcpyHostToDevice);
    cudaMemcpy(d_public_key, public_keys.data(),
               MLDSA44Verify::public_key_size * batch,
               cudaMemcpyHostToDevice);

    verify_kernel<<<batch, MLDSA44Verify::BlockDim>>>(
        d_valids, d_signatures, d_messages, message_size, d_public_key,
        workspace);

    cudaMemcpy(valids, d_valids, batch * sizeof(bool),
               cudaMemcpyDeviceToHost);

    cudaFree(d_valids);
    cudaFree(d_signatures);
    cudaFree(d_messages);
    cudaFree(d_public_key);
    destroy_workspace(workspace);
}

int main() {
    constexpr unsigned int batch = 10;
    constexpr size_t message_size = 1024;

    std::vector<uint8_t> public_keys(MLDSA44Key::public_key_size * batch);
    std::vector<uint8_t> secret_keys(MLDSA44Key::secret_key_size * batch);

    ml_dsa_keygen(public_keys, secret_keys, batch);

    std::vector<uint8_t> signatures(
        ((MLDSA44Sign::signature_size + 7) / 8 * 8) * batch);
    std::vector<uint8_t> messages(message_size * batch);
    ml_dsa_sign(signatures, messages, message_size, secret_keys, batch);

    bool is_valids[batch];
    ml_dsa_verify(is_valids, signatures, messages, message_size,
                 public_keys, batch);

    return 0;
}

Build and Run

git clone https://github.com/NVIDIA/cuPQC.git
cd cuPQC/examples/public_key
make
./example_ml_kem
./example_ml_dsa

Learn More