aiqtech commited on
Commit
3ecdd57
·
verified ·
1 Parent(s): 502f9c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -57
app.py CHANGED
@@ -8,7 +8,6 @@ import requests
8
  import os
9
  from typing import Dict, List, Tuple
10
  import asyncio
11
- import aiohttp
12
 
13
  # Initialize body estimation model
14
  body_estimation = Body('model/body_pose_model.pth')
@@ -67,9 +66,9 @@ def pil2cv(image):
67
  new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
68
  return new_image
69
 
70
- async def generate_pose_from_llm(prompt: str) -> Dict:
71
  """
72
- LLM을 사용하여 텍스트 프롬프트로부터 포즈 데이터를 생성
73
  """
74
  system_prompt = """You are an expert in human pose generation. Given a description, generate precise OpenPose keypoint coordinates.
75
 
@@ -109,23 +108,21 @@ async def generate_pose_from_llm(prompt: str) -> Dict:
109
  }
110
 
111
  try:
112
- async with aiohttp.ClientSession() as session:
113
- async with session.post(FIREWORKS_API_URL, headers=headers, json=payload) as response:
114
- if response.status == 200:
115
- data = await response.json()
116
- content = data['choices'][0]['message']['content']
117
-
118
- # Extract JSON from response
119
- import re
120
- json_match = re.search(r'\{.*\}', content, re.DOTALL)
121
- if json_match:
122
- pose_data = js.loads(json_match.group())
123
- return pose_data
124
- else:
125
- # Fallback to template
126
- return generate_template_pose(prompt)
127
- else:
128
- return generate_template_pose(prompt)
129
  except Exception as e:
130
  print(f"LLM Error: {e}")
131
  return generate_template_pose(prompt)
@@ -149,18 +146,31 @@ def generate_template_pose(prompt: str) -> Dict:
149
  for i in range(18):
150
  if i == 0: # Nose
151
  candidate.append([256, 100, 1.0])
152
- elif part_name := next((k for k, v in BODY_PARTS.items() if v == i), None):
153
- if part_name in template["keypoints"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  x, y = template["keypoints"][part_name]
155
  candidate.append([x, y, 1.0])
156
  else:
157
- # Estimate position based on nearby keypoints
158
  candidate.append([256, 256, 0.0])
159
- else:
160
- candidate.append([0, 0, 0.0])
161
 
162
  # Create subset (connection information)
163
- subset = [[i for i in range(18) if candidate[i][2] > 0] + [18.0, 18]]
 
164
 
165
  return {"candidate": candidate, "subset": subset}
166
 
@@ -190,7 +200,7 @@ def refine_pose_with_llm(current_pose: Dict, refinement_prompt: str) -> Dict:
190
  }
191
 
192
  try:
193
- response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload)
194
  if response.status_code == 200:
195
  data = response.json()
196
  content = data['choices'][0]['message']['content']
@@ -205,8 +215,11 @@ def refine_pose_with_llm(current_pose: Dict, refinement_prompt: str) -> Dict:
205
  return current_pose
206
 
207
  # FastAPI setup
208
- with open("static/poseEditor.js", "r") as f:
209
- file_contents = f.read()
 
 
 
210
 
211
  app = FastAPI()
212
 
@@ -239,12 +252,23 @@ async def some_fastapi_middleware(request: Request, call_next):
239
  return response
240
 
241
  def candidate_to_json_string(arr):
242
- a = [f'[{x:.2f}, {y:.2f}]' for x, y, *_ in arr]
243
- return '[' + ', '.join(a) + ']'
 
 
 
 
 
 
244
 
245
  def subset_to_json_string(arr):
246
- arr_str = ','.join(['[' + ','.join([f'{num:.2f}' for num in row]) + ']' for row in arr])
247
- return '[' + arr_str + ']'
 
 
 
 
 
248
 
249
  def estimate_body(source):
250
  if source == None:
@@ -269,25 +293,18 @@ def image_changed(image):
269
  jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
270
  return f"""{image.width}px x {image.height}px, {subset.shape[0]} individual(s)""", jsonText
271
 
272
- async def generate_pose_from_text(prompt: str, use_llm: bool = True):
273
  """
274
- 텍스트 프롬프트로부터 포즈 생성
275
  """
276
  if use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
277
- pose_data = await generate_pose_from_llm(prompt)
278
  else:
279
  pose_data = generate_template_pose(prompt)
280
 
281
  # Format for the pose editor
282
- if isinstance(pose_data['candidate'], list):
283
- candidate_str = candidate_to_json_string(pose_data['candidate'])
284
- else:
285
- candidate_str = js.dumps(pose_data['candidate'])
286
-
287
- if isinstance(pose_data['subset'], list):
288
- subset_str = subset_to_json_string(pose_data['subset'])
289
- else:
290
- subset_str = js.dumps(pose_data['subset'])
291
 
292
  return "{ \"candidate\": " + candidate_str + ", \"subset\": " + subset_str + " }"
293
 
@@ -298,9 +315,9 @@ html_text = f"""
298
 
299
  # Gradio interface
300
  with gr.Blocks(css="""
301
- button { min-width: 80px; }
302
- .prompt-box { border: 2px solid #667eea; border-radius: 8px; padding: 10px; }
303
- .llm-status { color: #667eea; font-weight: bold; }
304
  """) as demo:
305
 
306
  gr.Markdown("""
@@ -434,19 +451,16 @@ with gr.Blocks(css="""
434
  )
435
 
436
  # LLM generation events
437
- async def handle_generate(prompt, use_llm):
438
  if not prompt:
439
  return None, "⚠️ Please enter a pose description"
440
 
441
  try:
442
  status = "🔄 Generating pose with AI..." if use_llm else "🔄 Using template..."
443
- yield None, status
444
-
445
- pose_json = await generate_pose_from_text(prompt, use_llm)
446
- yield pose_json, "✅ Pose generated successfully!"
447
-
448
  except Exception as e:
449
- yield None, f"❌ Error: {str(e)}"
450
 
451
  generate_btn.click(
452
  fn=handle_generate,
@@ -467,7 +481,7 @@ with gr.Blocks(css="""
467
  outputs=[refinement_prompt]
468
  )
469
 
470
- async def handle_refine(current_json, refinement):
471
  if not current_json or not refinement:
472
  return None, "⚠️ Need current pose and refinement instructions"
473
 
@@ -530,4 +544,10 @@ with gr.Blocks(css="""
530
 
531
  demo.load(fn=check_api_status, outputs=[llm_status])
532
 
533
- gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
8
  import os
9
  from typing import Dict, List, Tuple
10
  import asyncio
 
11
 
12
  # Initialize body estimation model
13
  body_estimation = Body('model/body_pose_model.pth')
 
66
  new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
67
  return new_image
68
 
69
+ def generate_pose_from_llm_sync(prompt: str) -> Dict:
70
  """
71
+ LLM을 사용하여 텍스트 프롬프트로부터 포즈 데이터를 생성 (동기 버전)
72
  """
73
  system_prompt = """You are an expert in human pose generation. Given a description, generate precise OpenPose keypoint coordinates.
74
 
 
108
  }
109
 
110
  try:
111
+ response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload, timeout=30)
112
+ if response.status_code == 200:
113
+ data = response.json()
114
+ content = data['choices'][0]['message']['content']
115
+
116
+ # Extract JSON from response
117
+ import re
118
+ json_match = re.search(r'\{.*\}', content, re.DOTALL)
119
+ if json_match:
120
+ pose_data = js.loads(json_match.group())
121
+ return pose_data
122
+ else:
123
+ return generate_template_pose(prompt)
124
+ else:
125
+ return generate_template_pose(prompt)
 
 
126
  except Exception as e:
127
  print(f"LLM Error: {e}")
128
  return generate_template_pose(prompt)
 
146
  for i in range(18):
147
  if i == 0: # Nose
148
  candidate.append([256, 100, 1.0])
149
+ elif i == 14: # REye
150
+ candidate.append([246, 90, 1.0])
151
+ elif i == 15: # LEye
152
+ candidate.append([266, 90, 1.0])
153
+ elif i == 16: # REar
154
+ candidate.append([236, 95, 1.0])
155
+ elif i == 17: # LEar
156
+ candidate.append([276, 95, 1.0])
157
+ else:
158
+ # Find part name for this index
159
+ part_name = None
160
+ for name, idx in BODY_PARTS.items():
161
+ if idx == i:
162
+ part_name = name
163
+ break
164
+
165
+ if part_name and part_name in template["keypoints"]:
166
  x, y = template["keypoints"][part_name]
167
  candidate.append([x, y, 1.0])
168
  else:
 
169
  candidate.append([256, 256, 0.0])
 
 
170
 
171
  # Create subset (connection information)
172
+ valid_indices = [i for i in range(18) if candidate[i][2] > 0]
173
+ subset = [valid_indices + [float(len(valid_indices)), len(valid_indices)]]
174
 
175
  return {"candidate": candidate, "subset": subset}
176
 
 
200
  }
201
 
202
  try:
203
+ response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload, timeout=30)
204
  if response.status_code == 200:
205
  data = response.json()
206
  content = data['choices'][0]['message']['content']
 
215
  return current_pose
216
 
217
  # FastAPI setup
218
+ try:
219
+ with open("static/poseEditor.js", "r") as f:
220
+ file_contents = f.read()
221
+ except:
222
+ file_contents = "console.log('PoseEditor.js not found');"
223
 
224
  app = FastAPI()
225
 
 
252
  return response
253
 
254
  def candidate_to_json_string(arr):
255
+ if isinstance(arr, list):
256
+ a = []
257
+ for item in arr:
258
+ if len(item) >= 2:
259
+ x, y = item[0], item[1]
260
+ a.append(f'[{float(x):.2f}, {float(y):.2f}]')
261
+ return '[' + ', '.join(a) + ']'
262
+ return '[]'
263
 
264
  def subset_to_json_string(arr):
265
+ if isinstance(arr, np.ndarray):
266
+ arr_str = ','.join(['[' + ','.join([f'{num:.2f}' for num in row]) + ']' for row in arr])
267
+ return '[' + arr_str + ']'
268
+ elif isinstance(arr, list):
269
+ arr_str = ','.join(['[' + ','.join([f'{float(num):.2f}' for num in row]) + ']' for row in arr])
270
+ return '[' + arr_str + ']'
271
+ return '[]'
272
 
273
  def estimate_body(source):
274
  if source == None:
 
293
  jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
294
  return f"""{image.width}px x {image.height}px, {subset.shape[0]} individual(s)""", jsonText
295
 
296
+ def generate_pose_from_text(prompt: str, use_llm: bool = True):
297
  """
298
+ 텍스트 프롬프트로부터 포즈 생성 (동기 버전)
299
  """
300
  if use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
301
+ pose_data = generate_pose_from_llm_sync(prompt)
302
  else:
303
  pose_data = generate_template_pose(prompt)
304
 
305
  # Format for the pose editor
306
+ candidate_str = candidate_to_json_string(pose_data['candidate'])
307
+ subset_str = subset_to_json_string(pose_data['subset'])
 
 
 
 
 
 
 
308
 
309
  return "{ \"candidate\": " + candidate_str + ", \"subset\": " + subset_str + " }"
310
 
 
315
 
316
  # Gradio interface
317
  with gr.Blocks(css="""
318
+ button {{ min-width: 80px; }}
319
+ .prompt-box {{ border: 2px solid #667eea; border-radius: 8px; padding: 10px; }}
320
+ .llm-status {{ color: #667eea; font-weight: bold; }}
321
  """) as demo:
322
 
323
  gr.Markdown("""
 
451
  )
452
 
453
  # LLM generation events
454
+ def handle_generate(prompt, use_llm):
455
  if not prompt:
456
  return None, "⚠️ Please enter a pose description"
457
 
458
  try:
459
  status = "🔄 Generating pose with AI..." if use_llm else "🔄 Using template..."
460
+ pose_json = generate_pose_from_text(prompt, use_llm)
461
+ return pose_json, "✅ Pose generated successfully!"
 
 
 
462
  except Exception as e:
463
+ return None, f"❌ Error: {str(e)}"
464
 
465
  generate_btn.click(
466
  fn=handle_generate,
 
481
  outputs=[refinement_prompt]
482
  )
483
 
484
+ def handle_refine(current_json, refinement):
485
  if not current_json or not refinement:
486
  return None, "⚠️ Need current pose and refinement instructions"
487
 
 
544
 
545
  demo.load(fn=check_api_status, outputs=[llm_status])
546
 
547
+ # Mount Gradio app to FastAPI
548
+ gr.mount_gradio_app(app, demo, path="/")
549
+
550
+ # For running the app
551
+ if __name__ == "__main__":
552
+ import uvicorn
553
+ uvicorn.run(app, host="0.0.0.0", port=7860)