added stepback-Prompting
Browse files
app.py
CHANGED
@@ -291,8 +291,114 @@ class BSIChatbot:
|
|
291 |
global rerankingModel
|
292 |
rerankingModel = RAGPretrainedModel.from_pretrained(self.rerankModelPath)
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
#@spaces.GPU
|
295 |
-
def ragPromptRemote(self, query, rerankingStep, history):
|
296 |
global rerankingModel
|
297 |
prompt_in_chat_format = [
|
298 |
{
|
@@ -477,19 +583,26 @@ class BSIChatbot:
|
|
477 |
def launchGr(self):
|
478 |
gr.Interface.from_pipeline(self.llmpipeline).launch()
|
479 |
|
480 |
-
def queryRemoteLLM(self, systemPrompt, query):
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
return Answer
|
487 |
|
488 |
def stepBackPrompt(self, query):
|
489 |
systemPrompt = """
|
490 |
Sie sind ein Experte für den IT-Grundschutz des BSI.
|
491 |
Ihre Aufgabe ist es, eine Frage neu zu formulieren und sie in eine
|
492 |
-
|
493 |
|
494 |
Hier sind ein paar Beispiele:
|
495 |
Ursprüngliche Frage: Welche Bausteine werden auf einen Webserver angewendet?
|
@@ -516,6 +629,8 @@ if __name__ == '__main__':
|
|
516 |
|
517 |
renewEmbeddings = False
|
518 |
reranking = True
|
|
|
|
|
519 |
bot = BSIChatbot()
|
520 |
bot.initializeEmbeddingModel(renewEmbeddings)
|
521 |
if reranking == True:
|
@@ -564,7 +679,8 @@ if __name__ == '__main__':
|
|
564 |
print(f"DBG: ragQuery hist -1:{history[-1].get('content')}")
|
565 |
print(f"DBG: ragQuery hist 0:{history[0].get('content')}")
|
566 |
print(f"DBG: fullHistory: {history}" )
|
567 |
-
bot_response = bot.ragPromptRemote(history[-1].get('content'), reranking, history)
|
|
|
568 |
history.append({"role": "assistant", "content": ""})
|
569 |
|
570 |
image_gallery = returnImages()
|
|
|
291 |
global rerankingModel
|
292 |
rerankingModel = RAGPretrainedModel.from_pretrained(self.rerankModelPath)
|
293 |
|
294 |
+
|
295 |
+
def retrieval(query, rerankingStep):
|
296 |
+
retrieved_chunks = self.retrieveSimiliarEmbedding(query)
|
297 |
+
retrieved_chunks_text = []
|
298 |
+
# TODO Irgendwas stimmt hier mit den Listen nicht
|
299 |
+
for chunk in retrieved_chunks:
|
300 |
+
# TODO Hier noch was smarteres Überlegen für alle Header
|
301 |
+
if "Header 1" in chunk.metadata.keys():
|
302 |
+
retrieved_chunks_text.append(
|
303 |
+
f"The Document is: '{chunk.metadata['source']}'\nHeader of the Section is: '{chunk.metadata['Header 1']}' and Content of it:{chunk.page_content}")
|
304 |
+
else:
|
305 |
+
retrieved_chunks_text.append(
|
306 |
+
f"The Document is: '{chunk.metadata['source']}'\nImage Description is: ':{chunk.page_content}")
|
307 |
+
i = 1
|
308 |
+
for chunk in retrieved_chunks_text:
|
309 |
+
print(f"Retrieved Chunk number {i}:\n{chunk}")
|
310 |
+
i = i + 1
|
311 |
+
|
312 |
+
if rerankingStep == True:
|
313 |
+
if rerankingModel == None:
|
314 |
+
print("initializing Reranker-Model..")
|
315 |
+
self.initializeRerankingModel()
|
316 |
+
print("Starting Reranking Chunks...")
|
317 |
+
rerankingModel
|
318 |
+
retrieved_chunks_text = rerankingModel.rerank(query, retrieved_chunks_text, k=5)
|
319 |
+
retrieved_chunks_text = [chunk["content"] for chunk in retrieved_chunks_text]
|
320 |
+
|
321 |
+
i = 1
|
322 |
+
for chunk in retrieved_chunks_text:
|
323 |
+
print(f"Reranked Chunk number {i}:\n{chunk}")
|
324 |
+
i = i + 1
|
325 |
+
|
326 |
+
context = "\nExtracted documents:\n"
|
327 |
+
context += "".join([doc for i, doc in enumerate(retrieved_chunks_text)])
|
328 |
+
|
329 |
+
return query, context
|
330 |
+
|
331 |
+
def ragPromptNew(self, query, rerankingStep, history, stepBackPrompt):
|
332 |
+
global rerankingModel
|
333 |
+
prompt_in_chat_format = [
|
334 |
+
{
|
335 |
+
"role": "system",
|
336 |
+
"content": """You are an helpful Chatbot for the BSI IT-Grundschutz. Using the information contained in the context,
|
337 |
+
give a comprehensive answer to the question.
|
338 |
+
Respond only to the question asked, response should be concise and relevant but also give some context to the question.
|
339 |
+
Provide the source document when relevant for the understanding.
|
340 |
+
If the answer cannot be deduced from the context, do not give an answer.""",
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"role": "user",
|
344 |
+
"content": """Context:
|
345 |
+
{context}
|
346 |
+
---
|
347 |
+
Chat-History:
|
348 |
+
{history}
|
349 |
+
---
|
350 |
+
Now here is the question you need to answer.
|
351 |
+
|
352 |
+
Question: {question}""",
|
353 |
+
},
|
354 |
+
]
|
355 |
+
# RAG_PROMPT_TEMPLATE = self.llmtokenizer.apply_chat_template(
|
356 |
+
# prompt_in_chat_format, tokenize=False, add_generation_prompt=True
|
357 |
+
# )
|
358 |
+
|
359 |
+
# Alles außer letzte Useranfrage, Normaler Query
|
360 |
+
query, context = retrieval(query, True)
|
361 |
+
if stepBackPrompt == True:
|
362 |
+
stepBackQuery = stepBackPrompt(query)
|
363 |
+
stepBackQuery, stepBackContext = retrieval(stepBackQuery, True)
|
364 |
+
sysPrompt = """
|
365 |
+
You are an helpful Chatbot for the BSI IT-Grundschutz. Using the information contained in the context,
|
366 |
+
give a comprehensive answer to the question.
|
367 |
+
Respond only to the question asked, response should be concise and relevant but also give some context to the question.
|
368 |
+
Provide the source document when relevant for the understanding.
|
369 |
+
If the answer cannot be deduced from the context, do not give an answer.
|
370 |
+
"""
|
371 |
+
stepBackAnswer = queryRemoteLLM(sysPrompt, stepBackQuery, True)
|
372 |
+
context += "Übergreifende Frage:" + stepBackQuery + "Übergreifender Context:" + stepBackAnswer
|
373 |
+
|
374 |
+
#def queryRemoteLLM(self, systemPrompt, query, summary):
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
prompt_in_chat_format[-1]["content"] = prompt_in_chat_format[-1]["content"].format(
|
379 |
+
question=query, context=context, history=history[:-1]
|
380 |
+
)
|
381 |
+
final_prompt = prompt_in_chat_format
|
382 |
+
|
383 |
+
# final_prompt = prompt_in_chat_format[-1]["content"].format(
|
384 |
+
# question=query, context=context, history=history[:-1]
|
385 |
+
# )
|
386 |
+
|
387 |
+
print(f"Query:\n{final_prompt}")
|
388 |
+
pattern = r"Filename:(.*?);"
|
389 |
+
last_value = final_prompt[-1]["content"]
|
390 |
+
|
391 |
+
match = re.findall(pattern, last_value)
|
392 |
+
self.images = match
|
393 |
+
|
394 |
+
stream = self.llm_client.chat.completions.create(
|
395 |
+
messages=final_prompt,
|
396 |
+
model=self.llm_remote_model,
|
397 |
+
stream=True
|
398 |
+
)
|
399 |
+
return stream
|
400 |
#@spaces.GPU
|
401 |
+
def ragPromptRemote(self, query, rerankingStep, history, stepBackPrompt):
|
402 |
global rerankingModel
|
403 |
prompt_in_chat_format = [
|
404 |
{
|
|
|
583 |
def launchGr(self):
|
584 |
gr.Interface.from_pipeline(self.llmpipeline).launch()
|
585 |
|
586 |
+
def queryRemoteLLM(self, systemPrompt, query, summary):
|
587 |
+
if summary != True:
|
588 |
+
chat_completion = self.llm_client.chat.completions.create(
|
589 |
+
messages=[{"role": "system", "content": systemPrompt},
|
590 |
+
{"role": "user", "content": "Step-Back Frage, die neu gestellt werden soll: " + query}],
|
591 |
+
model=self.llm_remote_model,
|
592 |
+
)
|
593 |
+
if summary == True:
|
594 |
+
chat_completion = self.llm_client.chat.completions.create(
|
595 |
+
messages=[{"role": "system", "content": systemPrompt},
|
596 |
+
{"role": "user", "content": query}],
|
597 |
+
model=self.llm_remote_model,
|
598 |
+
)
|
599 |
return Answer
|
600 |
|
601 |
def stepBackPrompt(self, query):
|
602 |
systemPrompt = """
|
603 |
Sie sind ein Experte für den IT-Grundschutz des BSI.
|
604 |
Ihre Aufgabe ist es, eine Frage neu zu formulieren und sie in eine
|
605 |
+
Stepback-Frage umzuformulieren, die nach einem Grundkonzept der Begrifflichkeit fragt.
|
606 |
|
607 |
Hier sind ein paar Beispiele:
|
608 |
Ursprüngliche Frage: Welche Bausteine werden auf einen Webserver angewendet?
|
|
|
629 |
|
630 |
renewEmbeddings = False
|
631 |
reranking = True
|
632 |
+
stepBackEnable = True
|
633 |
+
|
634 |
bot = BSIChatbot()
|
635 |
bot.initializeEmbeddingModel(renewEmbeddings)
|
636 |
if reranking == True:
|
|
|
679 |
print(f"DBG: ragQuery hist -1:{history[-1].get('content')}")
|
680 |
print(f"DBG: ragQuery hist 0:{history[0].get('content')}")
|
681 |
print(f"DBG: fullHistory: {history}" )
|
682 |
+
#bot_response = bot.ragPromptRemote(history[-1].get('content'), reranking, history)
|
683 |
+
bot_response = bot.ragPromptNew(history[-1].get('content'), reranking, history, stepBackEnable)
|
684 |
history.append({"role": "assistant", "content": ""})
|
685 |
|
686 |
image_gallery = returnImages()
|