Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from peft import * | |
| import uuid | |
| import sys | |
| import requests | |
| from peft import * | |
| import bitsandbytes as bnb | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from datasets import load_dataset | |
| from huggingface_hub import notebook_login | |
| from peft import ( | |
| LoraConfig, | |
| PeftConfig, | |
| get_peft_model, | |
| prepare_model_for_kbit_training, | |
| ) | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| USER_ICON = "images/user-icon.png" | |
| AI_ICON = "images/ai-icon.png" | |
| MAX_HISTORY_LENGTH = 5 | |
| if 'user_id' in st.session_state: | |
| user_id = st.session_state['user_id'] | |
| else: | |
| user_id = str(uuid.uuid4()) | |
| st.session_state['user_id'] = user_id | |
| if 'chat_history' not in st.session_state: | |
| st.session_state['chat_history'] = [] | |
| if "chats" not in st.session_state: | |
| st.session_state.chats = [ | |
| { | |
| 'id': 0, | |
| 'question': '', | |
| 'answer': '' | |
| } | |
| ] | |
| if "questions" not in st.session_state: | |
| st.session_state.questions = [] | |
| if "answers" not in st.session_state: | |
| st.session_state.answers = [] | |
| if "input" not in st.session_state: | |
| st.session_state.input = "" | |
| st.markdown(""" | |
| <style> | |
| .block-container { | |
| padding-top: 32px; | |
| padding-bottom: 32px; | |
| padding-left: 0; | |
| padding-right: 0; | |
| } | |
| .element-container img { | |
| background-color: #000000; | |
| } | |
| .main-header { | |
| font-size: 24px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def write_top_bar(): | |
| col1, col2, col3 = st.columns([1,10,2]) | |
| with col1: | |
| st.image(AI_ICON, use_column_width='always') | |
| with col2: | |
| header = "Cogwise Intelligent Assistant" | |
| st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True) | |
| with col3: | |
| clear = st.button("Clear Chat") | |
| return clear | |
| clear = write_top_bar() | |
| if clear: | |
| st.session_state.questions = [] | |
| st.session_state.answers = [] | |
| st.session_state.input = "" | |
| st.session_state["chat_history"] = [] | |
| def handle_input(): | |
| input = st.session_state.input | |
| question_with_id = { | |
| 'question': input, | |
| 'id': len(st.session_state.questions) | |
| } | |
| st.session_state.questions.append(question_with_id) | |
| chat_history = st.session_state["chat_history"] | |
| if len(chat_history) == MAX_HISTORY_LENGTH: | |
| chat_history = chat_history[:-1] | |
| # api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise" | |
| # api_request_data = {"question": input, "session": user_id} | |
| # api_response = requests.post(api_url, json=api_request_data) | |
| # result = api_response.json() | |
| # answer = result['answer'] | |
| # !pip install -Uqqq pip --progress-bar off | |
| # !pip install -qqq bitsandbytes == 0.39.0 | |
| # !pip install -qqqtorch --2.0.1 --progress-bar off | |
| # !pip install -qqq -U git + https://github.com/huggingface/transformers.git@e03a9cc --progress-bar off | |
| # !pip install -qqq -U git + https://github.com/huggingface/peft.git@42a184f --progress-bar off | |
| # !pip install -qqq -U git + https://github.com/huggingface/accelerate.git@c9fbb71 --progress-bar off | |
| # !pip install -qqq datasets == 2.12.0 --progress-bar off | |
| # !pip install -qqq loralib == 0.1.1 --progress-bar off | |
| # !pip install einops | |
| import os | |
| # from pprint import pprint | |
| # import json | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| # notebook_login() | |
| # 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 | |
| # """### Load dataset""" | |
| from datasets import load_dataset | |
| dataset_name = "nisaar/Lawyer_GPT_India" | |
| # dataset_name = "patrick11434/TEST_LLM_DATASET" | |
| dataset = load_dataset(dataset_name, split="train") | |
| # """## Load adapters from the Hub | |
| # You can also directly load adapters from the Hub using the commands below: | |
| # """ | |
| # change peft_model_id | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| load_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts" | |
| config = PeftConfig.from_pretrained(peft_model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| return_dict=True, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = PeftModel.from_pretrained(model, peft_model_id) | |
| """## Inference | |
| You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`. | |
| """ | |
| generation_config = model.generation_config | |
| generation_config.max_new_tokens = 200 | |
| generation_config_temperature = 1 | |
| generation_config.top_p = 0.7 | |
| generation_config.num_return_sequences = 1 | |
| generation_config.pad_token_id = tokenizer.eos_token_id | |
| generation_config_eod_token_id = tokenizer.eos_token_id | |
| DEVICE = "cuda:0" | |
| # Commented out IPython magic to ensure Python compatibility. | |
| # %%time | |
| # prompt = f""" | |
| # <human>: Who appoints the Chief Justice of India? | |
| # <assistant>: | |
| # """.strip() | |
| # | |
| # encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| # with torch.inference_mode(): | |
| # outputs = model.generate( | |
| # input_ids=encoding.attention_mask, | |
| # generation_config=generation_config, | |
| # ) | |
| # print(tokenizer.decode(outputs[0],skip_special_tokens=True)) | |
| def generate_response(question: str) -> str: | |
| prompt = f""" | |
| <human>: {question} | |
| <assistant>: | |
| """.strip() | |
| encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=encoding.input_ids, | |
| attention_mask=encoding.attention_mask, | |
| generation_config=generation_config, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| assistant_start = '<assistant>:' | |
| response_start = response.find(assistant_start) | |
| return response[response_start + len(assistant_start):].strip() | |
| # prompt = "Debate the merits and demerits of introducing simultaneous elections in India?" | |
| prompt=input | |
| answer=print(generate_response(prompt)) | |
| # answer='Yes' | |
| chat_history.append((input, answer)) | |
| st.session_state.answers.append({ | |
| 'answer': answer, | |
| 'id': len(st.session_state.questions) | |
| }) | |
| st.session_state.input = "" | |
| def write_user_message(md): | |
| col1, col2 = st.columns([1,12]) | |
| with col1: | |
| st.image(USER_ICON, use_column_width='always') | |
| with col2: | |
| st.warning(md['question']) | |
| def render_answer(answer): | |
| col1, col2 = st.columns([1,12]) | |
| with col1: | |
| st.image(AI_ICON, use_column_width='always') | |
| with col2: | |
| st.info(answer) | |
| def write_chat_message(md, q): | |
| chat = st.container() | |
| with chat: | |
| render_answer(md['answer']) | |
| with st.container(): | |
| for (q, a) in zip(st.session_state.questions, st.session_state.answers): | |
| write_user_message(q) | |
| write_chat_message(a, q) | |
| st.markdown('---') | |
| input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) |