Project-grammar / app.py
smxxxxxxx's picture
Update app.py
640b69a verified
import os
import time
import streamlit as st
from transformers import pipeline
from pydub import AudioSegment
import tempfile
import torch
from datasets import load_dataset
import jiwer
import librosa
import soundfile
# Page configuration
st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide")
# Model configurations (three ASR models)
MODELS = {
"automatic-speech-recognition": {
"whisper-tiny": "openai/whisper-tiny",
"whisper-small": "openai/whisper-small",
"whisper-base": "openai/whisper-base"
},
"text2text-generation": {
"flan-t5-base": "pszemraj/grammar-synthesis-small"
}
}
# Cached model loading
@st.cache_resource
def load_model(model_key, task):
device = "cuda" if torch.cuda.is_available() else "cpu"
with st.spinner(f"Loading {model_key} model..."):
return pipeline(task, model=MODELS[task][model_key], device=device)
def convert_audio_to_wav(audio_file):
"""Convert uploaded audio to WAV format"""
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
audio = AudioSegment.from_file(audio_file)
audio.export(tmp_file.name, format="wav")
return tmp_file.name
except Exception as e:
st.error(f"Audio conversion failed: {str(e)}")
return None
def evaluate_asr_accuracy(transcription, reference):
"""Calculate WER and CER accuracy"""
ref_processed = reference.lower().strip()
hyp_processed = transcription.lower().strip()
if not ref_processed:
return 0.0, 0.0
wer = jiwer.wer(ref_processed, hyp_processed)
cer = jiwer.cer(ref_processed, hyp_processed)
return wer, cer
# Cached dataset loading with audio decoding
@st.cache_data(show_spinner=False)
def load_cached_dataset(num_samples=1):
st.info("Loading dataset...")
try:
dataset = load_dataset(
"librispeech_asr",
"clean",
split="test",
streaming=True,
trust_remote_code=True
).take(num_samples)
return [sample for sample in dataset]
except Exception as e:
st.error(f"Dataset loading failed: {str(e)}")
return None
def main():
st.title("🎤 Audio Grammar Evaluation System for Language Learners")
# Session state for persisting results
if "transcription" not in st.session_state:
st.session_state.transcription = ""
if "grammar_feedback" not in st.session_state:
st.session_state.grammar_feedback = ""
# Audio processing tab
tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"])
with tab1:
st.subheader("Upload & Process Audio")
audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"])
if audio_file:
st.audio(audio_file, format="audio/wav")
wav_path = convert_audio_to_wav(audio_file)
if wav_path:
asr_model = load_model("whisper-tiny", "automatic-speech-recognition")
with st.spinner("Generating transcription..."):
transcription = asr_model(wav_path)["text"]
st.session_state.transcription = transcription
st.text_area("Transcription Result", transcription, height=150)
if st.session_state.transcription:
grammar_model = load_model("flan-t5-base", "text2text-generation")
with st.spinner("Checking grammar..."):
grammar_feedback = grammar_model(
f"Correct the grammar in: {transcription}"
)[0]["generated_text"]
st.session_state.grammar_feedback = grammar_feedback
st.success("Grammar Corrected Text:")
st.write(grammar_feedback)
os.unlink(wav_path)
with tab2:
st.subheader("Triple Model Evaluation with Runtime")
# Model selection
model_options = list(MODELS["automatic-speech-recognition"].keys())
model1, model2, model3 = st.columns(3)
with model1:
selected_model1 = st.selectbox("Select Model 1", model_options, index=0)
with model2:
selected_model2 = st.selectbox("Select Model 2", model_options, index=1)
with model3:
selected_model3 = st.selectbox("Select Model 3", model_options, index=2)
if st.button("Run Triple Evaluation"):
dataset = load_cached_dataset(num_samples=1)
if not dataset:
return
# Load three models
model1 = load_model(selected_model1, "automatic-speech-recognition")
model2 = load_model(selected_model2, "automatic-speech-recognition")
model3 = load_model(selected_model3, "automatic-speech-recognition")
results = []
total_runtime_model1 = 0.0
total_runtime_model2 = 0.0
total_runtime_model3 = 0.0
for i, sample in enumerate(dataset):
with st.spinner(f"Processing Sample..."):
audio_array = sample["audio"]["array"]
reference_text = sample["text"]
# Evaluate Model 1
start_time = time.perf_counter()
transcription1 = model1(audio_array)["text"]
end_time = time.perf_counter()
runtime1 = end_time - start_time
total_runtime_model1 += runtime1
wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text)
# Evaluate Model 2
start_time = time.perf_counter()
transcription2 = model2(audio_array)["text"]
end_time = time.perf_counter()
runtime2 = end_time - start_time
total_runtime_model2 += runtime2
wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text)
# Evaluate Model 3
start_time = time.perf_counter()
transcription3 = model3(audio_array)["text"]
end_time = time.perf_counter()
runtime3 = end_time - start_time
total_runtime_model3 += runtime3
wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text)
# Organize results
model1_result = {
"Model": selected_model1,
"Runtime": f"{runtime1:.4f}s",
"WER": f"{wer1*100:.2f}%",
"CER": f"{cer1*100:.2f}%"
}
model2_result = {
"Model": selected_model2,
"Runtime": f"{runtime2:.4f}s",
"WER": f"{wer2*100:.2f}%",
"CER": f"{cer2*100:.2f}%"
}
model3_result = {
"Model": selected_model3,
"Runtime": f"{runtime3:.4f}s",
"WER": f"{wer3*100:.2f}%",
"CER": f"{cer3*100:.2f}%"
}
results.extend([model1_result, model2_result, model3_result])
# Display results
st.subheader("Model Evaluation Results")
st.table(results)
if __name__ == "__main__":
main()