dewiri commited on
Commit
7384f77
·
verified ·
1 Parent(s): 4db37a3

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +39 -49
rag_pipeline.py CHANGED
@@ -1,70 +1,60 @@
1
- # rag_pipeline.py (Debug-Version mit Indexprüfung & Logging)
2
-
3
  import os
4
  import pickle
5
- import numpy as np
6
  import faiss
 
7
  from sentence_transformers import SentenceTransformer
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
9
- import umap.umap_ as umap
10
- from dotenv import load_dotenv
11
- from groq import Groq
12
- from openai import OpenAI
13
- import tqdm
14
-
15
- print("🚀 RAG-App gestartet")
16
-
17
- # === Load environment variables (in HF Spaces über Secrets verfügbar) ===
18
- openai_api_key = os.getenv("OPENAI_API_KEY")
19
- groq_api_key = os.getenv("GROQ_API_KEY")
20
 
21
- groq_client = Groq(api_key=groq_api_key) if groq_api_key else None
22
- openai_client = OpenAI(api_key=openai_api_key) if openai_api_key else None
23
-
24
- # === Load SentenceTransformer model ===
25
- print("📦 Lade SentenceTransformer Modell...")
26
  model = SentenceTransformer("Sahajtomar/German-semantic")
27
- print("✅ Modell geladen")
28
 
29
- # === Lade FAISS-Index und Chunk-Mapping ===
30
- try:
31
- print("📂 Lade FAISS-Index...")
32
- if not os.path.exists("faiss/faiss_index.index"):
33
- raise FileNotFoundError("❌ faiss_index.index fehlt!")
34
 
35
- if not os.path.exists("faiss/chunks_mapping.pkl"):
36
- raise FileNotFoundError("❌ chunks_mapping.pkl fehlt!")
 
37
 
38
- index = faiss.read_index("faiss/faiss_index.index")
39
- with open("faiss/chunks_mapping.pkl", "rb") as f:
40
- token_split_texts = pickle.load(f)
 
 
 
 
 
 
 
 
41
 
42
- chunk_embeddings = model.encode(token_split_texts, convert_to_numpy=True)
43
- print("✅ FAISS & Embeddings geladen")
44
 
45
- # UMAP initialisieren
46
- umap_transform = umap.UMAP(random_state=0, transform_seed=0).fit(chunk_embeddings)
47
- print("✅ UMAP fit abgeschlossen")
48
 
49
- except Exception as e:
50
- print(f"❌ Fehler beim Laden von FAISS oder Chunks: {e}")
51
- index = None
52
- token_split_texts = []
53
- chunk_embeddings = None
54
- umap_transform = None
55
 
56
- def project_embeddings(embeddings, umap_transform):
57
- umap_embeddings = np.empty((len(embeddings), 2))
58
- for i, embedding in enumerate(tqdm.tqdm(embeddings, desc="Projecting Embeddings")):
59
- umap_embeddings[i] = umap_transform.transform([embedding])
60
- return umap_embeddings
61
 
62
  def retrieve(query, k=5):
63
- if index is None or chunk_embeddings is None:
64
- return ["Kein Index verfügbar."], [], []
65
  query_embedding = model.encode([query], convert_to_numpy=True)
66
  distances, indices = index.search(query_embedding, k)
67
  retrieved_texts = [token_split_texts[i] for i in indices[0]]
 
 
 
 
 
 
 
 
 
 
 
68
  retrieved_embeddings = np.array([chunk_embeddings[i] for i in indices[0]])
69
  return retrieved_texts, retrieved_embeddings, distances[0]
70
 
 
 
 
1
  import os
2
  import pickle
3
+ import requests
4
  import faiss
5
+ import numpy as np
6
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # === Modell laden ===
 
 
 
 
9
  model = SentenceTransformer("Sahajtomar/German-semantic")
 
10
 
11
+ # === Google Drive Direktlinks
12
+ url_index = "https://drive.google.com/uc?export=download&id=1QBg4vjitJ2xHEyp3Ae8TWJHwEHjbwgOO"
13
+ url_chunks = "https://drive.google.com/uc?export=download&id=1nsrAm_ozsK4GlmMui9yqZBjmgUfqU2qa"
 
 
14
 
15
+ # === Lokale Pfade
16
+ local_index = "faiss_index.index"
17
+ local_chunks = "chunks_mapping.pkl"
18
 
19
+ # === Download nur bei Bedarf
20
+ def download_if_missing(url, local_path):
21
+ if not os.path.exists(local_path):
22
+ print(f"⬇️ Lade {local_path} von Google Drive...")
23
+ r = requests.get(url)
24
+ if r.status_code == 200:
25
+ with open(local_path, "wb") as f:
26
+ f.write(r.content)
27
+ print(f"✅ Heruntergeladen: {local_path}")
28
+ else:
29
+ raise Exception(f"❌ Download fehlgeschlagen für {local_path}")
30
 
31
+ download_if_missing(url_index, local_index)
32
+ download_if_missing(url_chunks, local_chunks)
33
 
34
+ # === Dateien laden
35
+ print("📂 Lade FAISS Index und Chunks...")
36
+ index = faiss.read_index(local_index)
37
 
38
+ with open(local_chunks, "rb") as f:
39
+ token_split_texts = pickle.load(f)
 
 
 
 
40
 
41
+ chunk_embeddings = model.encode(token_split_texts, convert_to_numpy=True)
 
 
 
 
42
 
43
  def retrieve(query, k=5):
 
 
44
  query_embedding = model.encode([query], convert_to_numpy=True)
45
  distances, indices = index.search(query_embedding, k)
46
  retrieved_texts = [token_split_texts[i] for i in indices[0]]
47
+ return retrieved_texts
48
+
49
+ def build_prompt(query, texts):
50
+ context = "\n\n".join(texts)
51
+ return f"Beantworte die folgende Frage basierend auf dem Kontext:\n\nKontext:\n{context}\n\nFrage:\n{query}"
52
+
53
+ def run_qa_pipeline(query, k=5):
54
+ retrieved = retrieve(query, k)
55
+ prompt = build_prompt(query, retrieved)
56
+ return f"🔍 Kontext gefunden:\n\n{prompt}\n\n(Füge hier optional deine LLM-Antwort ein)" distances, indices = index.search(query_embedding, k)
57
+ retrieved_texts = [token_split_texts[i] for i in indices[0]]
58
  retrieved_embeddings = np.array([chunk_embeddings[i] for i in indices[0]])
59
  return retrieved_texts, retrieved_embeddings, distances[0]
60