File size: 2,446 Bytes
e9a38d8 a9b553f 87ed98d 9b0d920 c5f1959 9ee6131 106ed70 4c9f7f3 c17a736 b1af9dd 15fc625 9ee6131 15fc625 9ee6131 dee8762 9ee6131 c17a736 9ee6131 44e15dc 9ee6131 a2e8f1b 9ee6131 14eb8c8 9ee6131 14eb8c8 9702672 cb1d07a 9702672 cb1d07a 9702672 9ee6131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
import spaces
import gradio as gr
from huggingface_hub import InferenceClient, login
import time
import traceback
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import bitsandbytes
import torch
@spaces.GPU # Forces GPU allocation before execution
def force_gpu_allocation():
pass # Dummy function to trigger GPU setup
# Base model (LLaMA 3.1 8B) from Meta
base_model_name = "meta-llama/Llama-3.1-8B"
# Your fine-tuned LoRA adapter (uploaded to Hugging Face)
lora_model_name = "starnernj/Early-Christian-Church-Fathers-LLaMA-3.1-Fine-Tuned"
# Login because LLaMA 3.1 8B is a gated model
login(token=os.getenv("HuggingFaceFineGrainedReadToken"))
# Enable 4-bit Quantization with BitsAndBytes
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # β
Enables 4-bit quantization for memory efficiency
bnb_4bit_compute_dtype=torch.float16, # β
Uses float16 for performance
bnb_4bit_use_double_quant=True, # β
Optimizes quantization
bnb_4bit_quant_type="nf4" # β
Normalized Float-4 for better accuracy
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quantization_config,
device_map="auto"
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load LoRA Adapter
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, lora_model_name)
# Function to generate responses
def chatbot_response(user_input):
try:
inputs = tokenizer(user_input, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
error_message = f"AssertionError: {str(e)}\n{traceback.format_exc()}"
print(error_message) # β
Logs detailed error messages
return "An error occurred. Check the logs for details."
# Launch the Gradio chatbot
interface = gr.Interface(
fn=chatbot_response,
inputs=gr.Textbox(lines=2, placeholder="Ask me about the Christian Church Fathers..."),
outputs="text",
title="Early Christian Church Fathers Fine-Tuned LLaMA 3.1 8B with LoRA",
description="A chatbot using a fine-tuned LoRA adapter on LLaMA 3.1 8B, tuned on thousands of writings of the early Christian Church Fathers.",
)
if __name__ == "__main__":
interface.launch() |