Ashhar commited on
Commit
061239f
1 Parent(s): 60901ed

gsheet save tool

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .env
2
+ .venv
3
+ __pycache__/
4
+ .gitattributes
5
+ gradio_cached_examples/
6
+ app_*.py
app.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import datetime as DT
4
+ import pytz
5
+ import time
6
+ import json
7
+ import re
8
+ from typing import List
9
+ from transformers import AutoTokenizer
10
+ from gradio_client import Client
11
+ from tools import toolsInfo
12
+
13
+ from dotenv import load_dotenv
14
+ load_dotenv()
15
+
16
+ useGpt4 = os.environ.get("USE_GPT_4") == "1"
17
+
18
+ if useGpt4:
19
+ from openai import OpenAI
20
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
21
+ MODEL = "gpt-4o-mini"
22
+ MAX_CONTEXT = 128000
23
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/gpt-4o")
24
+ else:
25
+ from groq import Groq
26
+ client = Groq(
27
+ api_key=os.environ.get("GROQ_API_KEY"),
28
+ )
29
+ MODEL = "llama-3.1-70b-versatile"
30
+ MODEL = "llama3-groq-70b-8192-tool-use-preview"
31
+ MAX_CONTEXT = 8000
32
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
33
+
34
+
35
+ def countTokens(text):
36
+ text = str(text)
37
+ tokens = tokenizer.encode(text, add_special_tokens=False)
38
+ return len(tokens)
39
+
40
+
41
+ SYSTEM_MSG = f"""
42
+ You are a personalized email generator for cold outreach. You take the user through a workflow. Step by step.
43
+
44
+ - You ask for industry of the recipient
45
+ - His/her role
46
+ - More details about the recipient
47
+
48
+ Highlight the exact entity you're requesting for.
49
+ Once collected, you store these info in a Google Sheet
50
+
51
+ """
52
+
53
+ USER_ICON = "icons/man.png"
54
+ ASSISTANT_ICON = "icons/magic-wand(1).png"
55
+ TOOL_ICON = "icons/completed-task.png"
56
+
57
+ IMAGE_LOADER = "icons/ripple.svg"
58
+ TEXT_LOADER = "icons/balls.svg"
59
+ START_MSG = "Let's start 😊"
60
+
61
+ ROLE_TO_AVATAR = {
62
+ "user": USER_ICON,
63
+ "assistant": ASSISTANT_ICON,
64
+ "tool": TOOL_ICON,
65
+ }
66
+
67
+ st.set_page_config(
68
+ page_title="EmailGenie",
69
+ page_icon=ASSISTANT_ICON,
70
+ )
71
+ ipAddress = st.context.headers.get("x-forwarded-for")
72
+
73
+
74
+ def __nowInIST() -> DT.datetime:
75
+ return DT.datetime.now(pytz.timezone("Asia/Kolkata"))
76
+
77
+
78
+ def pprint(log: str):
79
+ now = __nowInIST()
80
+ now = now.strftime("%Y-%m-%d %H:%M:%S")
81
+ print(f"[{now}] [{ipAddress}] {log}")
82
+
83
+
84
+ pprint("\n")
85
+
86
+ st.markdown(
87
+ """
88
+ <style>
89
+ @keyframes blinker {
90
+ 0% {
91
+ opacity: 1;
92
+ }
93
+ 50% {
94
+ opacity: 0.2;
95
+ }
96
+ 100% {
97
+ opacity: 1;
98
+ }
99
+ }
100
+
101
+ .blinking {
102
+ animation: blinker 3s ease-out infinite;
103
+ }
104
+
105
+ .code {
106
+ color: green;
107
+ border-radius: 3px;
108
+ padding: 2px 4px; /* Padding around the text */
109
+ font-family: 'Courier New', Courier, monospace; /* Monospace font */
110
+ }
111
+
112
+ </style>
113
+ """,
114
+ unsafe_allow_html=True
115
+ )
116
+
117
+
118
+ def __isInvalidResponse(response: str):
119
+ # new line followed by small case char
120
+ if len(re.findall(r'\n[a-z]', response)) > 3:
121
+ return True
122
+
123
+ # lot of repeating words
124
+ if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
125
+ return True
126
+
127
+ # lots of paragraphs
128
+ if len(re.findall(r'\n\n', response)) > 15:
129
+ return True
130
+
131
+
132
+ def __matchingKeywordsCount(keywords: List[str], text: str):
133
+ return sum([
134
+ 1 if keyword in text else 0
135
+ for keyword in keywords
136
+ ])
137
+
138
+
139
+ def __isStringNumber(s: str) -> bool:
140
+ try:
141
+ float(s)
142
+ return True
143
+ except ValueError:
144
+ return False
145
+
146
+
147
+ def __resetButtonState():
148
+ st.session_state["buttonValue"] = ""
149
+
150
+
151
+ def __setStartMsg(msg):
152
+ st.session_state.startMsg = msg
153
+
154
+
155
+ if "chatHistory" not in st.session_state:
156
+ st.session_state.chatHistory = []
157
+
158
+ if "messages" not in st.session_state:
159
+ st.session_state.messages = []
160
+
161
+ if "buttonValue" not in st.session_state:
162
+ __resetButtonState()
163
+
164
+ if "startMsg" not in st.session_state:
165
+ st.session_state.startMsg = ""
166
+
167
+
168
+ def __getMessages():
169
+ def getContextSize():
170
+ currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.messages) + 100
171
+ pprint(f"{currContextSize=}")
172
+ return currContextSize
173
+
174
+ while getContextSize() > MAX_CONTEXT:
175
+ pprint("Context size exceeded, removing first message")
176
+ st.session_state.messages.pop(0)
177
+
178
+ return st.session_state.messages
179
+
180
+
181
+ tools = [
182
+ toolsInfo["saveInGSheet"]["schema"]
183
+ ]
184
+
185
+
186
+ def __showTaskStatus(msg):
187
+ taskContainer = st.container()
188
+ taskContainer.image(TOOL_ICON, width=30)
189
+ taskContainer.markdown(
190
+ f"""
191
+ <div class='code'>
192
+ {msg}
193
+ </div>
194
+ """,
195
+ unsafe_allow_html=True
196
+ )
197
+
198
+
199
+ def __processToolCalls(tool_calls):
200
+ for toolCall in tool_calls:
201
+ functionName = toolCall.function.name
202
+ functionToCall = toolsInfo[functionName]["func"]
203
+ functionArgs = json.loads(toolCall.function.arguments)
204
+ functionResult = functionToCall(**functionArgs)
205
+ functionResponse = functionResult["response"]
206
+ shouldShow = functionResult["shouldShow"]
207
+ pprint(f"{functionResponse=}")
208
+
209
+ if shouldShow:
210
+ st.session_state.chatHistory.append(
211
+ {
212
+ "role": "tool",
213
+ "content": functionResponse,
214
+ }
215
+ )
216
+
217
+ __showTaskStatus(functionResponse)
218
+
219
+ st.session_state.messages.append(
220
+ {
221
+ "role": "tool",
222
+ "tool_call_id": toolCall.id,
223
+ "name": functionName,
224
+ "content": functionResponse,
225
+ }
226
+ )
227
+
228
+ def __process_stream_chunk(chunk):
229
+ delta = chunk.choices[0].delta
230
+ if delta.content:
231
+ return delta.content
232
+ elif delta.tool_calls:
233
+ return delta.tool_calls[0]
234
+ return None
235
+
236
+
237
+ def __addToolCallsToMsgs(toolCalls):
238
+ st.session_state.messages.append(
239
+ {
240
+ "role": "assistant",
241
+ "tool_calls": [
242
+ {
243
+ "id": toolCall.id,
244
+ "function": {
245
+ "name": toolCall.function.name,
246
+ "arguments": toolCall.function.arguments,
247
+ },
248
+ "type": toolCall.type,
249
+ }
250
+ for toolCall in toolCalls
251
+ ],
252
+ }
253
+ )
254
+
255
+ def __add_tool_call(tool_call):
256
+ st.session_state.messages.append({
257
+ "role": "assistant",
258
+ "tool_calls": [{
259
+ "id": tool_call.id,
260
+ "function": {
261
+ "name": tool_call.function.name,
262
+ "arguments": tool_call.function.arguments,
263
+ },
264
+ "type": tool_call.type,
265
+ }]
266
+ })
267
+
268
+ def predict1():
269
+ shouldStream = True
270
+
271
+ messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
272
+ messagesFormatted.extend(__getMessages())
273
+ contextSize = countTokens(messagesFormatted)
274
+ pprint(f"{contextSize=} | {MODEL}")
275
+
276
+ response = client.chat.completions.create(
277
+ model=MODEL,
278
+ messages=messagesFormatted,
279
+ temperature=0.8,
280
+ max_tokens=4000,
281
+ stream=shouldStream,
282
+ tools=tools
283
+ )
284
+
285
+ content = ""
286
+ tool_call = None
287
+
288
+ for chunk in response:
289
+ chunk_content = __process_stream_chunk(chunk)
290
+ if isinstance(chunk_content, str):
291
+ content += chunk_content
292
+ yield chunk_content
293
+ elif chunk_content:
294
+ if not tool_call:
295
+ tool_call = chunk_content
296
+ else:
297
+ tool_call.function.arguments += chunk_content.function.arguments
298
+
299
+ if tool_call:
300
+ pprint(f"{tool_call=}")
301
+
302
+ __addToolCallsToMsgs([tool_call])
303
+ try:
304
+ __processToolCalls([tool_call])
305
+ return predict()
306
+ except Exception as e:
307
+ pprint(e)
308
+
309
+
310
+ def __dedupeToolCalls(toolCalls: list):
311
+ toolCallsDict = {}
312
+ for toolCall in toolCalls:
313
+ toolCallsDict[toolCall.function.name] = toolCall
314
+ dedupedToolCalls = list(toolCallsDict.values())
315
+
316
+ if len(toolCalls) != len(dedupedToolCalls):
317
+ pprint("Deduped tool calls!")
318
+ pprint(f"{toolCalls=} -> {dedupedToolCalls=}")
319
+
320
+ return dedupedToolCalls
321
+
322
+
323
+ def predict():
324
+ shouldStream = False
325
+
326
+ messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
327
+ messagesFormatted.extend(__getMessages())
328
+ contextSize = countTokens(messagesFormatted)
329
+ pprint(f"{contextSize=} | {MODEL}")
330
+ pprint(f"{messagesFormatted=}")
331
+
332
+ response = client.chat.completions.create(
333
+ model=MODEL,
334
+ messages=messagesFormatted,
335
+ temperature=0.8,
336
+ max_tokens=4000,
337
+ stream=shouldStream,
338
+ tools=tools
339
+ )
340
+ # pprint(f"llmResponse: {response}")
341
+
342
+ if shouldStream:
343
+ content = ""
344
+ toolCall = None
345
+
346
+ for chunk in response:
347
+ chunkContent = __process_stream_chunk(chunk)
348
+ if isinstance(chunkContent, str):
349
+ content += chunkContent
350
+ yield chunkContent
351
+ elif chunkContent:
352
+ if not toolCall:
353
+ toolCall = chunkContent
354
+ else:
355
+ toolCall.function.arguments += chunkContent.function.arguments
356
+
357
+ toolCalls = [toolCall] if toolCall else []
358
+ else:
359
+ responseMessage = response.choices[0].message
360
+ # pprint(f"{responseMessage=}")
361
+ responseContent = responseMessage.content
362
+ pprint(f"{responseContent=}")
363
+ if responseContent:
364
+ yield responseContent
365
+ toolCalls = responseMessage.tool_calls
366
+ # pprint(f"{toolCalls=}")
367
+
368
+ if toolCalls:
369
+ pprint(f"{toolCalls=}")
370
+ toolCalls = __dedupeToolCalls(toolCalls)
371
+ __addToolCallsToMsgs(toolCalls)
372
+ try:
373
+ __processToolCalls(toolCalls)
374
+ return predict()
375
+ except Exception as e:
376
+ pprint(e)
377
+
378
+
379
+ st.title("EmailGenie 💌")
380
+ if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
381
+ st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
382
+
383
+ for chat in st.session_state.chatHistory:
384
+ role = chat["role"]
385
+ content = chat["content"]
386
+ imagePath = chat.get("image")
387
+ avatar = ROLE_TO_AVATAR[role]
388
+ with st.chat_message(role, avatar=avatar):
389
+ if role == "tool":
390
+ st.markdown(
391
+ f"""
392
+ <div class='code'>
393
+ {content}
394
+ </div>
395
+ """,
396
+ unsafe_allow_html=True
397
+ )
398
+ else:
399
+ st.markdown(content)
400
+
401
+ if imagePath:
402
+ st.image(imagePath)
403
+
404
+ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state["startMsg"]):
405
+ __resetButtonState()
406
+ __setStartMsg("")
407
+
408
+ with st.chat_message("user", avatar=USER_ICON):
409
+ st.markdown(prompt)
410
+ pprint(f"{prompt=}")
411
+ st.session_state.messages.append({"role": "user", "content": prompt })
412
+ st.session_state.chatHistory.append({"role": "user", "content": prompt })
413
+
414
+ with st.chat_message("assistant", avatar=ASSISTANT_ICON):
415
+ responseContainer = st.empty()
416
+
417
+ def __printAndGetResponse():
418
+ response = ""
419
+ # responseContainer.markdown(".....")
420
+ responseContainer.image(TEXT_LOADER)
421
+ responseGenerator = predict()
422
+
423
+ for chunk in responseGenerator:
424
+ response += chunk
425
+ if __isInvalidResponse(response):
426
+ pprint(f"{response=}")
427
+ return
428
+ responseContainer.markdown(response)
429
+
430
+ return response
431
+
432
+ response = __printAndGetResponse()
433
+ while not response:
434
+ pprint("Empty response. Retrying..")
435
+ time.sleep(0.5)
436
+ response = __printAndGetResponse()
437
+
438
+ pprint(f"{response=}")
439
+
440
+ def selectButton(optionLabel):
441
+ st.session_state["buttonValue"] = optionLabel
442
+ pprint(f"Selected: {optionLabel}")
443
+
444
+ # responseParts = response.split(JSON_SEPARATOR)
445
+
446
+ # jsonStr = None
447
+ # if len(responseParts) > 1:
448
+ # [response, jsonStr] = responseParts
449
+
450
+ # if jsonStr:
451
+ # try:
452
+ # json.loads(jsonStr)
453
+ # jsonObj = json.loads(jsonStr)
454
+ # options = jsonObj["options"]
455
+
456
+ # for option in options:
457
+ # st.button(
458
+ # option["label"],
459
+ # key=option["id"],
460
+ # on_click=lambda label=option["label"]: selectButton(label)
461
+ # )
462
+ # # st.code(jsonStr, language="json")
463
+ # except Exception as e:
464
+ # pprint(e)
465
+
466
+ st.session_state.messages.append({
467
+ "role": "assistant",
468
+ "content": response,
469
+ })
470
+ st.session_state.chatHistory.append({
471
+ "role": "assistant",
472
+ "content": response,
473
+ })
icons/balls.svg ADDED
icons/completed-task.png ADDED
icons/magic-wand(1) - Copy.png ADDED
icons/magic-wand(1).png ADDED
icons/magic-wand.gif ADDED
icons/magic-wand.png ADDED
icons/man.png ADDED
icons/ripple.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ groq
3
+ transformers
4
+ gradio_client
5
+ oauth2client
6
+ gspread
tools.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gspread
3
+ from oauth2client.service_account import ServiceAccountCredentials
4
+ import os
5
+ import json
6
+
7
+ # from dotenv import load_dotenv
8
+ # load_dotenv()
9
+
10
+ GCP_JSON_KEY = os.environ.get("GCP_JSON_KEY")
11
+
12
+
13
+ def saveInGSheet(
14
+ industry: str,
15
+ role: str,
16
+ profileDetails: str
17
+ ):
18
+ client = gspread.service_account_from_dict(json.loads(GCP_JSON_KEY))
19
+
20
+ workBook = "test_sheet"
21
+ try:
22
+ spreadsheet = client.open(workBook)
23
+ except gspread.SpreadsheetNotFound:
24
+ spreadsheet = client.create(workBook)
25
+ spreadsheet.share('[email protected]', perm_type='user', role='writer')
26
+ print(f"Created new sheet: {spreadsheet.url}")
27
+
28
+ sheet = spreadsheet.sheet1
29
+ sheet.append_row([
30
+ industry,
31
+ role,
32
+ profileDetails
33
+ ])
34
+ return {
35
+ "response": "Saved in GSheet",
36
+ "shouldShow": True
37
+ }
38
+
39
+
40
+ # saveInGSheet({
41
+ # "name": "Ashhar",
42
+ # "email": "XXXXXXXXXXXXXXXXXXXXX",
43
+ # "subject": "Test Subject",
44
+ # "message": "Test Message",
45
+ # "sent": True
46
+ # })
47
+
48
+
49
+ toolsInfo = {
50
+ "saveInGSheet": {
51
+ "func": saveInGSheet,
52
+ "schema": {
53
+ "type": "function",
54
+ "function": {
55
+ "name": "saveInGSheet",
56
+ "description": "Saves the profile details in the GSheet",
57
+ "parameters": {
58
+ "type": "object",
59
+ "properties": {
60
+ "industry": {
61
+ "type": "string",
62
+ "description": "Industry of the person"
63
+ },
64
+ "role": {
65
+ "type": "string",
66
+ "description": "Role of the person"
67
+ },
68
+ "profileDetails": {
69
+ "type": "string",
70
+ "description": "Profile details of the person"
71
+ }
72
+ },
73
+ "required": ["industry", "role", "profileDetails"]
74
+ }
75
+ }
76
+ },
77
+ }
78
+ }
79
+
80
+
81
+ # def load_json_and_print():
82
+ # with open('emailgenie-434420-b9c81c93bb39.json', 'r') as json_file:
83
+ # data = json.load(json_file)
84
+ # json_string = json.dumps(data)
85
+ # print(json_string)
86
+
87
+ # load_json_and_print()