vivek2001123 commited on
Commit
89e90e1
·
1 Parent(s): b4f4e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -2
app.py CHANGED
@@ -1,4 +1,255 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import uuid
3
+ import sys
4
+ import requests
5
 
6
+ USER_ICON = "images/user-icon.png"
7
+ AI_ICON = "images/ai-icon.png"
8
+ MAX_HISTORY_LENGTH = 5
9
+
10
+ if 'user_id' in st.session_state:
11
+ user_id = st.session_state['user_id']
12
+ else:
13
+ user_id = str(uuid.uuid4())
14
+ st.session_state['user_id'] = user_id
15
+
16
+ if 'chat_history' not in st.session_state:
17
+ st.session_state['chat_history'] = []
18
+
19
+ if "chats" not in st.session_state:
20
+ st.session_state.chats = [
21
+ {
22
+ 'id': 0,
23
+ 'question': '',
24
+ 'answer': ''
25
+ }
26
+ ]
27
+
28
+ if "questions" not in st.session_state:
29
+ st.session_state.questions = []
30
+
31
+ if "answers" not in st.session_state:
32
+ st.session_state.answers = []
33
+
34
+ if "input" not in st.session_state:
35
+ st.session_state.input = ""
36
+
37
+ st.markdown("""
38
+ <style>
39
+ .block-container {
40
+ padding-top: 32px;
41
+ padding-bottom: 32px;
42
+ padding-left: 0;
43
+ padding-right: 0;
44
+ }
45
+ .element-container img {
46
+ background-color: #000000;
47
+ }
48
+
49
+ .main-header {
50
+ font-size: 24px;
51
+ }
52
+ </style>
53
+ """, unsafe_allow_html=True)
54
+
55
+ def write_top_bar():
56
+ col1, col2, col3 = st.columns([1,10,2])
57
+ with col1:
58
+ st.image(AI_ICON, use_column_width='always')
59
+ with col2:
60
+ header = "Cogwise Intelligent Assistant"
61
+ st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True)
62
+ with col3:
63
+ clear = st.button("Clear Chat")
64
+ return clear
65
+
66
+ clear = write_top_bar()
67
+
68
+ if clear:
69
+ st.session_state.questions = []
70
+ st.session_state.answers = []
71
+ st.session_state.input = ""
72
+ st.session_state["chat_history"] = []
73
+
74
+ def handle_input():
75
+ input = st.session_state.input
76
+ question_with_id = {
77
+ 'question': input,
78
+ 'id': len(st.session_state.questions)
79
+ }
80
+ st.session_state.questions.append(question_with_id)
81
+
82
+ chat_history = st.session_state["chat_history"]
83
+ if len(chat_history) == MAX_HISTORY_LENGTH:
84
+ chat_history = chat_history[:-1]
85
+
86
+ # api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise"
87
+ # api_request_data = {"question": input, "session": user_id}
88
+ # api_response = requests.post(api_url, json=api_request_data)
89
+ # result = api_response.json()
90
+
91
+ # answer = result['answer']
92
+ !pip install - Uqqq pip - -progress - bar off
93
+ !pip install - qqq bitsandbytes == 0.39.0
94
+ !pip install - qqqtorch - -2.0.1 - -progress - bar off
95
+ !pip install - qqq - U git + https: // github.com / huggingface / transformers.git @ e03a9cc - -progress - bar off
96
+ !pip install - qqq - U git + https: // github.com / huggingface / peft.git @ 42a184f - -progress - bar off
97
+ !pip install - qqq - U git + https: // github.com / huggingface / accelerate.git @ c9fbb71 - -progress - bar off
98
+ !pip install - qqq datasets == 2.12.0 - -progress - bar off
99
+ !pip install - qqq loralib == 0.1.1 - -progress - bar off
100
+ !pip install einops
101
+
102
+ import os
103
+ # from pprint import pprint
104
+ # import json
105
+
106
+ import bitsandbytes as bnb
107
+ import pandas as pd
108
+ import torch
109
+ import torch.nn as nn
110
+ import transformers
111
+ from datasets import load_dataset
112
+ from huggingface_hub import notebook_login
113
+ from peft import (
114
+ LoraConfig,
115
+ PeftConfig,
116
+ get_peft_model,
117
+ prepare_model_for_kbit_training,
118
+ )
119
+ from transformers import (
120
+ AutoConfig,
121
+ AutoModelForCausalLM,
122
+ AutoTokenizer,
123
+ BitsAndBytesConfig,
124
+ )
125
+
126
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
127
+
128
+ notebook_login()
129
+ # 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
130
+
131
+ """### Load dataset"""
132
+
133
+ from datasets import load_dataset
134
+
135
+ dataset_name = "nisaar/Lawyer_GPT_India"
136
+ # dataset_name = "patrick11434/TEST_LLM_DATASET"
137
+ dataset = load_dataset(dataset_name, split="train")
138
+
139
+ """## Load adapters from the Hub
140
+
141
+ You can also directly load adapters from the Hub using the commands below:
142
+ """
143
+
144
+ from peft import *
145
+
146
+ # change peft_model_id
147
+ bnb_config = BitsAndBytesConfig(
148
+ load_in_4bit=True,
149
+ load_4bit_use_double_quant=True,
150
+ bnb_4bit_quant_type="nf4",
151
+ bnb_4bit_compute_dtype=torch.bfloat16,
152
+ )
153
+
154
+ peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts"
155
+ config = PeftConfig.from_pretrained(peft_model_id)
156
+ model = AutoModelForCausalLM.from_pretrained(
157
+ config.base_model_name_or_path,
158
+ return_dict=True,
159
+ quantization_config=bnb_config,
160
+ device_map="auto",
161
+ trust_remote_code=True,
162
+ )
163
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
164
+ tokenizer.pad_token = tokenizer.eos_token
165
+
166
+ model = PeftModel.from_pretrained(model, peft_model_id)
167
+
168
+ """## Inference
169
+
170
+ You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`.
171
+ """
172
+
173
+ generation_config = model.generation_config
174
+ generation_config.max_new_tokens = 200
175
+ generation_config_temperature = 1
176
+ generation_config.top_p = 0.7
177
+ generation_config.num_return_sequences = 1
178
+ generation_config.pad_token_id = tokenizer.eos_token_id
179
+ generation_config_eod_token_id = tokenizer.eos_token_id
180
+
181
+ DEVICE = "cuda:0"
182
+
183
+ # Commented out IPython magic to ensure Python compatibility.
184
+ # %%time
185
+ # prompt = f"""
186
+ # <human>: Who appoints the Chief Justice of India?
187
+ # <assistant>:
188
+ # """.strip()
189
+ #
190
+ # encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
191
+ # with torch.inference_mode():
192
+ # outputs = model.generate(
193
+ # input_ids=encoding.attention_mask,
194
+ # generation_config=generation_config,
195
+ # )
196
+ # print(tokenizer.decode(outputs[0],skip_special_tokens=True))
197
+
198
+ def generate_response(question: str) -> str:
199
+ prompt = f"""
200
+ <human>: {question}
201
+ <assistant>:
202
+ """.strip()
203
+ encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
204
+ with torch.inference_mode():
205
+ outputs = model.generate(
206
+ input_ids=encoding.input_ids,
207
+ attention_mask=encoding.attention_mask,
208
+ generation_config=generation_config,
209
+ )
210
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
211
+
212
+ assistant_start = '<assistant>:'
213
+ response_start = response.find(assistant_start)
214
+ return response[response_start + len(assistant_start):].strip()
215
+
216
+ # prompt = "Debate the merits and demerits of introducing simultaneous elections in India?"
217
+ prompt=input
218
+ answer=print(generate_response(prompt))
219
+
220
+ # answer='Yes'
221
+ chat_history.append((input, answer))
222
+
223
+ st.session_state.answers.append({
224
+ 'answer': answer,
225
+ 'id': len(st.session_state.questions)
226
+ })
227
+ st.session_state.input = ""
228
+
229
+ def write_user_message(md):
230
+ col1, col2 = st.columns([1,12])
231
+
232
+ with col1:
233
+ st.image(USER_ICON, use_column_width='always')
234
+ with col2:
235
+ st.warning(md['question'])
236
+
237
+ def render_answer(answer):
238
+ col1, col2 = st.columns([1,12])
239
+ with col1:
240
+ st.image(AI_ICON, use_column_width='always')
241
+ with col2:
242
+ st.info(answer)
243
+
244
+ def write_chat_message(md, q):
245
+ chat = st.container()
246
+ with chat:
247
+ render_answer(md['answer'])
248
+
249
+ with st.container():
250
+ for (q, a) in zip(st.session_state.questions, st.session_state.answers):
251
+ write_user_message(q)
252
+ write_chat_message(a, q)
253
+
254
+ st.markdown('---')
255
+ input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)