""" Gradio Demo for AudioTextHTDemucs - Text-Conditioned Stem Separation Upload an audio file, enter a text prompt (e.g., "drums", "extract bass", "vocals"), and the model will separate that stem from the mixture. """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import gradio as gr import torch import torch.nn.functional as F import torchaudio import numpy as np import matplotlib.pyplot as plt from pathlib import Path from demucs import pretrained from transformers import ClapModel, AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs from utils import load_config, plot_spectrogram # ============================================================================ # Configuration # ============================================================================ cfg = load_config("config.yaml") SAMPLE_RATE = cfg["data"]["sample_rate"] SEGMENT_SECONDS = cfg["data"]["segment_seconds"] OVERLAP = cfg["data"]["overlap"] # Auto-detect device if torch.cuda.is_available(): DEVICE = "cuda" elif torch.backends.mps.is_available(): DEVICE = "mps" else: DEVICE = "cpu" # DEVICE = "cpu" # Load model HuggingFace Hub # TODO: Add our model to the AutoModel inferface ckpt = hf_hub_download( repo_id="jacob1576/AudioTextHTDemucs", filename="best_model.pt" ) # ============================================================================ # Model Loading # ============================================================================ print(f"Loading model on device: {DEVICE}") print("Loading HTDemucs...") htdemucs = pretrained.get_model('htdemucs').models[0] print("Loading CLAP...") clap = ClapModel.from_pretrained("laion/clap-htsat-unfused") tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") print("Building AudioTextHTDemucs...") model = AudioTextHTDemucs(htdemucs, clap, tokenizer) print(f"Loading checkpoint from HuggingFace Hub...") checkpoint = torch.load(ckpt, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"], strict=False) print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") model = model.to(DEVICE) model.eval() print("Model ready!") # ============================================================================ # Helper Functions # ============================================================================ def create_spectrogram(audio, sr=SAMPLE_RATE, title="Spectrogram"): """Create a spectrogram visualization.""" fig, ax = plt.subplots(figsize=(10, 4)) # Convert to mono for visualization if stereo if audio.shape[0] == 2: audio_mono = audio.mean(dim=0) else: audio_mono = audio.squeeze() # Compute spectrogram n_fft = 2048 hop_length = 512 spec = torch.stft( audio_mono, n_fft=n_fft, hop_length=hop_length, return_complex=True ) spec_mag = torch.abs(spec) spec_db = 20 * torch.log10(spec_mag + 1e-8) # Plot im = ax.imshow( spec_db.cpu().numpy(), aspect='auto', origin='lower', cmap='viridis', interpolation='nearest' ) ax.set_xlabel('Time (frames)') ax.set_ylabel('Frequency (bins)') ax.set_title(title) plt.colorbar(im, ax=ax, format='%+2.0f dB') plt.tight_layout() return fig def load_audio(audio_path, target_sr=SAMPLE_RATE): """Load audio file and resample if necessary.""" waveform, sr = torchaudio.load(audio_path) # Resample if necessary if sr != target_sr: resampler = torchaudio.transforms.Resample(sr, target_sr) waveform = resampler(waveform) # Convert to stereo if mono if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) return waveform, target_sr def chunked_inference(mixture, prompt): """Run chunked inference for a single stem.""" C, T = mixture.shape chunk_len = int(SAMPLE_RATE * SEGMENT_SECONDS) overlap_frames = int(OVERLAP * SAMPLE_RATE) output = torch.zeros(C, T, device=DEVICE) weight = torch.zeros(T, device=DEVICE) start = 0 while start < T: end = min(start + chunk_len, T) chunk = mixture[:, start:end].unsqueeze(0).to(DEVICE) # (1, C, chunk_len) # Pad if needed if chunk.shape[-1] < chunk_len: pad_amount = chunk_len - chunk.shape[-1] chunk = F.pad(chunk, (0, pad_amount)) with torch.no_grad(): out = model(chunk, [prompt]) # (1, C, chunk_len) # Ensure output is on the correct device out = out.to(DEVICE).squeeze(0) # (C, chunk_len) # Trim padding if we added any actual_len = end - start out = out[:, :actual_len] # Create fade weights for overlap-add fade_len = min(overlap_frames, actual_len // 2) chunk_weight = torch.ones(actual_len, device=DEVICE) if start > 0 and fade_len > 0: # Fade in chunk_weight[:fade_len] = torch.linspace(0, 1, fade_len, device=DEVICE) if end < T and fade_len > 0: # Fade out chunk_weight[-fade_len:] = torch.linspace(1, 0, fade_len, device=DEVICE) output[:, start:end] += out * chunk_weight weight[start:end] += chunk_weight # Move to next chunk with overlap start += chunk_len - overlap_frames # Normalize by weights weight = weight.clamp(min=1e-8) output = output / weight return output def download_youtube_audio(yt_link): """Download audio from a YouTube link using yt-dlp.""" try: import yt_dlp os.remove("temp/yt_audio.webm") if os.path.exists("temp/yt_audio.webm") else None ydl_opts = { 'format': 'bestaudio/best', 'quiet': True, 'outtmpl': 'temp/yt_audio.webm', } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([yt_link]) mixture, sr = load_audio("temp/yt_audio.webm", target_sr=SAMPLE_RATE) return (sr, mixture.T.numpy()) except Exception as e: return f"Error downloading audio from YouTube: {str(e)}" # ============================================================================ # Gradio Interface Functions # ============================================================================ def process_audio(audio_file, yt_link, text_prompt): """Main processing function for the Gradio interface.""" if audio_file is None and (yt_link is None or yt_link.strip() == ""): return None, None, None, None, "Please upload an audio file." if not text_prompt or text_prompt.strip() == "": return None, None, None, None, "Please enter a text prompt." if yt_link and yt_link.strip() != "": try: download_youtube_audio(yt_link) except Exception as e: return None, None, None, None, str(e) try: # Load audio mixture, sr = load_audio(audio_file if audio_file else "temp/yt_audio.webm", target_sr=SAMPLE_RATE) print(f"Loaded audio: {mixture.shape}, sr={sr}") # Create input spectrogram input_spec_fig = create_spectrogram(mixture, sr, title="Input Mixture Spectrogram") #input_spec_fig = plot_spectrogram(mixture, sr, title="Input Mixture Spectrogram") # Run separation print(f"Running separation with prompt: '{text_prompt}'") separated = chunked_inference(mixture.to(DEVICE), text_prompt.strip()) separated = separated.cpu() # Debug: Check if output is non-zero print(f"Separated audio shape: {separated.shape}") print(f"Separated audio range: [{separated.min():.4f}, {separated.max():.4f}]") print(f"Separated audio mean abs: {separated.abs().mean():.4f}") # Create output spectrogram output_spec_fig = create_spectrogram(separated, sr, title=f"Separated: {text_prompt}") # Convert to audio format for Gradio # Gradio Audio expects tuple: (sample_rate, numpy_array) # numpy_array shape should be (samples, channels) for stereo input_audio = (sr, mixture.T.numpy()) # (sr, (T, 2)) output_audio = (sr, separated.T.numpy()) # (sr, (T, 2)) status = f"✓ Successfully separated '{text_prompt}' from the mixture!" return input_audio, output_audio, input_spec_fig, output_spec_fig, status except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, None, None, None, error_msg # ============================================================================ # Gradio Interface # ============================================================================ def create_demo(): """Create the Gradio interface.""" with gr.Blocks(title="AudioTextHTDemucs Demo") as demo: gr.Markdown( """ # 🎵 AudioTextHTDemucs - Text-Conditioned Stem Separation Upload an audio file and enter a text prompt to separate specific stems from the mixture. **Example prompts:** - `drums` - Extract drum sounds - `bass` - Extract bass guitar - `vocals` - Extract singing voice - `other` - Extract other instruments - Or any natural language description like "extract the guitar" or "piano sound" """ ) with gr.Row(): with gr.Column(): gr.Markdown("### Input") audio_input = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload"] ) yt_link_input = gr.Textbox( label="YouTube Video URL (optional)", placeholder="Provide a YouTube link to fetch audio", lines=1 ) text_input = gr.Textbox( label="Text Prompt", placeholder="Enter what you want to extract (e.g., 'drums', 'vocals', 'bass')", lines=1 ) gr.Examples( examples=[ ["drums"], ["bass"], ["vocals"], ["other"], ["extract the drums"], ["guitar sound"], ], inputs=text_input, label="Click to use example prompts" ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") submit_btn = gr.Button("Separate Audio", variant="primary") status_output = gr.Textbox(label="Status", interactive=False) yt_link_input.change(download_youtube_audio, inputs=[yt_link_input], outputs=[audio_input]) with gr.Row(): with gr.Column(): gr.Markdown("### Input Mixture") input_audio_player = gr.Audio( label="Input Audio (Original Mix)", type="numpy", interactive=False ) input_spec_plot = gr.Plot(label="Input Spectrogram") with gr.Column(): gr.Markdown("### Separated Output") output_audio_player = gr.Audio( label="Separated Audio", type="numpy", interactive=False ) output_spec_plot = gr.Plot(label="Output Spectrogram") # Button actions submit_btn.click( fn=process_audio, inputs=[audio_input, yt_link_input, text_input], outputs=[ input_audio_player, output_audio_player, input_spec_plot, output_spec_plot, status_output ] ) def clear_all(): return None, "", None, None, None, None, None, "" clear_btn.click( fn=clear_all, outputs=[ audio_input, text_input, yt_link_input, input_audio_player, output_audio_player, input_spec_plot, output_spec_plot, status_output ] ) gr.Markdown( """ --- ### Notes - The model works best with music audio sampled at 44.1kHz - Processing time depends on audio length (segments processed in 6-second chunks) - The model was trained on stems: drums, bass, vocals, and other instruments - You can use natural language descriptions thanks to CLAP text embeddings """ ) return demo # ============================================================================ # Launch # ============================================================================ if __name__ == "__main__": demo = create_demo() demo.launch( share=False, # Set to True to create a public link server_name="0.0.0.0", # Allow external connections server_port=7860 )