prelington commited on
Commit
e9e5efd
·
verified ·
1 Parent(s): 4828ab6

Create ProTalk_MemoryChat.py

Browse files
Files changed (1) hide show
  1. ProTalk_MemoryChat.py +52 -0
ProTalk_MemoryChat.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ import torch
3
+ import threading
4
+
5
+ model_name = "microsoft/phi-2"
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
12
+ low_cpu_mem_usage=True
13
+ ).to(device)
14
+
15
+ system_prompt = (
16
+ "You are ProTalk, a professional AI assistant. "
17
+ "You remember everything the user said in this session and respond politely, "
18
+ "clearly, and intelligently. Keep a coherent conversation history."
19
+ )
20
+
21
+ chat_history = []
22
+
23
+ def chat_loop():
24
+ print("ProTalk Memory Chat Online — type 'exit' to quit.\n")
25
+ while True:
26
+ user_input = input("User: ")
27
+ if user_input.lower() == "exit":
28
+ break
29
+ chat_history.append(f"User: {user_input}")
30
+ prompt = system_prompt + "\n" + "\n".join(chat_history) + "\nProTalk:"
31
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
32
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
33
+ thread = threading.Thread(target=model.generate, kwargs={
34
+ "input_ids": inputs["input_ids"],
35
+ "max_new_tokens": 300,
36
+ "do_sample": True,
37
+ "temperature": 0.7,
38
+ "top_p": 0.9,
39
+ "repetition_penalty": 1.2,
40
+ "streamer": streamer
41
+ })
42
+ thread.start()
43
+ output_text = ""
44
+ for token in streamer:
45
+ print(token, end="", flush=True)
46
+ output_text += token
47
+ thread.join()
48
+ print()
49
+ chat_history.append(f"ProTalk: {output_text}")
50
+
51
+ if __name__ == "__main__":
52
+ chat_loop()