AlterProgramming commited on
Commit
ccc57f0
·
verified ·
1 Parent(s): 7faaef2

add txt2img tab + infer_txt2img endpoint (SD 1.5)

Browse files
Files changed (2) hide show
  1. app.py +41 -13
  2. studio/backends/hf_space.py +47 -0
app.py CHANGED
@@ -38,12 +38,20 @@ def _get_adapter() -> AnimateDiffAdapter:
38
  return _adapter
39
 
40
 
 
 
 
 
 
 
 
41
  def _get_sd_pipe():
42
- """Lazily load SD 1.5 txt2img pipeline (runs inside @spaces.GPU context).
 
43
 
44
- Cached at module scope so warm calls skip re-loading. Cold start is ~30s
45
- including weight download on first call ever (cached in persistent storage
46
- for subsequent cold starts).
47
  """
48
  global _sd_pipe
49
  if _sd_pipe is not None:
@@ -59,6 +67,16 @@ def _get_sd_pipe():
59
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
60
  pipe = pipe.to("cuda")
61
  pipe.set_progress_bar_config(disable=True)
 
 
 
 
 
 
 
 
 
 
62
  _sd_pipe = pipe
63
  return pipe
64
 
@@ -106,15 +124,24 @@ def infer_txt2img(
106
  height: int,
107
  width: int,
108
  seed: int,
 
109
  ) -> str:
110
  """Generate a single sprite from a text prompt. Returns path to a PNG.
111
 
112
- Defaults tuned for pixel-art sprites: 512×512, 25 steps, guidance 7.5.
113
- Caller downscales to target res (256×256 is the Dicer sprite size).
 
 
 
 
114
  """
115
  import torch
116
 
117
  pipe = _get_sd_pipe()
 
 
 
 
118
  g = torch.Generator(device="cuda").manual_seed(int(seed))
119
  out = pipe(
120
  prompt=prompt,
@@ -169,17 +196,17 @@ with gr.Blocks(title="Venture-Studio") as demo:
169
  with gr.Column():
170
  t2i_prompt = gr.Textbox(
171
  value=(
172
- "pixel art, 16-bit fantasy goblin warrior, empty hands, "
173
- "no weapon, standing pose, full body, centered, "
174
- "white background, sharp pixels, blocky, retro game sprite"
175
  ),
176
- lines=3, label="prompt",
177
  )
178
  t2i_neg = gr.Textbox(
179
  value=(
180
  "sword, weapon, dagger, axe, staff, blurry, soft, "
181
- "anti-aliasing, smooth, photorealistic, 3d render, "
182
- "extra limbs, distorted"
183
  ),
184
  lines=2, label="negative_prompt",
185
  )
@@ -189,6 +216,7 @@ with gr.Blocks(title="Venture-Studio") as demo:
189
  t2i_height = gr.Slider(256, 768, value=512, step=64, label="height")
190
  t2i_width = gr.Slider(256, 768, value=512, step=64, label="width")
191
  t2i_seed = gr.Number(value=0, precision=0, label="seed (0 = random)")
 
192
  t2i_run = gr.Button("Generate sprite", variant="primary")
193
  with gr.Column():
194
  t2i_out = gr.Image(label="Generated sprite", height=512)
@@ -196,7 +224,7 @@ with gr.Blocks(title="Venture-Studio") as demo:
196
  t2i_run.click(
197
  infer_txt2img,
198
  inputs=[t2i_prompt, t2i_neg, t2i_steps, t2i_guidance,
199
- t2i_height, t2i_width, t2i_seed],
200
  outputs=t2i_out,
201
  api_name="infer_txt2img",
202
  )
 
38
  return _adapter
39
 
40
 
41
+ PIXEL_ART_LORA_REPO = "artificialguybr/pixelartredmond-1-5v-pixel-art-loras-for-sd-1-5"
42
+ PIXEL_ART_LORA_WEIGHT_FILE = "PixelArtRedmond15V-PixelArt-PIXARFK.safetensors"
43
+ PIXEL_ART_LORA_ADAPTER = "pixart"
44
+ # Trigger words: "pixel art, PixArFK" should appear in the prompt for the LoRA
45
+ # to engage. The probe + UI include them by default.
46
+
47
+
48
  def _get_sd_pipe():
49
+ """Lazily load SD 1.5 txt2img pipeline + PixelArtRedmond LoRA (runs inside
50
+ @spaces.GPU context).
51
 
52
+ Cached at module scope so warm calls skip re-loading. Cold start is ~30-60s
53
+ including weight downloads on first call ever (cached in persistent storage
54
+ for subsequent cold starts). LoRA weight is set per-call via set_adapters().
55
  """
56
  global _sd_pipe
57
  if _sd_pipe is not None:
 
67
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
68
  pipe = pipe.to("cuda")
69
  pipe.set_progress_bar_config(disable=True)
70
+ try:
71
+ pipe.load_lora_weights(
72
+ PIXEL_ART_LORA_REPO,
73
+ weight_name=PIXEL_ART_LORA_WEIGHT_FILE,
74
+ adapter_name=PIXEL_ART_LORA_ADAPTER,
75
+ )
76
+ pipe.set_adapters([PIXEL_ART_LORA_ADAPTER], adapter_weights=[0.9])
77
+ print(f"loaded pixel-art LoRA: {PIXEL_ART_LORA_REPO}", flush=True)
78
+ except Exception as e:
79
+ print(f"WARN: could not load LoRA {PIXEL_ART_LORA_REPO}: {e}", flush=True)
80
  _sd_pipe = pipe
81
  return pipe
82
 
 
124
  height: int,
125
  width: int,
126
  seed: int,
127
+ lora_weight: float,
128
  ) -> str:
129
  """Generate a single sprite from a text prompt. Returns path to a PNG.
130
 
131
+ Defaults tuned for pixel-art sprites: 512×512, 25 steps, guidance 7.5,
132
+ LoRA strength 0.9. Caller downscales to target res (256×256 is the Dicer
133
+ sprite size).
134
+
135
+ Prompt must include "pixel art, PixArFK" to engage the LoRA. The UI
136
+ pre-populates these tokens; the probe always includes them.
137
  """
138
  import torch
139
 
140
  pipe = _get_sd_pipe()
141
+ try:
142
+ pipe.set_adapters([PIXEL_ART_LORA_ADAPTER], adapter_weights=[float(lora_weight)])
143
+ except Exception as e:
144
+ print(f"WARN: set_adapters failed: {e}", flush=True)
145
  g = torch.Generator(device="cuda").manual_seed(int(seed))
146
  out = pipe(
147
  prompt=prompt,
 
196
  with gr.Column():
197
  t2i_prompt = gr.Textbox(
198
  value=(
199
+ "pixel art, PixArFK, fantasy goblin warrior, green skin, "
200
+ "leather armor, empty hands, unarmed, standing pose, "
201
+ "full body, centered, white background, retro game sprite"
202
  ),
203
+ lines=3, label="prompt (include 'pixel art, PixArFK' for LoRA)",
204
  )
205
  t2i_neg = gr.Textbox(
206
  value=(
207
  "sword, weapon, dagger, axe, staff, blurry, soft, "
208
+ "photorealistic, 3d render, extra limbs, distorted, "
209
+ "multiple characters"
210
  ),
211
  lines=2, label="negative_prompt",
212
  )
 
216
  t2i_height = gr.Slider(256, 768, value=512, step=64, label="height")
217
  t2i_width = gr.Slider(256, 768, value=512, step=64, label="width")
218
  t2i_seed = gr.Number(value=0, precision=0, label="seed (0 = random)")
219
+ t2i_lora = gr.Slider(0.0, 1.5, value=0.9, step=0.05, label="LoRA weight (PixelArtRedmond)")
220
  t2i_run = gr.Button("Generate sprite", variant="primary")
221
  with gr.Column():
222
  t2i_out = gr.Image(label="Generated sprite", height=512)
 
224
  t2i_run.click(
225
  infer_txt2img,
226
  inputs=[t2i_prompt, t2i_neg, t2i_steps, t2i_guidance,
227
+ t2i_height, t2i_width, t2i_seed, t2i_lora],
228
  outputs=t2i_out,
229
  api_name="infer_txt2img",
230
  )
studio/backends/hf_space.py CHANGED
@@ -172,3 +172,50 @@ def _gradio_client_available() -> bool:
172
  return True
173
  except ImportError:
174
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return True
173
  except ImportError:
174
  return False
175
+
176
+
177
+ def generate_sprite_via_space(
178
+ prompt: str,
179
+ *,
180
+ space_id: str,
181
+ negative_prompt: str = "",
182
+ num_inference_steps: int = 25,
183
+ guidance_scale: float = 7.5,
184
+ height: int = 512,
185
+ width: int = 512,
186
+ seed: int = 0,
187
+ lora_weight: float = 0.9,
188
+ hf_token: Optional[str] = None,
189
+ api_name: str = "/infer_txt2img",
190
+ ) -> PILImage.Image:
191
+ """Call the Space's txt2img endpoint and return the generated sprite.
192
+
193
+ This is a free-standing helper (not bound to the cursor/motion abstraction).
194
+ A sprite is a fresh artifact, not a transform — wrapping it in a PixelCursor
195
+ would add ceremony without benefit.
196
+
197
+ Prompt should include the PixelArtRedmond trigger words "pixel art, PixArFK"
198
+ to engage the loaded LoRA. lora_weight scales LoRA strength (0.0 disables,
199
+ 1.5 is the upper end; default 0.9).
200
+
201
+ Raises ImportError if gradio_client is missing locally.
202
+ """
203
+ try:
204
+ from gradio_client import Client
205
+ except ImportError as e:
206
+ raise ImportError(INSTALL_HINT) from e
207
+
208
+ token = _resolve_hf_token(hf_token)
209
+ client = Client(space_id, token=token)
210
+ png_path = client.predict(
211
+ prompt,
212
+ negative_prompt,
213
+ int(num_inference_steps),
214
+ float(guidance_scale),
215
+ int(height),
216
+ int(width),
217
+ int(seed),
218
+ float(lora_weight),
219
+ api_name=api_name,
220
+ )
221
+ return PILImage.open(png_path).convert("RGB")