Spaces:
Sleeping
Sleeping
Commit
·
eca6215
0
Parent(s):
update model
Browse files- .gitattributes +38 -0
- .github/workflows/deploy.yml +115 -0
- .gitignore +9 -0
- Dokerfile +29 -0
- Makefile +91 -0
- README.md +45 -0
- Setup.py +76 -0
- User.code-workspace +32 -0
- __init__.py +51 -0
- app.py +116 -0
- pytest.ini +3 -0
- requirements.txt +20 -0
- space.yml +19 -0
- src/__init__.py +6 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-312.pyc +0 -0
- src/__pycache__/inference.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/__pycache__/save_model.cpython-312.pyc +0 -0
- src/__pycache__/training.cpython-312.pyc +0 -0
- src/config.py +44 -0
- src/dataset.py +22 -0
- src/fine_tune_llama.py +38 -0
- src/inference.py +51 -0
- src/model.py +32 -0
- src/save_model.py +23 -0
- src/training.py +20 -0
- tests/__init__.py +0 -0
- tests/__pycache__/test_config.cpython-312.pyc +0 -0
- tests/__pycache__/test_dataset.cpython-312.pyc +0 -0
- tests/__pycache__/test_inference.cpython-312.pyc +0 -0
- tests/__pycache__/test_model.cpython-312.pyc +0 -0
- tests/__pycache__/test_save_model.cpython-312.pyc +0 -0
- tests/__pycache__/test_training.cpython-312.pyc +0 -0
- tests/test_config.py +79 -0
- tests/test_dataset.py +43 -0
- tests/test_inference.py +44 -0
- tests/test_model.py +46 -0
- tests/test_save_model.py +49 -0
- tests/test_training.py +64 -0
.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gguf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
llama3_medichat filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/deploy.yml
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI/CD Workflow
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
build-test-deploy:
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
|
15 |
+
steps:
|
16 |
+
# Checkout repository
|
17 |
+
- name: Checkout code
|
18 |
+
uses: actions/checkout@v3
|
19 |
+
|
20 |
+
# Set up Python
|
21 |
+
- name: Setup Python
|
22 |
+
uses: actions/setup-python@v4
|
23 |
+
with:
|
24 |
+
python-version: '3.12.3'
|
25 |
+
|
26 |
+
# Install dependencies
|
27 |
+
- name: Install dependencies
|
28 |
+
run: |
|
29 |
+
python3 -m venv .venv
|
30 |
+
. .venv/bin/activate
|
31 |
+
pip install --upgrade pip
|
32 |
+
pip install -r requirements.txt
|
33 |
+
|
34 |
+
- name: Check for GPU Availability
|
35 |
+
id: gpu-check
|
36 |
+
run: |
|
37 |
+
if lspci | grep -i nvidia; then
|
38 |
+
echo "gpu=true" >> $GITHUB_ENV
|
39 |
+
else
|
40 |
+
echo "gpu=false" >> $GITHUB_ENV
|
41 |
+
fi
|
42 |
+
|
43 |
+
# Run tests
|
44 |
+
- name: Run Tests
|
45 |
+
if: env.gpu == 'true'
|
46 |
+
run: |
|
47 |
+
source .venv/bin/activate
|
48 |
+
pytest --maxfail=5 --disable-warnings
|
49 |
+
|
50 |
+
- name: Skip Tests (No GPU)
|
51 |
+
if: env.gpu == 'false'
|
52 |
+
run: |
|
53 |
+
echo "Skipping GPU-dependent tests: No GPU available."
|
54 |
+
|
55 |
+
sync-to-hub:
|
56 |
+
runs-on: ubuntu-latest
|
57 |
+
|
58 |
+
steps:
|
59 |
+
- name: Checkout repository
|
60 |
+
uses: actions/checkout@v3
|
61 |
+
|
62 |
+
- name: Set Git user identity
|
63 |
+
run: |
|
64 |
+
git config --global user.name "Hussein El Amouri"
|
65 |
+
git config --global user.email "[email protected]"
|
66 |
+
|
67 |
+
# - name: Set up Git LFS
|
68 |
+
# run: |
|
69 |
+
# git lfs install # Ensure Git LFS is installed and set up
|
70 |
+
|
71 |
+
# - name: Track large files with Git LFS
|
72 |
+
# run: |
|
73 |
+
# # Track specific large files that exceed the 10 MB limit
|
74 |
+
# git lfs track "*.gguf" # Add GGUF model to LFS
|
75 |
+
# git lfs track "*.safetensors" # Add safetensors model to LFS
|
76 |
+
# git lfs track "*.pt" # Add optimizer checkpoint to LFS
|
77 |
+
# git lfs track "*.json" # Add tokenizer to LFS
|
78 |
+
|
79 |
+
# # Add .gitattributes file to the staging area for Git LFS tracking
|
80 |
+
# git add .gitattributes
|
81 |
+
|
82 |
+
- name: Push to Hugging Face
|
83 |
+
env:
|
84 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
85 |
+
run: |
|
86 |
+
|
87 |
+
# git lfs ls-files
|
88 |
+
# git lfs fetch --all
|
89 |
+
# git lfs pull
|
90 |
+
# git rev-parse --is-shallow-repository
|
91 |
+
git filter-branch -- --all
|
92 |
+
git push https://helamouri:[email protected]/spaces/helamouri/medichat_assignment main --force # Push to Hugging Face
|
93 |
+
|
94 |
+
# - name: Set up Hugging Face CLI
|
95 |
+
# run: |
|
96 |
+
# pip install huggingface_hub
|
97 |
+
|
98 |
+
# - name: Login to Hugging Face
|
99 |
+
# env:
|
100 |
+
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
101 |
+
# run: |
|
102 |
+
# huggingface-cli login --token $HF_TOKEN
|
103 |
+
|
104 |
+
# - name: Sync with Hugging Face (including large files)
|
105 |
+
# env:
|
106 |
+
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
107 |
+
# run: |
|
108 |
+
# # Initialize git-lfs
|
109 |
+
# git lfs install
|
110 |
+
|
111 |
+
# # Pull any LFS-tracked files (if needed)
|
112 |
+
# git lfs pull
|
113 |
+
|
114 |
+
# # Push the repository to Hugging Face
|
115 |
+
# huggingface-cli upload spaces/helamouri/medichat_assignment ./* ./medichat_assignment
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore large files that are tracked by Git LFS
|
2 |
+
*.log
|
3 |
+
|
4 |
+
# Ignore build directories (e.g., for Python, Java, etc.)
|
5 |
+
env/
|
6 |
+
|
7 |
+
|
8 |
+
# Ensure .gitattributes is not ignored (needed for Git LFS tracking)
|
9 |
+
!.gitattributes
|
Dokerfile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use the official NVIDIA CUDA image as a base (you can adjust the CUDA version if needed)
|
2 |
+
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
3 |
+
|
4 |
+
# Set environment variable to avoid interactive prompts during package installation
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
# Install Python 3.9 and dependencies
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y python3.9 python3.9-dev python3.9-venv python3.9-distutils curl && \
|
10 |
+
ln -s /usr/bin/python3.9 /usr/bin/python && \
|
11 |
+
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
|
12 |
+
python get-pip.py && \
|
13 |
+
rm get-pip.py
|
14 |
+
|
15 |
+
# Set the working directory
|
16 |
+
WORKDIR /src
|
17 |
+
|
18 |
+
# Copy the requirements and application files
|
19 |
+
COPY requirements.txt .
|
20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
+
|
22 |
+
# Copy all source code
|
23 |
+
COPY . .
|
24 |
+
|
25 |
+
# Expose the default Streamlit port
|
26 |
+
EXPOSE 8501
|
27 |
+
|
28 |
+
# Run the Streamlit app
|
29 |
+
CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
Makefile
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SHELL := /bin/bash
|
2 |
+
# Makefile for Llama3.1:8B Project
|
3 |
+
|
4 |
+
# Variables
|
5 |
+
PYTHON = python
|
6 |
+
PIP = pip
|
7 |
+
VENV_DIR = ./env
|
8 |
+
VENV_PYTHON = $(VENV_DIR)/bin/python
|
9 |
+
VENV_PIP = $(VENV_DIR)/bin/pip
|
10 |
+
REQUIREMENTS = requirements.txt
|
11 |
+
|
12 |
+
# Default target
|
13 |
+
.DEFAULT_GOAL := help
|
14 |
+
|
15 |
+
# Help target
|
16 |
+
help:
|
17 |
+
@echo "Makefile for Llama3.1:8B Project"
|
18 |
+
@echo ""
|
19 |
+
@echo "Targets:"
|
20 |
+
@echo " help - Show this help message"
|
21 |
+
@echo " setup - Create virtual environment and install dependencies"
|
22 |
+
@echo " run - Run the main application"
|
23 |
+
@echo " test - Run unit tests"
|
24 |
+
@echo " lint - Run linters"
|
25 |
+
@echo " clean - Remove temporary files and directories"
|
26 |
+
@echo " clean-venv - Remove virtual environment"
|
27 |
+
@echo " purge - Clean and reinstall everything"
|
28 |
+
@echo " install - Install or update dependencies"
|
29 |
+
|
30 |
+
# Check for Python and pip
|
31 |
+
check-deps:
|
32 |
+
@echo "Checking for Python and pip..."
|
33 |
+
@if ! command -v $(PYTHON) >/dev/null 2>&1; then \
|
34 |
+
echo "Python is not installed. Please install Python3."; \
|
35 |
+
exit 1; \
|
36 |
+
fi
|
37 |
+
@echo "Python is installed."
|
38 |
+
@if ! command -v $(PIP) >/dev/null 2>&1; then \
|
39 |
+
echo "pip is not installed. Installing pip..."; \
|
40 |
+
sudo apt update && sudo apt install -y python3-pip; \
|
41 |
+
fi
|
42 |
+
@echo "pip is installed."
|
43 |
+
|
44 |
+
# Create virtual environment and install dependencies
|
45 |
+
setup: check-deps
|
46 |
+
@echo "Setting up virtual environment..."
|
47 |
+
@if [ ! -d "$(VENV_DIR)" ]; then \
|
48 |
+
$(PYTHON) -m venv $(VENV_DIR); \
|
49 |
+
echo "Virtual environment created."; \
|
50 |
+
fi
|
51 |
+
@echo "Installing dependencies..."
|
52 |
+
$(VENV_PIP) install --upgrade pip
|
53 |
+
$(VENV_PIP) install -r $(REQUIREMENTS)
|
54 |
+
@echo "Setup completed."
|
55 |
+
|
56 |
+
# Run the main application
|
57 |
+
run:
|
58 |
+
@echo "Running the application..."
|
59 |
+
$(VENV_PYTHON) main.py
|
60 |
+
|
61 |
+
# Run tests
|
62 |
+
test:
|
63 |
+
@echo "Running tests..."
|
64 |
+
$(VENV_PYTHON) -m unittest discover tests
|
65 |
+
|
66 |
+
# Run linters
|
67 |
+
lint:
|
68 |
+
@echo "Running linters..."
|
69 |
+
$(VENV_PYTHON) -m flake8 src/ tests/
|
70 |
+
|
71 |
+
# Clean temporary files and directories
|
72 |
+
clean:
|
73 |
+
@echo "Cleaning temporary files and directories..."
|
74 |
+
find . -type f -name '*.pyc' -delete
|
75 |
+
find . -type d -name '__pycache__' -exec rm -r {} +
|
76 |
+
@echo "Cleanup completed."
|
77 |
+
|
78 |
+
# Clean virtual environment
|
79 |
+
clean-venv:
|
80 |
+
@echo "Removing virtual environment..."
|
81 |
+
rm -rf $(VENV_DIR)
|
82 |
+
@echo "Virtual environment removed."
|
83 |
+
|
84 |
+
# Purge: remove all and reinstall environment
|
85 |
+
purge: clean clean-venv setup
|
86 |
+
|
87 |
+
# Install or update dependencies
|
88 |
+
install:
|
89 |
+
@echo "Installing or updating dependencies..."
|
90 |
+
$(VENV_PIP) install -r $(REQUIREMENTS)
|
91 |
+
@echo "Dependencies installed or updated."
|
README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MediChat
|
3 |
+
emoji: 🩺
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: "1.40.1" # Replace with the actual version of your SDK
|
8 |
+
app_file: app.py # Replace with the main app file name
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
[](https://github.com/hussein88al88amouri/medichat_assignment/actions/workflows/deploy.yml)
|
13 |
+
|
14 |
+
# MediChat: AI-Powered Medical Consultation Assistant
|
15 |
+
|
16 |
+
MediChat is an intelligent chatbot designed to provide medical consultations using a fine-tuned Llama3.1:8B model. The project bridges advanced AI capabilities with practical healthcare assistance.
|
17 |
+
|
18 |
+
## Features
|
19 |
+
- Fine-tuned model for medical conversations
|
20 |
+
- Interactive and user-friendly interface
|
21 |
+
- Secure and containerized deployment
|
22 |
+
|
23 |
+
## How to Use
|
24 |
+
1. Access the chatbot interface.
|
25 |
+
2. Input your medical query.
|
26 |
+
3. Receive intelligent and context-aware responses.
|
27 |
+
|
28 |
+
## Technical Details
|
29 |
+
- Model: Llama3.1:8B
|
30 |
+
- Framework: Gradio
|
31 |
+
|
32 |
+
## Installation
|
33 |
+
1. Clone the repository:
|
34 |
+
```bash
|
35 |
+
git clone https://github.com/your_username/medichat.git
|
36 |
+
cd medichat
|
37 |
+
2. Build and run the Docker container: (in bash copy the following code)
|
38 |
+
docker build -t medichat-app .
|
39 |
+
docker run -p 8501:8501 medichat-app
|
40 |
+
3. Access the app at http://localhost:8501.
|
41 |
+
|
42 |
+
Limitations
|
43 |
+
This tool is not a replacement for professional medical advice.
|
44 |
+
For critical issues, always consult a licensed medical professional.
|
45 |
+
|
Setup.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
# Read the requirements from the requirements.txt file
|
5 |
+
def parse_requirements():
|
6 |
+
requirements_path = Path(__file__).parent / 'requirements.txt'
|
7 |
+
with open(requirements_path, 'r') as file:
|
8 |
+
return [line.strip() for line in file if line.strip() and not line.startswith('#')]
|
9 |
+
|
10 |
+
setup(
|
11 |
+
# The name of your package.
|
12 |
+
name='medichat',
|
13 |
+
|
14 |
+
# A version number for your package.
|
15 |
+
version='0.1.0',
|
16 |
+
|
17 |
+
# A brief summary of what your package does.
|
18 |
+
description='A fine-tuned LLM for medical consultations based on the Meta-Llama 3.1 8B model.',
|
19 |
+
|
20 |
+
# The URL of your project's homepage.
|
21 |
+
url='https://github.com/hussein88al88amouri/medichat',
|
22 |
+
|
23 |
+
# The author’s name.
|
24 |
+
author='Hussein El Amouri',
|
25 |
+
|
26 |
+
# The author’s email address.
|
27 |
+
author_email='[email protected]',
|
28 |
+
|
29 |
+
# This defines which packages should be included in the distribution.
|
30 |
+
packages=find_packages(),
|
31 |
+
|
32 |
+
# Read dependencies from the requirements.txt
|
33 |
+
install_requires=parse_requirements(),
|
34 |
+
|
35 |
+
# Additional classification of your package.
|
36 |
+
classifiers=[
|
37 |
+
'Development Status :: 3 - Alpha',
|
38 |
+
'Intended Audience :: Developers',
|
39 |
+
'License :: OSI Approved :: MIT License',
|
40 |
+
'Programming Language :: Python :: 3',
|
41 |
+
'Programming Language :: Python :: 3.8',
|
42 |
+
'Programming Language :: Python :: 3.9',
|
43 |
+
'Programming Language :: Python :: 3.10',
|
44 |
+
],
|
45 |
+
|
46 |
+
# A license for your package.
|
47 |
+
license='MIT',
|
48 |
+
|
49 |
+
# You can add entry points for command-line tools if your package includes such functionality.
|
50 |
+
entry_points={
|
51 |
+
'console_scripts': [
|
52 |
+
'medichat=medichat.cli:main', # Adjust to your actual CLI entry point, if any
|
53 |
+
],
|
54 |
+
},
|
55 |
+
|
56 |
+
# If you have data files (like configuration files), you can specify them here.
|
57 |
+
data_files=[
|
58 |
+
# Example of configuration files for saving the model, etc.
|
59 |
+
('share/config', ['config/config.json']),
|
60 |
+
],
|
61 |
+
|
62 |
+
# If your package has specific testing requirements or needs test dependencies, list them here.
|
63 |
+
extras_require={
|
64 |
+
'dev': ['pytest', 'tox'], # Optional dependencies for development or testing
|
65 |
+
'docs': ['sphinx'], # Optional dependencies for documentation generation
|
66 |
+
},
|
67 |
+
|
68 |
+
# Specify your package's minimum supported Python version
|
69 |
+
python_requires='>=3.8',
|
70 |
+
|
71 |
+
# If your package includes command-line scripts, you can list them here
|
72 |
+
scripts=['scripts/cli_script.py'], # Update path if you have a script to run
|
73 |
+
|
74 |
+
# If your package includes C extensions or other modules, specify them here.
|
75 |
+
ext_modules=[],
|
76 |
+
)
|
User.code-workspace
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"folders": [
|
3 |
+
{
|
4 |
+
"path": "C:/Users/hasso/AppData/Roaming/Code/User"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"path": ".."
|
8 |
+
}
|
9 |
+
],
|
10 |
+
"settings": {
|
11 |
+
"workbench.colorCustomizations": {
|
12 |
+
"activityBar.activeBackground": "#5b5b5b",
|
13 |
+
"activityBar.background": "#5b5b5b",
|
14 |
+
"activityBar.foreground": "#e7e7e7",
|
15 |
+
"activityBar.inactiveForeground": "#e7e7e799",
|
16 |
+
"activityBarBadge.background": "#103010",
|
17 |
+
"activityBarBadge.foreground": "#e7e7e7",
|
18 |
+
"commandCenter.border": "#e7e7e799",
|
19 |
+
"sash.hoverBorder": "#5b5b5b",
|
20 |
+
"statusBar.background": "#424242",
|
21 |
+
"statusBar.foreground": "#e7e7e7",
|
22 |
+
"statusBarItem.hoverBackground": "#5b5b5b",
|
23 |
+
"statusBarItem.remoteBackground": "#424242",
|
24 |
+
"statusBarItem.remoteForeground": "#e7e7e7",
|
25 |
+
"titleBar.activeBackground": "#424242",
|
26 |
+
"titleBar.activeForeground": "#e7e7e7",
|
27 |
+
"titleBar.inactiveBackground": "#42424299",
|
28 |
+
"titleBar.inactiveForeground": "#e7e7e799"
|
29 |
+
},
|
30 |
+
"peacock.color": "#424242"
|
31 |
+
}
|
32 |
+
}
|
__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# __init__.py
|
2 |
+
|
3 |
+
# Import necessary modules or functions from submodules.
|
4 |
+
# This is where you can aggregate the public API of your package.
|
5 |
+
|
6 |
+
# Example:
|
7 |
+
# from .module_name import function_name, ClassName
|
8 |
+
|
9 |
+
# You can also import specific components to make them available directly from the package level.
|
10 |
+
# For example:
|
11 |
+
# from .subpackage.module_name import function_name
|
12 |
+
from src import llama3_finetune
|
13 |
+
from src import main
|
14 |
+
|
15 |
+
# Initialize any package-level variables or constants
|
16 |
+
# For example, if you have any version number or author info, you can define it here.
|
17 |
+
|
18 |
+
__version__ = "0.1.0" # Replace with your actual package version
|
19 |
+
__author__ = "Hussein El Amouri" # Replace with your name or the author name
|
20 |
+
|
21 |
+
# You can include initialization code here, if your package requires any.
|
22 |
+
# For example, setting up logging, initializing global variables, etc.
|
23 |
+
|
24 |
+
# Example:
|
25 |
+
# import logging
|
26 |
+
# logging.basicConfig(level=logging.INFO)
|
27 |
+
|
28 |
+
# Define a list of publicly exposed items (optional)
|
29 |
+
# This list is used to specify which functions, classes, or variables
|
30 |
+
# should be available when `from package_name import *` is used.
|
31 |
+
|
32 |
+
# Example:
|
33 |
+
__all__ = [
|
34 |
+
'function_name', # List the names of the functions, classes, or variables you want exposed
|
35 |
+
'ClassName',
|
36 |
+
]
|
37 |
+
|
38 |
+
# If your package uses a specific function or submodule as the primary entry point,
|
39 |
+
# you can set that here.
|
40 |
+
# For example, if the main function of the package is in a submodule called 'main.py',
|
41 |
+
# you can import that here:
|
42 |
+
# from .main import run
|
43 |
+
|
44 |
+
# Initialize any necessary package-specific code here, if needed
|
45 |
+
# Example for adding environment setup, database initialization, etc.
|
46 |
+
|
47 |
+
# If your package contains a command-line interface (CLI), you can import it here,
|
48 |
+
# so it can be executed as a script if the package is installed:
|
49 |
+
# from .cli import main
|
50 |
+
|
51 |
+
# Any other necessary imports that users should be aware of can go here.
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
# from unsloth import FastLanguageModel
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import torch
|
5 |
+
from llama_cpp import Llama
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
# # Suppress unwanted outputs (e.g., from unsloth or other libraries)
|
11 |
+
# def suppress_output():
|
12 |
+
# sys.stdout = open(os.devnull, 'w') # Redirect stdout to devnull
|
13 |
+
# sys.stderr = open(os.devnull, 'w') # Redirect stderr to devnull
|
14 |
+
|
15 |
+
# def restore_output():
|
16 |
+
# sys.stdout = sys.__stdout__ # Restore stdout
|
17 |
+
# sys.stderr = sys.__stderr__ # Restore stderr
|
18 |
+
|
19 |
+
# Load the model (GGUF format)
|
20 |
+
@st.cache_resource
|
21 |
+
def load_model():
|
22 |
+
# Define the repository and model filenames for both the base model and LoRA adapter
|
23 |
+
base_model_repo = "helamouri/Meta-Llama-3.1-8B-Q8_0.gguf"
|
24 |
+
base_model_filename = "Meta-Llama-3.1-8B-Q8_0.gguf"
|
25 |
+
adapter_repo = "helamouri/medichat_assignment"
|
26 |
+
# adapter_filename = "llama3_medichat.gguf" # assuming adapter is also in safetensors format
|
27 |
+
adapter_repo = "helamouri/model_medichat_finetuned_v1"
|
28 |
+
|
29 |
+
# Download the base model and adapter model to local paths
|
30 |
+
base_model_path = hf_hub_download(repo_id=base_model_repo, filename=base_model_filename)
|
31 |
+
adapter_model_path = hf_hub_download(repo_id=adapter_repo, filename=adapter_filename)
|
32 |
+
|
33 |
+
# Log paths for debugging
|
34 |
+
print(f"Base model path: {base_model_path}")
|
35 |
+
print(f"Adapter model path: {adapter_model_path}")
|
36 |
+
|
37 |
+
# Load the full model (base model) and the adapter (LoRA)
|
38 |
+
try:
|
39 |
+
model = Llama(model_path=base_model_path) #, adapter_path=adapter_model_path)
|
40 |
+
print("Model loaded successfully.")
|
41 |
+
except ValueError as e:
|
42 |
+
print(f"Error loading model: {e}")
|
43 |
+
raise
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
# Generate a response using Llama.cpp
|
48 |
+
def generate_response(model, prompt):
|
49 |
+
print('prompt')
|
50 |
+
print(prompt)
|
51 |
+
response = model(
|
52 |
+
prompt,
|
53 |
+
max_tokens=200, # Maximum tokens for the response
|
54 |
+
temperature=0.7, # Adjust for creativity (lower = deterministic)
|
55 |
+
top_p=0.9, # Nucleus sampling
|
56 |
+
stop=["\n"] # Stop generating when newline is encountered
|
57 |
+
)
|
58 |
+
print('response["choices"]')
|
59 |
+
print(response["choices"])
|
60 |
+
return response["choices"][0]["text"]
|
61 |
+
|
62 |
+
# Load the model and tokenizer (GGUF format)
|
63 |
+
# @st.cache_resource
|
64 |
+
# def load_model():
|
65 |
+
# model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's GGUF path
|
66 |
+
# model = FastLanguageModel.from_pretrained(model_name, device='cpu') # Load the model using unsloth
|
67 |
+
# tokenizer = model.tokenizer # Assuming the tokenizer is part of the GGUF model object
|
68 |
+
# return tokenizer, model
|
69 |
+
|
70 |
+
|
71 |
+
# @st.cache_resource
|
72 |
+
# def load_model():
|
73 |
+
# model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's path
|
74 |
+
# # Load the tokenizer
|
75 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
76 |
+
# # Load the model (if it's a causal language model or suitable model type)
|
77 |
+
# model = AutoModelForCausalLM.from_pretrained(model_name,
|
78 |
+
# device_map="cpu",
|
79 |
+
# revision="main",
|
80 |
+
# quantize=False,
|
81 |
+
# load_in_8bit=False,
|
82 |
+
# load_in_4bit=False,
|
83 |
+
# #torch_dtype=torch.float32
|
84 |
+
# )
|
85 |
+
# return tokenizer, model
|
86 |
+
|
87 |
+
# Suppress unwanted outputs from unsloth or any other libraries during model loading
|
88 |
+
#suppress_output()
|
89 |
+
|
90 |
+
# Load the GGUF model
|
91 |
+
print('Loading the model')
|
92 |
+
model = load_model()
|
93 |
+
# Restore stdout and stderr
|
94 |
+
|
95 |
+
#restore_output()
|
96 |
+
|
97 |
+
# App layout
|
98 |
+
print('Setting App layout')
|
99 |
+
st.title("MediChat: Your AI Medical Consultation Assistant")
|
100 |
+
st.markdown("Ask me anything about your health!")
|
101 |
+
st.write("Enter your symptoms or medical questions below:")
|
102 |
+
|
103 |
+
# User input
|
104 |
+
print(f'Setting user interface')
|
105 |
+
user_input = st.text_input("Your Question:")
|
106 |
+
if st.button("Get Response"):
|
107 |
+
if user_input:
|
108 |
+
with st.spinner("Generating response..."):
|
109 |
+
# Generate Response
|
110 |
+
response = generate_response(model, user_input)
|
111 |
+
print('Response')
|
112 |
+
print(response)
|
113 |
+
# Display response
|
114 |
+
st.text_area("Response:", value=response, height=200)
|
115 |
+
else:
|
116 |
+
st.warning("Please enter a question.")
|
pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
markers =
|
3 |
+
gpu: marks tests that require a GPU
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install the multi-backend version of bitsandbytes
|
2 |
+
https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl
|
3 |
+
llama-cpp-python
|
4 |
+
datasets
|
5 |
+
huggingface_hub
|
6 |
+
huggingface_hub[cli]
|
7 |
+
unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
|
8 |
+
xformers==0.0.28.post2
|
9 |
+
trl
|
10 |
+
peft
|
11 |
+
accelerate
|
12 |
+
bitsandbytes
|
13 |
+
torchvision
|
14 |
+
torch
|
15 |
+
sentencepiece
|
16 |
+
transformers[torch]>=4.45.1
|
17 |
+
streamlit==1.40.1
|
18 |
+
gguf>=0.10.0
|
19 |
+
pytest
|
20 |
+
flake8
|
space.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# docker-compose.yaml for Hugging Face Spaces
|
2 |
+
version: '3.8'
|
3 |
+
|
4 |
+
services:
|
5 |
+
app:
|
6 |
+
build:
|
7 |
+
context: .
|
8 |
+
dockerfile: Dockerfile
|
9 |
+
ports:
|
10 |
+
- "7860:7860" # Default port for Streamlit or Gradio apps
|
11 |
+
environment:
|
12 |
+
HF_TOKEN: ${HF_TOKEN} # Hugging Face API token
|
13 |
+
command: >
|
14 |
+
bash -c "
|
15 |
+
python3 -m venv .venv &&
|
16 |
+
. .venv/bin/activate &&
|
17 |
+
pip install -r requirements.txt &&
|
18 |
+
python main.py
|
19 |
+
"
|
src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import *
|
2 |
+
from .model import load_model, configure_peft_model
|
3 |
+
from .dataset import load_and_prepare_dataset, formatting_prompts_func
|
4 |
+
from .training import train_model
|
5 |
+
from .inference import prepare_inference_inputs, generate_responses, stream_responses
|
6 |
+
from .save_model import save_model_and_tokenizer
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (601 Bytes). View file
|
|
src/__pycache__/config.cpython-312.pyc
ADDED
Binary file (1.16 kB). View file
|
|
src/__pycache__/dataset.cpython-312.pyc
ADDED
Binary file (1.37 kB). View file
|
|
src/__pycache__/inference.cpython-312.pyc
ADDED
Binary file (2.51 kB). View file
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
Binary file (1.41 kB). View file
|
|
src/__pycache__/save_model.cpython-312.pyc
ADDED
Binary file (1.11 kB). View file
|
|
src/__pycache__/training.cpython-312.pyc
ADDED
Binary file (743 Bytes). View file
|
|
src/config.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_device_map():
|
5 |
+
if torch.cuda.is_available():
|
6 |
+
return {'' : torch.cuda.current_device()}
|
7 |
+
else:
|
8 |
+
return {} # Or some default, fallback configuration
|
9 |
+
|
10 |
+
# General configuration
|
11 |
+
MAX_SEQ_LENGTH = 2**4
|
12 |
+
DTYPE = None
|
13 |
+
LOAD_IN_4BIT = True
|
14 |
+
DEVICE_MAP = {'': get_device_map()}
|
15 |
+
EOS_TOKEN = None # Set dynamically based on tokenizer
|
16 |
+
|
17 |
+
# Alpaca prompt template
|
18 |
+
ALPACA_PROMPT_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
19 |
+
|
20 |
+
###Instruction:
|
21 |
+
{}
|
22 |
+
|
23 |
+
###Input:
|
24 |
+
{}
|
25 |
+
|
26 |
+
###Response:
|
27 |
+
{}"""
|
28 |
+
|
29 |
+
# Training arguments
|
30 |
+
TRAIN_ARGS = {
|
31 |
+
"per_device_train_batch_size": 2,
|
32 |
+
"gradient_accumulation_steps": 4,
|
33 |
+
"warmup_steps": 5,
|
34 |
+
"max_steps": 60,
|
35 |
+
"learning_rate": 2e-4,
|
36 |
+
"fp16": not torch.cuda.is_bf16_supported(),
|
37 |
+
"bf16": torch.cuda.is_bf16_supported(),
|
38 |
+
"logging_steps": 1,
|
39 |
+
"optim": "adamw_8bit",
|
40 |
+
"weight_decay": 0.01,
|
41 |
+
"lr_scheduler_type": "linear",
|
42 |
+
"seed": 3407,
|
43 |
+
"output_dir": "outputs",
|
44 |
+
}
|
src/dataset.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
|
3 |
+
def formatting_prompts_func(examples, template, eos_token):
|
4 |
+
instructions = examples["instruction"]
|
5 |
+
inputs = examples["input"]
|
6 |
+
outputs = examples["output"]
|
7 |
+
|
8 |
+
# Format the examples using the provided template
|
9 |
+
texts = []
|
10 |
+
for instruction, input_text, output in zip(instructions, inputs, outputs):
|
11 |
+
text = template.format(instruction, input_text, output) + eos_token
|
12 |
+
texts.append(text)
|
13 |
+
|
14 |
+
# Return a dictionary with the formatted text
|
15 |
+
return {"text": texts}
|
16 |
+
|
17 |
+
def load_and_prepare_dataset(dataset_name, nsamples, formatting_func, template, eos_token):
|
18 |
+
# Load the dataset and prepare it by applying the formatting function
|
19 |
+
dataset = load_dataset(dataset_name, split="train").select(range(nsamples))
|
20 |
+
|
21 |
+
# Map the formatting function over the dataset
|
22 |
+
return dataset.map(lambda examples: formatting_func(examples, template, eos_token), batched=True)
|
src/fine_tune_llama.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src import *
|
2 |
+
|
3 |
+
# Load configuration
|
4 |
+
max_seq_length = config.MAX_SEQ_LENGTH
|
5 |
+
device_map = config.DEVICE_MAP
|
6 |
+
eos_token = config.EOS_TOKEN
|
7 |
+
|
8 |
+
# Load and configure model
|
9 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
10 |
+
model, tokenizer = load_model(model_name, max_seq_length, config.DTYPE, config.LOAD_IN_4BIT, device_map)
|
11 |
+
eos_token = tokenizer.eos_token
|
12 |
+
|
13 |
+
model = configure_peft_model(model, target_modules=["q_proj", "down_proj"])
|
14 |
+
|
15 |
+
# Prepare dataset
|
16 |
+
nsamples = 1000
|
17 |
+
dataset = load_and_prepare_dataset(
|
18 |
+
"lavita/ChatDoctor-HealthCareMagic-100k",
|
19 |
+
nsamples,
|
20 |
+
formatting_prompts_func,
|
21 |
+
config.ALPACA_PROMPT_TEMPLATE,
|
22 |
+
eos_token,
|
23 |
+
)
|
24 |
+
|
25 |
+
# Train model
|
26 |
+
trainer_stats = train_model(
|
27 |
+
model=model,
|
28 |
+
tokenizer=tokenizer,
|
29 |
+
train_dataset=dataset,
|
30 |
+
dataset_text_field="text",
|
31 |
+
max_seq_length=max_seq_length,
|
32 |
+
dataset_num_proc=2,
|
33 |
+
packing=False,
|
34 |
+
training_args=config.TRAIN_ARGS,
|
35 |
+
)
|
36 |
+
|
37 |
+
# Save the model
|
38 |
+
save_model_and_tokenizer(model, tokenizer, "./llama3_medichat")
|
src/inference.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TextStreamer
|
2 |
+
|
3 |
+
def prepare_inference_inputs(tokenizer, template, instruction, input_text, eos_token, device="cuda"):
|
4 |
+
"""
|
5 |
+
Prepares the inputs for inference by formatting the prompt and tokenizing it.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
- tokenizer: The tokenizer used for tokenization.
|
9 |
+
- template: The template string for the prompt format.
|
10 |
+
- instruction: The instruction to be included in the prompt.
|
11 |
+
- input_text: The input to be included in the prompt.
|
12 |
+
- eos_token: The end of sequence token.
|
13 |
+
- device: The device for the model ('cuda' or 'cpu').
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
- Tokenized inputs ready for inference.
|
17 |
+
"""
|
18 |
+
prompt = template.format(instruction, input_text, "") + eos_token
|
19 |
+
return tokenizer([prompt], return_tensors="pt").to(device)
|
20 |
+
|
21 |
+
def generate_responses(model, inputs, tokenizer, max_new_tokens=64):
|
22 |
+
"""
|
23 |
+
Generates responses from the model based on the provided inputs.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
- model: The pre-trained model for generation.
|
27 |
+
- inputs: The tokenized inputs to generate responses.
|
28 |
+
- tokenizer: The tokenizer used to decode the output.
|
29 |
+
- max_new_tokens: The maximum number of tokens to generate.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
- Decoded responses from the model.
|
33 |
+
"""
|
34 |
+
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
|
35 |
+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
36 |
+
|
37 |
+
def stream_responses(model, inputs, tokenizer, max_new_tokens=128):
|
38 |
+
"""
|
39 |
+
Streams the model's response using a text streamer.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
- model: The pre-trained model for generation.
|
43 |
+
- inputs: The tokenized inputs to generate responses.
|
44 |
+
- tokenizer: The tokenizer used to decode the output.
|
45 |
+
- max_new_tokens: The maximum number of tokens to generate.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
- Streams the output directly.
|
49 |
+
"""
|
50 |
+
text_streamer = TextStreamer(tokenizer)
|
51 |
+
model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_new_tokens)
|
src/model.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from unsloth import FastLanguageModel
|
3 |
+
|
4 |
+
def load_model(model_name, max_seq_length, dtype, load_in_4bit, device_map):
|
5 |
+
try:
|
6 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
7 |
+
model_name=model_name,
|
8 |
+
max_seq_length=max_seq_length,
|
9 |
+
dtype=dtype,
|
10 |
+
load_in_4bit=load_in_4bit,
|
11 |
+
device_map=device_map,
|
12 |
+
)
|
13 |
+
return model, tokenizer
|
14 |
+
except Exception as e:
|
15 |
+
raise RuntimeError(f"Failed to load model {model_name}: {e}")
|
16 |
+
|
17 |
+
def configure_peft_model(model, target_modules, lora_alpha=16, lora_dropout=0, random_state=3407, use_rslora=False):
|
18 |
+
try:
|
19 |
+
peft_model = FastLanguageModel.get_peft_model(
|
20 |
+
model=model,
|
21 |
+
target_modules=target_modules,
|
22 |
+
lora_alpha=lora_alpha,
|
23 |
+
lora_dropout=lora_dropout,
|
24 |
+
bias="none",
|
25 |
+
use_gradient_checkpointing="unsloth",
|
26 |
+
random_state=random_state,
|
27 |
+
use_rslora=use_rslora,
|
28 |
+
loftq_config=None,
|
29 |
+
)
|
30 |
+
return peft_model
|
31 |
+
except Exception as e:
|
32 |
+
raise RuntimeError(f"Failed to configure PEFT model: {e}")
|
src/save_model.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def save_model_and_tokenizer(model, tokenizer, save_directory):
|
4 |
+
"""
|
5 |
+
Save model and tokenizer to the specified directory.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
- model: The model to save.
|
9 |
+
- tokenizer: The tokenizer to save.
|
10 |
+
- save_directory: Directory where the model and tokenizer should be saved.
|
11 |
+
"""
|
12 |
+
try:
|
13 |
+
# Ensure the save directory exists
|
14 |
+
os.makedirs(save_directory, exist_ok=True)
|
15 |
+
|
16 |
+
# Save model and tokenizer
|
17 |
+
model.save_pretrained(save_directory, safe_serialization=True)
|
18 |
+
tokenizer.save_pretrained(save_directory)
|
19 |
+
|
20 |
+
print(f"Model and tokenizer saved locally at {save_directory}")
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error saving model and tokenizer: {str(e)}")
|
23 |
+
raise
|
src/training.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trl import SFTTrainer
|
2 |
+
from transformers import TrainingArguments
|
3 |
+
|
4 |
+
def train_model(model, tokenizer, train_dataset, dataset_text_field, max_seq_length, dataset_num_proc, packing, training_args):
|
5 |
+
trainer = SFTTrainer(
|
6 |
+
model=model,
|
7 |
+
tokenizer=tokenizer,
|
8 |
+
train_dataset=train_dataset,
|
9 |
+
dataset_text_field=dataset_text_field,
|
10 |
+
max_seq_length=max_seq_length,
|
11 |
+
dataset_num_proc=dataset_num_proc,
|
12 |
+
packing=packing,
|
13 |
+
args=TrainingArguments(**training_args),
|
14 |
+
)
|
15 |
+
|
16 |
+
# Train the model
|
17 |
+
train_results = trainer.train()
|
18 |
+
|
19 |
+
# Optionally, you can return more specific training information if necessary
|
20 |
+
return train_results
|
tests/__init__.py
ADDED
File without changes
|
tests/__pycache__/test_config.cpython-312.pyc
ADDED
Binary file (4.25 kB). View file
|
|
tests/__pycache__/test_dataset.cpython-312.pyc
ADDED
Binary file (1.66 kB). View file
|
|
tests/__pycache__/test_inference.cpython-312.pyc
ADDED
Binary file (1.86 kB). View file
|
|
tests/__pycache__/test_model.cpython-312.pyc
ADDED
Binary file (1.83 kB). View file
|
|
tests/__pycache__/test_save_model.cpython-312.pyc
ADDED
Binary file (2.89 kB). View file
|
|
tests/__pycache__/test_training.cpython-312.pyc
ADDED
Binary file (1.86 kB). View file
|
|
tests/test_config.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
from src.config import (MAX_SEQ_LENGTH, DTYPE, LOAD_IN_4BIT, DEVICE_MAP, EOS_TOKEN,
|
4 |
+
ALPACA_PROMPT_TEMPLATE, TRAIN_ARGS)
|
5 |
+
|
6 |
+
# Test that required configuration keys are present
|
7 |
+
def test_required_config_keys():
|
8 |
+
assert MAX_SEQ_LENGTH is not None, "MAX_SEQ_LENGTH is not set."
|
9 |
+
assert TRAIN_ARGS is not None, "TRAIN_ARGS is not set."
|
10 |
+
assert ALPACA_PROMPT_TEMPLATE is not None, "ALPACA_PROMPT_TEMPLATE is not set."
|
11 |
+
assert DEVICE_MAP is not None, "DEVICE_MAP is not set."
|
12 |
+
|
13 |
+
# Test that MAX_SEQ_LENGTH is a power of two
|
14 |
+
def test_max_seq_length():
|
15 |
+
assert isinstance(MAX_SEQ_LENGTH, int), "MAX_SEQ_LENGTH should be an integer."
|
16 |
+
assert MAX_SEQ_LENGTH > 0, "MAX_SEQ_LENGTH should be greater than 0."
|
17 |
+
assert (MAX_SEQ_LENGTH & (MAX_SEQ_LENGTH - 1)) == 0, "MAX_SEQ_LENGTH should be a power of two."
|
18 |
+
|
19 |
+
# Test that TRAIN_ARGS dictionary contains required fields and types
|
20 |
+
def test_train_args():
|
21 |
+
required_keys = [
|
22 |
+
"per_device_train_batch_size",
|
23 |
+
"gradient_accumulation_steps",
|
24 |
+
"warmup_steps",
|
25 |
+
"max_steps",
|
26 |
+
"learning_rate",
|
27 |
+
"fp16",
|
28 |
+
"bf16",
|
29 |
+
"logging_steps",
|
30 |
+
"optim",
|
31 |
+
"weight_decay",
|
32 |
+
"lr_scheduler_type",
|
33 |
+
"seed",
|
34 |
+
"output_dir"
|
35 |
+
]
|
36 |
+
|
37 |
+
for key in required_keys:
|
38 |
+
assert key in TRAIN_ARGS, f"Missing {key} in TRAIN_ARGS."
|
39 |
+
|
40 |
+
# Check types of specific fields
|
41 |
+
assert isinstance(TRAIN_ARGS["per_device_train_batch_size"], int), "per_device_train_batch_size should be an integer."
|
42 |
+
assert isinstance(TRAIN_ARGS["learning_rate"], float), "learning_rate should be a float."
|
43 |
+
assert isinstance(TRAIN_ARGS["output_dir"], str), "output_dir should be a string."
|
44 |
+
|
45 |
+
# Test that the DEVICE_MAP references a valid CUDA device
|
46 |
+
@pytest.mark.gpu
|
47 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
48 |
+
def test_device_map():
|
49 |
+
device = DEVICE_MAP.get('', None)
|
50 |
+
assert device is not None, "DEVICE_MAP should reference a CUDA device."
|
51 |
+
assert isinstance(device, int), "DEVICE_MAP should be an integer (CUDA device ID)."
|
52 |
+
assert torch.cuda.is_available(), "CUDA is not available, but DEVICE_MAP points to a CUDA device."
|
53 |
+
|
54 |
+
# Test that the EOS_TOKEN is set dynamically based on the tokenizer
|
55 |
+
def test_eos_token():
|
56 |
+
assert EOS_TOKEN is not None, "EOS_TOKEN should be dynamically set based on tokenizer."
|
57 |
+
|
58 |
+
# Test the ALPACA_PROMPT_TEMPLATE for expected formatting
|
59 |
+
def test_alpaca_prompt_template():
|
60 |
+
test_instruction = "Test Instruction"
|
61 |
+
test_input = "Test Input"
|
62 |
+
test_output = "Test Output"
|
63 |
+
|
64 |
+
formatted_prompt = ALPACA_PROMPT_TEMPLATE.format(test_instruction, test_input, test_output)
|
65 |
+
|
66 |
+
# Ensure that the prompt template contains the required placeholders
|
67 |
+
assert "{}" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain placeholders."
|
68 |
+
assert "###Instruction:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Instruction'."
|
69 |
+
assert "###Input:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Input'."
|
70 |
+
assert "###Response:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Response'."
|
71 |
+
|
72 |
+
# Test that the LOAD_IN_4BIT setting is a boolean
|
73 |
+
def test_load_in_4bit():
|
74 |
+
assert isinstance(LOAD_IN_4BIT, bool), "LOAD_IN_4BIT should be a boolean."
|
75 |
+
|
76 |
+
# Test for the DTYPE (should be None or a valid data type)
|
77 |
+
def test_dtype():
|
78 |
+
assert DTYPE is None or isinstance(DTYPE, type), "DTYPE should be None or a valid data type."
|
79 |
+
|
tests/test_dataset.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.dataset import formatting_prompts_func
|
2 |
+
|
3 |
+
def test_formatting_prompts_func():
|
4 |
+
# Test case with basic input
|
5 |
+
examples = {
|
6 |
+
"instruction": ["Test instruction"],
|
7 |
+
"input": ["Test input"],
|
8 |
+
"output": ["Test output"],
|
9 |
+
}
|
10 |
+
template = "Instruction: {}\nInput: {}\nOutput: {}"
|
11 |
+
eos_token = "<EOS>"
|
12 |
+
|
13 |
+
result = formatting_prompts_func(examples, template, eos_token)
|
14 |
+
|
15 |
+
# Check if result contains the 'text' key
|
16 |
+
assert "text" in result
|
17 |
+
|
18 |
+
# Check if result contains exactly one formatted entry
|
19 |
+
assert len(result["text"]) == 1
|
20 |
+
|
21 |
+
# Check if the formatted text is correct
|
22 |
+
expected = "Instruction: Test instruction\nInput: Test input\nOutput: Test output<EOS>"
|
23 |
+
assert result["text"][0] == expected
|
24 |
+
|
25 |
+
# Test with empty inputs (edge case)
|
26 |
+
examples_empty = {
|
27 |
+
"instruction": [""],
|
28 |
+
"input": [""],
|
29 |
+
"output": [""],
|
30 |
+
}
|
31 |
+
result_empty = formatting_prompts_func(examples_empty, template, eos_token)
|
32 |
+
assert result_empty["text"][0] == "Instruction: \nInput: \nOutput: <EOS>"
|
33 |
+
|
34 |
+
# Test with multiple examples
|
35 |
+
examples_multi = {
|
36 |
+
"instruction": ["Test instruction 1", "Test instruction 2"],
|
37 |
+
"input": ["Test input 1", "Test input 2"],
|
38 |
+
"output": ["Test output 1", "Test output 2"],
|
39 |
+
}
|
40 |
+
result_multi = formatting_prompts_func(examples_multi, template, eos_token)
|
41 |
+
assert len(result_multi["text"]) == 2
|
42 |
+
assert result_multi["text"][0] == "Instruction: Test instruction 1\nInput: Test input 1\nOutput: Test output 1<EOS>"
|
43 |
+
assert result_multi["text"][1] == "Instruction: Test instruction 2\nInput: Test input 2\nOutput: Test output 2<EOS>"
|
tests/test_inference.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.inference import prepare_inference_inputs, generate_responses
|
2 |
+
from src.model import load_model
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
|
6 |
+
@pytest.mark.gpu
|
7 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
8 |
+
def test_gpu_feature():
|
9 |
+
# Your test code that needs a GPU
|
10 |
+
assert torch.cuda.is_available()
|
11 |
+
|
12 |
+
@pytest.mark.gpu
|
13 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
14 |
+
@pytest.fixture
|
15 |
+
def model_and_tokenizer():
|
16 |
+
"""Fixture to load model and tokenizer for inference"""
|
17 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
18 |
+
model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
|
19 |
+
return model, tokenizer
|
20 |
+
|
21 |
+
@pytest.mark.gpu
|
22 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
23 |
+
def test_inference(model_and_tokenizer):
|
24 |
+
model, tokenizer = model_and_tokenizer
|
25 |
+
|
26 |
+
# Test input values
|
27 |
+
instruction = "What is your name?"
|
28 |
+
input_text = "Tell me about yourself."
|
29 |
+
eos_token = "<EOS>"
|
30 |
+
|
31 |
+
# Prepare inference inputs
|
32 |
+
inputs = prepare_inference_inputs(tokenizer, "Instruction: {}\nInput: {}", instruction, input_text, eos_token)
|
33 |
+
|
34 |
+
# Generate responses
|
35 |
+
responses = generate_responses(model, inputs, tokenizer, max_new_tokens=32)
|
36 |
+
|
37 |
+
# Assertions
|
38 |
+
assert isinstance(responses, list), f"Expected list, but got {type(responses)}"
|
39 |
+
assert len(responses) > 0, "Expected non-empty responses list"
|
40 |
+
assert isinstance(responses[0], str), f"Expected string, but got {type(responses[0])}"
|
41 |
+
assert len(responses[0]) > 0, "Expected non-empty string response"
|
42 |
+
|
43 |
+
# Optionally, assert that the response matches some expected pattern or content
|
44 |
+
assert "name" in responses[0].lower(), "Response does not contain expected content"
|
tests/test_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.model import load_model, configure_peft_model
|
2 |
+
import torch
|
3 |
+
|
4 |
+
@pytest.mark.gpu
|
5 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
6 |
+
def test_gpu_feature():
|
7 |
+
# Your test code that needs a GPU
|
8 |
+
assert torch.cuda.is_available()
|
9 |
+
|
10 |
+
@pytest.mark.gpu
|
11 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
12 |
+
def test_load_model():
|
13 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
14 |
+
model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
|
15 |
+
|
16 |
+
# Check that model and tokenizer are not None
|
17 |
+
assert model is not None
|
18 |
+
assert tokenizer is not None
|
19 |
+
|
20 |
+
# Check that model is on the correct device (e.g., GPU or CPU)
|
21 |
+
assert next(model.parameters()).device == torch.device('cuda:0'), "Model should be loaded on CUDA device"
|
22 |
+
|
23 |
+
# Check that the tokenizer is an instance of the correct class
|
24 |
+
assert hasattr(tokenizer, "encode"), "Tokenizer should have the 'encode' method"
|
25 |
+
|
26 |
+
@pytest.mark.gpu
|
27 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
28 |
+
def test_configure_peft_model():
|
29 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
30 |
+
model, _ = load_model(model_name, 16, None, True, {'': 0})
|
31 |
+
|
32 |
+
# Configure the PEFT model
|
33 |
+
peft_model = configure_peft_model(model, target_modules=["q_proj", "down_proj"])
|
34 |
+
|
35 |
+
# Check that PEFT model is not None
|
36 |
+
assert peft_model is not None, "PEFT model should not be None"
|
37 |
+
|
38 |
+
# Check that the PEFT model has a forward method
|
39 |
+
assert hasattr(peft_model, "forward"), "PEFT model should have a 'forward' method"
|
40 |
+
|
41 |
+
# Ensure that PEFT model can perform a forward pass (check if no error is raised)
|
42 |
+
try:
|
43 |
+
dummy_input = torch.randint(0, 1000, (1, 16)) # Dummy input tensor
|
44 |
+
peft_model(dummy_input)
|
45 |
+
except Exception as e:
|
46 |
+
pytest.fail(f"PEFT model forward pass failed: {e}")
|
tests/test_save_model.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytest
|
3 |
+
from src.save_model import save_model_and_tokenizer
|
4 |
+
from src.model import load_model
|
5 |
+
import torch
|
6 |
+
|
7 |
+
@pytest.mark.gpu
|
8 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
9 |
+
def test_gpu_feature():
|
10 |
+
# Your test code that needs a GPU
|
11 |
+
assert torch.cuda.is_available()
|
12 |
+
|
13 |
+
@pytest.mark.gpu
|
14 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
15 |
+
@pytest.fixture
|
16 |
+
def model_and_tokenizer():
|
17 |
+
"""Fixture to load the model and tokenizer for saving."""
|
18 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
19 |
+
model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
|
20 |
+
return model, tokenizer
|
21 |
+
|
22 |
+
@pytest.mark.gpu
|
23 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
24 |
+
def test_save_model(model_and_tokenizer):
|
25 |
+
model, tokenizer = model_and_tokenizer
|
26 |
+
save_directory = "./test_save_dir"
|
27 |
+
|
28 |
+
# Save model and tokenizer
|
29 |
+
save_model_and_tokenizer(model, tokenizer, save_directory)
|
30 |
+
|
31 |
+
# Check if the directory exists
|
32 |
+
assert os.path.exists(save_directory), f"Directory {save_directory} does not exist"
|
33 |
+
|
34 |
+
# Check for key model files
|
35 |
+
assert os.path.exists(os.path.join(save_directory, "config.json")), "config.json not found"
|
36 |
+
assert os.path.exists(os.path.join(save_directory, "tokenizer_config.json")), "tokenizer_config.json not found"
|
37 |
+
assert os.path.exists(os.path.join(save_directory, "pytorch_model.bin")), "pytorch_model.bin not found"
|
38 |
+
|
39 |
+
# Check that files are not empty
|
40 |
+
assert os.path.getsize(os.path.join(save_directory, "pytorch_model.bin")) > 0, "pytorch_model.bin is empty"
|
41 |
+
assert os.path.getsize(os.path.join(save_directory, "config.json")) > 0, "config.json is empty"
|
42 |
+
assert os.path.getsize(os.path.join(save_directory, "tokenizer_config.json")) > 0, "tokenizer_config.json is empty"
|
43 |
+
|
44 |
+
# Cleanup after test
|
45 |
+
for file in os.listdir(save_directory):
|
46 |
+
file_path = os.path.join(save_directory, file)
|
47 |
+
if os.path.isfile(file_path):
|
48 |
+
os.remove(file_path)
|
49 |
+
os.rmdir(save_directory)
|
tests/test_training.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.training import train_model
|
2 |
+
from src.model import load_model
|
3 |
+
from src.dataset import formatting_prompts_func
|
4 |
+
from datasets import Dataset
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.mark.gpu
|
10 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
11 |
+
def test_gpu_feature():
|
12 |
+
# Your test code that needs a GPU
|
13 |
+
assert torch.cuda.is_available()
|
14 |
+
|
15 |
+
@pytest.mark.gpu
|
16 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
17 |
+
@pytest.fixture
|
18 |
+
def mock_dataset():
|
19 |
+
"""Fixture to provide a mock dataset for training"""
|
20 |
+
data = {
|
21 |
+
"instruction": ["Test instruction 1", "Test instruction 2"],
|
22 |
+
"input": ["Test input 1", "Test input 2"],
|
23 |
+
"output": ["Test output 1", "Test output 2"]
|
24 |
+
}
|
25 |
+
formatted_data = formatting_prompts_func(data, template="Instruction: {}\nInput: {}\nOutput: {}", eos_token="<EOS>")
|
26 |
+
return Dataset.from_dict(formatted_data)
|
27 |
+
|
28 |
+
@pytest.mark.gpu
|
29 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
|
30 |
+
def test_train_model(mock_dataset):
|
31 |
+
"""Test to ensure the training model function works with a mock dataset"""
|
32 |
+
|
33 |
+
# Load model
|
34 |
+
model_name = "unsloth/Meta-Llama-3.1-8B"
|
35 |
+
model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
|
36 |
+
|
37 |
+
# Training arguments
|
38 |
+
training_args = {
|
39 |
+
"max_steps": 1,
|
40 |
+
"output_dir": "outputs"
|
41 |
+
}
|
42 |
+
|
43 |
+
# Train the model
|
44 |
+
train_stats = train_model(
|
45 |
+
model=model,
|
46 |
+
tokenizer=tokenizer,
|
47 |
+
train_dataset=mock_dataset,
|
48 |
+
dataset_text_field="text",
|
49 |
+
max_seq_length=16,
|
50 |
+
dataset_num_proc=1,
|
51 |
+
packing=False,
|
52 |
+
training_args=training_args
|
53 |
+
)
|
54 |
+
|
55 |
+
# Assert that training statistics are returned
|
56 |
+
assert train_stats is not None
|
57 |
+
|
58 |
+
# Optionally, check for specific fields in `train_stats` (e.g., loss, global_step)
|
59 |
+
# Since trainer.train() returns an object that has 'global_step' and 'train_loss', we can assert them
|
60 |
+
assert hasattr(train_stats, "global_step")
|
61 |
+
assert hasattr(train_stats, "train_loss")
|
62 |
+
|
63 |
+
# For further validation, assert that the model directory was created (outputs directory)
|
64 |
+
assert "outputs" in train_stats.args.output_dir
|