Directly load model on CUDA + display progress

#1
by cbensimon HF Staff - opened
Files changed (1) hide show
  1. app.py +4 -21
app.py CHANGED
@@ -8,9 +8,9 @@ from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
  # --- 1. FORCE CPU FOR GLOBAL LOADING ---
10
  # ZeroGPU forbids CUDA during startup. We only move to CUDA inside the decorated function.
11
- DEVICE = "cpu"
12
 
13
- MODEL = None
14
 
15
  EVENT_TAGS = [
16
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
@@ -75,13 +75,6 @@ def set_seed(seed: int):
75
  random.seed(seed)
76
  np.random.seed(seed)
77
 
78
-
79
- def load_model():
80
- global MODEL
81
- print(f"Loading Chatterbox-Turbo on {DEVICE}...")
82
- MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
83
- return MODEL
84
-
85
  @spaces.GPU
86
  def generate(
87
  text,
@@ -92,16 +85,9 @@ def generate(
92
  top_p,
93
  top_k,
94
  repetition_penalty,
95
- norm_loudness
 
96
  ):
97
- global MODEL
98
- # Reload if the worker lost the global state
99
- if MODEL is None:
100
- MODEL = ChatterboxTurboTTS.from_pretrained("cpu")
101
-
102
- # --- MOVE TO GPU HERE ---
103
- MODEL.to("cuda")
104
-
105
  if seed_num != 0:
106
  set_seed(int(seed_num))
107
 
@@ -162,9 +148,6 @@ with gr.Blocks(title="Chatterbox Turbo") as demo:
162
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
163
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
164
 
165
- # Load on startup (CPU)
166
- demo.load(fn=load_model, inputs=[], outputs=[])
167
-
168
  run_btn.click(
169
  fn=generate,
170
  inputs=[
 
8
 
9
  # --- 1. FORCE CPU FOR GLOBAL LOADING ---
10
  # ZeroGPU forbids CUDA during startup. We only move to CUDA inside the decorated function.
11
+ DEVICE = "cuda"
12
 
13
+ MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
14
 
15
  EVENT_TAGS = [
16
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
 
75
  random.seed(seed)
76
  np.random.seed(seed)
77
 
 
 
 
 
 
 
 
78
  @spaces.GPU
79
  def generate(
80
  text,
 
85
  top_p,
86
  top_k,
87
  repetition_penalty,
88
+ norm_loudness,
89
+ progress=gr.Progress(track_tqdm=True),
90
  ):
 
 
 
 
 
 
 
 
91
  if seed_num != 0:
92
  set_seed(int(seed_num))
93
 
 
148
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
149
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
150
 
 
 
 
151
  run_btn.click(
152
  fn=generate,
153
  inputs=[