|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from huggingface_hub import login |
|
|
|
|
|
st.set_page_config( |
|
page_title="Chat Avancé avec Apertus", |
|
page_icon="🚀", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stSpinner > div > div { |
|
border-top-color: #f63366; |
|
} |
|
.stChatMessage { |
|
background-color: #f0f2f6; |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin-bottom: 10px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.title("🚀 Paramètres") |
|
st.markdown("Configurez l'assistant et le modèle de langage.") |
|
|
|
|
|
st.subheader("Authentification Hugging Face") |
|
hf_token = st.text_input("Votre Token Hugging Face (hf_...)", type="password") |
|
if st.button("Se Connecter"): |
|
if hf_token: |
|
try: |
|
login(token=hf_token) |
|
st.success("Connecté à Hugging Face Hub !") |
|
st.session_state.hf_logged_in = True |
|
except Exception as e: |
|
st.error(f"Échec de la connexion : {e}") |
|
else: |
|
st.warning("Veuillez entrer un token Hugging Face.") |
|
|
|
|
|
st.subheader("Sélection du Modèle") |
|
model_options = { |
|
"Apertus 8B (Rapide)": "swiss-ai/Apertus-8B-Instruct-2509", |
|
"Apertus 70B (Puissant)": "swiss-ai/Apertus-70B-2509" |
|
} |
|
selected_model_name = st.selectbox("Choisissez un modèle :", options=list(model_options.keys())) |
|
model_id = model_options[selected_model_name] |
|
st.caption(f"ID du modèle : `{model_id}`") |
|
|
|
|
|
|
|
st.subheader("Paramètres de Génération") |
|
temperature = st.slider("Température", min_value=0.1, max_value=1.5, value=0.7, step=0.05, |
|
help="Plus la valeur est élevée, plus la réponse est créative et aléatoire.") |
|
max_new_tokens = st.slider("Tokens Max", min_value=64, max_value=1024, value=256, step=64, |
|
help="Longueur maximale de la réponse générée.") |
|
top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.95, step=0.05, |
|
help="Contrôle la diversité en sélectionnant les mots les plus probables dont la somme des probabilités dépasse ce seuil.") |
|
|
|
|
|
if st.button("🗑️ Effacer l'historique"): |
|
st.session_state.messages = [] |
|
st.experimental_rerun() |
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_model(model_identifier): |
|
"""Charge le tokenizer et le modèle avec quantification 4-bit.""" |
|
with st.spinner(f"Chargement du modèle '{model_identifier}'... Cela peut prendre un moment. ⏳"): |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_identifier) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_identifier, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
) |
|
return tokenizer, model |
|
|
|
|
|
try: |
|
tokenizer, model = load_model(model_id) |
|
except Exception as e: |
|
st.error(f"Impossible de charger le modèle. Assurez-vous d'être connecté si le modèle est privé. Erreur : {e}") |
|
st.stop() |
|
|
|
|
|
|
|
st.title("🤖 Chat avec Apertus") |
|
st.caption(f"Vous discutez actuellement avec **{selected_model_name}**.") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Posez votre question à Apertus..."): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
response_placeholder = st.empty() |
|
with st.spinner("Réflexion en cours... 🤔"): |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
**input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
cleaned_response = response_text.replace(prompt, "").strip() |
|
|
|
response_placeholder.markdown(cleaned_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": cleaned_response}) |