Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from diffusers import DiffusionPipeline | |
| import os | |
| import random | |
| # --- Model Loading and Setup --- | |
| model_name = "OPPOer/Qwen-Image-Pruning" | |
| COMPILATION_WIDTH = 1328 | |
| COMPILATION_HEIGHT = 1328 | |
| # Configure device and dtype | |
| if torch.cuda.is_available(): | |
| # Use bfloat16 for optimal performance on modern NVIDIA GPUs (A100/H200 recommended) | |
| torch_dtype = torch.bfloat16 | |
| device = "cuda" | |
| else: | |
| # Fallback for CPU, note: diffusion on CPU is extremely slow | |
| torch_dtype = torch.float32 | |
| device = "cpu" | |
| try: | |
| # Load the pipeline | |
| pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True) | |
| pipe.to(device) | |
| except Exception as e: | |
| # Handle environment where bfloat16 is not fully supported or other loading issues | |
| print(f"Failed to load model with bfloat16: {e}. Trying float16/32 fallback.") | |
| try: | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True) | |
| pipe.to(device) | |
| except Exception as e2: | |
| print(f"Failed to load model even with fallback: {e2}") | |
| raise e2 | |
| # Qwen-specific prompt extension (Chinese magic prompt) | |
| positive_magic = ", 超清,4K,电影级构图。" | |
| negative_prompt = "bad anatomy, blurry, disfigured, poorly drawn face, mutation, mutated, extra limb, missing limb, floating limbs, disconnected limbs, malformed hands, ugly, low-resolution, artifacts, text, watermark, signature" | |
| # --- ZeroGPU AoT Compilation (Mandatory for Diffusion Models) --- | |
| if device == "cuda": | |
| def compile_transformer(): | |
| print("Starting AOT compilation...") | |
| # Qwen-Image uses a transformer (DiT-style architecture). | |
| if not hasattr(pipe, 'transformer'): | |
| raise AttributeError("Pipeline does not have a 'transformer' attribute for AoT compilation.") | |
| # 1. Capture example inputs (run minimal inference) | |
| prompt_for_capture = "test prompt for compilation" | |
| # Ensure CFG is enabled for export (true_cfg_scale=1) | |
| temp_cfg = pipe.config.true_cfg_scale | |
| pipe.config.true_cfg_scale = 1.0 | |
| with spaces.aoti_capture(pipe.transformer) as call: | |
| pipe( | |
| prompt=prompt_for_capture, | |
| negative_prompt=negative_prompt, | |
| width=COMPILATION_WIDTH, | |
| height=COMPILATION_HEIGHT, | |
| num_inference_steps=1, | |
| true_cfg_scale=1.0, | |
| generator=torch.Generator(device=device).manual_seed(42), | |
| ) | |
| # Restore original config | |
| pipe.config.true_cfg_scale = temp_cfg | |
| # 2. Export the model (static shapes based on COMPILATION_WIDTH/HEIGHT) | |
| exported = torch.export.export( | |
| pipe.transformer, | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| ) | |
| # 3. Compile the exported model | |
| print(f"Export successful. Compiling for {COMPILATION_WIDTH}x{COMPILATION_HEIGHT}...") | |
| return spaces.aoti_compile(exported) | |
| # 4. Apply compiled model to pipeline during startup | |
| try: | |
| compiled_transformer = compile_transformer() | |
| spaces.aoti_apply(compiled_transformer, pipe.transformer) | |
| print("✅ AOT Compilation successful and applied.") | |
| except Exception as e: | |
| print(f"⚠️ AOT Compilation failed (falling back to standard GPU mode). Performance may be lower. Error: {e}") | |
| # --- Inference Function --- | |
| def generate_image(prompt: str, steps: int, width: int, height: int, seed: int): | |
| # Apply the Chinese positive magic | |
| full_prompt = prompt + positive_magic | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| if width % 8 != 0 or height % 8 != 0: | |
| gr.Warning("Width and Height should be divisible by 8 for optimal performance.") | |
| # Set true_cfg_scale=1 as specified in the original request | |
| image = pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| true_cfg_scale=1, | |
| generator=generator | |
| ).images[0] | |
| return image | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Qwen-Image Text-to-Image Generation (AoT Optimized)") as demo: | |
| gr.HTML(f""" | |
| <div style="text-align: center; max-width: 800px; margin: 0 auto;"> | |
| <h1>Qwen-Image Pruning Text-to-Image</h1> | |
| <p>Optimized for speed using Gradio ZeroGPU AoT Compilation.</p> | |
| <p>🚨 Prompts should ideally be in Chinese for best results due to the model training and included magic prompts.</p> | |
| <p style="margin-top: 10px;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt (Chinese Recommended)", | |
| value='一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。', | |
| lines=3 | |
| ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| steps_slider = gr.Slider( | |
| minimum=4, maximum=50, value=8, step=1, label="Inference Steps" | |
| ) | |
| with gr.Row(): | |
| width_input = gr.Slider( | |
| minimum=512, maximum=1536, value=COMPILATION_WIDTH, step=8, label="Width", interactive=(device != "cuda") # Restrict changing size if AoT is active on a fixed resolution | |
| ) | |
| height_input = gr.Slider( | |
| minimum=512, maximum=1536, value=COMPILATION_HEIGHT, step=8, label="Height", interactive=(device != "cuda") | |
| ) | |
| if device == "cuda": | |
| gr.Markdown(f"Note: For maximum performance (AoT), recommended resolution is {COMPILATION_WIDTH}x{COMPILATION_HEIGHT}") | |
| seed_input = gr.Number(value=42, label="Seed", precision=0) | |
| random_seed_btn = gr.Button("🎲 Random Seed", scale=0) | |
| generate_btn = gr.Button("Generate Image", variant="primary") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Generated Image", show_share_button=True) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ['一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。'], | |
| ['海报,温馨家庭场景,柔和阳光洒在野餐布上,色彩温暖明亮。文字内容:“共享阳光,共享爱。”'], | |
| ['一个穿着校服的年轻女孩站在教室里,在黑板上写字。黑板中央用整洁的白粉笔写着“Introducing Qwen-Image”。'], | |
| ], | |
| inputs=prompt_input, | |
| outputs=output_image, | |
| fn=generate_image, | |
| cache_examples=False, | |
| run_on_click=True | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[prompt_input, steps_slider, width_input, height_input, seed_input], | |
| outputs=output_image, | |
| show_progress="minimal" | |
| ) | |
| random_seed_btn.click( | |
| fn=lambda: int(random.randint(0, 1000000)), | |
| inputs=[], | |
| outputs=seed_input, | |
| queue=False, | |
| show_progress="hidden" | |
| ) | |
| demo.queue().launch() |