Instructions to use flexthink/discrete_wavlm_spk_rec_ecapatdn_lite with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- speechbrain
How to use flexthink/discrete_wavlm_spk_rec_ecapatdn_lite with speechbrain:
# interface not specified in config.json
- Notebooks
- Google Colab
- Kaggle
| from typing import Mapping | |
| import torch | |
| import math | |
| from speechbrain.inference.interfaces import Pretrained | |
| class AttentionMLP(torch.nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super(AttentionMLP, self).__init__() | |
| self.layers = torch.nn.Sequential( | |
| torch.nn.Linear(input_dim, hidden_dim), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(hidden_dim, 1, bias=False), | |
| ) | |
| def forward(self, x): | |
| x = self.layers(x) | |
| att_w = torch.nn.functional.softmax(x, dim=2) | |
| return att_w | |
| class Discrete_EmbeddingLayer(torch.nn.Module): | |
| """This class handles embedding layers for discrete tokens. | |
| Arguments | |
| --------- | |
| num_codebooks: int , | |
| number of codebooks of the tokenizer. | |
| vocab_size : int, | |
| size of the dictionary of embeddings | |
| emb_dim: int , | |
| the size of each embedding vector | |
| pad_index: int (default: 0), | |
| If specified, the entries at padding_idx do not contribute to the gradient. | |
| init: boolean (default: False): | |
| If set to True, init the embedding with the tokenizer embedding otherwise init randomly. | |
| freeze: boolean (default: False) | |
| If True, the embedding is frozen. If False, the model will be trained | |
| alongside with the rest of the pipeline. | |
| chunk_size: int | |
| The size of lengthwize chunks use when evaluating via | |
| Gumbel softmax | |
| Example | |
| ------- | |
| >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec | |
| >>> model_hub = "facebook/encodec_24khz" | |
| >>> save_path = "savedir" | |
| >>> model = Encodec(model_hub, save_path) | |
| >>> audio = torch.randn(4, 1000) | |
| >>> length = torch.tensor([1.0, .5, .75, 1.0]) | |
| >>> tokens, emb = model.encode(audio, length) | |
| >>> print(tokens.shape) | |
| torch.Size([4, 4, 2]) | |
| >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) | |
| >>> in_emb = emb(tokens) | |
| >>> print(in_emb.shape) | |
| torch.Size([4, 4, 2, 1024]) | |
| """ | |
| def __init__( | |
| self, | |
| num_codebooks, | |
| vocab_size, | |
| emb_dim, | |
| pad_index=0, | |
| init=False, | |
| freeze=False, | |
| available_layers=None, | |
| layers=None, | |
| chunk_size=100, | |
| ): | |
| super(Discrete_EmbeddingLayer, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.num_codebooks = num_codebooks | |
| self.freeze = freeze | |
| self.embedding = torch.nn.Embedding( | |
| num_codebooks * vocab_size, emb_dim | |
| ).requires_grad_(not self.freeze) | |
| self.init = init | |
| self.layers = layers | |
| self.available_layers = available_layers | |
| self.register_buffer("offsets", self.build_offsets()) | |
| self.register_buffer("layer_embs", self.compute_layer_embs()) | |
| self.chunk_size = chunk_size | |
| def init_embedding(self, weights): | |
| with torch.no_grad(): | |
| self.embedding.weight = torch.nn.Parameter(weights) | |
| def build_offsets(self): | |
| offsets = torch.arange( | |
| 0, | |
| self.num_codebooks * self.vocab_size, | |
| self.vocab_size, | |
| ) | |
| if self.layers: | |
| selected_layers = set(self.layers) | |
| indexes = [ | |
| idx for idx, layer in enumerate(self.available_layers) | |
| if layer in selected_layers | |
| ] | |
| offsets = offsets[indexes] | |
| return offsets | |
| def forward(self, in_tokens): | |
| """Computes the embedding for discrete tokens. | |
| a sample. | |
| Arguments | |
| --------- | |
| in_tokens : torch.Tensor | |
| A (Batch x Time x num_codebooks) | |
| audio sample | |
| Returns | |
| ------- | |
| in_embs : torch.Tensor | |
| """ | |
| with torch.set_grad_enabled(not self.freeze): | |
| # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size | |
| in_tokens_offset = in_tokens + self.offsets.to(in_tokens.device) | |
| # Forward Pass to embedding and | |
| in_embs = self.embedding(in_tokens_offset.int()) | |
| return in_embs | |
| def compute_layer_embs(self): | |
| weight = self.embedding.weight | |
| # Compute offsets | |
| layer_idx_map = { | |
| layer: idx | |
| for idx, layer in enumerate(self.available_layers) | |
| } | |
| layer_idx = [ | |
| layer_idx_map[layer] | |
| for layer in self.layers | |
| ] | |
| offsets = [ | |
| idx * self.vocab_size | |
| for idx in layer_idx | |
| ] | |
| layer_embs = torch.stack([ | |
| weight[offset:offset + self.vocab_size] | |
| for offset in offsets | |
| ]) | |
| # To (Batch x Length x Emb) | |
| layer_embs = layer_embs.unsqueeze(0).unsqueeze(0) | |
| return layer_embs | |
| def encode_logits(self, logits, length=None): | |
| """Computes waveforms from a batch of discrete units | |
| Arguments | |
| --------- | |
| units: torch.tensor | |
| Batch of discrete unit logits [batch, length, head, token] | |
| or tokens [batch, length, head] | |
| spk: torch.tensor | |
| Batch of speaker embeddings [batch, spk_dim] | |
| Returns | |
| ------- | |
| waveforms: torch.tensor | |
| Batch of mel-waveforms [batch, 1, time] | |
| """ | |
| # Convert logits to one-hot representations | |
| # without losing the gradient | |
| units_gumbel = torch.nn.functional.gumbel_softmax( | |
| logits, | |
| hard=False, | |
| dim=-1 | |
| ) | |
| # Straight-through trick | |
| _, argmax_idx = logits.max(dim=-1, keepdim=True) | |
| units_ref = torch.zeros_like(logits).scatter_( | |
| dim=-1, index=argmax_idx, src=torch.ones_like(logits) | |
| ) | |
| units_hard = units_ref - units_gumbel.detach() + units_gumbel | |
| # Sum over embeddings for each layer | |
| units_hard_chunked = units_hard.chunk( | |
| math.ceil(units_hard.size(1) / self.chunk_size), | |
| dim=1 | |
| ) | |
| emb = torch.cat( | |
| [ | |
| (self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2) | |
| for units_hard_chunk in units_hard_chunked | |
| ], | |
| dim=1 | |
| ) | |
| return emb | |
| def load_state_dict(self, state_dict, strict=True): | |
| result = super().load_state_dict(state_dict, strict) | |
| self.layer_embs = self.compute_layer_embs() | |
| return result | |
| class DiscreteSpkEmb(Pretrained): | |
| """A ready-to-use class for utterance-level classification (e.g, speaker-id, | |
| language-id, emotion recognition, keyword spotting, etc). | |
| The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model | |
| are defined in the yaml file. If you want to | |
| convert the predicted index into a corresponding text label, please | |
| provide the path of the label_encoder in a variable called 'lab_encoder_file' | |
| within the yaml. | |
| The class can be used either to run only the encoder (encode_batch()) to | |
| extract embeddings or to run a classification step (classify_batch()). | |
| ``` | |
| Example | |
| ------- | |
| >>> import torchaudio | |
| >>> from speechbrain.pretrained import EncoderClassifier | |
| >>> # Model is downloaded from the speechbrain HuggingFace repo | |
| >>> tmpdir = getfixture("tmpdir") | |
| >>> classifier = EncoderClassifier.from_hparams( | |
| ... source="speechbrain/spkrec-ecapa-voxceleb", | |
| ... savedir=tmpdir, | |
| ... ) | |
| >>> # Compute embeddings | |
| >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") | |
| >>> embeddings = classifier.encode_batch(signal) | |
| >>> # Classification | |
| >>> prediction = classifier .classify_batch(signal) | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def encode_batch(self, audio, length=None): | |
| """Encodes the input audio into a single vector embedding. | |
| The waveforms should already be in the model's desired format. | |
| Arguments | |
| --------- | |
| audio : torch.tensor | |
| Batch of tokenized audio [batch, time, heads] | |
| length : torch.tensor | |
| Lengths of the waveforms relative to the longest one in the | |
| batch, tensor of shape [batch]. The longest one should have | |
| relative length 1.0 and others len(waveform) / max_length. | |
| Used for ignoring padding. | |
| Returns | |
| ------- | |
| torch.tensor | |
| The encoded batch | |
| """ | |
| # Manage single waveforms in input | |
| embeddings = self.mods.discrete_embedding_layer(audio) | |
| att_w = self.mods.attention_mlp(embeddings) | |
| feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) | |
| embeddings = self.mods.embedding_model(feats, length) | |
| return embeddings.squeeze(1) | |
| def encode_logits(self, logits, length=None): | |
| """Encodes the input audio logits into a single vector embedding. | |
| Arguments | |
| --------- | |
| audio : torch.tensor | |
| Batch of tokenized audio [batch, time, heads] | |
| length : torch.tensor | |
| Lengths of the waveforms relative to the longest one in the | |
| batch, tensor of shape [batch]. The longest one should have | |
| relative length 1.0 and others len(waveform) / max_length. | |
| Used for ignoring padding. | |
| Returns | |
| ------- | |
| torch.tensor | |
| The encoded batch | |
| """ | |
| embeddings = self.mods.discrete_embedding_layer.encode_logits(logits) | |
| att_w = self.mods.attention_mlp(embeddings) | |
| feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) | |
| embeddings = self.mods.embedding_model(feats, length) | |
| return embeddings.squeeze(1) | |
| def forward(self, audio, length=None): | |
| """Encodes the input audio into a single vector embedding. | |
| The waveforms should already be in the model's desired format. | |
| Arguments | |
| --------- | |
| audio : torch.tensor | |
| Batch of tokenized audio [batch, time, heads] | |
| or logits [batch, time, heads, tokens] | |
| length : torch.tensor | |
| Lengths of the waveforms relative to the longest one in the | |
| batch, tensor of shape [batch]. The longest one should have | |
| relative length 1.0 and others len(waveform) / max_length. | |
| Used for ignoring padding. | |
| Returns | |
| ------- | |
| torch.tensor | |
| The encoded batch | |
| """ | |
| audio_dim = audio.dim() | |
| if audio_dim == 3: | |
| embeddings = self.encode_batch(audio, length) | |
| elif audio_dim == 4: | |
| embeddings = self.encode_logits(audio, length) | |
| else: | |
| raise ValueError("Unsupported audio shape {audio.shape}") | |
| return embeddings | |