File size: 5,689 Bytes
1538b6a
9486a95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3116c2c
 
9486a95
 
 
 
 
 
 
 
 
 
 
 
 
 
3116c2c
9486a95
 
 
 
3116c2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login

# --- CONFIGURATION DE LA PAGE ---
st.set_page_config(
    page_title="Chat Avancé avec Apertus",
    page_icon="🚀",
    layout="wide",
    initial_sidebar_state="expanded",
)

# --- STYLES CSS PERSONNALISÉS (Optionnel) ---
st.markdown("""
<style>
    .stSpinner > div > div {
        border-top-color: #f63366;
    }
    .stChatMessage {
        background-color: #f0f2f6;
        border-radius: 10px;
        padding: 15px;
        margin-bottom: 10px;
    }
</style>
""", unsafe_allow_html=True)


# --- BARRE LATÉRALE DE CONFIGURATION ---
with st.sidebar:
    st.title("🚀 Paramètres")
    st.markdown("Configurez l'assistant et le modèle de langage.")

    # --- Authentification Hugging Face ---
    st.subheader("Authentification Hugging Face")
    hf_token = st.text_input("Votre Token Hugging Face (hf_...)", type="password")
    if st.button("Se Connecter"):
        if hf_token:
            try:
                login(token=hf_token)
                st.success("Connecté à Hugging Face Hub !")
                st.session_state.hf_logged_in = True
            except Exception as e:
                st.error(f"Échec de la connexion : {e}")
        else:
            st.warning("Veuillez entrer un token Hugging Face.")

    # --- Sélection du Modèle ---
    st.subheader("Sélection du Modèle")
    model_options = {
        "Apertus 8B (Rapide)": "swiss-ai/Apertus-8B-Instruct-2509",
        "Apertus 70B (Puissant)": "swiss-ai/Apertus-70B-2509"
    }
    selected_model_name = st.selectbox("Choisissez un modèle :", options=list(model_options.keys()))
    model_id = model_options[selected_model_name]
    st.caption(f"ID du modèle : `{model_id}`")


    # --- Paramètres de Génération ---
    st.subheader("Paramètres de Génération")
    temperature = st.slider("Température", min_value=0.1, max_value=1.5, value=0.7, step=0.05,
                              help="Plus la valeur est élevée, plus la réponse est créative et aléatoire.")
    max_new_tokens = st.slider("Tokens Max", min_value=64, max_value=1024, value=256, step=64,
                                 help="Longueur maximale de la réponse générée.")
    top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.95, step=0.05,
                      help="Contrôle la diversité en sélectionnant les mots les plus probables dont la somme des probabilités dépasse ce seuil.")
    
    # --- Bouton pour effacer l'historique ---
    if st.button("🗑️ Effacer l'historique"):
        st.session_state.messages = []
        st.experimental_rerun()


# --- CHARGEMENT DU MODÈLE (MIS EN CACHE) ---
@st.cache_resource(show_spinner=False)
def load_model(model_identifier):
    """Charge le tokenizer et le modèle avec quantification 4-bit."""
    with st.spinner(f"Chargement du modèle '{model_identifier}'... Cela peut prendre un moment. ⏳"):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_identifier)
        model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            quantization_config=bnb_config,
            device_map="auto",
        )
    return tokenizer, model

# Charge le modèle sélectionné
try:
    tokenizer, model = load_model(model_id)
except Exception as e:
    st.error(f"Impossible de charger le modèle. Assurez-vous d'être connecté si le modèle est privé. Erreur : {e}")
    st.stop()


# --- INTERFACE DE CHAT PRINCIPALE ---
st.title("🤖 Chat avec Apertus")
st.caption(f"Vous discutez actuellement avec **{selected_model_name}**.")

# Initialisation de l'historique du chat
if "messages" not in st.session_state:
    st.session_state.messages = []

# Affichage des messages de l'historique
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Zone de saisie utilisateur
if prompt := st.chat_input("Posez votre question à Apertus..."):
    # Ajout et affichage du message utilisateur
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    # --- GÉNÉRATION DE LA RÉPONSE ---
    with st.chat_message("assistant"):
        response_placeholder = st.empty()
        with st.spinner("Réflexion en cours... 🤔"):
            # Préparation des entrées pour le modèle
            # Nous ne formaterons plus le prompt, le modèle instruct est déjà finetuné pour ça.
            input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)

            # Génération de la réponse
            outputs = model.generate(
                **input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                eos_token_id=tokenizer.eos_token_id
            )
            
            # Décodage et nettoyage de la réponse
            response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            # Nettoyage pour retirer la question initiale de la réponse
            cleaned_response = response_text.replace(prompt, "").strip()
            
            response_placeholder.markdown(cleaned_response)
    
    # Ajout de la réponse de l'assistant à l'historique
    st.session_state.messages.append({"role": "assistant", "content": cleaned_response})