Core

The flashdreams.core package collects the low-level kernels and process-group utilities that integrations share.

Attention

The attention package provides the kernels used by the transformer and the block-structured KV cache that backs streaming inference.

class NativeAttention(qkv_format: Literal['bhsd', 'bshd'] = 'bhsd', backend: Literal['math', 'efficient', 'cudnn', 'flash'] = 'cudnn')[source]

Bases: Module

Native attention module with configurable QKV layout and SDPA backend.

set_context_parallel_group(cp_group: ProcessGroup | None) None[source]

Enable or disable context parallelism for ring attention.

Parameters:

cp_group – Process group for context parallel; use None to disable.

is_context_parallel_enabled() bool[source]

Return True if context parallelism is active.

context_parallel_size() int[source]

Return the context parallel world size, or 1 if disabled.

forward(query: Tensor, key: Tensor, value: Tensor) Tensor[source]

Run context-parallel SDPA (or single-rank SDPA when CP is disabled).

Parameters:
  • query – Query tensor in configured qkv_format.

  • key – Key tensor in configured qkv_format.

  • value – Value tensor in configured qkv_format.

Returns:

Attention output in the same format as inputs.

class ContextParallelAttention(qkv_format: Literal['bhsd', 'bshd'] = 'bhsd', backend: Literal['cudnn', 'flash'] = 'cudnn', method: Literal['ring', 'ulysses'] = 'ring', convert_to_fp32: bool = True)[source]

Bases: NativeAttention

Context-parallel attention with selectable method and SDPA backend.

class BlockKVCache(k_shape: tuple[int, ...], v_shape: tuple[int, ...], seq_dim: int, chunk_size: int, window_size: int, sink_size: int = 0, device: device | str = device(type='cuda'), dtype: dtype = torch.float16, _prev_chunk_idx: int = -1, _curr_chunk_idx: int | None = None, _n_cached: int = 0)[source]

Bases: object

KV cache for causal attention with a fixed-size local window, CUDA-graph compatible.

Keys and values can have arbitrary shape [..., total_size, ...]; the sequence (rolling) dimension is given by seq_dim (dimension index, can be negative). Layout along that dimension: [sink tokens | local window tokens]. Sink tokens are never evicted; the local window rolls left as new chunks are added if full. Chunks are non-overlapping: each update adds one chunk of chunk_size tokens at the next logical position in the full sequence.

Note: Currently only supports total_size (sink_size + window_size) divisible by chunk_size.

Phases:
  • - Filling – cache not yet full; tokens are written contiguously; cached_k() / cached_v() return only the valid prefix.

  • - Steady-state – cache full; each new chunk triggers a left-roll of the local window and overwrites the rightmost positions; cached_k() / cached_v() return the full buffer.

The argument chunk_idx (0, 1, 2, …) is the index of the new chunk in the full sequence (not an index into the cache). If chunk_idx is greater than the previous one, the chunk is appended (or, in steady-state, written after the roll). If chunk_idx equals the previous one, the same cache positions are overwritten.

Per-step usage:
  1. before_update(chunk_idx) — prepare (roll local window if steady-state).

  2. update(k, v) — write the new chunk’s keys/values into the cache.

  3. cached_k() / cached_v() — get cached keys/values for attention.

  4. after_update(chunk_idx) — update internal bookkeeping.

k_shape: tuple[int, ...]

Shape of the keys. Must be the same as the values shape except for the last dimension.

v_shape: tuple[int, ...]

Shape of the values. Must be the same as the keys shape except for the last dimension.

seq_dim: int

Sequence dimension that will be rolled. Can be negative.

chunk_size: int

Number of tokens processed each time.

window_size: int

Size of the local attention window (excluding sink tokens).

sink_size: int = 0

Number of sink tokens at the start of the cache that are never evicted. Defaults to 0.

device: device | str = device(type='cuda')

Device to store the cache on.

dtype: dtype = torch.float16

Data type to store the cache in.

property size: int

Number of valid cached tokens visible to attention.

property write_end: int

Right edge of the current chunk in the physical cache layout.

classmethod from_tensor(k: Tensor, v: Tensor, seq_dim: int) Self[source]

Build a single-chunk cache pre-filled with the given key and value tensors.

is_steady_state() bool[source]

Return True if the cache is full (steady-state phase).

before_update(chunk_idx: int) None[source]

Prepare the cache before writing new tokens.

If chunk_idx equals the previous chunk index, this is a no-op. Otherwise, we expect the chunk_idx to be +1 from the previous chunk index. In this case, we will roll the local window left if the cache is in steady-state, or no op if the cache is in filling phase.

Parameters:

chunk_idx – Chunk index of the new chunk in the full sequence.

update(k: Tensor, v: Tensor) None[source]

Write the new chunk’s keys and values into the cache.

Must be called after before_update() and before after_update().

Parameters:
  • k – Keys; shape must match cached keys except at seq_dim, where length must be chunk_size.

  • v – Values; shape must match cached values except at seq_dim, where length must be chunk_size.

after_update(chunk_idx: int) None[source]

Finalize bookkeeping after writing new tokens.

Updates _prev_chunk_idx and, in filling phase, _n_cached.

Parameters:

chunk_idx – The index of the new chunk in the full sequence.

cached_k() Tensor[source]

Return cached keys for attention (valid prefix in filling phase, full buffer in steady-state).

cached_v() Tensor[source]

Return cached values for attention (valid prefix in filling phase, full buffer in steady-state).

reset() None[source]

Reset the cache to its initial empty state.

Distributed

Helpers for multi-GPU / multi-node inference. init boots the NCCL process group with sensible defaults (NVML-derived CPU affinity, heartbeat timeout, larger L2 fetch granularity) and is a drop-in for the boilerplate at the top of the example launchers.

init() int | None[source]

Initialize distributed training.

class Device(device_idx: int)[source]

Bases: object

Lightweight wrapper around an NVML device handle for CPU-affinity queries.

get_name() str[source]

Return the marketing name reported by NVML for this device.

get_cpu_affinity() list[int][source]

Return the indices of CPUs ideally affined to this GPU per NVML.