import spaces import logging import os import random import re import sys import warnings from PIL import Image from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.append(os.path.dirname(os.path.abspath(__file__))) from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # ==================== Environment Variables ================================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") HF_TOKEN = os.environ.get("HF_TOKEN") # ============================================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1536x864 ( 16:9 )", "864x1536 ( 9:16 )", "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], "1536": [ "1536x1536 ( 1:1 )", "1728x1344 ( 9:7 )", "1344x1728 ( 7:9 )", "1728x1296 ( 4:3 )", "1296x1728 ( 3:4 )", "1872x1248 ( 3:2 )", "1248x1872 ( 2:3 )", "2048x1152 ( 16:9 )", "1152x2048 ( 9:16 )", "2016x864 ( 21:9 )", "864x2016 ( 9:21 )", ], "2048": [ "2048x2048 ( 1:1 )", "2304x1792 ( 9:7 )", "1792x2304 ( 7:9 )", "2304x1728 ( 4:3 )", "1728x2304 ( 3:4 )", "2496x1664 ( 3:2 )", "1664x2496 ( 2:3 )", "2720x1536 ( 16:9 )", "1536x2720 ( 9:16 )", "2688x1152 ( 21:9 )", "1152x2688 ( 9:21 )", ], } RESOLUTION_SET = [] for resolutions in RES_CHOICES.values(): RESOLUTION_SET.extend(resolutions) EXAMPLE_PROMPTS = [ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"] ] def get_resolution(resolution): match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def load_models(model_path, enable_compile=False, attention_backend="flash_3"): print(f"Loading models from {model_path}...") use_auth_token = HF_TOKEN if HF_TOKEN else True if not os.path.exists(model_path): vae = AutoencoderKL.from_pretrained( f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ) text_encoder = AutoModelForCausalLM.from_pretrained( f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", use_auth_token=use_auth_token, ).eval() tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) else: vae = AutoencoderKL.from_pretrained( os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" ) text_encoder = AutoModelForCausalLM.from_pretrained( os.path.join(model_path, "text_encoder"), torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) tokenizer.padding_side = "left" if enable_compile: print("Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) if enable_compile: pipe.vae.disable_tiling() if not os.path.exists(model_path): transformer = ZImageTransformer2DModel.from_pretrained( f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token ).to("cuda", torch.bfloat16) else: transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to( "cuda", torch.bfloat16 ) pipe.transformer = transformer pipe.transformer.set_attention_backend(attention_backend) if enable_compile: print("Compiling transformer...") pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) pipe.to("cuda", torch.bfloat16) return pipe def generate_image( pipe, prompt, resolution="1024x1024", seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True), ): width, height = get_resolution(resolution) generator = torch.Generator("cuda").manual_seed(seed) scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) pipe.scheduler = scheduler image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=max_sequence_length, ).images[0] return image def warmup_model(pipe, resolutions): print("Starting warmup phase...") dummy_prompt = "warmup" for res_str in resolutions: print(f"Warming up for resolution: {res_str}") try: for i in range(3): generate_image( pipe, prompt=dummy_prompt, resolution=res_str, num_inference_steps=9, guidance_scale=0.0, seed=42 + i, ) except Exception as e: print(f"Warmup failed for {res_str}: {e}") print("Warmup completed.") pipe = None def init_app(): global pipe try: pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") if ENABLE_WARMUP: all_resolutions = [] for cat in RES_CHOICES.values(): all_resolutions.extend(cat) warmup_model(pipe, all_resolutions) except Exception as e: print(f"Error loading model: {e}") pipe = None @spaces.GPU def generate( prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True), ): """ Generate an image using the Z-Image model based on the provided prompt and settings. Args: prompt (str): Text prompt describing the desired image content resolution (str): Output resolution seed (int): Seed for reproducible generation steps (int): Number of inference steps shift (float): Time shift parameter random_seed (bool): Whether to generate a new random seed gallery_images (list): List of previously generated images progress (gr.Progress): Gradio progress tracker Returns: tuple: (gallery_images, seed_str, seed_int) """ if random_seed: new_seed = random.randint(1, 1000000) else: new_seed = seed if seed != -1 else random.randint(1, 1000000) try: if pipe is None: raise gr.Error("Model not loaded.") final_prompt = prompt try: resolution_str = resolution.split(" ")[0] except: resolution_str = "1024x1024" image = generate_image( pipe=pipe, prompt=final_prompt, resolution=resolution_str, seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps), shift=shift, ) except Exception as e: print(f"Error generation: {e}") # Return empty/error image or re-raise # For now, just re-raising to let Gradio handle or user see error raise e if gallery_images is None: gallery_images = [] gallery_images = [image] + gallery_images return gallery_images, str(new_seed), int(new_seed) init_app() # ==================== AoTI (Ahead of Time Inductor compilation) ==================== # pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] # spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") with gr.Blocks(title="Z-Image Demo") as demo: gr.Markdown( """