ChatGLM4CS313 / app.py
kietnt0603's picture
Update app.py
cda859f verified
import streamlit as st
import os
import torch
from datasets import DatasetDict, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, logging
logging.set_verbosity_error()
model_name = 'THUDM/chatglm3-6b'
#############################################
# bitsandbytes parameters
#############################################
# Activate 4-bit precision for base model loading
use_4bit = True
# Compute dtype of 4-bit base models
bnb_4bit_compute_dtype = 'float16'
# Quantization type (fp4 or np4)
bnb_4bit_quant_type = 'nf4'
# Activate nested quantization for 4-bit base models
use_nested_quant = False
# device mapping
device = torch.device("cpu") # Set device to CPU
device_map = {"": -1} # Use -1 for CPU in bnb_config
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
  load_in_4bit=use_4bit,
  bnb_4bit_quant_type=bnb_4bit_quant_type,
  bnb_4bit_compute_dtype=compute_dtype,
  bnb_4bit_use_double_quant=use_nested_quant,
)
if compute_dtype == torch.float16 and use_4bit:
  major, _ = torch.cuda.get_device_capability()
  if major >= 8:
    print('='*80)
    print('Your GPU supports bfloat16, you can accelerate using the argument --fp16')
    print('='*80)
model = AutoModelForCausalLM.from_pretrained(
  model_name,
  trust_remote_code=True,
  quantization_config=bnb_config,
  device_map=device_map,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = 'left'
# Set the title of the Streamlit app
st.title("Chatbot with LangChain and HuggingFace Model")
# Display the conversation history
conversation_text = st.empty()
# Get the user input
user_input = st.text_input("You: ")
history = []
# If the user has submitted input
if st.button("Send"):
  # Generate the chatbot's response
  response, history = model.chat(tokenizer, user_input, history=history)
  # Add the response to the conversation history
  conversation_history.append(f"Bot: {response}")
  # Update the conversation text
  conversation_text.markdown("**Conversation:**\n")
  for message in conversation_history:
    conversation_text.markdown(f"- {message}")