File size: 4,828 Bytes
67ca15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import io
import uuid
import time
import json
import logging
import tempfile
import threading

from flask import Flask, request, jsonify, send_file
from transformers import pipeline
from gtts import gTTS
from pydub import AudioSegment

# ================= CONFIG =================

TEMP_AUDIO_DIR = "/tmp/audio"
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)

STT_MODEL = "openai/whisper-tiny"
LLM_MODEL = "google/flan-t5-base"

MAX_AUDIO_SECONDS = 10
MAX_TEXT_LEN = 200

CLEANUP_INTERVAL = 300      # seconds
FILE_EXPIRE_TIME = 600     # seconds

# ================= LOG =================

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)

# ================= APP =================

app = Flask(__name__)
app.config["TEMP_AUDIO_DIR"] = TEMP_AUDIO_DIR

# ================= LOAD MODELS =================

logger.info("Loading STT model...")
stt_pipeline = pipeline(
    "automatic-speech-recognition",
    model=STT_MODEL,
    device="cpu"
)

logger.info("Loading LLM model...")
llm_pipeline = pipeline(
    "text2text-generation",
    model=LLM_MODEL,
    device="cpu"
)

logger.info("Models loaded successfully")

# ================= UTILS =================

def generate_tts_audio(text: str) -> bytes:
    """
    Generate WAV 16kHz mono audio from text
    """
    try:
        text = text.replace("\n", " ").strip()
        if not text:
            text = "I understand."

        text = text[:MAX_TEXT_LEN]
        logger.info(f"TTS: {text}")

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_file:
            mp3_path = wav_file.name.replace(".wav", ".mp3")

            tts = gTTS(text=text, lang="en")
            tts.save(mp3_path)

            audio = AudioSegment.from_file(mp3_path)
            audio = audio.set_frame_rate(16000).set_channels(1)
            audio.export(wav_file.name, format="wav")

            with open(wav_file.name, "rb") as f:
                wav_data = f.read()

        os.remove(mp3_path)
        os.remove(wav_file.name)

        return wav_data

    except Exception as e:
        logger.error(f"TTS error: {e}", exc_info=True)
        return b""


def cleanup_temp_files():
    while True:
        try:
            now = time.time()
            for filename in os.listdir(TEMP_AUDIO_DIR):
                path = os.path.join(TEMP_AUDIO_DIR, filename)
                if os.path.isfile(path):
                    if now - os.path.getmtime(path) > FILE_EXPIRE_TIME:
                        os.remove(path)
        except Exception as e:
            logger.warning(f"Cleanup error: {e}")

        time.sleep(CLEANUP_INTERVAL)


# ================= ROUTES =================

@app.route("/health", methods=["GET"])
def health():
    return jsonify({
        "status": "ok",
        "stt": STT_MODEL,
        "llm": LLM_MODEL
    })


@app.route("/process_audio", methods=["POST"])
def process_audio():
    try:
        if "audio" not in request.files:
            return jsonify({"error": "No audio file"}), 400

        audio_file = request.files["audio"]
        raw_audio = audio_file.read()

        if len(raw_audio) < 1000:
            return jsonify({"error": "Audio too short"}), 400

        # ================= STT =================
        logger.info("Running STT...")
        stt_result = stt_pipeline(
            raw_audio,
            sampling_rate=16000
        )

        user_text = stt_result.get("text", "").strip()
        logger.info(f"User said: {user_text}")

        if not user_text:
            user_text = "Hello"

        # ================= LLM =================
        logger.info("Running LLM...")
        llm_result = llm_pipeline(
            user_text,
            max_new_tokens=64,
            do_sample=False
        )

        answer = llm_result[0]["generated_text"]
        logger.info(f"Answer: {answer}")

        # ================= TTS =================
        audio_response = generate_tts_audio(answer)

        if not audio_response:
            return jsonify({"error": "TTS failed"}), 500

        file_id = str(uuid.uuid4())
        filepath = os.path.join(TEMP_AUDIO_DIR, f"{file_id}.wav")

        with open(filepath, "wb") as f:
            f.write(audio_response)

        return send_file(
            filepath,
            mimetype="audio/wav",
            as_attachment=False,
            download_name="response.wav"
        )

    except Exception as e:
        logger.error(f"Processing error: {e}", exc_info=True)
        return jsonify({"error": "Internal error"}), 500


# ================= STARTUP =================

if __name__ == "__main__":
    threading.Thread(target=cleanup_temp_files, daemon=True).start()

    app.run(
        host="0.0.0.0",
        port=7860,
        threaded=True
    )