File size: 7,811 Bytes
6b8ec85
b984d1e
ea039d8
 
97e80ff
 
 
44dda8e
 
da36a40
 
ea039d8
97e80ff
6b7ad44
97e80ff
6fc7da3
97e80ff
4756ffd
6b7ad44
e7e9993
 
97e80ff
4756ffd
d2f728a
97e80ff
b984d1e
 
884904a
97e80ff
 
 
 
 
 
 
 
6b8ec85
97e80ff
 
 
 
6b8ec85
97e80ff
6b8ec85
ea039d8
97e80ff
170ab60
f12944f
 
 
170ab60
f12944f
 
 
 
 
640b69a
97e80ff
4756ffd
 
e7e9993
 
44dda8e
97e80ff
da36a40
 
 
 
 
884904a
170ab60
44dda8e
97e80ff
44dda8e
 
97e80ff
170ab60
b984d1e
4756ffd
 
 
 
 
 
 
 
b984d1e
4756ffd
170ab60
4756ffd
44dda8e
4756ffd
 
 
97e80ff
4756ffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170ab60
4756ffd
 
6fc7da3
6b7ad44
 
 
6fc7da3
6b7ad44
 
 
 
6fc7da3
 
6b7ad44
6fc7da3
af2e16d
4756ffd
 
2223b5f
6fc7da3
6b7ad44
 
6fc7da3
6b7ad44
4756ffd
6b7ad44
 
6fc7da3
2223b5f
4756ffd
af2e16d
170ab60
 
2223b5f
6b7ad44
170ab60
6b7ad44
170ab60
6b7ad44
 
 
2223b5f
6b7ad44
 
 
 
 
 
 
2223b5f
6fc7da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2223b5f
6b7ad44
 
4756ffd
170ab60
ea039d8
97e80ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import time
import streamlit as st
from transformers import pipeline
from pydub import AudioSegment
import tempfile
import torch
from datasets import load_dataset
import jiwer
import librosa  
import soundfile

# Page configuration
st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide")

# Model configurations (three ASR models)
MODELS = {
    "automatic-speech-recognition": {
        "whisper-tiny": "openai/whisper-tiny",
        "whisper-small": "openai/whisper-small",
        "whisper-base": "openai/whisper-base"
    },
    "text2text-generation": {
        "flan-t5-base": "pszemraj/grammar-synthesis-small"
    }
}

# Cached model loading
@st.cache_resource
def load_model(model_key, task):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    with st.spinner(f"Loading {model_key} model..."):
        return pipeline(task, model=MODELS[task][model_key], device=device)

def convert_audio_to_wav(audio_file):
    """Convert uploaded audio to WAV format"""
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
            audio = AudioSegment.from_file(audio_file)
            audio.export(tmp_file.name, format="wav")
            return tmp_file.name
    except Exception as e:
        st.error(f"Audio conversion failed: {str(e)}")
        return None

def evaluate_asr_accuracy(transcription, reference):
    """Calculate WER and CER accuracy"""
    ref_processed = reference.lower().strip()
    hyp_processed = transcription.lower().strip()
    
    if not ref_processed:
        return 0.0, 0.0
    
    wer = jiwer.wer(ref_processed, hyp_processed)
    cer = jiwer.cer(ref_processed, hyp_processed)
    
    return  wer,  cer

# Cached dataset loading with audio decoding
@st.cache_data(show_spinner=False)
def load_cached_dataset(num_samples=1):
    st.info("Loading dataset...")
    try:
        dataset = load_dataset(
            "librispeech_asr", 
            "clean", 
            split="test", 
            streaming=True, 
            trust_remote_code=True
        ).take(num_samples)
        return [sample for sample in dataset]
    except Exception as e:
        st.error(f"Dataset loading failed: {str(e)}")
        return None

def main():
    st.title("🎤 Audio Grammar Evaluation System for Language Learners")
    
    # Session state for persisting results
    if "transcription" not in st.session_state:
        st.session_state.transcription = ""
    if "grammar_feedback" not in st.session_state:
        st.session_state.grammar_feedback = ""

    # Audio processing tab
    tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"])
    
    with tab1:
        st.subheader("Upload & Process Audio")
        audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"])
        
        if audio_file:
            st.audio(audio_file, format="audio/wav")
            wav_path = convert_audio_to_wav(audio_file)
            
            if wav_path:
                asr_model = load_model("whisper-tiny", "automatic-speech-recognition")
                
                with st.spinner("Generating transcription..."):
                    transcription = asr_model(wav_path)["text"]
                    st.session_state.transcription = transcription
                    st.text_area("Transcription Result", transcription, height=150)
                
                if st.session_state.transcription:
                    grammar_model = load_model("flan-t5-base", "text2text-generation")
                    with st.spinner("Checking grammar..."):
                        grammar_feedback = grammar_model(
                            f"Correct the grammar in: {transcription}"
                        )[0]["generated_text"]
                        st.session_state.grammar_feedback = grammar_feedback
                        st.success("Grammar Corrected Text:")
                        st.write(grammar_feedback)
                
                os.unlink(wav_path)

    with tab2:
        st.subheader("Triple Model Evaluation with Runtime")
        
        # Model selection
        model_options = list(MODELS["automatic-speech-recognition"].keys())
        model1, model2, model3 = st.columns(3)
        with model1:
            selected_model1 = st.selectbox("Select Model 1", model_options, index=0)
        with model2:
            selected_model2 = st.selectbox("Select Model 2", model_options, index=1)
        with model3:
            selected_model3 = st.selectbox("Select Model 3", model_options, index=2)
        
        if st.button("Run Triple Evaluation"):
            dataset = load_cached_dataset(num_samples=1)
            if not dataset:
                return
            
            # Load three models
            model1 = load_model(selected_model1, "automatic-speech-recognition")
            model2 = load_model(selected_model2, "automatic-speech-recognition")
            model3 = load_model(selected_model3, "automatic-speech-recognition")
            
            results = []
            total_runtime_model1 = 0.0
            total_runtime_model2 = 0.0
            total_runtime_model3 = 0.0
            
            for i, sample in enumerate(dataset):
                with st.spinner(f"Processing Sample..."):
                    audio_array = sample["audio"]["array"]
                    reference_text = sample["text"]
                    
                    # Evaluate Model 1
                    start_time = time.perf_counter()
                    transcription1 = model1(audio_array)["text"]
                    end_time = time.perf_counter()
                    runtime1 = end_time - start_time
                    total_runtime_model1 += runtime1
                    wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text)
                    
                    # Evaluate Model 2
                    start_time = time.perf_counter()
                    transcription2 = model2(audio_array)["text"]
                    end_time = time.perf_counter()
                    runtime2 = end_time - start_time
                    total_runtime_model2 += runtime2
                    wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text)
                    
                    # Evaluate Model 3
                    start_time = time.perf_counter()
                    transcription3 = model3(audio_array)["text"]
                    end_time = time.perf_counter()
                    runtime3 = end_time - start_time
                    total_runtime_model3 += runtime3
                    wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text)
                    
                    # Organize results
                    model1_result = {
                        "Model": selected_model1,
                        "Runtime": f"{runtime1:.4f}s",
                        "WER": f"{wer1*100:.2f}%",
                        "CER": f"{cer1*100:.2f}%"
                    }
                    model2_result = {
                        "Model": selected_model2,
                        "Runtime": f"{runtime2:.4f}s",
                        "WER": f"{wer2*100:.2f}%",
                        "CER": f"{cer2*100:.2f}%"
                    }
                    model3_result = {
                        "Model": selected_model3,
                        "Runtime": f"{runtime3:.4f}s",
                        "WER": f"{wer3*100:.2f}%",
                        "CER": f"{cer3*100:.2f}%"
                    }
                    results.extend([model1_result, model2_result, model3_result])
            
            # Display results
            st.subheader("Model Evaluation Results")
            st.table(results)
            

if __name__ == "__main__":
    main()