import logging from threading import Thread from typing import Generator, Dict, Any, List import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer ) logger = logging.getLogger("plutus.model") logging.basicConfig(level=logging.INFO) MAIN_MODEL_NAME = "Remostart/Plutus_Tutor_model" SUMMARY_MODEL_NAME = "Remostart/Plutus_Tutor_model" class PlutusModel: """ Handles the main learning model: - Teaching prompt - Synchronous generation - Streaming generation - Explaining recommendations """ def __init__(self, model_name: str = MAIN_MODEL_NAME): self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[INIT] Main model running on: {self.device}") self.tokenizer = None self.model = None self._load() def _load(self): """Loads the main teaching model and tokenizer.""" try: logger.info(f"[LOAD] Loading tokenizer: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) kwargs = {"torch_dtype": torch.float16} if self.device == "cuda" else {} logger.info("[LOAD] Loading main model weights...") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto" if self.device == "cuda" else None, low_cpu_mem_usage=True, **kwargs ) self.model.eval() logger.info("[READY] Main model successfully loaded.") except Exception as e: logger.exception("Main model loading failed") raise RuntimeError(f"Main model loading failed: {e}") def create_prompt( self, personality: str, level: str, topic: str, extra_context: str = None ) -> str: prompt = ( f"You are PlutusTutor — the best expert in Cardano's Plutus smart contract ecosystem.\n\n" f"User Info:\n" f"- Personality: {personality}\n" f"- Level: {level}\n" f"- Topic: {topic}\n\n" "Your task:\n" "- Teach with extreme clarity.\n" "- Give structured explanations.\n" "- Include examples and, where needed, code.\n" "- Avoid useless filler.\n" "- Adapt tone slightly to user personality.\n\n" ) if extra_context: prompt += f"Additional Context:\n{extra_context}\n\n" prompt += "Begin teaching now.\n\nAssistant:" return prompt def generate( self, prompt: str, max_new_tokens: int = 700, temperature: float = 0.4, top_p: float = 0.5 ) -> str: try: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if decoded.startswith(prompt): decoded = decoded[len(prompt):].strip() return decoded except Exception as e: logger.exception("Generation failed") return f"[Generation Error] {e}" def stream_generate( self, prompt: str, max_new_tokens: int = 400, temperature: float = 0.4, top_p: float = 0.5 ) -> Generator[str, None, None]: try: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) thread = Thread(target=self.model.generate, kwargs={ **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, }) thread.start() accumulated = "" for chunk in streamer: accumulated += chunk yield accumulated except Exception as e: logger.exception("Streaming failed") yield f"[Streaming Error] {e}" def summarize_recommendations( self, topic: str, items: List[Dict[str, Any]], personality: str = None, level: str = None, max_new_tokens: int = 120 ) -> str: bullet_list = [ f"- {item['type'].upper()}: {item.get('title') or item.get('url')} ({item['url']})" for item in items ] refs = "\n".join(bullet_list) prompt = ( f"The user is learning: {topic}\n" "Here are recommended videos and documents:\n\n" f"{refs}\n\n" "Explain why these choices are perfect for the user.\n" f"Personality: {personality}\n" f"Skill Level: {level}\n" "Tone should be confident and friendly.\n\nAssistant:" ) return self.generate(prompt, max_new_tokens=max_new_tokens) class SummaryModel: """ Runs the summarization LLM: - Summaries the full teaching text - Adds clarity + structure - Used in the /summary endpoint """ def __init__(self, model_name: str = SUMMARY_MODEL_NAME): self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[INIT] Summary model running on: {self.device}") self.tokenizer = None self.model = None self._load() def _load(self): try: logger.info(f"[LOAD] Loading summary tokenizer: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) kwargs = {"torch_dtype": torch.float16} if self.device == "cuda" else {} logger.info(f"[LOAD] Loading summary model: {self.model_name}") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto" if self.device == "cuda" else None, low_cpu_mem_usage=True, **kwargs ) self.model.eval() logger.info("[READY] Summary model loaded.") except Exception as e: logger.exception("Summary model loading failed") raise RuntimeError(f"Summary model loading failed: {e}") def summarize_text( self, full_teaching: str, topic: str, level: str, recommended: List[Dict[str, Any]], max_new_tokens: int = 500 ) -> str: # Format RAG references refs = "\n".join([ f"- {item['type'].upper()}: {item.get('title') or item.get('url')} ({item['url']})" for item in recommended ]) if recommended else "None" prompt = ( f"You are a world-class summarization assistant.\n\n" f"TOPIC: {topic}\n" f"LEVEL: {level}\n\n" f"Here is the full teaching content you must summarize:\n\n" f"{full_teaching}\n\n" "Now produce a clean, structured, extremely clear summary.\n" "After the summary, recommend these resources clearly:\n\n" f"{refs}\n\n" "Assistant:" ) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) out = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.2, top_p=0.85, do_sample=True, eos_token_id=self.tokenizer.eos_token_id ) decoded = self.tokenizer.decode(out[0], skip_special_tokens=True) if decoded.startswith(prompt): decoded = decoded[len(prompt):].strip() return decoded.strip()