LagRAG_demo / app.py
SS8297's picture
Update app.py
1f05d65 verified
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList, StoppingCriteria
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialize Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(model_name)
def read_file(file_path: str) -> str:
"""Read the contents of a file."""
with open(file_path, "r") as file:
return file.read()
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model.to(device)
document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")
# Note: 'index1' has been pre-created in the pinecone console
# read the pinecone api key from a file
pinecone_api_key = st.secrets["pinecone_api_key"]
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index("index1")
def query_pincecone_namespace(
vector_databse_index: Pinecone, q_embedding: str, namespace: str
) -> str:
result = vector_databse_index.query(
namespace=namespace,
vector=q_embedding.tolist(),
top_k=1,
include_values=True,
include_metadata=True,
)
results = []
for match in result.matches:
results.append(match.metadata["paragraph"])
return results[0]
def generate_prompt(llmprompt: str) -> str:
"""Generates a prompt for the GPT-3 model"""
start_token = "<|endoftext|><s>"
end_token = "<s>"
return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip()
def encode_query(query: str) -> torch.Tensor:
"""Encode the query using the model's tokenizer"""
return document_encoder_model.encode(query)
class StopOnTokenCriteria(StoppingCriteria):
def __init__(self, stop_token_id):
self.stop_token_id = stop_token_id
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.stop_token_id
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
st.title("Paralegal Assistant")
st.subheader("RAG: föräldrabalken")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("Skriv din fråga..."):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
query = query_pincecone_namespace(
vector_databse_index=index,
q_embedding=encode_query(query=prompt),
namespace="ns-parent-balk",
)
llmprompt = (
"Följande stycke är en del av lagen: "
+ query
+"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
+ prompt
)
llmprompt = generate_prompt(llmprompt=llmprompt)
# # Convert prompt to tokens
input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device)
# Genqerate tokens based om prompt
generated_token_ids = model.generate(
inputs=input_ids,
max_new_tokens=128,
do_sample=True,
temperature=0.8,
top_p=1,
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
)[0]
# Decode the generated tokens
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])
response = f"{generated_text}"
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(f"```{query}```\n" + response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})