Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import streamlit as st | |
| from transformers import pipeline | |
| from pydub import AudioSegment | |
| import tempfile | |
| import torch | |
| from datasets import load_dataset | |
| import jiwer | |
| import librosa | |
| import soundfile | |
| # Page configuration | |
| st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide") | |
| # Model configurations (three ASR models) | |
| MODELS = { | |
| "automatic-speech-recognition": { | |
| "whisper-tiny": "openai/whisper-tiny", | |
| "whisper-small": "openai/whisper-small", | |
| "whisper-base": "openai/whisper-base" | |
| }, | |
| "text2text-generation": { | |
| "flan-t5-base": "pszemraj/grammar-synthesis-small" | |
| } | |
| } | |
| # Cached model loading | |
| def load_model(model_key, task): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with st.spinner(f"Loading {model_key} model..."): | |
| return pipeline(task, model=MODELS[task][model_key], device=device) | |
| def convert_audio_to_wav(audio_file): | |
| """Convert uploaded audio to WAV format""" | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| audio = AudioSegment.from_file(audio_file) | |
| audio.export(tmp_file.name, format="wav") | |
| return tmp_file.name | |
| except Exception as e: | |
| st.error(f"Audio conversion failed: {str(e)}") | |
| return None | |
| def evaluate_asr_accuracy(transcription, reference): | |
| """Calculate WER and CER accuracy""" | |
| ref_processed = reference.lower().strip() | |
| hyp_processed = transcription.lower().strip() | |
| if not ref_processed: | |
| return 0.0, 0.0 | |
| wer = jiwer.wer(ref_processed, hyp_processed) | |
| cer = jiwer.cer(ref_processed, hyp_processed) | |
| return wer, cer | |
| # Cached dataset loading with audio decoding | |
| def load_cached_dataset(num_samples=1): | |
| st.info("Loading dataset...") | |
| try: | |
| dataset = load_dataset( | |
| "librispeech_asr", | |
| "clean", | |
| split="test", | |
| streaming=True, | |
| trust_remote_code=True | |
| ).take(num_samples) | |
| return [sample for sample in dataset] | |
| except Exception as e: | |
| st.error(f"Dataset loading failed: {str(e)}") | |
| return None | |
| def main(): | |
| st.title("🎤 Audio Grammar Evaluation System for Language Learners") | |
| # Session state for persisting results | |
| if "transcription" not in st.session_state: | |
| st.session_state.transcription = "" | |
| if "grammar_feedback" not in st.session_state: | |
| st.session_state.grammar_feedback = "" | |
| # Audio processing tab | |
| tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"]) | |
| with tab1: | |
| st.subheader("Upload & Process Audio") | |
| audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"]) | |
| if audio_file: | |
| st.audio(audio_file, format="audio/wav") | |
| wav_path = convert_audio_to_wav(audio_file) | |
| if wav_path: | |
| asr_model = load_model("whisper-tiny", "automatic-speech-recognition") | |
| with st.spinner("Generating transcription..."): | |
| transcription = asr_model(wav_path)["text"] | |
| st.session_state.transcription = transcription | |
| st.text_area("Transcription Result", transcription, height=150) | |
| if st.session_state.transcription: | |
| grammar_model = load_model("flan-t5-base", "text2text-generation") | |
| with st.spinner("Checking grammar..."): | |
| grammar_feedback = grammar_model( | |
| f"Correct the grammar in: {transcription}" | |
| )[0]["generated_text"] | |
| st.session_state.grammar_feedback = grammar_feedback | |
| st.success("Grammar Corrected Text:") | |
| st.write(grammar_feedback) | |
| os.unlink(wav_path) | |
| with tab2: | |
| st.subheader("Triple Model Evaluation with Runtime") | |
| # Model selection | |
| model_options = list(MODELS["automatic-speech-recognition"].keys()) | |
| model1, model2, model3 = st.columns(3) | |
| with model1: | |
| selected_model1 = st.selectbox("Select Model 1", model_options, index=0) | |
| with model2: | |
| selected_model2 = st.selectbox("Select Model 2", model_options, index=1) | |
| with model3: | |
| selected_model3 = st.selectbox("Select Model 3", model_options, index=2) | |
| if st.button("Run Triple Evaluation"): | |
| dataset = load_cached_dataset(num_samples=1) | |
| if not dataset: | |
| return | |
| # Load three models | |
| model1 = load_model(selected_model1, "automatic-speech-recognition") | |
| model2 = load_model(selected_model2, "automatic-speech-recognition") | |
| model3 = load_model(selected_model3, "automatic-speech-recognition") | |
| results = [] | |
| total_runtime_model1 = 0.0 | |
| total_runtime_model2 = 0.0 | |
| total_runtime_model3 = 0.0 | |
| for i, sample in enumerate(dataset): | |
| with st.spinner(f"Processing Sample..."): | |
| audio_array = sample["audio"]["array"] | |
| reference_text = sample["text"] | |
| # Evaluate Model 1 | |
| start_time = time.perf_counter() | |
| transcription1 = model1(audio_array)["text"] | |
| end_time = time.perf_counter() | |
| runtime1 = end_time - start_time | |
| total_runtime_model1 += runtime1 | |
| wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text) | |
| # Evaluate Model 2 | |
| start_time = time.perf_counter() | |
| transcription2 = model2(audio_array)["text"] | |
| end_time = time.perf_counter() | |
| runtime2 = end_time - start_time | |
| total_runtime_model2 += runtime2 | |
| wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text) | |
| # Evaluate Model 3 | |
| start_time = time.perf_counter() | |
| transcription3 = model3(audio_array)["text"] | |
| end_time = time.perf_counter() | |
| runtime3 = end_time - start_time | |
| total_runtime_model3 += runtime3 | |
| wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text) | |
| # Organize results | |
| model1_result = { | |
| "Model": selected_model1, | |
| "Runtime": f"{runtime1:.4f}s", | |
| "WER": f"{wer1*100:.2f}%", | |
| "CER": f"{cer1*100:.2f}%" | |
| } | |
| model2_result = { | |
| "Model": selected_model2, | |
| "Runtime": f"{runtime2:.4f}s", | |
| "WER": f"{wer2*100:.2f}%", | |
| "CER": f"{cer2*100:.2f}%" | |
| } | |
| model3_result = { | |
| "Model": selected_model3, | |
| "Runtime": f"{runtime3:.4f}s", | |
| "WER": f"{wer3*100:.2f}%", | |
| "CER": f"{cer3*100:.2f}%" | |
| } | |
| results.extend([model1_result, model2_result, model3_result]) | |
| # Display results | |
| st.subheader("Model Evaluation Results") | |
| st.table(results) | |
| if __name__ == "__main__": | |
| main() |