DavidWill's picture
Update app.py
09a365f verified
import gradio as gr
import torch
from transformers import pipeline
import librosa
import tempfile
import os
import soundfile as sf
# Step 1: 載入預訓練音樂分類模型(Audio Spectrogram Transformer)
# 這個模型在 AudioSet 上訓練,可分類多種聲音事件,包括音樂流派
classifier = pipeline(
"audio-classification",
model="MIT/ast-finetuned-audioset-10-10-0.4593",
device=-1 # 強制使用 CPU
)
# 我們只篩選與「音樂流派」相關的標籤(可自訂)
MUSIC_GENRES = {
"acoustic_guitar", "electric_guitar", "piano", "violin", "flute",
"drum", "electronic_music", "classical_music", "rock_music",
"pop_music", "hip_hop_music", "jazz_music", "opera", "country_music"
}
def classify_music_genre(audio_file):
"""
輸入:上傳的音檔(任何格式)
輸出:Top-3 預測流派 + 機率
"""
if audio_file is None:
return "請上傳音檔!"
# Step 2: 用 librosa 載入音檔,並重採樣到 16kHz(模型預期輸入)
audio, sr = librosa.load(audio_file, sr=16000)
# Step 3: 儲存為臨時 wav 檔(pipeline 需要檔案路徑)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, audio, sr) # 👈 修正:使用 soundfile 寫檔
temp_path = tmpfile.name
# Step 4: 進行分類
try:
results = classifier(temp_path)
except Exception as e:
os.unlink(temp_path)
return f"分類失敗:{str(e)}"
# Step 5: 過濾出音樂相關標籤 + 格式化輸出
filtered_results = [
r for r in results if any(genre in r['label'].lower() for genre in MUSIC_GENRES)
][:3] # 取 Top 3
if not filtered_results:
filtered_results = results[:3] # 若無音樂標籤,顯示原始 Top 3
# 格式化為易讀文字
output = "\n".join([
f"🎵 {r['label']} ({r['score']:.2%})"
for r in filtered_results
])
# 清理臨時檔案
os.unlink(temp_path)
return output
# Step 6: 建立 Gradio UI
demo = gr.Interface(
fn=classify_music_genre,
inputs=gr.Audio(type="filepath", label="上傳音樂片段(建議 5~15 秒)"),
outputs=gr.Textbox(
label="AI 預測的音樂流派 Top 3",
lines=6, # 👈 設定顯示行數(高度)
max_lines=10, # 👈 最多可顯示的行數(超過會出現捲軸)
interactive=False # 👈 防止使用者編輯(可選)
),
title="🎧 音樂流派分類器(CPU 友好版)",
description="""
上傳一小段音樂,AI 會預測它屬於哪些流派!
支援格式:mp3, wav, m4a, ogg 等(需瀏覽器支援)
""",
#examples=[
# ["sample_music.mp3"]
#["example2.mp3"]
#], # 可選:上傳範例音檔到 Space Files
allow_flagging="never"
)
# 啟動!
if __name__ == "__main__":
demo.launch()