Spaces:
Running
Running
| from typing import Union, Optional, Dict, List | |
| from pathlib import Path | |
| import yaml | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Non-interactive backend for server/training use | |
| # ============================================================================ | |
| # YAML Config | |
| # ============================================================================ | |
| def load_config(file_path: Union[str, Path]) -> dict: | |
| """Load a YAML configuration file.""" | |
| with open(file_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| # ============================================================================ | |
| # Spectrogram Utilities | |
| # ============================================================================ | |
| def compute_spectrogram( | |
| waveform: torch.Tensor, | |
| n_fft: int = 2048, | |
| hop_length: int = 512, | |
| power: float = 2.0, | |
| to_db: bool = True, | |
| top_db: float = 80.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute spectrogram from waveform using STFT. | |
| Args: | |
| waveform: (C, T) or (T,) audio waveform | |
| n_fft: FFT window size | |
| hop_length: Hop length between frames | |
| power: Exponent for magnitude (1.0 for magnitude, 2.0 for power) | |
| to_db: Convert to decibel scale | |
| top_db: Threshold for dynamic range in dB | |
| Returns: | |
| (F, T') spectrogram tensor | |
| """ | |
| # Handle stereo by taking mean to mono | |
| if waveform.dim() == 2: | |
| waveform = waveform.mean(dim=0) # (T,) | |
| # Move to CPU for STFT computation | |
| waveform = waveform.cpu() | |
| # Compute STFT | |
| window = torch.hann_window(n_fft) | |
| stft = torch.stft( | |
| waveform, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=n_fft, | |
| window=window, | |
| return_complex=True, | |
| center=True, | |
| pad_mode='reflect' | |
| ) | |
| # Compute magnitude spectrogram | |
| spec = torch.abs(stft).pow(power) | |
| # Convert to dB | |
| if to_db: | |
| spec = amplitude_to_db(spec, top_db=top_db) | |
| return spec | |
| def amplitude_to_db( | |
| spec: torch.Tensor, | |
| ref: float = 1.0, | |
| amin: float = 1e-10, | |
| top_db: float = 80.0, | |
| ) -> torch.Tensor: | |
| """Convert amplitude/power spectrogram to decibel scale.""" | |
| spec_db = 10.0 * torch.log10(torch.clamp(spec, min=amin) / ref) | |
| # Clip to top_db range | |
| max_val = spec_db.max() | |
| spec_db = torch.clamp(spec_db, min=max_val - top_db) | |
| return spec_db | |
| def plot_spectrogram( | |
| spec: torch.Tensor, | |
| sample_rate: int = 44100, | |
| hop_length: int = 512, | |
| title: str = "Spectrogram", | |
| figsize: tuple = (10, 4), | |
| cmap: str = "magma", | |
| colorbar: bool = True, | |
| ) -> plt.Figure: | |
| """ | |
| Plot a single spectrogram. | |
| Args: | |
| spec: (F, T) spectrogram tensor (in dB scale) | |
| sample_rate: Audio sample rate | |
| hop_length: Hop length used for STFT | |
| title: Plot title | |
| figsize: Figure size | |
| cmap: Colormap for spectrogram | |
| colorbar: Whether to show colorbar | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| spec_np = spec.detach().cpu().numpy() if isinstance(spec, torch.Tensor) else spec | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Compute time and frequency axes | |
| n_frames = spec_np.shape[1] | |
| n_freqs = spec_np.shape[0] | |
| time_max = n_frames * hop_length / sample_rate | |
| freq_max = sample_rate / 2 # Nyquist frequency | |
| img = ax.imshow( | |
| spec_np, | |
| aspect='auto', | |
| origin='lower', | |
| cmap=cmap, | |
| extent=[0, time_max, 0, freq_max / 1000] # freq in kHz | |
| ) | |
| ax.set_xlabel('Time (s)') | |
| ax.set_ylabel('Frequency (kHz)') | |
| ax.set_title(title) | |
| if colorbar: | |
| cbar = fig.colorbar(img, ax=ax, format='%+2.0f dB') | |
| cbar.set_label('Magnitude (dB)') | |
| fig.tight_layout() | |
| return fig | |
| def plot_spectrogram_comparison( | |
| spectrograms: Dict[str, torch.Tensor], | |
| sample_rate: int = 44100, | |
| hop_length: int = 512, | |
| figsize: tuple = (14, 3), | |
| cmap: str = "magma", | |
| suptitle: Optional[str] = None, | |
| ) -> plt.Figure: | |
| """ | |
| Plot multiple spectrograms side by side for comparison. | |
| Args: | |
| spectrograms: Dict mapping names to spectrogram tensors | |
| sample_rate: Audio sample rate | |
| hop_length: Hop length used for STFT | |
| figsize: Figure size (width, height per row) | |
| cmap: Colormap for spectrograms | |
| suptitle: Super title for the figure | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| n_specs = len(spectrograms) | |
| fig, axes = plt.subplots( | |
| 1, n_specs, | |
| figsize=(figsize[0], figsize[1]), | |
| constrained_layout=True # Better layout handling with colorbars | |
| ) | |
| if n_specs == 1: | |
| axes = [axes] | |
| # Find global min/max for consistent colorbar | |
| all_specs = [s.detach().cpu().numpy() if isinstance(s, torch.Tensor) else s | |
| for s in spectrograms.values()] | |
| vmin = min(s.min() for s in all_specs) | |
| vmax = max(s.max() for s in all_specs) | |
| for ax, (name, spec) in zip(axes, spectrograms.items()): | |
| spec_np = spec.detach().cpu().numpy() if isinstance(spec, torch.Tensor) else spec | |
| n_frames = spec_np.shape[1] | |
| time_max = n_frames * hop_length / sample_rate | |
| freq_max = sample_rate / 2 | |
| img = ax.imshow( | |
| spec_np, | |
| aspect='auto', | |
| origin='lower', | |
| cmap=cmap, | |
| extent=[0, time_max, 0, freq_max / 1000], | |
| vmin=vmin, | |
| vmax=vmax, | |
| ) | |
| ax.set_xlabel('Time (s)') | |
| ax.set_ylabel('Frequency (kHz)') | |
| ax.set_title(name) | |
| # Add single colorbar | |
| fig.colorbar(img, ax=axes, format='%+2.0f dB', label='Magnitude (dB)') | |
| if suptitle: | |
| fig.suptitle(suptitle, fontsize=12) | |
| return fig | |
| def plot_separation_spectrograms( | |
| mixture: torch.Tensor, | |
| estimated: torch.Tensor, | |
| reference: torch.Tensor, | |
| stem_name: str = "stem", | |
| sample_rate: int = 44100, | |
| n_fft: int = 2048, | |
| hop_length: int = 512, | |
| ) -> plt.Figure: | |
| """ | |
| Create a comparison spectrogram plot for stem separation. | |
| Shows mixture, estimated, reference, and difference. | |
| Args: | |
| mixture: (C, T) mixture waveform | |
| estimated: (C, T) estimated stem waveform | |
| reference: (C, T) ground truth stem waveform | |
| stem_name: Name of the stem for title | |
| sample_rate: Audio sample rate | |
| n_fft: FFT window size | |
| hop_length: Hop length | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| # Compute spectrograms | |
| spec_mix = compute_spectrogram(mixture, n_fft=n_fft, hop_length=hop_length) | |
| spec_est = compute_spectrogram(estimated, n_fft=n_fft, hop_length=hop_length) | |
| spec_ref = compute_spectrogram(reference, n_fft=n_fft, hop_length=hop_length) | |
| # Create comparison plot | |
| spectrograms = { | |
| "Mixture": spec_mix, | |
| f"Estimated ({stem_name})": spec_est, | |
| f"Ground Truth ({stem_name})": spec_ref, | |
| } | |
| fig = plot_spectrogram_comparison( | |
| spectrograms, | |
| sample_rate=sample_rate, | |
| hop_length=hop_length, | |
| suptitle=f"Stem Separation: {stem_name.capitalize()}" | |
| ) | |
| return fig | |
| def plot_all_stems_spectrograms( | |
| mixture: torch.Tensor, | |
| estimated_stems: Dict[str, torch.Tensor], | |
| reference_stems: Dict[str, torch.Tensor], | |
| sample_rate: int = 44100, | |
| n_fft: int = 2048, | |
| hop_length: int = 512, | |
| figsize: tuple = (16, 12), | |
| ) -> plt.Figure: | |
| """ | |
| Create a grid of spectrograms for all stems. | |
| Args: | |
| mixture: (C, T) mixture waveform | |
| estimated_stems: Dict mapping stem names to estimated (C, T) waveforms | |
| reference_stems: Dict mapping stem names to reference (C, T) waveforms | |
| sample_rate: Audio sample rate | |
| n_fft: FFT window size | |
| hop_length: Hop length | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| stem_names = list(estimated_stems.keys()) | |
| n_stems = len(stem_names) | |
| # Create grid: rows = stems, cols = [Estimated, Ground Truth] | |
| fig, axes = plt.subplots( | |
| n_stems, 2, | |
| figsize=figsize, | |
| constrained_layout=True # Better layout handling with colorbars | |
| ) | |
| if n_stems == 1: | |
| axes = axes.reshape(1, -1) | |
| # Compute all spectrograms and find global min/max for consistent colorbar | |
| all_specs = [] | |
| spec_data = {} | |
| for stem_name in stem_names: | |
| spec_est = compute_spectrogram( | |
| estimated_stems[stem_name], n_fft=n_fft, hop_length=hop_length | |
| ) | |
| spec_ref = compute_spectrogram( | |
| reference_stems[stem_name], n_fft=n_fft, hop_length=hop_length | |
| ) | |
| spec_data[stem_name] = {'est': spec_est, 'ref': spec_ref} | |
| all_specs.extend([spec_est.cpu().numpy(), spec_ref.cpu().numpy()]) | |
| vmin = min(s.min() for s in all_specs) | |
| vmax = max(s.max() for s in all_specs) | |
| for row, stem_name in enumerate(stem_names): | |
| spec_est = spec_data[stem_name]['est'] | |
| spec_ref = spec_data[stem_name]['ref'] | |
| # Get time extent | |
| n_frames = spec_est.shape[1] | |
| time_max = n_frames * hop_length / sample_rate | |
| freq_max = sample_rate / 2 | |
| # Plot estimated | |
| spec_np = spec_est.detach().cpu().numpy() | |
| axes[row, 0].imshow( | |
| spec_np, aspect='auto', origin='lower', cmap='magma', | |
| extent=[0, time_max, 0, freq_max / 1000], | |
| vmin=vmin, vmax=vmax | |
| ) | |
| axes[row, 0].set_title(f'{stem_name.capitalize()} - Estimated') | |
| axes[row, 0].set_ylabel('Freq (kHz)') | |
| # Plot reference | |
| spec_np = spec_ref.detach().cpu().numpy() | |
| img = axes[row, 1].imshow( | |
| spec_np, aspect='auto', origin='lower', cmap='magma', | |
| extent=[0, time_max, 0, freq_max / 1000], | |
| vmin=vmin, vmax=vmax | |
| ) | |
| axes[row, 1].set_title(f'{stem_name.capitalize()} - Ground Truth') | |
| # Set x labels on bottom row | |
| axes[-1, 0].set_xlabel('Time (s)') | |
| axes[-1, 1].set_xlabel('Time (s)') | |
| fig.colorbar(img, ax=axes, format='%+2.0f dB', label='Magnitude (dB)') | |
| fig.suptitle('Stem Separation Results', fontsize=14) | |
| return fig | |
| # ============================================================================ | |
| # Weights & Biases Logging Utilities | |
| # ============================================================================ | |
| def log_spectrogram_to_wandb( | |
| fig: plt.Figure, | |
| key: str = "spectrogram", | |
| step: Optional[int] = None, | |
| caption: Optional[str] = None, | |
| ): | |
| """ | |
| Log a matplotlib figure as an image to W&B. | |
| Args: | |
| fig: matplotlib Figure object | |
| key: W&B log key | |
| step: Training step (optional) | |
| caption: Image caption | |
| """ | |
| import wandb | |
| # Convert figure to W&B Image | |
| wandb_img = wandb.Image(fig, caption=caption) | |
| log_dict = {key: wandb_img} | |
| if step is not None: | |
| wandb.log(log_dict, step=step) | |
| else: | |
| wandb.log(log_dict) | |
| # Close the figure to free memory | |
| plt.close(fig) | |
| def log_audio_to_wandb( | |
| audio: torch.Tensor, | |
| stem_name: str, | |
| is_gt: bool, | |
| sample_rate: int = 44100 | |
| ): | |
| """ | |
| Log audio waveform to W&B. | |
| Args: | |
| audio: (C, T) audio waveform tensor | |
| stem_name: Name of the stem | |
| is_gt: Whether this is ground truth audio (or extracted audio) | |
| sample_rate: Audio sample rate | |
| """ | |
| import wandb | |
| # Convert to numpy | |
| audio_np = audio.detach().cpu().numpy().T # (T, C) | |
| title =f"true_{stem_name}" if is_gt else f"extracted_{stem_name}" | |
| keyname = f"audio/{title}" | |
| wandb.log({ | |
| keyname: wandb.Audio( | |
| audio_np, | |
| sample_rate=sample_rate, | |
| caption=title | |
| ) | |
| }) | |
| def log_separation_spectrograms_to_wandb( | |
| mixture: torch.Tensor, | |
| estimated: torch.Tensor, | |
| reference: torch.Tensor, | |
| stem_name: str, | |
| step: Optional[int] = None, | |
| sample_rate: int = 44100, | |
| ): | |
| """ | |
| Log stem separation spectrograms to W&B. | |
| Args: | |
| mixture: (C, T) mixture waveform | |
| estimated: (C, T) estimated stem waveform | |
| reference: (C, T) ground truth stem waveform | |
| stem_name: Name of the stem | |
| step: Training step (optional) | |
| sample_rate: Audio sample rate | |
| """ | |
| fig = plot_separation_spectrograms( | |
| mixture=mixture, | |
| estimated=estimated, | |
| reference=reference, | |
| stem_name=stem_name, | |
| sample_rate=sample_rate, | |
| ) | |
| log_spectrogram_to_wandb( | |
| fig=fig, | |
| key=f"spectrograms/{stem_name}", | |
| step=step, | |
| caption=f"Separation for {stem_name}" | |
| ) | |
| def log_all_stems_to_wandb( | |
| mixture: torch.Tensor, | |
| estimated_stems: Dict[str, torch.Tensor], | |
| reference_stems: Dict[str, torch.Tensor], | |
| step: Optional[int] = None, | |
| sample_rate: int = 44100, | |
| log_individual: bool = True, | |
| log_combined: bool = True, | |
| ): | |
| """ | |
| Log spectrograms for all stems to W&B. | |
| Args: | |
| mixture: (C, T) mixture waveform | |
| estimated_stems: Dict mapping stem names to estimated (C, T) waveforms | |
| reference_stems: Dict mapping stem names to reference (C, T) waveforms | |
| step: Training step (optional) | |
| sample_rate: Audio sample rate | |
| log_individual: Log individual stem comparisons | |
| log_combined: Log combined grid of all stems | |
| """ | |
| if log_individual: | |
| for stem_name in estimated_stems.keys(): | |
| log_separation_spectrograms_to_wandb( | |
| mixture=mixture, | |
| estimated=estimated_stems[stem_name], | |
| reference=reference_stems[stem_name], | |
| stem_name=stem_name, | |
| step=step, | |
| sample_rate=sample_rate, | |
| ) | |
| if log_combined: | |
| fig = plot_all_stems_spectrograms( | |
| mixture=mixture, | |
| estimated_stems=estimated_stems, | |
| reference_stems=reference_stems, | |
| sample_rate=sample_rate, | |
| ) | |
| log_spectrogram_to_wandb( | |
| fig=fig, | |
| key="spectrograms/all_stems", | |
| step=step, | |
| caption="All stems separation comparison" | |
| ) | |
| # --- Audio I/O --- | |
| # def load_audio( | |
| # file_path: Union[str, Path], | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # max_len: int = 5, | |
| # mono: bool = True | |
| # ) -> Tuple[np.ndarray, int]: | |
| # """ | |
| # Load an audio file into a numpy array. | |
| # Parameters | |
| # ---------- | |
| # file_path (str or Path): Path to the audio file | |
| # max_len (int): Maximum length of audio in seconds | |
| # sample_rate (int, optional): Target sample rate | |
| # mono (bool, optional): Whether to convert audio to mono | |
| # Returns | |
| # ------- | |
| # tuple | |
| # (audio_data, sample_rate) | |
| # """ | |
| # try: | |
| # audio_data, sr = librosa.load(file_path, sr=sample_rate, mono=mono) | |
| # # Clip audio to max_len | |
| # max_samples = int(sample_rate * max_len) | |
| # if len(audio_data) > max_samples: | |
| # audio_data = audio_data[:max_samples] | |
| # else: | |
| # padding = max_samples - len(audio_data) | |
| # audio_data = np.pad( | |
| # audio_data, | |
| # (0, padding), | |
| # 'constant' | |
| # ) | |
| # return audio_data, sr | |
| # except Exception as e: | |
| # raise IOError(f"Error loading audio file {file_path}: {str(e)}") | |
| # def save_audio( | |
| # audio_data: np.ndarray, | |
| # file_path: Union[str, Path], | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # normalize: bool = True, | |
| # file_format: str = 'flac' | |
| # ) -> None: | |
| # """ | |
| # Save audio data to a file. | |
| # Parameters | |
| # ---------- | |
| # audio_data (np.ndarray): Audio time series | |
| # file_path (str or Path): Path to save the audio file | |
| # sample_rate (int, optional): Sample rate of audio | |
| # normalize (bool, optional): Whether to normalize audio before saving | |
| # file_format (str, optional): Audio file format | |
| # Returns | |
| # ------- | |
| # None | |
| # """ | |
| # output_dir = Path(file_path).parent | |
| # if output_dir and not output_dir.exists(): | |
| # try: | |
| # output_dir.mkdir(parents=True, exist_ok=True) | |
| # except Exception as e: | |
| # raise IOError(f"Error creating directory {output_dir}: {str(e)}") | |
| # # Normalize audio before saving | |
| # audio_data = librosa.util.normalize(audio_data) if normalize else audio_data | |
| # try: | |
| # sf.write(file_path, audio_data, sample_rate, format=file_format) | |
| # except Exception as e: | |
| # raise IOError(f"Error saving audio to {file_path}: {str(e)}") | |
| # # --- Gap Processing --- | |
| # def create_gap_mask( | |
| # audio_len_samples: int, | |
| # gap_len_s: float, | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # gap_start_s: Optional[float] = None, | |
| # ) -> Tuple[np.ndarray, Tuple[int, int]]: | |
| # """ | |
| # Creates a binary mask with a single gap of zeros at a random location. | |
| # Parameters | |
| # ---------- | |
| # audio_len_samples : int | |
| # Length of the target audio in samples. | |
| # gap_len_s : float | |
| # Desired gap length in seconds. | |
| # sample_rate : int, optional | |
| # Sample rate. Defaults to DEFAULT_SAMPLE_RATE. | |
| # gap_start_s : float, optional | |
| # Timestap in seconds where the gap starts. If None, a random position is chosen. | |
| # Returns | |
| # ------- | |
| # Tuple[np.ndarray, Tuple[int, int]] | |
| # (mask, (gap_start_sample, gap_end_sample)) | |
| # Mask is 1.0 for signal, 0.0 for gap (float32). | |
| # Interval is gap start/end indices in samples. | |
| # """ | |
| # gap_len_samples = int(gap_len_s * sample_rate) | |
| # if gap_len_samples <= 0: | |
| # # No gap, return full mask and zero interval | |
| # return np.ones(audio_len_samples, dtype=np.float32), (0, 0) | |
| # if gap_len_samples >= audio_len_samples: | |
| # # Gap covers everything | |
| # print(f"Warning: Gap length ({gap_len_s}s) >= audio length. Returning all zeros mask.") | |
| # return np.zeros(audio_len_samples, dtype=np.float32), (0, audio_len_samples) | |
| # # Choose a random start position for the gap (inclusive range) | |
| # max_start_sample = audio_len_samples - gap_len_samples | |
| # if (gap_start_s is None): | |
| # gap_start_sample = np.random.randint(0, max_start_sample + 1) | |
| # else: | |
| # gap_start_sample = int(gap_start_s * sample_rate) | |
| # gap_end_sample = gap_start_sample + gap_len_samples | |
| # # Create mask | |
| # mask = np.ones(audio_len_samples, dtype=np.float32) | |
| # mask[gap_start_sample:gap_end_sample] = 0.0 | |
| # return mask, (gap_start_sample, gap_end_sample) | |
| # def add_random_gap( | |
| # file_path: Union[str, Path], | |
| # gap_len: int, | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # mono: bool = True | |
| # ) -> Tuple[np.ndarray, Tuple[float, float]]: | |
| # """ | |
| # Add a random gap of length gap_len at a random valid position within the audio file and return the audio data | |
| # Parameters | |
| # ---------- | |
| # file_path (str or Path): Path to the audio file | |
| # gap_len (int): Gap length (seconds) to add at one location within the audio file | |
| # sample_rate (int, optional): Target sample rate | |
| # mono (bool, optional): Whether to convert audio to mono | |
| # Returns | |
| # ------- | |
| # tuple | |
| # (modified_audio_data, gap_interval) | |
| # gap_interval is a tuple of (start_time, end_time) in seconds | |
| # """ | |
| # audio_data, sr = load_audio(file_path, sample_rate=sample_rate, mono=mono) | |
| # # Convert gap length to samples | |
| # gap_length = int(gap_len * sample_rate) | |
| # audio_len = len(audio_data) | |
| # # Handle case where gap is longer than audio | |
| # if gap_length >= audio_len: | |
| # raise ValueError(f"Gap length ({gap_length}s) exceeds audio length ({audio_len/sample_rate}s)") | |
| # # Get sample indices for gap placement | |
| # gap_start_idx = np.random.randint(0, audio_len - int(gap_len * sample_rate)) | |
| # silence = np.zeros(gap_length) | |
| # # Add gap | |
| # audio_new = np.concatenate([audio_data[:gap_start_idx], silence, audio_data[gap_start_idx + gap_length:]]) | |
| # # Return gap interval as a tuple | |
| # gap_interval = (gap_start_idx / sample_rate, (gap_start_idx + gap_length) / sample_rate) | |
| # return audio_new, gap_interval | |
| # # --- STFT Processing --- | |
| # def extract_spectrogram( | |
| # audio_data: np.ndarray, | |
| # n_fft: int = 2048, | |
| # hop_length: int = 512, | |
| # win_length: Optional[int] = None, | |
| # window: str = 'hann', | |
| # center: bool = True, | |
| # power: float = 1.0 | |
| # ) -> np.ndarray: | |
| # """ | |
| # Extract magnitude spectrogram from audio data. | |
| # Parameters | |
| # ---------- | |
| # audio_data (np.ndarray): Audio time series | |
| # n_fft (int, optional): FFT window size | |
| # hop_length (int, optional): Number of samples between successive frames | |
| # win_length (int or None, optional): Window length. If None, defaults to n_fft | |
| # window (str, optional): Window specification | |
| # center (bool, optional): If True, pad signal on both sides | |
| # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power) | |
| # Returns | |
| # ------- | |
| # np.ndarray | |
| # Magnitude spectrogram | |
| # """ | |
| # if power < 0: | |
| # raise ValueError("Power must be non-negative") | |
| # if win_length is None: | |
| # win_length = n_fft | |
| # stft = librosa.stft( | |
| # audio_data, | |
| # n_fft=n_fft, | |
| # hop_length=hop_length, | |
| # win_length=win_length, | |
| # window=window, | |
| # center=center | |
| # ) | |
| # return stft | |
| # def extract_mel_spectrogram( | |
| # audio_data: np.ndarray, | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # n_fft: int = 2048, | |
| # hop_length: int = 512, | |
| # n_mels: int = 128, | |
| # fmin: float = 0.0, | |
| # fmax: Optional[float] = None, | |
| # power: float = 2.0 | |
| # ) -> np.ndarray: | |
| # """ | |
| # Extract mel spectrogram from audio data. | |
| # Parameters | |
| # ---------- | |
| # audio_data (np.ndarray): Audio time series | |
| # sample_rate (int, optional): Sample rate of audio | |
| # n_fft (int, optional): FFT window size | |
| # hop_length (int, optional): Number of samples between successive frames | |
| # n_mels (int, optional): Number of mel bands | |
| # fmin (float, optional): Minimum frequency | |
| # fmax (float or None, optional): Maximum frequency. If None, use sample_rate/2 | |
| # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power) | |
| # Returns | |
| # ------- | |
| # np.ndarray | |
| # Mel spectrogram | |
| # """ | |
| # if power < 0: | |
| # raise ValueError("Power must be non-negative") | |
| # return librosa.feature.melspectrogram( | |
| # y=audio_data, | |
| # sr=sample_rate, | |
| # n_fft=n_fft, | |
| # hop_length=hop_length, | |
| # n_mels=n_mels, | |
| # fmin=fmin, | |
| # fmax=fmax, | |
| # power=power | |
| # ) | |
| # def spectrogram_to_audio( | |
| # spectrogram: np.ndarray, | |
| # phase: Optional[np.ndarray] = None, | |
| # phase_info: bool = False, | |
| # n_fft=512, | |
| # n_iter=64, | |
| # window='hann', | |
| # hop_length=512, | |
| # win_length=None, | |
| # center=True) -> np.ndarray: | |
| # """ | |
| # Convert a spectrogram back to audio using either: | |
| # 1. Original phase information (if provided) | |
| # 2. Griffin-Lim algorithm to estimate phase (if no phase provided) | |
| # Even with original phase, the reconstruction is not truely lossless 1e-33 MSE loss. | |
| # Parameters: | |
| # ----------- | |
| # spectrogram (np.ndarray): The magnitude spectrogram to convert back to audio | |
| # phase (np.ndarray, optional): Phase information to use for reconstruction. If None, Griffin-Lim is used. | |
| # phase_info (bool): If True, the input is assumed to be a phase spectrogram | |
| # n_fft (int): FFT window size | |
| # n_iter (int, optional): Number of iterations for Griffin-Lim algorithm | |
| # window (str): Window function to use | |
| # win_length (int or None): Window size. If None, defaults to n_fft | |
| # hop_length (int, optional): Number of samples between successive frames | |
| # center (bool, optional): Whether to pad the signal at the edges | |
| # Returns: | |
| # -------- | |
| # y : np.ndarray The reconstructed audio signal | |
| # """ | |
| # # If the input is in dB scale, convert back to amplitude | |
| # if np.max(spectrogram) < 0 and np.mean(spectrogram) < 0: | |
| # spectrogram = librosa.db_to_amplitude(spectrogram) | |
| # if phase_info: | |
| # return librosa.istft(spectrogram, n_fft=n_fft, hop_length=hop_length, | |
| # win_length=win_length, window=window, center=center) | |
| # # If phase information is provided, use it for reconstruction | |
| # if phase is not None: | |
| # # Combine magnitude and phase to form complex spectrogram | |
| # complex_spectrogram = spectrogram * np.exp(1j * phase) | |
| # # Inverse STFT to get audio | |
| # y = librosa.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, | |
| # win_length=win_length, window=window, center=center) | |
| # else: | |
| # # Use Griffin-Lim algorithm to estimate phase | |
| # y = librosa.griffinlim(spectrogram, n_fft=n_fft, n_iter=n_iter, | |
| # hop_length=hop_length, win_length=win_length, | |
| # window=window, center=center) | |
| # return y | |
| # def mel_spectrogram_to_audio( | |
| # mel_spectrogram: np.ndarray, | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # n_fft: int = 2048, | |
| # hop_length: int = 512, | |
| # n_iter: int = 32, | |
| # n_mels: int = 128, | |
| # fmin: float = 0.0, | |
| # fmax: Optional[float] = None, | |
| # power: float = 2.0 | |
| # ) -> np.ndarray: | |
| # """ | |
| # Convert a mel spectrogram to audio using inverse transformation and Griffin-Lim. | |
| # Parameters | |
| # ---------- | |
| # mel_spectrogram (np.ndarray): Mel spectrogram | |
| # sample_rate (int, optional): Sample rate of audio | |
| # n_fft (int, optional): FFT window size | |
| # hop_length (int, optional): Number of samples between successive frames | |
| # n_iter (int, optional): Number of iterations for Griffin-Lim | |
| # n_mels (int, optional): Number of mel bands | |
| # fmin (float, optional): Minimum frequency | |
| # fmax (float or None, optional): Maximum frequency. If None, use sample_rate/2 | |
| # power (float, optional): Exponent for the magnitude spectrogram (e.g. 1 for energy, 2 for power) | |
| # Returns | |
| # ------- | |
| # np.ndarray | |
| # Audio time series | |
| # """ | |
| # # Create a mel filterbank | |
| # mel_basis = librosa.filters.mel( | |
| # sr=sample_rate, | |
| # n_fft=n_fft, | |
| # n_mels=n_mels, | |
| # fmin=fmin, | |
| # fmax=fmax | |
| # ) | |
| # # Compute the pseudo-inverse of the mel filterbank | |
| # mel_filterbank_inv = np.linalg.pinv(mel_basis) | |
| # # Convert Mel spectrogram to linear spectrogram | |
| # linear_spec = np.dot(mel_filterbank_inv, mel_spectrogram) | |
| # # # If the input was a power spectrogram, take the square root | |
| # if power == 2.0: | |
| # linear_spec = np.sqrt(linear_spec) | |
| # # Perform Griffin-Lim to estimate the phase and convert to audio | |
| # audio_data = librosa.griffinlim( | |
| # linear_spec, | |
| # hop_length=hop_length, | |
| # n_fft=n_fft, | |
| # n_iter=n_iter | |
| # ) | |
| # return audio_data | |
| # def visualize_spectrogram( | |
| # spectrogram: np.ndarray, | |
| # power: int = 1, | |
| # sample_rate: int = DEFAULT_SAMPLE_RATE, | |
| # n_fft: int = 512, | |
| # hop_length: int = 192, | |
| # win_length: int = 384, | |
| # gap_int: Optional[Tuple[int, int]] = None, | |
| # in_db: bool = False, | |
| # y_axis: str = 'log', | |
| # x_axis: str = 'time', | |
| # title: str = 'Spectrogram', | |
| # save_path: Optional[Union[str, Path]] = None | |
| # ) -> figure: | |
| # """ | |
| # Visualize a spectrogram. | |
| # Parameters | |
| # ---------- | |
| # spectrogram (np.ndarray): Spectrogram to visualize | |
| # power (int): Whether the spectrogram is in energy (1) or power (2) scale | |
| # sample_rate (int, optional): Sample rate of audio | |
| # hop_length (int, optional): Number of samples between successive frames | |
| # gap_int (float tuple, optional): Start and end time [s] of the gap (if given) to be plotted as vertical lines | |
| # in_db (bool, optional): Whether the spectrogram is already in dB scale | |
| # y_axis (str, optional): Scale for the y-axis ('linear', 'log', or 'mel') | |
| # x_axis (str, optional): Scale for the x-axis ('time' or 'frames') | |
| # title (str, optional): Title for the plot | |
| # save_path (str or Path or None, optional): Path to save the visualization. If None, the plot is displayed. | |
| # Returns | |
| # ------- | |
| # Figure or None | |
| # The matplotlib Figure object if save_path is None, otherwise None | |
| # """ | |
| # if power not in (1, 2): | |
| # raise ValueError("Power must be 1 (energy) or 2 (power)") | |
| # # Convert to dB scale if needed | |
| # if in_db: | |
| # spectrogram_data = np.array(spectrogram) | |
| # elif power == 1: | |
| # spectrogram_data = librosa.amplitude_to_db(spectrogram, ref=np.max, amin=1e-5, top_db=80) | |
| # else: # power == 2 | |
| # spectrogram_data = librosa.power_to_db(spectrogram, ref=np.max, amin=1e-5, top_db=80) | |
| # fig, ax = plt.subplots(figsize=(10, 4)) | |
| # img = librosa.display.specshow( | |
| # spectrogram_data, | |
| # sr=sample_rate, | |
| # n_fft=n_fft, | |
| # win_length=win_length, | |
| # hop_length=hop_length, | |
| # y_axis=y_axis, | |
| # x_axis=x_axis, | |
| # ax=ax | |
| # ) | |
| # # Compute gap start and end indices and plot vertical lines | |
| # if gap_int is not None: | |
| # gap_start_s, gap_end_s = gap_int | |
| # ax.axvline(x=gap_start_s, color='white', linestyle='--', label='Gap Start') | |
| # ax.axvline(x=gap_end_s, color='white', linestyle='--', label='Gap End') | |
| # ax.legend() | |
| # # Add colorbar and title | |
| # fig.colorbar(img, ax=ax, format='%+2.0f dB') | |
| # ax.set_title(title) | |
| # fig.tight_layout() | |
| # # Save or return the figure | |
| # if save_path is not None: | |
| # save_path = Path(save_path) | |
| # output_dir = save_path.parent | |
| # if output_dir and not output_dir.exists(): | |
| # output_dir.mkdir(parents=True, exist_ok=True) | |
| # fig.savefig(save_path) | |
| # plt.close(fig) | |
| # return None | |
| # return fig |