Instructions to use Overworld/Waypoint-1-Small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Overworld/Waypoint-1-Small with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Overworld/Waypoint-1-Small", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (C) 2025 Hugging Face Team and Overworld | |
| # | |
| # This program is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # (at your option) any later version. | |
| # | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU General Public License for more details. | |
| # | |
| # You should have received a copy of the GNU General Public License | |
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| """Before-denoise blocks for WorldEngine modular pipeline.""" | |
| from typing import List, Optional, Union | |
| import PIL.Image | |
| import torch | |
| from torch import nn, Tensor | |
| from tensordict import TensorDict | |
| from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask | |
| from diffusers.configuration_utils import FrozenDict | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.utils import logging | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.modular_pipelines import ( | |
| ModularPipelineBlocks, | |
| ModularPipeline, | |
| PipelineState, | |
| SequentialPipelineBlocks, | |
| ) | |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( | |
| ComponentSpec, | |
| ConfigSpec, | |
| InputParam, | |
| OutputParam, | |
| ) | |
| logger = logging.get_logger(__name__) | |
| def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: | |
| """ | |
| Create a block mask for flex_attention. | |
| Args: | |
| T: Q length for this frame | |
| L: KV capacity == written.numel() | |
| written: [L] bool, True where there is valid KV data | |
| """ | |
| BS = _DEFAULT_SPARSE_BLOCK_SIZE | |
| KV_blocks = (L + BS - 1) // BS | |
| Q_blocks = (T + BS - 1) // BS | |
| # [KV_blocks, BS] | |
| written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view( | |
| KV_blocks, BS | |
| ) | |
| # Block-level occupancy | |
| block_any = written_blocks.any(-1) # block has at least one written token | |
| block_all = written_blocks.all(-1) # block is fully written | |
| # Every Q-block sees the same KV-block pattern | |
| nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks] | |
| full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks] | |
| partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks] | |
| def dense_to_ordered(dense_mask: torch.Tensor): | |
| # dense_mask: [Q_blocks, KV_blocks] bool | |
| # returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks] | |
| num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks] | |
| indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to( | |
| torch.int32 | |
| ) | |
| return num_blocks[None, None].contiguous(), indices[None, None].contiguous() | |
| # Partial blocks (need mask_mod) | |
| kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) | |
| # Full blocks (mask_mod can be skipped entirely) | |
| full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) | |
| def mask_mod(b, h, q, kv): | |
| return written[kv] | |
| bm = BlockMask.from_kv_blocks( | |
| kv_num_blocks, | |
| kv_indices, | |
| full_kv_num_blocks, | |
| full_kv_indices, | |
| BLOCK_SIZE=BS, | |
| mask_mod=mask_mod, | |
| seq_lengths=(T, L), | |
| compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path | |
| ) | |
| return bm | |
| class LayerKVCache(nn.Module): | |
| """ | |
| Ring-buffer KV cache with fixed capacity L (tokens) for history plus | |
| one extra frame (tokens_per_frame) at the tail holding the current frame. | |
| """ | |
| def __init__( | |
| self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 | |
| ): | |
| super().__init__() | |
| self.tpf = tokens_per_frame | |
| self.L = L | |
| # total KV capacity: ring (L) + tail frame (tpf) | |
| self.capacity = L + self.tpf | |
| self.pinned_dilation = pinned_dilation | |
| self.num_buckets = (L // self.tpf) // self.pinned_dilation | |
| assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 | |
| # KV buffer: [2, B, H, capacity, Dh] | |
| self.kv = nn.Buffer( | |
| torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), | |
| persistent=False, | |
| ) | |
| # which slots have ever been written | |
| # tail slice [L, L+tpf) always holds the current frame and is considered written | |
| written = torch.zeros(self.capacity, dtype=torch.bool) | |
| written[L:] = True | |
| self.written = nn.Buffer(written, persistent=False) | |
| # Precompute indices: | |
| # frame_offsets: [0, 1, ..., tpf-1] (for ring indexing) | |
| # current_idx: [L, L+1, ..., L+tpf-1] (tail slice) | |
| self.frame_offsets = nn.Buffer( | |
| torch.arange(self.tpf, dtype=torch.long), persistent=False | |
| ) | |
| self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) | |
| def reset(self): | |
| self.kv.zero_() | |
| self.written.zero_() | |
| self.written[self.L :].fill_(True) | |
| def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): | |
| """ | |
| Args: | |
| kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) | |
| pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1) | |
| """ | |
| T = self.tpf | |
| t_pos = pos_ids["t_pos"] | |
| if not torch.compiler.is_compiling(): | |
| torch._check( | |
| kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" | |
| ) | |
| torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") | |
| torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") | |
| torch._check( | |
| self.L % self.tpf == 0, | |
| f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", | |
| ) | |
| torch._check( | |
| self.kv.size(3) == self.capacity, | |
| "KV buffer has unexpected length (expected L + tokens_per_frame)", | |
| ) | |
| torch._check( | |
| (t_pos >= 0).all().item(), | |
| "t_pos must be non-negative during inference", | |
| ) | |
| torch._check( | |
| ((t_pos == t_pos[:, :1]).all()).item(), | |
| "t_pos must be constant within frame", | |
| ) | |
| frame_t = t_pos[0, 0] | |
| # map frame_t to a bucket, each bucket owns T contiguous slots | |
| bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation | |
| slot = bucket % self.num_buckets | |
| base = slot * T | |
| # indices in the ring for this frame: [T] in [0, L) | |
| ring_idx = self.frame_offsets + base | |
| # Always write current frame into the tail slice [L, L+T): | |
| # this is the "self-attention component" for the current frame. | |
| self.kv.index_copy_(3, self.current_idx, kv) | |
| write_step = frame_t.remainder(self.pinned_dilation) == 0 | |
| mask_written = self.written.clone() | |
| mask_written[ring_idx] = mask_written[ring_idx] & ~write_step | |
| bm = make_block_mask(T, self.capacity, mask_written) | |
| # Persist current frame into the ring for future queries when unfrozen. | |
| if not is_frozen: | |
| # Persist current frame into the ring for future queries. | |
| dst = torch.where(write_step, ring_idx, self.current_idx) | |
| self.kv.index_copy_(3, dst, kv) | |
| self.written[dst] = True | |
| k, v = self.kv.unbind(0) | |
| return k, v, bm | |
| class StaticKVCache(nn.Module): | |
| """Static KV cache with per-layer configuration for local/global attention.""" | |
| def __init__(self, config, batch_size, dtype): | |
| super().__init__() | |
| self.tpf = config.tokens_per_frame | |
| local_L = config.local_window * self.tpf | |
| global_L = config.global_window * self.tpf | |
| period = config.global_attn_period | |
| off = getattr(config, "global_attn_offset", 0) % period | |
| self.layers = nn.ModuleList( | |
| [ | |
| LayerKVCache( | |
| batch_size, | |
| getattr(config, "n_kv_heads", config.n_heads), | |
| global_L if ((layer_idx - off) % period == 0) else local_L, | |
| config.d_model // config.n_heads, | |
| dtype, | |
| self.tpf, | |
| ( | |
| config.global_pinned_dilation | |
| if ((layer_idx - off) % period == 0) | |
| else 1 | |
| ), | |
| ) | |
| for layer_idx in range(config.n_layers) | |
| ] | |
| ) | |
| self._is_frozen = True | |
| def reset(self): | |
| for layer in self.layers: | |
| layer.reset() | |
| self._is_frozen = True | |
| def set_frozen(self, is_frozen: bool): | |
| self._is_frozen = is_frozen | |
| def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): | |
| kv = torch.stack([k, v], dim=0) | |
| return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) | |
| class WorldEngineSetTimestepsStep(ModularPipelineBlocks): | |
| """Sets up the scheduler sigmas for rectified flow denoising.""" | |
| model_name = "world_engine" | |
| def description(self) -> str: | |
| return "Sets up scheduler sigmas for rectified flow denoising" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [] | |
| def expected_configs(self) -> List[ConfigSpec]: | |
| return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "scheduler_sigmas", | |
| type_hint=List[float], | |
| description="Custom scheduler sigmas (overrides config)", | |
| ), | |
| InputParam( | |
| "frame_timestamp", | |
| type_hint=torch.Tensor, | |
| description="Current frame timestamp", | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "scheduler_sigmas", | |
| type_hint=torch.Tensor, | |
| description="Tensor of scheduler sigmas for denoising", | |
| ), | |
| OutputParam( | |
| "frame_timestamp", | |
| type_hint=torch.Tensor, | |
| description="Current frame timestamp", | |
| ), | |
| ] | |
| def __call__( | |
| self, components: ModularPipeline, state: PipelineState | |
| ) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| device = components._execution_device | |
| dtype = components.transformer.dtype | |
| # Use provided sigmas or get from config | |
| sigmas = block_state.scheduler_sigmas | |
| if sigmas is None: | |
| sigmas = components.config.scheduler_sigmas | |
| block_state.scheduler_sigmas = torch.tensor( | |
| sigmas, device=device, dtype=dtype | |
| ) | |
| frame_ts = block_state.frame_timestamp | |
| if frame_ts is None: | |
| frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) | |
| elif isinstance(frame_ts, int): | |
| frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) | |
| block_state.frame_timestamp = frame_ts | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): | |
| """Initializes or reuses the KV cache for autoregressive generation.""" | |
| model_name = "world_engine" | |
| def description(self) -> str: | |
| return "Initializes or reuses KV cache for autoregressive frame generation" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "kv_cache", | |
| type_hint=Optional[StaticKVCache], | |
| description="Existing KV cache (will be reused if provided)", | |
| ), | |
| InputParam( | |
| "reset_cache", | |
| type_hint=bool, | |
| default=False, | |
| description="If True, reset the KV cache even if one exists", | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "kv_cache", | |
| type_hint=StaticKVCache, | |
| description="KV cache for transformer attention", | |
| ), | |
| ] | |
| def __call__( | |
| self, components: ModularPipeline, state: PipelineState | |
| ) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| device = components._execution_device | |
| dtype = components.transformer.dtype | |
| # Create or reuse KV cache | |
| if block_state.kv_cache is None: | |
| block_state.kv_cache = StaticKVCache( | |
| components.transformer.config, | |
| batch_size=1, | |
| dtype=dtype, | |
| ).to(device) | |
| elif block_state.reset_cache: | |
| block_state.kv_cache.reset() | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): | |
| """Prepares latents for frame generation, optionally encoding an input image.""" | |
| model_name = "world_engine" | |
| def description(self) -> str: | |
| return ( | |
| "Prepares latents for frame generation. If an image is provided on the " | |
| "first frame, encodes it and caches it as context. Always creates fresh " | |
| "random noise for the actual denoising." | |
| ) | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec( | |
| "image_processor", | |
| VaeImageProcessor, | |
| config=FrozenDict( | |
| { | |
| "vae_scale_factor": 16, | |
| "do_normalize": False, | |
| "do_convert_rgb": False, | |
| } | |
| ), | |
| default_creation_method="from_config", | |
| ), | |
| ] | |
| def expected_configs(self) -> List[ConfigSpec]: | |
| return [ | |
| ConfigSpec("channels", 16), | |
| ConfigSpec("height", 16), | |
| ConfigSpec("width", 16), | |
| ConfigSpec("patch", [2, 2]), | |
| ConfigSpec("vae_scale_factor", 16), | |
| ] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "image", | |
| type_hint=Union[PIL.Image.Image, torch.Tensor], | |
| description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", | |
| ), | |
| InputParam( | |
| "latents", | |
| type_hint=torch.Tensor, | |
| description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", | |
| ), | |
| InputParam( | |
| "use_random_latents", | |
| type_hint=bool, | |
| default=True, | |
| description="If True, always generate fresh random latents. If False, use provided latents.", | |
| ), | |
| InputParam( | |
| "kv_cache", | |
| description="KV cache to update", | |
| ), | |
| InputParam( | |
| "frame_timestamp", | |
| type_hint=torch.Tensor, | |
| description="Current frame timestamp", | |
| ), | |
| InputParam( | |
| "prompt_embeds", | |
| type_hint=torch.Tensor, | |
| description="Prompt embeddings for cache pass", | |
| ), | |
| InputParam( | |
| "prompt_pad_mask", | |
| type_hint=torch.Tensor, | |
| description="Prompt padding mask", | |
| ), | |
| InputParam( | |
| "button_tensor", | |
| type_hint=torch.Tensor, | |
| description="Button tensor for cache pass", | |
| ), | |
| InputParam( | |
| "mouse_tensor", | |
| type_hint=torch.Tensor, | |
| description="Mouse tensor for cache pass", | |
| ), | |
| InputParam( | |
| "scroll_tensor", | |
| type_hint=torch.Tensor, | |
| description="Scroll tensor for cache pass", | |
| ), | |
| InputParam( | |
| "generator", | |
| type_hint=torch.Generator, | |
| default=None, | |
| description="torch Generator for deterministic output", | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "latents", | |
| type_hint=torch.Tensor, | |
| description="Latent tensor for denoising [1, 1, C, H, W]", | |
| ), | |
| ] | |
| def _cache_pass( | |
| transformer, | |
| x, | |
| frame_timestamp, | |
| prompt_emb, | |
| prompt_pad_mask, | |
| mouse, | |
| button, | |
| scroll, | |
| kv_cache, | |
| ): | |
| """Cache pass to persist frame in KV cache.""" | |
| kv_cache.set_frozen(False) | |
| transformer( | |
| x=x, | |
| sigma=x.new_zeros((x.size(0), x.size(1))), | |
| frame_timestamp=frame_timestamp, | |
| prompt_emb=prompt_emb, | |
| prompt_pad_mask=prompt_pad_mask, | |
| mouse=mouse, | |
| button=button, | |
| scroll=scroll, | |
| kv_cache=kv_cache, | |
| ) | |
| def __call__( | |
| self, components: ModularPipeline, state: PipelineState | |
| ) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| device = components._execution_device | |
| dtype = components.transformer.dtype | |
| # Get latent shape info | |
| channels = components.config.channels | |
| height = components.config.height | |
| width = components.config.width | |
| patch = components.config.patch | |
| pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) | |
| shape = ( | |
| 1, | |
| 1, | |
| channels, | |
| components.config.vae_scale_factor * pH, | |
| components.config.vae_scale_factor * pW, | |
| ) | |
| if block_state.image is not None: | |
| image = block_state.image | |
| # Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1] | |
| image = components.image_processor.preprocess( | |
| image, | |
| height=height, | |
| width=width, | |
| ) | |
| # Convert to [H, W, 3] uint8 for VAE encoder | |
| image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) | |
| assert image.dtype == torch.uint8, ( | |
| f"Expected uint8 image, got {image.dtype}" | |
| ) | |
| latents = components.vae.encode(image) | |
| latents = latents.unsqueeze(1) | |
| # Run cache pass to persist encoded frame | |
| self._cache_pass( | |
| components.transformer, | |
| latents, | |
| block_state.frame_timestamp, | |
| block_state.prompt_embeds, | |
| block_state.prompt_pad_mask, | |
| block_state.mouse_tensor, | |
| block_state.button_tensor, | |
| block_state.scroll_tensor, | |
| block_state.kv_cache, | |
| ) | |
| block_state.frame_timestamp.add_(1) | |
| # Generate latents based on use_random_latents flag | |
| if block_state.use_random_latents or block_state.latents is None: | |
| block_state.latents = torch.randn( | |
| shape, device=device, dtype=torch.bfloat16 | |
| ) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): | |
| """Sequential pipeline that prepares all inputs for denoising.""" | |
| block_classes = [ | |
| WorldEngineSetTimestepsStep, | |
| WorldEngineSetupKVCacheStep, | |
| WorldEnginePrepareLatentsStep, | |
| ] | |
| block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] | |
| def description(self) -> str: | |
| return ( | |
| "Before denoise step that prepares inputs for denoising:\n" | |
| " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" | |
| " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" | |
| " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" | |
| ) | |