Z-Image-Turbo / app.py
knighjok's picture
feat: remove unused code
f9212b1
import spaces
import logging
import os
import random
import re
import sys
import warnings
from PIL import Image
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from diffusers import ZImagePipeline
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
# ==================== Environment Variables ==================================
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
HF_TOKEN = os.environ.get("HF_TOKEN")
# =============================================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
RES_CHOICES = {
"1024": [
"1024x1024 ( 1:1 )",
"1152x896 ( 9:7 )",
"896x1152 ( 7:9 )",
"1152x864 ( 4:3 )",
"864x1152 ( 3:4 )",
"1248x832 ( 3:2 )",
"832x1248 ( 2:3 )",
"1280x720 ( 16:9 )",
"720x1280 ( 9:16 )",
"1344x576 ( 21:9 )",
"576x1344 ( 9:21 )",
],
"1280": [
"1280x1280 ( 1:1 )",
"1440x1120 ( 9:7 )",
"1120x1440 ( 7:9 )",
"1472x1104 ( 4:3 )",
"1104x1472 ( 3:4 )",
"1536x1024 ( 3:2 )",
"1024x1536 ( 2:3 )",
"1536x864 ( 16:9 )",
"864x1536 ( 9:16 )",
"1680x720 ( 21:9 )",
"720x1680 ( 9:21 )",
],
"1536": [
"1536x1536 ( 1:1 )",
"1728x1344 ( 9:7 )",
"1344x1728 ( 7:9 )",
"1728x1296 ( 4:3 )",
"1296x1728 ( 3:4 )",
"1872x1248 ( 3:2 )",
"1248x1872 ( 2:3 )",
"2048x1152 ( 16:9 )",
"1152x2048 ( 9:16 )",
"2016x864 ( 21:9 )",
"864x2016 ( 9:21 )",
],
"2048": [
"2048x2048 ( 1:1 )",
"2304x1792 ( 9:7 )",
"1792x2304 ( 7:9 )",
"2304x1728 ( 4:3 )",
"1728x2304 ( 3:4 )",
"2496x1664 ( 3:2 )",
"1664x2496 ( 2:3 )",
"2720x1536 ( 16:9 )",
"1536x2720 ( 9:16 )",
"2688x1152 ( 21:9 )",
"1152x2688 ( 9:21 )",
],
}
RESOLUTION_SET = []
for resolutions in RES_CHOICES.values():
RESOLUTION_SET.extend(resolutions)
EXAMPLE_PROMPTS = [
["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"]
]
def get_resolution(resolution):
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
if match:
return int(match.group(1)), int(match.group(2))
return 1024, 1024
def load_models(model_path, enable_compile=False, attention_backend="flash_3"):
print(f"Loading models from {model_path}...")
use_auth_token = HF_TOKEN if HF_TOKEN else True
if not os.path.exists(model_path):
vae = AutoencoderKL.from_pretrained(
f"{model_path}",
subfolder="vae",
torch_dtype=torch.bfloat16,
device_map="cuda",
use_auth_token=use_auth_token,
)
text_encoder = AutoModelForCausalLM.from_pretrained(
f"{model_path}",
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
device_map="cuda",
use_auth_token=use_auth_token,
).eval()
tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token)
else:
vae = AutoencoderKL.from_pretrained(
os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
)
text_encoder = AutoModelForCausalLM.from_pretrained(
os.path.join(model_path, "text_encoder"),
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
tokenizer.padding_side = "left"
if enable_compile:
print("Enabling torch.compile optimizations...")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.max_autotune_gemm = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
torch._inductor.config.triton.cudagraphs = False
pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
if enable_compile:
pipe.vae.disable_tiling()
if not os.path.exists(model_path):
transformer = ZImageTransformer2DModel.from_pretrained(
f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
).to("cuda", torch.bfloat16)
else:
transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
"cuda", torch.bfloat16
)
pipe.transformer = transformer
pipe.transformer.set_attention_backend(attention_backend)
if enable_compile:
print("Compiling transformer...")
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
pipe.to("cuda", torch.bfloat16)
return pipe
def generate_image(
pipe,
prompt,
resolution="1024x1024",
seed=42,
guidance_scale=5.0,
num_inference_steps=50,
shift=3.0,
max_sequence_length=512,
progress=gr.Progress(track_tqdm=True),
):
width, height = get_resolution(resolution)
generator = torch.Generator("cuda").manual_seed(seed)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
pipe.scheduler = scheduler
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
max_sequence_length=max_sequence_length,
).images[0]
return image
def warmup_model(pipe, resolutions):
print("Starting warmup phase...")
dummy_prompt = "warmup"
for res_str in resolutions:
print(f"Warming up for resolution: {res_str}")
try:
for i in range(3):
generate_image(
pipe,
prompt=dummy_prompt,
resolution=res_str,
num_inference_steps=9,
guidance_scale=0.0,
seed=42 + i,
)
except Exception as e:
print(f"Warmup failed for {res_str}: {e}")
print("Warmup completed.")
pipe = None
def init_app():
global pipe
try:
pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
if ENABLE_WARMUP:
all_resolutions = []
for cat in RES_CHOICES.values():
all_resolutions.extend(cat)
warmup_model(pipe, all_resolutions)
except Exception as e:
print(f"Error loading model: {e}")
pipe = None
@spaces.GPU
def generate(
prompt,
resolution="1024x1024 ( 1:1 )",
seed=42,
steps=9,
shift=3.0,
random_seed=True,
gallery_images=None,
progress=gr.Progress(track_tqdm=True),
):
"""
Generate an image using the Z-Image model based on the provided prompt and settings.
Args:
prompt (str): Text prompt describing the desired image content
resolution (str): Output resolution
seed (int): Seed for reproducible generation
steps (int): Number of inference steps
shift (float): Time shift parameter
random_seed (bool): Whether to generate a new random seed
gallery_images (list): List of previously generated images
progress (gr.Progress): Gradio progress tracker
Returns:
tuple: (gallery_images, seed_str, seed_int)
"""
if random_seed:
new_seed = random.randint(1, 1000000)
else:
new_seed = seed if seed != -1 else random.randint(1, 1000000)
try:
if pipe is None:
raise gr.Error("Model not loaded.")
final_prompt = prompt
try:
resolution_str = resolution.split(" ")[0]
except:
resolution_str = "1024x1024"
image = generate_image(
pipe=pipe,
prompt=final_prompt,
resolution=resolution_str,
seed=new_seed,
guidance_scale=0.0,
num_inference_steps=int(steps),
shift=shift,
)
except Exception as e:
print(f"Error generation: {e}")
# Return empty/error image or re-raise
# For now, just re-raising to let Gradio handle or user see error
raise e
if gallery_images is None:
gallery_images = []
gallery_images = [image] + gallery_images
return gallery_images, str(new_seed), int(new_seed)
init_app()
# ==================== AoTI (Ahead of Time Inductor compilation) ====================
# pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
# spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
with gr.Blocks(title="Z-Image Demo") as demo:
gr.Markdown(
"""<div align="center">
# Z-Image Generation Demo
</div>"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
with gr.Row():
choices = [int(k) for k in RES_CHOICES.keys()]
res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
initial_res_choices = RES_CHOICES["1024"]
resolution = gr.Dropdown(
value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)"
)
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
random_seed = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=True)
shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
generate_btn = gr.Button("Generate", variant="primary")
# Example prompts
gr.Markdown("### 📝 Example Prompts")
gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="Generated Images",
columns=2,
rows=2,
height=600,
object_fit="contain",
format="png",
interactive=False,
)
used_seed = gr.Textbox(label="Seed Used", interactive=False)
def update_res_choices(_res_cat):
if str(_res_cat) in RES_CHOICES:
res_choices = RES_CHOICES[str(_res_cat)]
else:
res_choices = RES_CHOICES["1024"]
return gr.update(value=res_choices[0], choices=res_choices)
res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
generate_btn.click(
generate,
inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery],
outputs=[output_gallery, used_seed, seed],
api_visibility="public",
)
css = """
.fillable{max-width: 1230px !important}
"""
if __name__ == "__main__":
demo.launch(css=css, mcp_server=True)