helamouri commited on
Commit
eca6215
·
0 Parent(s):

update model

Browse files
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gguf filter=lfs diff=lfs merge=lfs -text
37
+ *.json filter=lfs diff=lfs merge=lfs -text
38
+ llama3_medichat filter=lfs diff=lfs merge=lfs -text
.github/workflows/deploy.yml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI/CD Workflow
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ build-test-deploy:
13
+ runs-on: ubuntu-latest
14
+
15
+ steps:
16
+ # Checkout repository
17
+ - name: Checkout code
18
+ uses: actions/checkout@v3
19
+
20
+ # Set up Python
21
+ - name: Setup Python
22
+ uses: actions/setup-python@v4
23
+ with:
24
+ python-version: '3.12.3'
25
+
26
+ # Install dependencies
27
+ - name: Install dependencies
28
+ run: |
29
+ python3 -m venv .venv
30
+ . .venv/bin/activate
31
+ pip install --upgrade pip
32
+ pip install -r requirements.txt
33
+
34
+ - name: Check for GPU Availability
35
+ id: gpu-check
36
+ run: |
37
+ if lspci | grep -i nvidia; then
38
+ echo "gpu=true" >> $GITHUB_ENV
39
+ else
40
+ echo "gpu=false" >> $GITHUB_ENV
41
+ fi
42
+
43
+ # Run tests
44
+ - name: Run Tests
45
+ if: env.gpu == 'true'
46
+ run: |
47
+ source .venv/bin/activate
48
+ pytest --maxfail=5 --disable-warnings
49
+
50
+ - name: Skip Tests (No GPU)
51
+ if: env.gpu == 'false'
52
+ run: |
53
+ echo "Skipping GPU-dependent tests: No GPU available."
54
+
55
+ sync-to-hub:
56
+ runs-on: ubuntu-latest
57
+
58
+ steps:
59
+ - name: Checkout repository
60
+ uses: actions/checkout@v3
61
+
62
+ - name: Set Git user identity
63
+ run: |
64
+ git config --global user.name "Hussein El Amouri"
65
+ git config --global user.email "[email protected]"
66
+
67
+ # - name: Set up Git LFS
68
+ # run: |
69
+ # git lfs install # Ensure Git LFS is installed and set up
70
+
71
+ # - name: Track large files with Git LFS
72
+ # run: |
73
+ # # Track specific large files that exceed the 10 MB limit
74
+ # git lfs track "*.gguf" # Add GGUF model to LFS
75
+ # git lfs track "*.safetensors" # Add safetensors model to LFS
76
+ # git lfs track "*.pt" # Add optimizer checkpoint to LFS
77
+ # git lfs track "*.json" # Add tokenizer to LFS
78
+
79
+ # # Add .gitattributes file to the staging area for Git LFS tracking
80
+ # git add .gitattributes
81
+
82
+ - name: Push to Hugging Face
83
+ env:
84
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
85
+ run: |
86
+
87
+ # git lfs ls-files
88
+ # git lfs fetch --all
89
+ # git lfs pull
90
+ # git rev-parse --is-shallow-repository
91
+ git filter-branch -- --all
92
+ git push https://helamouri:[email protected]/spaces/helamouri/medichat_assignment main --force # Push to Hugging Face
93
+
94
+ # - name: Set up Hugging Face CLI
95
+ # run: |
96
+ # pip install huggingface_hub
97
+
98
+ # - name: Login to Hugging Face
99
+ # env:
100
+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
101
+ # run: |
102
+ # huggingface-cli login --token $HF_TOKEN
103
+
104
+ # - name: Sync with Hugging Face (including large files)
105
+ # env:
106
+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
107
+ # run: |
108
+ # # Initialize git-lfs
109
+ # git lfs install
110
+
111
+ # # Pull any LFS-tracked files (if needed)
112
+ # git lfs pull
113
+
114
+ # # Push the repository to Hugging Face
115
+ # huggingface-cli upload spaces/helamouri/medichat_assignment ./* ./medichat_assignment
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore large files that are tracked by Git LFS
2
+ *.log
3
+
4
+ # Ignore build directories (e.g., for Python, Java, etc.)
5
+ env/
6
+
7
+
8
+ # Ensure .gitattributes is not ignored (needed for Git LFS tracking)
9
+ !.gitattributes
Dokerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official NVIDIA CUDA image as a base (you can adjust the CUDA version if needed)
2
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
3
+
4
+ # Set environment variable to avoid interactive prompts during package installation
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+
7
+ # Install Python 3.9 and dependencies
8
+ RUN apt-get update && \
9
+ apt-get install -y python3.9 python3.9-dev python3.9-venv python3.9-distutils curl && \
10
+ ln -s /usr/bin/python3.9 /usr/bin/python && \
11
+ curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
12
+ python get-pip.py && \
13
+ rm get-pip.py
14
+
15
+ # Set the working directory
16
+ WORKDIR /src
17
+
18
+ # Copy the requirements and application files
19
+ COPY requirements.txt .
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy all source code
23
+ COPY . .
24
+
25
+ # Expose the default Streamlit port
26
+ EXPOSE 8501
27
+
28
+ # Run the Streamlit app
29
+ CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
Makefile ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SHELL := /bin/bash
2
+ # Makefile for Llama3.1:8B Project
3
+
4
+ # Variables
5
+ PYTHON = python
6
+ PIP = pip
7
+ VENV_DIR = ./env
8
+ VENV_PYTHON = $(VENV_DIR)/bin/python
9
+ VENV_PIP = $(VENV_DIR)/bin/pip
10
+ REQUIREMENTS = requirements.txt
11
+
12
+ # Default target
13
+ .DEFAULT_GOAL := help
14
+
15
+ # Help target
16
+ help:
17
+ @echo "Makefile for Llama3.1:8B Project"
18
+ @echo ""
19
+ @echo "Targets:"
20
+ @echo " help - Show this help message"
21
+ @echo " setup - Create virtual environment and install dependencies"
22
+ @echo " run - Run the main application"
23
+ @echo " test - Run unit tests"
24
+ @echo " lint - Run linters"
25
+ @echo " clean - Remove temporary files and directories"
26
+ @echo " clean-venv - Remove virtual environment"
27
+ @echo " purge - Clean and reinstall everything"
28
+ @echo " install - Install or update dependencies"
29
+
30
+ # Check for Python and pip
31
+ check-deps:
32
+ @echo "Checking for Python and pip..."
33
+ @if ! command -v $(PYTHON) >/dev/null 2>&1; then \
34
+ echo "Python is not installed. Please install Python3."; \
35
+ exit 1; \
36
+ fi
37
+ @echo "Python is installed."
38
+ @if ! command -v $(PIP) >/dev/null 2>&1; then \
39
+ echo "pip is not installed. Installing pip..."; \
40
+ sudo apt update && sudo apt install -y python3-pip; \
41
+ fi
42
+ @echo "pip is installed."
43
+
44
+ # Create virtual environment and install dependencies
45
+ setup: check-deps
46
+ @echo "Setting up virtual environment..."
47
+ @if [ ! -d "$(VENV_DIR)" ]; then \
48
+ $(PYTHON) -m venv $(VENV_DIR); \
49
+ echo "Virtual environment created."; \
50
+ fi
51
+ @echo "Installing dependencies..."
52
+ $(VENV_PIP) install --upgrade pip
53
+ $(VENV_PIP) install -r $(REQUIREMENTS)
54
+ @echo "Setup completed."
55
+
56
+ # Run the main application
57
+ run:
58
+ @echo "Running the application..."
59
+ $(VENV_PYTHON) main.py
60
+
61
+ # Run tests
62
+ test:
63
+ @echo "Running tests..."
64
+ $(VENV_PYTHON) -m unittest discover tests
65
+
66
+ # Run linters
67
+ lint:
68
+ @echo "Running linters..."
69
+ $(VENV_PYTHON) -m flake8 src/ tests/
70
+
71
+ # Clean temporary files and directories
72
+ clean:
73
+ @echo "Cleaning temporary files and directories..."
74
+ find . -type f -name '*.pyc' -delete
75
+ find . -type d -name '__pycache__' -exec rm -r {} +
76
+ @echo "Cleanup completed."
77
+
78
+ # Clean virtual environment
79
+ clean-venv:
80
+ @echo "Removing virtual environment..."
81
+ rm -rf $(VENV_DIR)
82
+ @echo "Virtual environment removed."
83
+
84
+ # Purge: remove all and reinstall environment
85
+ purge: clean clean-venv setup
86
+
87
+ # Install or update dependencies
88
+ install:
89
+ @echo "Installing or updating dependencies..."
90
+ $(VENV_PIP) install -r $(REQUIREMENTS)
91
+ @echo "Dependencies installed or updated."
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MediChat
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: "1.40.1" # Replace with the actual version of your SDK
8
+ app_file: app.py # Replace with the main app file name
9
+ pinned: false
10
+ ---
11
+
12
+ [![CI/CD Workflow](https://github.com/hussein88al88amouri/medichat_assignment/actions/workflows/deploy.yml/badge.svg)](https://github.com/hussein88al88amouri/medichat_assignment/actions/workflows/deploy.yml)
13
+
14
+ # MediChat: AI-Powered Medical Consultation Assistant
15
+
16
+ MediChat is an intelligent chatbot designed to provide medical consultations using a fine-tuned Llama3.1:8B model. The project bridges advanced AI capabilities with practical healthcare assistance.
17
+
18
+ ## Features
19
+ - Fine-tuned model for medical conversations
20
+ - Interactive and user-friendly interface
21
+ - Secure and containerized deployment
22
+
23
+ ## How to Use
24
+ 1. Access the chatbot interface.
25
+ 2. Input your medical query.
26
+ 3. Receive intelligent and context-aware responses.
27
+
28
+ ## Technical Details
29
+ - Model: Llama3.1:8B
30
+ - Framework: Gradio
31
+
32
+ ## Installation
33
+ 1. Clone the repository:
34
+ ```bash
35
+ git clone https://github.com/your_username/medichat.git
36
+ cd medichat
37
+ 2. Build and run the Docker container: (in bash copy the following code)
38
+ docker build -t medichat-app .
39
+ docker run -p 8501:8501 medichat-app
40
+ 3. Access the app at http://localhost:8501.
41
+
42
+ Limitations
43
+ This tool is not a replacement for professional medical advice.
44
+ For critical issues, always consult a licensed medical professional.
45
+
Setup.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ from pathlib import Path
3
+
4
+ # Read the requirements from the requirements.txt file
5
+ def parse_requirements():
6
+ requirements_path = Path(__file__).parent / 'requirements.txt'
7
+ with open(requirements_path, 'r') as file:
8
+ return [line.strip() for line in file if line.strip() and not line.startswith('#')]
9
+
10
+ setup(
11
+ # The name of your package.
12
+ name='medichat',
13
+
14
+ # A version number for your package.
15
+ version='0.1.0',
16
+
17
+ # A brief summary of what your package does.
18
+ description='A fine-tuned LLM for medical consultations based on the Meta-Llama 3.1 8B model.',
19
+
20
+ # The URL of your project's homepage.
21
+ url='https://github.com/hussein88al88amouri/medichat',
22
+
23
+ # The author’s name.
24
+ author='Hussein El Amouri',
25
+
26
+ # The author’s email address.
27
+ author_email='[email protected]',
28
+
29
+ # This defines which packages should be included in the distribution.
30
+ packages=find_packages(),
31
+
32
+ # Read dependencies from the requirements.txt
33
+ install_requires=parse_requirements(),
34
+
35
+ # Additional classification of your package.
36
+ classifiers=[
37
+ 'Development Status :: 3 - Alpha',
38
+ 'Intended Audience :: Developers',
39
+ 'License :: OSI Approved :: MIT License',
40
+ 'Programming Language :: Python :: 3',
41
+ 'Programming Language :: Python :: 3.8',
42
+ 'Programming Language :: Python :: 3.9',
43
+ 'Programming Language :: Python :: 3.10',
44
+ ],
45
+
46
+ # A license for your package.
47
+ license='MIT',
48
+
49
+ # You can add entry points for command-line tools if your package includes such functionality.
50
+ entry_points={
51
+ 'console_scripts': [
52
+ 'medichat=medichat.cli:main', # Adjust to your actual CLI entry point, if any
53
+ ],
54
+ },
55
+
56
+ # If you have data files (like configuration files), you can specify them here.
57
+ data_files=[
58
+ # Example of configuration files for saving the model, etc.
59
+ ('share/config', ['config/config.json']),
60
+ ],
61
+
62
+ # If your package has specific testing requirements or needs test dependencies, list them here.
63
+ extras_require={
64
+ 'dev': ['pytest', 'tox'], # Optional dependencies for development or testing
65
+ 'docs': ['sphinx'], # Optional dependencies for documentation generation
66
+ },
67
+
68
+ # Specify your package's minimum supported Python version
69
+ python_requires='>=3.8',
70
+
71
+ # If your package includes command-line scripts, you can list them here
72
+ scripts=['scripts/cli_script.py'], # Update path if you have a script to run
73
+
74
+ # If your package includes C extensions or other modules, specify them here.
75
+ ext_modules=[],
76
+ )
User.code-workspace ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "C:/Users/hasso/AppData/Roaming/Code/User"
5
+ },
6
+ {
7
+ "path": ".."
8
+ }
9
+ ],
10
+ "settings": {
11
+ "workbench.colorCustomizations": {
12
+ "activityBar.activeBackground": "#5b5b5b",
13
+ "activityBar.background": "#5b5b5b",
14
+ "activityBar.foreground": "#e7e7e7",
15
+ "activityBar.inactiveForeground": "#e7e7e799",
16
+ "activityBarBadge.background": "#103010",
17
+ "activityBarBadge.foreground": "#e7e7e7",
18
+ "commandCenter.border": "#e7e7e799",
19
+ "sash.hoverBorder": "#5b5b5b",
20
+ "statusBar.background": "#424242",
21
+ "statusBar.foreground": "#e7e7e7",
22
+ "statusBarItem.hoverBackground": "#5b5b5b",
23
+ "statusBarItem.remoteBackground": "#424242",
24
+ "statusBarItem.remoteForeground": "#e7e7e7",
25
+ "titleBar.activeBackground": "#424242",
26
+ "titleBar.activeForeground": "#e7e7e7",
27
+ "titleBar.inactiveBackground": "#42424299",
28
+ "titleBar.inactiveForeground": "#e7e7e799"
29
+ },
30
+ "peacock.color": "#424242"
31
+ }
32
+ }
__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # __init__.py
2
+
3
+ # Import necessary modules or functions from submodules.
4
+ # This is where you can aggregate the public API of your package.
5
+
6
+ # Example:
7
+ # from .module_name import function_name, ClassName
8
+
9
+ # You can also import specific components to make them available directly from the package level.
10
+ # For example:
11
+ # from .subpackage.module_name import function_name
12
+ from src import llama3_finetune
13
+ from src import main
14
+
15
+ # Initialize any package-level variables or constants
16
+ # For example, if you have any version number or author info, you can define it here.
17
+
18
+ __version__ = "0.1.0" # Replace with your actual package version
19
+ __author__ = "Hussein El Amouri" # Replace with your name or the author name
20
+
21
+ # You can include initialization code here, if your package requires any.
22
+ # For example, setting up logging, initializing global variables, etc.
23
+
24
+ # Example:
25
+ # import logging
26
+ # logging.basicConfig(level=logging.INFO)
27
+
28
+ # Define a list of publicly exposed items (optional)
29
+ # This list is used to specify which functions, classes, or variables
30
+ # should be available when `from package_name import *` is used.
31
+
32
+ # Example:
33
+ __all__ = [
34
+ 'function_name', # List the names of the functions, classes, or variables you want exposed
35
+ 'ClassName',
36
+ ]
37
+
38
+ # If your package uses a specific function or submodule as the primary entry point,
39
+ # you can set that here.
40
+ # For example, if the main function of the package is in a submodule called 'main.py',
41
+ # you can import that here:
42
+ # from .main import run
43
+
44
+ # Initialize any necessary package-specific code here, if needed
45
+ # Example for adding environment setup, database initialization, etc.
46
+
47
+ # If your package contains a command-line interface (CLI), you can import it here,
48
+ # so it can be executed as a script if the package is installed:
49
+ # from .cli import main
50
+
51
+ # Any other necessary imports that users should be aware of can go here.
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # from unsloth import FastLanguageModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ from llama_cpp import Llama
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
+ import sys
9
+
10
+ # # Suppress unwanted outputs (e.g., from unsloth or other libraries)
11
+ # def suppress_output():
12
+ # sys.stdout = open(os.devnull, 'w') # Redirect stdout to devnull
13
+ # sys.stderr = open(os.devnull, 'w') # Redirect stderr to devnull
14
+
15
+ # def restore_output():
16
+ # sys.stdout = sys.__stdout__ # Restore stdout
17
+ # sys.stderr = sys.__stderr__ # Restore stderr
18
+
19
+ # Load the model (GGUF format)
20
+ @st.cache_resource
21
+ def load_model():
22
+ # Define the repository and model filenames for both the base model and LoRA adapter
23
+ base_model_repo = "helamouri/Meta-Llama-3.1-8B-Q8_0.gguf"
24
+ base_model_filename = "Meta-Llama-3.1-8B-Q8_0.gguf"
25
+ adapter_repo = "helamouri/medichat_assignment"
26
+ # adapter_filename = "llama3_medichat.gguf" # assuming adapter is also in safetensors format
27
+ adapter_repo = "helamouri/model_medichat_finetuned_v1"
28
+
29
+ # Download the base model and adapter model to local paths
30
+ base_model_path = hf_hub_download(repo_id=base_model_repo, filename=base_model_filename)
31
+ adapter_model_path = hf_hub_download(repo_id=adapter_repo, filename=adapter_filename)
32
+
33
+ # Log paths for debugging
34
+ print(f"Base model path: {base_model_path}")
35
+ print(f"Adapter model path: {adapter_model_path}")
36
+
37
+ # Load the full model (base model) and the adapter (LoRA)
38
+ try:
39
+ model = Llama(model_path=base_model_path) #, adapter_path=adapter_model_path)
40
+ print("Model loaded successfully.")
41
+ except ValueError as e:
42
+ print(f"Error loading model: {e}")
43
+ raise
44
+
45
+ return model
46
+
47
+ # Generate a response using Llama.cpp
48
+ def generate_response(model, prompt):
49
+ print('prompt')
50
+ print(prompt)
51
+ response = model(
52
+ prompt,
53
+ max_tokens=200, # Maximum tokens for the response
54
+ temperature=0.7, # Adjust for creativity (lower = deterministic)
55
+ top_p=0.9, # Nucleus sampling
56
+ stop=["\n"] # Stop generating when newline is encountered
57
+ )
58
+ print('response["choices"]')
59
+ print(response["choices"])
60
+ return response["choices"][0]["text"]
61
+
62
+ # Load the model and tokenizer (GGUF format)
63
+ # @st.cache_resource
64
+ # def load_model():
65
+ # model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's GGUF path
66
+ # model = FastLanguageModel.from_pretrained(model_name, device='cpu') # Load the model using unsloth
67
+ # tokenizer = model.tokenizer # Assuming the tokenizer is part of the GGUF model object
68
+ # return tokenizer, model
69
+
70
+
71
+ # @st.cache_resource
72
+ # def load_model():
73
+ # model_name = "helamouri/model_medichat_finetuned_v1" # Replace with your model's path
74
+ # # Load the tokenizer
75
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
76
+ # # Load the model (if it's a causal language model or suitable model type)
77
+ # model = AutoModelForCausalLM.from_pretrained(model_name,
78
+ # device_map="cpu",
79
+ # revision="main",
80
+ # quantize=False,
81
+ # load_in_8bit=False,
82
+ # load_in_4bit=False,
83
+ # #torch_dtype=torch.float32
84
+ # )
85
+ # return tokenizer, model
86
+
87
+ # Suppress unwanted outputs from unsloth or any other libraries during model loading
88
+ #suppress_output()
89
+
90
+ # Load the GGUF model
91
+ print('Loading the model')
92
+ model = load_model()
93
+ # Restore stdout and stderr
94
+
95
+ #restore_output()
96
+
97
+ # App layout
98
+ print('Setting App layout')
99
+ st.title("MediChat: Your AI Medical Consultation Assistant")
100
+ st.markdown("Ask me anything about your health!")
101
+ st.write("Enter your symptoms or medical questions below:")
102
+
103
+ # User input
104
+ print(f'Setting user interface')
105
+ user_input = st.text_input("Your Question:")
106
+ if st.button("Get Response"):
107
+ if user_input:
108
+ with st.spinner("Generating response..."):
109
+ # Generate Response
110
+ response = generate_response(model, user_input)
111
+ print('Response')
112
+ print(response)
113
+ # Display response
114
+ st.text_area("Response:", value=response, height=200)
115
+ else:
116
+ st.warning("Please enter a question.")
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ gpu: marks tests that require a GPU
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the multi-backend version of bitsandbytes
2
+ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl
3
+ llama-cpp-python
4
+ datasets
5
+ huggingface_hub
6
+ huggingface_hub[cli]
7
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
8
+ xformers==0.0.28.post2
9
+ trl
10
+ peft
11
+ accelerate
12
+ bitsandbytes
13
+ torchvision
14
+ torch
15
+ sentencepiece
16
+ transformers[torch]>=4.45.1
17
+ streamlit==1.40.1
18
+ gguf>=0.10.0
19
+ pytest
20
+ flake8
space.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docker-compose.yaml for Hugging Face Spaces
2
+ version: '3.8'
3
+
4
+ services:
5
+ app:
6
+ build:
7
+ context: .
8
+ dockerfile: Dockerfile
9
+ ports:
10
+ - "7860:7860" # Default port for Streamlit or Gradio apps
11
+ environment:
12
+ HF_TOKEN: ${HF_TOKEN} # Hugging Face API token
13
+ command: >
14
+ bash -c "
15
+ python3 -m venv .venv &&
16
+ . .venv/bin/activate &&
17
+ pip install -r requirements.txt &&
18
+ python main.py
19
+ "
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .config import *
2
+ from .model import load_model, configure_peft_model
3
+ from .dataset import load_and_prepare_dataset, formatting_prompts_func
4
+ from .training import train_model
5
+ from .inference import prepare_inference_inputs, generate_responses, stream_responses
6
+ from .save_model import save_model_and_tokenizer
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (601 Bytes). View file
 
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.16 kB). View file
 
src/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (1.37 kB). View file
 
src/__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.51 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (1.41 kB). View file
 
src/__pycache__/save_model.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
src/__pycache__/training.cpython-312.pyc ADDED
Binary file (743 Bytes). View file
 
src/config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_device_map():
5
+ if torch.cuda.is_available():
6
+ return {'' : torch.cuda.current_device()}
7
+ else:
8
+ return {} # Or some default, fallback configuration
9
+
10
+ # General configuration
11
+ MAX_SEQ_LENGTH = 2**4
12
+ DTYPE = None
13
+ LOAD_IN_4BIT = True
14
+ DEVICE_MAP = {'': get_device_map()}
15
+ EOS_TOKEN = None # Set dynamically based on tokenizer
16
+
17
+ # Alpaca prompt template
18
+ ALPACA_PROMPT_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
19
+
20
+ ###Instruction:
21
+ {}
22
+
23
+ ###Input:
24
+ {}
25
+
26
+ ###Response:
27
+ {}"""
28
+
29
+ # Training arguments
30
+ TRAIN_ARGS = {
31
+ "per_device_train_batch_size": 2,
32
+ "gradient_accumulation_steps": 4,
33
+ "warmup_steps": 5,
34
+ "max_steps": 60,
35
+ "learning_rate": 2e-4,
36
+ "fp16": not torch.cuda.is_bf16_supported(),
37
+ "bf16": torch.cuda.is_bf16_supported(),
38
+ "logging_steps": 1,
39
+ "optim": "adamw_8bit",
40
+ "weight_decay": 0.01,
41
+ "lr_scheduler_type": "linear",
42
+ "seed": 3407,
43
+ "output_dir": "outputs",
44
+ }
src/dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ def formatting_prompts_func(examples, template, eos_token):
4
+ instructions = examples["instruction"]
5
+ inputs = examples["input"]
6
+ outputs = examples["output"]
7
+
8
+ # Format the examples using the provided template
9
+ texts = []
10
+ for instruction, input_text, output in zip(instructions, inputs, outputs):
11
+ text = template.format(instruction, input_text, output) + eos_token
12
+ texts.append(text)
13
+
14
+ # Return a dictionary with the formatted text
15
+ return {"text": texts}
16
+
17
+ def load_and_prepare_dataset(dataset_name, nsamples, formatting_func, template, eos_token):
18
+ # Load the dataset and prepare it by applying the formatting function
19
+ dataset = load_dataset(dataset_name, split="train").select(range(nsamples))
20
+
21
+ # Map the formatting function over the dataset
22
+ return dataset.map(lambda examples: formatting_func(examples, template, eos_token), batched=True)
src/fine_tune_llama.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src import *
2
+
3
+ # Load configuration
4
+ max_seq_length = config.MAX_SEQ_LENGTH
5
+ device_map = config.DEVICE_MAP
6
+ eos_token = config.EOS_TOKEN
7
+
8
+ # Load and configure model
9
+ model_name = "unsloth/Meta-Llama-3.1-8B"
10
+ model, tokenizer = load_model(model_name, max_seq_length, config.DTYPE, config.LOAD_IN_4BIT, device_map)
11
+ eos_token = tokenizer.eos_token
12
+
13
+ model = configure_peft_model(model, target_modules=["q_proj", "down_proj"])
14
+
15
+ # Prepare dataset
16
+ nsamples = 1000
17
+ dataset = load_and_prepare_dataset(
18
+ "lavita/ChatDoctor-HealthCareMagic-100k",
19
+ nsamples,
20
+ formatting_prompts_func,
21
+ config.ALPACA_PROMPT_TEMPLATE,
22
+ eos_token,
23
+ )
24
+
25
+ # Train model
26
+ trainer_stats = train_model(
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ train_dataset=dataset,
30
+ dataset_text_field="text",
31
+ max_seq_length=max_seq_length,
32
+ dataset_num_proc=2,
33
+ packing=False,
34
+ training_args=config.TRAIN_ARGS,
35
+ )
36
+
37
+ # Save the model
38
+ save_model_and_tokenizer(model, tokenizer, "./llama3_medichat")
src/inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextStreamer
2
+
3
+ def prepare_inference_inputs(tokenizer, template, instruction, input_text, eos_token, device="cuda"):
4
+ """
5
+ Prepares the inputs for inference by formatting the prompt and tokenizing it.
6
+
7
+ Args:
8
+ - tokenizer: The tokenizer used for tokenization.
9
+ - template: The template string for the prompt format.
10
+ - instruction: The instruction to be included in the prompt.
11
+ - input_text: The input to be included in the prompt.
12
+ - eos_token: The end of sequence token.
13
+ - device: The device for the model ('cuda' or 'cpu').
14
+
15
+ Returns:
16
+ - Tokenized inputs ready for inference.
17
+ """
18
+ prompt = template.format(instruction, input_text, "") + eos_token
19
+ return tokenizer([prompt], return_tensors="pt").to(device)
20
+
21
+ def generate_responses(model, inputs, tokenizer, max_new_tokens=64):
22
+ """
23
+ Generates responses from the model based on the provided inputs.
24
+
25
+ Args:
26
+ - model: The pre-trained model for generation.
27
+ - inputs: The tokenized inputs to generate responses.
28
+ - tokenizer: The tokenizer used to decode the output.
29
+ - max_new_tokens: The maximum number of tokens to generate.
30
+
31
+ Returns:
32
+ - Decoded responses from the model.
33
+ """
34
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
35
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
36
+
37
+ def stream_responses(model, inputs, tokenizer, max_new_tokens=128):
38
+ """
39
+ Streams the model's response using a text streamer.
40
+
41
+ Args:
42
+ - model: The pre-trained model for generation.
43
+ - inputs: The tokenized inputs to generate responses.
44
+ - tokenizer: The tokenizer used to decode the output.
45
+ - max_new_tokens: The maximum number of tokens to generate.
46
+
47
+ Returns:
48
+ - Streams the output directly.
49
+ """
50
+ text_streamer = TextStreamer(tokenizer)
51
+ model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_new_tokens)
src/model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from unsloth import FastLanguageModel
3
+
4
+ def load_model(model_name, max_seq_length, dtype, load_in_4bit, device_map):
5
+ try:
6
+ model, tokenizer = FastLanguageModel.from_pretrained(
7
+ model_name=model_name,
8
+ max_seq_length=max_seq_length,
9
+ dtype=dtype,
10
+ load_in_4bit=load_in_4bit,
11
+ device_map=device_map,
12
+ )
13
+ return model, tokenizer
14
+ except Exception as e:
15
+ raise RuntimeError(f"Failed to load model {model_name}: {e}")
16
+
17
+ def configure_peft_model(model, target_modules, lora_alpha=16, lora_dropout=0, random_state=3407, use_rslora=False):
18
+ try:
19
+ peft_model = FastLanguageModel.get_peft_model(
20
+ model=model,
21
+ target_modules=target_modules,
22
+ lora_alpha=lora_alpha,
23
+ lora_dropout=lora_dropout,
24
+ bias="none",
25
+ use_gradient_checkpointing="unsloth",
26
+ random_state=random_state,
27
+ use_rslora=use_rslora,
28
+ loftq_config=None,
29
+ )
30
+ return peft_model
31
+ except Exception as e:
32
+ raise RuntimeError(f"Failed to configure PEFT model: {e}")
src/save_model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def save_model_and_tokenizer(model, tokenizer, save_directory):
4
+ """
5
+ Save model and tokenizer to the specified directory.
6
+
7
+ Args:
8
+ - model: The model to save.
9
+ - tokenizer: The tokenizer to save.
10
+ - save_directory: Directory where the model and tokenizer should be saved.
11
+ """
12
+ try:
13
+ # Ensure the save directory exists
14
+ os.makedirs(save_directory, exist_ok=True)
15
+
16
+ # Save model and tokenizer
17
+ model.save_pretrained(save_directory, safe_serialization=True)
18
+ tokenizer.save_pretrained(save_directory)
19
+
20
+ print(f"Model and tokenizer saved locally at {save_directory}")
21
+ except Exception as e:
22
+ print(f"Error saving model and tokenizer: {str(e)}")
23
+ raise
src/training.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trl import SFTTrainer
2
+ from transformers import TrainingArguments
3
+
4
+ def train_model(model, tokenizer, train_dataset, dataset_text_field, max_seq_length, dataset_num_proc, packing, training_args):
5
+ trainer = SFTTrainer(
6
+ model=model,
7
+ tokenizer=tokenizer,
8
+ train_dataset=train_dataset,
9
+ dataset_text_field=dataset_text_field,
10
+ max_seq_length=max_seq_length,
11
+ dataset_num_proc=dataset_num_proc,
12
+ packing=packing,
13
+ args=TrainingArguments(**training_args),
14
+ )
15
+
16
+ # Train the model
17
+ train_results = trainer.train()
18
+
19
+ # Optionally, you can return more specific training information if necessary
20
+ return train_results
tests/__init__.py ADDED
File without changes
tests/__pycache__/test_config.cpython-312.pyc ADDED
Binary file (4.25 kB). View file
 
tests/__pycache__/test_dataset.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
tests/__pycache__/test_inference.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
tests/__pycache__/test_model.cpython-312.pyc ADDED
Binary file (1.83 kB). View file
 
tests/__pycache__/test_save_model.cpython-312.pyc ADDED
Binary file (2.89 kB). View file
 
tests/__pycache__/test_training.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
tests/test_config.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ from src.config import (MAX_SEQ_LENGTH, DTYPE, LOAD_IN_4BIT, DEVICE_MAP, EOS_TOKEN,
4
+ ALPACA_PROMPT_TEMPLATE, TRAIN_ARGS)
5
+
6
+ # Test that required configuration keys are present
7
+ def test_required_config_keys():
8
+ assert MAX_SEQ_LENGTH is not None, "MAX_SEQ_LENGTH is not set."
9
+ assert TRAIN_ARGS is not None, "TRAIN_ARGS is not set."
10
+ assert ALPACA_PROMPT_TEMPLATE is not None, "ALPACA_PROMPT_TEMPLATE is not set."
11
+ assert DEVICE_MAP is not None, "DEVICE_MAP is not set."
12
+
13
+ # Test that MAX_SEQ_LENGTH is a power of two
14
+ def test_max_seq_length():
15
+ assert isinstance(MAX_SEQ_LENGTH, int), "MAX_SEQ_LENGTH should be an integer."
16
+ assert MAX_SEQ_LENGTH > 0, "MAX_SEQ_LENGTH should be greater than 0."
17
+ assert (MAX_SEQ_LENGTH & (MAX_SEQ_LENGTH - 1)) == 0, "MAX_SEQ_LENGTH should be a power of two."
18
+
19
+ # Test that TRAIN_ARGS dictionary contains required fields and types
20
+ def test_train_args():
21
+ required_keys = [
22
+ "per_device_train_batch_size",
23
+ "gradient_accumulation_steps",
24
+ "warmup_steps",
25
+ "max_steps",
26
+ "learning_rate",
27
+ "fp16",
28
+ "bf16",
29
+ "logging_steps",
30
+ "optim",
31
+ "weight_decay",
32
+ "lr_scheduler_type",
33
+ "seed",
34
+ "output_dir"
35
+ ]
36
+
37
+ for key in required_keys:
38
+ assert key in TRAIN_ARGS, f"Missing {key} in TRAIN_ARGS."
39
+
40
+ # Check types of specific fields
41
+ assert isinstance(TRAIN_ARGS["per_device_train_batch_size"], int), "per_device_train_batch_size should be an integer."
42
+ assert isinstance(TRAIN_ARGS["learning_rate"], float), "learning_rate should be a float."
43
+ assert isinstance(TRAIN_ARGS["output_dir"], str), "output_dir should be a string."
44
+
45
+ # Test that the DEVICE_MAP references a valid CUDA device
46
+ @pytest.mark.gpu
47
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
48
+ def test_device_map():
49
+ device = DEVICE_MAP.get('', None)
50
+ assert device is not None, "DEVICE_MAP should reference a CUDA device."
51
+ assert isinstance(device, int), "DEVICE_MAP should be an integer (CUDA device ID)."
52
+ assert torch.cuda.is_available(), "CUDA is not available, but DEVICE_MAP points to a CUDA device."
53
+
54
+ # Test that the EOS_TOKEN is set dynamically based on the tokenizer
55
+ def test_eos_token():
56
+ assert EOS_TOKEN is not None, "EOS_TOKEN should be dynamically set based on tokenizer."
57
+
58
+ # Test the ALPACA_PROMPT_TEMPLATE for expected formatting
59
+ def test_alpaca_prompt_template():
60
+ test_instruction = "Test Instruction"
61
+ test_input = "Test Input"
62
+ test_output = "Test Output"
63
+
64
+ formatted_prompt = ALPACA_PROMPT_TEMPLATE.format(test_instruction, test_input, test_output)
65
+
66
+ # Ensure that the prompt template contains the required placeholders
67
+ assert "{}" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain placeholders."
68
+ assert "###Instruction:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Instruction'."
69
+ assert "###Input:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Input'."
70
+ assert "###Response:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Response'."
71
+
72
+ # Test that the LOAD_IN_4BIT setting is a boolean
73
+ def test_load_in_4bit():
74
+ assert isinstance(LOAD_IN_4BIT, bool), "LOAD_IN_4BIT should be a boolean."
75
+
76
+ # Test for the DTYPE (should be None or a valid data type)
77
+ def test_dtype():
78
+ assert DTYPE is None or isinstance(DTYPE, type), "DTYPE should be None or a valid data type."
79
+
tests/test_dataset.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.dataset import formatting_prompts_func
2
+
3
+ def test_formatting_prompts_func():
4
+ # Test case with basic input
5
+ examples = {
6
+ "instruction": ["Test instruction"],
7
+ "input": ["Test input"],
8
+ "output": ["Test output"],
9
+ }
10
+ template = "Instruction: {}\nInput: {}\nOutput: {}"
11
+ eos_token = "<EOS>"
12
+
13
+ result = formatting_prompts_func(examples, template, eos_token)
14
+
15
+ # Check if result contains the 'text' key
16
+ assert "text" in result
17
+
18
+ # Check if result contains exactly one formatted entry
19
+ assert len(result["text"]) == 1
20
+
21
+ # Check if the formatted text is correct
22
+ expected = "Instruction: Test instruction\nInput: Test input\nOutput: Test output<EOS>"
23
+ assert result["text"][0] == expected
24
+
25
+ # Test with empty inputs (edge case)
26
+ examples_empty = {
27
+ "instruction": [""],
28
+ "input": [""],
29
+ "output": [""],
30
+ }
31
+ result_empty = formatting_prompts_func(examples_empty, template, eos_token)
32
+ assert result_empty["text"][0] == "Instruction: \nInput: \nOutput: <EOS>"
33
+
34
+ # Test with multiple examples
35
+ examples_multi = {
36
+ "instruction": ["Test instruction 1", "Test instruction 2"],
37
+ "input": ["Test input 1", "Test input 2"],
38
+ "output": ["Test output 1", "Test output 2"],
39
+ }
40
+ result_multi = formatting_prompts_func(examples_multi, template, eos_token)
41
+ assert len(result_multi["text"]) == 2
42
+ assert result_multi["text"][0] == "Instruction: Test instruction 1\nInput: Test input 1\nOutput: Test output 1<EOS>"
43
+ assert result_multi["text"][1] == "Instruction: Test instruction 2\nInput: Test input 2\nOutput: Test output 2<EOS>"
tests/test_inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.inference import prepare_inference_inputs, generate_responses
2
+ from src.model import load_model
3
+ import pytest
4
+ import torch
5
+
6
+ @pytest.mark.gpu
7
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
8
+ def test_gpu_feature():
9
+ # Your test code that needs a GPU
10
+ assert torch.cuda.is_available()
11
+
12
+ @pytest.mark.gpu
13
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
14
+ @pytest.fixture
15
+ def model_and_tokenizer():
16
+ """Fixture to load model and tokenizer for inference"""
17
+ model_name = "unsloth/Meta-Llama-3.1-8B"
18
+ model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
19
+ return model, tokenizer
20
+
21
+ @pytest.mark.gpu
22
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
23
+ def test_inference(model_and_tokenizer):
24
+ model, tokenizer = model_and_tokenizer
25
+
26
+ # Test input values
27
+ instruction = "What is your name?"
28
+ input_text = "Tell me about yourself."
29
+ eos_token = "<EOS>"
30
+
31
+ # Prepare inference inputs
32
+ inputs = prepare_inference_inputs(tokenizer, "Instruction: {}\nInput: {}", instruction, input_text, eos_token)
33
+
34
+ # Generate responses
35
+ responses = generate_responses(model, inputs, tokenizer, max_new_tokens=32)
36
+
37
+ # Assertions
38
+ assert isinstance(responses, list), f"Expected list, but got {type(responses)}"
39
+ assert len(responses) > 0, "Expected non-empty responses list"
40
+ assert isinstance(responses[0], str), f"Expected string, but got {type(responses[0])}"
41
+ assert len(responses[0]) > 0, "Expected non-empty string response"
42
+
43
+ # Optionally, assert that the response matches some expected pattern or content
44
+ assert "name" in responses[0].lower(), "Response does not contain expected content"
tests/test_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model import load_model, configure_peft_model
2
+ import torch
3
+
4
+ @pytest.mark.gpu
5
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
6
+ def test_gpu_feature():
7
+ # Your test code that needs a GPU
8
+ assert torch.cuda.is_available()
9
+
10
+ @pytest.mark.gpu
11
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
12
+ def test_load_model():
13
+ model_name = "unsloth/Meta-Llama-3.1-8B"
14
+ model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
15
+
16
+ # Check that model and tokenizer are not None
17
+ assert model is not None
18
+ assert tokenizer is not None
19
+
20
+ # Check that model is on the correct device (e.g., GPU or CPU)
21
+ assert next(model.parameters()).device == torch.device('cuda:0'), "Model should be loaded on CUDA device"
22
+
23
+ # Check that the tokenizer is an instance of the correct class
24
+ assert hasattr(tokenizer, "encode"), "Tokenizer should have the 'encode' method"
25
+
26
+ @pytest.mark.gpu
27
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
28
+ def test_configure_peft_model():
29
+ model_name = "unsloth/Meta-Llama-3.1-8B"
30
+ model, _ = load_model(model_name, 16, None, True, {'': 0})
31
+
32
+ # Configure the PEFT model
33
+ peft_model = configure_peft_model(model, target_modules=["q_proj", "down_proj"])
34
+
35
+ # Check that PEFT model is not None
36
+ assert peft_model is not None, "PEFT model should not be None"
37
+
38
+ # Check that the PEFT model has a forward method
39
+ assert hasattr(peft_model, "forward"), "PEFT model should have a 'forward' method"
40
+
41
+ # Ensure that PEFT model can perform a forward pass (check if no error is raised)
42
+ try:
43
+ dummy_input = torch.randint(0, 1000, (1, 16)) # Dummy input tensor
44
+ peft_model(dummy_input)
45
+ except Exception as e:
46
+ pytest.fail(f"PEFT model forward pass failed: {e}")
tests/test_save_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from src.save_model import save_model_and_tokenizer
4
+ from src.model import load_model
5
+ import torch
6
+
7
+ @pytest.mark.gpu
8
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
9
+ def test_gpu_feature():
10
+ # Your test code that needs a GPU
11
+ assert torch.cuda.is_available()
12
+
13
+ @pytest.mark.gpu
14
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
15
+ @pytest.fixture
16
+ def model_and_tokenizer():
17
+ """Fixture to load the model and tokenizer for saving."""
18
+ model_name = "unsloth/Meta-Llama-3.1-8B"
19
+ model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
20
+ return model, tokenizer
21
+
22
+ @pytest.mark.gpu
23
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
24
+ def test_save_model(model_and_tokenizer):
25
+ model, tokenizer = model_and_tokenizer
26
+ save_directory = "./test_save_dir"
27
+
28
+ # Save model and tokenizer
29
+ save_model_and_tokenizer(model, tokenizer, save_directory)
30
+
31
+ # Check if the directory exists
32
+ assert os.path.exists(save_directory), f"Directory {save_directory} does not exist"
33
+
34
+ # Check for key model files
35
+ assert os.path.exists(os.path.join(save_directory, "config.json")), "config.json not found"
36
+ assert os.path.exists(os.path.join(save_directory, "tokenizer_config.json")), "tokenizer_config.json not found"
37
+ assert os.path.exists(os.path.join(save_directory, "pytorch_model.bin")), "pytorch_model.bin not found"
38
+
39
+ # Check that files are not empty
40
+ assert os.path.getsize(os.path.join(save_directory, "pytorch_model.bin")) > 0, "pytorch_model.bin is empty"
41
+ assert os.path.getsize(os.path.join(save_directory, "config.json")) > 0, "config.json is empty"
42
+ assert os.path.getsize(os.path.join(save_directory, "tokenizer_config.json")) > 0, "tokenizer_config.json is empty"
43
+
44
+ # Cleanup after test
45
+ for file in os.listdir(save_directory):
46
+ file_path = os.path.join(save_directory, file)
47
+ if os.path.isfile(file_path):
48
+ os.remove(file_path)
49
+ os.rmdir(save_directory)
tests/test_training.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.training import train_model
2
+ from src.model import load_model
3
+ from src.dataset import formatting_prompts_func
4
+ from datasets import Dataset
5
+ import pytest
6
+ import torch
7
+
8
+
9
+ @pytest.mark.gpu
10
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
11
+ def test_gpu_feature():
12
+ # Your test code that needs a GPU
13
+ assert torch.cuda.is_available()
14
+
15
+ @pytest.mark.gpu
16
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
17
+ @pytest.fixture
18
+ def mock_dataset():
19
+ """Fixture to provide a mock dataset for training"""
20
+ data = {
21
+ "instruction": ["Test instruction 1", "Test instruction 2"],
22
+ "input": ["Test input 1", "Test input 2"],
23
+ "output": ["Test output 1", "Test output 2"]
24
+ }
25
+ formatted_data = formatting_prompts_func(data, template="Instruction: {}\nInput: {}\nOutput: {}", eos_token="<EOS>")
26
+ return Dataset.from_dict(formatted_data)
27
+
28
+ @pytest.mark.gpu
29
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
30
+ def test_train_model(mock_dataset):
31
+ """Test to ensure the training model function works with a mock dataset"""
32
+
33
+ # Load model
34
+ model_name = "unsloth/Meta-Llama-3.1-8B"
35
+ model, tokenizer = load_model(model_name, 16, None, True, {'': 0})
36
+
37
+ # Training arguments
38
+ training_args = {
39
+ "max_steps": 1,
40
+ "output_dir": "outputs"
41
+ }
42
+
43
+ # Train the model
44
+ train_stats = train_model(
45
+ model=model,
46
+ tokenizer=tokenizer,
47
+ train_dataset=mock_dataset,
48
+ dataset_text_field="text",
49
+ max_seq_length=16,
50
+ dataset_num_proc=1,
51
+ packing=False,
52
+ training_args=training_args
53
+ )
54
+
55
+ # Assert that training statistics are returned
56
+ assert train_stats is not None
57
+
58
+ # Optionally, check for specific fields in `train_stats` (e.g., loss, global_step)
59
+ # Since trainer.train() returns an object that has 'global_step' and 'train_loss', we can assert them
60
+ assert hasattr(train_stats, "global_step")
61
+ assert hasattr(train_stats, "train_loss")
62
+
63
+ # For further validation, assert that the model directory was created (outputs directory)
64
+ assert "outputs" in train_stats.args.output_dir