Remostart's picture
Update app.py
8d2e4a7 verified
import gradio as gr
import torch
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model & tokenizer
MODEL_NAME = "ubiodee/Plutus_Tutor_new"
try:
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.eval()
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Error loading model or tokenizer: {str(e)}")
raise
# Define options for dropdowns
PERSONALITY_TYPES = ["Autistic", "Dyslexic", "Expressive", "Nerd", "Visual", "Other"]
PROGRAMMING_LEVELS = ["Beginner", "Intermediate", "Professional"]
TOPICS = [
"What is Plutus",
"Introduction to Validation",
"Smart Contracts",
"Versioning in Plutus",
"Monad",
"Other" # Add more as needed
]
# Prompt template to guide the model
def create_prompt(personality, level, topic):
return f"User: Teach me about {topic} in Plutus. I am a {level} programmer with {personality} traits. Make the explanation tailored to my needs, easy to understand, and engaging.\nAssistant:"
# Response function with proper streaming
def generate_response(personality, level, topic):
try:
logger.info("Processing selections...")
prompt = create_prompt(personality, level, topic)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Use streamer for token-by-token generation
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": 500,
"do_sample": True,
"temperature": 0.4,
"top_p": 0.5,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id
}
# Run generation in a separate thread to avoid blocking
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text.strip()
logger.info("Response generated successfully.")
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
yield f"Error: {str(e)}"
# Gradio UI with dropdowns and button
with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
gr.Markdown("### Your Personalised Plutus Tutor")
gr.Markdown("Select your personality type, programming level, and topic, then click Generate.")
personality = gr.Dropdown(
choices=PERSONALITY_TYPES,
label="Personality Type",
value="Dyslexic" # Default
)
level = gr.Dropdown(
choices=PROGRAMMING_LEVELS,
label="Programming Level",
value="Beginner" # Default
)
topic = gr.Dropdown(
choices=TOPICS,
label="Topic",
value="What is Plutus" # Default
)
generate_btn = gr.Button("Generate")
output = gr.Textbox(
label="Model Response",
show_label=True,
lines=10,
placeholder="Generated content will appear here..."
)
generate_btn.click(
fn=generate_response,
inputs=[personality, level, topic],
outputs=output
)
# Launch the app
try:
logger.info("Launching Gradio interface...")
demo.launch()
except Exception as e:
logger.error(f"Error launching Gradio: {str(e)}")
raise