File size: 2,379 Bytes
a9b553f 87ed98d 9b0d920 392c48a 9702672 4c9f7f3 20d8f1c c5f1959 ce3da34 14eb8c8 a2e8f1b 15fc625 4c9f7f3 bc11d2a b7e43e1 bc11d2a b1af9dd 82843e4 1e24890 6cb24cb bc11d2a 15fc625 a2e8f1b cb1d07a 4dbc1c3 15fc625 cb1d07a 15fc625 4dbc1c3 15fc625 cb1d07a 14eb8c8 9702672 cb1d07a 9702672 cb1d07a 9702672 0e2ae71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import spaces
import gradio as gr
from huggingface_hub import InferenceClient, login
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import os
import torch
import time
import bitsandbytes
import traceback
import threading
from accelerate import Accelerator
@spaces.GPU # Forces GPU allocation before execution
def force_gpu_allocation():
pass # Dummy function to trigger GPU setup
print(f"Is CUDA available: {torch.cuda.is_available()}") # True
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Define the device correctly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # Debugging info
# Base model (LLaMA 3.1 8B) from Meta
base_model_name = "meta-llama/Llama-3.1-8B"
# Your fine-tuned LoRA adapter (uploaded to Hugging Face)
lora_model_name = "starnernj/Early-Christian-Church-Fathers-LLaMA-3.1-Fine-Tuned"
# Function to generate responses
def chatbot_response(user_input):
accelerator = Accelerator()
# Login because LLaMA 3.1 8B is a gated model
login(token=os.getenv("HuggingFaceFineGrainedReadToken"))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
)
# Load LoRA adapter
model = PeftModel.from_pretrained(model, lora_model_name)
model = accelerator.prepare(model)
try:
inputs = tokenizer(user_input, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
error_message = f"AssertionError: {str(e)}\n{traceback.format_exc()}"
print(error_message) # ✅ Logs detailed error messages
return "An error occurred. Check the logs for details."
# Launch the Gradio chatbot
interface = gr.Interface(
fn=chatbot_response,
inputs=gr.Textbox(lines=2, placeholder="Ask me about the Christian Church Fathers..."),
outputs="text",
title="Early Christian Church Fathers Fine-Tuned LLaMA 3.1 8B with LoRA",
description="A chatbot using a fine-tuned LoRA adapter on LLaMA 3.1 8B, tuned on thousands of writings of the early Christian Church Fathers.",
)
interface.launch() |