Haseeb-001 commited on
Commit
098946e
Β·
verified Β·
1 Parent(s): 20f6b70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -73
app.py CHANGED
@@ -4,106 +4,80 @@ import faiss
4
  import pickle
5
  from groq import Groq
6
  from datasets import load_dataset
7
- from transformers import AutoTokenizer, pipeline
8
- import subprocess # For downloading if needed
9
 
10
  # Initialize Groq API
11
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
12
 
13
- # Download model (if necessary - try requirements.txt first)
14
- try:
15
- # Try loading directly (after requirements.txt)
16
- tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
17
- chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
18
- print("Model loaded successfully (direct load).") # Check in logs
19
- except Exception as e:
20
- try:
21
- # Fallback: Download using subprocess (less preferred)
22
- print("Trying to download model...") # Check in logs
23
- subprocess.run(["transformers-cli", "download", "rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog"], check=True) # Updated download command
24
- tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
25
- chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
26
- print("Model downloaded and loaded successfully (subprocess).") # Check in logs
27
- except Exception as download_e:
28
- st.error(f"Error loading/downloading chat model: {e}. Download error: {download_e}")
29
- st.stop()
30
 
 
31
 
32
 
33
- # Load datasets (with error handling)
34
- try:
35
- healthcare_ds = load_dataset("harishnair04/mtsamples")
36
- education_ds = load_dataset("ehovy/race", "all")
37
- finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
38
- except Exception as e:
39
- st.error(f"Error loading datasets: {e}")
40
- st.stop()
41
 
42
- # FAISS Index Setup (Simplified)
43
- index = faiss.IndexFlatL2(768) # Adjust dimension if needed
 
 
 
44
  chat_history = []
45
 
46
  # Streamlit UI Setup
47
  st.set_page_config(page_title="AI Chatbot", layout="wide")
48
  st.title("πŸ€– AI Chatbot (Healthcare, Education & Finance)")
49
 
50
- # ... (rest of your Streamlit UI code - sidebar, input, buttons)
 
 
 
 
 
51
 
52
  # Chat Interface
53
  user_input = st.text_input("πŸ’¬ Ask me anything:", placeholder="Type your query here...")
54
  if st.button("Send"):
55
  if user_input:
56
- # Dataset Selection (Improved)
57
- dataset = None
58
- if "health" in user_input.lower():
59
- dataset = healthcare_ds
60
- elif "education" in user_input.lower():
61
- dataset = education_ds
62
- elif "finance" in user_input.lower():
63
- dataset = finance_ds
64
-
65
- if dataset is None:
66
- st.warning("No relevant dataset found for your query. Please use keywords like 'health', 'education', or 'finance'.")
67
- st.stop()
68
-
69
- # RAG: Retrieve (Simplified and safer)
70
- retrieved_data = dataset['train'][0]['text'] if dataset and len(dataset['train']) > 0 and 'text' in dataset['train'][0] else "No relevant data retrieved."
71
-
72
- try:
73
- # Generate response (Groq)
74
- chat_completion = client.chat.completions.create(
75
- messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
76
- model="llama-3.3-70b-versatile"
77
- )
78
- response = chat_completion.choices[0].message.content
79
- except Exception as e:
80
- st.error(f"Error generating response: {e}")
81
- response = "Error generating response."
82
-
83
- # Save and display
84
  chat_history.append(f"User: {user_input}\nBot: {response}")
85
  st.text_area("πŸ€– AI Response:", value=response, height=200)
86
 
87
- # ... (rest of your Streamlit code - chat history display, save/load)
 
88
 
89
- # Persistence functions (pickle)
90
  def save_chat_history():
91
- try:
92
- with open("chat_history.pkl", "wb") as file:
93
- pickle.dump(chat_history, file)
94
- st.sidebar.success("Chat history saved permanently!")
95
- except Exception as e:
96
- st.sidebar.error(f"Error saving chat history: {e}")
97
 
98
  def load_chat_history():
99
  global chat_history
100
- try:
101
- if os.path.exists("chat_history.pkl"):
102
- with open("chat_history.pkl", "rb") as file:
103
- chat_history = pickle.load(file)
104
- except Exception as e:
105
- st.sidebar.warning(f"Error loading chat history (may be corrupted): {e}")
106
 
107
  load_chat_history()
108
  if st.sidebar.button("Save Chat History"):
109
- save_chat_history()
 
 
4
  import pickle
5
  from groq import Groq
6
  from datasets import load_dataset
7
+ from transformers import pipeline
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  # Initialize Groq API
11
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
12
 
13
+ model_name = "rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ chat_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
 
19
 
20
+ # Load datasets
21
+ healthcare_ds = load_dataset("harishnair04/mtsamples")
22
+ education_ds = load_dataset("ehovy/race", "all")
23
+ finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
 
 
 
 
24
 
25
+ # Load chat model
26
+ chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog")
27
+
28
+ # FAISS Index Setup
29
+ index = faiss.IndexFlatL2(768)
30
  chat_history = []
31
 
32
  # Streamlit UI Setup
33
  st.set_page_config(page_title="AI Chatbot", layout="wide")
34
  st.title("πŸ€– AI Chatbot (Healthcare, Education & Finance)")
35
 
36
+ # Sidebar for chat history
37
+ st.sidebar.title("πŸ“œ Chat History")
38
+ if st.sidebar.button("Download Chat History"):
39
+ with open("chat_history.txt", "w") as file:
40
+ file.write("\n".join(chat_history))
41
+ st.sidebar.success("Chat history saved!")
42
 
43
  # Chat Interface
44
  user_input = st.text_input("πŸ’¬ Ask me anything:", placeholder="Type your query here...")
45
  if st.button("Send"):
46
  if user_input:
47
+ # Determine dataset based on user query (Basic CAG Implementation)
48
+ dataset = healthcare_ds if "health" in user_input.lower() else \
49
+ education_ds if "education" in user_input.lower() else \
50
+ finance_ds
51
+
52
+ # RAG: Retrieve relevant data
53
+ retrieved_data = dataset['train'][0] # Simplified retrieval
54
+
55
+ # Generate response using Llama via Groq API
56
+ chat_completion = client.chat.completions.create(
57
+ messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
58
+ model="llama-3.3-70b-versatile"
59
+ )
60
+ response = chat_completion.choices[0].message.content
61
+
62
+ # Save chat to FAISS and display
 
 
 
 
 
 
 
 
 
 
 
 
63
  chat_history.append(f"User: {user_input}\nBot: {response}")
64
  st.text_area("πŸ€– AI Response:", value=response, height=200)
65
 
66
+ # Display past chats
67
+ st.sidebar.write("\n".join(chat_history))
68
 
69
+ # Save chat history using pickle for persistence
70
  def save_chat_history():
71
+ with open("chat_history.pkl", "wb") as file:
72
+ pickle.dump(chat_history, file)
 
 
 
 
73
 
74
  def load_chat_history():
75
  global chat_history
76
+ if os.path.exists("chat_history.pkl"):
77
+ with open("chat_history.pkl", "rb") as file:
78
+ chat_history = pickle.load(file)
 
 
 
79
 
80
  load_chat_history()
81
  if st.sidebar.button("Save Chat History"):
82
+ save_chat_history()
83
+ st.sidebar.success("Chat history saved permanently!")