| from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
| from spleeter.separator import Separator |
|
|
|
|
| class SpleeterConfig(PretrainedConfig): |
| model_type = "spleeter" |
|
|
| def __init__(self, stems=2, **kwargs): |
| super().__init__(**kwargs) |
| self.stems = stems |
|
|
|
|
| class SpleeterModel(PreTrainedModel): |
| config_class = SpleeterConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.separator = Separator(f"spleeter:{config.stems}stems") |
|
|
| def forward(self, audio_path: str): |
| """ |
| Separates the stems in the given audio file. |
| Args: |
| audio_path (str): Path to the input audio file. |
| Returns: |
| path: Separated stems. |
| """ |
| return self.separator.separate_to_file(audio_path, "separated_audio") |
|
|
|
|
| AutoConfig.register("spleeter", SpleeterConfig) |
| AutoModel.register(SpleeterConfig, SpleeterModel) |
| SpleeterConfig.register_for_auto_class() |
| SpleeterModel.register_for_auto_class("AutoModel") |