iasbeck commited on
Commit
a4f5679
·
1 Parent(s): 7a254b5

Implementação do RAG.

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. README.md +4 -4
  3. app.py +136 -0
  4. rag_test.py +118 -0
  5. requirements.txt +114 -0
  6. test.txt +0 -0
  7. train.txt +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ venv
3
+ .env
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Ross Gpt
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 1.41.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: Ross Gpt
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: gray
6
  sdk: streamlit
7
+ sdk_version: 1.40.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import multiprocessing
3
+ from langchain.docstore.document import Document as LangChainDocument
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from huggingface_hub import login
8
+ from loguru import logger
9
+ import os
10
+ from dotenv import load_dotenv
11
+
12
+ multiprocessing.freeze_support()
13
+ load_dotenv()
14
+ access_token = os.getenv("ACCESS_TOKEN")
15
+ login(token=access_token)
16
+ logger.info('Login realizado com sucesso.')
17
+
18
+ logger.info('Carregando arquivo no qual será baseado o RAG.')
19
+ with open('train.txt', 'r') as f:
20
+ data = f.read()
21
+
22
+ logger.info('Representando o documento utilizando o LangChainDocument.')
23
+ raw_database = LangChainDocument(page_content=data)
24
+
25
+ MARKDOWN_SEPARATORS = [
26
+ "\n#{1,6} ",
27
+ "```\n",
28
+ "\n\\*\\*\\*+\n",
29
+ "\n---+\n",
30
+ "\n___+\n",
31
+ "\n\n",
32
+ "\n",
33
+ " ",
34
+ "",
35
+ ]
36
+
37
+ logger.info('Quebrando o documento para a criação dos chunks.')
38
+ splitter = RecursiveCharacterTextSplitter(separators=MARKDOWN_SEPARATORS, chunk_size=1000, chunk_overlap=100)
39
+ process_data = splitter.split_documents([raw_database])
40
+ process_data = process_data[:5] # TODO: REMOVER DEPOIS
41
+
42
+ embedding_model_name = "thenlper/gte-small"
43
+ logger.info(f'Definição do modelo de embeddings: {embedding_model_name}.')
44
+ embedding_model = HuggingFaceEmbeddings(
45
+ model_name=embedding_model_name,
46
+ multi_process=True,
47
+ model_kwargs={"device": "cuda"},
48
+ encode_kwargs={"normalize_embeddings": True}, # Set `True` for cosine similarity
49
+ )
50
+
51
+ logger.info('Criação da base de dados vetorial (em memória).')
52
+ vectors = FAISS.from_documents(process_data, embedding_model)
53
+
54
+ from transformers import pipeline
55
+ import torch
56
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
57
+
58
+ # model_name = "meta-llama/Llama-3.2-1B"
59
+ model_name = "HuggingFaceH4/zephyr-7b-beta"
60
+ # model_name = "mistralai/Mistral-7B-Instruct-v0.3"
61
+ # model_name = "meta-llama/Llama-3.2-3B-Instruct"
62
+ logger.info(f'Carregamento do modelo de linguagem principal: {model_name}')
63
+
64
+ bnb_config = BitsAndBytesConfig(
65
+ load_in_4bit=True,
66
+ bnb_4bit_use_double_quant=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.bfloat16,
69
+ )
70
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
72
+
73
+ llm_model = pipeline(
74
+ model=model,
75
+ tokenizer=tokenizer,
76
+ task="text-generation",
77
+ do_sample=True,
78
+ temperature=0.4,
79
+ repetition_penalty=1.1,
80
+ return_full_text=False,
81
+ max_new_tokens=500
82
+ )
83
+ logger.info(f'Modelo {model_name} carregado com sucesso.')
84
+
85
+ prompt = """
86
+ <|system|>
87
+ You are a helpful assistant that answers on medical questions based on the real information provided from different sources and in the context.
88
+ Give the rational and well written response. If you don't have proper info in the context, answer "I don't know"
89
+ Respond only to the question asked.
90
+
91
+ <|user|>
92
+ Context:
93
+ {}
94
+ ---
95
+ Here is the question you need to answer.
96
+
97
+ Question: {}
98
+ ---
99
+ <|assistant|>
100
+ """
101
+
102
+ st.title("Echo Bot")
103
+
104
+ if "messages" not in st.session_state:
105
+ st.session_state.messages = []
106
+
107
+ for message in st.session_state.messages:
108
+ with st.chat_message(message["role"]):
109
+ st.markdown(message["content"])
110
+
111
+ question = st.chat_input("How can I help you?")
112
+ if question:
113
+ with st.chat_message("user"):
114
+ st.markdown(prompt)
115
+
116
+ st.session_state.messages.append({"role": "user", "content": prompt})
117
+
118
+ search_results = vectors.similarity_search(question, k=3)
119
+
120
+ logger.info('Contexto: ')
121
+ for i, search_result in enumerate(search_results):
122
+ logger.info(f"{i + 1}) {search_result.page_content}")
123
+
124
+ context = " ".join([search_result.page_content for search_result in search_results])
125
+ final_prompt = prompt.format(context, question)
126
+ logger.info(f'\n{final_prompt}\n')
127
+
128
+ answer = llm_model(final_prompt)
129
+ text_answer = answer[0]['generated_text']
130
+
131
+ logger.info("AI response: ", text_answer)
132
+
133
+ with st.chat_message("assistant"):
134
+ st.markdown(text_answer)
135
+
136
+ st.session_state.messages.append({"role": "assistant", "content": text_answer})
rag_test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from langchain.docstore.document import Document as LangChainDocument
3
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from huggingface_hub import login
7
+ from loguru import logger
8
+ from transformers import pipeline
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
+ import os
12
+ from dotenv import load_dotenv
13
+
14
+
15
+ def main():
16
+ load_dotenv()
17
+ logger.info('Carregando arquivo no qual será baseado o RAG.')
18
+ with open('train.txt', 'r') as f:
19
+ data = f.read()
20
+
21
+ logger.info('Representando o documento utilizando o LangChainDocument.')
22
+ raw_database = LangChainDocument(page_content=data)
23
+
24
+ MARKDOWN_SEPARATORS = [
25
+ "\n#{1,6} ",
26
+ "```\n",
27
+ "\n\\*\\*\\*+\n",
28
+ "\n---+\n",
29
+ "\n___+\n",
30
+ "\n\n",
31
+ "\n",
32
+ " ",
33
+ "",
34
+ ]
35
+
36
+ logger.info('Quebrando o documento para a criação dos chunks.')
37
+ splitter = RecursiveCharacterTextSplitter(separators=MARKDOWN_SEPARATORS, chunk_size=1000, chunk_overlap=100)
38
+ process_data = splitter.split_documents([raw_database])
39
+ process_data = process_data[:5] # TODO: REMOVER DEPOIS
40
+
41
+ embedding_model_name = "thenlper/gte-small"
42
+ logger.info(f'Definição do modelo de embeddings: {embedding_model_name}.')
43
+ embedding_model = HuggingFaceEmbeddings(
44
+ model_name=embedding_model_name,
45
+ multi_process=True,
46
+ model_kwargs={"device": "cuda"},
47
+ encode_kwargs={"normalize_embeddings": True}, # Set `True` for cosine similarity
48
+ )
49
+
50
+ logger.info('Criação da base de dados vetorial (em memória).')
51
+ vectors = FAISS.from_documents(process_data, embedding_model)
52
+
53
+ # model_name = "meta-llama/Llama-3.2-1B"
54
+ model_name = "HuggingFaceH4/zephyr-7b-beta"
55
+ # model_name = "mistralai/Mistral-7B-Instruct-v0.3"
56
+ # model_name = "meta-llama/Llama-3.2-3B-Instruct"
57
+ logger.info(f'Carregamento do modelo de linguagem principal: {model_name}')
58
+
59
+ bnb_config = BitsAndBytesConfig(
60
+ load_in_4bit=True,
61
+ bnb_4bit_use_double_quant=True,
62
+ bnb_4bit_quant_type="nf4",
63
+ bnb_4bit_compute_dtype=torch.bfloat16,
64
+ )
65
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
66
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+
68
+ llm_model = pipeline(
69
+ model=model,
70
+ tokenizer=tokenizer,
71
+ task="text-generation",
72
+ do_sample=True,
73
+ temperature=0.4,
74
+ repetition_penalty=1.1,
75
+ return_full_text=False,
76
+ max_new_tokens=500
77
+ )
78
+ logger.info(f'Modelo {model_name} carregado com sucesso.')
79
+
80
+ prompt = """
81
+ <|system|>
82
+ You are a helpful assistant that answers on medical questions based on the real information provided from different sources and in the context.
83
+ Give the rational and well written response. If you don't have proper info in the context, answer "I don't know"
84
+ Respond only to the question asked.
85
+
86
+ <|user|>
87
+ Context:
88
+ {}
89
+ ---
90
+ Here is the question you need to answer.
91
+
92
+ Question: {}
93
+ ---
94
+ <|assistant|>
95
+ """
96
+
97
+ question = "What is Cardiogenic shock?"
98
+ search_results = vectors.similarity_search(question, k=3)
99
+
100
+ logger.info('Contexto: ')
101
+ for i, search_result in enumerate(search_results):
102
+ logger.info(f"{i + 1}) {search_result.page_content}")
103
+
104
+ context = " ".join([search_result.page_content for search_result in search_results])
105
+ final_prompt = prompt.format(context, question)
106
+ logger.info(f'\n{final_prompt}\n')
107
+
108
+ answer = llm_model(final_prompt)
109
+
110
+ logger.info("AI response: ", answer[0]['generated_text'])
111
+
112
+
113
+ if __name__ == '__main__':
114
+ multiprocessing.freeze_support()
115
+ access_token = os.getenv("ACCESS_TOKEN")
116
+ login(token=access_token)
117
+ logger.info('Login realizado com sucesso.')
118
+ main()
requirements.txt ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.0
2
+ aiohappyeyeballs==2.4.4
3
+ aiohttp==3.11.10
4
+ aiosignal==1.3.1
5
+ altair==5.5.0
6
+ annotated-types==0.7.0
7
+ anyio==4.7.0
8
+ attrs==24.2.0
9
+ bitsandbytes==0.45.0
10
+ blinker==1.9.0
11
+ cachetools==5.5.0
12
+ certifi==2024.8.30
13
+ charset-normalizer==3.4.0
14
+ click==8.1.7
15
+ dataclasses-json==0.6.7
16
+ datasets==3.2.0
17
+ dill==0.3.8
18
+ filelock==3.16.1
19
+ frozenlist==1.5.0
20
+ fsspec==2024.9.0
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ greenlet==3.1.1
24
+ h11==0.14.0
25
+ httpcore==1.0.7
26
+ httpx==0.28.1
27
+ httpx-sse==0.4.0
28
+ huggingface-hub==0.26.5
29
+ idna==3.10
30
+ Jinja2==3.1.4
31
+ joblib==1.4.2
32
+ jsonpatch==1.33
33
+ jsonpointer==3.0.0
34
+ jsonschema==4.23.0
35
+ jsonschema-specifications==2024.10.1
36
+ langchain==0.3.11
37
+ langchain-community==0.3.11
38
+ langchain-core==0.3.24
39
+ langchain-huggingface==0.1.2
40
+ langchain-text-splitters==0.3.2
41
+ langsmith==0.2.2
42
+ loguru==0.7.3
43
+ markdown-it-py==3.0.0
44
+ MarkupSafe==3.0.2
45
+ marshmallow==3.23.1
46
+ mdurl==0.1.2
47
+ mpmath==1.3.0
48
+ multidict==6.1.0
49
+ multiprocess==0.70.16
50
+ mypy-extensions==1.0.0
51
+ narwhals==1.17.0
52
+ networkx==3.4.2
53
+ numpy==1.26.4
54
+ nvidia-cublas-cu12==12.4.5.8
55
+ nvidia-cuda-cupti-cu12==12.4.127
56
+ nvidia-cuda-nvrtc-cu12==12.4.127
57
+ nvidia-cuda-runtime-cu12==12.4.127
58
+ nvidia-cudnn-cu12==9.1.0.70
59
+ nvidia-cufft-cu12==11.2.1.3
60
+ nvidia-curand-cu12==10.3.5.147
61
+ nvidia-cusolver-cu12==11.6.1.9
62
+ nvidia-cusparse-cu12==12.3.1.170
63
+ nvidia-nccl-cu12==2.21.5
64
+ nvidia-nvjitlink-cu12==12.4.127
65
+ nvidia-nvtx-cu12==12.4.127
66
+ orjson==3.10.12
67
+ packaging==24.2
68
+ pandas==2.2.3
69
+ pillow==11.0.0
70
+ propcache==0.2.1
71
+ protobuf==5.29.1
72
+ psutil==6.1.0
73
+ pyarrow==18.1.0
74
+ pydantic==2.10.3
75
+ pydantic-settings==2.6.1
76
+ pydantic_core==2.27.1
77
+ pydeck==0.9.1
78
+ Pygments==2.18.0
79
+ python-dateutil==2.9.0.post0
80
+ python-dotenv==1.0.1
81
+ pytz==2024.2
82
+ PyYAML==6.0.2
83
+ referencing==0.35.1
84
+ regex==2024.11.6
85
+ requests==2.32.3
86
+ requests-toolbelt==1.0.0
87
+ rich==13.9.4
88
+ rpds-py==0.22.3
89
+ safetensors==0.4.5
90
+ scikit-learn==1.6.0
91
+ scipy==1.14.1
92
+ sentence-transformers==3.3.1
93
+ six==1.17.0
94
+ smmap==5.0.1
95
+ sniffio==1.3.1
96
+ SQLAlchemy==2.0.36
97
+ streamlit==1.41.0
98
+ sympy==1.13.1
99
+ tenacity==9.0.0
100
+ threadpoolctl==3.5.0
101
+ tokenizers==0.21.0
102
+ toml==0.10.2
103
+ torch==2.5.1
104
+ tornado==6.4.2
105
+ tqdm==4.67.1
106
+ transformers==4.47.0
107
+ triton==3.1.0
108
+ typing-inspect==0.9.0
109
+ typing_extensions==4.12.2
110
+ tzdata==2024.2
111
+ urllib3==2.2.3
112
+ watchdog==6.0.0
113
+ xxhash==3.5.0
114
+ yarl==1.18.3
test.txt ADDED
The diff for this file is too large to render. See raw diff
 
train.txt ADDED
The diff for this file is too large to render. See raw diff