FlashMemory DS-V4 Retriever

A lightweight retriever that sparsifies DeepSeek-V4 Compressed-Sparse-Attention (CSA) KV-cache. Given a decode-token hidden state, it predicts which compressed-K chunks the next ~64 tokens will attend to β€” keeping only those on GPU, offloading the rest.

This model is the Neural Memory Indexer presented in FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention.

Detailed code and inference scripts can be found on GitHub: libertywing/FlashMemory-Deepseek-V4.

Performance

In downstream evaluation, it matches or beats the full-attention baseline on reasoning-heavy long-context tasks (RULER, LongMemEval, LongBench V2) while reducing KV-cache usage by ~85–90%. Precise needle-retrieval tasks require an additional threshold-fallback mechanism (not in this release).

Quick start

To use this model, you will need the retriever.py file from the official repository.

pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors

Usage

from retriever import FlashMemoryRetriever

model = FlashMemoryRetriever.from_checkpoint(
    "weights/flashmemory_ds_v4.safetensors", device="cuda"
)

# hidden:     [B, 4096] decode-token hidden state
# comp_k:     [B, N, 132] uint8 compressed CSA keys
# positions:  [B] int64 token positions

# Per-layer sigmoid scores: {"l10": [B,N], "l12": [B,N], "l20": [B,N]}
per_layer = model(hidden, comp_k, positions)

# Cross-layer ensemble (mode="max" or "mean")
scores = model.ensemble(hidden, comp_k, positions, mode="max")       # [B, N]

# Boolean keep mask
keep = model.select_topk(hidden, comp_k, positions, top_k=512)       # top-K
keep = model.select_topk(hidden, comp_k, positions, threshold=0.5)   # threshold

compressed_k format: each chunk = 128 bytes float8_e4m3 values + 4 bytes float32 scale. See make_mock_compressed_k() in demo.py.

Architecture

3-layer joint model (l10, l12, l20), 128 heads, 2048 LoRA rank. Per-layer sigmoid scores are ensembled (max or mean) per chunk.

hidden [B,4096] β†’ q-proj β†’ RoPE(YaRN) β†’ Hadamard β†’ q [B,128,128]
               β†’ weights_proj β†’ fused_w [B,128]
compressed_k    β†’ FP8 dequant β†’ k [B,N,128]

score = sigmoid( Ξ£( relu(k @ qα΅€) Β· fused_w ) )  ∈ [0,1]

Citation

If you use FlashMemory in your research, please cite:

@article{wang2026flashmemory,
  title   = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
  author  = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
             Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
             Yujiu Yang and Haitao Mi and Dong Yu},
  year    = {2026},
  journal = {arXiv preprint arXiv:2606.09079},
  url     = {https://huggingface.co/papers/2606.09079},
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for libertywing/FlashMemory-Deepseek-V4