File size: 5,184 Bytes
a4f5679
a21dabb
 
a4f5679
 
 
 
 
 
 
 
 
 
 
 
a21dabb
a4f5679
a21dabb
 
a4f5679
 
 
 
 
 
a21dabb
a4f5679
a21dabb
 
a4f5679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a21dabb
a4f5679
 
 
a21dabb
 
 
 
 
 
 
a4f5679
 
a21dabb
 
 
 
 
 
 
a4f5679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a21dabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import multiprocessing
import time

from langchain.docstore.document import Document as LangChainDocument
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from huggingface_hub import login
from loguru import logger
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
from dotenv import load_dotenv

vector_database_builded = False


def load_document():
    logger.info('Carregando arquivo no qual será baseado o RAG.')
    with open('train.txt', 'r') as f:
        data = f.read()

    logger.info('Representando o documento utilizando o LangChainDocument.')
    raw_database = LangChainDocument(page_content=data)
    return raw_database


def generate_chunks(raw_database):
    MARKDOWN_SEPARATORS = [
        "\n#{1,6} ",
        "```\n",
        "\n\\*\\*\\*+\n",
        "\n---+\n",
        "\n___+\n",
        "\n\n",
        "\n",
        " ",
        "",
    ]

    logger.info('Quebrando o documento para a criação dos chunks.')
    splitter = RecursiveCharacterTextSplitter(separators=MARKDOWN_SEPARATORS, chunk_size=1000, chunk_overlap=100)
    process_data = splitter.split_documents([raw_database])
    process_data = process_data[:5]  # TODO: REMOVER DEPOIS

    embedding_model_name = "thenlper/gte-small"
    logger.info(f'Definição do modelo de embeddings: {embedding_model_name}.')
    embedding_model = HuggingFaceEmbeddings(
        model_name=embedding_model_name,
        multi_process=True,
        model_kwargs={"device": "cuda"},  # TODO: AJUSTAR DEPOIS
        encode_kwargs={"normalize_embeddings": True},  # Set `True` for cosine similarity
    )

    return process_data, embedding_model


def build_vector_database():
    raw_database = load_document()
    process_data, embedding_model = generate_chunks(raw_database)

    logger.info('Criação da base de dados vetorial (em memória).')
    vectors = FAISS.from_documents(process_data, embedding_model)
    return vectors


def load_model():
    load_dotenv()
    login(token=os.getenv('HF_TOKEN'))
    time.sleep(2)

    # model_name = "meta-llama/Llama-3.2-1B"
    model_name = "HuggingFaceH4/zephyr-7b-beta"
    # model_name = "mistralai/Mistral-7B-Instruct-v0.3"
    # model_name = "meta-llama/Llama-3.2-3B-Instruct"
    logger.info(f'Carregamento do modelo de linguagem principal: {model_name}')

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    llm_model = pipeline(
        model=model,
        tokenizer=tokenizer,
        task="text-generation",
        do_sample=True,
        temperature=0.4,
        repetition_penalty=1.1,
        return_full_text=False,
        max_new_tokens=500
    )
    logger.info(f'Modelo {model_name} carregado com sucesso.')

    return llm_model


def get_answer(question, use_context=True):
    vectors = build_vector_database()
    llm_model = load_model()

    if use_context:
        prompt = """
        <|system|>
        You are a helpful assistant that answers on medical questions based on the real information provided from different sources and in the context.
        Give the rational and well written response. If you don't have proper info in the context, answer "I don't know"
        Respond only to the question asked.
    
        <|user|>
        Context:
        {}
        ---
        Here is the question you need to answer.
    
        Question: {}
        ---
        <|assistant|>
        """

        search_results = vectors.similarity_search(question, k=3)
        logger.info('Contexto: ')
        for i, search_result in enumerate(search_results):
            logger.info(f"{i + 1}) {search_result.page_content}")

        context = " ".join([search_result.page_content for search_result in search_results])

        final_prompt = prompt.format(context, question)
        logger.info(f'Prompt final: \n{final_prompt}\n')
        answer = llm_model(final_prompt)
        logger.info(f"Resposta da IA: {answer[0]['generated_text']}")

    else:
        prompt = """
        <|system|>
        You are a helpful assistant that answers on medical questions based on the real information provided from different sources and in the context.
        Give the rational and well written response. If you don't have proper info in the context, answer "I don't know"
        Respond only to the question asked.
    
        <|user|>
        ---
        Here is the question you need to answer.
    
        Question: {}
        ---
        <|assistant|>
        """

        final_prompt = prompt.format(question)
        logger.info(f'Prompt final: \n{final_prompt}\n')
        answer = llm_model(final_prompt)
        logger.info(f"Resposta da IA: {answer[0]['generated_text']}")

    return answer[0]['generated_text']