dewiri commited on
Commit
d9b374b
·
verified ·
1 Parent(s): eb74260

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +83 -27
rag_pipeline.py CHANGED
@@ -4,6 +4,18 @@ import requests
4
  import faiss
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # === Modell laden ===
9
  print("🧠 Lade SentenceTransformer...")
@@ -13,54 +25,98 @@ model = SentenceTransformer("Sahajtomar/German-semantic")
13
  url_index = "https://drive.google.com/uc?export=download&id=1QBg4vjitJ2xHEyp3Ae8TWJHwEHjbwgOO"
14
  url_chunks = "https://drive.google.com/uc?export=download&id=1nsrAm_ozsK4GlmMui9yqZBjmgUfqU2qa"
15
 
16
- # === Lokale Dateipfade
17
  local_index = "faiss_index.index"
18
  local_chunks = "chunks_mapping.pkl"
19
 
20
- # === Datei-Download bei Bedarf
21
- def download_if_missing(url, local_path):
22
- if not os.path.exists(local_path):
23
- print(f"⬇️ Lade {local_path} von Google Drive...")
24
  r = requests.get(url)
25
  if r.status_code == 200:
26
- with open(local_path, "wb") as f:
27
  f.write(r.content)
28
- print(f"✅ Heruntergeladen: {local_path}")
29
  else:
30
- raise Exception(f"❌ Download fehlgeschlagen für {local_path}")
31
 
32
  download_if_missing(url_index, local_index)
33
  download_if_missing(url_chunks, local_chunks)
34
 
35
- # === FAISS & Chunks laden
36
- print("📂 Lade FAISS Index und Text-Chunks...")
37
- index = faiss.read_index(local_index)
38
-
39
  with open(local_chunks, "rb") as f:
40
  token_split_texts = pickle.load(f)
 
41
 
42
- print(f"✅ Geladene Chunks: {len(token_split_texts)}")
 
 
 
 
43
 
44
- # === Embedding nur auf den ersten 10 Chunks testen
45
- print("⚙️ Starte Embedding-Berechnung auf 10 Chunks...")
46
- test_chunks = token_split_texts[:10]
47
- chunk_embeddings = model.encode(test_chunks, convert_to_numpy=True)
48
- print("✅ Embeddings kodiert")
49
-
50
- # === Abruffunktion
51
  def retrieve(query, k=5):
52
  query_embedding = model.encode([query], convert_to_numpy=True)
53
  distances, indices = index.search(query_embedding, k)
54
- retrieved_texts = [test_chunks[i] for i in indices[0]]
55
- return retrieved_texts
56
 
57
- # === Prompt Builder
58
  def build_prompt(query, texts):
59
  context = "\n\n".join(texts)
60
- return f"Beantworte die folgende Frage basierend auf dem Kontext:\n\nKontext:\n{context}\n\nFrage:\n{query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # === Hauptfunktion für Gradio
63
  def run_qa_pipeline(query, k=5):
64
- retrieved = retrieve(query, k)
65
- prompt = build_prompt(query, retrieved)
66
- return f"🔍 Kontext gefunden:\n\n{prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import faiss
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
+ from dotenv import load_dotenv
8
+
9
+ from openai import OpenAI
10
+ from groq import Groq
11
+
12
+ # === API Keys laden ===
13
+ load_dotenv()
14
+ openai_key = os.getenv("OPENAI_API_KEY")
15
+ groq_key = os.getenv("GROQ_API_KEY")
16
+
17
+ openai_client = OpenAI(api_key=openai_key) if openai_key else None
18
+ groq_client = Groq(api_key=groq_key) if groq_key else None
19
 
20
  # === Modell laden ===
21
  print("🧠 Lade SentenceTransformer...")
 
25
  url_index = "https://drive.google.com/uc?export=download&id=1QBg4vjitJ2xHEyp3Ae8TWJHwEHjbwgOO"
26
  url_chunks = "https://drive.google.com/uc?export=download&id=1nsrAm_ozsK4GlmMui9yqZBjmgUfqU2qa"
27
 
 
28
  local_index = "faiss_index.index"
29
  local_chunks = "chunks_mapping.pkl"
30
 
31
+ # === Download bei Bedarf
32
+ def download_if_missing(url, path):
33
+ if not os.path.exists(path):
34
+ print(f"⬇️ Lade {path} von Google Drive...")
35
  r = requests.get(url)
36
  if r.status_code == 200:
37
+ with open(path, "wb") as f:
38
  f.write(r.content)
39
+ print(f"✅ Heruntergeladen: {path}")
40
  else:
41
+ raise Exception(f"❌ Fehler beim Herunterladen von {path}")
42
 
43
  download_if_missing(url_index, local_index)
44
  download_if_missing(url_chunks, local_chunks)
45
 
46
+ # === FAISS laden
47
+ print("📂 Lade FAISS & Chunks...")
 
 
48
  with open(local_chunks, "rb") as f:
49
  token_split_texts = pickle.load(f)
50
+ print(f"✅ {len(token_split_texts)} Chunks geladen.")
51
 
52
+ chunk_embeddings = model.encode(token_split_texts, convert_to_numpy=True)
53
+ d = chunk_embeddings.shape[1]
54
+ index = faiss.IndexFlatL2(d)
55
+ index.add(chunk_embeddings)
56
+ print(f"✅ FAISS Index mit {index.ntotal} Einträgen.")
57
 
58
+ # === Ähnliche Chunks abrufen
 
 
 
 
 
 
59
  def retrieve(query, k=5):
60
  query_embedding = model.encode([query], convert_to_numpy=True)
61
  distances, indices = index.search(query_embedding, k)
62
+ safe_indices = [i for i in indices[0] if i < len(token_split_texts)]
63
+ return [token_split_texts[i] for i in safe_indices]
64
 
65
+ # === Prompt zusammenbauen
66
  def build_prompt(query, texts):
67
  context = "\n\n".join(texts)
68
+ return f"""Beantworte die folgende Frage basierend auf dem Kontext.
69
+
70
+ Kontext:
71
+ {context}
72
+
73
+ Frage:
74
+ {query}
75
+ """
76
+
77
+ # === Anfrage an OpenAI
78
+ def ask_openai(prompt):
79
+ if not openai_client:
80
+ return "❌ Kein OpenAI API Key gefunden"
81
+ res = openai_client.chat.completions.create(
82
+ model="gpt-4",
83
+ messages=[
84
+ {"role": "system", "content": "Du bist ein hilfsbereiter Catan-Regel-Experte."},
85
+ {"role": "user", "content": prompt}
86
+ ]
87
+ )
88
+ return res.choices[0].message.content.strip()
89
+
90
+ # === Anfrage an Groq
91
+ def ask_groq(prompt):
92
+ if not groq_client:
93
+ return "❌ Kein Groq API Key gefunden"
94
+ res = groq_client.chat.completions.create(
95
+ model="llama3-70b-8192",
96
+ messages=[
97
+ {"role": "system", "content": "Du bist ein hilfsbereiter Catan-Regel-Experte."},
98
+ {"role": "user", "content": prompt}
99
+ ]
100
+ )
101
+ return res.choices[0].message.content.strip()
102
 
103
  # === Hauptfunktion für Gradio
104
  def run_qa_pipeline(query, k=5):
105
+ try:
106
+ retrieved = retrieve(query, k)
107
+ if not retrieved:
108
+ return "⚠️ Keine relevanten Textstellen gefunden."
109
+ prompt = build_prompt(query, retrieved)
110
+ print("📨 Prompt gesendet...")
111
+
112
+ if openai_client:
113
+ answer = ask_openai(prompt)
114
+ elif groq_client:
115
+ answer = ask_groq(prompt)
116
+ else:
117
+ return "⚠️ Kein LLM API-Key vorhanden. Bitte OPENAI_API_KEY oder GROQ_API_KEY hinterlegen."
118
+
119
+ return f"📌 Frage: {query}\n\n📖 Antwort:\n{answer}"
120
+
121
+ except Exception as e:
122
+ return f"❌ Fehler: {str(e)}" return f"🔍 Kontext gefunden:\n\n{prompt}"