File size: 4,467 Bytes
4f9093b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from PIL import Image
import torch
import numpy as np

from transformers import Qwen2_5_VLForConditionalGeneration

from diffusers import (
    QwenImagePipeline,
    QwenImageTransformer2DModel,
    QwenImageInpaintPipeline,
)

from optimum.quanto import quantize, qint8, freeze


prompt = (
    "equirectangular, a woman and a man sitting at a cafe, the woman has red hair "
    "and she's wearing purple sweater with a black scarf and a white hat, the man "
    "is sitting on the other side of the table and he's wearing a white shirt with "
    "a purple scarf and red hat, both of them are sipping their coffee while in the "
    "table there's some cake slices on their respective plates, each with forks and "
    "knives at each side."
)
negative_prompt = ""
output_filename = "qwen_int8.png"
width, height = 2048, 1024
true_cfg_scale = 4.0
num_inference_steps = 25
seed = 42

lora_model_id = "ProGamerGov/qwen-360-diffusion"
lora_filename = "qwen-360-diffusion-int8-bf16-v1.safetensors"

# Use the base fp16/bf16 model, not the nf4 variant
model_id = "Qwen/Qwen-Image"
torch_dtype = torch.bfloat16
device = "cuda"

fix_seam = True
inpaint_strength, seam_width = 0.5, 0.10


def shift_equirect(img):
    """Horizontal 50% shift using torch.roll."""
    t = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
    t = torch.roll(t, shifts=(0, t.shape[2] // 2), dims=(1, 2))
    return Image.fromarray((t.permute(1, 2, 0).numpy() * 255).astype(np.uint8))


def create_seam_mask(w, h, frac=0.10):
    """Create vertical seam mask as PIL Image (center seam)."""
    mask = torch.zeros((h, w))
    seam_w = max(1, int(w * frac))
    c = w // 2
    mask[:, c - seam_w // 2:c + seam_w // 2] = 1.0
    return Image.fromarray((mask.numpy() * 255).astype("uint8"), "L")


def load_pipeline(text_encoder, transformer, mode="t2i"):
    pip_class = QwenImagePipeline if mode == "t2i" else QwenImageInpaintPipeline
    pipe = pip_class.from_pretrained(
        model_id,
        transformer=transformer,
        text_encoder=text_encoder,
        torch_dtype=torch_dtype,
        use_safetensors=True,
        low_cpu_mem_usage=True,
    )
    pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
    pipe.enable_model_cpu_offload()
    pipe.enable_vae_tiling()

    # This still works with the quantized transformer
    return pipe


def main():
    # 1) Load and INT8-quantize transformer on CPU
    transformer = QwenImageTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
    )
    quantize(transformer, weights=qint8)
    freeze(transformer)

    # 2) Load and INT8-quantize text encoder on CPU
    text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        subfolder="text_encoder",
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        device_map={"": "cpu"},  # keep it on CPU; offload will move as needed
    )
    quantize(text_encoder, weights=qint8)
    freeze(text_encoder)

    # 3) Build T2I pipeline
    generator = torch.Generator(device=device).manual_seed(seed)
    pipe = load_pipeline(text_encoder, transformer, mode="t2i")

    # 4) First pass: base panorama
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        true_cfg_scale=true_cfg_scale,
        generator=generator,
    ).images[0]

    image.save(output_filename)

    # 5) Optional seam-fix pass using inpainting
    if fix_seam:
        del pipe
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        shifted = shift_equirect(image)  # roll 50% to expose seam
        mask = create_seam_mask(width, height, frac=seam_width)

        pipe = load_pipeline(text_encoder, transformer, mode="i2i")
        image_fixed = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=shifted,
            mask_image=mask,
            strength=inpaint_strength,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            true_cfg_scale=true_cfg_scale,
            generator=generator,
        ).images[0]
        image_fixed = shift_equirect(image_fixed)
        image_fixed.save(output_filename.replace(".png", "_seamfix.png"))


if __name__ == "__main__":
    main()