import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from huggingface_hub import login # --- CONFIGURATION DE LA PAGE --- st.set_page_config( page_title="Chat Avancé avec Apertus", page_icon="🚀", layout="wide", initial_sidebar_state="expanded", ) # --- STYLES CSS PERSONNALISÉS (Optionnel) --- st.markdown(""" """, unsafe_allow_html=True) # --- BARRE LATÉRALE DE CONFIGURATION --- with st.sidebar: st.title("🚀 Paramètres") st.markdown("Configurez l'assistant et le modèle de langage.") # --- Authentification Hugging Face --- st.subheader("Authentification Hugging Face") hf_token = st.text_input("Votre Token Hugging Face (hf_...)", type="password") if st.button("Se Connecter"): if hf_token: try: login(token=hf_token) st.success("Connecté à Hugging Face Hub !") st.session_state.hf_logged_in = True except Exception as e: st.error(f"Échec de la connexion : {e}") else: st.warning("Veuillez entrer un token Hugging Face.") # --- Sélection du Modèle --- st.subheader("Sélection du Modèle") model_options = { "Apertus 8B (Rapide)": "swiss-ai/Apertus-8B-Instruct-2509", "Apertus 70B (Puissant)": "swiss-ai/Apertus-70B-2509" } selected_model_name = st.selectbox("Choisissez un modèle :", options=list(model_options.keys())) model_id = model_options[selected_model_name] st.caption(f"ID du modèle : `{model_id}`") # --- Paramètres de Génération --- st.subheader("Paramètres de Génération") temperature = st.slider("Température", min_value=0.1, max_value=1.5, value=0.7, step=0.05, help="Plus la valeur est élevée, plus la réponse est créative et aléatoire.") max_new_tokens = st.slider("Tokens Max", min_value=64, max_value=1024, value=256, step=64, help="Longueur maximale de la réponse générée.") top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.95, step=0.05, help="Contrôle la diversité en sélectionnant les mots les plus probables dont la somme des probabilités dépasse ce seuil.") # --- Bouton pour effacer l'historique --- if st.button("🗑️ Effacer l'historique"): st.session_state.messages = [] st.experimental_rerun() # --- CHARGEMENT DU MODÈLE (MIS EN CACHE) --- @st.cache_resource(show_spinner=False) def load_model(model_identifier): """Charge le tokenizer et le modèle avec quantification 4-bit.""" with st.spinner(f"Chargement du modèle '{model_identifier}'... Cela peut prendre un moment. ⏳"): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(model_identifier) model = AutoModelForCausalLM.from_pretrained( model_identifier, quantization_config=bnb_config, device_map="auto", ) return tokenizer, model # Charge le modèle sélectionné try: tokenizer, model = load_model(model_id) except Exception as e: st.error(f"Impossible de charger le modèle. Assurez-vous d'être connecté si le modèle est privé. Erreur : {e}") st.stop() # --- INTERFACE DE CHAT PRINCIPALE --- st.title("🤖 Chat avec Apertus") st.caption(f"Vous discutez actuellement avec **{selected_model_name}**.") # Initialisation de l'historique du chat if "messages" not in st.session_state: st.session_state.messages = [] # Affichage des messages de l'historique for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Zone de saisie utilisateur if prompt := st.chat_input("Posez votre question à Apertus..."): # Ajout et affichage du message utilisateur st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # --- GÉNÉRATION DE LA RÉPONSE --- with st.chat_message("assistant"): response_placeholder = st.empty() with st.spinner("Réflexion en cours... 🤔"): # Préparation des entrées pour le modèle # Nous ne formaterons plus le prompt, le modèle instruct est déjà finetuné pour ça. input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) # Génération de la réponse outputs = model.generate( **input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id ) # Décodage et nettoyage de la réponse response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Nettoyage pour retirer la question initiale de la réponse cleaned_response = response_text.replace(prompt, "").strip() response_placeholder.markdown(cleaned_response) # Ajout de la réponse de l'assistant à l'historique st.session_state.messages.append({"role": "assistant", "content": cleaned_response})