File size: 4,467 Bytes
4f9093b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from PIL import Image
import torch
import numpy as np
from transformers import Qwen2_5_VLForConditionalGeneration
from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
QwenImageInpaintPipeline,
)
from optimum.quanto import quantize, qint8, freeze
prompt = (
"equirectangular, a woman and a man sitting at a cafe, the woman has red hair "
"and she's wearing purple sweater with a black scarf and a white hat, the man "
"is sitting on the other side of the table and he's wearing a white shirt with "
"a purple scarf and red hat, both of them are sipping their coffee while in the "
"table there's some cake slices on their respective plates, each with forks and "
"knives at each side."
)
negative_prompt = ""
output_filename = "qwen_int8.png"
width, height = 2048, 1024
true_cfg_scale = 4.0
num_inference_steps = 25
seed = 42
lora_model_id = "ProGamerGov/qwen-360-diffusion"
lora_filename = "qwen-360-diffusion-int8-bf16-v1.safetensors"
# Use the base fp16/bf16 model, not the nf4 variant
model_id = "Qwen/Qwen-Image"
torch_dtype = torch.bfloat16
device = "cuda"
fix_seam = True
inpaint_strength, seam_width = 0.5, 0.10
def shift_equirect(img):
"""Horizontal 50% shift using torch.roll."""
t = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
t = torch.roll(t, shifts=(0, t.shape[2] // 2), dims=(1, 2))
return Image.fromarray((t.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
def create_seam_mask(w, h, frac=0.10):
"""Create vertical seam mask as PIL Image (center seam)."""
mask = torch.zeros((h, w))
seam_w = max(1, int(w * frac))
c = w // 2
mask[:, c - seam_w // 2:c + seam_w // 2] = 1.0
return Image.fromarray((mask.numpy() * 255).astype("uint8"), "L")
def load_pipeline(text_encoder, transformer, mode="t2i"):
pip_class = QwenImagePipeline if mode == "t2i" else QwenImageInpaintPipeline
pipe = pip_class.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
use_safetensors=True,
low_cpu_mem_usage=True,
)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
# This still works with the quantized transformer
return pipe
def main():
# 1) Load and INT8-quantize transformer on CPU
transformer = QwenImageTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
quantize(transformer, weights=qint8)
freeze(transformer)
# 2) Load and INT8-quantize text encoder on CPU
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
subfolder="text_encoder",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map={"": "cpu"}, # keep it on CPU; offload will move as needed
)
quantize(text_encoder, weights=qint8)
freeze(text_encoder)
# 3) Build T2I pipeline
generator = torch.Generator(device=device).manual_seed(seed)
pipe = load_pipeline(text_encoder, transformer, mode="t2i")
# 4) First pass: base panorama
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
true_cfg_scale=true_cfg_scale,
generator=generator,
).images[0]
image.save(output_filename)
# 5) Optional seam-fix pass using inpainting
if fix_seam:
del pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
shifted = shift_equirect(image) # roll 50% to expose seam
mask = create_seam_mask(width, height, frac=seam_width)
pipe = load_pipeline(text_encoder, transformer, mode="i2i")
image_fixed = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=shifted,
mask_image=mask,
strength=inpaint_strength,
width=width,
height=height,
num_inference_steps=num_inference_steps,
true_cfg_scale=true_cfg_scale,
generator=generator,
).images[0]
image_fixed = shift_equirect(image_fixed)
image_fixed.save(output_filename.replace(".png", "_seamfix.png"))
if __name__ == "__main__":
main()
|