FlashMemory DS-V4 Retriever
A lightweight retriever that sparsifies DeepSeek-V4 Compressed-Sparse-Attention (CSA) KV-cache. Given a decode-token hidden state, it predicts which compressed-K chunks the next ~64 tokens will attend to β keeping only those on GPU, offloading the rest.
This model is the Neural Memory Indexer presented in FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention.
Detailed code and inference scripts can be found on GitHub: libertywing/FlashMemory-Deepseek-V4.
Performance
In downstream evaluation, it matches or beats the full-attention baseline on reasoning-heavy long-context tasks (RULER, LongMemEval, LongBench V2) while reducing KV-cache usage by ~85β90%. Precise needle-retrieval tasks require an additional threshold-fallback mechanism (not in this release).
Quick start
To use this model, you will need the retriever.py file from the official repository.
pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
Usage
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
# hidden: [B, 4096] decode-token hidden state
# comp_k: [B, N, 132] uint8 compressed CSA keys
# positions: [B] int64 token positions
# Per-layer sigmoid scores: {"l10": [B,N], "l12": [B,N], "l20": [B,N]}
per_layer = model(hidden, comp_k, positions)
# Cross-layer ensemble (mode="max" or "mean")
scores = model.ensemble(hidden, comp_k, positions, mode="max") # [B, N]
# Boolean keep mask
keep = model.select_topk(hidden, comp_k, positions, top_k=512) # top-K
keep = model.select_topk(hidden, comp_k, positions, threshold=0.5) # threshold
compressed_k format: each chunk = 128 bytes float8_e4m3 values + 4 bytes float32 scale. See make_mock_compressed_k() in demo.py.
Architecture
3-layer joint model (l10, l12, l20), 128 heads, 2048 LoRA rank. Per-layer sigmoid scores are ensembled (max or mean) per chunk.
hidden [B,4096] β q-proj β RoPE(YaRN) β Hadamard β q [B,128,128]
β weights_proj β fused_w [B,128]
compressed_k β FP8 dequant β k [B,N,128]
score = sigmoid( Ξ£( relu(k @ qα΅) Β· fused_w ) ) β [0,1]
Citation
If you use FlashMemory in your research, please cite:
@article{wang2026flashmemory,
title = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
author = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
Yujiu Yang and Haitao Mi and Dong Yu},
year = {2026},
journal = {arXiv preprint arXiv:2606.09079},
url = {https://huggingface.co/papers/2606.09079},
}
License
MIT