File size: 9,004 Bytes
5c99fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# File: app.py

import os
import uvicorn
import pickle
import faiss
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

# --- This part of the code runs ONCE when the server starts up ---
print("--- Server starting up: Loading all models into memory... ---")

# --- 1. Load RAG Components (FAISS Index and Texts) ---
# This path now looks for a local 'data' folder.
data_path = "./data/"
try:
    with open(os.path.join(data_path, "faiss_texts.pkl"), "rb") as f:
        texts = pickle.load(f)
    index = faiss.read_index(os.path.join(data_path, "faiss_index.index"))
    print("βœ… RAG components (FAISS index and texts) loaded successfully.")
except FileNotFoundError:
    print("❌ ERROR: FAISS index or texts not found in the './data/' directory.")
    print("Please make sure 'faiss_texts.pkl' and 'faiss_index.index' are present.")
    exit() # Stop the server if data files are missing.

# --- 2. Load the AI Models ---
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"βœ… Running on device: {device}")

print("⏳ Loading embedding model (Qwen)...")
embedding_model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B')
print("βœ… Embedding model loaded.")

print("⏳ Loading main LLM (microsoft/MediPhi-Clinical)...")
model_name = "microsoft/MediPhi-Clinical"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    load_in_4bit=True,
    trust_remote_code=True
)
print("βœ… Main LLM loaded.")

print("⏳ Loading classifier model (facebook/bart-large-mnli)...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("βœ… Classifier loaded.")
print("--- All models loaded. Server is ready. ---")


# --- HELPER FUNCTIONS (Copied and cleaned from your notebook) ---

def query_rag(query: str, top_k: int = 5):
    query_vec = embedding_model.encode([query], convert_to_numpy=True, device=device)
    query_vec = np.array(query_vec, dtype=np.float32).reshape(1, -1)
    D, I = index.search(query_vec, top_k)
    return [texts[i] for i in I[0]]

class ChatMessage(BaseModel): # Moved here to be available for functions
    role: str
    content: str

def generate_answer(query: str, context: list, history: List[ChatMessage], tokens: int = 1000) -> str:
    formatted_history = ""
    for message in history:
        if message.role == 'user':
            formatted_history += f"Previous Question: {message.content}\\n"
        elif message.role == 'assistant':
            formatted_history += f"Your Previous Answer: {message.content}\\n---\\n"
    
    system_prompt = f"""You are a reliable, calm, and knowledgeable medical first aid assistant. You can remember the past conversation.
 
Your job is to give a **complete, step-by-step guide**.
 
Follow these strict rules:
- Use clear, simple, reassuring language.
- Avoid generic bullet lists β€” instead, structure the response in short sections.
- Highlight urgent actions with **bold** text.
- **If the query describes or implies an emergency / injury / accident / sudden illness / life-threatening scenario β†’ switch 
to FIRST AID MODE.**
 
- **If the query is informational, preventive, or explanatory (e.g., about conditions, symptoms, nutrition, health advice) β†’ 
switch to GENERAL MEDICAL MODE.**
- In case of emergencies, instruct the user to contact their *local* emergency services or visit the nearest medical facility.
- If the retrieved context mentions a specific emergency hotline (e.g., 911, 000, 112), replace it with the phrase 
β€˜your local emergency number’. In case a country is mentioned do not include it in the reponse.
- Do not mention or speculate about the user's country, location, or local hotlines β€” use only "your local emergency 
number" when advising to call emergency services.
- All words must be spelled correctly in standard English; do not produce garbled or partial words (e.g., β€œImmedi000”). If unsure, rewrite the phrase normally.
- If a protocol (like FAST or DRSABCD) appears, **expand and explain each letter.**
 
 
### FIRST AID MODE (for emergencies)
 
**Format:**
 
**Situation:** short summary of the problem  
Start with a calm reassurance: β€œStay calm β€” quick action can make a big difference.”  
 
Then follow this exact structure:
 
1. **Immediate Actions** β€” 3–6 numbered steps showing exactly what to do first  
2. **Rationale** β€” one short paragraph explaining *why* those steps matter  
3. **What NOT to do** β€” 2–4 clear things to avoid  
4. **🚨 When to Seek Immediate Help** β€” short list of red-flag signs then append `<END>` and nothing after it

 
**Rules:**
- Highlight urgent actions with **bold**.  
- Always replace emergency numbers with β€œyour local emergency number”.  
- If a first aid protocol (like DRSABCD or FAST) appears, **expand and explain each step clearly**.
 
---
 
### GENERAL MEDICAL MODE (for non-emergencies)
 
**Format:**
 
Then use this structure:
1. **Overview** β€” simple explanation of the concept or condition  
2. **Causes / Mechanism** β€” brief summary of underlying reason(s)  
3. **Common Symptoms or Signs** β€” bullet points if relevant  
4. **Management / Prevention** β€” clear, general guidance (no prescription-only details)  
5. **When to See a Doctor** β€” short note for red flags then append `<END>` and nothing after it. 

---

### Chat History:
{formatted_history}

### Context for the new question:
{context}

### Current Question:
{query}

Now provide your single response below:
"""

    inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **inputs,
        max_new_tokens=tokens,
        temperature=0.4,
        top_p=0.9,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(output[0], skip_special_tokens=True)
    response = text.split("Now provide your single response below:")[-1]
    response = response.split("<END>")[0].strip()
    return response

def classify_intent_advanced(query: str, history: List[ChatMessage]) -> str:
    query_lower = query.lower().strip()
    greetings = ["hello", "hi", "hey", "good morning", "good afternoon", "good evening"]
    if any(greeting in query_lower for greeting in greetings):
        return "greeting"

    candidate_labels = ["medical first-aid question", "general conversation"]
    result = classifier(query, candidate_labels)
    top_label = result['labels'][0]
    top_score = result['scores'][0]
    
    if top_label == "medical first-aid question" and top_score > 0.55:
        print(f"🧠 Intent classified as MEDICAL ({top_score:.2f})")
        return "medical"
    else:
        print(f"🧠 Intent classified as OFF_TOPIC ({top_score:.2f})")
        return "off_topic"

def classify_response_seriousness(response_text: str) -> bool:
    candidate_labels = ["urgent medical emergency", "minor first-aid advice"]
    result = classifier(response_text, candidate_labels)
    top_label = result['labels'][0]
    top_score = result['scores'][0]
    print(f"βš•οΈ Seriousness check β†’ {top_label} ({top_score:.2f})")
    return top_label == "urgent medical emergency" and top_score > 0.6


# --- FASTAPI SERVER LOGIC ---

app = FastAPI()

class QueryRequest(BaseModel):
    query: str
    history: List[ChatMessage] = []

class QueryResponse(BaseModel):
    answer: str
    show_hospital_modal: bool

@app.get("/") # Add a root endpoint for health checks
def read_root():
    return {"status": "ok"}

@app.post("/ask", response_model=QueryResponse)
async def ask_question(request: QueryRequest):
    query = request.query
    history = request.history
    print(f"Received query: '{query}'")

    intent = classify_intent_advanced(query, history)
    print(f"Classified intent as: {intent}")

    try:
        if intent == "greeting":
            answer = "Hello! I am a first aid assistant. How can I help you with your medical situation today?"
            return QueryResponse(answer=answer, show_hospital_modal=False)

        elif intent == "off_topic":
            answer = "I apologize, but I am a specialized medical first aid assistant. I cannot provide information on topics outside of that scope."
            return QueryResponse(answer=answer, show_hospital_modal=False)
            
        elif intent == "medical":
            context = query_rag(query)
            answer_text = generate_answer(query, context, history)
            is_serious = classify_response_seriousness(answer_text)
            return QueryResponse(answer=answer_text, show_hospital_modal=is_serious)

    except Exception as e:
        print(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail=str(e))