luanpoppe commited on
Commit
4d3bceb
·
1 Parent(s): 4e93adb

feat: adicionando possibilidade de chamar criar um resumo com iterative_refinement

Browse files
langchain_backend/main.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
- from langchain_backend.utils import create_prompt_llm_chain, create_retriever, getPDF
3
  from langchain_backend import utils
4
  from langchain.chains import create_retrieval_chain
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_chroma import Chroma
7
  from langchain_openai import OpenAIEmbeddings
 
8
 
9
  os.environ.get("OPENAI_API_KEY")
10
 
@@ -29,18 +30,27 @@ def get_llm_answer(system_prompt, user_prompt, pdf_url, model, embedding):
29
  retriever = create_retriever(pages, vectorstore)
30
  rag_chain = create_retrieval_chain(retriever, create_prompt_llm_chain(system_prompt, model))
31
  results = rag_chain.invoke({"input": user_prompt})
32
- print('allIds ARQUIVO MAIN: ', utils.allIds)
33
  vectorstore.delete( utils.allIds)
34
  vectorstore.delete_collection()
35
  utils.allIds = []
36
- print('utils.allIds: ', utils.allIds)
37
  return results
38
 
39
- def get_llm_answer_summary(system_prompt, user_prompt, pdf_url, model):
40
  print('model: ', model)
 
 
41
  pages = getPDF(pdf_url)
42
- rag_chain = create_prompt_llm_chain(system_prompt, model)
 
43
 
44
- results = rag_chain.invoke({"input": user_prompt, "context": pages})
45
 
46
- return results
 
 
 
 
 
 
 
1
  import os
2
+ from langchain_backend.utils import create_prompt_llm_chain, create_retriever, getPDF, create_llm
3
  from langchain_backend import utils
4
  from langchain.chains import create_retrieval_chain
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_chroma import Chroma
7
  from langchain_openai import OpenAIEmbeddings
8
+ from langchain.chains.summarize import load_summarize_chain
9
 
10
  os.environ.get("OPENAI_API_KEY")
11
 
 
30
  retriever = create_retriever(pages, vectorstore)
31
  rag_chain = create_retrieval_chain(retriever, create_prompt_llm_chain(system_prompt, model))
32
  results = rag_chain.invoke({"input": user_prompt})
33
+ # print('allIds ARQUIVO MAIN: ', utils.allIds)
34
  vectorstore.delete( utils.allIds)
35
  vectorstore.delete_collection()
36
  utils.allIds = []
37
+ # print('utils.allIds: ', utils.allIds)
38
  return results
39
 
40
+ def get_llm_answer_summary(system_prompt, user_prompt, pdf_url, model, isIterativeRefinement):
41
  print('model: ', model)
42
+ print('isIterativeRefinement: ', isIterativeRefinement)
43
+ print('\n\n\n')
44
  pages = getPDF(pdf_url)
45
+ if not isIterativeRefinement:
46
+ rag_chain = create_prompt_llm_chain(system_prompt, model)
47
 
48
+ results = rag_chain.invoke({"input": user_prompt, "context": pages})
49
 
50
+ return results
51
+ else:
52
+ chain = load_summarize_chain(create_llm(model), "refine", True)
53
+ result = chain.invoke({"input_documents": pages})
54
+ print('result: ', result)
55
+ # Obs --> Para passar informações personalizadas --> chain = load_summarize_chain(llm, "refine", True, question_prompt=initial_prompt, refine_prompt=PromptTemplate.from_template(refine_prompt))
56
+ # Para ver mais opções --> Acessa a origem da função load_summarize_chain , e nela acessa a origem da função _load_refine_chain --> As opções são os parâmetros que esta última função recebe
langchain_backend/utils.py CHANGED
@@ -9,7 +9,10 @@ from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from setup.environment import default_model
10
  from uuid import uuid4
11
 
12
-
 
 
 
13
  os.environ.get("OPENAI_API_KEY")
14
  os.environ.get("HUGGINGFACEHUB_API_TOKEN")
15
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
@@ -29,8 +32,8 @@ def getPDF(file_paths):
29
  # loader = PyPDFLoader(file_paths, extract_images=False)
30
  # pages = loader.load_and_split(text_splitter)
31
  for page in pages:
32
- print('\n')
33
- print('allIds: ', allIds)
34
  documentId = str(uuid4())
35
  allIds.append(documentId)
36
  page.id = documentId
@@ -50,16 +53,7 @@ def create_retriever(documents, vectorstore):
50
  return retriever
51
 
52
  def create_prompt_llm_chain(system_prompt, modelParam):
53
- if modelParam == default_model:
54
- model = ChatOpenAI(model=modelParam)
55
- else:
56
- model = HuggingFaceEndpoint(
57
- repo_id=modelParam,
58
- task="text-generation",
59
- # max_new_tokens=100,
60
- do_sample=False,
61
- huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
62
- )
63
 
64
  system_prompt = system_prompt + "\n\n" + "{context}"
65
  prompt = ChatPromptTemplate.from_messages(
@@ -69,4 +63,16 @@ def create_prompt_llm_chain(system_prompt, modelParam):
69
  ]
70
  )
71
  question_answer_chain = create_stuff_documents_chain(model, prompt)
72
- return question_answer_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from setup.environment import default_model
10
  from uuid import uuid4
11
 
12
+ os.environ["LANGCHAIN_TRACING_V2"]="true"
13
+ os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
14
+ os.environ.get("LANGCHAIN_API_KEY")
15
+ os.environ["LANGCHAIN_PROJECT"]="VELLA"
16
  os.environ.get("OPENAI_API_KEY")
17
  os.environ.get("HUGGINGFACEHUB_API_TOKEN")
18
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
32
  # loader = PyPDFLoader(file_paths, extract_images=False)
33
  # pages = loader.load_and_split(text_splitter)
34
  for page in pages:
35
+ # print('\n')
36
+ # print('allIds: ', allIds)
37
  documentId = str(uuid4())
38
  allIds.append(documentId)
39
  page.id = documentId
 
53
  return retriever
54
 
55
  def create_prompt_llm_chain(system_prompt, modelParam):
56
+ model = create_llm(modelParam)
 
 
 
 
 
 
 
 
 
57
 
58
  system_prompt = system_prompt + "\n\n" + "{context}"
59
  prompt = ChatPromptTemplate.from_messages(
 
63
  ]
64
  )
65
  question_answer_chain = create_stuff_documents_chain(model, prompt)
66
+ return question_answer_chain
67
+
68
+ def create_llm(modelParam):
69
+ if modelParam == default_model:
70
+ return ChatOpenAI(model=modelParam)
71
+ else:
72
+ return HuggingFaceEndpoint(
73
+ repo_id=modelParam,
74
+ task="text-generation",
75
+ # max_new_tokens=100,
76
+ do_sample=False,
77
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
78
+ )
resumos/serializer.py CHANGED
@@ -20,5 +20,6 @@ from rest_framework import serializers
20
  class ResumoPDFSerializer(serializers.Serializer):
21
  files = serializers.ListField(child=serializers.FileField(), required=True)
22
  system_prompt = serializers.CharField(required=True)
23
- user_message = serializers.CharField(required=False)
24
- model = serializers.CharField(required=False)
 
 
20
  class ResumoPDFSerializer(serializers.Serializer):
21
  files = serializers.ListField(child=serializers.FileField(), required=True)
22
  system_prompt = serializers.CharField(required=True)
23
+ user_message = serializers.CharField(required=False, default="")
24
+ model = serializers.CharField(required=False)
25
+ iterative_refinement = serializers.BooleanField(required=False, default=False)
resumos/views.py CHANGED
@@ -18,9 +18,9 @@ class ResumoView(APIView):
18
  serializer = ResumoPDFSerializer(data=request.data)
19
  if serializer.is_valid(raise_exception=True):
20
  listaPDFs = []
21
- data = request.data
22
  model = serializer.validated_data.get("model", default_model)
23
- user_message = data.get("user_message", "")
24
 
25
  for file in serializer.validated_data['files']:
26
  print("file: ", file)
@@ -32,7 +32,7 @@ class ResumoView(APIView):
32
  listaPDFs.append(temp_file_path)
33
  # print('listaPDFs: ', listaPDFs)
34
 
35
- resposta_llm = get_llm_answer_summary(data["system_prompt"], user_message, listaPDFs, model=model)
36
 
37
  for file in listaPDFs:
38
  os.remove(file)
 
18
  serializer = ResumoPDFSerializer(data=request.data)
19
  if serializer.is_valid(raise_exception=True):
20
  listaPDFs = []
21
+ data = serializer.validated_data
22
  model = serializer.validated_data.get("model", default_model)
23
+ print('serializer.validated_data: ', serializer.validated_data)
24
 
25
  for file in serializer.validated_data['files']:
26
  print("file: ", file)
 
32
  listaPDFs.append(temp_file_path)
33
  # print('listaPDFs: ', listaPDFs)
34
 
35
+ resposta_llm = get_llm_answer_summary(data["system_prompt"], data["user_message"], listaPDFs, model=model, isIterativeRefinement=data["iterative_refinement"])
36
 
37
  for file in listaPDFs:
38
  os.remove(file)