#!/usr/bin/env python3 import random import logging import numpy as np from typing import List, Optional from langchain_community.vectorstores import FAISS from langchain.embeddings.base import Embeddings import gradio as gr from sentence_transformers import SentenceTransformer # Configuration DATA_FILE = "data-mtc.txt" # This file is no longer used in the Space DATABASE_DIR = "." # Database files are in the root directory CHUNK_SIZE = 800 TOP_K_RESULTS = 100 SIMILARITY_THRESHOLD = 0.1 BASE_SYSTEM_PROMPT = """ Répondez en français selon ces règles : 1. Utilisez EXCLUSIVEMENT le contexte fourni. 2. Structurez la réponse en : - Définition principale. - Caractéristiques clés (3 points maximum). - Relations avec d'autres concepts. 3. Si aucune information pertinente, indiquez-le clairement. Contexte : {context} """ # Logging configuration logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() # Output to console in the Space ] ) # Embedding Model Integration device = torch.device("cpu") embedding_model = SentenceTransformer("Snowflake/snowflake-arctic-embed-l", device=device, trust_remote_code=True) class HuggingFaceEmbeddings(Embeddings): """Embedding management using Hugging Face SentenceTransformer""" def _generate_embedding(self, text: str) -> np.ndarray: """Generate an embedding via the Hugging Face model""" try: return np.array(embedding_model.encode(text.strip()), dtype=np.float32) except Exception as e: logging.error(f"Embedding error: {str(e)}") raise RuntimeError("Failed to generate embedding") from e def embed_documents(self, texts: List[str]) -> List[List[float]]: return [self._generate_embedding(text).tolist() for text in texts] def embed_query(self, text: str) -> List[float]: return self._generate_embedding(text).tolist() def initialize_vector_store() -> FAISS: """Robust initialization of the vector store""" embeddings = HuggingFaceEmbeddings() try: logging.info("Loading existing database...") return FAISS.load_local( DATABASE_DIR, embeddings, allow_dangerous_deserialization=True ) except Exception as e: logging.error(f"FAISS loading error: {str(e)}") raise def generate_response(user_input: str, vector_store: FAISS) -> Optional[str]: """Generate a response with complete error handling""" try: docs_scores = vector_store.similarity_search_with_score( user_input, k=TOP_K_RESULTS * 3 ) filtered_docs = [ (doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD ] filtered_docs.sort(key=lambda x: x[1]) if not filtered_docs: return ("Aucune correspondance trouvée dans les textes MTC. " "Essayez avec des termes plus spécifiques.") best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]] context = "\n".join( f"=== Source {i + 1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs) ) from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "Qwen/Qwen2.5-72B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) prompt = BASE_SYSTEM_PROMPT.format(context=context) messages = [ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": user_input} ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate(**model_inputs, max_new_tokens=512) response = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True) return response[0] if response else "Réponse indisponible - Veuillez reformuler votre question." except Exception as e: logging.error(f"Generation error: {str(e)}", exc_info=True) return "Une erreur s'est produite lors de la génération de la réponse." def chatbot(query): """Main function to run the chatbot""" try: vs = initialize_vector_store() response = generate_response(query, vs) return response or "Aucune réponse générée." except Exception as e: logging.error(f"Chatbot error: {str(e)}") return f"Une erreur s'est produite : {str(e)}" # Gradio Interface Setup with Enhanced UI with gr.Blocks(title="MTC Chatbot") as demo: gr.Markdown("# Apprenez-en plus sur le savoir MTC!") chatbot_ui = gr.Chatbot(label="MTC Assistant", type="messages") input_box = gr.Textbox( placeholder="Posez votre question ici...", label="Votre question" ) input_box.submit(chatbot, inputs=input_box, outputs=chatbot_ui) if __name__ == "__main__": demo.launch()