jacob1576 commited on
Commit
7417a6a
·
1 Parent(s): 6c92aac

Add application file and dependencies

Browse files
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Demo for AudioTextHTDemucs - Text-Conditioned Stem Separation
3
+
4
+ Upload an audio file, enter a text prompt (e.g., "drums", "extract bass", "vocals"),
5
+ and the model will separate that stem from the mixture.
6
+ """
7
+
8
+ import os
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+
11
+ import gradio as gr
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ from pathlib import Path
18
+
19
+ from demucs import pretrained
20
+ from transformers import ClapModel, AutoTokenizer
21
+
22
+ from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs
23
+ from utils import load_config, plot_spectrogram
24
+
25
+ # ============================================================================
26
+ # Configuration
27
+ # ============================================================================
28
+
29
+ cfg = load_config("config.yaml")
30
+ CHECKPOINT_PATH = cfg["training"]["resume_from"] # Change as needed
31
+ SAMPLE_RATE = cfg["data"]["sample_rate"]
32
+ SEGMENT_SECONDS = cfg["data"]["segment_seconds"]
33
+ OVERLAP = cfg["data"]["overlap"]
34
+
35
+ # Auto-detect device
36
+ if torch.cuda.is_available():
37
+ DEVICE = "cuda"
38
+ elif torch.backends.mps.is_available():
39
+ DEVICE = "mps"
40
+ else:
41
+ DEVICE = "cpu"
42
+ # DEVICE = "cpu"
43
+
44
+
45
+ # ============================================================================
46
+ # Model Loading
47
+ # ============================================================================
48
+
49
+ print(f"Loading model on device: {DEVICE}")
50
+ print("Loading HTDemucs...")
51
+ htdemucs = pretrained.get_model('htdemucs').models[0]
52
+
53
+ print("Loading CLAP...")
54
+ clap = ClapModel.from_pretrained("laion/clap-htsat-unfused")
55
+ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
56
+
57
+ print("Building AudioTextHTDemucs...")
58
+ model = AudioTextHTDemucs(htdemucs, clap, tokenizer)
59
+
60
+ print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
61
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
62
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
63
+ print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
64
+
65
+ model = model.to(DEVICE)
66
+ model.eval()
67
+ print("Model ready!")
68
+
69
+
70
+ # ============================================================================
71
+ # Helper Functions
72
+ # ============================================================================
73
+
74
+ def create_spectrogram(audio, sr=SAMPLE_RATE, title="Spectrogram"):
75
+ """Create a spectrogram visualization."""
76
+ fig, ax = plt.subplots(figsize=(10, 4))
77
+
78
+ # Convert to mono for visualization if stereo
79
+ if audio.shape[0] == 2:
80
+ audio_mono = audio.mean(dim=0)
81
+ else:
82
+ audio_mono = audio.squeeze()
83
+
84
+ # Compute spectrogram
85
+ n_fft = 2048
86
+ hop_length = 512
87
+ spec = torch.stft(
88
+ audio_mono,
89
+ n_fft=n_fft,
90
+ hop_length=hop_length,
91
+ return_complex=True
92
+ )
93
+ spec_mag = torch.abs(spec)
94
+ spec_db = 20 * torch.log10(spec_mag + 1e-8)
95
+
96
+ # Plot
97
+ im = ax.imshow(
98
+ spec_db.cpu().numpy(),
99
+ aspect='auto',
100
+ origin='lower',
101
+ cmap='viridis',
102
+ interpolation='nearest'
103
+ )
104
+ ax.set_xlabel('Time (frames)')
105
+ ax.set_ylabel('Frequency (bins)')
106
+ ax.set_title(title)
107
+ plt.colorbar(im, ax=ax, format='%+2.0f dB')
108
+ plt.tight_layout()
109
+
110
+ return fig
111
+
112
+
113
+ def load_audio(audio_path, target_sr=SAMPLE_RATE):
114
+ """Load audio file and resample if necessary."""
115
+ waveform, sr = torchaudio.load(audio_path)
116
+
117
+ # Resample if necessary
118
+ if sr != target_sr:
119
+ resampler = torchaudio.transforms.Resample(sr, target_sr)
120
+ waveform = resampler(waveform)
121
+
122
+ # Convert to stereo if mono
123
+ if waveform.shape[0] == 1:
124
+ waveform = waveform.repeat(2, 1)
125
+
126
+ return waveform, target_sr
127
+
128
+
129
+ def chunked_inference(mixture, prompt):
130
+ """Run chunked inference for a single stem."""
131
+ C, T = mixture.shape
132
+ chunk_len = int(SAMPLE_RATE * SEGMENT_SECONDS)
133
+ overlap_frames = int(OVERLAP * SAMPLE_RATE)
134
+
135
+ output = torch.zeros(C, T, device=DEVICE)
136
+ weight = torch.zeros(T, device=DEVICE)
137
+
138
+ start = 0
139
+ while start < T:
140
+ end = min(start + chunk_len, T)
141
+ chunk = mixture[:, start:end].unsqueeze(0).to(DEVICE) # (1, C, chunk_len)
142
+
143
+ # Pad if needed
144
+ if chunk.shape[-1] < chunk_len:
145
+ pad_amount = chunk_len - chunk.shape[-1]
146
+ chunk = F.pad(chunk, (0, pad_amount))
147
+
148
+ with torch.no_grad():
149
+ out = model(chunk, [prompt]) # (1, C, chunk_len)
150
+
151
+ # Ensure output is on the correct device
152
+ out = out.to(DEVICE).squeeze(0) # (C, chunk_len)
153
+
154
+ # Trim padding if we added any
155
+ actual_len = end - start
156
+ out = out[:, :actual_len]
157
+
158
+ # Create fade weights for overlap-add
159
+ fade_len = min(overlap_frames, actual_len // 2)
160
+ chunk_weight = torch.ones(actual_len, device=DEVICE)
161
+ if start > 0 and fade_len > 0:
162
+ # Fade in
163
+ chunk_weight[:fade_len] = torch.linspace(0, 1, fade_len, device=DEVICE)
164
+ if end < T and fade_len > 0:
165
+ # Fade out
166
+ chunk_weight[-fade_len:] = torch.linspace(1, 0, fade_len, device=DEVICE)
167
+
168
+ output[:, start:end] += out * chunk_weight
169
+ weight[start:end] += chunk_weight
170
+
171
+ # Move to next chunk with overlap
172
+ start += chunk_len - overlap_frames
173
+
174
+ # Normalize by weights
175
+ weight = weight.clamp(min=1e-8)
176
+ output = output / weight
177
+
178
+ return output
179
+
180
+ def download_youtube_audio(yt_link):
181
+ """Download audio from a YouTube link using yt-dlp."""
182
+ try:
183
+ import yt_dlp
184
+ os.remove("temp/yt_audio.webm") if os.path.exists("temp/yt_audio.webm") else None
185
+
186
+ ydl_opts = {
187
+ 'format': 'bestaudio/best',
188
+ 'quiet': True,
189
+ 'outtmpl': 'temp/yt_audio.webm',
190
+ }
191
+
192
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
193
+ ydl.download([yt_link])
194
+
195
+ mixture, sr = load_audio("temp/yt_audio.webm", target_sr=SAMPLE_RATE)
196
+ return (sr, mixture.T.numpy())
197
+ except Exception as e:
198
+ return f"Error downloading audio from YouTube: {str(e)}"
199
+
200
+
201
+ # ============================================================================
202
+ # Gradio Interface Functions
203
+ # ============================================================================
204
+
205
+ def process_audio(audio_file, yt_link, text_prompt):
206
+ """Main processing function for the Gradio interface."""
207
+ if audio_file is None and (yt_link is None or yt_link.strip() == ""):
208
+ return None, None, None, None, "Please upload an audio file."
209
+
210
+ if not text_prompt or text_prompt.strip() == "":
211
+ return None, None, None, None, "Please enter a text prompt."
212
+
213
+ if yt_link and yt_link.strip() != "":
214
+ try:
215
+ download_youtube_audio(yt_link)
216
+ except Exception as e:
217
+ return None, None, None, None, str(e)
218
+
219
+ try:
220
+ # Load audio
221
+ mixture, sr = load_audio(audio_file if audio_file else "temp/yt_audio.webm", target_sr=SAMPLE_RATE)
222
+ print(f"Loaded audio: {mixture.shape}, sr={sr}")
223
+
224
+ # Create input spectrogram
225
+ input_spec_fig = create_spectrogram(mixture, sr, title="Input Mixture Spectrogram")
226
+ #input_spec_fig = plot_spectrogram(mixture, sr, title="Input Mixture Spectrogram")
227
+
228
+ # Run separation
229
+ print(f"Running separation with prompt: '{text_prompt}'")
230
+ separated = chunked_inference(mixture.to(DEVICE), text_prompt.strip())
231
+ separated = separated.cpu()
232
+
233
+ # Debug: Check if output is non-zero
234
+ print(f"Separated audio shape: {separated.shape}")
235
+ print(f"Separated audio range: [{separated.min():.4f}, {separated.max():.4f}]")
236
+ print(f"Separated audio mean abs: {separated.abs().mean():.4f}")
237
+
238
+ # Create output spectrogram
239
+ output_spec_fig = create_spectrogram(separated, sr, title=f"Separated: {text_prompt}")
240
+
241
+ # Convert to audio format for Gradio
242
+ # Gradio Audio expects tuple: (sample_rate, numpy_array)
243
+ # numpy_array shape should be (samples, channels) for stereo
244
+ input_audio = (sr, mixture.T.numpy()) # (sr, (T, 2))
245
+ output_audio = (sr, separated.T.numpy()) # (sr, (T, 2))
246
+
247
+ status = f"✓ Successfully separated '{text_prompt}' from the mixture!"
248
+
249
+ return input_audio, output_audio, input_spec_fig, output_spec_fig, status
250
+
251
+ except Exception as e:
252
+ error_msg = f"Error: {str(e)}"
253
+ print(error_msg)
254
+ import traceback
255
+ traceback.print_exc()
256
+ return None, None, None, None, error_msg
257
+
258
+
259
+ # ============================================================================
260
+ # Gradio Interface
261
+ # ============================================================================
262
+
263
+ def create_demo():
264
+ """Create the Gradio interface."""
265
+
266
+ with gr.Blocks(title="AudioTextHTDemucs Demo") as demo:
267
+ gr.Markdown(
268
+ """
269
+ # 🎵 AudioTextHTDemucs - Text-Conditioned Stem Separation
270
+
271
+ Upload an audio file and enter a text prompt to separate specific stems from the mixture.
272
+
273
+ **Example prompts:**
274
+ - `drums` - Extract drum sounds
275
+ - `bass` - Extract bass guitar
276
+ - `vocals` - Extract singing voice
277
+ - `other` - Extract other instruments
278
+ - Or any natural language description like "extract the guitar" or "piano sound"
279
+ """
280
+ )
281
+
282
+ with gr.Row():
283
+ with gr.Column():
284
+ gr.Markdown("### Input")
285
+ audio_input = gr.Audio(
286
+ label="Upload Audio File",
287
+ type="filepath",
288
+ sources=["upload"]
289
+ )
290
+ yt_link_input = gr.Textbox(
291
+ label="YouTube Video URL (optional)",
292
+ placeholder="Provide a YouTube link to fetch audio",
293
+ lines=1
294
+ )
295
+ text_input = gr.Textbox(
296
+ label="Text Prompt",
297
+ placeholder="Enter what you want to extract (e.g., 'drums', 'vocals', 'bass')",
298
+ lines=1
299
+ )
300
+ gr.Examples(
301
+ examples=[
302
+ ["drums"],
303
+ ["bass"],
304
+ ["vocals"],
305
+ ["other"],
306
+ ["extract the drums"],
307
+ ["guitar sound"],
308
+ ],
309
+ inputs=text_input,
310
+ label="Click to use example prompts"
311
+ )
312
+
313
+ with gr.Row():
314
+ clear_btn = gr.Button("Clear", variant="secondary")
315
+ submit_btn = gr.Button("Separate Audio", variant="primary")
316
+
317
+ status_output = gr.Textbox(label="Status", interactive=False)
318
+ yt_link_input.change(download_youtube_audio, inputs=[yt_link_input], outputs=[audio_input])
319
+
320
+ with gr.Row():
321
+ with gr.Column():
322
+ gr.Markdown("### Input Mixture")
323
+ input_audio_player = gr.Audio(
324
+ label="Input Audio (Original Mix)",
325
+ type="numpy",
326
+ interactive=False
327
+ )
328
+ input_spec_plot = gr.Plot(label="Input Spectrogram")
329
+
330
+ with gr.Column():
331
+ gr.Markdown("### Separated Output")
332
+ output_audio_player = gr.Audio(
333
+ label="Separated Audio",
334
+ type="numpy",
335
+ interactive=False
336
+ )
337
+ output_spec_plot = gr.Plot(label="Output Spectrogram")
338
+
339
+ # Button actions
340
+ submit_btn.click(
341
+ fn=process_audio,
342
+ inputs=[audio_input, yt_link_input, text_input],
343
+ outputs=[
344
+ input_audio_player,
345
+ output_audio_player,
346
+ input_spec_plot,
347
+ output_spec_plot,
348
+ status_output
349
+ ]
350
+ )
351
+
352
+ def clear_all():
353
+ return None, "", None, None, None, None, None, ""
354
+
355
+ clear_btn.click(
356
+ fn=clear_all,
357
+ outputs=[
358
+ audio_input,
359
+ text_input,
360
+ yt_link_input,
361
+ input_audio_player,
362
+ output_audio_player,
363
+ input_spec_plot,
364
+ output_spec_plot,
365
+ status_output
366
+ ]
367
+ )
368
+
369
+ gr.Markdown(
370
+ """
371
+ ---
372
+ ### Notes
373
+ - The model works best with music audio sampled at 44.1kHz
374
+ - Processing time depends on audio length (segments processed in 6-second chunks)
375
+ - The model was trained on stems: drums, bass, vocals, and other instruments
376
+ - You can use natural language descriptions thanks to CLAP text embeddings
377
+ """
378
+ )
379
+
380
+ return demo
381
+
382
+
383
+ # ============================================================================
384
+ # Launch
385
+ # ============================================================================
386
+
387
+ if __name__ == "__main__":
388
+ demo = create_demo()
389
+ demo.launch(
390
+ share=False, # Set to True to create a public link
391
+ server_name="0.0.0.0", # Allow external connections
392
+ server_port=7860
393
+ )
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (149 Bytes). View file
 
src/__pycache__/dataloader.cpython-313.pyc ADDED
Binary file (8.6 kB). View file
 
src/__pycache__/loss.cpython-313.pyc ADDED
Binary file (5.93 kB). View file
 
src/__pycache__/train.cpython-313.pyc ADDED
Binary file (21.5 kB). View file
 
src/dataloader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Dict, List, Tuple
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import stempeg
7
+ import soundfile as sf
8
+ import math
9
+ import numpy as np
10
+
11
+ # ============================================================================
12
+ # Data Loader
13
+ # ============================================================================
14
+
15
+ def get_random_prompt(stem_name: str) -> str:
16
+ """Get a random text prompt for a given stem."""
17
+ return random.choice(STEM_PROMPTS[stem_name])
18
+
19
+
20
+ # Text Prompt Templates
21
+ STEM_PROMPTS: Dict[str, List[str]] = {
22
+ "drums": ["drums", "drum kit", "percussion", "the drums"],
23
+ "bass": ["bass", "bass guitar", "the bass", "bass line"],
24
+ "other": ["other instruments", "accompaniment", "instruments"],
25
+ "vocals": ["vocals", "voice", "singing", "the vocals"],
26
+ }
27
+
28
+ PROMPT_TO_STEM: Dict[str, str] = {
29
+ prompt: stem
30
+ for stem, prompts in STEM_PROMPTS.items()
31
+ for prompt in prompts
32
+ }
33
+
34
+ STEM_NAME_TO_INDEX = {"drums": 0, "bass": 1, "other": 2, "vocals": 3}
35
+
36
+
37
+ class MusDBStemDataset(Dataset):
38
+ def __init__(
39
+ self,
40
+ root_dir: str,
41
+ segment_samples: int,
42
+ sample_rate: int = 44100,
43
+ channels: int = 2,
44
+ random_segments: bool = True,
45
+ augment: bool = True,
46
+ ):
47
+ self.root_dir = Path(root_dir)
48
+ self.segment_samples = segment_samples
49
+ self.sample_rate = sample_rate
50
+ self.channels = channels
51
+ self.random_segments = random_segments
52
+ self.augment = augment
53
+
54
+ self.stem_names = ["drums", "bass", "other", "vocals"]
55
+
56
+ self.files = list(self.root_dir.glob("*.stem.mp4"))
57
+ if not self.files:
58
+ raise ValueError(f"No .stem.mp4 files found in {root_dir}")
59
+
60
+ # Compute number of examples
61
+ self.index_map = [] # (file_idx, stem_idx, segment_idx)
62
+ #self.sample_lengths = [0] * len(self.files) # total samples per file
63
+ for file_idx, file in enumerate(self.files):
64
+ info = stempeg.Info(str(file))
65
+ total_samples = info.duration(0) * info.sample_rate(0) # 0 - using mixture stream as reference
66
+ #self.sample_lengths[file_idx] = int(total_samples)
67
+ num_segments = math.ceil(total_samples / segment_samples)
68
+
69
+ # Build index map: for each stem, each segment
70
+ for stem_idx in range(len(self.stem_names)):
71
+ for seg in range(num_segments):
72
+ self.index_map.append((file_idx, stem_idx, seg))
73
+
74
+ print(f"Found {len(self.files)} tracks, total dataset items: {len(self.index_map)}")
75
+
76
+ def __len__(self) -> int:
77
+ return len(self.index_map)
78
+
79
+ def _load_stems(self, filepath: Path) -> np.ndarray:
80
+ """Load all stems from a .stem.mp4 file."""
81
+ stems, rate = stempeg.read_stems(str(filepath))
82
+ # stems shape: (num_stems, samples, channels)
83
+ # [mix, drums, bass, other, vocals]
84
+ return stems
85
+
86
+ def _extract_random_segment(self, stems: np.ndarray) -> np.ndarray:
87
+ """Extract the same random segment from all stems."""
88
+ total_samples = stems.shape[1] # stems: (num_stems, samples, channels)
89
+
90
+ if total_samples <= self.segment_samples:
91
+ # Pad if too short
92
+ pad_amount = self.segment_samples - total_samples
93
+ stems = np.pad(stems, ((0, 0), (0, pad_amount), (0, 0)), mode='constant')
94
+ else:
95
+ # Random start position (same for all stems)
96
+ if self.random_segments:
97
+ start = random.randint(0, total_samples - self.segment_samples)
98
+ else:
99
+ start = 0
100
+ stems = stems[:, start:start + self.segment_samples, :]
101
+
102
+ return stems
103
+
104
+ def _extract_segment(self, stems: np.ndarray, seg_idx: int) -> np.ndarray:
105
+ total_samples = stems.shape[1]
106
+
107
+ if self.random_segments:
108
+ # fallback to random segment extractor
109
+ return self._extract_random_segment(stems)
110
+
111
+ start = seg_idx * self.segment_samples
112
+ end = start + self.segment_samples
113
+
114
+ if end <= total_samples:
115
+ return stems[:, start:end, :]
116
+ else:
117
+ # Last segment may need padding
118
+ pad_amount = end - total_samples
119
+ seg = stems[:, start:, :]
120
+ seg = np.pad(seg, ((0, 0),(0, pad_amount), (0, 0)), mode="constant")
121
+ return seg
122
+
123
+ def _augment(self, mixture: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
124
+ """Apply data augmentation."""
125
+ if random.random() < 0.5:
126
+ gain = random.uniform(0.7, 1.3)
127
+ mixture = mixture * gain
128
+ target = target * gain
129
+
130
+ if random.random() < 0.3 and mixture.shape[-1] == 2:
131
+ mixture = mixture[:, ::-1].copy()
132
+ target = target[:, ::-1].copy()
133
+
134
+ return mixture, target
135
+
136
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor | str]:
137
+ file_idx, stem_idx, seg_idx = self.index_map[idx]
138
+
139
+ filepath = self.files[file_idx]
140
+ stems = self._load_stems(filepath)
141
+
142
+ # deterministic segment selection
143
+ stems = self._extract_segment(stems, seg_idx)
144
+
145
+ mixture = stems[0] # (T, C)
146
+ target = stems[stem_idx+1] # (T, C)
147
+
148
+ if self.augment:
149
+ mixture, target = self._augment(mixture, target)
150
+
151
+ # -> (C, T)
152
+ mixture = torch.from_numpy(mixture.T).float()
153
+ target = torch.from_numpy(target.T).float()
154
+
155
+ # ensure stereo
156
+ if mixture.shape[0] == 1:
157
+ mixture = mixture.repeat(2, 1)
158
+ target = target.repeat(2, 1)
159
+
160
+ prompt = get_random_prompt(self.stem_names[stem_idx])
161
+
162
+ return {
163
+ "mixture": mixture,
164
+ "target": target,
165
+ "prompt": prompt,
166
+ "stem_name": self.stem_names[stem_idx],
167
+ "file_idx": file_idx,
168
+ "segment_idx": seg_idx,
169
+ }
170
+
171
+
172
+ def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor | List[str]]:
173
+ """Custom collate function."""
174
+ return {
175
+ "mixture": torch.stack([item["mixture"] for item in batch]),
176
+ "target": torch.stack([item["target"] for item in batch]),
177
+ "prompt": [item["prompt"] for item in batch],
178
+ "stem_name": [item["stem_name"] for item in batch],
179
+ }
src/loss.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ import torch
3
+
4
+
5
+ # ============================================================================
6
+ # Loss Functions
7
+ # ============================================================================
8
+
9
+ def sdr_loss(estimated, target):
10
+ """
11
+ Compute negative SDR loss.
12
+ Based on the definition from Vincent et al. 2006.
13
+ """
14
+ # Flatten to [batch, -1] to ensure compatible shapes
15
+ est_flat = estimated.reshape(estimated.shape[0], -1)
16
+ tgt_flat = target.reshape(target.shape[0], -1)
17
+
18
+ # Compute SDR: 10 * log10(||target||^2 / ||target - estimated||^2)
19
+ delta = 1e-8 # Small constant for numerical stability
20
+
21
+ num = torch.sum(tgt_flat ** 2, dim=-1)
22
+ den = torch.sum((tgt_flat - est_flat) ** 2, dim=-1)
23
+
24
+ # Avoid division by zero
25
+ sdr = 10 * torch.log10((num + delta) / (den + delta))
26
+
27
+ # Clamp to reasonable range to avoid extreme values
28
+ sdr = torch.clamp(sdr, min=-30, max=30)
29
+
30
+ return -sdr.mean() # Return negative for minimization
31
+
32
+
33
+ def sisdr_loss(estimated, target):
34
+ """
35
+ Compute negative SI-SDR (Scale-Invariant SDR) loss.
36
+ This is more robust to scaling differences between estimate and target.
37
+ """
38
+ # Flatten to [batch, -1]
39
+ est_flat = estimated.reshape(estimated.shape[0], -1)
40
+ tgt_flat = target.reshape(target.shape[0], -1)
41
+
42
+ # Zero-mean normalization (critical for SI-SDR)
43
+ est_flat = est_flat - est_flat.mean(dim=-1, keepdim=True)
44
+ tgt_flat = tgt_flat - tgt_flat.mean(dim=-1, keepdim=True)
45
+
46
+ # SI-SDR calculation
47
+ # Project estimate onto target: s_target = <s', s> / ||s||^2 * s
48
+ delta = 1e-8
49
+
50
+ dot = torch.sum(est_flat * tgt_flat, dim=-1, keepdim=True)
51
+ s_target_norm_sq = torch.sum(tgt_flat ** 2, dim=-1, keepdim=True)
52
+
53
+ # Scaled target
54
+ s_target = (dot / (s_target_norm_sq + delta)) * tgt_flat
55
+
56
+ # Noise is the orthogonal component
57
+ e_noise = est_flat - s_target
58
+
59
+ # SI-SDR = 10 * log10(||s_target||^2 / ||e_noise||^2)
60
+ s_target_energy = torch.sum(s_target ** 2, dim=-1)
61
+ e_noise_energy = torch.sum(e_noise ** 2, dim=-1)
62
+
63
+ sisdr = 10 * torch.log10((s_target_energy + delta) / (e_noise_energy + delta))
64
+
65
+ # Clamp to reasonable range
66
+ sisdr = torch.clamp(sisdr, min=-30, max=30)
67
+
68
+ return -sisdr.mean() # Return negative for minimization
69
+
70
+
71
+ def new_sdr_metric(estimated: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
72
+ """
73
+ Compute the SDR according to the MDX challenge definition (positive values).
74
+ This is for evaluation/logging, not for loss.
75
+
76
+ Args:
77
+ estimated: (batch, channels, time)
78
+ target: (batch, channels, time)
79
+
80
+ Returns:
81
+ SDR scores per batch item (batch,)
82
+ """
83
+ delta = 1e-8
84
+ num = torch.sum(target ** 2, dim=(1, 2))
85
+ den = torch.sum((target - estimated) ** 2, dim=(1, 2))
86
+ scores = 10 * torch.log10((num + delta) / (den + delta))
87
+ return scores
88
+
89
+
90
+ def combined_loss(
91
+ estimated: torch.Tensor,
92
+ target: torch.Tensor,
93
+ sdr_weight: float = 0.9,
94
+ sisdr_weight: float = 0.1
95
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
96
+ """
97
+ Combined SDR and SI-SDR loss.
98
+
99
+ Args:
100
+ estimated: Estimated audio (batch, channels, time)
101
+ target: Target audio (batch, channels, time)
102
+ sdr_weight: Weight for SDR loss (default 0.9)
103
+ sisdr_weight: Weight for SI-SDR loss (default 0.1)
104
+
105
+ Returns:
106
+ total_loss: Combined loss for backpropagation
107
+ metrics: Dictionary of metrics for logging
108
+ """
109
+ sdr = sdr_loss(estimated, target)
110
+ sisdr = sisdr_loss(estimated, target)
111
+
112
+ total = sdr_weight * sdr + sisdr_weight * sisdr
113
+
114
+ # For logging, also compute positive SDR metric
115
+ with torch.no_grad():
116
+ pos_sdr = new_sdr_metric(estimated, target).mean()
117
+
118
+ metrics = {
119
+ "loss/total": total.item(),
120
+ "loss/sdr": sdr.item(),
121
+ "loss/sisdr": sisdr.item(),
122
+ "metrics/sdr": -sdr.item(), # Positive SDR for logging
123
+ "metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging
124
+ "metrics/new_sdr": pos_sdr.item(), # MDX-style SDR
125
+ }
126
+
127
+ return total, metrics
128
+
129
+
130
+ def combined_L1_sdr_loss(
131
+ estimated: torch.Tensor,
132
+ target: torch.Tensor,
133
+ sdr_weight: float = 1.0,
134
+ l1_weight: float = 0.05
135
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
136
+ """
137
+ Combined SDR and L1 loss.
138
+
139
+ Args:
140
+ estimated: Estimated audio (batch, channels, time)
141
+ target: Target audio (batch, channels, time)
142
+ sdr_weight: Weight for SDR loss (default 0.9)
143
+ l1_weight: Weight for SI-SDR loss (default 0.1)
144
+ Returns:
145
+ total_loss: Combined loss for backpropagation
146
+ metrics: Dictionary of metrics for logging
147
+ """
148
+ sdr = sdr_loss(estimated, target)
149
+ sisdr = sisdr_loss(estimated, target)
150
+ l1 = torch.nn.functional.l1_loss(estimated, target)
151
+
152
+ total = sdr_weight * sdr + l1_weight * l1
153
+
154
+ metrics = {
155
+ "loss/total": total.item(),
156
+ "loss/sdr": sdr.item(),
157
+ "loss/sisdr": sisdr.item(),
158
+ "metrics/sdr": -sdr.item(), # Positive SDR for logging
159
+ "metrics/sisdr": -sisdr.item(), # Positive SI-SDR for logging
160
+ }
161
+
162
+ return total, metrics
src/models/__init__.py ADDED
File without changes
src/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (156 Bytes). View file
 
src/models/stem_separation/ATHTDemucs_v2.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AudioTextHTDemucs v2 - Text-conditioned source separation.
3
+
4
+ Changes from v1:
5
+ - Custom trainable decoder that outputs 1 source (not 4)
6
+ - HTDemucs encoder kept (frozen)
7
+ - CLAP text encoder (frozen)
8
+ - Cross-attention conditioning at bottleneck
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from typing import List, Any
15
+ from fractions import Fraction
16
+ from einops import rearrange
17
+
18
+ from demucs.htdemucs import HTDemucs
19
+ from transformers import ClapModel, ClapTextModelWithProjection, RobertaTokenizerFast
20
+
21
+ class TextCrossAttention(nn.Module):
22
+ """Cross-attention: audio features attend to text embeddings."""
23
+
24
+ def __init__(self, feat_dim, text_dim, n_heads=8, dropout=0.0):
25
+ super().__init__()
26
+ self.q_proj = nn.Linear(feat_dim, feat_dim)
27
+ self.k_proj = nn.Linear(text_dim, feat_dim)
28
+ self.v_proj = nn.Linear(text_dim, feat_dim)
29
+ self.attn = nn.MultiheadAttention(feat_dim, n_heads, batch_first=True, dropout=dropout)
30
+ self.out_mlp = nn.Sequential(
31
+ nn.Linear(feat_dim, feat_dim),
32
+ nn.GELU(),
33
+ nn.Linear(feat_dim, feat_dim),
34
+ )
35
+ self.norm_q = nn.LayerNorm(feat_dim)
36
+ self.norm_out = nn.LayerNorm(feat_dim)
37
+
38
+ def forward_attend(self, queries, text_emb):
39
+ q = self.norm_q(queries)
40
+ if text_emb.dim() == 2:
41
+ text_emb = text_emb.unsqueeze(1)
42
+ k = self.k_proj(text_emb)
43
+ v = self.v_proj(text_emb)
44
+ q_proj = self.q_proj(q)
45
+ attn_out, _ = self.attn(query=q_proj, key=k, value=v)
46
+ out = queries + attn_out
47
+ out = out + self.out_mlp(out)
48
+ return self.norm_out(out)
49
+
50
+ def forward(self, x, xt, text_emb):
51
+ B, C, F, T = x.shape
52
+ x_seq = rearrange(x, "b c f t -> b (f t) c")
53
+ xt_seq = rearrange(xt, "b c t -> b t c")
54
+ x_seq = self.forward_attend(x_seq, text_emb)
55
+ xt_seq = self.forward_attend(xt_seq, text_emb)
56
+ x = rearrange(x_seq, "b (f t) c -> b c f t", f=F, t=T)
57
+ xt = rearrange(xt_seq, "b t c -> b c t")
58
+ return x, xt
59
+
60
+
61
+ class FreqDecoder(nn.Module):
62
+ """Frequency-domain decoder: mirrors HTDemucs encoder structure but outputs 1 source."""
63
+
64
+ def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4):
65
+ """
66
+ channels: List of channel dims from bottleneck to output, e.g. [384, 192, 96, 48, 2]
67
+ """
68
+ super().__init__()
69
+ self.layers = nn.ModuleList()
70
+
71
+ for i in range(len(channels) - 1):
72
+ in_ch = channels[i]
73
+ out_ch = channels[i + 1]
74
+ is_last = (i == len(channels) - 2)
75
+
76
+ self.layers.append(nn.Sequential(
77
+ nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(kernel_size, 1), stride=(stride, 1), padding=(kernel_size//4, 0)),
78
+ nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(),
79
+ nn.GELU() if not is_last else nn.Identity(),
80
+ ))
81
+
82
+ def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]):
83
+ """
84
+ x: (B, C, F, T) bottleneck features
85
+ skips: encoder skip connections (reversed order)
86
+ target_lengths: target frequency dimensions for each layer
87
+ """
88
+ for i, layer in enumerate(self.layers):
89
+ x = layer(x)
90
+ # Match target size
91
+ if i < len(target_lengths):
92
+ target_f = target_lengths[i]
93
+ if x.shape[2] != target_f:
94
+ x = F.interpolate(x, size=(target_f, x.shape[3]), mode='bilinear', align_corners=False)
95
+ # Add skip connection if available
96
+ if i < len(skips):
97
+ skip = skips[i]
98
+ # Project skip to match channels if needed
99
+ if skip.shape[1] != x.shape[1]:
100
+ skip = skip[:, :x.shape[1]] # Simple channel truncation
101
+ if skip.shape[2:] != x.shape[2:]:
102
+ skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
103
+ x = x + skip * 0.1 # Scaled residual
104
+ return x
105
+
106
+
107
+ class TimeDecoder(nn.Module):
108
+ """Time-domain decoder: outputs 1 source waveform."""
109
+
110
+ def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4):
111
+ super().__init__()
112
+ self.layers = nn.ModuleList()
113
+
114
+ for i in range(len(channels) - 1):
115
+ in_ch = channels[i]
116
+ out_ch = channels[i + 1]
117
+ is_last = (i == len(channels) - 2)
118
+
119
+ self.layers.append(nn.Sequential(
120
+ nn.ConvTranspose1d(in_ch, out_ch, kernel_size, stride, padding=kernel_size//4),
121
+ nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(),
122
+ nn.GELU() if not is_last else nn.Identity(),
123
+ ))
124
+
125
+ def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]):
126
+ for i, layer in enumerate(self.layers):
127
+ x = layer(x)
128
+ if i < len(target_lengths):
129
+ target_t = target_lengths[i]
130
+ if x.shape[2] != target_t:
131
+ x = F.interpolate(x, size=target_t, mode='linear', align_corners=False)
132
+ if i < len(skips):
133
+ skip = skips[i]
134
+ if skip.shape[1] != x.shape[1]:
135
+ skip = skip[:, :x.shape[1]]
136
+ if skip.shape[2] != x.shape[2]:
137
+ skip = F.interpolate(skip, size=x.shape[2], mode='linear', align_corners=False)
138
+ x = x + skip * 0.1
139
+ return x
140
+
141
+
142
+ class AudioTextHTDemucs(nn.Module):
143
+ """
144
+ Text-conditioned source separation.
145
+ - HTDemucs encoder (frozen): extracts multi-scale audio features
146
+ - CLAP (frozen): text embeddings
147
+ - Cross-attention: conditions audio on text at bottleneck
148
+ - Custom decoder (trainable): outputs single source
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ htdemucs_model: HTDemucs,
154
+ clap_encoder: ClapModel | ClapTextModelWithProjection,
155
+ clap_tokenizer: RobertaTokenizerFast,
156
+ model_dim: int = 384,
157
+ text_dim: int = 512,
158
+ num_heads: int = 8,
159
+ sample_rate: int = 44100,
160
+ segment: float = 7.8,
161
+ ):
162
+ super().__init__()
163
+
164
+ self.htdemucs = htdemucs_model
165
+ self.clap = clap_encoder
166
+ self.tokenizer = clap_tokenizer
167
+ self.sample_rate = sample_rate
168
+ self.segment = segment
169
+
170
+ # Freeze HTDemucs encoder
171
+ for param in self.htdemucs.parameters():
172
+ param.requires_grad = False
173
+
174
+ # Freeze CLAP
175
+ for param in self.clap.parameters():
176
+ param.requires_grad = False
177
+
178
+ # Text cross-attention at bottleneck
179
+ self.text_attn = TextCrossAttention(model_dim, text_dim, num_heads)
180
+
181
+ # Custom decoders (trainable) - output 1 source with 2 channels (stereo)
182
+ # Channel progression: 384 -> 192 -> 96 -> 48 -> 4 (will be reshaped to 2 channels)
183
+ self.freq_decoder = FreqDecoder([384, 192, 96, 48, 4])
184
+ self.time_decoder = TimeDecoder([384, 192, 96, 48, 4])
185
+
186
+ # Final projection to stereo
187
+ self.freq_out = nn.Conv2d(4, 2, 1)
188
+ self.time_out = nn.Conv1d(4, 2, 1)
189
+
190
+ def _encode(self, x, xt):
191
+ """Run HTDemucs encoder, save skip connections."""
192
+ saved = []
193
+ saved_t = []
194
+ lengths = []
195
+ lengths_t = []
196
+
197
+ for idx, encode in enumerate(self.htdemucs.encoder):
198
+ lengths.append(x.shape[-1])
199
+ inject = None
200
+
201
+ if idx < len(self.htdemucs.tencoder):
202
+ lengths_t.append(xt.shape[-1])
203
+ tenc = self.htdemucs.tencoder[idx]
204
+ xt = tenc(xt)
205
+ if not tenc.empty:
206
+ saved_t.append(xt)
207
+ else:
208
+ inject = xt
209
+
210
+ x = encode(x, inject)
211
+
212
+ if idx == 0 and self.htdemucs.freq_emb is not None:
213
+ frs = torch.arange(x.shape[-2], device=x.device)
214
+ emb = self.htdemucs.freq_emb(frs).t()[None, :, :, None].expand_as(x)
215
+ x = x + self.htdemucs.freq_emb_scale * emb
216
+
217
+ saved.append(x)
218
+
219
+ # Cross-transformer at bottleneck
220
+ if self.htdemucs.crosstransformer:
221
+ if self.htdemucs.bottom_channels:
222
+ b, c, f, t = x.shape
223
+ x = rearrange(x, "b c f t -> b c (f t)")
224
+ x = self.htdemucs.channel_upsampler(x)
225
+ x = rearrange(x, "b c (f t) -> b c f t", f=f)
226
+ xt = self.htdemucs.channel_upsampler_t(xt)
227
+
228
+ x, xt = self.htdemucs.crosstransformer(x, xt)
229
+
230
+ if self.htdemucs.bottom_channels:
231
+ x = rearrange(x, "b c f t -> b c (f t)")
232
+ x = self.htdemucs.channel_downsampler(x)
233
+ x = rearrange(x, "b c (f t) -> b c f t", f=f)
234
+ xt = self.htdemucs.channel_downsampler_t(xt)
235
+
236
+ return x, xt, saved, saved_t, lengths, lengths_t
237
+
238
+ def _get_clap_embeddings(self, text: List[str], device):
239
+ inputs = self.tokenizer(text, padding=True, return_tensors="pt")
240
+ inputs = {k: v.to(device) for k, v in inputs.items()}
241
+ if isinstance(self.clap, ClapModel):
242
+ # Use get_text_features for ClapModel
243
+ with torch.no_grad():
244
+ return self.clap.get_text_features(**inputs)
245
+ else:
246
+ # Use forward pass for ClapTextModelWithProjection
247
+ with torch.no_grad():
248
+ return self.clap.forward(**inputs).text_embeds
249
+
250
+ def forward(self, wav, text):
251
+ """
252
+ wav: (B, 2, T) stereo mixture
253
+ text: List[str] prompts
254
+ Returns: (B, 2, T) separated stereo source
255
+ """
256
+ device = wav.device
257
+ B = wav.shape[0]
258
+ original_length = wav.shape[-1]
259
+
260
+ # Compute spectrogram (ensure all on same device)
261
+ z = self.htdemucs._spec(wav).to(device)
262
+ mag = self.htdemucs._magnitude(z).to(device)
263
+ x = mag
264
+
265
+ B, C, Fq, T_spec = x.shape
266
+
267
+ # Normalize
268
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
269
+ std = x.std(dim=(1, 2, 3), keepdim=True)
270
+ x = (x - mean) / (1e-5 + std)
271
+
272
+ xt = wav
273
+ meant = xt.mean(dim=(1, 2), keepdim=True)
274
+ stdt = xt.std(dim=(1, 2), keepdim=True)
275
+ xt = (xt - meant) / (1e-5 + stdt)
276
+
277
+ # Encode (frozen)
278
+ with torch.no_grad():
279
+ x_enc, xt_enc, saved, saved_t, lengths, lengths_t = self._encode(x, xt)
280
+
281
+ # Text conditioning via cross-attention (trainable)
282
+ text_emb = self._get_clap_embeddings(text, device)
283
+ x_cond, xt_cond = self.text_attn(x_enc, xt_enc, text_emb)
284
+
285
+ # Decode with custom decoder (trainable)
286
+ # Reverse skips for decoder
287
+ saved_rev = saved[::-1]
288
+ saved_t_rev = saved_t[::-1]
289
+ lengths_rev = lengths[::-1]
290
+ lengths_t_rev = lengths_t[::-1]
291
+
292
+ # Frequency decoder
293
+ x_dec = self.freq_decoder(x_cond, saved_rev, lengths_rev)
294
+ x_dec = self.freq_out(x_dec) # (B, 2, F, T)
295
+
296
+ # Interpolate to match original spectrogram size
297
+ x_dec = F.interpolate(x_dec, size=(Fq, T_spec), mode='bilinear', align_corners=False)
298
+
299
+ # Apply as mask and invert spectrogram
300
+ mask = torch.sigmoid(x_dec) # (B, 2, F, T) in [0, 1]
301
+
302
+ # mag is (B, C, F, T) from htdemucs - take first 2 channels for stereo
303
+ mag_stereo = mag[:, :2, :, :] # (B, 2, F, T)
304
+ masked_spec = mag_stereo * mask
305
+
306
+ # z is complex (B, C, F, T) - take stereo channels
307
+ z_stereo = z[:, :2, :, :] # (B, 2, F, T)
308
+ phase = z_stereo / (mag_stereo + 1e-8) # Complex phase
309
+ masked_z = masked_spec * phase # Apply mask while preserving phase
310
+ freq_wav = self.htdemucs._ispec(masked_z, original_length).to(device)
311
+
312
+ # Time decoder
313
+ xt_dec = self.time_decoder(xt_cond, saved_t_rev, lengths_t_rev)
314
+ xt_dec = self.time_out(xt_dec) # (B, 2, T)
315
+
316
+ # Interpolate to original length
317
+ if xt_dec.shape[-1] != original_length:
318
+ xt_dec = F.interpolate(xt_dec, size=original_length, mode='linear', align_corners=False)
319
+
320
+ # Denormalize time output
321
+ xt_dec = xt_dec * stdt + meant
322
+
323
+ # Combine frequency and time branches
324
+ output = freq_wav + xt_dec
325
+
326
+ return output
327
+
328
+
329
+ if __name__ == "__main__":
330
+ from demucs import pretrained
331
+
332
+ htdemucs = pretrained.get_model('htdemucs').models[0]
333
+ clap = ClapModel.from_pretrained("laion/clap-htsat-unfused")
334
+ tokenizer = __import__('transformers').AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
335
+
336
+ model = AudioTextHTDemucs(htdemucs, clap, tokenizer)
337
+
338
+ # Count params
339
+ total = sum(p.numel() for p in model.parameters())
340
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
341
+ print(f"Total params: {total:,}")
342
+ print(f"Trainable params: {trainable:,}")
343
+
344
+ # Test forward
345
+ wav = torch.randn(2, 2, 44100 * 3)
346
+ prompts = ["drums", "bass"]
347
+ out = model(wav, prompts)
348
+ print(f"Input: {wav.shape} -> Output: {out.shape}")
src/models/stem_separation/AudioTextDemucsV2.txt ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model:
2
+ TextConditionedSeparator(
3
+ (clap): ClapModel(
4
+ (text_model): ClapTextModel(
5
+ (embeddings): ClapTextEmbeddings(
6
+ (word_embeddings): Embedding(50265, 768, padding_idx=1)
7
+ (position_embeddings): Embedding(514, 768, padding_idx=1)
8
+ (token_type_embeddings): Embedding(1, 768)
9
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
10
+ (dropout): Dropout(p=0.1, inplace=False)
11
+ )
12
+ (encoder): ClapTextEncoder(
13
+ (layer): ModuleList(
14
+ (0-11): 12 x ClapTextLayer(
15
+ (attention): ClapTextAttention(
16
+ (self): ClapTextSelfAttention(
17
+ (query): Linear(in_features=768, out_features=768, bias=True)
18
+ (key): Linear(in_features=768, out_features=768, bias=True)
19
+ (value): Linear(in_features=768, out_features=768, bias=True)
20
+ (dropout): Dropout(p=0.1, inplace=False)
21
+ )
22
+ (output): ClapTextSelfOutput(
23
+ (dense): Linear(in_features=768, out_features=768, bias=True)
24
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
25
+ (dropout): Dropout(p=0.1, inplace=False)
26
+ )
27
+ )
28
+ (intermediate): ClapTextIntermediate(
29
+ (dense): Linear(in_features=768, out_features=3072, bias=True)
30
+ (intermediate_act_fn): GELUActivation()
31
+ )
32
+ (output): ClapTextOutput(
33
+ (dense): Linear(in_features=3072, out_features=768, bias=True)
34
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
35
+ (dropout): Dropout(p=0.1, inplace=False)
36
+ )
37
+ )
38
+ )
39
+ )
40
+ (pooler): ClapTextPooler(
41
+ (dense): Linear(in_features=768, out_features=768, bias=True)
42
+ (activation): Tanh()
43
+ )
44
+ )
45
+ (text_projection): ClapProjectionLayer(
46
+ (linear1): Linear(in_features=768, out_features=512, bias=True)
47
+ (activation): ReLU()
48
+ (linear2): Linear(in_features=512, out_features=512, bias=True)
49
+ )
50
+ (audio_model): ClapAudioModel(
51
+ (audio_encoder): ClapAudioEncoder(
52
+ (patch_embed): ClapAudioPatchEmbed(
53
+ (proj): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
54
+ (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
55
+ )
56
+ (layers): ModuleList(
57
+ (0): ClapAudioStage(
58
+ (blocks): ModuleList(
59
+ (0-1): 2 x ClapAudioLayer(
60
+ (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
61
+ (attention): ClapAudioAttention(
62
+ (self): ClapAudioSelfAttention(
63
+ (query): Linear(in_features=96, out_features=96, bias=True)
64
+ (key): Linear(in_features=96, out_features=96, bias=True)
65
+ (value): Linear(in_features=96, out_features=96, bias=True)
66
+ (dropout): Dropout(p=0.0, inplace=False)
67
+ )
68
+ (output): ClapAudioSelfOutput(
69
+ (dense): Linear(in_features=96, out_features=96, bias=True)
70
+ (dropout): Dropout(p=0.0, inplace=False)
71
+ )
72
+ )
73
+ (drop_path): Identity()
74
+ (layernorm_after): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
75
+ (intermediate): ClapAudioIntermediate(
76
+ (dense): Linear(in_features=96, out_features=384, bias=True)
77
+ (intermediate_act_fn): GELUActivation()
78
+ )
79
+ (output): ClapAudioOutput(
80
+ (dense): Linear(in_features=384, out_features=96, bias=True)
81
+ (dropout): Dropout(p=0.1, inplace=False)
82
+ )
83
+ )
84
+ )
85
+ (downsample): ClapAudioPatchMerging(
86
+ (reduction): Linear(in_features=384, out_features=192, bias=False)
87
+ (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
88
+ )
89
+ )
90
+ (1): ClapAudioStage(
91
+ (blocks): ModuleList(
92
+ (0-1): 2 x ClapAudioLayer(
93
+ (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
94
+ (attention): ClapAudioAttention(
95
+ (self): ClapAudioSelfAttention(
96
+ (query): Linear(in_features=192, out_features=192, bias=True)
97
+ (key): Linear(in_features=192, out_features=192, bias=True)
98
+ (value): Linear(in_features=192, out_features=192, bias=True)
99
+ (dropout): Dropout(p=0.0, inplace=False)
100
+ )
101
+ (output): ClapAudioSelfOutput(
102
+ (dense): Linear(in_features=192, out_features=192, bias=True)
103
+ (dropout): Dropout(p=0.0, inplace=False)
104
+ )
105
+ )
106
+ (drop_path): Identity()
107
+ (layernorm_after): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
108
+ (intermediate): ClapAudioIntermediate(
109
+ (dense): Linear(in_features=192, out_features=768, bias=True)
110
+ (intermediate_act_fn): GELUActivation()
111
+ )
112
+ (output): ClapAudioOutput(
113
+ (dense): Linear(in_features=768, out_features=192, bias=True)
114
+ (dropout): Dropout(p=0.1, inplace=False)
115
+ )
116
+ )
117
+ )
118
+ (downsample): ClapAudioPatchMerging(
119
+ (reduction): Linear(in_features=768, out_features=384, bias=False)
120
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
121
+ )
122
+ )
123
+ (2): ClapAudioStage(
124
+ (blocks): ModuleList(
125
+ (0-5): 6 x ClapAudioLayer(
126
+ (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
127
+ (attention): ClapAudioAttention(
128
+ (self): ClapAudioSelfAttention(
129
+ (query): Linear(in_features=384, out_features=384, bias=True)
130
+ (key): Linear(in_features=384, out_features=384, bias=True)
131
+ (value): Linear(in_features=384, out_features=384, bias=True)
132
+ (dropout): Dropout(p=0.0, inplace=False)
133
+ )
134
+ (output): ClapAudioSelfOutput(
135
+ (dense): Linear(in_features=384, out_features=384, bias=True)
136
+ (dropout): Dropout(p=0.0, inplace=False)
137
+ )
138
+ )
139
+ (drop_path): Identity()
140
+ (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
141
+ (intermediate): ClapAudioIntermediate(
142
+ (dense): Linear(in_features=384, out_features=1536, bias=True)
143
+ (intermediate_act_fn): GELUActivation()
144
+ )
145
+ (output): ClapAudioOutput(
146
+ (dense): Linear(in_features=1536, out_features=384, bias=True)
147
+ (dropout): Dropout(p=0.1, inplace=False)
148
+ )
149
+ )
150
+ )
151
+ (downsample): ClapAudioPatchMerging(
152
+ (reduction): Linear(in_features=1536, out_features=768, bias=False)
153
+ (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
154
+ )
155
+ )
156
+ (3): ClapAudioStage(
157
+ (blocks): ModuleList(
158
+ (0-1): 2 x ClapAudioLayer(
159
+ (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
160
+ (attention): ClapAudioAttention(
161
+ (self): ClapAudioSelfAttention(
162
+ (query): Linear(in_features=768, out_features=768, bias=True)
163
+ (key): Linear(in_features=768, out_features=768, bias=True)
164
+ (value): Linear(in_features=768, out_features=768, bias=True)
165
+ (dropout): Dropout(p=0.0, inplace=False)
166
+ )
167
+ (output): ClapAudioSelfOutput(
168
+ (dense): Linear(in_features=768, out_features=768, bias=True)
169
+ (dropout): Dropout(p=0.0, inplace=False)
170
+ )
171
+ )
172
+ (drop_path): Identity()
173
+ (layernorm_after): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
174
+ (intermediate): ClapAudioIntermediate(
175
+ (dense): Linear(in_features=768, out_features=3072, bias=True)
176
+ (intermediate_act_fn): GELUActivation()
177
+ )
178
+ (output): ClapAudioOutput(
179
+ (dense): Linear(in_features=3072, out_features=768, bias=True)
180
+ (dropout): Dropout(p=0.1, inplace=False)
181
+ )
182
+ )
183
+ )
184
+ )
185
+ )
186
+ (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
187
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
188
+ (avgpool): AdaptiveAvgPool1d(output_size=1)
189
+ )
190
+ )
191
+ (audio_projection): ClapProjectionLayer(
192
+ (linear1): Linear(in_features=768, out_features=512, bias=True)
193
+ (activation): ReLU()
194
+ (linear2): Linear(in_features=512, out_features=512, bias=True)
195
+ )
196
+ )
197
+ (z_encoder): PatchConv1d(
198
+ (conv): Conv1d(1, 256, kernel_size=(16,), stride=(8,))
199
+ )
200
+ (text_proj): Linear(in_features=512, out_features=256, bias=True)
201
+ (z_proj): Linear(in_features=256, out_features=256, bias=True)
202
+ (cross): CrossAttention(
203
+ (attn): MultiheadAttention(
204
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
205
+ )
206
+ (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
207
+ (ff): MLP(
208
+ (fc1): Linear(in_features=256, out_features=1024, bias=True)
209
+ (fc2): Linear(in_features=1024, out_features=256, bias=True)
210
+ (act): GELU(approximate='none')
211
+ )
212
+ (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
213
+ )
214
+ (transformer): TransformerEncoder(
215
+ (layers): ModuleList(
216
+ (0-5): 6 x TransformerEncoderLayer(
217
+ (self_attn): MultiheadAttention(
218
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
219
+ )
220
+ (linear1): Linear(in_features=256, out_features=1024, bias=True)
221
+ (dropout): Dropout(p=0.1, inplace=False)
222
+ (linear2): Linear(in_features=1024, out_features=256, bias=True)
223
+ (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
224
+ (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
225
+ (dropout1): Dropout(p=0.1, inplace=False)
226
+ (dropout2): Dropout(p=0.1, inplace=False)
227
+ )
228
+ )
229
+ )
230
+ (spec_decoder): Sequential(
231
+ (0): Linear(in_features=256, out_features=256, bias=True)
232
+ (1): GELU(approximate='none')
233
+ (2): Linear(in_features=256, out_features=2049, bias=True)
234
+ )
235
+ )
236
+ output waveform shape: torch.Size([2, 1, 48000])
237
+ output spectrogram shape: torch.Size([2, 12001, 2049])
src/models/stem_separation/AudioTextHTDemucs_Full.txt ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Summary:
2
+ AudioTextHTDemucs(
3
+ (htdemucs): HTDemucs(
4
+ (encoder): ModuleList(
5
+ (0): HEncLayer(
6
+ (conv): Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
7
+ (norm1): Identity()
8
+ (rewrite): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))
9
+ (norm2): Identity()
10
+ (dconv): DConv(
11
+ (layers): ModuleList(
12
+ (0): Sequential(
13
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
14
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
15
+ (2): GELU(approximate='none')
16
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
17
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
18
+ (5): GLU(dim=1)
19
+ (6): LayerScale()
20
+ )
21
+ (1): Sequential(
22
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
23
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
24
+ (2): GELU(approximate='none')
25
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
26
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
27
+ (5): GLU(dim=1)
28
+ (6): LayerScale()
29
+ )
30
+ )
31
+ )
32
+ )
33
+ (1): HEncLayer(
34
+ (conv): Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
35
+ (norm1): Identity()
36
+ (rewrite): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))
37
+ (norm2): Identity()
38
+ (dconv): DConv(
39
+ (layers): ModuleList(
40
+ (0): Sequential(
41
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
42
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
43
+ (2): GELU(approximate='none')
44
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
45
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
46
+ (5): GLU(dim=1)
47
+ (6): LayerScale()
48
+ )
49
+ (1): Sequential(
50
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
51
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
52
+ (2): GELU(approximate='none')
53
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
54
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
55
+ (5): GLU(dim=1)
56
+ (6): LayerScale()
57
+ )
58
+ )
59
+ )
60
+ )
61
+ (2): HEncLayer(
62
+ (conv): Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
63
+ (norm1): Identity()
64
+ (rewrite): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))
65
+ (norm2): Identity()
66
+ (dconv): DConv(
67
+ (layers): ModuleList(
68
+ (0): Sequential(
69
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
70
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
71
+ (2): GELU(approximate='none')
72
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
73
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
74
+ (5): GLU(dim=1)
75
+ (6): LayerScale()
76
+ )
77
+ (1): Sequential(
78
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
79
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
80
+ (2): GELU(approximate='none')
81
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
82
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
83
+ (5): GLU(dim=1)
84
+ (6): LayerScale()
85
+ )
86
+ )
87
+ )
88
+ )
89
+ (3): HEncLayer(
90
+ (conv): Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
91
+ (norm1): Identity()
92
+ (rewrite): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
93
+ (norm2): Identity()
94
+ (dconv): DConv(
95
+ (layers): ModuleList(
96
+ (0): Sequential(
97
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
98
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
99
+ (2): GELU(approximate='none')
100
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
101
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
102
+ (5): GLU(dim=1)
103
+ (6): LayerScale()
104
+ )
105
+ (1): Sequential(
106
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
107
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
108
+ (2): GELU(approximate='none')
109
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
110
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
111
+ (5): GLU(dim=1)
112
+ (6): LayerScale()
113
+ )
114
+ )
115
+ )
116
+ )
117
+ )
118
+ (decoder): ModuleList(
119
+ (0): HDecLayer(
120
+ (conv_tr): ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1))
121
+ (norm2): Identity()
122
+ (rewrite): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
123
+ (norm1): Identity()
124
+ (dconv): DConv(
125
+ (layers): ModuleList(
126
+ (0): Sequential(
127
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
128
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
129
+ (2): GELU(approximate='none')
130
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
131
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
132
+ (5): GLU(dim=1)
133
+ (6): LayerScale()
134
+ )
135
+ (1): Sequential(
136
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
137
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
138
+ (2): GELU(approximate='none')
139
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
140
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
141
+ (5): GLU(dim=1)
142
+ (6): LayerScale()
143
+ )
144
+ )
145
+ )
146
+ )
147
+ (1): HDecLayer(
148
+ (conv_tr): ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
149
+ (norm2): Identity()
150
+ (rewrite): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
151
+ (norm1): Identity()
152
+ (dconv): DConv(
153
+ (layers): ModuleList(
154
+ (0): Sequential(
155
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
156
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
157
+ (2): GELU(approximate='none')
158
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
159
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
160
+ (5): GLU(dim=1)
161
+ (6): LayerScale()
162
+ )
163
+ (1): Sequential(
164
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
165
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
166
+ (2): GELU(approximate='none')
167
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
168
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
169
+ (5): GLU(dim=1)
170
+ (6): LayerScale()
171
+ )
172
+ )
173
+ )
174
+ )
175
+ (2): HDecLayer(
176
+ (conv_tr): ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
177
+ (norm2): Identity()
178
+ (rewrite): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
179
+ (norm1): Identity()
180
+ (dconv): DConv(
181
+ (layers): ModuleList(
182
+ (0): Sequential(
183
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
184
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
185
+ (2): GELU(approximate='none')
186
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
187
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
188
+ (5): GLU(dim=1)
189
+ (6): LayerScale()
190
+ )
191
+ (1): Sequential(
192
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
193
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
194
+ (2): GELU(approximate='none')
195
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
196
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
197
+ (5): GLU(dim=1)
198
+ (6): LayerScale()
199
+ )
200
+ )
201
+ )
202
+ )
203
+ (3): HDecLayer(
204
+ (conv_tr): ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))
205
+ (norm2): Identity()
206
+ (rewrite): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
207
+ (norm1): Identity()
208
+ (dconv): DConv(
209
+ (layers): ModuleList(
210
+ (0): Sequential(
211
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
212
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
213
+ (2): GELU(approximate='none')
214
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
215
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
216
+ (5): GLU(dim=1)
217
+ (6): LayerScale()
218
+ )
219
+ (1): Sequential(
220
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
221
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
222
+ (2): GELU(approximate='none')
223
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
224
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
225
+ (5): GLU(dim=1)
226
+ (6): LayerScale()
227
+ )
228
+ )
229
+ )
230
+ )
231
+ )
232
+ (tencoder): ModuleList(
233
+ (0): HEncLayer(
234
+ (conv): Conv1d(2, 48, kernel_size=(8,), stride=(4,), padding=(2,))
235
+ (norm1): Identity()
236
+ (rewrite): Conv1d(48, 96, kernel_size=(1,), stride=(1,))
237
+ (norm2): Identity()
238
+ (dconv): DConv(
239
+ (layers): ModuleList(
240
+ (0): Sequential(
241
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
242
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
243
+ (2): GELU(approximate='none')
244
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
245
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
246
+ (5): GLU(dim=1)
247
+ (6): LayerScale()
248
+ )
249
+ (1): Sequential(
250
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
251
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
252
+ (2): GELU(approximate='none')
253
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
254
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
255
+ (5): GLU(dim=1)
256
+ (6): LayerScale()
257
+ )
258
+ )
259
+ )
260
+ )
261
+ (1): HEncLayer(
262
+ (conv): Conv1d(48, 96, kernel_size=(8,), stride=(4,), padding=(2,))
263
+ (norm1): Identity()
264
+ (rewrite): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
265
+ (norm2): Identity()
266
+ (dconv): DConv(
267
+ (layers): ModuleList(
268
+ (0): Sequential(
269
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
270
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
271
+ (2): GELU(approximate='none')
272
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
273
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
274
+ (5): GLU(dim=1)
275
+ (6): LayerScale()
276
+ )
277
+ (1): Sequential(
278
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
279
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
280
+ (2): GELU(approximate='none')
281
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
282
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
283
+ (5): GLU(dim=1)
284
+ (6): LayerScale()
285
+ )
286
+ )
287
+ )
288
+ )
289
+ (2): HEncLayer(
290
+ (conv): Conv1d(96, 192, kernel_size=(8,), stride=(4,), padding=(2,))
291
+ (norm1): Identity()
292
+ (rewrite): Conv1d(192, 384, kernel_size=(1,), stride=(1,))
293
+ (norm2): Identity()
294
+ (dconv): DConv(
295
+ (layers): ModuleList(
296
+ (0): Sequential(
297
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
298
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
299
+ (2): GELU(approximate='none')
300
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
301
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
302
+ (5): GLU(dim=1)
303
+ (6): LayerScale()
304
+ )
305
+ (1): Sequential(
306
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
307
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
308
+ (2): GELU(approximate='none')
309
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
310
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
311
+ (5): GLU(dim=1)
312
+ (6): LayerScale()
313
+ )
314
+ )
315
+ )
316
+ )
317
+ (3): HEncLayer(
318
+ (conv): Conv1d(192, 384, kernel_size=(8,), stride=(4,), padding=(2,))
319
+ (norm1): Identity()
320
+ (rewrite): Conv1d(384, 768, kernel_size=(1,), stride=(1,))
321
+ (norm2): Identity()
322
+ (dconv): DConv(
323
+ (layers): ModuleList(
324
+ (0): Sequential(
325
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
326
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
327
+ (2): GELU(approximate='none')
328
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
329
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
330
+ (5): GLU(dim=1)
331
+ (6): LayerScale()
332
+ )
333
+ (1): Sequential(
334
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
335
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
336
+ (2): GELU(approximate='none')
337
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
338
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
339
+ (5): GLU(dim=1)
340
+ (6): LayerScale()
341
+ )
342
+ )
343
+ )
344
+ )
345
+ )
346
+ (tdecoder): ModuleList(
347
+ (0): HDecLayer(
348
+ (conv_tr): ConvTranspose1d(384, 192, kernel_size=(8,), stride=(4,))
349
+ (norm2): Identity()
350
+ (rewrite): Conv1d(384, 768, kernel_size=(3,), stride=(1,), padding=(1,))
351
+ (norm1): Identity()
352
+ (dconv): DConv(
353
+ (layers): ModuleList(
354
+ (0): Sequential(
355
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
356
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
357
+ (2): GELU(approximate='none')
358
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
359
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
360
+ (5): GLU(dim=1)
361
+ (6): LayerScale()
362
+ )
363
+ (1): Sequential(
364
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
365
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
366
+ (2): GELU(approximate='none')
367
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
368
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
369
+ (5): GLU(dim=1)
370
+ (6): LayerScale()
371
+ )
372
+ )
373
+ )
374
+ )
375
+ (1): HDecLayer(
376
+ (conv_tr): ConvTranspose1d(192, 96, kernel_size=(8,), stride=(4,))
377
+ (norm2): Identity()
378
+ (rewrite): Conv1d(192, 384, kernel_size=(3,), stride=(1,), padding=(1,))
379
+ (norm1): Identity()
380
+ (dconv): DConv(
381
+ (layers): ModuleList(
382
+ (0): Sequential(
383
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
384
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
385
+ (2): GELU(approximate='none')
386
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
387
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
388
+ (5): GLU(dim=1)
389
+ (6): LayerScale()
390
+ )
391
+ (1): Sequential(
392
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
393
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
394
+ (2): GELU(approximate='none')
395
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
396
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
397
+ (5): GLU(dim=1)
398
+ (6): LayerScale()
399
+ )
400
+ )
401
+ )
402
+ )
403
+ (2): HDecLayer(
404
+ (conv_tr): ConvTranspose1d(96, 48, kernel_size=(8,), stride=(4,))
405
+ (norm2): Identity()
406
+ (rewrite): Conv1d(96, 192, kernel_size=(3,), stride=(1,), padding=(1,))
407
+ (norm1): Identity()
408
+ (dconv): DConv(
409
+ (layers): ModuleList(
410
+ (0): Sequential(
411
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
412
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
413
+ (2): GELU(approximate='none')
414
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
415
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
416
+ (5): GLU(dim=1)
417
+ (6): LayerScale()
418
+ )
419
+ (1): Sequential(
420
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
421
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
422
+ (2): GELU(approximate='none')
423
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
424
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
425
+ (5): GLU(dim=1)
426
+ (6): LayerScale()
427
+ )
428
+ )
429
+ )
430
+ )
431
+ (3): HDecLayer(
432
+ (conv_tr): ConvTranspose1d(48, 8, kernel_size=(8,), stride=(4,))
433
+ (norm2): Identity()
434
+ (rewrite): Conv1d(48, 96, kernel_size=(3,), stride=(1,), padding=(1,))
435
+ (norm1): Identity()
436
+ (dconv): DConv(
437
+ (layers): ModuleList(
438
+ (0): Sequential(
439
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
440
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
441
+ (2): GELU(approximate='none')
442
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
443
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
444
+ (5): GLU(dim=1)
445
+ (6): LayerScale()
446
+ )
447
+ (1): Sequential(
448
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
449
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
450
+ (2): GELU(approximate='none')
451
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
452
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
453
+ (5): GLU(dim=1)
454
+ (6): LayerScale()
455
+ )
456
+ )
457
+ )
458
+ )
459
+ )
460
+ (freq_emb): ScaledEmbedding(
461
+ (embedding): Embedding(512, 48)
462
+ )
463
+ (channel_upsampler): Conv1d(384, 512, kernel_size=(1,), stride=(1,))
464
+ (channel_downsampler): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
465
+ (channel_upsampler_t): Conv1d(384, 512, kernel_size=(1,), stride=(1,))
466
+ (channel_downsampler_t): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
467
+ (crosstransformer): CrossTransformerEncoder(
468
+ (norm_in): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
469
+ (norm_in_t): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
470
+ (layers): ModuleList(
471
+ (0): MyTransformerEncoderLayer(
472
+ (self_attn): MultiheadAttention(
473
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
474
+ )
475
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
476
+ (dropout): Dropout(p=0.02, inplace=False)
477
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
478
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
479
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
480
+ (dropout1): Dropout(p=0.02, inplace=False)
481
+ (dropout2): Dropout(p=0.02, inplace=False)
482
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
483
+ (gamma_1): LayerScale()
484
+ (gamma_2): LayerScale()
485
+ )
486
+ (1): CrossTransformerEncoderLayer(
487
+ (cross_attn): MultiheadAttention(
488
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
489
+ )
490
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
491
+ (dropout): Dropout(p=0.02, inplace=False)
492
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
493
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
494
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
495
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
496
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
497
+ (gamma_1): LayerScale()
498
+ (gamma_2): LayerScale()
499
+ (dropout1): Dropout(p=0.02, inplace=False)
500
+ (dropout2): Dropout(p=0.02, inplace=False)
501
+ )
502
+ (2): MyTransformerEncoderLayer(
503
+ (self_attn): MultiheadAttention(
504
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
505
+ )
506
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
507
+ (dropout): Dropout(p=0.02, inplace=False)
508
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
509
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
510
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
511
+ (dropout1): Dropout(p=0.02, inplace=False)
512
+ (dropout2): Dropout(p=0.02, inplace=False)
513
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
514
+ (gamma_1): LayerScale()
515
+ (gamma_2): LayerScale()
516
+ )
517
+ (3): CrossTransformerEncoderLayer(
518
+ (cross_attn): MultiheadAttention(
519
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
520
+ )
521
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
522
+ (dropout): Dropout(p=0.02, inplace=False)
523
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
524
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
525
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
526
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
527
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
528
+ (gamma_1): LayerScale()
529
+ (gamma_2): LayerScale()
530
+ (dropout1): Dropout(p=0.02, inplace=False)
531
+ (dropout2): Dropout(p=0.02, inplace=False)
532
+ )
533
+ (4): MyTransformerEncoderLayer(
534
+ (self_attn): MultiheadAttention(
535
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
536
+ )
537
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
538
+ (dropout): Dropout(p=0.02, inplace=False)
539
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
540
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
541
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
542
+ (dropout1): Dropout(p=0.02, inplace=False)
543
+ (dropout2): Dropout(p=0.02, inplace=False)
544
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
545
+ (gamma_1): LayerScale()
546
+ (gamma_2): LayerScale()
547
+ )
548
+ )
549
+ (layers_t): ModuleList(
550
+ (0): MyTransformerEncoderLayer(
551
+ (self_attn): MultiheadAttention(
552
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
553
+ )
554
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
555
+ (dropout): Dropout(p=0.02, inplace=False)
556
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
557
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
558
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
559
+ (dropout1): Dropout(p=0.02, inplace=False)
560
+ (dropout2): Dropout(p=0.02, inplace=False)
561
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
562
+ (gamma_1): LayerScale()
563
+ (gamma_2): LayerScale()
564
+ )
565
+ (1): CrossTransformerEncoderLayer(
566
+ (cross_attn): MultiheadAttention(
567
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
568
+ )
569
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
570
+ (dropout): Dropout(p=0.02, inplace=False)
571
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
572
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
573
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
574
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
575
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
576
+ (gamma_1): LayerScale()
577
+ (gamma_2): LayerScale()
578
+ (dropout1): Dropout(p=0.02, inplace=False)
579
+ (dropout2): Dropout(p=0.02, inplace=False)
580
+ )
581
+ (2): MyTransformerEncoderLayer(
582
+ (self_attn): MultiheadAttention(
583
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
584
+ )
585
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
586
+ (dropout): Dropout(p=0.02, inplace=False)
587
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
588
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
589
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
590
+ (dropout1): Dropout(p=0.02, inplace=False)
591
+ (dropout2): Dropout(p=0.02, inplace=False)
592
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
593
+ (gamma_1): LayerScale()
594
+ (gamma_2): LayerScale()
595
+ )
596
+ (3): CrossTransformerEncoderLayer(
597
+ (cross_attn): MultiheadAttention(
598
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
599
+ )
600
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
601
+ (dropout): Dropout(p=0.02, inplace=False)
602
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
603
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
604
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
605
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
606
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
607
+ (gamma_1): LayerScale()
608
+ (gamma_2): LayerScale()
609
+ (dropout1): Dropout(p=0.02, inplace=False)
610
+ (dropout2): Dropout(p=0.02, inplace=False)
611
+ )
612
+ (4): MyTransformerEncoderLayer(
613
+ (self_attn): MultiheadAttention(
614
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
615
+ )
616
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
617
+ (dropout): Dropout(p=0.02, inplace=False)
618
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
619
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
620
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
621
+ (dropout1): Dropout(p=0.02, inplace=False)
622
+ (dropout2): Dropout(p=0.02, inplace=False)
623
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
624
+ (gamma_1): LayerScale()
625
+ (gamma_2): LayerScale()
626
+ )
627
+ )
628
+ )
629
+ )
630
+ (clap): ClapModel(
631
+ (text_model): ClapTextModel(
632
+ (embeddings): ClapTextEmbeddings(
633
+ (word_embeddings): Embedding(50265, 768, padding_idx=1)
634
+ (position_embeddings): Embedding(514, 768, padding_idx=1)
635
+ (token_type_embeddings): Embedding(1, 768)
636
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
637
+ (dropout): Dropout(p=0.1, inplace=False)
638
+ )
639
+ (encoder): ClapTextEncoder(
640
+ (layer): ModuleList(
641
+ (0-11): 12 x ClapTextLayer(
642
+ (attention): ClapTextAttention(
643
+ (self): ClapTextSelfAttention(
644
+ (query): Linear(in_features=768, out_features=768, bias=True)
645
+ (key): Linear(in_features=768, out_features=768, bias=True)
646
+ (value): Linear(in_features=768, out_features=768, bias=True)
647
+ (dropout): Dropout(p=0.1, inplace=False)
648
+ )
649
+ (output): ClapTextSelfOutput(
650
+ (dense): Linear(in_features=768, out_features=768, bias=True)
651
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
652
+ (dropout): Dropout(p=0.1, inplace=False)
653
+ )
654
+ )
655
+ (intermediate): ClapTextIntermediate(
656
+ (dense): Linear(in_features=768, out_features=3072, bias=True)
657
+ (intermediate_act_fn): GELUActivation()
658
+ )
659
+ (output): ClapTextOutput(
660
+ (dense): Linear(in_features=3072, out_features=768, bias=True)
661
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
662
+ (dropout): Dropout(p=0.1, inplace=False)
663
+ )
664
+ )
665
+ )
666
+ )
667
+ (pooler): ClapTextPooler(
668
+ (dense): Linear(in_features=768, out_features=768, bias=True)
669
+ (activation): Tanh()
670
+ )
671
+ )
672
+ (text_projection): ClapProjectionLayer(
673
+ (linear1): Linear(in_features=768, out_features=512, bias=True)
674
+ (activation): ReLU()
675
+ (linear2): Linear(in_features=512, out_features=512, bias=True)
676
+ )
677
+ (audio_model): ClapAudioModel(
678
+ (audio_encoder): ClapAudioEncoder(
679
+ (patch_embed): ClapAudioPatchEmbed(
680
+ (proj): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
681
+ (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
682
+ )
683
+ (layers): ModuleList(
684
+ (0): ClapAudioStage(
685
+ (blocks): ModuleList(
686
+ (0-1): 2 x ClapAudioLayer(
687
+ (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
688
+ (attention): ClapAudioAttention(
689
+ (self): ClapAudioSelfAttention(
690
+ (query): Linear(in_features=96, out_features=96, bias=True)
691
+ (key): Linear(in_features=96, out_features=96, bias=True)
692
+ (value): Linear(in_features=96, out_features=96, bias=True)
693
+ (dropout): Dropout(p=0.0, inplace=False)
694
+ )
695
+ (output): ClapAudioSelfOutput(
696
+ (dense): Linear(in_features=96, out_features=96, bias=True)
697
+ (dropout): Dropout(p=0.0, inplace=False)
698
+ )
699
+ )
700
+ (drop_path): Identity()
701
+ (layernorm_after): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
702
+ (intermediate): ClapAudioIntermediate(
703
+ (dense): Linear(in_features=96, out_features=384, bias=True)
704
+ (intermediate_act_fn): GELUActivation()
705
+ )
706
+ (output): ClapAudioOutput(
707
+ (dense): Linear(in_features=384, out_features=96, bias=True)
708
+ (dropout): Dropout(p=0.1, inplace=False)
709
+ )
710
+ )
711
+ )
712
+ (downsample): ClapAudioPatchMerging(
713
+ (reduction): Linear(in_features=384, out_features=192, bias=False)
714
+ (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
715
+ )
716
+ )
717
+ (1): ClapAudioStage(
718
+ (blocks): ModuleList(
719
+ (0-1): 2 x ClapAudioLayer(
720
+ (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
721
+ (attention): ClapAudioAttention(
722
+ (self): ClapAudioSelfAttention(
723
+ (query): Linear(in_features=192, out_features=192, bias=True)
724
+ (key): Linear(in_features=192, out_features=192, bias=True)
725
+ (value): Linear(in_features=192, out_features=192, bias=True)
726
+ (dropout): Dropout(p=0.0, inplace=False)
727
+ )
728
+ (output): ClapAudioSelfOutput(
729
+ (dense): Linear(in_features=192, out_features=192, bias=True)
730
+ (dropout): Dropout(p=0.0, inplace=False)
731
+ )
732
+ )
733
+ (drop_path): Identity()
734
+ (layernorm_after): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
735
+ (intermediate): ClapAudioIntermediate(
736
+ (dense): Linear(in_features=192, out_features=768, bias=True)
737
+ (intermediate_act_fn): GELUActivation()
738
+ )
739
+ (output): ClapAudioOutput(
740
+ (dense): Linear(in_features=768, out_features=192, bias=True)
741
+ (dropout): Dropout(p=0.1, inplace=False)
742
+ )
743
+ )
744
+ )
745
+ (downsample): ClapAudioPatchMerging(
746
+ (reduction): Linear(in_features=768, out_features=384, bias=False)
747
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
748
+ )
749
+ )
750
+ (2): ClapAudioStage(
751
+ (blocks): ModuleList(
752
+ (0-5): 6 x ClapAudioLayer(
753
+ (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
754
+ (attention): ClapAudioAttention(
755
+ (self): ClapAudioSelfAttention(
756
+ (query): Linear(in_features=384, out_features=384, bias=True)
757
+ (key): Linear(in_features=384, out_features=384, bias=True)
758
+ (value): Linear(in_features=384, out_features=384, bias=True)
759
+ (dropout): Dropout(p=0.0, inplace=False)
760
+ )
761
+ (output): ClapAudioSelfOutput(
762
+ (dense): Linear(in_features=384, out_features=384, bias=True)
763
+ (dropout): Dropout(p=0.0, inplace=False)
764
+ )
765
+ )
766
+ (drop_path): Identity()
767
+ (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
768
+ (intermediate): ClapAudioIntermediate(
769
+ (dense): Linear(in_features=384, out_features=1536, bias=True)
770
+ (intermediate_act_fn): GELUActivation()
771
+ )
772
+ (output): ClapAudioOutput(
773
+ (dense): Linear(in_features=1536, out_features=384, bias=True)
774
+ (dropout): Dropout(p=0.1, inplace=False)
775
+ )
776
+ )
777
+ )
778
+ (downsample): ClapAudioPatchMerging(
779
+ (reduction): Linear(in_features=1536, out_features=768, bias=False)
780
+ (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
781
+ )
782
+ )
783
+ (3): ClapAudioStage(
784
+ (blocks): ModuleList(
785
+ (0-1): 2 x ClapAudioLayer(
786
+ (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
787
+ (attention): ClapAudioAttention(
788
+ (self): ClapAudioSelfAttention(
789
+ (query): Linear(in_features=768, out_features=768, bias=True)
790
+ (key): Linear(in_features=768, out_features=768, bias=True)
791
+ (value): Linear(in_features=768, out_features=768, bias=True)
792
+ (dropout): Dropout(p=0.0, inplace=False)
793
+ )
794
+ (output): ClapAudioSelfOutput(
795
+ (dense): Linear(in_features=768, out_features=768, bias=True)
796
+ (dropout): Dropout(p=0.0, inplace=False)
797
+ )
798
+ )
799
+ (drop_path): Identity()
800
+ (layernorm_after): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
801
+ (intermediate): ClapAudioIntermediate(
802
+ (dense): Linear(in_features=768, out_features=3072, bias=True)
803
+ (intermediate_act_fn): GELUActivation()
804
+ )
805
+ (output): ClapAudioOutput(
806
+ (dense): Linear(in_features=3072, out_features=768, bias=True)
807
+ (dropout): Dropout(p=0.1, inplace=False)
808
+ )
809
+ )
810
+ )
811
+ )
812
+ )
813
+ (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
814
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
815
+ (avgpool): AdaptiveAvgPool1d(output_size=1)
816
+ )
817
+ )
818
+ (audio_projection): ClapProjectionLayer(
819
+ (linear1): Linear(in_features=768, out_features=512, bias=True)
820
+ (activation): ReLU()
821
+ (linear2): Linear(in_features=512, out_features=512, bias=True)
822
+ )
823
+ )
824
+ (text_attn): TextCrossAttention(
825
+ (q_proj): Linear(in_features=384, out_features=384, bias=True)
826
+ (k_proj): Linear(in_features=512, out_features=384, bias=True)
827
+ (v_proj): Linear(in_features=512, out_features=384, bias=True)
828
+ (attn): MultiheadAttention(
829
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
830
+ )
831
+ (out_mlp): Sequential(
832
+ (0): Linear(in_features=384, out_features=384, bias=True)
833
+ (1): GELU(approximate='none')
834
+ (2): Linear(in_features=384, out_features=384, bias=True)
835
+ )
836
+ (norm_q): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
837
+ (norm_out): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
838
+ )
839
+ (freq_decoder): FreqDecoder(
840
+ (layers): ModuleList(
841
+ (0): Sequential(
842
+ (0): ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
843
+ (1): GroupNorm(1, 192, eps=1e-05, affine=True)
844
+ (2): GELU(approximate='none')
845
+ )
846
+ (1): Sequential(
847
+ (0): ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
848
+ (1): GroupNorm(1, 96, eps=1e-05, affine=True)
849
+ (2): GELU(approximate='none')
850
+ )
851
+ (2): Sequential(
852
+ (0): ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
853
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
854
+ (2): GELU(approximate='none')
855
+ )
856
+ (3): Sequential(
857
+ (0): ConvTranspose2d(48, 4, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
858
+ (1): Identity()
859
+ (2): Identity()
860
+ )
861
+ )
862
+ )
863
+ (time_decoder): TimeDecoder(
864
+ (layers): ModuleList(
865
+ (0): Sequential(
866
+ (0): ConvTranspose1d(384, 192, kernel_size=(8,), stride=(4,), padding=(2,))
867
+ (1): GroupNorm(1, 192, eps=1e-05, affine=True)
868
+ (2): GELU(approximate='none')
869
+ )
870
+ (1): Sequential(
871
+ (0): ConvTranspose1d(192, 96, kernel_size=(8,), stride=(4,), padding=(2,))
872
+ (1): GroupNorm(1, 96, eps=1e-05, affine=True)
873
+ (2): GELU(approximate='none')
874
+ )
875
+ (2): Sequential(
876
+ (0): ConvTranspose1d(96, 48, kernel_size=(8,), stride=(4,), padding=(2,))
877
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
878
+ (2): GELU(approximate='none')
879
+ )
880
+ (3): Sequential(
881
+ (0): ConvTranspose1d(48, 4, kernel_size=(8,), stride=(4,), padding=(2,))
882
+ (1): Identity()
883
+ (2): Identity()
884
+ )
885
+ )
886
+ )
887
+ (freq_out): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1))
888
+ (time_out): Conv1d(4, 2, kernel_size=(1,), stride=(1,))
889
+ )
src/models/stem_separation/AudioTextHTDemucs_Text_Only.txt ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Loading pretrained HTDemucs...
2
+ Loading CLAP model...
3
+ Model Summary:
4
+ AudioTextHTDemucs(
5
+ (htdemucs): HTDemucs(
6
+ (encoder): ModuleList(
7
+ (0): HEncLayer(
8
+ (conv): Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
9
+ (norm1): Identity()
10
+ (rewrite): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))
11
+ (norm2): Identity()
12
+ (dconv): DConv(
13
+ (layers): ModuleList(
14
+ (0): Sequential(
15
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
16
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
17
+ (2): GELU(approximate='none')
18
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
19
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
20
+ (5): GLU(dim=1)
21
+ (6): LayerScale()
22
+ )
23
+ (1): Sequential(
24
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
25
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
26
+ (2): GELU(approximate='none')
27
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
28
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
29
+ (5): GLU(dim=1)
30
+ (6): LayerScale()
31
+ )
32
+ )
33
+ )
34
+ )
35
+ (1): HEncLayer(
36
+ (conv): Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
37
+ (norm1): Identity()
38
+ (rewrite): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))
39
+ (norm2): Identity()
40
+ (dconv): DConv(
41
+ (layers): ModuleList(
42
+ (0): Sequential(
43
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
44
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
45
+ (2): GELU(approximate='none')
46
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
47
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
48
+ (5): GLU(dim=1)
49
+ (6): LayerScale()
50
+ )
51
+ (1): Sequential(
52
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
53
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
54
+ (2): GELU(approximate='none')
55
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
56
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
57
+ (5): GLU(dim=1)
58
+ (6): LayerScale()
59
+ )
60
+ )
61
+ )
62
+ )
63
+ (2): HEncLayer(
64
+ (conv): Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
65
+ (norm1): Identity()
66
+ (rewrite): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))
67
+ (norm2): Identity()
68
+ (dconv): DConv(
69
+ (layers): ModuleList(
70
+ (0): Sequential(
71
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
72
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
73
+ (2): GELU(approximate='none')
74
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
75
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
76
+ (5): GLU(dim=1)
77
+ (6): LayerScale()
78
+ )
79
+ (1): Sequential(
80
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
81
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
82
+ (2): GELU(approximate='none')
83
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
84
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
85
+ (5): GLU(dim=1)
86
+ (6): LayerScale()
87
+ )
88
+ )
89
+ )
90
+ )
91
+ (3): HEncLayer(
92
+ (conv): Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
93
+ (norm1): Identity()
94
+ (rewrite): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
95
+ (norm2): Identity()
96
+ (dconv): DConv(
97
+ (layers): ModuleList(
98
+ (0): Sequential(
99
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
100
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
101
+ (2): GELU(approximate='none')
102
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
103
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
104
+ (5): GLU(dim=1)
105
+ (6): LayerScale()
106
+ )
107
+ (1): Sequential(
108
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
109
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
110
+ (2): GELU(approximate='none')
111
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
112
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
113
+ (5): GLU(dim=1)
114
+ (6): LayerScale()
115
+ )
116
+ )
117
+ )
118
+ )
119
+ )
120
+ (decoder): ModuleList(
121
+ (0): HDecLayer(
122
+ (conv_tr): ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1))
123
+ (norm2): Identity()
124
+ (rewrite): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
125
+ (norm1): Identity()
126
+ (dconv): DConv(
127
+ (layers): ModuleList(
128
+ (0): Sequential(
129
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
130
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
131
+ (2): GELU(approximate='none')
132
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
133
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
134
+ (5): GLU(dim=1)
135
+ (6): LayerScale()
136
+ )
137
+ (1): Sequential(
138
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
139
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
140
+ (2): GELU(approximate='none')
141
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
142
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
143
+ (5): GLU(dim=1)
144
+ (6): LayerScale()
145
+ )
146
+ )
147
+ )
148
+ )
149
+ (1): HDecLayer(
150
+ (conv_tr): ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
151
+ (norm2): Identity()
152
+ (rewrite): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
153
+ (norm1): Identity()
154
+ (dconv): DConv(
155
+ (layers): ModuleList(
156
+ (0): Sequential(
157
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
158
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
159
+ (2): GELU(approximate='none')
160
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
161
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
162
+ (5): GLU(dim=1)
163
+ (6): LayerScale()
164
+ )
165
+ (1): Sequential(
166
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
167
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
168
+ (2): GELU(approximate='none')
169
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
170
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
171
+ (5): GLU(dim=1)
172
+ (6): LayerScale()
173
+ )
174
+ )
175
+ )
176
+ )
177
+ (2): HDecLayer(
178
+ (conv_tr): ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
179
+ (norm2): Identity()
180
+ (rewrite): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
181
+ (norm1): Identity()
182
+ (dconv): DConv(
183
+ (layers): ModuleList(
184
+ (0): Sequential(
185
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
186
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
187
+ (2): GELU(approximate='none')
188
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
189
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
190
+ (5): GLU(dim=1)
191
+ (6): LayerScale()
192
+ )
193
+ (1): Sequential(
194
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
195
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
196
+ (2): GELU(approximate='none')
197
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
198
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
199
+ (5): GLU(dim=1)
200
+ (6): LayerScale()
201
+ )
202
+ )
203
+ )
204
+ )
205
+ (3): HDecLayer(
206
+ (conv_tr): ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))
207
+ (norm2): Identity()
208
+ (rewrite): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
209
+ (norm1): Identity()
210
+ (dconv): DConv(
211
+ (layers): ModuleList(
212
+ (0): Sequential(
213
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
214
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
215
+ (2): GELU(approximate='none')
216
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
217
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
218
+ (5): GLU(dim=1)
219
+ (6): LayerScale()
220
+ )
221
+ (1): Sequential(
222
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
223
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
224
+ (2): GELU(approximate='none')
225
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
226
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
227
+ (5): GLU(dim=1)
228
+ (6): LayerScale()
229
+ )
230
+ )
231
+ )
232
+ )
233
+ )
234
+ (tencoder): ModuleList(
235
+ (0): HEncLayer(
236
+ (conv): Conv1d(2, 48, kernel_size=(8,), stride=(4,), padding=(2,))
237
+ (norm1): Identity()
238
+ (rewrite): Conv1d(48, 96, kernel_size=(1,), stride=(1,))
239
+ (norm2): Identity()
240
+ (dconv): DConv(
241
+ (layers): ModuleList(
242
+ (0): Sequential(
243
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
244
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
245
+ (2): GELU(approximate='none')
246
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
247
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
248
+ (5): GLU(dim=1)
249
+ (6): LayerScale()
250
+ )
251
+ (1): Sequential(
252
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
253
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
254
+ (2): GELU(approximate='none')
255
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
256
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
257
+ (5): GLU(dim=1)
258
+ (6): LayerScale()
259
+ )
260
+ )
261
+ )
262
+ )
263
+ (1): HEncLayer(
264
+ (conv): Conv1d(48, 96, kernel_size=(8,), stride=(4,), padding=(2,))
265
+ (norm1): Identity()
266
+ (rewrite): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
267
+ (norm2): Identity()
268
+ (dconv): DConv(
269
+ (layers): ModuleList(
270
+ (0): Sequential(
271
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
272
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
273
+ (2): GELU(approximate='none')
274
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
275
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
276
+ (5): GLU(dim=1)
277
+ (6): LayerScale()
278
+ )
279
+ (1): Sequential(
280
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
281
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
282
+ (2): GELU(approximate='none')
283
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
284
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
285
+ (5): GLU(dim=1)
286
+ (6): LayerScale()
287
+ )
288
+ )
289
+ )
290
+ )
291
+ (2): HEncLayer(
292
+ (conv): Conv1d(96, 192, kernel_size=(8,), stride=(4,), padding=(2,))
293
+ (norm1): Identity()
294
+ (rewrite): Conv1d(192, 384, kernel_size=(1,), stride=(1,))
295
+ (norm2): Identity()
296
+ (dconv): DConv(
297
+ (layers): ModuleList(
298
+ (0): Sequential(
299
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
300
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
301
+ (2): GELU(approximate='none')
302
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
303
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
304
+ (5): GLU(dim=1)
305
+ (6): LayerScale()
306
+ )
307
+ (1): Sequential(
308
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
309
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
310
+ (2): GELU(approximate='none')
311
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
312
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
313
+ (5): GLU(dim=1)
314
+ (6): LayerScale()
315
+ )
316
+ )
317
+ )
318
+ )
319
+ (3): HEncLayer(
320
+ (conv): Conv1d(192, 384, kernel_size=(8,), stride=(4,), padding=(2,))
321
+ (norm1): Identity()
322
+ (rewrite): Conv1d(384, 768, kernel_size=(1,), stride=(1,))
323
+ (norm2): Identity()
324
+ (dconv): DConv(
325
+ (layers): ModuleList(
326
+ (0): Sequential(
327
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
328
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
329
+ (2): GELU(approximate='none')
330
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
331
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
332
+ (5): GLU(dim=1)
333
+ (6): LayerScale()
334
+ )
335
+ (1): Sequential(
336
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
337
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
338
+ (2): GELU(approximate='none')
339
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
340
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
341
+ (5): GLU(dim=1)
342
+ (6): LayerScale()
343
+ )
344
+ )
345
+ )
346
+ )
347
+ )
348
+ (tdecoder): ModuleList(
349
+ (0): HDecLayer(
350
+ (conv_tr): ConvTranspose1d(384, 192, kernel_size=(8,), stride=(4,))
351
+ (norm2): Identity()
352
+ (rewrite): Conv1d(384, 768, kernel_size=(3,), stride=(1,), padding=(1,))
353
+ (norm1): Identity()
354
+ (dconv): DConv(
355
+ (layers): ModuleList(
356
+ (0): Sequential(
357
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
358
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
359
+ (2): GELU(approximate='none')
360
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
361
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
362
+ (5): GLU(dim=1)
363
+ (6): LayerScale()
364
+ )
365
+ (1): Sequential(
366
+ (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
367
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
368
+ (2): GELU(approximate='none')
369
+ (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
370
+ (4): GroupNorm(1, 768, eps=1e-05, affine=True)
371
+ (5): GLU(dim=1)
372
+ (6): LayerScale()
373
+ )
374
+ )
375
+ )
376
+ )
377
+ (1): HDecLayer(
378
+ (conv_tr): ConvTranspose1d(192, 96, kernel_size=(8,), stride=(4,))
379
+ (norm2): Identity()
380
+ (rewrite): Conv1d(192, 384, kernel_size=(3,), stride=(1,), padding=(1,))
381
+ (norm1): Identity()
382
+ (dconv): DConv(
383
+ (layers): ModuleList(
384
+ (0): Sequential(
385
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
386
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
387
+ (2): GELU(approximate='none')
388
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
389
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
390
+ (5): GLU(dim=1)
391
+ (6): LayerScale()
392
+ )
393
+ (1): Sequential(
394
+ (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
395
+ (1): GroupNorm(1, 24, eps=1e-05, affine=True)
396
+ (2): GELU(approximate='none')
397
+ (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
398
+ (4): GroupNorm(1, 384, eps=1e-05, affine=True)
399
+ (5): GLU(dim=1)
400
+ (6): LayerScale()
401
+ )
402
+ )
403
+ )
404
+ )
405
+ (2): HDecLayer(
406
+ (conv_tr): ConvTranspose1d(96, 48, kernel_size=(8,), stride=(4,))
407
+ (norm2): Identity()
408
+ (rewrite): Conv1d(96, 192, kernel_size=(3,), stride=(1,), padding=(1,))
409
+ (norm1): Identity()
410
+ (dconv): DConv(
411
+ (layers): ModuleList(
412
+ (0): Sequential(
413
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
414
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
415
+ (2): GELU(approximate='none')
416
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
417
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
418
+ (5): GLU(dim=1)
419
+ (6): LayerScale()
420
+ )
421
+ (1): Sequential(
422
+ (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
423
+ (1): GroupNorm(1, 12, eps=1e-05, affine=True)
424
+ (2): GELU(approximate='none')
425
+ (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
426
+ (4): GroupNorm(1, 192, eps=1e-05, affine=True)
427
+ (5): GLU(dim=1)
428
+ (6): LayerScale()
429
+ )
430
+ )
431
+ )
432
+ )
433
+ (3): HDecLayer(
434
+ (conv_tr): ConvTranspose1d(48, 8, kernel_size=(8,), stride=(4,))
435
+ (norm2): Identity()
436
+ (rewrite): Conv1d(48, 96, kernel_size=(3,), stride=(1,), padding=(1,))
437
+ (norm1): Identity()
438
+ (dconv): DConv(
439
+ (layers): ModuleList(
440
+ (0): Sequential(
441
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
442
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
443
+ (2): GELU(approximate='none')
444
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
445
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
446
+ (5): GLU(dim=1)
447
+ (6): LayerScale()
448
+ )
449
+ (1): Sequential(
450
+ (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
451
+ (1): GroupNorm(1, 6, eps=1e-05, affine=True)
452
+ (2): GELU(approximate='none')
453
+ (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
454
+ (4): GroupNorm(1, 96, eps=1e-05, affine=True)
455
+ (5): GLU(dim=1)
456
+ (6): LayerScale()
457
+ )
458
+ )
459
+ )
460
+ )
461
+ )
462
+ (freq_emb): ScaledEmbedding(
463
+ (embedding): Embedding(512, 48)
464
+ )
465
+ (channel_upsampler): Conv1d(384, 512, kernel_size=(1,), stride=(1,))
466
+ (channel_downsampler): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
467
+ (channel_upsampler_t): Conv1d(384, 512, kernel_size=(1,), stride=(1,))
468
+ (channel_downsampler_t): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
469
+ (crosstransformer): CrossTransformerEncoder(
470
+ (norm_in): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
471
+ (norm_in_t): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
472
+ (layers): ModuleList(
473
+ (0): MyTransformerEncoderLayer(
474
+ (self_attn): MultiheadAttention(
475
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
476
+ )
477
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
478
+ (dropout): Dropout(p=0.02, inplace=False)
479
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
480
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
481
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
482
+ (dropout1): Dropout(p=0.02, inplace=False)
483
+ (dropout2): Dropout(p=0.02, inplace=False)
484
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
485
+ (gamma_1): LayerScale()
486
+ (gamma_2): LayerScale()
487
+ )
488
+ (1): CrossTransformerEncoderLayer(
489
+ (cross_attn): MultiheadAttention(
490
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
491
+ )
492
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
493
+ (dropout): Dropout(p=0.02, inplace=False)
494
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
495
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
496
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
497
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
498
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
499
+ (gamma_1): LayerScale()
500
+ (gamma_2): LayerScale()
501
+ (dropout1): Dropout(p=0.02, inplace=False)
502
+ (dropout2): Dropout(p=0.02, inplace=False)
503
+ )
504
+ (2): MyTransformerEncoderLayer(
505
+ (self_attn): MultiheadAttention(
506
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
507
+ )
508
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
509
+ (dropout): Dropout(p=0.02, inplace=False)
510
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
511
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
512
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
513
+ (dropout1): Dropout(p=0.02, inplace=False)
514
+ (dropout2): Dropout(p=0.02, inplace=False)
515
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
516
+ (gamma_1): LayerScale()
517
+ (gamma_2): LayerScale()
518
+ )
519
+ (3): CrossTransformerEncoderLayer(
520
+ (cross_attn): MultiheadAttention(
521
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
522
+ )
523
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
524
+ (dropout): Dropout(p=0.02, inplace=False)
525
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
526
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
527
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
528
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
529
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
530
+ (gamma_1): LayerScale()
531
+ (gamma_2): LayerScale()
532
+ (dropout1): Dropout(p=0.02, inplace=False)
533
+ (dropout2): Dropout(p=0.02, inplace=False)
534
+ )
535
+ (4): MyTransformerEncoderLayer(
536
+ (self_attn): MultiheadAttention(
537
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
538
+ )
539
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
540
+ (dropout): Dropout(p=0.02, inplace=False)
541
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
542
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
543
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
544
+ (dropout1): Dropout(p=0.02, inplace=False)
545
+ (dropout2): Dropout(p=0.02, inplace=False)
546
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
547
+ (gamma_1): LayerScale()
548
+ (gamma_2): LayerScale()
549
+ )
550
+ )
551
+ (layers_t): ModuleList(
552
+ (0): MyTransformerEncoderLayer(
553
+ (self_attn): MultiheadAttention(
554
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
555
+ )
556
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
557
+ (dropout): Dropout(p=0.02, inplace=False)
558
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
559
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
560
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
561
+ (dropout1): Dropout(p=0.02, inplace=False)
562
+ (dropout2): Dropout(p=0.02, inplace=False)
563
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
564
+ (gamma_1): LayerScale()
565
+ (gamma_2): LayerScale()
566
+ )
567
+ (1): CrossTransformerEncoderLayer(
568
+ (cross_attn): MultiheadAttention(
569
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
570
+ )
571
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
572
+ (dropout): Dropout(p=0.02, inplace=False)
573
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
574
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
575
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
576
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
577
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
578
+ (gamma_1): LayerScale()
579
+ (gamma_2): LayerScale()
580
+ (dropout1): Dropout(p=0.02, inplace=False)
581
+ (dropout2): Dropout(p=0.02, inplace=False)
582
+ )
583
+ (2): MyTransformerEncoderLayer(
584
+ (self_attn): MultiheadAttention(
585
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
586
+ )
587
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
588
+ (dropout): Dropout(p=0.02, inplace=False)
589
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
590
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
591
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
592
+ (dropout1): Dropout(p=0.02, inplace=False)
593
+ (dropout2): Dropout(p=0.02, inplace=False)
594
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
595
+ (gamma_1): LayerScale()
596
+ (gamma_2): LayerScale()
597
+ )
598
+ (3): CrossTransformerEncoderLayer(
599
+ (cross_attn): MultiheadAttention(
600
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
601
+ )
602
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
603
+ (dropout): Dropout(p=0.02, inplace=False)
604
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
605
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
606
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
607
+ (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
608
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
609
+ (gamma_1): LayerScale()
610
+ (gamma_2): LayerScale()
611
+ (dropout1): Dropout(p=0.02, inplace=False)
612
+ (dropout2): Dropout(p=0.02, inplace=False)
613
+ )
614
+ (4): MyTransformerEncoderLayer(
615
+ (self_attn): MultiheadAttention(
616
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
617
+ )
618
+ (linear1): Linear(in_features=512, out_features=2048, bias=True)
619
+ (dropout): Dropout(p=0.02, inplace=False)
620
+ (linear2): Linear(in_features=2048, out_features=512, bias=True)
621
+ (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
622
+ (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
623
+ (dropout1): Dropout(p=0.02, inplace=False)
624
+ (dropout2): Dropout(p=0.02, inplace=False)
625
+ (norm_out): MyGroupNorm(1, 512, eps=1e-05, affine=True)
626
+ (gamma_1): LayerScale()
627
+ (gamma_2): LayerScale()
628
+ )
629
+ )
630
+ )
631
+ )
632
+ (clap): ClapTextModelWithProjection(
633
+ (text_model): ClapTextModel(
634
+ (embeddings): ClapTextEmbeddings(
635
+ (word_embeddings): Embedding(50265, 768, padding_idx=1)
636
+ (position_embeddings): Embedding(514, 768, padding_idx=1)
637
+ (token_type_embeddings): Embedding(1, 768)
638
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
639
+ (dropout): Dropout(p=0.1, inplace=False)
640
+ )
641
+ (encoder): ClapTextEncoder(
642
+ (layer): ModuleList(
643
+ (0-11): 12 x ClapTextLayer(
644
+ (attention): ClapTextAttention(
645
+ (self): ClapTextSelfAttention(
646
+ (query): Linear(in_features=768, out_features=768, bias=True)
647
+ (key): Linear(in_features=768, out_features=768, bias=True)
648
+ (value): Linear(in_features=768, out_features=768, bias=True)
649
+ (dropout): Dropout(p=0.1, inplace=False)
650
+ )
651
+ (output): ClapTextSelfOutput(
652
+ (dense): Linear(in_features=768, out_features=768, bias=True)
653
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
654
+ (dropout): Dropout(p=0.1, inplace=False)
655
+ )
656
+ )
657
+ (intermediate): ClapTextIntermediate(
658
+ (dense): Linear(in_features=768, out_features=3072, bias=True)
659
+ (intermediate_act_fn): GELUActivation()
660
+ )
661
+ (output): ClapTextOutput(
662
+ (dense): Linear(in_features=3072, out_features=768, bias=True)
663
+ (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
664
+ (dropout): Dropout(p=0.1, inplace=False)
665
+ )
666
+ )
667
+ )
668
+ )
669
+ (pooler): ClapTextPooler(
670
+ (dense): Linear(in_features=768, out_features=768, bias=True)
671
+ (activation): Tanh()
672
+ )
673
+ )
674
+ (text_projection): ClapProjectionLayer(
675
+ (linear1): Linear(in_features=768, out_features=512, bias=True)
676
+ (activation): ReLU()
677
+ (linear2): Linear(in_features=512, out_features=512, bias=True)
678
+ )
679
+ )
680
+ (text_attn): TextCrossAttention(
681
+ (q_proj): Linear(in_features=384, out_features=384, bias=True)
682
+ (k_proj): Linear(in_features=512, out_features=384, bias=True)
683
+ (v_proj): Linear(in_features=512, out_features=384, bias=True)
684
+ (attn): MultiheadAttention(
685
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
686
+ )
687
+ (out_mlp): Sequential(
688
+ (0): Linear(in_features=384, out_features=384, bias=True)
689
+ (1): GELU(approximate='none')
690
+ (2): Linear(in_features=384, out_features=384, bias=True)
691
+ )
692
+ (norm_q): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
693
+ (norm_out): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
694
+ )
695
+ (freq_decoder): FreqDecoder(
696
+ (layers): ModuleList(
697
+ (0): Sequential(
698
+ (0): ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
699
+ (1): GroupNorm(1, 192, eps=1e-05, affine=True)
700
+ (2): GELU(approximate='none')
701
+ )
702
+ (1): Sequential(
703
+ (0): ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
704
+ (1): GroupNorm(1, 96, eps=1e-05, affine=True)
705
+ (2): GELU(approximate='none')
706
+ )
707
+ (2): Sequential(
708
+ (0): ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
709
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
710
+ (2): GELU(approximate='none')
711
+ )
712
+ (3): Sequential(
713
+ (0): ConvTranspose2d(48, 4, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
714
+ (1): Identity()
715
+ (2): Identity()
716
+ )
717
+ )
718
+ )
719
+ (time_decoder): TimeDecoder(
720
+ (layers): ModuleList(
721
+ (0): Sequential(
722
+ (0): ConvTranspose1d(384, 192, kernel_size=(8,), stride=(4,), padding=(2,))
723
+ (1): GroupNorm(1, 192, eps=1e-05, affine=True)
724
+ (2): GELU(approximate='none')
725
+ )
726
+ (1): Sequential(
727
+ (0): ConvTranspose1d(192, 96, kernel_size=(8,), stride=(4,), padding=(2,))
728
+ (1): GroupNorm(1, 96, eps=1e-05, affine=True)
729
+ (2): GELU(approximate='none')
730
+ )
731
+ (2): Sequential(
732
+ (0): ConvTranspose1d(96, 48, kernel_size=(8,), stride=(4,), padding=(2,))
733
+ (1): GroupNorm(1, 48, eps=1e-05, affine=True)
734
+ (2): GELU(approximate='none')
735
+ )
736
+ (3): Sequential(
737
+ (0): ConvTranspose1d(48, 4, kernel_size=(8,), stride=(4,), padding=(2,))
738
+ (1): Identity()
739
+ (2): Identity()
740
+ )
741
+ )
742
+ )
743
+ (freq_out): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1))
744
+ (time_out): Conv1d(4, 2, kernel_size=(1,), stride=(1,))
745
+ )
src/models/stem_separation/CLAP_Text_Model_Fwd_Pass.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ** NOTE: get_text_features() method does a final projection to 512 dims and normalization after this forward pass
2
+
3
+ =========================================================================================================
4
+ Layer (type:depth-idx) Output Shape Param #
5
+ =========================================================================================================
6
+ ClapTextModel [1, 768] --
7
+ ├─ClapTextEmbeddings: 1-1 [1, 5, 768] --
8
+ │ └─Embedding: 2-1 [1, 5, 768] (38,603,520)
9
+ │ └─Embedding: 2-2 [1, 5, 768] (768)
10
+ │ └─Embedding: 2-3 [1, 5, 768] (394,752)
11
+ │ └─LayerNorm: 2-4 [1, 5, 768] (1,536)
12
+ │ └─Dropout: 2-5 [1, 5, 768] --
13
+ ├─ClapTextEncoder: 1-2 [1, 5, 768] --
14
+ │ └─ModuleList: 2-6 -- --
15
+ │ │ └─ClapTextLayer: 3-1 [1, 5, 768] (7,087,872)
16
+ │ │ └─ClapTextLayer: 3-2 [1, 5, 768] (7,087,872)
17
+ │ │ └─ClapTextLayer: 3-3 [1, 5, 768] (7,087,872)
18
+ │ │ └─ClapTextLayer: 3-4 [1, 5, 768] (7,087,872)
19
+ │ │ └─ClapTextLayer: 3-5 [1, 5, 768] (7,087,872)
20
+ │ │ └─ClapTextLayer: 3-6 [1, 5, 768] (7,087,872)
21
+ │ │ └─ClapTextLayer: 3-7 [1, 5, 768] (7,087,872)
22
+ │ │ └─ClapTextLayer: 3-8 [1, 5, 768] (7,087,872)
23
+ │ │ └─ClapTextLayer: 3-9 [1, 5, 768] (7,087,872)
24
+ │ │ └─ClapTextLayer: 3-10 [1, 5, 768] (7,087,872)
25
+ │ │ └─ClapTextLayer: 3-11 [1, 5, 768] (7,087,872)
26
+ │ │ └─ClapTextLayer: 3-12 [1, 5, 768] (7,087,872)
27
+ ├─ClapTextPooler: 1-3 [1, 768] --
28
+ │ └─Linear: 2-7 [1, 768] (590,592)
29
+ │ └─Tanh: 2-8 [1, 768] --
30
+ =========================================================================================================
31
+ Total params: 124,645,632
32
+ Trainable params: 0
33
+ Non-trainable params: 124,645,632
34
+ Total mult-adds (Units.MEGABYTES): 124.65
35
+ =========================================================================================================
36
+ Input size (MB): 0.00
37
+ Forward/backward pass size (MB): 4.18
38
+ Params size (MB): 498.58
39
+ Estimated Total Size (MB): 502.77
40
+ =========================================================================================================
src/models/stem_separation/HTDemucs_Fwd_Pass.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ====================================================================================================
2
+ Layer (type:depth-idx) Output Shape Param #
3
+ ====================================================================================================
4
+ HTDemucs [1, 4, 2, 264600] --
5
+ ├─ModuleList: 1-8 -- (recursive)
6
+ │ └─HEncLayer: 2-1 [1, 48, 85995] --
7
+ │ │ └─Conv1d: 3-1 [1, 48, 85995] (816)
8
+ │ │ └─Identity: 3-2 [1, 48, 85995] --
9
+ │ │ └─DConv: 3-3 [1, 48, 85995] (3,588)
10
+ │ │ └─Conv1d: 3-4 [1, 96, 85995] (4,704)
11
+ │ │ └─Identity: 3-5 [1, 96, 85995] --
12
+ ├─ModuleList: 1-9 -- (recursive)
13
+ │ └─HEncLayer: 2-2 [1, 48, 512, 336] --
14
+ │ │ └─Conv2d: 3-6 [1, 48, 512, 336] (1,584)
15
+ │ │ └─Identity: 3-7 [1, 48, 512, 336] --
16
+ │ │ └─DConv: 3-8 [512, 48, 336] (3,588)
17
+ │ │ └─Conv2d: 3-9 [1, 96, 512, 336] (4,704)
18
+ │ │ └─Identity: 3-10 [1, 96, 512, 336] --
19
+ ├─ScaledEmbedding: 1-3 [512, 48] --
20
+ │ └─Embedding: 2-3 [512, 48] (24,576)
21
+ ├─ModuleList: 1-8 -- (recursive)
22
+ │ └─HEncLayer: 2-4 [1, 96, 21499] --
23
+ │ │ └─Conv1d: 3-11 [1, 96, 21499] (36,960)
24
+ │ │ └─Identity: 3-12 [1, 96, 21499] --
25
+ │ │ └─DConv: 3-13 [1, 96, 21499] (12,936)
26
+ │ │ └─Conv1d: 3-14 [1, 192, 21499] (18,624)
27
+ │ │ └─Identity: 3-15 [1, 192, 21499] --
28
+ ├─ModuleList: 1-9 -- (recursive)
29
+ │ └─HEncLayer: 2-5 [1, 96, 128, 336] --
30
+ │ │ └─Conv2d: 3-16 [1, 96, 128, 336] (36,960)
31
+ │ │ └─Identity: 3-17 [1, 96, 128, 336] --
32
+ │ │ └─DConv: 3-18 [128, 96, 336] (12,936)
33
+ │ │ └─Conv2d: 3-19 [1, 192, 128, 336] (18,624)
34
+ │ │ └─Identity: 3-20 [1, 192, 128, 336] --
35
+ ├─ModuleList: 1-8 -- (recursive)
36
+ │ └─HEncLayer: 2-6 [1, 192, 5375] --
37
+ │ │ └─Conv1d: 3-21 [1, 192, 5375] (147,648)
38
+ │ │ └─Identity: 3-22 [1, 192, 5375] --
39
+ │ │ └─DConv: 3-23 [1, 192, 5375] (48,912)
40
+ │ │ └─Conv1d: 3-24 [1, 384, 5375] (74,112)
41
+ │ │ └─Identity: 3-25 [1, 384, 5375] --
42
+ ├─ModuleList: 1-9 -- (recursive)
43
+ │ └─HEncLayer: 2-7 [1, 192, 32, 336] --
44
+ │ │ └─Conv2d: 3-26 [1, 192, 32, 336] (147,648)
45
+ │ │ └─Identity: 3-27 [1, 192, 32, 336] --
46
+ │ │ └─DConv: 3-28 [32, 192, 336] (48,912)
47
+ │ │ └─Conv2d: 3-29 [1, 384, 32, 336] (74,112)
48
+ │ │ └─Identity: 3-30 [1, 384, 32, 336] --
49
+ ├─ModuleList: 1-8 -- (recursive)
50
+ │ └─HEncLayer: 2-8 [1, 384, 1344] --
51
+ │ │ └─Conv1d: 3-31 [1, 384, 1344] (590,208)
52
+ │ │ └─Identity: 3-32 [1, 384, 1344] --
53
+ │ │ └─DConv: 3-33 [1, 384, 1344] (189,984)
54
+ │ │ └─Conv1d: 3-34 [1, 768, 1344] (295,680)
55
+ │ │ └─Identity: 3-35 [1, 768, 1344] --
56
+ ├─ModuleList: 1-9 -- (recursive)
57
+ │ └─HEncLayer: 2-9 [1, 384, 8, 336] --
58
+ │ │ └─Conv2d: 3-36 [1, 384, 8, 336] (590,208)
59
+ │ │ └─Identity: 3-37 [1, 384, 8, 336] --
60
+ │ │ └─DConv: 3-38 [8, 384, 336] (189,984)
61
+ │ │ └─Conv2d: 3-39 [1, 768, 8, 336] (295,680)
62
+ │ │ └─Identity: 3-40 [1, 768, 8, 336] --
63
+ ├─Conv1d: 1-10 [1, 512, 2688] (197,120)
64
+ ├─Conv1d: 1-11 [1, 512, 1344] (197,120)
65
+ ├─CrossTransformerEncoder: 1-12 [1, 512, 8, 336] --
66
+ │ └─LayerNorm: 2-10 [1, 2688, 512] (1,024)
67
+ │ └─LayerNorm: 2-11 [1, 1344, 512] (1,024)
68
+ │ └─ModuleList: 2-20 -- (recursive)
69
+ │ │ └─MyTransformerEncoderLayer: 3-41 [1, 2688, 512] (3,154,432)
70
+ │ └─ModuleList: 2-21 -- (recursive)
71
+ │ │ └─MyTransformerEncoderLayer: 3-42 [1, 1344, 512] (3,154,432)
72
+ │ └─ModuleList: 2-20 -- (recursive)
73
+ │ │ └─CrossTransformerEncoderLayer: 3-43 [1, 2688, 512] (3,155,456)
74
+ │ └─ModuleList: 2-21 -- (recursive)
75
+ │ │ └─CrossTransformerEncoderLayer: 3-44 [1, 1344, 512] (3,155,456)
76
+ │ └─ModuleList: 2-20 -- (recursive)
77
+ │ │ └─MyTransformerEncoderLayer: 3-45 [1, 2688, 512] (3,154,432)
78
+ │ └─ModuleList: 2-21 -- (recursive)
79
+ │ │ └─MyTransformerEncoderLayer: 3-46 [1, 1344, 512] (3,154,432)
80
+ │ └─ModuleList: 2-20 -- (recursive)
81
+ │ │ └─CrossTransformerEncoderLayer: 3-47 [1, 2688, 512] (3,155,456)
82
+ │ └─ModuleList: 2-21 -- (recursive)
83
+ │ │ └─CrossTransformerEncoderLayer: 3-48 [1, 1344, 512] (3,155,456)
84
+ │ └─ModuleList: 2-20 -- (recursive)
85
+ │ │ └─MyTransformerEncoderLayer: 3-49 [1, 2688, 512] (3,154,432)
86
+ │ └─ModuleList: 2-21 -- (recursive)
87
+ │ │ └─MyTransformerEncoderLayer: 3-50 [1, 1344, 512] (3,154,432)
88
+ ├─Conv1d: 1-13 [1, 384, 2688] (196,992)
89
+ ├─Conv1d: 1-14 [1, 384, 1344] (196,992)
90
+ ├─ModuleList: 1-21 -- (recursive)
91
+ │ └─HDecLayer: 2-22 [1, 192, 32, 336] --
92
+ │ │ └─Conv2d: 3-51 [1, 768, 8, 336] (2,654,976)
93
+ │ │ └─Identity: 3-52 [1, 768, 8, 336] --
94
+ │ │ └─DConv: 3-53 [8, 384, 336] (189,984)
95
+ │ │ └─ConvTranspose2d: 3-54 [1, 192, 36, 336] (590,016)
96
+ │ │ └─Identity: 3-55 [1, 192, 36, 336] --
97
+ ├─ModuleList: 1-22 -- (recursive)
98
+ │ └─HDecLayer: 2-23 [1, 192, 5375] --
99
+ │ │ └─Conv1d: 3-56 [1, 768, 1344] (885,504)
100
+ │ │ └─Identity: 3-57 [1, 768, 1344] --
101
+ │ │ └─DConv: 3-58 [1, 384, 1344] (189,984)
102
+ │ │ └─ConvTranspose1d: 3-59 [1, 192, 5380] (590,016)
103
+ │ │ └─Identity: 3-60 [1, 192, 5380] --
104
+ ├─ModuleList: 1-21 -- (recursive)
105
+ │ └─HDecLayer: 2-24 [1, 96, 128, 336] --
106
+ │ │ └─Conv2d: 3-61 [1, 384, 32, 336] (663,936)
107
+ │ │ └─Identity: 3-62 [1, 384, 32, 336] --
108
+ │ │ └─DConv: 3-63 [32, 192, 336] (48,912)
109
+ │ │ └─ConvTranspose2d: 3-64 [1, 96, 132, 336] (147,552)
110
+ │ │ └─Identity: 3-65 [1, 96, 132, 336] --
111
+ ���─ModuleList: 1-22 -- (recursive)
112
+ │ └─HDecLayer: 2-25 [1, 96, 21499] --
113
+ │ │ └─Conv1d: 3-66 [1, 384, 5375] (221,568)
114
+ │ │ └─Identity: 3-67 [1, 384, 5375] --
115
+ │ │ └─DConv: 3-68 [1, 192, 5375] (48,912)
116
+ │ │ └─ConvTranspose1d: 3-69 [1, 96, 21504] (147,552)
117
+ │ │ └─Identity: 3-70 [1, 96, 21504] --
118
+ ├─ModuleList: 1-21 -- (recursive)
119
+ │ └─HDecLayer: 2-26 [1, 48, 512, 336] --
120
+ │ │ └─Conv2d: 3-71 [1, 192, 128, 336] (166,080)
121
+ │ │ └─Identity: 3-72 [1, 192, 128, 336] --
122
+ │ │ └─DConv: 3-73 [128, 96, 336] (12,936)
123
+ │ │ └─ConvTranspose2d: 3-74 [1, 48, 516, 336] (36,912)
124
+ │ │ └─Identity: 3-75 [1, 48, 516, 336] --
125
+ ├─ModuleList: 1-22 -- (recursive)
126
+ │ └─HDecLayer: 2-27 [1, 48, 85995] --
127
+ │ │ └─Conv1d: 3-76 [1, 192, 21499] (55,488)
128
+ │ │ └─Identity: 3-77 [1, 192, 21499] --
129
+ │ │ └─DConv: 3-78 [1, 96, 21499] (12,936)
130
+ │ │ └─ConvTranspose1d: 3-79 [1, 48, 86000] (36,912)
131
+ │ │ └─Identity: 3-80 [1, 48, 86000] --
132
+ ├─ModuleList: 1-21 -- (recursive)
133
+ │ └─HDecLayer: 2-28 [1, 16, 2048, 336] --
134
+ │ │ └─Conv2d: 3-81 [1, 96, 512, 336] (41,568)
135
+ │ │ └─Identity: 3-82 [1, 96, 512, 336] --
136
+ │ │ └─DConv: 3-83 [512, 48, 336] (3,588)
137
+ │ │ └─ConvTranspose2d: 3-84 [1, 16, 2052, 336] (6,160)
138
+ │ │ └─Identity: 3-85 [1, 16, 2052, 336] --
139
+ ├─ModuleList: 1-22 -- (recursive)
140
+ │ └─HDecLayer: 2-29 [1, 8, 343980] --
141
+ │ │ └─Conv1d: 3-86 [1, 96, 85995] (13,920)
142
+ │ │ └─Identity: 3-87 [1, 96, 85995] --
143
+ │ │ └─DConv: 3-88 [1, 48, 85995] (3,588)
144
+ │ │ └─ConvTranspose1d: 3-89 [1, 8, 343984] (3,080)
145
+ │ │ └─Identity: 3-90 [1, 8, 343984] --
146
+ ====================================================================================================
147
+ Total params: 41,984,456
148
+ Trainable params: 0
149
+ Non-trainable params: 41,984,456
150
+ Total mult-adds (Units.GIGABYTES): 88.31
151
+ ====================================================================================================
152
+ Input size (MB): 2.12
153
+ Forward/backward pass size (MB): 6021.99
154
+ Params size (MB): 125.91
155
+ Estimated Total Size (MB): 6150.02
156
+ ====================================================================================================
src/models/stem_separation/__init__.py ADDED
File without changes
src/models/stem_separation/__pycache__/ATHTDemucs_v2.cpython-313.pyc ADDED
Binary file (18.6 kB). View file
 
src/models/stem_separation/__pycache__/AudioTextHTDemucs.cpython-313.pyc ADDED
Binary file (14.2 kB). View file
 
src/models/stem_separation/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (172 Bytes). View file
 
src/train.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, Optional
3
+ import torch
4
+ from torch.utils.data import DataLoader, Subset
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from torch.cuda.amp import GradScaler, autocast
8
+ from tqdm import tqdm
9
+
10
+ from demucs import pretrained
11
+ from transformers import AutoTokenizer, ClapModel, ClapTextModelWithProjection
12
+
13
+ from src.models.stem_separation.ATHTDemucs_v2 import AudioTextHTDemucs
14
+ from src.loss import combined_loss, combined_L1_sdr_loss, sdr_loss
15
+ from src.dataloader import MusDBStemDataset, collate_fn, STEM_PROMPTS, PROMPT_TO_STEM
16
+ from utils import load_config, log_separation_spectrograms_to_wandb, log_audio_to_wandb
17
+
18
+
19
+ # ============================================================================
20
+ # Training Helper Functions
21
+ # ============================================================================
22
+
23
+ def train_epoch(
24
+ model: AudioTextHTDemucs,
25
+ dataloader: DataLoader,
26
+ optimizer: torch.optim.Optimizer,
27
+ scaler: Optional[GradScaler],
28
+ device: str,
29
+ use_amp: bool,
30
+ use_L1_cmb_loss: bool,
31
+ l1_sdr_weight: Optional[float],
32
+ l1_weight: Optional[float],
33
+ grad_clip: float,
34
+ sdr_weight: float,
35
+ sisdr_weight: float,
36
+ epoch: int,
37
+ log_every: int,
38
+ use_wandb: bool,
39
+ ) -> Dict[str, float]:
40
+ """Train for one epoch."""
41
+ model.train()
42
+
43
+ total_loss = 0.0
44
+ total_sdr = 0.0
45
+ total_sisdr = 0.0
46
+ num_batches = 0
47
+
48
+ # Set loss function
49
+ if use_L1_cmb_loss:
50
+ loss_function = combined_L1_sdr_loss
51
+ weight1 = l1_sdr_weight
52
+ if l1_weight is None:
53
+ raise ValueError("l1_weight must be provided when using L1 combination loss.")
54
+ weight2 = l1_weight
55
+ print("**Using L1 + SDR combination loss for training")
56
+ else:
57
+ loss_function = combined_loss
58
+ weight1 = sdr_weight
59
+ weight2 = sisdr_weight
60
+
61
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
62
+
63
+ for batch_idx, batch in enumerate(pbar):
64
+ mixture = batch["mixture"].to(device)
65
+ target = batch["target"].to(device)
66
+ prompts = batch["prompt"]
67
+
68
+ optimizer.zero_grad()
69
+
70
+ # TODO: Add L1 + SDR combination loss option
71
+
72
+ if use_amp and device == "cuda":
73
+ with autocast():
74
+ estimated = model(mixture, prompts)
75
+ loss, metrics = loss_function(
76
+ estimated, target, weight1, weight2
77
+ )
78
+ scaler.scale(loss).backward()
79
+ scaler.unscale_(optimizer)
80
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
81
+ scaler.step(optimizer)
82
+ scaler.update()
83
+ else:
84
+ estimated = model(mixture, prompts)
85
+ loss, metrics = loss_function(
86
+ estimated, target, weight1, weight2
87
+ )
88
+ loss.backward()
89
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
90
+ optimizer.step()
91
+
92
+ total_loss += metrics["loss/total"]
93
+ total_sdr += metrics["metrics/sdr"]
94
+ total_sisdr += metrics["metrics/sisdr"]
95
+ num_batches += 1
96
+
97
+ pbar.set_postfix({
98
+ "loss": f"{metrics['loss/total']:.4f}",
99
+ "SDR": f"{metrics['metrics/sdr']:.2f}",
100
+ })
101
+
102
+ if use_wandb and batch_idx % log_every == 0:
103
+ import wandb
104
+ wandb.log({
105
+ "train/loss": metrics["loss/total"],
106
+ "train/sdr": metrics["metrics/sdr"],
107
+ "train/sisdr": metrics["metrics/sisdr"],
108
+ "train/step": epoch * len(dataloader) + batch_idx,
109
+ })
110
+ # Plot spectrograms for first sample in batch and log to wandb
111
+ # NOTE: For now, only 1 extracted stem is visualized (should be extended to all stems later)
112
+ stem_name_log = PROMPT_TO_STEM[prompts[0]]
113
+ log_separation_spectrograms_to_wandb(
114
+ mixture=mixture[0],
115
+ estimated=estimated[0],
116
+ reference=target[0],
117
+ stem_name=stem_name_log,
118
+ step=epoch * len(dataloader) + batch_idx,
119
+ )
120
+ # Log audio to wandb
121
+ log_audio_to_wandb(mixture[0], "mixture", is_gt=True)
122
+ log_audio_to_wandb(target[0], stem_name_log, is_gt=True)
123
+ log_audio_to_wandb(estimated[0], stem_name_log, is_gt=False)
124
+
125
+ return {
126
+ "loss": total_loss / num_batches,
127
+ "sdr": total_sdr / num_batches,
128
+ "sisdr": total_sisdr / num_batches,
129
+ }
130
+
131
+
132
+ @torch.no_grad()
133
+ def validate(
134
+ model: AudioTextHTDemucs,
135
+ dataloader: DataLoader,
136
+ device: str,
137
+ use_amp: bool,
138
+ use_L1_cmb_loss: bool,
139
+ l1_sdr_weight: Optional[float],
140
+ l1_weight: Optional[float],
141
+ sdr_weight: float = 0.9,
142
+ sisdr_weight: float = 0.1,
143
+ ) -> Dict[str, float]:
144
+ """Validate the model."""
145
+ model.eval()
146
+
147
+ total_loss = 0.0
148
+ total_sdr = 0.0
149
+ total_sisdr = 0.0
150
+ num_batches = 0
151
+
152
+ stem_metrics = {name: {"sdr": 0.0, "count": 0} for name in STEM_PROMPTS.keys()}
153
+
154
+ # Set loss function
155
+ if use_L1_cmb_loss:
156
+ loss_function = combined_L1_sdr_loss
157
+ weight1 = l1_sdr_weight
158
+ if l1_weight is None:
159
+ raise ValueError("l1_weight must be provided when using L1 combination loss.")
160
+ weight2 = l1_weight
161
+ else:
162
+ loss_function = combined_loss
163
+ weight1 = sdr_weight
164
+ weight2 = sisdr_weight
165
+
166
+ for batch in tqdm(dataloader, desc="Validating"):
167
+ mixture = batch["mixture"].to(device)
168
+ target = batch["target"].to(device)
169
+ prompts = batch["prompt"]
170
+ stem_names = batch["stem_name"]
171
+
172
+ if use_amp and device == "cuda":
173
+ with autocast():
174
+ estimated = model(mixture, prompts)
175
+ loss, metrics = loss_function(estimated, target, weight1, weight2)
176
+ else:
177
+ estimated = model(mixture, prompts)
178
+ loss, metrics = loss_function(estimated, target, weight1, weight2)
179
+
180
+ total_loss += metrics["loss/total"]
181
+ total_sdr += metrics["metrics/sdr"]
182
+ total_sisdr += metrics["metrics/sisdr"]
183
+ num_batches += 1
184
+
185
+ for i, stem_name in enumerate(stem_names):
186
+ est_i = estimated[i:i + 1]
187
+ tgt_i = target[i:i + 1]
188
+ sdr_i = -sdr_loss(est_i, tgt_i).item()
189
+ stem_metrics[stem_name]["sdr"] += sdr_i
190
+ stem_metrics[stem_name]["count"] += 1
191
+
192
+ avg_metrics = {
193
+ "loss": total_loss / num_batches,
194
+ "sdr": total_sdr / num_batches,
195
+ "sisdr": total_sisdr / num_batches,
196
+ }
197
+
198
+ for stem_name, data in stem_metrics.items():
199
+ if data["count"] > 0:
200
+ avg_metrics[f"sdr/{stem_name}"] = data["sdr"] / data["count"]
201
+
202
+ return avg_metrics
203
+
204
+
205
+ def save_checkpoint(
206
+ model: AudioTextHTDemucs,
207
+ optimizer: torch.optim.Optimizer,
208
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
209
+ epoch: int,
210
+ metrics: Dict[str, float],
211
+ checkpoint_dir: str,
212
+ is_best: bool = False,
213
+ ):
214
+ """Save a training checkpoint."""
215
+ checkpoint_path = Path(checkpoint_dir)
216
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
217
+
218
+ checkpoint = {
219
+ "epoch": epoch,
220
+ "model_state_dict": model.state_dict(),
221
+ "optimizer_state_dict": optimizer.state_dict(),
222
+ "scheduler_state_dict": scheduler.state_dict(),
223
+ "metrics": metrics,
224
+ }
225
+
226
+ path = checkpoint_path / f"checkpoint_epoch_{epoch}.pt"
227
+ torch.save(checkpoint, path)
228
+ print(f"Saved checkpoint to {path}")
229
+
230
+ if is_best:
231
+ best_path = checkpoint_path / "best_model.pt"
232
+ torch.save(checkpoint, best_path)
233
+ print(f"Saved best model to {best_path}")
234
+
235
+ latest_path = checkpoint_path / "latest.pt"
236
+ torch.save(checkpoint, latest_path)
237
+
238
+
239
+ def load_checkpoint(
240
+ model: AudioTextHTDemucs,
241
+ optimizer: Optional[torch.optim.Optimizer],
242
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
243
+ checkpoint_path: str,
244
+ ) -> int:
245
+ """
246
+ Load a checkpoint and return the epoch number.
247
+
248
+ Ignores any unused weights (e.g. if ClapTextModelWithProjection is being used but checkpoint has ClapModel with audio encoder weights).
249
+ Also applies to optimizer and scheduler.
250
+ """
251
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
252
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
253
+
254
+ # Try loading optimizer and scheduler state, but ignore mismatches (due to new CLAP model, etc)
255
+ try:
256
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
257
+ except Exception as e:
258
+ print("Skipping optimizer state...")
259
+
260
+ # Same idea for scheduler
261
+ try:
262
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
263
+ except:
264
+ print("Skipping scheduler state...")
265
+
266
+ print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
267
+ return checkpoint["epoch"]
268
+
269
+
270
+ # ============================================================================
271
+ # Main Training Function
272
+ # ============================================================================
273
+
274
+ def train(config_path):
275
+ """
276
+ Main training function for AudioTextHTDemucs.
277
+
278
+ Args (loaded from YAML config):
279
+ train_dir: Path to training data directory
280
+ test_dir: Path to test/validation data directory
281
+ checkpoint_dir: Path to save checkpoints
282
+ sample_rate: Audio sample rate
283
+ segment_seconds: Length of audio segments in seconds
284
+ batch_size: Training batch size
285
+ num_workers: Number of dataloader workers
286
+ epochs: Number of training epochs
287
+ learning_rate: Initial learning rate
288
+ weight_decay: AdamW weight decay
289
+ grad_clip: Gradient clipping value
290
+ sdr_weight: Weight for SDR loss component
291
+ sisdr_weight: Weight for SI-SDR loss component
292
+ model_dim: Model hidden dimension
293
+ text_dim: Text embedding dimension
294
+ n_heads: Number of attention heads
295
+ use_wandb: Whether to use Weights & Biases logging
296
+ wandb_project: W&B project name
297
+ wandb_run_name: W&B run name (optional)
298
+ log_every: Log training metrics every N batches
299
+ validate_every: Run validation every N epochs
300
+ save_every: Save checkpoint every N epochs
301
+ use_amp: Use automatic mixed precision
302
+ device: Device to train on (auto-detected if None)
303
+ resume_from: Path to checkpoint to resume from (optional)
304
+
305
+ Returns:
306
+ Dict containing final metrics and best SDR achieved
307
+ """
308
+ # Load configuration
309
+ cfg = load_config(config_path)
310
+ data_cfg = cfg["data"]
311
+ model_cfg = cfg["model"]
312
+ training_cfg = cfg["training"]
313
+ wandb_cfg = cfg["wandb"]
314
+ # Paths
315
+ train_dir = data_cfg.get("train_dir", "../data/train")
316
+ test_dir = data_cfg.get("test_dir", "../data/test")
317
+ checkpoint_dir = wandb_cfg.get("checkpoint_dir", "../checkpoints")
318
+ # Data splits
319
+ pct_train = data_cfg.get("pct_train", 1.0)
320
+ pct_test = data_cfg.get("pct_test", 1.0)
321
+ # Audio parameters
322
+ sample_rate = data_cfg.get("sample_rate", 44100)
323
+ segment_seconds = data_cfg.get("segment_seconds", 6.0)
324
+ # Training parameters
325
+ batch_size = training_cfg.get("batch_size", 4)
326
+ num_workers = training_cfg.get("num_workers", 0)
327
+ epochs = training_cfg.get("num_epochs", 10)
328
+ learning_rate = float(training_cfg["optimizer"].get("lr", 1e-4))
329
+ weight_decay = float(training_cfg["optimizer"].get("weight_decay", 1e-5))
330
+ grad_clip = training_cfg["optimizer"].get("grad_clip", 1.0)
331
+ use_L1_cmb_loss = training_cfg.get("use_L1_comb_loss", False)
332
+ l1_sdr_weight = training_cfg["L1_comb_loss"].get("sdr_weight", 1.0)
333
+ l1_weight = training_cfg["L1_comb_loss"].get("l1_weight", 0.05)
334
+ # Loss weights
335
+ sdr_weight = training_cfg["loss_weights"].get("sdr", 0.9)
336
+ sisdr_weight = training_cfg["loss_weights"].get("sisdr", 0.1)
337
+ # Model parameters
338
+ model_dim = model_cfg.get("model_dim", 384)
339
+ text_dim = model_cfg.get("text_dim", 512)
340
+ n_heads = model_cfg.get("n_heads", 8)
341
+ # Logging
342
+ use_wandb = wandb_cfg.get("use_wandb", True)
343
+ wandb_project = wandb_cfg.get("project", "audio-text-htdemucs")
344
+ wandb_run_name = wandb_cfg.get("run_name", None)
345
+ log_every = wandb_cfg.get("log_every", 50)
346
+ validate_every = wandb_cfg.get("validate_every", 1)
347
+ save_every = wandb_cfg.get("save_every", 1)
348
+ # Mixed precision
349
+ use_amp = training_cfg.get("use_amp", False)
350
+ # Device
351
+ device = model_cfg.get("device", None)
352
+ # Resume training
353
+ resume_from = training_cfg.get("resume_from", None)
354
+
355
+ # Auto-detect device
356
+ if device is None:
357
+ device = "cuda" if torch.cuda.is_available() else "cpu"
358
+
359
+ segment_samples = int(sample_rate * segment_seconds)
360
+
361
+ # Initialize wandb
362
+ if use_wandb:
363
+ import wandb
364
+ wandb.init(
365
+ project=wandb_project,
366
+ name=wandb_run_name,
367
+ config={
368
+ "train_dir": train_dir,
369
+ "test_dir": test_dir,
370
+ "sample_rate": sample_rate,
371
+ "segment_seconds": segment_seconds,
372
+ "batch_size": batch_size,
373
+ "epochs": epochs,
374
+ "learning_rate": learning_rate,
375
+ "weight_decay": weight_decay,
376
+ "grad_clip": grad_clip,
377
+ "sdr_weight": sdr_weight,
378
+ "sisdr_weight": sisdr_weight,
379
+ "model_dim": model_dim,
380
+ "text_dim": text_dim,
381
+ "n_heads": n_heads,
382
+ "use_amp": use_amp,
383
+ },
384
+ )
385
+
386
+ print("=" * 60)
387
+ print("Audio-Text HTDemucs Training")
388
+ print("=" * 60)
389
+ print(f"Device: {device}")
390
+ print(f"Train directory: {train_dir}")
391
+ print(f"Test directory: {test_dir}")
392
+ print(f"Segment length: {segment_seconds}s ({segment_samples} samples)")
393
+ print(f"Batch size: {batch_size}")
394
+ print(f"Epochs: {epochs}")
395
+ print(f"Learning rate: {learning_rate}")
396
+ print("=" * 60)
397
+
398
+ # Load pretrained models
399
+ print("Loading pretrained HTDemucs...")
400
+ htdemucs = pretrained.get_model('htdemucs').models[0]
401
+
402
+ print("Loading CLAP model...")
403
+ #clap = ClapModel.from_pretrained("laion/clap-htsat-unfused")
404
+ clap = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") # More memory efficient than loading full ClapModel (text + audio)
405
+ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
406
+
407
+ # Create model
408
+ print("Building AudioTextHTDemucs model...")
409
+ model = AudioTextHTDemucs(
410
+ htdemucs_model=htdemucs,
411
+ clap_encoder=clap,
412
+ clap_tokenizer=tokenizer,
413
+ model_dim=model_dim,
414
+ text_dim=text_dim,
415
+ num_heads=n_heads,
416
+ )
417
+ model = model.to(device)
418
+
419
+ # Count parameters
420
+ total_params = sum(p.numel() for p in model.parameters())
421
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
422
+ print(f"Total parameters: {total_params:,}")
423
+ print(f"Trainable parameters: {trainable_params:,}")
424
+
425
+ # Create datasets
426
+ print("Creating datasets...")
427
+ train_dataset = MusDBStemDataset(
428
+ root_dir=train_dir,
429
+ segment_samples=segment_samples,
430
+ sample_rate=sample_rate,
431
+ random_segments=True,
432
+ augment=True,
433
+ )
434
+
435
+ val_dataset = MusDBStemDataset(
436
+ root_dir=test_dir,
437
+ segment_samples=segment_samples,
438
+ sample_rate=sample_rate,
439
+ random_segments=False,
440
+ augment=False,
441
+ )
442
+
443
+ # Create suubsets if specified
444
+ if 0.0 < pct_train < 1.0:
445
+ num_train = int(len(train_dataset) * pct_train)
446
+ train_idxs = torch.randperm(len(train_dataset))[:num_train]
447
+ train_subset = Subset(train_dataset, train_idxs)
448
+
449
+ if 0.0 < pct_test < 1.0:
450
+ num_val = int(len(val_dataset) * pct_test)
451
+ val_idxs = torch.randperm(len(val_dataset))[:num_val]
452
+ val_subset = Subset(train_dataset, val_idxs)
453
+
454
+
455
+ # Create dataloaders
456
+ train_loader = DataLoader(
457
+ train_dataset if pct_train >= 1.0 else train_subset,
458
+ batch_size=batch_size,
459
+ shuffle=True,
460
+ num_workers=num_workers,
461
+ collate_fn=collate_fn,
462
+ pin_memory=(device == "cuda"),
463
+ drop_last=True,
464
+ )
465
+
466
+ val_loader = DataLoader(
467
+ val_dataset if pct_test >= 1.0 else val_subset,
468
+ batch_size=batch_size,
469
+ shuffle=False,
470
+ num_workers=num_workers,
471
+ collate_fn=collate_fn,
472
+ pin_memory=(device == "cuda"),
473
+ )
474
+
475
+ # Optimizer and scheduler
476
+ optimizer = AdamW(
477
+ model.parameters(),
478
+ lr=learning_rate,
479
+ weight_decay=weight_decay,
480
+ betas=(0.9, 0.999),
481
+ )
482
+
483
+ scheduler = CosineAnnealingLR(
484
+ optimizer,
485
+ T_max=epochs,
486
+ eta_min=learning_rate * 0.01,
487
+ )
488
+
489
+ # Mixed precision scaler
490
+ scaler = GradScaler() if use_amp and device == "cuda" else None
491
+
492
+ # Resume from checkpoint
493
+ start_epoch = 0
494
+ best_sdr = -float("inf")
495
+
496
+ if resume_from is not None:
497
+ resume_path = Path(resume_from)
498
+ if resume_path.exists():
499
+ print(f"Resuming from {resume_path}")
500
+ start_epoch = load_checkpoint(model, optimizer, scheduler, str(resume_path))
501
+ start_epoch += 1
502
+ else:
503
+ # Check for latest checkpoint
504
+ latest_checkpoint = Path(checkpoint_dir) / "latest.pt"
505
+ if latest_checkpoint.exists():
506
+ print(f"Found latest checkpoint at {latest_checkpoint}")
507
+ start_epoch = load_checkpoint(model, optimizer, scheduler, str(latest_checkpoint))
508
+ start_epoch += 1
509
+
510
+ # Training loop
511
+ print("\nStarting training...")
512
+ for epoch in range(start_epoch, epochs):
513
+ print(f"\n{'=' * 60}")
514
+ print(f"Epoch {epoch + 1}/{epochs}")
515
+ print(f"Learning rate: {scheduler.get_last_lr()[0]:.2e}")
516
+ print(f"{'=' * 60}")
517
+
518
+ # Train
519
+ train_metrics = train_epoch(
520
+ model=model,
521
+ dataloader=train_loader,
522
+ optimizer=optimizer,
523
+ scaler=scaler,
524
+ device=device,
525
+ use_amp=use_amp,
526
+ use_L1_cmb_loss=use_L1_cmb_loss,
527
+ l1_sdr_weight=l1_sdr_weight,
528
+ l1_weight=l1_weight,
529
+ grad_clip=grad_clip,
530
+ sdr_weight=sdr_weight,
531
+ sisdr_weight=sisdr_weight,
532
+ epoch=epoch,
533
+ log_every=log_every,
534
+ use_wandb=use_wandb,
535
+ )
536
+ print(f"Train - Loss: {train_metrics['loss']:.4f}, SDR: {train_metrics['sdr']:.2f} dB")
537
+
538
+ # Step scheduler
539
+ scheduler.step()
540
+
541
+ # Validate
542
+ if (epoch + 1) % validate_every == 0:
543
+ val_metrics = validate(
544
+ model=model,
545
+ dataloader=val_loader,
546
+ device=device,
547
+ use_amp=use_amp,
548
+ use_L1_cmb_loss=use_L1_cmb_loss,
549
+ l1_sdr_weight=l1_sdr_weight,
550
+ l1_weight=l1_weight,
551
+ sdr_weight=sdr_weight,
552
+ sisdr_weight=sisdr_weight,
553
+ )
554
+ print(f"Val - Loss: {val_metrics['loss']:.4f}, SDR: {val_metrics['sdr']:.2f} dB")
555
+
556
+ for stem_name in STEM_PROMPTS.keys():
557
+ if f"sdr/{stem_name}" in val_metrics:
558
+ print(f" {stem_name}: {val_metrics[f'sdr/{stem_name}']:.2f} dB")
559
+
560
+ if use_wandb:
561
+ import wandb
562
+ wandb.log({
563
+ "val/loss": val_metrics["loss"],
564
+ "val/sdr": val_metrics["sdr"],
565
+ "val/sisdr": val_metrics["sisdr"],
566
+ **{f"val/{k}": v for k, v in val_metrics.items() if k.startswith("sdr/")},
567
+ "epoch": epoch + 1,
568
+ })
569
+
570
+ is_best = val_metrics["sdr"] > best_sdr
571
+ if is_best:
572
+ best_sdr = val_metrics["sdr"]
573
+ print(f"New best SDR: {best_sdr:.2f} dB")
574
+ else:
575
+ val_metrics = {}
576
+ is_best = False
577
+
578
+ # Save checkpoint
579
+ if (epoch + 1) % save_every == 0 or is_best:
580
+ save_checkpoint(
581
+ model, optimizer, scheduler, epoch + 1,
582
+ {**train_metrics, **val_metrics},
583
+ checkpoint_dir, is_best
584
+ )
585
+ else:
586
+ save_checkpoint(
587
+ model, optimizer, scheduler, epoch + 1,
588
+ {**train_metrics, **val_metrics},
589
+ checkpoint_dir, is_best=False
590
+ )
591
+
592
+ print("\n" + "=" * 60)
593
+ print("Training complete!")
594
+ print(f"Best validation SDR: {best_sdr:.2f} dB")
595
+ print("=" * 60)
596
+
597
+ if use_wandb:
598
+ import wandb
599
+ wandb.finish()
600
+
601
+ return {
602
+ "final_train_metrics": train_metrics,
603
+ "final_val_metrics": val_metrics,
604
+ "best_sdr": best_sdr,
605
+ }
606
+
607
+
608
+ if __name__ == "__main__":
609
+ # Example: run training with default parameters
610
+ train(train_dir="/home/jacob/datasets/musdb18/train", test_dir="/home/jacob/datasets/musdb18/test", checkpoint_dir="../checkpoints")
utils.py ADDED
@@ -0,0 +1,968 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Union, Optional, Dict, List
3
+ from pathlib import Path
4
+ import yaml
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Non-interactive backend for server/training use
12
+
13
+
14
+ # ============================================================================
15
+ # YAML Config
16
+ # ============================================================================
17
+
18
+ def load_config(file_path: Union[str, Path]) -> dict:
19
+ """Load a YAML configuration file."""
20
+ with open(file_path, 'r') as f:
21
+ config = yaml.safe_load(f)
22
+
23
+ return config
24
+
25
+
26
+ # ============================================================================
27
+ # Spectrogram Utilities
28
+ # ============================================================================
29
+
30
+ def compute_spectrogram(
31
+ waveform: torch.Tensor,
32
+ n_fft: int = 2048,
33
+ hop_length: int = 512,
34
+ power: float = 2.0,
35
+ to_db: bool = True,
36
+ top_db: float = 80.0,
37
+ ) -> torch.Tensor:
38
+ """
39
+ Compute spectrogram from waveform using STFT.
40
+
41
+ Args:
42
+ waveform: (C, T) or (T,) audio waveform
43
+ n_fft: FFT window size
44
+ hop_length: Hop length between frames
45
+ power: Exponent for magnitude (1.0 for magnitude, 2.0 for power)
46
+ to_db: Convert to decibel scale
47
+ top_db: Threshold for dynamic range in dB
48
+
49
+ Returns:
50
+ (F, T') spectrogram tensor
51
+ """
52
+ # Handle stereo by taking mean to mono
53
+ if waveform.dim() == 2:
54
+ waveform = waveform.mean(dim=0) # (T,)
55
+
56
+ # Move to CPU for STFT computation
57
+ waveform = waveform.cpu()
58
+
59
+ # Compute STFT
60
+ window = torch.hann_window(n_fft)
61
+ stft = torch.stft(
62
+ waveform,
63
+ n_fft=n_fft,
64
+ hop_length=hop_length,
65
+ win_length=n_fft,
66
+ window=window,
67
+ return_complex=True,
68
+ center=True,
69
+ pad_mode='reflect'
70
+ )
71
+
72
+ # Compute magnitude spectrogram
73
+ spec = torch.abs(stft).pow(power)
74
+
75
+ # Convert to dB
76
+ if to_db:
77
+ spec = amplitude_to_db(spec, top_db=top_db)
78
+
79
+ return spec
80
+
81
+
82
+ def amplitude_to_db(
83
+ spec: torch.Tensor,
84
+ ref: float = 1.0,
85
+ amin: float = 1e-10,
86
+ top_db: float = 80.0,
87
+ ) -> torch.Tensor:
88
+ """Convert amplitude/power spectrogram to decibel scale."""
89
+ spec_db = 10.0 * torch.log10(torch.clamp(spec, min=amin) / ref)
90
+
91
+ # Clip to top_db range
92
+ max_val = spec_db.max()
93
+ spec_db = torch.clamp(spec_db, min=max_val - top_db)
94
+
95
+ return spec_db
96
+
97
+
98
+ def plot_spectrogram(
99
+ spec: torch.Tensor,
100
+ sample_rate: int = 44100,
101
+ hop_length: int = 512,
102
+ title: str = "Spectrogram",
103
+ figsize: tuple = (10, 4),
104
+ cmap: str = "magma",
105
+ colorbar: bool = True,
106
+ ) -> plt.Figure:
107
+ """
108
+ Plot a single spectrogram.
109
+
110
+ Args:
111
+ spec: (F, T) spectrogram tensor (in dB scale)
112
+ sample_rate: Audio sample rate
113
+ hop_length: Hop length used for STFT
114
+ title: Plot title
115
+ figsize: Figure size
116
+ cmap: Colormap for spectrogram
117
+ colorbar: Whether to show colorbar
118
+
119
+ Returns:
120
+ matplotlib Figure object
121
+ """
122
+ spec_np = spec.detach().cpu().numpy() if isinstance(spec, torch.Tensor) else spec
123
+
124
+ fig, ax = plt.subplots(figsize=figsize)
125
+
126
+ # Compute time and frequency axes
127
+ n_frames = spec_np.shape[1]
128
+ n_freqs = spec_np.shape[0]
129
+ time_max = n_frames * hop_length / sample_rate
130
+ freq_max = sample_rate / 2 # Nyquist frequency
131
+
132
+ img = ax.imshow(
133
+ spec_np,
134
+ aspect='auto',
135
+ origin='lower',
136
+ cmap=cmap,
137
+ extent=[0, time_max, 0, freq_max / 1000] # freq in kHz
138
+ )
139
+
140
+ ax.set_xlabel('Time (s)')
141
+ ax.set_ylabel('Frequency (kHz)')
142
+ ax.set_title(title)
143
+
144
+ if colorbar:
145
+ cbar = fig.colorbar(img, ax=ax, format='%+2.0f dB')
146
+ cbar.set_label('Magnitude (dB)')
147
+
148
+ fig.tight_layout()
149
+ return fig
150
+
151
+
152
+ def plot_spectrogram_comparison(
153
+ spectrograms: Dict[str, torch.Tensor],
154
+ sample_rate: int = 44100,
155
+ hop_length: int = 512,
156
+ figsize: tuple = (14, 3),
157
+ cmap: str = "magma",
158
+ suptitle: Optional[str] = None,
159
+ ) -> plt.Figure:
160
+ """
161
+ Plot multiple spectrograms side by side for comparison.
162
+
163
+ Args:
164
+ spectrograms: Dict mapping names to spectrogram tensors
165
+ sample_rate: Audio sample rate
166
+ hop_length: Hop length used for STFT
167
+ figsize: Figure size (width, height per row)
168
+ cmap: Colormap for spectrograms
169
+ suptitle: Super title for the figure
170
+
171
+ Returns:
172
+ matplotlib Figure object
173
+ """
174
+ n_specs = len(spectrograms)
175
+ fig, axes = plt.subplots(
176
+ 1, n_specs,
177
+ figsize=(figsize[0], figsize[1]),
178
+ constrained_layout=True # Better layout handling with colorbars
179
+ )
180
+
181
+ if n_specs == 1:
182
+ axes = [axes]
183
+
184
+ # Find global min/max for consistent colorbar
185
+ all_specs = [s.detach().cpu().numpy() if isinstance(s, torch.Tensor) else s
186
+ for s in spectrograms.values()]
187
+ vmin = min(s.min() for s in all_specs)
188
+ vmax = max(s.max() for s in all_specs)
189
+
190
+ for ax, (name, spec) in zip(axes, spectrograms.items()):
191
+ spec_np = spec.detach().cpu().numpy() if isinstance(spec, torch.Tensor) else spec
192
+
193
+ n_frames = spec_np.shape[1]
194
+ time_max = n_frames * hop_length / sample_rate
195
+ freq_max = sample_rate / 2
196
+
197
+ img = ax.imshow(
198
+ spec_np,
199
+ aspect='auto',
200
+ origin='lower',
201
+ cmap=cmap,
202
+ extent=[0, time_max, 0, freq_max / 1000],
203
+ vmin=vmin,
204
+ vmax=vmax,
205
+ )
206
+
207
+ ax.set_xlabel('Time (s)')
208
+ ax.set_ylabel('Frequency (kHz)')
209
+ ax.set_title(name)
210
+
211
+ # Add single colorbar
212
+ fig.colorbar(img, ax=axes, format='%+2.0f dB', label='Magnitude (dB)')
213
+
214
+ if suptitle:
215
+ fig.suptitle(suptitle, fontsize=12)
216
+
217
+ return fig
218
+
219
+
220
+ def plot_separation_spectrograms(
221
+ mixture: torch.Tensor,
222
+ estimated: torch.Tensor,
223
+ reference: torch.Tensor,
224
+ stem_name: str = "stem",
225
+ sample_rate: int = 44100,
226
+ n_fft: int = 2048,
227
+ hop_length: int = 512,
228
+ ) -> plt.Figure:
229
+ """
230
+ Create a comparison spectrogram plot for stem separation.
231
+ Shows mixture, estimated, reference, and difference.
232
+
233
+ Args:
234
+ mixture: (C, T) mixture waveform
235
+ estimated: (C, T) estimated stem waveform
236
+ reference: (C, T) ground truth stem waveform
237
+ stem_name: Name of the stem for title
238
+ sample_rate: Audio sample rate
239
+ n_fft: FFT window size
240
+ hop_length: Hop length
241
+
242
+ Returns:
243
+ matplotlib Figure object
244
+ """
245
+ # Compute spectrograms
246
+ spec_mix = compute_spectrogram(mixture, n_fft=n_fft, hop_length=hop_length)
247
+ spec_est = compute_spectrogram(estimated, n_fft=n_fft, hop_length=hop_length)
248
+ spec_ref = compute_spectrogram(reference, n_fft=n_fft, hop_length=hop_length)
249
+
250
+ # Create comparison plot
251
+ spectrograms = {
252
+ "Mixture": spec_mix,
253
+ f"Estimated ({stem_name})": spec_est,
254
+ f"Ground Truth ({stem_name})": spec_ref,
255
+ }
256
+
257
+ fig = plot_spectrogram_comparison(
258
+ spectrograms,
259
+ sample_rate=sample_rate,
260
+ hop_length=hop_length,
261
+ suptitle=f"Stem Separation: {stem_name.capitalize()}"
262
+ )
263
+
264
+ return fig
265
+
266
+
267
+ def plot_all_stems_spectrograms(
268
+ mixture: torch.Tensor,
269
+ estimated_stems: Dict[str, torch.Tensor],
270
+ reference_stems: Dict[str, torch.Tensor],
271
+ sample_rate: int = 44100,
272
+ n_fft: int = 2048,
273
+ hop_length: int = 512,
274
+ figsize: tuple = (16, 12),
275
+ ) -> plt.Figure:
276
+ """
277
+ Create a grid of spectrograms for all stems.
278
+
279
+ Args:
280
+ mixture: (C, T) mixture waveform
281
+ estimated_stems: Dict mapping stem names to estimated (C, T) waveforms
282
+ reference_stems: Dict mapping stem names to reference (C, T) waveforms
283
+ sample_rate: Audio sample rate
284
+ n_fft: FFT window size
285
+ hop_length: Hop length
286
+ figsize: Figure size
287
+
288
+ Returns:
289
+ matplotlib Figure object
290
+ """
291
+ stem_names = list(estimated_stems.keys())
292
+ n_stems = len(stem_names)
293
+
294
+ # Create grid: rows = stems, cols = [Estimated, Ground Truth]
295
+ fig, axes = plt.subplots(
296
+ n_stems, 2,
297
+ figsize=figsize,
298
+ constrained_layout=True # Better layout handling with colorbars
299
+ )
300
+
301
+ if n_stems == 1:
302
+ axes = axes.reshape(1, -1)
303
+
304
+ # Compute all spectrograms and find global min/max for consistent colorbar
305
+ all_specs = []
306
+ spec_data = {}
307
+
308
+ for stem_name in stem_names:
309
+ spec_est = compute_spectrogram(
310
+ estimated_stems[stem_name], n_fft=n_fft, hop_length=hop_length
311
+ )
312
+ spec_ref = compute_spectrogram(
313
+ reference_stems[stem_name], n_fft=n_fft, hop_length=hop_length
314
+ )
315
+ spec_data[stem_name] = {'est': spec_est, 'ref': spec_ref}
316
+ all_specs.extend([spec_est.cpu().numpy(), spec_ref.cpu().numpy()])
317
+
318
+ vmin = min(s.min() for s in all_specs)
319
+ vmax = max(s.max() for s in all_specs)
320
+
321
+ for row, stem_name in enumerate(stem_names):
322
+ spec_est = spec_data[stem_name]['est']
323
+ spec_ref = spec_data[stem_name]['ref']
324
+
325
+ # Get time extent
326
+ n_frames = spec_est.shape[1]
327
+ time_max = n_frames * hop_length / sample_rate
328
+ freq_max = sample_rate / 2
329
+
330
+ # Plot estimated
331
+ spec_np = spec_est.detach().cpu().numpy()
332
+ axes[row, 0].imshow(
333
+ spec_np, aspect='auto', origin='lower', cmap='magma',
334
+ extent=[0, time_max, 0, freq_max / 1000],
335
+ vmin=vmin, vmax=vmax
336
+ )
337
+ axes[row, 0].set_title(f'{stem_name.capitalize()} - Estimated')
338
+ axes[row, 0].set_ylabel('Freq (kHz)')
339
+
340
+ # Plot reference
341
+ spec_np = spec_ref.detach().cpu().numpy()
342
+ img = axes[row, 1].imshow(
343
+ spec_np, aspect='auto', origin='lower', cmap='magma',
344
+ extent=[0, time_max, 0, freq_max / 1000],
345
+ vmin=vmin, vmax=vmax
346
+ )
347
+ axes[row, 1].set_title(f'{stem_name.capitalize()} - Ground Truth')
348
+
349
+ # Set x labels on bottom row
350
+ axes[-1, 0].set_xlabel('Time (s)')
351
+ axes[-1, 1].set_xlabel('Time (s)')
352
+
353
+ fig.colorbar(img, ax=axes, format='%+2.0f dB', label='Magnitude (dB)')
354
+ fig.suptitle('Stem Separation Results', fontsize=14)
355
+
356
+ return fig
357
+
358
+
359
+ # ============================================================================
360
+ # Weights & Biases Logging Utilities
361
+ # ============================================================================
362
+
363
+ def log_spectrogram_to_wandb(
364
+ fig: plt.Figure,
365
+ key: str = "spectrogram",
366
+ step: Optional[int] = None,
367
+ caption: Optional[str] = None,
368
+ ):
369
+ """
370
+ Log a matplotlib figure as an image to W&B.
371
+
372
+ Args:
373
+ fig: matplotlib Figure object
374
+ key: W&B log key
375
+ step: Training step (optional)
376
+ caption: Image caption
377
+ """
378
+ import wandb
379
+
380
+ # Convert figure to W&B Image
381
+ wandb_img = wandb.Image(fig, caption=caption)
382
+
383
+ log_dict = {key: wandb_img}
384
+ if step is not None:
385
+ wandb.log(log_dict, step=step)
386
+ else:
387
+ wandb.log(log_dict)
388
+
389
+ # Close the figure to free memory
390
+ plt.close(fig)
391
+
392
+ def log_audio_to_wandb(
393
+ audio: torch.Tensor,
394
+ stem_name: str,
395
+ is_gt: bool,
396
+ sample_rate: int = 44100
397
+ ):
398
+ """
399
+ Log audio waveform to W&B.
400
+
401
+ Args:
402
+ audio: (C, T) audio waveform tensor
403
+ stem_name: Name of the stem
404
+ is_gt: Whether this is ground truth audio (or extracted audio)
405
+ sample_rate: Audio sample rate
406
+ """
407
+ import wandb
408
+
409
+ # Convert to numpy
410
+ audio_np = audio.detach().cpu().numpy().T # (T, C)
411
+ title =f"true_{stem_name}" if is_gt else f"extracted_{stem_name}"
412
+ keyname = f"audio/{title}"
413
+ wandb.log({
414
+ keyname: wandb.Audio(
415
+ audio_np,
416
+ sample_rate=sample_rate,
417
+ caption=title
418
+ )
419
+ })
420
+
421
+ def log_separation_spectrograms_to_wandb(
422
+ mixture: torch.Tensor,
423
+ estimated: torch.Tensor,
424
+ reference: torch.Tensor,
425
+ stem_name: str,
426
+ step: Optional[int] = None,
427
+ sample_rate: int = 44100,
428
+ ):
429
+ """
430
+ Log stem separation spectrograms to W&B.
431
+
432
+ Args:
433
+ mixture: (C, T) mixture waveform
434
+ estimated: (C, T) estimated stem waveform
435
+ reference: (C, T) ground truth stem waveform
436
+ stem_name: Name of the stem
437
+ step: Training step (optional)
438
+ sample_rate: Audio sample rate
439
+ """
440
+ fig = plot_separation_spectrograms(
441
+ mixture=mixture,
442
+ estimated=estimated,
443
+ reference=reference,
444
+ stem_name=stem_name,
445
+ sample_rate=sample_rate,
446
+ )
447
+
448
+ log_spectrogram_to_wandb(
449
+ fig=fig,
450
+ key=f"spectrograms/{stem_name}",
451
+ step=step,
452
+ caption=f"Separation for {stem_name}"
453
+ )
454
+
455
+
456
+ def log_all_stems_to_wandb(
457
+ mixture: torch.Tensor,
458
+ estimated_stems: Dict[str, torch.Tensor],
459
+ reference_stems: Dict[str, torch.Tensor],
460
+ step: Optional[int] = None,
461
+ sample_rate: int = 44100,
462
+ log_individual: bool = True,
463
+ log_combined: bool = True,
464
+ ):
465
+ """
466
+ Log spectrograms for all stems to W&B.
467
+
468
+ Args:
469
+ mixture: (C, T) mixture waveform
470
+ estimated_stems: Dict mapping stem names to estimated (C, T) waveforms
471
+ reference_stems: Dict mapping stem names to reference (C, T) waveforms
472
+ step: Training step (optional)
473
+ sample_rate: Audio sample rate
474
+ log_individual: Log individual stem comparisons
475
+ log_combined: Log combined grid of all stems
476
+ """
477
+ if log_individual:
478
+ for stem_name in estimated_stems.keys():
479
+ log_separation_spectrograms_to_wandb(
480
+ mixture=mixture,
481
+ estimated=estimated_stems[stem_name],
482
+ reference=reference_stems[stem_name],
483
+ stem_name=stem_name,
484
+ step=step,
485
+ sample_rate=sample_rate,
486
+ )
487
+
488
+ if log_combined:
489
+ fig = plot_all_stems_spectrograms(
490
+ mixture=mixture,
491
+ estimated_stems=estimated_stems,
492
+ reference_stems=reference_stems,
493
+ sample_rate=sample_rate,
494
+ )
495
+ log_spectrogram_to_wandb(
496
+ fig=fig,
497
+ key="spectrograms/all_stems",
498
+ step=step,
499
+ caption="All stems separation comparison"
500
+ )
501
+
502
+ # --- Audio I/O ---
503
+
504
+ # def load_audio(
505
+ # file_path: Union[str, Path],
506
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
507
+ # max_len: int = 5,
508
+ # mono: bool = True
509
+ # ) -> Tuple[np.ndarray, int]:
510
+ # """
511
+ # Load an audio file into a numpy array.
512
+
513
+ # Parameters
514
+ # ----------
515
+ # file_path (str or Path): Path to the audio file
516
+ # max_len (int): Maximum length of audio in seconds
517
+ # sample_rate (int, optional): Target sample rate
518
+ # mono (bool, optional): Whether to convert audio to mono
519
+
520
+ # Returns
521
+ # -------
522
+ # tuple
523
+ # (audio_data, sample_rate)
524
+ # """
525
+ # try:
526
+ # audio_data, sr = librosa.load(file_path, sr=sample_rate, mono=mono)
527
+
528
+ # # Clip audio to max_len
529
+ # max_samples = int(sample_rate * max_len)
530
+ # if len(audio_data) > max_samples:
531
+ # audio_data = audio_data[:max_samples]
532
+ # else:
533
+ # padding = max_samples - len(audio_data)
534
+ # audio_data = np.pad(
535
+ # audio_data,
536
+ # (0, padding),
537
+ # 'constant'
538
+ # )
539
+
540
+ # return audio_data, sr
541
+ # except Exception as e:
542
+ # raise IOError(f"Error loading audio file {file_path}: {str(e)}")
543
+
544
+ # def save_audio(
545
+ # audio_data: np.ndarray,
546
+ # file_path: Union[str, Path],
547
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
548
+ # normalize: bool = True,
549
+ # file_format: str = 'flac'
550
+ # ) -> None:
551
+ # """
552
+ # Save audio data to a file.
553
+
554
+ # Parameters
555
+ # ----------
556
+ # audio_data (np.ndarray): Audio time series
557
+ # file_path (str or Path): Path to save the audio file
558
+ # sample_rate (int, optional): Sample rate of audio
559
+ # normalize (bool, optional): Whether to normalize audio before saving
560
+ # file_format (str, optional): Audio file format
561
+
562
+ # Returns
563
+ # -------
564
+ # None
565
+ # """
566
+ # output_dir = Path(file_path).parent
567
+ # if output_dir and not output_dir.exists():
568
+ # try:
569
+ # output_dir.mkdir(parents=True, exist_ok=True)
570
+ # except Exception as e:
571
+ # raise IOError(f"Error creating directory {output_dir}: {str(e)}")
572
+
573
+ # # Normalize audio before saving
574
+ # audio_data = librosa.util.normalize(audio_data) if normalize else audio_data
575
+
576
+ # try:
577
+ # sf.write(file_path, audio_data, sample_rate, format=file_format)
578
+ # except Exception as e:
579
+ # raise IOError(f"Error saving audio to {file_path}: {str(e)}")
580
+
581
+ # # --- Gap Processing ---
582
+
583
+ # def create_gap_mask(
584
+ # audio_len_samples: int,
585
+ # gap_len_s: float,
586
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
587
+ # gap_start_s: Optional[float] = None,
588
+ # ) -> Tuple[np.ndarray, Tuple[int, int]]:
589
+ # """
590
+ # Creates a binary mask with a single gap of zeros at a random location.
591
+
592
+ # Parameters
593
+ # ----------
594
+ # audio_len_samples : int
595
+ # Length of the target audio in samples.
596
+ # gap_len_s : float
597
+ # Desired gap length in seconds.
598
+ # sample_rate : int, optional
599
+ # Sample rate. Defaults to DEFAULT_SAMPLE_RATE.
600
+ # gap_start_s : float, optional
601
+ # Timestap in seconds where the gap starts. If None, a random position is chosen.
602
+
603
+ # Returns
604
+ # -------
605
+ # Tuple[np.ndarray, Tuple[int, int]]
606
+ # (mask, (gap_start_sample, gap_end_sample))
607
+ # Mask is 1.0 for signal, 0.0 for gap (float32).
608
+ # Interval is gap start/end indices in samples.
609
+ # """
610
+ # gap_len_samples = int(gap_len_s * sample_rate)
611
+
612
+ # if gap_len_samples <= 0:
613
+ # # No gap, return full mask and zero interval
614
+ # return np.ones(audio_len_samples, dtype=np.float32), (0, 0)
615
+
616
+ # if gap_len_samples >= audio_len_samples:
617
+ # # Gap covers everything
618
+ # print(f"Warning: Gap length ({gap_len_s}s) >= audio length. Returning all zeros mask.")
619
+ # return np.zeros(audio_len_samples, dtype=np.float32), (0, audio_len_samples)
620
+
621
+ # # Choose a random start position for the gap (inclusive range)
622
+ # max_start_sample = audio_len_samples - gap_len_samples
623
+ # if (gap_start_s is None):
624
+ # gap_start_sample = np.random.randint(0, max_start_sample + 1)
625
+ # else:
626
+ # gap_start_sample = int(gap_start_s * sample_rate)
627
+
628
+ # gap_end_sample = gap_start_sample + gap_len_samples
629
+
630
+ # # Create mask
631
+ # mask = np.ones(audio_len_samples, dtype=np.float32)
632
+ # mask[gap_start_sample:gap_end_sample] = 0.0
633
+
634
+ # return mask, (gap_start_sample, gap_end_sample)
635
+
636
+ # def add_random_gap(
637
+ # file_path: Union[str, Path],
638
+ # gap_len: int,
639
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
640
+ # mono: bool = True
641
+ # ) -> Tuple[np.ndarray, Tuple[float, float]]:
642
+ # """
643
+ # Add a random gap of length gap_len at a random valid position within the audio file and return the audio data
644
+
645
+ # Parameters
646
+ # ----------
647
+ # file_path (str or Path): Path to the audio file
648
+ # gap_len (int): Gap length (seconds) to add at one location within the audio file
649
+ # sample_rate (int, optional): Target sample rate
650
+ # mono (bool, optional): Whether to convert audio to mono
651
+
652
+ # Returns
653
+ # -------
654
+ # tuple
655
+ # (modified_audio_data, gap_interval)
656
+ # gap_interval is a tuple of (start_time, end_time) in seconds
657
+ # """
658
+ # audio_data, sr = load_audio(file_path, sample_rate=sample_rate, mono=mono)
659
+
660
+ # # Convert gap length to samples
661
+ # gap_length = int(gap_len * sample_rate)
662
+ # audio_len = len(audio_data)
663
+
664
+ # # Handle case where gap is longer than audio
665
+ # if gap_length >= audio_len:
666
+ # raise ValueError(f"Gap length ({gap_length}s) exceeds audio length ({audio_len/sample_rate}s)")
667
+
668
+ # # Get sample indices for gap placement
669
+ # gap_start_idx = np.random.randint(0, audio_len - int(gap_len * sample_rate))
670
+ # silence = np.zeros(gap_length)
671
+
672
+ # # Add gap
673
+ # audio_new = np.concatenate([audio_data[:gap_start_idx], silence, audio_data[gap_start_idx + gap_length:]])
674
+
675
+ # # Return gap interval as a tuple
676
+ # gap_interval = (gap_start_idx / sample_rate, (gap_start_idx + gap_length) / sample_rate)
677
+
678
+ # return audio_new, gap_interval
679
+
680
+ # # --- STFT Processing ---
681
+
682
+ # def extract_spectrogram(
683
+ # audio_data: np.ndarray,
684
+ # n_fft: int = 2048,
685
+ # hop_length: int = 512,
686
+ # win_length: Optional[int] = None,
687
+ # window: str = 'hann',
688
+ # center: bool = True,
689
+ # power: float = 1.0
690
+ # ) -> np.ndarray:
691
+ # """
692
+ # Extract magnitude spectrogram from audio data.
693
+
694
+ # Parameters
695
+ # ----------
696
+ # audio_data (np.ndarray): Audio time series
697
+ # n_fft (int, optional): FFT window size
698
+ # hop_length (int, optional): Number of samples between successive frames
699
+ # win_length (int or None, optional): Window length. If None, defaults to n_fft
700
+ # window (str, optional): Window specification
701
+ # center (bool, optional): If True, pad signal on both sides
702
+ # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power)
703
+
704
+ # Returns
705
+ # -------
706
+ # np.ndarray
707
+ # Magnitude spectrogram
708
+ # """
709
+ # if power < 0:
710
+ # raise ValueError("Power must be non-negative")
711
+
712
+ # if win_length is None:
713
+ # win_length = n_fft
714
+
715
+ # stft = librosa.stft(
716
+ # audio_data,
717
+ # n_fft=n_fft,
718
+ # hop_length=hop_length,
719
+ # win_length=win_length,
720
+ # window=window,
721
+ # center=center
722
+ # )
723
+
724
+ # return stft
725
+
726
+ # def extract_mel_spectrogram(
727
+ # audio_data: np.ndarray,
728
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
729
+ # n_fft: int = 2048,
730
+ # hop_length: int = 512,
731
+ # n_mels: int = 128,
732
+ # fmin: float = 0.0,
733
+ # fmax: Optional[float] = None,
734
+ # power: float = 2.0
735
+ # ) -> np.ndarray:
736
+ # """
737
+ # Extract mel spectrogram from audio data.
738
+
739
+ # Parameters
740
+ # ----------
741
+ # audio_data (np.ndarray): Audio time series
742
+ # sample_rate (int, optional): Sample rate of audio
743
+ # n_fft (int, optional): FFT window size
744
+ # hop_length (int, optional): Number of samples between successive frames
745
+ # n_mels (int, optional): Number of mel bands
746
+ # fmin (float, optional): Minimum frequency
747
+ # fmax (float or None, optional): Maximum frequency. If None, use sample_rate/2
748
+ # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power)
749
+
750
+ # Returns
751
+ # -------
752
+ # np.ndarray
753
+ # Mel spectrogram
754
+ # """
755
+ # if power < 0:
756
+ # raise ValueError("Power must be non-negative")
757
+
758
+ # return librosa.feature.melspectrogram(
759
+ # y=audio_data,
760
+ # sr=sample_rate,
761
+ # n_fft=n_fft,
762
+ # hop_length=hop_length,
763
+ # n_mels=n_mels,
764
+ # fmin=fmin,
765
+ # fmax=fmax,
766
+ # power=power
767
+ # )
768
+
769
+ # def spectrogram_to_audio(
770
+ # spectrogram: np.ndarray,
771
+ # phase: Optional[np.ndarray] = None,
772
+ # phase_info: bool = False,
773
+ # n_fft=512,
774
+ # n_iter=64,
775
+ # window='hann',
776
+ # hop_length=512,
777
+ # win_length=None,
778
+ # center=True) -> np.ndarray:
779
+ # """
780
+ # Convert a spectrogram back to audio using either:
781
+ # 1. Original phase information (if provided)
782
+ # 2. Griffin-Lim algorithm to estimate phase (if no phase provided)
783
+
784
+ # Even with original phase, the reconstruction is not truely lossless 1e-33 MSE loss.
785
+
786
+ # Parameters:
787
+ # -----------
788
+ # spectrogram (np.ndarray): The magnitude spectrogram to convert back to audio
789
+ # phase (np.ndarray, optional): Phase information to use for reconstruction. If None, Griffin-Lim is used.
790
+ # phase_info (bool): If True, the input is assumed to be a phase spectrogram
791
+ # n_fft (int): FFT window size
792
+ # n_iter (int, optional): Number of iterations for Griffin-Lim algorithm
793
+ # window (str): Window function to use
794
+ # win_length (int or None): Window size. If None, defaults to n_fft
795
+ # hop_length (int, optional): Number of samples between successive frames
796
+ # center (bool, optional): Whether to pad the signal at the edges
797
+
798
+ # Returns:
799
+ # --------
800
+ # y : np.ndarray The reconstructed audio signal
801
+ # """
802
+ # # If the input is in dB scale, convert back to amplitude
803
+ # if np.max(spectrogram) < 0 and np.mean(spectrogram) < 0:
804
+ # spectrogram = librosa.db_to_amplitude(spectrogram)
805
+
806
+ # if phase_info:
807
+ # return librosa.istft(spectrogram, n_fft=n_fft, hop_length=hop_length,
808
+ # win_length=win_length, window=window, center=center)
809
+
810
+ # # If phase information is provided, use it for reconstruction
811
+ # if phase is not None:
812
+ # # Combine magnitude and phase to form complex spectrogram
813
+ # complex_spectrogram = spectrogram * np.exp(1j * phase)
814
+
815
+ # # Inverse STFT to get audio
816
+ # y = librosa.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length,
817
+ # win_length=win_length, window=window, center=center)
818
+ # else:
819
+ # # Use Griffin-Lim algorithm to estimate phase
820
+ # y = librosa.griffinlim(spectrogram, n_fft=n_fft, n_iter=n_iter,
821
+ # hop_length=hop_length, win_length=win_length,
822
+ # window=window, center=center)
823
+ # return y
824
+
825
+ # def mel_spectrogram_to_audio(
826
+ # mel_spectrogram: np.ndarray,
827
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
828
+ # n_fft: int = 2048,
829
+ # hop_length: int = 512,
830
+ # n_iter: int = 32,
831
+ # n_mels: int = 128,
832
+ # fmin: float = 0.0,
833
+ # fmax: Optional[float] = None,
834
+ # power: float = 2.0
835
+ # ) -> np.ndarray:
836
+ # """
837
+ # Convert a mel spectrogram to audio using inverse transformation and Griffin-Lim.
838
+
839
+ # Parameters
840
+ # ----------
841
+ # mel_spectrogram (np.ndarray): Mel spectrogram
842
+ # sample_rate (int, optional): Sample rate of audio
843
+ # n_fft (int, optional): FFT window size
844
+ # hop_length (int, optional): Number of samples between successive frames
845
+ # n_iter (int, optional): Number of iterations for Griffin-Lim
846
+ # n_mels (int, optional): Number of mel bands
847
+ # fmin (float, optional): Minimum frequency
848
+ # fmax (float or None, optional): Maximum frequency. If None, use sample_rate/2
849
+ # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power)
850
+
851
+ # Returns
852
+ # -------
853
+ # np.ndarray
854
+ # Audio time series
855
+ # """
856
+ # # Create a mel filterbank
857
+ # mel_basis = librosa.filters.mel(
858
+ # sr=sample_rate,
859
+ # n_fft=n_fft,
860
+ # n_mels=n_mels,
861
+ # fmin=fmin,
862
+ # fmax=fmax
863
+ # )
864
+
865
+ # # Compute the pseudo-inverse of the mel filterbank
866
+ # mel_filterbank_inv = np.linalg.pinv(mel_basis)
867
+
868
+ # # Convert Mel spectrogram to linear spectrogram
869
+ # linear_spec = np.dot(mel_filterbank_inv, mel_spectrogram)
870
+
871
+ # # # If the input was a power spectrogram, take the square root
872
+ # if power == 2.0:
873
+ # linear_spec = np.sqrt(linear_spec)
874
+
875
+ # # Perform Griffin-Lim to estimate the phase and convert to audio
876
+ # audio_data = librosa.griffinlim(
877
+ # linear_spec,
878
+ # hop_length=hop_length,
879
+ # n_fft=n_fft,
880
+ # n_iter=n_iter
881
+ # )
882
+
883
+ # return audio_data
884
+
885
+ # def visualize_spectrogram(
886
+ # spectrogram: np.ndarray,
887
+ # power: int = 1,
888
+ # sample_rate: int = DEFAULT_SAMPLE_RATE,
889
+ # n_fft: int = 512,
890
+ # hop_length: int = 192,
891
+ # win_length: int = 384,
892
+ # gap_int: Optional[Tuple[int, int]] = None,
893
+ # in_db: bool = False,
894
+ # y_axis: str = 'log',
895
+ # x_axis: str = 'time',
896
+ # title: str = 'Spectrogram',
897
+ # save_path: Optional[Union[str, Path]] = None
898
+ # ) -> figure:
899
+ # """
900
+ # Visualize a spectrogram.
901
+
902
+ # Parameters
903
+ # ----------
904
+ # spectrogram (np.ndarray): Spectrogram to visualize
905
+ # power (int): Whether the spectrogram is in energy (1) or power (2) scale
906
+ # sample_rate (int, optional): Sample rate of audio
907
+ # hop_length (int, optional): Number of samples between successive frames
908
+ # gap_int (float tuple, optional): Start and end time [s] of the gap (if given) to be plotted as vertical lines
909
+ # in_db (bool, optional): Whether the spectrogram is already in dB scale
910
+ # y_axis (str, optional): Scale for the y-axis ('linear', 'log', or 'mel')
911
+ # x_axis (str, optional): Scale for the x-axis ('time' or 'frames')
912
+ # title (str, optional): Title for the plot
913
+ # save_path (str or Path or None, optional): Path to save the visualization. If None, the plot is displayed.
914
+
915
+ # Returns
916
+ # -------
917
+ # Figure or None
918
+ # The matplotlib Figure object if save_path is None, otherwise None
919
+ # """
920
+ # if power not in (1, 2):
921
+ # raise ValueError("Power must be 1 (energy) or 2 (power)")
922
+
923
+ # # Convert to dB scale if needed
924
+ # if in_db:
925
+ # spectrogram_data = np.array(spectrogram)
926
+ # elif power == 1:
927
+ # spectrogram_data = librosa.amplitude_to_db(spectrogram, ref=np.max, amin=1e-5, top_db=80)
928
+ # else: # power == 2
929
+ # spectrogram_data = librosa.power_to_db(spectrogram, ref=np.max, amin=1e-5, top_db=80)
930
+
931
+
932
+ # fig, ax = plt.subplots(figsize=(10, 4))
933
+ # img = librosa.display.specshow(
934
+ # spectrogram_data,
935
+ # sr=sample_rate,
936
+ # n_fft=n_fft,
937
+ # win_length=win_length,
938
+ # hop_length=hop_length,
939
+ # y_axis=y_axis,
940
+ # x_axis=x_axis,
941
+ # ax=ax
942
+ # )
943
+
944
+ # # Compute gap start and end indices and plot vertical lines
945
+ # if gap_int is not None:
946
+ # gap_start_s, gap_end_s = gap_int
947
+
948
+ # ax.axvline(x=gap_start_s, color='white', linestyle='--', label='Gap Start')
949
+ # ax.axvline(x=gap_end_s, color='white', linestyle='--', label='Gap End')
950
+ # ax.legend()
951
+
952
+ # # Add colorbar and title
953
+ # fig.colorbar(img, ax=ax, format='%+2.0f dB')
954
+ # ax.set_title(title)
955
+ # fig.tight_layout()
956
+
957
+ # # Save or return the figure
958
+ # if save_path is not None:
959
+ # save_path = Path(save_path)
960
+ # output_dir = save_path.parent
961
+ # if output_dir and not output_dir.exists():
962
+ # output_dir.mkdir(parents=True, exist_ok=True)
963
+
964
+ # fig.savefig(save_path)
965
+ # plt.close(fig)
966
+ # return None
967
+
968
+ # return fig