Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ import requests
|
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Tuple
|
| 10 |
import asyncio
|
| 11 |
-
import aiohttp
|
| 12 |
|
| 13 |
# Initialize body estimation model
|
| 14 |
body_estimation = Body('model/body_pose_model.pth')
|
|
@@ -67,9 +66,9 @@ def pil2cv(image):
|
|
| 67 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
| 68 |
return new_image
|
| 69 |
|
| 70 |
-
|
| 71 |
"""
|
| 72 |
-
LLM을 사용하여 텍스트 프롬프트로부터 포즈 데이터를 생성
|
| 73 |
"""
|
| 74 |
system_prompt = """You are an expert in human pose generation. Given a description, generate precise OpenPose keypoint coordinates.
|
| 75 |
|
|
@@ -109,23 +108,21 @@ async def generate_pose_from_llm(prompt: str) -> Dict:
|
|
| 109 |
}
|
| 110 |
|
| 111 |
try:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
else:
|
| 128 |
-
return generate_template_pose(prompt)
|
| 129 |
except Exception as e:
|
| 130 |
print(f"LLM Error: {e}")
|
| 131 |
return generate_template_pose(prompt)
|
|
@@ -149,18 +146,31 @@ def generate_template_pose(prompt: str) -> Dict:
|
|
| 149 |
for i in range(18):
|
| 150 |
if i == 0: # Nose
|
| 151 |
candidate.append([256, 100, 1.0])
|
| 152 |
-
elif
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
x, y = template["keypoints"][part_name]
|
| 155 |
candidate.append([x, y, 1.0])
|
| 156 |
else:
|
| 157 |
-
# Estimate position based on nearby keypoints
|
| 158 |
candidate.append([256, 256, 0.0])
|
| 159 |
-
else:
|
| 160 |
-
candidate.append([0, 0, 0.0])
|
| 161 |
|
| 162 |
# Create subset (connection information)
|
| 163 |
-
|
|
|
|
| 164 |
|
| 165 |
return {"candidate": candidate, "subset": subset}
|
| 166 |
|
|
@@ -190,7 +200,7 @@ def refine_pose_with_llm(current_pose: Dict, refinement_prompt: str) -> Dict:
|
|
| 190 |
}
|
| 191 |
|
| 192 |
try:
|
| 193 |
-
response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload)
|
| 194 |
if response.status_code == 200:
|
| 195 |
data = response.json()
|
| 196 |
content = data['choices'][0]['message']['content']
|
|
@@ -205,8 +215,11 @@ def refine_pose_with_llm(current_pose: Dict, refinement_prompt: str) -> Dict:
|
|
| 205 |
return current_pose
|
| 206 |
|
| 207 |
# FastAPI setup
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
app = FastAPI()
|
| 212 |
|
|
@@ -239,12 +252,23 @@ async def some_fastapi_middleware(request: Request, call_next):
|
|
| 239 |
return response
|
| 240 |
|
| 241 |
def candidate_to_json_string(arr):
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
def subset_to_json_string(arr):
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
def estimate_body(source):
|
| 250 |
if source == None:
|
|
@@ -269,25 +293,18 @@ def image_changed(image):
|
|
| 269 |
jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
|
| 270 |
return f"""{image.width}px x {image.height}px, {subset.shape[0]} individual(s)""", jsonText
|
| 271 |
|
| 272 |
-
|
| 273 |
"""
|
| 274 |
-
텍스트 프롬프트로부터 포즈 생성
|
| 275 |
"""
|
| 276 |
if use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
|
| 277 |
-
pose_data =
|
| 278 |
else:
|
| 279 |
pose_data = generate_template_pose(prompt)
|
| 280 |
|
| 281 |
# Format for the pose editor
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
else:
|
| 285 |
-
candidate_str = js.dumps(pose_data['candidate'])
|
| 286 |
-
|
| 287 |
-
if isinstance(pose_data['subset'], list):
|
| 288 |
-
subset_str = subset_to_json_string(pose_data['subset'])
|
| 289 |
-
else:
|
| 290 |
-
subset_str = js.dumps(pose_data['subset'])
|
| 291 |
|
| 292 |
return "{ \"candidate\": " + candidate_str + ", \"subset\": " + subset_str + " }"
|
| 293 |
|
|
@@ -298,9 +315,9 @@ html_text = f"""
|
|
| 298 |
|
| 299 |
# Gradio interface
|
| 300 |
with gr.Blocks(css="""
|
| 301 |
-
button { min-width: 80px; }
|
| 302 |
-
.prompt-box { border: 2px solid #667eea; border-radius: 8px; padding: 10px; }
|
| 303 |
-
.llm-status { color: #667eea; font-weight: bold; }
|
| 304 |
""") as demo:
|
| 305 |
|
| 306 |
gr.Markdown("""
|
|
@@ -434,19 +451,16 @@ with gr.Blocks(css="""
|
|
| 434 |
)
|
| 435 |
|
| 436 |
# LLM generation events
|
| 437 |
-
|
| 438 |
if not prompt:
|
| 439 |
return None, "⚠️ Please enter a pose description"
|
| 440 |
|
| 441 |
try:
|
| 442 |
status = "🔄 Generating pose with AI..." if use_llm else "🔄 Using template..."
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
pose_json = await generate_pose_from_text(prompt, use_llm)
|
| 446 |
-
yield pose_json, "✅ Pose generated successfully!"
|
| 447 |
-
|
| 448 |
except Exception as e:
|
| 449 |
-
|
| 450 |
|
| 451 |
generate_btn.click(
|
| 452 |
fn=handle_generate,
|
|
@@ -467,7 +481,7 @@ with gr.Blocks(css="""
|
|
| 467 |
outputs=[refinement_prompt]
|
| 468 |
)
|
| 469 |
|
| 470 |
-
|
| 471 |
if not current_json or not refinement:
|
| 472 |
return None, "⚠️ Need current pose and refinement instructions"
|
| 473 |
|
|
@@ -530,4 +544,10 @@ with gr.Blocks(css="""
|
|
| 530 |
|
| 531 |
demo.load(fn=check_api_status, outputs=[llm_status])
|
| 532 |
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Tuple
|
| 10 |
import asyncio
|
|
|
|
| 11 |
|
| 12 |
# Initialize body estimation model
|
| 13 |
body_estimation = Body('model/body_pose_model.pth')
|
|
|
|
| 66 |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
| 67 |
return new_image
|
| 68 |
|
| 69 |
+
def generate_pose_from_llm_sync(prompt: str) -> Dict:
|
| 70 |
"""
|
| 71 |
+
LLM을 사용하여 텍스트 프롬프트로부터 포즈 데이터를 생성 (동기 버전)
|
| 72 |
"""
|
| 73 |
system_prompt = """You are an expert in human pose generation. Given a description, generate precise OpenPose keypoint coordinates.
|
| 74 |
|
|
|
|
| 108 |
}
|
| 109 |
|
| 110 |
try:
|
| 111 |
+
response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload, timeout=30)
|
| 112 |
+
if response.status_code == 200:
|
| 113 |
+
data = response.json()
|
| 114 |
+
content = data['choices'][0]['message']['content']
|
| 115 |
+
|
| 116 |
+
# Extract JSON from response
|
| 117 |
+
import re
|
| 118 |
+
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
| 119 |
+
if json_match:
|
| 120 |
+
pose_data = js.loads(json_match.group())
|
| 121 |
+
return pose_data
|
| 122 |
+
else:
|
| 123 |
+
return generate_template_pose(prompt)
|
| 124 |
+
else:
|
| 125 |
+
return generate_template_pose(prompt)
|
|
|
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
print(f"LLM Error: {e}")
|
| 128 |
return generate_template_pose(prompt)
|
|
|
|
| 146 |
for i in range(18):
|
| 147 |
if i == 0: # Nose
|
| 148 |
candidate.append([256, 100, 1.0])
|
| 149 |
+
elif i == 14: # REye
|
| 150 |
+
candidate.append([246, 90, 1.0])
|
| 151 |
+
elif i == 15: # LEye
|
| 152 |
+
candidate.append([266, 90, 1.0])
|
| 153 |
+
elif i == 16: # REar
|
| 154 |
+
candidate.append([236, 95, 1.0])
|
| 155 |
+
elif i == 17: # LEar
|
| 156 |
+
candidate.append([276, 95, 1.0])
|
| 157 |
+
else:
|
| 158 |
+
# Find part name for this index
|
| 159 |
+
part_name = None
|
| 160 |
+
for name, idx in BODY_PARTS.items():
|
| 161 |
+
if idx == i:
|
| 162 |
+
part_name = name
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
if part_name and part_name in template["keypoints"]:
|
| 166 |
x, y = template["keypoints"][part_name]
|
| 167 |
candidate.append([x, y, 1.0])
|
| 168 |
else:
|
|
|
|
| 169 |
candidate.append([256, 256, 0.0])
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Create subset (connection information)
|
| 172 |
+
valid_indices = [i for i in range(18) if candidate[i][2] > 0]
|
| 173 |
+
subset = [valid_indices + [float(len(valid_indices)), len(valid_indices)]]
|
| 174 |
|
| 175 |
return {"candidate": candidate, "subset": subset}
|
| 176 |
|
|
|
|
| 200 |
}
|
| 201 |
|
| 202 |
try:
|
| 203 |
+
response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload, timeout=30)
|
| 204 |
if response.status_code == 200:
|
| 205 |
data = response.json()
|
| 206 |
content = data['choices'][0]['message']['content']
|
|
|
|
| 215 |
return current_pose
|
| 216 |
|
| 217 |
# FastAPI setup
|
| 218 |
+
try:
|
| 219 |
+
with open("static/poseEditor.js", "r") as f:
|
| 220 |
+
file_contents = f.read()
|
| 221 |
+
except:
|
| 222 |
+
file_contents = "console.log('PoseEditor.js not found');"
|
| 223 |
|
| 224 |
app = FastAPI()
|
| 225 |
|
|
|
|
| 252 |
return response
|
| 253 |
|
| 254 |
def candidate_to_json_string(arr):
|
| 255 |
+
if isinstance(arr, list):
|
| 256 |
+
a = []
|
| 257 |
+
for item in arr:
|
| 258 |
+
if len(item) >= 2:
|
| 259 |
+
x, y = item[0], item[1]
|
| 260 |
+
a.append(f'[{float(x):.2f}, {float(y):.2f}]')
|
| 261 |
+
return '[' + ', '.join(a) + ']'
|
| 262 |
+
return '[]'
|
| 263 |
|
| 264 |
def subset_to_json_string(arr):
|
| 265 |
+
if isinstance(arr, np.ndarray):
|
| 266 |
+
arr_str = ','.join(['[' + ','.join([f'{num:.2f}' for num in row]) + ']' for row in arr])
|
| 267 |
+
return '[' + arr_str + ']'
|
| 268 |
+
elif isinstance(arr, list):
|
| 269 |
+
arr_str = ','.join(['[' + ','.join([f'{float(num):.2f}' for num in row]) + ']' for row in arr])
|
| 270 |
+
return '[' + arr_str + ']'
|
| 271 |
+
return '[]'
|
| 272 |
|
| 273 |
def estimate_body(source):
|
| 274 |
if source == None:
|
|
|
|
| 293 |
jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
|
| 294 |
return f"""{image.width}px x {image.height}px, {subset.shape[0]} individual(s)""", jsonText
|
| 295 |
|
| 296 |
+
def generate_pose_from_text(prompt: str, use_llm: bool = True):
|
| 297 |
"""
|
| 298 |
+
텍스트 프롬프트로부터 포즈 생성 (동기 버전)
|
| 299 |
"""
|
| 300 |
if use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
|
| 301 |
+
pose_data = generate_pose_from_llm_sync(prompt)
|
| 302 |
else:
|
| 303 |
pose_data = generate_template_pose(prompt)
|
| 304 |
|
| 305 |
# Format for the pose editor
|
| 306 |
+
candidate_str = candidate_to_json_string(pose_data['candidate'])
|
| 307 |
+
subset_str = subset_to_json_string(pose_data['subset'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
return "{ \"candidate\": " + candidate_str + ", \"subset\": " + subset_str + " }"
|
| 310 |
|
|
|
|
| 315 |
|
| 316 |
# Gradio interface
|
| 317 |
with gr.Blocks(css="""
|
| 318 |
+
button {{ min-width: 80px; }}
|
| 319 |
+
.prompt-box {{ border: 2px solid #667eea; border-radius: 8px; padding: 10px; }}
|
| 320 |
+
.llm-status {{ color: #667eea; font-weight: bold; }}
|
| 321 |
""") as demo:
|
| 322 |
|
| 323 |
gr.Markdown("""
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
# LLM generation events
|
| 454 |
+
def handle_generate(prompt, use_llm):
|
| 455 |
if not prompt:
|
| 456 |
return None, "⚠️ Please enter a pose description"
|
| 457 |
|
| 458 |
try:
|
| 459 |
status = "🔄 Generating pose with AI..." if use_llm else "🔄 Using template..."
|
| 460 |
+
pose_json = generate_pose_from_text(prompt, use_llm)
|
| 461 |
+
return pose_json, "✅ Pose generated successfully!"
|
|
|
|
|
|
|
|
|
|
| 462 |
except Exception as e:
|
| 463 |
+
return None, f"❌ Error: {str(e)}"
|
| 464 |
|
| 465 |
generate_btn.click(
|
| 466 |
fn=handle_generate,
|
|
|
|
| 481 |
outputs=[refinement_prompt]
|
| 482 |
)
|
| 483 |
|
| 484 |
+
def handle_refine(current_json, refinement):
|
| 485 |
if not current_json or not refinement:
|
| 486 |
return None, "⚠️ Need current pose and refinement instructions"
|
| 487 |
|
|
|
|
| 544 |
|
| 545 |
demo.load(fn=check_api_status, outputs=[llm_status])
|
| 546 |
|
| 547 |
+
# Mount Gradio app to FastAPI
|
| 548 |
+
gr.mount_gradio_app(app, demo, path="/")
|
| 549 |
+
|
| 550 |
+
# For running the app
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
import uvicorn
|
| 553 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|