Spaces:
Running
Running
Upload 34 files
Browse files- .env +1 -0
- .gitattributes +1 -0
- .github/workflows/action.yml +41 -0
- .github/workflows/huggingface-workflow.yml +48 -0
- .gitignore +13 -0
- .replit +43 -0
- .streamlit/config.toml +9 -0
- .streamlit/secrets.toml +7 -0
- README.md +87 -2
- app.py +11 -0
- assets/custom.css +59 -0
- attached_assets/Pasted-Below-is-a-design-proposal-for-a-Hugging-Face-based-system-that-lets-users-fine-tune-a-code-generati-1740904225626.txt +116 -0
- attached_assets/Pasted-For-a-robust-foundation-you-ll-want-to-configure-a-set-of-tools-that-catch-errors-early-maintain-c-1740904212802.txt +19 -0
- attached_assets/Pasted-For-a-robust-foundation-you-ll-want-to-configure-a-set-of-tools-that-catch-errors-early-maintain-c-1740906031222.txt +19 -0
- components/code_quality.py +347 -0
- components/dataset_preview.py +75 -0
- components/dataset_statistics.py +149 -0
- components/dataset_uploader.py +113 -0
- components/dataset_validation.py +181 -0
- components/dataset_version_control.py +276 -0
- components/dataset_visualization.py +502 -0
- components/fine_tuning/__init__.py +3 -0
- components/fine_tuning/finetune_ui.py +529 -0
- components/fine_tuning/model_interface.py +228 -0
- generated-icon.png +3 -0
- huggingface-spacefile +8 -0
- main.py +546 -0
- pyproject.toml +33 -0
- replit.nix +18 -0
- test_app.py +105 -0
- utils/dataset_utils.py +167 -0
- utils/huggingface_integration.py +99 -0
- utils/smolagents_integration.py +211 -0
- uv.lock +0 -0
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
PYTHONPATH=.
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
generated-icon.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/action.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Huggingface Login
|
2 |
+
description: "Login to Huggingface using token"
|
3 |
+
author: osbm
|
4 |
+
branding:
|
5 |
+
icon: server
|
6 |
+
color: yellow
|
7 |
+
|
8 |
+
inputs:
|
9 |
+
username:
|
10 |
+
description: "Huggingface Username"
|
11 |
+
required: true
|
12 |
+
key:
|
13 |
+
description: "Huggingface token"
|
14 |
+
required: true
|
15 |
+
|
16 |
+
add_to_git_credentials:
|
17 |
+
description: "Add to git credentials"
|
18 |
+
required: false
|
19 |
+
default: "false"
|
20 |
+
|
21 |
+
runs:
|
22 |
+
using: "composite"
|
23 |
+
steps:
|
24 |
+
- name: Install huggingface-hub
|
25 |
+
shell: bash
|
26 |
+
run: |
|
27 |
+
pip install huggingface-hub
|
28 |
+
|
29 |
+
- name: Login to Huggingface
|
30 |
+
shell: bash
|
31 |
+
run: |
|
32 |
+
mkdir -p ~/.cache/huggingface
|
33 |
+
echo "${{ inputs.key }}" > ~/.cache/huggingface/token
|
34 |
+
|
35 |
+
- name: Add to git credentials
|
36 |
+
shell: bash
|
37 |
+
if: inputs.add_to_git_credentials == 'true'
|
38 |
+
run: |
|
39 |
+
git config --global credential.helper store
|
40 |
+
git config --global credential.https://huggingface.co.username ${{ inputs.username }}
|
41 |
+
git config --global credential.https://huggingface.co.password ${{ inputs.key }}
|
.github/workflows/huggingface-workflow.yml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Hugging Face Space Interaction
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
interact-with-space:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout repository
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: '3.11'
|
20 |
+
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install huggingface_hub requests
|
25 |
+
|
26 |
+
- name: Login to Hugging Face
|
27 |
+
run: echo "${{ secrets.HF_TOKEN }}" | huggingface-cli login --token
|
28 |
+
|
29 |
+
- name: Example interaction with Space
|
30 |
+
run: |
|
31 |
+
python -c "
|
32 |
+
import requests
|
33 |
+
import os
|
34 |
+
|
35 |
+
HF_TOKEN = os.environ.get('HF_TOKEN')
|
36 |
+
headers = {'Authorization': f'Bearer {HF_TOKEN}'}
|
37 |
+
API_URL = 'YOUR_SPACE_API_URL' # Replace with your Space's API URL.
|
38 |
+
|
39 |
+
payload = {'inputs': 'Your input data'}
|
40 |
+
|
41 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
42 |
+
if response.status_code == 200:
|
43 |
+
print(response.json())
|
44 |
+
else:
|
45 |
+
print(f'Error: {response.status_code}, {response.text}')
|
46 |
+
"
|
47 |
+
env:
|
48 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
.env
|
5 |
+
.venv
|
6 |
+
env/
|
7 |
+
venv/
|
8 |
+
ENV/
|
9 |
+
database/data/*.db
|
10 |
+
fine_tuned_models/
|
11 |
+
.streamlit/secrets.toml
|
12 |
+
.ipynb_checkpoints/
|
13 |
+
.DS_Store
|
.replit
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
modules = ["python-3.11"]
|
2 |
+
|
3 |
+
[nix]
|
4 |
+
channel = "stable-24_05"
|
5 |
+
|
6 |
+
[deployment]
|
7 |
+
deploymentTarget = "autoscale"
|
8 |
+
run = ["sh", "-c", "python -m streamlit run DataHubHub/app.py --server.address=0.0.0.0 --server.port=5000 --server.headless=true --server.enableCORS=false --server.enableXsrfProtection=false"]
|
9 |
+
|
10 |
+
[workflows]
|
11 |
+
runButton = "Project"
|
12 |
+
|
13 |
+
[[workflows.workflow]]
|
14 |
+
name = "Project"
|
15 |
+
mode = "parallel"
|
16 |
+
author = "agent"
|
17 |
+
|
18 |
+
[[workflows.workflow.tasks]]
|
19 |
+
task = "workflow.run"
|
20 |
+
args = "Streamlit Server"
|
21 |
+
|
22 |
+
[[workflows.workflow]]
|
23 |
+
name = "Streamlit Server"
|
24 |
+
author = "agent"
|
25 |
+
|
26 |
+
[workflows.workflow.metadata]
|
27 |
+
agentRequireRestartOnSave = false
|
28 |
+
|
29 |
+
[[workflows.workflow.tasks]]
|
30 |
+
task = "packager.installForAll"
|
31 |
+
|
32 |
+
[[workflows.workflow.tasks]]
|
33 |
+
task = "shell.exec"
|
34 |
+
args = "python -m streamlit run DataHubHub/app.py --server.address=0.0.0.0 --server.port=5000 --server.headless=true --server.enableCORS=false --server.enableXsrfProtection=false"
|
35 |
+
waitForPort = 5000
|
36 |
+
|
37 |
+
[[ports]]
|
38 |
+
localPort = 5000
|
39 |
+
externalPort = 5000
|
40 |
+
|
41 |
+
[[ports]]
|
42 |
+
localPort = 8501
|
43 |
+
externalPort = 80
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
headless = true
|
3 |
+
address = "0.0.0.0"
|
4 |
+
port = 5000
|
5 |
+
enableCORS = true
|
6 |
+
enableXsrfProtection = false
|
7 |
+
|
8 |
+
[browser]
|
9 |
+
gatherUsageStats = false
|
.streamlit/secrets.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Secrets configuration for Hugging Face
|
2 |
+
# This file is for demonstration purposes only
|
3 |
+
# Replace with actual API keys if needed
|
4 |
+
|
5 |
+
[huggingface]
|
6 |
+
# Add your Hugging Face API token here if needed
|
7 |
+
# hf_token = "YOUR_HF_TOKEN"
|
README.md
CHANGED
@@ -1,13 +1,98 @@
|
|
1 |
---
|
2 |
title: DataHubHub
|
3 |
emoji: ⚡
|
4 |
-
colorFrom:
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.42.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: DataHubHub
|
3 |
emoji: ⚡
|
4 |
+
colorFrom: red
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.42.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
language: en
|
12 |
---
|
13 |
|
14 |
+
# ML Dataset & Code Generation Manager
|
15 |
+
|
16 |
+
A comprehensive platform for ML dataset management and code generation with Hugging Face integration.
|
17 |
+
|
18 |
+
## Features
|
19 |
+
|
20 |
+
- **Dataset Management**: Upload, explore, and manage machine learning datasets
|
21 |
+
- **Data Visualization**: Visualize dataset statistics and distributions
|
22 |
+
- **Code Generation**: Fine-tune models for code generation tasks
|
23 |
+
- **Code Quality Tools**: Improve code quality with integrated formatters, linters, and type checkers
|
24 |
+
|
25 |
+
## Technology Stack
|
26 |
+
|
27 |
+
- **Frontend**: Streamlit
|
28 |
+
- **Backend**: Python
|
29 |
+
- **Database**: SQLite (via SQLAlchemy)
|
30 |
+
- **ML Integration**: Hugging Face Transformers, Datasets
|
31 |
+
- **Visualization**: Plotly, Matplotlib
|
32 |
+
|
33 |
+
## Project Structure
|
34 |
+
|
35 |
+
```
|
36 |
+
.
|
37 |
+
├── app.py # Main application entry point
|
38 |
+
├── components/ # UI components
|
39 |
+
│ ├── code_quality.py # Code quality tools
|
40 |
+
│ ├── dataset_preview.py # Dataset preview component
|
41 |
+
│ ├── dataset_statistics.py # Dataset statistics component
|
42 |
+
│ ├── dataset_uploader.py # Dataset upload component
|
43 |
+
│ ├── dataset_validation.py # Dataset validation component
|
44 |
+
│ ├── dataset_visualization.py # Dataset visualization component
|
45 |
+
│ └── fine_tuning/ # Fine-tuning components
|
46 |
+
│ ├── finetune_ui.py # Fine-tuning UI
|
47 |
+
│ └── model_interface.py # Model interface
|
48 |
+
├── database/ # Database configuration
|
49 |
+
│ ├── models.py # Database models
|
50 |
+
│ └── operations.py # Database operations
|
51 |
+
├── utils/ # Utility functions
|
52 |
+
│ ├── dataset_utils.py # Dataset utilities
|
53 |
+
│ ├── huggingface_integration.py # Hugging Face integration
|
54 |
+
│ └── smolagents_integration.py # SmolaAgents integration
|
55 |
+
└── assets/ # Static assets
|
56 |
+
```
|
57 |
+
|
58 |
+
## Deployment
|
59 |
+
|
60 |
+
This application is designed to be deployed as a Hugging Face Space.
|
61 |
+
|
62 |
+
### Hugging Face Space Deployment
|
63 |
+
|
64 |
+
1. Fork this repository
|
65 |
+
2. Create a new Hugging Face Space
|
66 |
+
3. Connect the forked repository to your Space
|
67 |
+
4. The application will be deployed automatically
|
68 |
+
|
69 |
+
### Local Development
|
70 |
+
|
71 |
+
1. Clone the repository
|
72 |
+
2. Install dependencies:
|
73 |
+
```
|
74 |
+
pip install streamlit pandas numpy plotly matplotlib scikit-learn SQLAlchemy huggingface-hub datasets transformers torch
|
75 |
+
```
|
76 |
+
3. Run the application:
|
77 |
+
```
|
78 |
+
streamlit run app.py
|
79 |
+
```
|
80 |
+
|
81 |
+
## Configuration
|
82 |
+
|
83 |
+
- `.streamlit/config.toml`: Streamlit configuration
|
84 |
+
- `.streamlit/secrets.toml`: Secrets and API keys
|
85 |
+
- `huggingface-spacefile`: Hugging Face Space configuration
|
86 |
+
|
87 |
+
## API Keys
|
88 |
+
|
89 |
+
To use the Hugging Face integration features, add your Hugging Face API token to `.streamlit/secrets.toml`:
|
90 |
+
|
91 |
+
```toml
|
92 |
+
[huggingface]
|
93 |
+
hf_token = "YOUR_HF_TOKEN"
|
94 |
+
```
|
95 |
+
|
96 |
+
## License
|
97 |
+
|
98 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ML Dataset & Code Generation Manager - Streamlit Application
|
3 |
+
This is the main entry point for the Streamlit application.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Import from main.py
|
7 |
+
from main import main
|
8 |
+
|
9 |
+
# Execute the main function
|
10 |
+
if __name__ == "__main__":
|
11 |
+
main()
|
assets/custom.css
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
/* Custom styles for ML Dataset & Code Generation Manager */
|
3 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Space+Grotesk:wght@500;700&display=swap');
|
4 |
+
|
5 |
+
h1, h2, h3, h4, h5, h6 {
|
6 |
+
font-family: 'Space Grotesk', sans-serif;
|
7 |
+
font-weight: 700;
|
8 |
+
color: #1A1C1F;
|
9 |
+
}
|
10 |
+
|
11 |
+
body {
|
12 |
+
font-family: 'Inter', sans-serif;
|
13 |
+
color: #1A1C1F;
|
14 |
+
background-color: #F8F9FA;
|
15 |
+
}
|
16 |
+
|
17 |
+
.stButton button {
|
18 |
+
background-color: #2563EB;
|
19 |
+
color: white;
|
20 |
+
border-radius: 4px;
|
21 |
+
border: none;
|
22 |
+
padding: 0.5rem 1rem;
|
23 |
+
font-weight: 600;
|
24 |
+
}
|
25 |
+
|
26 |
+
.stButton button:hover {
|
27 |
+
background-color: #1D4ED8;
|
28 |
+
}
|
29 |
+
|
30 |
+
/* Card styling */
|
31 |
+
.card {
|
32 |
+
background-color: white;
|
33 |
+
border-radius: 8px;
|
34 |
+
padding: 1.5rem;
|
35 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
36 |
+
margin-bottom: 1rem;
|
37 |
+
}
|
38 |
+
|
39 |
+
/* Accent colors */
|
40 |
+
.accent-primary {
|
41 |
+
color: #2563EB;
|
42 |
+
}
|
43 |
+
|
44 |
+
.accent-secondary {
|
45 |
+
color: #84919A;
|
46 |
+
}
|
47 |
+
|
48 |
+
.accent-success {
|
49 |
+
color: #10B981;
|
50 |
+
}
|
51 |
+
|
52 |
+
.accent-warning {
|
53 |
+
color: #F59E0B;
|
54 |
+
}
|
55 |
+
|
56 |
+
.accent-danger {
|
57 |
+
color: #EF4444;
|
58 |
+
}
|
59 |
+
|
attached_assets/Pasted-Below-is-a-design-proposal-for-a-Hugging-Face-based-system-that-lets-users-fine-tune-a-code-generati-1740904225626.txt
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Below is a design proposal for a Hugging Face–based system that lets users fine-tune a code generation model via a simple Streamlit interface.
|
2 |
+
|
3 |
+
Overview:
|
4 |
+
1. Model & Library Setup:
|
5 |
+
• Use a pre-trained code generation model (e.g., CodeT5 or CodeT5-base) from Hugging Face.
|
6 |
+
• Leverage the Hugging Face Transformers and Datasets libraries together with the Hugging Face Trainer API to perform fine-tuning.
|
7 |
+
2. Streamlit Interface:
|
8 |
+
• Input Section: Users can upload a small dataset (e.g., a CSV file with code and target comments) or manually enter a few fine-tuning examples.
|
9 |
+
• Hyperparameter Controls: Sliders or input boxes for settings like learning rate, number of epochs, batch size, and maybe even a choice of optimizer.
|
10 |
+
• Execution Controls: Buttons to start fine-tuning and to monitor training progress (using, for example, real-time logging or a progress bar).
|
11 |
+
• Output Section: Display training metrics (loss curves, evaluation scores) and allow users to run inference on new prompts once fine-tuning completes.
|
12 |
+
3. Back-end Process:
|
13 |
+
• When the user initiates fine-tuning, the uploaded dataset is preprocessed (tokenization using the model’s tokenizer).
|
14 |
+
• A Trainer object is configured with the user-specified hyperparameters.
|
15 |
+
• Fine-tuning is launched (this can run in a background thread or via caching intermediate results).
|
16 |
+
• Once training is complete, the updated model can be saved to disk (or even directly loaded into the interface for inference).
|
17 |
+
4. Deployment & Reproducibility:
|
18 |
+
• The whole pipeline (data upload, preprocessing, training, evaluation, and inference) should be reproducible.
|
19 |
+
• Optionally, support saving the fine-tuned model and the training configuration to allow users to share their work.
|
20 |
+
|
21 |
+
Example Code Snippet (Simplified):
|
22 |
+
|
23 |
+
Below is a simplified version of what the Streamlit app might look like. (Note: In a production setup, you would want proper error handling and asynchronous processing.)
|
24 |
+
|
25 |
+
import streamlit as st
|
26 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
|
27 |
+
from datasets import load_dataset, Dataset
|
28 |
+
import torch
|
29 |
+
|
30 |
+
# Title
|
31 |
+
st.title("Fine-Tune Code Generation Model with Hugging Face & Streamlit")
|
32 |
+
|
33 |
+
# Sidebar: Hyperparameters
|
34 |
+
st.sidebar.header("Training Hyperparameters")
|
35 |
+
learning_rate = st.sidebar.slider("Learning Rate", 1e-6, 5e-5, 2e-5, 1e-6)
|
36 |
+
epochs = st.sidebar.number_input("Epochs", 1, 10, 3)
|
37 |
+
batch_size = st.sidebar.number_input("Batch Size", 4, 32, 8)
|
38 |
+
|
39 |
+
# Upload your fine-tuning data: CSV file with columns "input" and "target"
|
40 |
+
uploaded_file = st.file_uploader("Upload your fine-tuning dataset (CSV)", type="csv")
|
41 |
+
|
42 |
+
if uploaded_file is not None:
|
43 |
+
import pandas as pd
|
44 |
+
df = pd.read_csv(uploaded_file)
|
45 |
+
st.write("Dataset preview:", df.head())
|
46 |
+
# Convert to Hugging Face Dataset
|
47 |
+
dataset = Dataset.from_pandas(df)
|
48 |
+
else:
|
49 |
+
st.info("Please upload a CSV dataset with columns 'input' and 'target'.")
|
50 |
+
|
51 |
+
# Model selection
|
52 |
+
model_name = st.selectbox("Choose a model", ["Salesforce/codet5-base"])
|
53 |
+
|
54 |
+
# Load model and tokenizer
|
55 |
+
@st.cache_resource(show_spinner=False)
|
56 |
+
def load_model_and_tokenizer(name):
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
58 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(name)
|
59 |
+
return tokenizer, model
|
60 |
+
|
61 |
+
tokenizer, model = load_model_and_tokenizer(model_name)
|
62 |
+
|
63 |
+
# Preprocess function for tokenization
|
64 |
+
def preprocess_function(examples):
|
65 |
+
inputs = [f"translate code to comment: {ex}" for ex in examples["input"]]
|
66 |
+
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
|
67 |
+
with tokenizer.as_target_tokenizer():
|
68 |
+
labels = tokenizer(examples["target"], max_length=64, truncation=True)
|
69 |
+
model_inputs["labels"] = labels["input_ids"]
|
70 |
+
return model_inputs
|
71 |
+
|
72 |
+
if uploaded_file is not None:
|
73 |
+
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
74 |
+
|
75 |
+
# Setup training arguments
|
76 |
+
training_args = TrainingArguments(
|
77 |
+
output_dir="./results",
|
78 |
+
num_train_epochs=epochs,
|
79 |
+
per_device_train_batch_size=batch_size,
|
80 |
+
learning_rate=learning_rate,
|
81 |
+
logging_steps=10,
|
82 |
+
logging_dir='./logs',
|
83 |
+
report_to="none"
|
84 |
+
)
|
85 |
+
|
86 |
+
trainer = Trainer(
|
87 |
+
model=model,
|
88 |
+
args=training_args,
|
89 |
+
train_dataset=tokenized_dataset,
|
90 |
+
)
|
91 |
+
|
92 |
+
if st.button("Start Fine-Tuning"):
|
93 |
+
st.info("Fine-tuning started... This might take a while.")
|
94 |
+
trainer.train()
|
95 |
+
st.success("Fine-tuning complete!")
|
96 |
+
|
97 |
+
# Save the model to disk (or load it for inference)
|
98 |
+
model.save_pretrained("fine_tuned_model")
|
99 |
+
tokenizer.save_pretrained("fine_tuned_model")
|
100 |
+
st.write("Model saved to 'fine_tuned_model'.")
|
101 |
+
|
102 |
+
# Option to run inference on new inputs
|
103 |
+
user_input = st.text_area("Enter a new code prompt for inference:")
|
104 |
+
if user_input:
|
105 |
+
inputs = tokenizer(f"translate code to comment: {user_input}", return_tensors="pt", truncation=True)
|
106 |
+
outputs = model.generate(**inputs, max_length=64)
|
107 |
+
generated_comment = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
108 |
+
st.write("Generated comment:", generated_comment)
|
109 |
+
|
110 |
+
Key Points:
|
111 |
+
• User Interaction: The interface lets users set hyperparameters, upload datasets, and start fine-tuning.
|
112 |
+
• Model Integration: It uses Hugging Face’s pre-trained CodeT5 model and tokenizer, then fine-tunes on user-provided examples.
|
113 |
+
• Reproducibility: The pipeline includes caching, dataset conversion, and saving the final model.
|
114 |
+
• Extensibility: You can later add more options (e.g., additional hyperparameters, evaluation metrics, visualization of training progress).
|
115 |
+
|
116 |
+
This design should give you a robust, end-to-end solution to let users easily fine-tune a code generation model through a Streamlit interface. Would you like further details on any component of the design?
|
attached_assets/Pasted-For-a-robust-foundation-you-ll-want-to-configure-a-set-of-tools-that-catch-errors-early-maintain-c-1740904212802.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
For a robust foundation, you’ll want to configure a set of tools that catch errors early, maintain code quality, and ensure your model and interface work as intended. Here’s a comprehensive list:
|
2 |
+
• Linting and Formatting:
|
3 |
+
• Pylint or Flake8: Both are excellent for catching stylistic issues, potential errors, and enforcing coding standards.
|
4 |
+
• Black: An uncompromising code formatter that automatically reformats your code to a consistent style.
|
5 |
+
• isort: Automatically sorts your imports, keeping them tidy and making merge conflicts less likely.
|
6 |
+
• mypy: For static type checking—great for catching type mismatches early, especially in larger projects.
|
7 |
+
• Debugging:
|
8 |
+
• pdb or ipdb: Python’s built-in debugger (with ipdb providing a friendlier interface) lets you step through code interactively.
|
9 |
+
• VS Code Debugger: If you’re using VS Code, take advantage of its powerful debugging features with breakpoints, variable inspection, and integrated terminal support.
|
10 |
+
• Streamlit’s Debugging Tools: Streamlit now offers logging and error traceback views—integrate these for your interface to catch issues on the fly.
|
11 |
+
• Testing:
|
12 |
+
• pytest: A flexible testing framework that supports fixtures and parameterized tests. It’s widely used for both unit and integration tests.
|
13 |
+
• unittest: Python’s built-in framework for basic tests (though pytest often provides a more modern and user-friendly approach).
|
14 |
+
• coverage.py: To measure how much of your code is exercised by your tests, ensuring thorough test coverage.
|
15 |
+
• Tox: For running your tests in multiple environments, which is useful if your project depends on various Python versions or dependencies.
|
16 |
+
• Continuous Integration (CI):
|
17 |
+
• GitHub Actions or GitLab CI: Automate your linting, testing, and building processes so that every commit triggers your checks—keeping your repository healthy over time.
|
18 |
+
|
19 |
+
Setting these up at the beginning ensures that your code stays clean, errors are caught early, and your automated code generation pipeline is both reliable and production-ready. This not only speeds up development but also builds a solid foundation for scaling and collaboration.
|
attached_assets/Pasted-For-a-robust-foundation-you-ll-want-to-configure-a-set-of-tools-that-catch-errors-early-maintain-c-1740906031222.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
For a robust foundation, you’ll want to configure a set of tools that catch errors early, maintain code quality, and ensure your model and interface work as intended. Here’s a comprehensive list:
|
2 |
+
• Linting and Formatting:
|
3 |
+
• Pylint or Flake8: Both are excellent for catching stylistic issues, potential errors, and enforcing coding standards.
|
4 |
+
• Black: An uncompromising code formatter that automatically reformats your code to a consistent style.
|
5 |
+
• isort: Automatically sorts your imports, keeping them tidy and making merge conflicts less likely.
|
6 |
+
• mypy: For static type checking—great for catching type mismatches early, especially in larger projects.
|
7 |
+
• Debugging:
|
8 |
+
• pdb or ipdb: Python’s built-in debugger (with ipdb providing a friendlier interface) lets you step through code interactively.
|
9 |
+
• VS Code Debugger: If you’re using VS Code, take advantage of its powerful debugging features with breakpoints, variable inspection, and integrated terminal support.
|
10 |
+
• Streamlit’s Debugging Tools: Streamlit now offers logging and error traceback views—integrate these for your interface to catch issues on the fly.
|
11 |
+
• Testing:
|
12 |
+
• pytest: A flexible testing framework that supports fixtures and parameterized tests. It’s widely used for both unit and integration tests.
|
13 |
+
• unittest: Python’s built-in framework for basic tests (though pytest often provides a more modern and user-friendly approach).
|
14 |
+
• coverage.py: To measure how much of your code is exercised by your tests, ensuring thorough test coverage.
|
15 |
+
• Tox: For running your tests in multiple environments, which is useful if your project depends on various Python versions or dependencies.
|
16 |
+
• Continuous Integration (CI):
|
17 |
+
• GitHub Actions or GitLab CI: Automate your linting, testing, and building processes so that every commit triggers your checks—keeping your repository healthy over time.
|
18 |
+
|
19 |
+
Setting these up at the beginning ensures that your code stays clean, errors are caught early, and your automated code generation pipeline is both reliable and production-ready. This not only speeds up development but also builds a solid foundation for scaling and collaboration.
|
components/code_quality.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code quality tools and configuration for the application.
|
3 |
+
"""
|
4 |
+
import streamlit as st
|
5 |
+
import subprocess
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import tempfile
|
9 |
+
import json
|
10 |
+
|
11 |
+
def render_code_quality_tools():
|
12 |
+
"""
|
13 |
+
Render the code quality tools interface.
|
14 |
+
"""
|
15 |
+
st.markdown("<h2>Code Quality Tools</h2>", unsafe_allow_html=True)
|
16 |
+
|
17 |
+
# Tabs for different tools
|
18 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Linting", "Formatting", "Type Checking", "Testing"])
|
19 |
+
|
20 |
+
with tab1:
|
21 |
+
render_linting_tools()
|
22 |
+
|
23 |
+
with tab2:
|
24 |
+
render_formatting_tools()
|
25 |
+
|
26 |
+
with tab3:
|
27 |
+
render_type_checking_tools()
|
28 |
+
|
29 |
+
with tab4:
|
30 |
+
render_testing_tools()
|
31 |
+
|
32 |
+
def render_linting_tools():
|
33 |
+
"""
|
34 |
+
Render linting tools interface.
|
35 |
+
"""
|
36 |
+
st.markdown("### Linting with Pylint/Flake8")
|
37 |
+
st.markdown("""
|
38 |
+
Linting tools help identify potential errors, enforce coding standards, and encourage best practices.
|
39 |
+
|
40 |
+
**Available Tools:**
|
41 |
+
- **Pylint**: Comprehensive linter that checks for errors and enforces a coding standard
|
42 |
+
- **Flake8**: Wrapper around PyFlakes, pycodestyle, and McCabe complexity checker
|
43 |
+
""")
|
44 |
+
|
45 |
+
# File upload for linting
|
46 |
+
uploaded_file = st.file_uploader("Upload Python file for linting", type=["py"])
|
47 |
+
|
48 |
+
linter = st.radio("Select linter", ["Pylint", "Flake8"])
|
49 |
+
|
50 |
+
if uploaded_file and st.button("Run Linter"):
|
51 |
+
with st.spinner(f"Running {linter}..."):
|
52 |
+
# Save uploaded file to a temporary file
|
53 |
+
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp_file:
|
54 |
+
tmp_file.write(uploaded_file.getvalue())
|
55 |
+
tmp_path = tmp_file.name
|
56 |
+
|
57 |
+
try:
|
58 |
+
if linter == "Pylint":
|
59 |
+
# Run pylint
|
60 |
+
result = subprocess.run(
|
61 |
+
["pylint", tmp_path],
|
62 |
+
capture_output=True,
|
63 |
+
text=True
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
# Run flake8
|
67 |
+
result = subprocess.run(
|
68 |
+
["flake8", tmp_path],
|
69 |
+
capture_output=True,
|
70 |
+
text=True
|
71 |
+
)
|
72 |
+
|
73 |
+
# Display results
|
74 |
+
st.subheader("Linting Results")
|
75 |
+
if result.returncode == 0:
|
76 |
+
st.success("No issues found!")
|
77 |
+
else:
|
78 |
+
st.error("Issues found:")
|
79 |
+
st.code(result.stdout or result.stderr, language="text")
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
st.error(f"Error running {linter}: {str(e)}")
|
83 |
+
|
84 |
+
finally:
|
85 |
+
# Clean up temporary file
|
86 |
+
os.unlink(tmp_path)
|
87 |
+
|
88 |
+
def render_formatting_tools():
|
89 |
+
"""
|
90 |
+
Render code formatting tools interface.
|
91 |
+
"""
|
92 |
+
st.markdown("### Code Formatting with Black & isort")
|
93 |
+
st.markdown("""
|
94 |
+
Code formatters automatically reformat your code to follow a consistent style.
|
95 |
+
|
96 |
+
**Available Tools:**
|
97 |
+
- **Black**: The uncompromising Python code formatter
|
98 |
+
- **isort**: A utility to sort imports alphabetically and automatically separate them into sections
|
99 |
+
""")
|
100 |
+
|
101 |
+
# File upload for formatting
|
102 |
+
uploaded_file = st.file_uploader("Upload Python file for formatting", type=["py"])
|
103 |
+
|
104 |
+
formatter = st.radio("Select formatter", ["Black", "isort", "Both"])
|
105 |
+
|
106 |
+
if uploaded_file and st.button("Format Code"):
|
107 |
+
with st.spinner(f"Running {formatter}..."):
|
108 |
+
# Get original code
|
109 |
+
original_code = uploaded_file.getvalue().decode("utf-8")
|
110 |
+
|
111 |
+
# Save uploaded file to a temporary file
|
112 |
+
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp_file:
|
113 |
+
tmp_file.write(uploaded_file.getvalue())
|
114 |
+
tmp_path = tmp_file.name
|
115 |
+
|
116 |
+
try:
|
117 |
+
formatted_code = ""
|
118 |
+
|
119 |
+
if formatter in ["Black", "Both"]:
|
120 |
+
# Run black
|
121 |
+
result = subprocess.run(
|
122 |
+
["black", tmp_path],
|
123 |
+
capture_output=True,
|
124 |
+
text=True
|
125 |
+
)
|
126 |
+
|
127 |
+
with open(tmp_path, "r") as f:
|
128 |
+
formatted_code = f.read()
|
129 |
+
|
130 |
+
if formatter in ["isort", "Both"]:
|
131 |
+
# If both, use the code formatted by black
|
132 |
+
if formatter == "Both":
|
133 |
+
with open(tmp_path, "w") as f:
|
134 |
+
f.write(formatted_code)
|
135 |
+
|
136 |
+
# Run isort
|
137 |
+
result = subprocess.run(
|
138 |
+
["isort", tmp_path],
|
139 |
+
capture_output=True,
|
140 |
+
text=True
|
141 |
+
)
|
142 |
+
|
143 |
+
with open(tmp_path, "r") as f:
|
144 |
+
formatted_code = f.read()
|
145 |
+
|
146 |
+
# Display results side by side
|
147 |
+
st.subheader("Formatting Results")
|
148 |
+
col1, col2 = st.columns(2)
|
149 |
+
|
150 |
+
with col1:
|
151 |
+
st.markdown("#### Original Code")
|
152 |
+
st.code(original_code, language="python")
|
153 |
+
|
154 |
+
with col2:
|
155 |
+
st.markdown("#### Formatted Code")
|
156 |
+
st.code(formatted_code, language="python")
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
st.error(f"Error running {formatter}: {str(e)}")
|
160 |
+
|
161 |
+
finally:
|
162 |
+
# Clean up temporary file
|
163 |
+
os.unlink(tmp_path)
|
164 |
+
|
165 |
+
def render_type_checking_tools():
|
166 |
+
"""
|
167 |
+
Render type checking tools interface.
|
168 |
+
"""
|
169 |
+
st.markdown("### Type Checking with mypy")
|
170 |
+
st.markdown("""
|
171 |
+
Static type checking helps catch type errors before runtime.
|
172 |
+
|
173 |
+
**Available Tool:**
|
174 |
+
- **mypy**: Optional static typing for Python
|
175 |
+
""")
|
176 |
+
|
177 |
+
# File upload for type checking
|
178 |
+
uploaded_file = st.file_uploader("Upload Python file for type checking", type=["py"])
|
179 |
+
|
180 |
+
if uploaded_file and st.button("Check Types"):
|
181 |
+
with st.spinner("Running mypy..."):
|
182 |
+
# Save uploaded file to a temporary file
|
183 |
+
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp_file:
|
184 |
+
tmp_file.write(uploaded_file.getvalue())
|
185 |
+
tmp_path = tmp_file.name
|
186 |
+
|
187 |
+
try:
|
188 |
+
# Run mypy
|
189 |
+
result = subprocess.run(
|
190 |
+
["mypy", tmp_path],
|
191 |
+
capture_output=True,
|
192 |
+
text=True
|
193 |
+
)
|
194 |
+
|
195 |
+
# Display results
|
196 |
+
st.subheader("Type Checking Results")
|
197 |
+
if result.returncode == 0:
|
198 |
+
st.success("No type issues found!")
|
199 |
+
else:
|
200 |
+
st.error("Type issues found:")
|
201 |
+
st.code(result.stdout or result.stderr, language="text")
|
202 |
+
|
203 |
+
except Exception as e:
|
204 |
+
st.error(f"Error running mypy: {str(e)}")
|
205 |
+
|
206 |
+
finally:
|
207 |
+
# Clean up temporary file
|
208 |
+
os.unlink(tmp_path)
|
209 |
+
|
210 |
+
def render_testing_tools():
|
211 |
+
"""
|
212 |
+
Render testing tools interface.
|
213 |
+
"""
|
214 |
+
st.markdown("### Testing with pytest")
|
215 |
+
st.markdown("""
|
216 |
+
Testing frameworks help ensure your code works as expected.
|
217 |
+
|
218 |
+
**Available Tool:**
|
219 |
+
- **pytest**: Simple and powerful testing framework
|
220 |
+
""")
|
221 |
+
|
222 |
+
# Test file upload
|
223 |
+
test_file = st.file_uploader("Upload test file", type=["py"])
|
224 |
+
|
225 |
+
# Code file upload (optional)
|
226 |
+
code_file = st.file_uploader("Upload code file to test (optional)", type=["py"])
|
227 |
+
|
228 |
+
if test_file and st.button("Run Tests"):
|
229 |
+
with st.spinner("Running tests..."):
|
230 |
+
# Create temporary directory for test files
|
231 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
232 |
+
# Save test file
|
233 |
+
test_path = os.path.join(tmp_dir, "test_" + test_file.name)
|
234 |
+
with open(test_path, "wb") as f:
|
235 |
+
f.write(test_file.getvalue())
|
236 |
+
|
237 |
+
# Save code file if provided
|
238 |
+
if code_file:
|
239 |
+
code_path = os.path.join(tmp_dir, code_file.name)
|
240 |
+
with open(code_path, "wb") as f:
|
241 |
+
f.write(code_file.getvalue())
|
242 |
+
|
243 |
+
try:
|
244 |
+
# Run pytest
|
245 |
+
result = subprocess.run(
|
246 |
+
["pytest", "-v", test_path],
|
247 |
+
capture_output=True,
|
248 |
+
text=True
|
249 |
+
)
|
250 |
+
|
251 |
+
# Display results
|
252 |
+
st.subheader("Test Results")
|
253 |
+
st.code(result.stdout, language="text")
|
254 |
+
|
255 |
+
if result.returncode == 0:
|
256 |
+
st.success("All tests passed!")
|
257 |
+
else:
|
258 |
+
st.error("Some tests failed.")
|
259 |
+
|
260 |
+
except Exception as e:
|
261 |
+
st.error(f"Error running tests: {str(e)}")
|
262 |
+
|
263 |
+
def create_pylintrc():
|
264 |
+
"""
|
265 |
+
Create a sample pylintrc configuration file.
|
266 |
+
"""
|
267 |
+
pylintrc = """[MASTER]
|
268 |
+
# Python version
|
269 |
+
py-version = 3.8
|
270 |
+
|
271 |
+
# Parallel processing
|
272 |
+
jobs = 1
|
273 |
+
|
274 |
+
[MESSAGES CONTROL]
|
275 |
+
# Disable specific messages
|
276 |
+
disable=
|
277 |
+
C0111, # missing-docstring
|
278 |
+
C0103, # invalid-name
|
279 |
+
R0903, # too-few-public-methods
|
280 |
+
R0913, # too-many-arguments
|
281 |
+
W0511, # fixme
|
282 |
+
|
283 |
+
[FORMAT]
|
284 |
+
# Maximum line length
|
285 |
+
max-line-length = 100
|
286 |
+
|
287 |
+
# Expected indentation
|
288 |
+
indent-string = ' '
|
289 |
+
|
290 |
+
[DESIGN]
|
291 |
+
# Maximum number of locals for function / method body
|
292 |
+
max-locals = 15
|
293 |
+
|
294 |
+
# Maximum number of arguments for function / method
|
295 |
+
max-args = 5
|
296 |
+
|
297 |
+
# Maximum number of attributes for a class
|
298 |
+
max-attributes = 7
|
299 |
+
"""
|
300 |
+
return pylintrc
|
301 |
+
|
302 |
+
def create_flake8_config():
|
303 |
+
"""
|
304 |
+
Create a sample flake8 configuration file.
|
305 |
+
"""
|
306 |
+
flake8_config = """[flake8]
|
307 |
+
max-line-length = 100
|
308 |
+
exclude = .git,__pycache__,build,dist
|
309 |
+
ignore =
|
310 |
+
E203, # whitespace before ':'
|
311 |
+
E501, # line too long
|
312 |
+
W503 # line break before binary operator
|
313 |
+
"""
|
314 |
+
return flake8_config
|
315 |
+
|
316 |
+
def create_mypy_config():
|
317 |
+
"""
|
318 |
+
Create a sample mypy configuration file.
|
319 |
+
"""
|
320 |
+
mypy_config = """[mypy]
|
321 |
+
python_version = 3.8
|
322 |
+
warn_return_any = True
|
323 |
+
warn_unused_configs = True
|
324 |
+
disallow_untyped_defs = False
|
325 |
+
disallow_incomplete_defs = False
|
326 |
+
|
327 |
+
[mypy.plugins.numpy.*]
|
328 |
+
follow_imports = skip
|
329 |
+
|
330 |
+
[mypy.plugins.pandas.*]
|
331 |
+
follow_imports = skip
|
332 |
+
"""
|
333 |
+
return mypy_config
|
334 |
+
|
335 |
+
def create_pytest_config():
|
336 |
+
"""
|
337 |
+
Create a sample pytest configuration file.
|
338 |
+
"""
|
339 |
+
pytest_config = """[pytest]
|
340 |
+
testpaths = tests
|
341 |
+
python_files = test_*.py
|
342 |
+
python_functions = test_*
|
343 |
+
markers =
|
344 |
+
slow: marks tests as slow (deselect with '-m "not slow"')
|
345 |
+
integration: marks tests as integration tests
|
346 |
+
"""
|
347 |
+
return pytest_config
|
components/dataset_preview.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
|
5 |
+
def render_dataset_preview(dataset, dataset_type):
|
6 |
+
"""
|
7 |
+
Renders a preview of the dataset with pagination options.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
dataset: The dataset to preview (pandas DataFrame)
|
11 |
+
dataset_type: The type of dataset (csv, json, etc.)
|
12 |
+
"""
|
13 |
+
if dataset is None:
|
14 |
+
st.warning("No dataset to preview.")
|
15 |
+
return
|
16 |
+
|
17 |
+
st.markdown(f"<h3>Dataset Preview: {st.session_state.dataset_name}</h3>", unsafe_allow_html=True)
|
18 |
+
|
19 |
+
# Show basic info
|
20 |
+
col1, col2, col3 = st.columns(3)
|
21 |
+
with col1:
|
22 |
+
st.metric("Rows", f"{dataset.shape[0]:,}")
|
23 |
+
with col2:
|
24 |
+
st.metric("Columns", f"{dataset.shape[1]:,}")
|
25 |
+
with col3:
|
26 |
+
st.metric("Type", dataset_type.upper())
|
27 |
+
|
28 |
+
# Preview options
|
29 |
+
col1, col2 = st.columns([1, 3])
|
30 |
+
with col1:
|
31 |
+
num_rows = st.number_input("Rows to display", min_value=5, max_value=100, value=10, step=5)
|
32 |
+
with col2:
|
33 |
+
preview_mode = st.radio("Preview mode", ["Head", "Tail", "Sample"], horizontal=True)
|
34 |
+
|
35 |
+
# Display dataset preview
|
36 |
+
st.markdown("<div class='dataset-preview'>", unsafe_allow_html=True)
|
37 |
+
|
38 |
+
if preview_mode == "Head":
|
39 |
+
st.dataframe(dataset.head(num_rows), use_container_width=True)
|
40 |
+
elif preview_mode == "Tail":
|
41 |
+
st.dataframe(dataset.tail(num_rows), use_container_width=True)
|
42 |
+
else: # Sample
|
43 |
+
st.dataframe(dataset.sample(min(num_rows, len(dataset))), use_container_width=True)
|
44 |
+
|
45 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
46 |
+
|
47 |
+
# Show dataset schema
|
48 |
+
with st.expander("Dataset Schema"):
|
49 |
+
col1, col2 = st.columns(2)
|
50 |
+
|
51 |
+
with col1:
|
52 |
+
st.markdown("**Column Types**")
|
53 |
+
type_df = pd.DataFrame({
|
54 |
+
'Column': dataset.dtypes.index,
|
55 |
+
'Type': dataset.dtypes.values.astype(str)
|
56 |
+
})
|
57 |
+
st.dataframe(type_df, use_container_width=True)
|
58 |
+
|
59 |
+
with col2:
|
60 |
+
st.markdown("**Missing Values**")
|
61 |
+
missing_df = pd.DataFrame({
|
62 |
+
'Column': dataset.columns,
|
63 |
+
'Missing': dataset.isna().sum().values,
|
64 |
+
'Percentage': dataset.isna().sum().values / len(dataset) * 100
|
65 |
+
})
|
66 |
+
st.dataframe(missing_df.style.format({
|
67 |
+
'Percentage': '{:.2f}%'
|
68 |
+
}), use_container_width=True)
|
69 |
+
|
70 |
+
# Raw data
|
71 |
+
with st.expander("Raw Data (First 5 records)"):
|
72 |
+
if dataset_type == 'csv':
|
73 |
+
st.code(dataset.head(5).to_csv(index=False), language="text")
|
74 |
+
else: # json or jsonl
|
75 |
+
st.code(dataset.head(5).to_json(orient='records', indent=2), language="json")
|
components/dataset_statistics.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import plotly.express as px
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
|
7 |
+
def render_dataset_statistics(dataset, dataset_type):
|
8 |
+
"""
|
9 |
+
Renders statistical analysis of the dataset.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
dataset: The dataset to analyze (pandas DataFrame)
|
13 |
+
dataset_type: The type of dataset (csv, json, etc.)
|
14 |
+
"""
|
15 |
+
if dataset is None:
|
16 |
+
st.warning("No dataset to analyze.")
|
17 |
+
return
|
18 |
+
|
19 |
+
st.markdown("<h3>Dataset Statistics</h3>", unsafe_allow_html=True)
|
20 |
+
|
21 |
+
# Tabs for different kinds of statistics
|
22 |
+
tab1, tab2, tab3 = st.tabs(["Summary Statistics", "Distribution Analysis", "Correlation Analysis"])
|
23 |
+
|
24 |
+
with tab1:
|
25 |
+
# Summary statistics
|
26 |
+
st.markdown("### Summary Statistics")
|
27 |
+
|
28 |
+
# Filter only numeric columns for statistics
|
29 |
+
numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist()
|
30 |
+
|
31 |
+
if numeric_cols:
|
32 |
+
# Display summary statistics
|
33 |
+
st.dataframe(dataset[numeric_cols].describe().T.style.highlight_max(axis=1, color='#FFD21E'), use_container_width=True)
|
34 |
+
|
35 |
+
# Top values for categorical columns
|
36 |
+
categorical_cols = dataset.select_dtypes(exclude=[np.number]).columns.tolist()
|
37 |
+
if categorical_cols:
|
38 |
+
st.markdown("### Category Value Counts")
|
39 |
+
selected_cat_col = st.selectbox("Select categorical column", categorical_cols)
|
40 |
+
|
41 |
+
# Show top values and their counts
|
42 |
+
value_counts = dataset[selected_cat_col].value_counts().head(10)
|
43 |
+
fig = px.bar(
|
44 |
+
x=value_counts.index,
|
45 |
+
y=value_counts.values,
|
46 |
+
title=f"Top 10 values in {selected_cat_col}",
|
47 |
+
labels={"x": selected_cat_col, "y": "Count"},
|
48 |
+
color_discrete_sequence=["#2563EB"]
|
49 |
+
)
|
50 |
+
st.plotly_chart(fig, use_container_width=True)
|
51 |
+
else:
|
52 |
+
st.warning("No numeric columns found in the dataset.")
|
53 |
+
|
54 |
+
with tab2:
|
55 |
+
# Distribution analysis
|
56 |
+
st.markdown("### Distribution Analysis")
|
57 |
+
|
58 |
+
if numeric_cols:
|
59 |
+
selected_num_col = st.selectbox("Select numeric column", numeric_cols)
|
60 |
+
|
61 |
+
# Create distribution plot
|
62 |
+
fig = px.histogram(
|
63 |
+
dataset,
|
64 |
+
x=selected_num_col,
|
65 |
+
title=f"Distribution of {selected_num_col}",
|
66 |
+
marginal="box",
|
67 |
+
color_discrete_sequence=["#FFD21E"],
|
68 |
+
template="simple_white"
|
69 |
+
)
|
70 |
+
st.plotly_chart(fig, use_container_width=True)
|
71 |
+
|
72 |
+
# Basic distribution stats
|
73 |
+
col1, col2, col3, col4 = st.columns(4)
|
74 |
+
with col1:
|
75 |
+
st.metric("Mean", f"{dataset[selected_num_col].mean():.2f}")
|
76 |
+
with col2:
|
77 |
+
st.metric("Median", f"{dataset[selected_num_col].median():.2f}")
|
78 |
+
with col3:
|
79 |
+
st.metric("Min", f"{dataset[selected_num_col].min():.2f}")
|
80 |
+
with col4:
|
81 |
+
st.metric("Max", f"{dataset[selected_num_col].max():.2f}")
|
82 |
+
else:
|
83 |
+
st.warning("No numeric columns found in the dataset.")
|
84 |
+
|
85 |
+
with tab3:
|
86 |
+
# Correlation analysis
|
87 |
+
st.markdown("### Correlation Analysis")
|
88 |
+
|
89 |
+
if len(numeric_cols) > 1:
|
90 |
+
# Compute correlation matrix
|
91 |
+
corr_matrix = dataset[numeric_cols].corr()
|
92 |
+
|
93 |
+
# Plot heatmap
|
94 |
+
fig = px.imshow(
|
95 |
+
corr_matrix,
|
96 |
+
color_continuous_scale=["#84919A", "#FFFFFF", "#FFD21E"],
|
97 |
+
title="Correlation Matrix",
|
98 |
+
template="simple_white"
|
99 |
+
)
|
100 |
+
st.plotly_chart(fig, use_container_width=True)
|
101 |
+
|
102 |
+
# Top correlated features
|
103 |
+
st.markdown("### Top Correlated Features")
|
104 |
+
|
105 |
+
# Convert correlation matrix to a long format
|
106 |
+
corr_pairs = []
|
107 |
+
for i in range(len(corr_matrix.columns)):
|
108 |
+
for j in range(i+1, len(corr_matrix.columns)):
|
109 |
+
col1 = corr_matrix.columns[i]
|
110 |
+
col2 = corr_matrix.columns[j]
|
111 |
+
corr_value = corr_matrix.iloc[i, j]
|
112 |
+
corr_pairs.append((col1, col2, corr_value))
|
113 |
+
|
114 |
+
# Sort by absolute correlation
|
115 |
+
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
|
116 |
+
|
117 |
+
# Display top 10 correlated pairs
|
118 |
+
if corr_pairs:
|
119 |
+
top_pairs = pd.DataFrame(corr_pairs[:10], columns=["Feature 1", "Feature 2", "Correlation"])
|
120 |
+
st.dataframe(
|
121 |
+
top_pairs.style.format({
|
122 |
+
"Correlation": "{:.4f}"
|
123 |
+
}).background_gradient(subset=["Correlation"], cmap="coolwarm"),
|
124 |
+
use_container_width=True
|
125 |
+
)
|
126 |
+
|
127 |
+
# Scatter plot for the top correlated pair
|
128 |
+
if corr_pairs:
|
129 |
+
top_pair = corr_pairs[0]
|
130 |
+
fig = px.scatter(
|
131 |
+
dataset,
|
132 |
+
x=top_pair[0],
|
133 |
+
y=top_pair[1],
|
134 |
+
title=f"Scatter plot: {top_pair[0]} vs {top_pair[1]} (Corr: {top_pair[2]:.4f})",
|
135 |
+
color_discrete_sequence=["#2563EB"],
|
136 |
+
template="simple_white"
|
137 |
+
)
|
138 |
+
fig.add_traces(
|
139 |
+
go.Scatter(
|
140 |
+
x=[None],
|
141 |
+
y=[None],
|
142 |
+
mode='lines',
|
143 |
+
line=dict(color="#FFD21E", width=3),
|
144 |
+
name='Best Fit'
|
145 |
+
)
|
146 |
+
)
|
147 |
+
st.plotly_chart(fig, use_container_width=True)
|
148 |
+
else:
|
149 |
+
st.warning("Need at least two numeric columns for correlation analysis.")
|
components/dataset_uploader.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
import io
|
5 |
+
from utils.dataset_utils import get_dataset_info, detect_dataset_format
|
6 |
+
|
7 |
+
def render_dataset_uploader():
|
8 |
+
"""
|
9 |
+
Renders the dataset upload component that supports CSV and JSON formats.
|
10 |
+
"""
|
11 |
+
st.markdown("""
|
12 |
+
<div class="upload-container">
|
13 |
+
<p>Upload your dataset in CSV or JSON format</p>
|
14 |
+
</div>
|
15 |
+
""", unsafe_allow_html=True)
|
16 |
+
|
17 |
+
# File uploader
|
18 |
+
uploaded_file = st.file_uploader(
|
19 |
+
"Choose a file",
|
20 |
+
type=["csv", "json"],
|
21 |
+
help="Upload a CSV or JSON file containing your dataset"
|
22 |
+
)
|
23 |
+
|
24 |
+
# Sample dataset option
|
25 |
+
st.markdown("Or use a sample dataset:")
|
26 |
+
sample_dataset = st.selectbox(
|
27 |
+
"Select a sample dataset",
|
28 |
+
["None", "Iris Dataset", "Titanic Dataset", "Boston Housing Dataset"]
|
29 |
+
)
|
30 |
+
|
31 |
+
# Process uploaded file
|
32 |
+
if uploaded_file is not None:
|
33 |
+
try:
|
34 |
+
# Check file extension
|
35 |
+
file_extension = uploaded_file.name.split(".")[-1].lower()
|
36 |
+
|
37 |
+
if file_extension == "csv":
|
38 |
+
df = pd.read_csv(uploaded_file)
|
39 |
+
dataset_type = "csv"
|
40 |
+
elif file_extension == "json":
|
41 |
+
# Try different JSON formats
|
42 |
+
try:
|
43 |
+
# First try parsing as a regular JSON with records orientation
|
44 |
+
df = pd.read_json(uploaded_file)
|
45 |
+
dataset_type = "json"
|
46 |
+
except:
|
47 |
+
# If that fails, try to parse as JSON Lines
|
48 |
+
try:
|
49 |
+
df = pd.read_json(uploaded_file, lines=True)
|
50 |
+
dataset_type = "jsonl"
|
51 |
+
except:
|
52 |
+
# If that also fails, load raw JSON and convert
|
53 |
+
content = json.loads(uploaded_file.getvalue().decode("utf-8"))
|
54 |
+
if isinstance(content, list):
|
55 |
+
df = pd.DataFrame(content)
|
56 |
+
elif isinstance(content, dict):
|
57 |
+
# Handle nested dict structures
|
58 |
+
if any(isinstance(v, list) for v in content.values()):
|
59 |
+
# Find the list field and use it
|
60 |
+
for key, value in content.items():
|
61 |
+
if isinstance(value, list):
|
62 |
+
df = pd.DataFrame(value)
|
63 |
+
break
|
64 |
+
else:
|
65 |
+
# Flat dict or dict of dicts
|
66 |
+
df = pd.DataFrame([content])
|
67 |
+
dataset_type = "json"
|
68 |
+
else:
|
69 |
+
st.error(f"Unsupported file format: {file_extension}")
|
70 |
+
return
|
71 |
+
|
72 |
+
# Store dataset and its info in session state
|
73 |
+
st.session_state.dataset = df
|
74 |
+
st.session_state.dataset_name = uploaded_file.name
|
75 |
+
st.session_state.dataset_type = dataset_type
|
76 |
+
st.session_state.dataset_info = get_dataset_info(df)
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
st.error(f"Error loading dataset: {str(e)}")
|
80 |
+
|
81 |
+
# Process sample dataset
|
82 |
+
elif sample_dataset != "None":
|
83 |
+
try:
|
84 |
+
if sample_dataset == "Iris Dataset":
|
85 |
+
# Load Iris dataset
|
86 |
+
from sklearn.datasets import load_iris
|
87 |
+
iris = load_iris()
|
88 |
+
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
|
89 |
+
df['target'] = iris.target
|
90 |
+
dataset_type = "csv"
|
91 |
+
|
92 |
+
elif sample_dataset == "Titanic Dataset":
|
93 |
+
# URL for Titanic dataset
|
94 |
+
url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
|
95 |
+
df = pd.read_csv(url)
|
96 |
+
dataset_type = "csv"
|
97 |
+
|
98 |
+
elif sample_dataset == "Boston Housing Dataset":
|
99 |
+
# Load Boston Housing dataset
|
100 |
+
from sklearn.datasets import fetch_california_housing
|
101 |
+
housing = fetch_california_housing()
|
102 |
+
df = pd.DataFrame(data=housing.data, columns=housing.feature_names)
|
103 |
+
df['target'] = housing.target
|
104 |
+
dataset_type = "csv"
|
105 |
+
|
106 |
+
# Store dataset and its info in session state
|
107 |
+
st.session_state.dataset = df
|
108 |
+
st.session_state.dataset_name = sample_dataset
|
109 |
+
st.session_state.dataset_type = dataset_type
|
110 |
+
st.session_state.dataset_info = get_dataset_info(df)
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
st.error(f"Error loading sample dataset: {str(e)}")
|
components/dataset_validation.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from utils.dataset_utils import check_column_completeness, detect_outliers
|
6 |
+
|
7 |
+
def render_dataset_validation(dataset, dataset_type):
|
8 |
+
"""
|
9 |
+
Renders validation checks for the dataset.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
dataset: The dataset to validate (pandas DataFrame)
|
13 |
+
dataset_type: The type of dataset (csv, json, etc.)
|
14 |
+
"""
|
15 |
+
if dataset is None:
|
16 |
+
st.warning("No dataset to validate.")
|
17 |
+
return
|
18 |
+
|
19 |
+
st.markdown("<h3>Dataset Validation</h3>", unsafe_allow_html=True)
|
20 |
+
|
21 |
+
# Data quality metrics
|
22 |
+
col1, col2, col3, col4 = st.columns(4)
|
23 |
+
|
24 |
+
# Calculate data quality metrics
|
25 |
+
total_cells = dataset.shape[0] * dataset.shape[1]
|
26 |
+
missing_cells = dataset.isna().sum().sum()
|
27 |
+
missing_percentage = (missing_cells / total_cells) * 100 if total_cells > 0 else 0
|
28 |
+
duplicate_rows = dataset.duplicated().sum()
|
29 |
+
duplicate_percentage = (duplicate_rows / dataset.shape[0]) * 100 if dataset.shape[0] > 0 else 0
|
30 |
+
|
31 |
+
with col1:
|
32 |
+
st.metric("Completeness", f"{100 - missing_percentage:.2f}%")
|
33 |
+
with col2:
|
34 |
+
st.metric("Missing Values", f"{missing_cells:,} ({missing_percentage:.2f}%)")
|
35 |
+
with col3:
|
36 |
+
st.metric("Duplicate Rows", f"{duplicate_rows:,} ({duplicate_percentage:.2f}%)")
|
37 |
+
with col4:
|
38 |
+
# Quality score is a simple metric between 0-100 based on completeness and duplicates
|
39 |
+
quality_score = 100 - (missing_percentage + duplicate_percentage)
|
40 |
+
quality_score = max(0, min(100, quality_score)) # Clamp between 0 and 100
|
41 |
+
st.metric("Quality Score", f"{quality_score:.2f}/100")
|
42 |
+
|
43 |
+
# Tabs for different validation aspects
|
44 |
+
tab1, tab2 = st.tabs(["Data Quality Issues", "Anomaly Detection"])
|
45 |
+
|
46 |
+
with tab1:
|
47 |
+
st.markdown("### Data Quality Issues")
|
48 |
+
|
49 |
+
# Check for missing values by column
|
50 |
+
missing_by_col = dataset.isna().sum()
|
51 |
+
missing_by_col = missing_by_col[missing_by_col > 0]
|
52 |
+
|
53 |
+
if not missing_by_col.empty:
|
54 |
+
st.markdown("#### Missing Values by Column")
|
55 |
+
missing_df = pd.DataFrame({
|
56 |
+
'Column': missing_by_col.index,
|
57 |
+
'Missing Count': missing_by_col.values,
|
58 |
+
'Percentage': (missing_by_col.values / dataset.shape[0] * 100).round(2)
|
59 |
+
})
|
60 |
+
missing_df['Status'] = missing_df['Percentage'].apply(
|
61 |
+
lambda x: "🟢 Good" if x < 5 else ("🟠 Warning" if x < 20 else "🔴 Critical")
|
62 |
+
)
|
63 |
+
|
64 |
+
st.dataframe(
|
65 |
+
missing_df.style.format({
|
66 |
+
'Percentage': '{:.2f}%'
|
67 |
+
}).background_gradient(subset=['Percentage'], cmap='Reds'),
|
68 |
+
use_container_width=True
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
st.success("No missing values found in the dataset!")
|
72 |
+
|
73 |
+
# Check for duplicate rows
|
74 |
+
if duplicate_rows > 0:
|
75 |
+
st.markdown("#### Duplicate Rows")
|
76 |
+
st.warning(f"Found {duplicate_rows} duplicate rows ({duplicate_percentage:.2f}% of the dataset)")
|
77 |
+
|
78 |
+
# Option to show duplicates
|
79 |
+
if st.checkbox("Show duplicates"):
|
80 |
+
st.dataframe(dataset[dataset.duplicated(keep='first')], use_container_width=True)
|
81 |
+
else:
|
82 |
+
st.success("No duplicate rows found in the dataset!")
|
83 |
+
|
84 |
+
# Check column data types
|
85 |
+
st.markdown("#### Column Data Types")
|
86 |
+
type_issues = []
|
87 |
+
|
88 |
+
for col in dataset.columns:
|
89 |
+
dtype = dataset[col].dtype
|
90 |
+
if dtype == 'object':
|
91 |
+
# Check if it could be numeric
|
92 |
+
try:
|
93 |
+
# Try to convert a sample to numeric
|
94 |
+
sample = dataset[col].dropna().head(100)
|
95 |
+
if len(sample) > 0:
|
96 |
+
numeric_count = pd.to_numeric(sample, errors='coerce').notna().sum()
|
97 |
+
if numeric_count / len(sample) > 0.8: # If more than 80% can be converted
|
98 |
+
type_issues.append({
|
99 |
+
'Column': col,
|
100 |
+
'Current Type': 'object',
|
101 |
+
'Suggested Type': 'numeric',
|
102 |
+
'Issue': 'Column contains mostly numeric values but is stored as text'
|
103 |
+
})
|
104 |
+
continue
|
105 |
+
except:
|
106 |
+
pass
|
107 |
+
|
108 |
+
# Check if it could be datetime
|
109 |
+
try:
|
110 |
+
sample = dataset[col].dropna().head(100)
|
111 |
+
if len(sample) > 0:
|
112 |
+
datetime_count = pd.to_datetime(sample, errors='coerce').notna().sum()
|
113 |
+
if datetime_count / len(sample) > 0.8: # If more than 80% can be converted
|
114 |
+
type_issues.append({
|
115 |
+
'Column': col,
|
116 |
+
'Current Type': 'object',
|
117 |
+
'Suggested Type': 'datetime',
|
118 |
+
'Issue': 'Column contains mostly dates but is stored as text'
|
119 |
+
})
|
120 |
+
except:
|
121 |
+
pass
|
122 |
+
|
123 |
+
if type_issues:
|
124 |
+
st.dataframe(pd.DataFrame(type_issues), use_container_width=True)
|
125 |
+
else:
|
126 |
+
st.success("No data type issues detected!")
|
127 |
+
|
128 |
+
# Check for column completeness
|
129 |
+
st.markdown("#### Column Completeness Check")
|
130 |
+
completeness_results = check_column_completeness(dataset)
|
131 |
+
if completeness_results:
|
132 |
+
st.dataframe(pd.DataFrame(completeness_results), use_container_width=True)
|
133 |
+
else:
|
134 |
+
st.success("All columns have good completeness!")
|
135 |
+
|
136 |
+
with tab2:
|
137 |
+
st.markdown("### Anomaly Detection")
|
138 |
+
|
139 |
+
# Detect outliers in numeric columns
|
140 |
+
numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist()
|
141 |
+
|
142 |
+
if numeric_cols:
|
143 |
+
selected_num_col = st.selectbox("Select column to check for outliers", numeric_cols)
|
144 |
+
|
145 |
+
outliers, lower_bound, upper_bound = detect_outliers(dataset[selected_num_col])
|
146 |
+
outlier_percentage = (len(outliers) / len(dataset)) * 100
|
147 |
+
|
148 |
+
st.markdown(f"#### Outliers in column: {selected_num_col}")
|
149 |
+
st.metric("Outliers Detected", f"{len(outliers)} ({outlier_percentage:.2f}%)")
|
150 |
+
|
151 |
+
st.markdown(f"""
|
152 |
+
**Bounds for outlier detection:**
|
153 |
+
- Lower bound: {lower_bound:.4f}
|
154 |
+
- Upper bound: {upper_bound:.4f}
|
155 |
+
""")
|
156 |
+
|
157 |
+
if len(outliers) > 0:
|
158 |
+
# Plot with outliers highlighted
|
159 |
+
import plotly.express as px
|
160 |
+
|
161 |
+
# Create a new column for coloring
|
162 |
+
temp_df = dataset.copy()
|
163 |
+
temp_df['is_outlier'] = temp_df.index.isin(outliers)
|
164 |
+
|
165 |
+
fig = px.box(
|
166 |
+
temp_df,
|
167 |
+
y=selected_num_col,
|
168 |
+
color='is_outlier',
|
169 |
+
color_discrete_map={True: "#FF5757", False: "#2563EB"},
|
170 |
+
title=f"Outliers in {selected_num_col}",
|
171 |
+
labels={"is_outlier": "Is Outlier"}
|
172 |
+
)
|
173 |
+
st.plotly_chart(fig, use_container_width=True)
|
174 |
+
|
175 |
+
# Option to show outliers in table
|
176 |
+
if st.checkbox("Show outlier data"):
|
177 |
+
st.dataframe(dataset.loc[outliers], use_container_width=True)
|
178 |
+
else:
|
179 |
+
st.success(f"No outliers detected in {selected_num_col}!")
|
180 |
+
else:
|
181 |
+
st.warning("No numeric columns found for outlier detection.")
|
components/dataset_version_control.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset version control UI component for the ML Dataset & Code Generation Manager.
|
3 |
+
Provides UI for viewing, comparing, and restoring dataset versions.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import datetime
|
10 |
+
import hashlib
|
11 |
+
import plotly.express as px
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Dict, List, Optional, Tuple, Any
|
14 |
+
|
15 |
+
from database import version_control
|
16 |
+
|
17 |
+
def render_version_control_ui(dataset_id: int, df: Optional[pd.DataFrame] = None):
|
18 |
+
"""
|
19 |
+
Render the version control UI for a dataset
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dataset_id: ID of the dataset
|
23 |
+
df: Current DataFrame of the dataset (optional)
|
24 |
+
"""
|
25 |
+
st.header("Dataset Version Control")
|
26 |
+
|
27 |
+
# Get all versions of the dataset
|
28 |
+
versions = version_control.get_versions(dataset_id)
|
29 |
+
|
30 |
+
if not versions:
|
31 |
+
st.info("No versions found for this dataset. Save changes to create the first version.")
|
32 |
+
|
33 |
+
if df is not None and st.button("Create Initial Version"):
|
34 |
+
version = version_control.create_version(
|
35 |
+
dataset_id=dataset_id,
|
36 |
+
df=df,
|
37 |
+
description="Initial version"
|
38 |
+
)
|
39 |
+
st.success(f"Created initial version: {version.version_id}")
|
40 |
+
st.experimental_rerun()
|
41 |
+
|
42 |
+
return
|
43 |
+
|
44 |
+
# Display version history
|
45 |
+
st.subheader("Version History")
|
46 |
+
|
47 |
+
version_data = []
|
48 |
+
for v in versions:
|
49 |
+
version_data.append({
|
50 |
+
"Version ID": v.version_id,
|
51 |
+
"Date": v.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
|
52 |
+
"Rows": v.metadata.get("rows", "N/A"),
|
53 |
+
"Columns": v.metadata.get("columns", "N/A"),
|
54 |
+
"Description": v.description
|
55 |
+
})
|
56 |
+
|
57 |
+
version_df = pd.DataFrame(version_data)
|
58 |
+
st.dataframe(version_df, use_container_width=True)
|
59 |
+
|
60 |
+
# Version actions section
|
61 |
+
st.subheader("Version Actions")
|
62 |
+
|
63 |
+
col1, col2 = st.columns(2)
|
64 |
+
|
65 |
+
with col1:
|
66 |
+
selected_version = st.selectbox(
|
67 |
+
"Select Version",
|
68 |
+
options=[v.version_id for v in versions],
|
69 |
+
format_func=lambda x: f"{x} - {next((v.timestamp.strftime('%Y-%m-%d %H:%M:%S') for v in versions if v.version_id == x), '')}"
|
70 |
+
)
|
71 |
+
|
72 |
+
# Get selected version object
|
73 |
+
selected_v = next((v for v in versions if v.version_id == selected_version), None)
|
74 |
+
|
75 |
+
if selected_v:
|
76 |
+
st.write(f"**Description:** {selected_v.description}")
|
77 |
+
st.write(f"**Created:** {selected_v.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
78 |
+
|
79 |
+
# Display metadata
|
80 |
+
if selected_v.metadata:
|
81 |
+
with st.expander("Version Metadata"):
|
82 |
+
for key, value in selected_v.metadata.items():
|
83 |
+
if key != "column_names": # Show column names separately
|
84 |
+
st.write(f"**{key}:** {value}")
|
85 |
+
|
86 |
+
if "column_names" in selected_v.metadata:
|
87 |
+
st.write("**Columns:**")
|
88 |
+
st.write(", ".join(selected_v.metadata["column_names"]))
|
89 |
+
|
90 |
+
with col2:
|
91 |
+
st.write("**Actions:**")
|
92 |
+
|
93 |
+
if selected_v:
|
94 |
+
# Load selected version
|
95 |
+
if st.button("View Version Data"):
|
96 |
+
version_df = version_control.load_version_data(selected_v)
|
97 |
+
st.session_state["viewing_version_df"] = version_df
|
98 |
+
st.session_state["viewing_version_id"] = selected_v.version_id
|
99 |
+
|
100 |
+
# Restore version
|
101 |
+
if st.button("Restore This Version"):
|
102 |
+
if df is not None:
|
103 |
+
description = st.session_state.get("restore_description", f"Restored from {selected_v.version_id}")
|
104 |
+
new_version = version_control.restore_version(
|
105 |
+
dataset_id=dataset_id,
|
106 |
+
version_id=selected_v.version_id,
|
107 |
+
description=description
|
108 |
+
)
|
109 |
+
st.success(f"Restored version {selected_v.version_id} as new version {new_version.version_id}")
|
110 |
+
st.experimental_rerun()
|
111 |
+
else:
|
112 |
+
st.error("Cannot restore version: No dataset provided")
|
113 |
+
|
114 |
+
# Compare versions
|
115 |
+
if len(versions) > 1:
|
116 |
+
st.write("**Compare Versions:**")
|
117 |
+
compare_v1 = st.selectbox("Version 1", options=[v.version_id for v in versions], key="compare_v1")
|
118 |
+
compare_v2 = st.selectbox("Version 2", options=[v.version_id for v in versions], key="compare_v2")
|
119 |
+
|
120 |
+
if st.button("Compare Versions"):
|
121 |
+
if compare_v1 != compare_v2:
|
122 |
+
comparison = version_control.compare_versions(
|
123 |
+
dataset_id=dataset_id,
|
124 |
+
version_id1=compare_v1,
|
125 |
+
version_id2=compare_v2
|
126 |
+
)
|
127 |
+
st.session_state["version_comparison"] = comparison
|
128 |
+
else:
|
129 |
+
st.warning("Please select different versions to compare")
|
130 |
+
|
131 |
+
# Show version data if requested
|
132 |
+
if "viewing_version_df" in st.session_state:
|
133 |
+
st.subheader(f"Data for Version: {st.session_state['viewing_version_id']}")
|
134 |
+
st.dataframe(st.session_state["viewing_version_df"], use_container_width=True)
|
135 |
+
|
136 |
+
if st.button("Clear Version View"):
|
137 |
+
del st.session_state["viewing_version_df"]
|
138 |
+
del st.session_state["viewing_version_id"]
|
139 |
+
st.experimental_rerun()
|
140 |
+
|
141 |
+
# Show version comparison if requested
|
142 |
+
if "version_comparison" in st.session_state:
|
143 |
+
comparison = st.session_state["version_comparison"]
|
144 |
+
st.subheader(f"Version Comparison")
|
145 |
+
|
146 |
+
col1, col2 = st.columns(2)
|
147 |
+
|
148 |
+
with col1:
|
149 |
+
st.write(f"**Version 1:** {comparison['version1']}")
|
150 |
+
st.write(f"**Date:** {comparison['version1_timestamp'].strftime('%Y-%m-%d %H:%M:%S')}")
|
151 |
+
|
152 |
+
with col2:
|
153 |
+
st.write(f"**Version 2:** {comparison['version2']}")
|
154 |
+
st.write(f"**Date:** {comparison['version2_timestamp'].strftime('%Y-%m-%d %H:%M:%S')}")
|
155 |
+
|
156 |
+
st.write(f"**Rows Changed:** {comparison['rows_diff']} ({'+' if comparison['rows_diff'] > 0 else ''}{comparison['rows_diff']})")
|
157 |
+
|
158 |
+
if comparison["columns_added"]:
|
159 |
+
st.write("**Columns Added:**")
|
160 |
+
for col in comparison["columns_added"]:
|
161 |
+
st.write(f"- {col}")
|
162 |
+
|
163 |
+
if comparison["columns_removed"]:
|
164 |
+
st.write("**Columns Removed:**")
|
165 |
+
for col in comparison["columns_removed"]:
|
166 |
+
st.write(f"- {col}")
|
167 |
+
|
168 |
+
if comparison["columns_diff"]:
|
169 |
+
st.write("**Columns Changed:**")
|
170 |
+
for col, diff in comparison["columns_diff"].items():
|
171 |
+
if diff.get("type_changed", False):
|
172 |
+
st.write(f"- {col}: Type changed from {diff['type1']} to {diff['type2']}")
|
173 |
+
elif diff.get("values_changed", False):
|
174 |
+
st.write(f"- {col}: Values changed")
|
175 |
+
|
176 |
+
if st.button("Clear Comparison"):
|
177 |
+
del st.session_state["version_comparison"]
|
178 |
+
st.experimental_rerun()
|
179 |
+
|
180 |
+
def render_save_version_ui(dataset_id: int, df: pd.DataFrame):
|
181 |
+
"""
|
182 |
+
Render UI for saving a new version of a dataset
|
183 |
+
|
184 |
+
Args:
|
185 |
+
dataset_id: ID of the dataset
|
186 |
+
df: DataFrame to save
|
187 |
+
"""
|
188 |
+
st.subheader("Save Current Version")
|
189 |
+
|
190 |
+
# Get latest version if any
|
191 |
+
latest_version = version_control.get_latest_version(dataset_id)
|
192 |
+
|
193 |
+
# Calculate changes if a previous version exists
|
194 |
+
if latest_version:
|
195 |
+
try:
|
196 |
+
prev_df = version_control.load_version_data(latest_version)
|
197 |
+
rows_diff = len(df) - len(prev_df)
|
198 |
+
cols_diff = len(df.columns) - len(prev_df.columns)
|
199 |
+
|
200 |
+
st.write(f"Changes from last version:")
|
201 |
+
st.write(f"- Rows: {'+' if rows_diff > 0 else ''}{rows_diff}")
|
202 |
+
st.write(f"- Columns: {'+' if cols_diff > 0 else ''}{cols_diff}")
|
203 |
+
|
204 |
+
# Check content hash
|
205 |
+
current_hash = hashlib.md5(df.to_json().encode()).hexdigest()[:8]
|
206 |
+
if current_hash == latest_version.metadata.get("content_hash"):
|
207 |
+
st.info("No changes detected in the data content since the last version.")
|
208 |
+
except:
|
209 |
+
st.warning("Could not compare with previous version.")
|
210 |
+
|
211 |
+
# Input for version description
|
212 |
+
description = st.text_area("Version Description", placeholder="Describe the changes in this version", key="version_description")
|
213 |
+
|
214 |
+
# Save button
|
215 |
+
if st.button("Save Version"):
|
216 |
+
version = version_control.create_version(
|
217 |
+
dataset_id=dataset_id,
|
218 |
+
df=df,
|
219 |
+
description=description
|
220 |
+
)
|
221 |
+
st.success(f"Created new version: {version.version_id}")
|
222 |
+
|
223 |
+
return version
|
224 |
+
|
225 |
+
return None
|
226 |
+
|
227 |
+
def render_version_visualization(dataset_id: int):
|
228 |
+
"""
|
229 |
+
Render visualization of dataset versions
|
230 |
+
|
231 |
+
Args:
|
232 |
+
dataset_id: ID of the dataset
|
233 |
+
"""
|
234 |
+
versions = version_control.get_versions(dataset_id)
|
235 |
+
|
236 |
+
if not versions:
|
237 |
+
st.info("No versions available to visualize.")
|
238 |
+
return
|
239 |
+
|
240 |
+
st.subheader("Version Metrics Visualization")
|
241 |
+
|
242 |
+
# Prepare data for visualization
|
243 |
+
viz_data = []
|
244 |
+
for version in versions:
|
245 |
+
viz_data.append({
|
246 |
+
"Version": version.version_id[:8] + "...", # Truncated ID for display
|
247 |
+
"Date": version.timestamp,
|
248 |
+
"Rows": version.metadata.get("rows", 0),
|
249 |
+
"Columns": version.metadata.get("columns", 0),
|
250 |
+
"Full Version ID": version.version_id, # For tooltip
|
251 |
+
"Description": version.description
|
252 |
+
})
|
253 |
+
|
254 |
+
viz_df = pd.DataFrame(viz_data)
|
255 |
+
|
256 |
+
# Visualize row counts over versions
|
257 |
+
fig1 = px.line(
|
258 |
+
viz_df,
|
259 |
+
x="Date",
|
260 |
+
y="Rows",
|
261 |
+
title="Dataset Size (Rows) Across Versions",
|
262 |
+
markers=True,
|
263 |
+
hover_data=["Full Version ID", "Description"]
|
264 |
+
)
|
265 |
+
st.plotly_chart(fig1, use_container_width=True)
|
266 |
+
|
267 |
+
# Visualize column counts over versions
|
268 |
+
fig2 = px.line(
|
269 |
+
viz_df,
|
270 |
+
x="Date",
|
271 |
+
y="Columns",
|
272 |
+
title="Dataset Structure (Columns) Across Versions",
|
273 |
+
markers=True,
|
274 |
+
hover_data=["Full Version ID", "Description"]
|
275 |
+
)
|
276 |
+
st.plotly_chart(fig2, use_container_width=True)
|
components/dataset_visualization.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import plotly.express as px
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
from plotly.subplots import make_subplots
|
7 |
+
|
8 |
+
def render_dataset_visualization(dataset, dataset_type):
|
9 |
+
"""
|
10 |
+
Renders visualizations for the dataset.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
dataset: The dataset to visualize (pandas DataFrame)
|
14 |
+
dataset_type: The type of dataset (csv, json, etc.)
|
15 |
+
"""
|
16 |
+
if dataset is None:
|
17 |
+
st.warning("No dataset to visualize.")
|
18 |
+
return
|
19 |
+
|
20 |
+
st.markdown("<h3>Dataset Visualization</h3>", unsafe_allow_html=True)
|
21 |
+
|
22 |
+
# Get column types
|
23 |
+
numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist()
|
24 |
+
categorical_cols = dataset.select_dtypes(include=['object', 'category']).columns.tolist()
|
25 |
+
date_cols = [col for col in dataset.columns if dataset[col].dtype == 'datetime64[ns]']
|
26 |
+
|
27 |
+
# Add visualization options based on column types
|
28 |
+
viz_type = st.selectbox(
|
29 |
+
"Select visualization type",
|
30 |
+
["Distribution", "Correlation", "Categories", "Time Series", "Custom"],
|
31 |
+
help="Choose the type of visualization to create"
|
32 |
+
)
|
33 |
+
|
34 |
+
if viz_type == "Distribution":
|
35 |
+
if numeric_cols:
|
36 |
+
# Select columns for distribution visualization
|
37 |
+
selected_cols = st.multiselect(
|
38 |
+
"Select columns to visualize",
|
39 |
+
numeric_cols,
|
40 |
+
default=numeric_cols[:min(3, len(numeric_cols))]
|
41 |
+
)
|
42 |
+
|
43 |
+
if not selected_cols:
|
44 |
+
st.warning("Please select at least one column to visualize.")
|
45 |
+
return
|
46 |
+
|
47 |
+
# Distribution plots
|
48 |
+
if len(selected_cols) == 1:
|
49 |
+
# Single column histogram with density curve
|
50 |
+
col = selected_cols[0]
|
51 |
+
fig = px.histogram(
|
52 |
+
dataset,
|
53 |
+
x=col,
|
54 |
+
histnorm='probability density',
|
55 |
+
title=f"Distribution of {col}",
|
56 |
+
color_discrete_sequence=["#FFD21E"],
|
57 |
+
template="simple_white"
|
58 |
+
)
|
59 |
+
fig.add_traces(
|
60 |
+
go.Scatter(
|
61 |
+
x=dataset[col].sort_values(),
|
62 |
+
y=dataset[col].sort_values().reset_index(drop=True).rolling(
|
63 |
+
window=int(len(dataset[col])/10) if len(dataset[col]) > 10 else len(dataset[col]),
|
64 |
+
min_periods=1,
|
65 |
+
center=True
|
66 |
+
).mean(),
|
67 |
+
mode='lines',
|
68 |
+
line=dict(color="#2563EB", width=3),
|
69 |
+
name='Smoothed'
|
70 |
+
)
|
71 |
+
)
|
72 |
+
st.plotly_chart(fig, use_container_width=True)
|
73 |
+
else:
|
74 |
+
# Multiple histograms in a grid
|
75 |
+
num_cols = min(len(selected_cols), 2)
|
76 |
+
num_rows = (len(selected_cols) + num_cols - 1) // num_cols
|
77 |
+
|
78 |
+
fig = make_subplots(
|
79 |
+
rows=num_rows,
|
80 |
+
cols=num_cols,
|
81 |
+
subplot_titles=[f"Distribution of {col}" for col in selected_cols]
|
82 |
+
)
|
83 |
+
|
84 |
+
for i, col in enumerate(selected_cols):
|
85 |
+
row = i // num_cols + 1
|
86 |
+
col_pos = i % num_cols + 1
|
87 |
+
|
88 |
+
# Add histogram
|
89 |
+
fig.add_trace(
|
90 |
+
go.Histogram(
|
91 |
+
x=dataset[col],
|
92 |
+
name=col,
|
93 |
+
marker_color="#FFD21E"
|
94 |
+
),
|
95 |
+
row=row, col=col_pos
|
96 |
+
)
|
97 |
+
|
98 |
+
fig.update_layout(
|
99 |
+
title="Distribution of Selected Features",
|
100 |
+
showlegend=False,
|
101 |
+
template="simple_white",
|
102 |
+
height=300 * num_rows
|
103 |
+
)
|
104 |
+
st.plotly_chart(fig, use_container_width=True)
|
105 |
+
|
106 |
+
# Show distribution statistics
|
107 |
+
st.markdown("### Distribution Statistics")
|
108 |
+
stats_df = dataset[selected_cols].describe().T
|
109 |
+
st.dataframe(stats_df, use_container_width=True)
|
110 |
+
else:
|
111 |
+
st.warning("No numeric columns found for distribution visualization.")
|
112 |
+
|
113 |
+
elif viz_type == "Correlation":
|
114 |
+
if len(numeric_cols) >= 2:
|
115 |
+
# Correlation matrix
|
116 |
+
st.markdown("### Correlation Matrix")
|
117 |
+
|
118 |
+
# Select columns for correlation
|
119 |
+
selected_cols = st.multiselect(
|
120 |
+
"Select columns for correlation analysis",
|
121 |
+
numeric_cols,
|
122 |
+
default=numeric_cols[:min(5, len(numeric_cols))]
|
123 |
+
)
|
124 |
+
|
125 |
+
if len(selected_cols) < 2:
|
126 |
+
st.warning("Please select at least two columns for correlation analysis.")
|
127 |
+
return
|
128 |
+
|
129 |
+
# Compute correlation
|
130 |
+
corr = dataset[selected_cols].corr()
|
131 |
+
|
132 |
+
# Heatmap
|
133 |
+
fig = px.imshow(
|
134 |
+
corr,
|
135 |
+
color_continuous_scale="RdBu_r",
|
136 |
+
title="Correlation Matrix",
|
137 |
+
template="simple_white",
|
138 |
+
text_auto=True
|
139 |
+
)
|
140 |
+
st.plotly_chart(fig, use_container_width=True)
|
141 |
+
|
142 |
+
# Scatter plot matrix for selected columns
|
143 |
+
if len(selected_cols) > 2 and len(selected_cols) <= 5: # Limit to 5 columns for readability
|
144 |
+
st.markdown("### Scatter Plot Matrix")
|
145 |
+
fig = px.scatter_matrix(
|
146 |
+
dataset,
|
147 |
+
dimensions=selected_cols,
|
148 |
+
color_discrete_sequence=["#2563EB"],
|
149 |
+
title="Scatter Plot Matrix",
|
150 |
+
template="simple_white"
|
151 |
+
)
|
152 |
+
fig.update_traces(diagonal_visible=False)
|
153 |
+
st.plotly_chart(fig, use_container_width=True)
|
154 |
+
|
155 |
+
# Correlation pairs as bar chart
|
156 |
+
st.markdown("### Top Correlation Pairs")
|
157 |
+
|
158 |
+
# Get correlation pairs
|
159 |
+
corr_pairs = []
|
160 |
+
for i in range(len(corr.columns)):
|
161 |
+
for j in range(i+1, len(corr.columns)):
|
162 |
+
corr_pairs.append({
|
163 |
+
'Feature 1': corr.columns[i],
|
164 |
+
'Feature 2': corr.columns[j],
|
165 |
+
'Correlation': corr.iloc[i, j]
|
166 |
+
})
|
167 |
+
|
168 |
+
# Sort by absolute correlation
|
169 |
+
corr_pairs = sorted(corr_pairs, key=lambda x: abs(x['Correlation']), reverse=True)
|
170 |
+
|
171 |
+
# Create bar chart
|
172 |
+
if corr_pairs:
|
173 |
+
# Convert to DataFrame
|
174 |
+
corr_df = pd.DataFrame(corr_pairs)
|
175 |
+
pair_labels = [f"{row['Feature 1']} & {row['Feature 2']}" for _, row in corr_df.iterrows()]
|
176 |
+
|
177 |
+
# Bar chart
|
178 |
+
fig = px.bar(
|
179 |
+
x=pair_labels,
|
180 |
+
y=[abs(c) for c in corr_df['Correlation']],
|
181 |
+
color=corr_df['Correlation'],
|
182 |
+
color_continuous_scale="RdBu_r",
|
183 |
+
labels={'x': 'Feature Pairs', 'y': 'Absolute Correlation'},
|
184 |
+
title="Top Feature Correlations"
|
185 |
+
)
|
186 |
+
st.plotly_chart(fig, use_container_width=True)
|
187 |
+
else:
|
188 |
+
st.warning("Need at least two numeric columns for correlation analysis.")
|
189 |
+
|
190 |
+
elif viz_type == "Categories":
|
191 |
+
if categorical_cols:
|
192 |
+
# Select categorical column
|
193 |
+
selected_cat = st.selectbox("Select categorical column", categorical_cols)
|
194 |
+
|
195 |
+
# Category counts
|
196 |
+
value_counts = dataset[selected_cat].value_counts()
|
197 |
+
|
198 |
+
# Limit to top N categories if there are too many
|
199 |
+
if len(value_counts) > 20:
|
200 |
+
st.info(f"Showing top 20 categories out of {len(value_counts)}")
|
201 |
+
value_counts = value_counts.head(20)
|
202 |
+
|
203 |
+
# Bar chart
|
204 |
+
fig = px.bar(
|
205 |
+
x=value_counts.index,
|
206 |
+
y=value_counts.values,
|
207 |
+
title=f"Category Counts for {selected_cat}",
|
208 |
+
labels={'x': selected_cat, 'y': 'Count'},
|
209 |
+
color_discrete_sequence=["#FFD21E"]
|
210 |
+
)
|
211 |
+
st.plotly_chart(fig, use_container_width=True)
|
212 |
+
|
213 |
+
# If there are numeric columns, show relationship with categorical
|
214 |
+
if numeric_cols:
|
215 |
+
st.markdown(f"### {selected_cat} vs Numeric Features")
|
216 |
+
selected_num = st.selectbox("Select numeric column", numeric_cols)
|
217 |
+
|
218 |
+
# Box plot
|
219 |
+
fig = px.box(
|
220 |
+
dataset,
|
221 |
+
x=selected_cat,
|
222 |
+
y=selected_num,
|
223 |
+
title=f"{selected_cat} vs {selected_num}",
|
224 |
+
color_discrete_sequence=["#2563EB"],
|
225 |
+
template="simple_white"
|
226 |
+
)
|
227 |
+
st.plotly_chart(fig, use_container_width=True)
|
228 |
+
|
229 |
+
# Statistics by category
|
230 |
+
st.markdown(f"### Statistics of {selected_num} by {selected_cat}")
|
231 |
+
stats_by_cat = dataset.groupby(selected_cat)[selected_num].describe()
|
232 |
+
st.dataframe(stats_by_cat, use_container_width=True)
|
233 |
+
else:
|
234 |
+
st.warning("No categorical columns found for category visualization.")
|
235 |
+
|
236 |
+
elif viz_type == "Time Series":
|
237 |
+
# Check if there are potential date columns
|
238 |
+
potential_date_cols = date_cols.copy()
|
239 |
+
|
240 |
+
# Also check for object columns that might be dates
|
241 |
+
for col in categorical_cols:
|
242 |
+
# Sample the column to check if it contains date-like strings
|
243 |
+
sample = dataset[col].dropna().head(5).tolist()
|
244 |
+
if sample and all('/' in str(x) or '-' in str(x) for x in sample):
|
245 |
+
potential_date_cols.append(col)
|
246 |
+
|
247 |
+
if potential_date_cols:
|
248 |
+
date_col = st.selectbox("Select date column", potential_date_cols)
|
249 |
+
|
250 |
+
# Convert to datetime if it's not already
|
251 |
+
if dataset[date_col].dtype != 'datetime64[ns]':
|
252 |
+
try:
|
253 |
+
temp_df = dataset.copy()
|
254 |
+
temp_df[date_col] = pd.to_datetime(temp_df[date_col])
|
255 |
+
except:
|
256 |
+
st.error(f"Could not convert {date_col} to datetime.")
|
257 |
+
return
|
258 |
+
else:
|
259 |
+
temp_df = dataset.copy()
|
260 |
+
|
261 |
+
# Select numeric column for time series
|
262 |
+
if numeric_cols:
|
263 |
+
value_col = st.selectbox("Select value column", numeric_cols)
|
264 |
+
|
265 |
+
# Aggregate by time period
|
266 |
+
time_period = st.selectbox(
|
267 |
+
"Aggregate by",
|
268 |
+
["Day", "Week", "Month", "Quarter", "Year"]
|
269 |
+
)
|
270 |
+
|
271 |
+
# Set up time grouping
|
272 |
+
if time_period == "Day":
|
273 |
+
temp_df['period'] = temp_df[date_col].dt.date
|
274 |
+
elif time_period == "Week":
|
275 |
+
temp_df['period'] = temp_df[date_col].dt.to_period('W').dt.start_time
|
276 |
+
elif time_period == "Month":
|
277 |
+
temp_df['period'] = temp_df[date_col].dt.to_period('M').dt.start_time
|
278 |
+
elif time_period == "Quarter":
|
279 |
+
temp_df['period'] = temp_df[date_col].dt.to_period('Q').dt.start_time
|
280 |
+
else: # Year
|
281 |
+
temp_df['period'] = temp_df[date_col].dt.year
|
282 |
+
|
283 |
+
# Aggregate data
|
284 |
+
agg_method = st.selectbox("Aggregation method", ["Mean", "Sum", "Min", "Max", "Count"])
|
285 |
+
agg_map = {
|
286 |
+
"Mean": "mean",
|
287 |
+
"Sum": "sum",
|
288 |
+
"Min": "min",
|
289 |
+
"Max": "max",
|
290 |
+
"Count": "count"
|
291 |
+
}
|
292 |
+
|
293 |
+
time_series = temp_df.groupby('period')[value_col].agg(agg_map[agg_method]).reset_index()
|
294 |
+
|
295 |
+
# Line chart
|
296 |
+
fig = px.line(
|
297 |
+
time_series,
|
298 |
+
x='period',
|
299 |
+
y=value_col,
|
300 |
+
title=f"{agg_method} of {value_col} by {time_period}",
|
301 |
+
markers=True,
|
302 |
+
color_discrete_sequence=["#2563EB"],
|
303 |
+
template="simple_white"
|
304 |
+
)
|
305 |
+
fig.update_layout(
|
306 |
+
xaxis_title=time_period,
|
307 |
+
yaxis_title=f"{agg_method} of {value_col}"
|
308 |
+
)
|
309 |
+
st.plotly_chart(fig, use_container_width=True)
|
310 |
+
|
311 |
+
# Show trendline option
|
312 |
+
if st.checkbox("Show trendline"):
|
313 |
+
fig = px.scatter(
|
314 |
+
time_series,
|
315 |
+
x='period',
|
316 |
+
y=value_col,
|
317 |
+
trendline="ols",
|
318 |
+
title=f"{agg_method} of {value_col} by {time_period} with Trendline",
|
319 |
+
color_discrete_sequence=["#2563EB"],
|
320 |
+
template="simple_white"
|
321 |
+
)
|
322 |
+
fig.update_layout(
|
323 |
+
xaxis_title=time_period,
|
324 |
+
yaxis_title=f"{agg_method} of {value_col}"
|
325 |
+
)
|
326 |
+
st.plotly_chart(fig, use_container_width=True)
|
327 |
+
|
328 |
+
# Table view of time series data
|
329 |
+
st.dataframe(time_series, use_container_width=True)
|
330 |
+
else:
|
331 |
+
st.warning("No numeric columns found for time series values.")
|
332 |
+
else:
|
333 |
+
st.warning("No date columns found for time series visualization.")
|
334 |
+
|
335 |
+
elif viz_type == "Custom":
|
336 |
+
st.markdown("### Custom Visualization")
|
337 |
+
st.info("Create a custom plot by selecting axes and plot type")
|
338 |
+
|
339 |
+
# Select plot type
|
340 |
+
plot_type = st.selectbox(
|
341 |
+
"Select plot type",
|
342 |
+
["Scatter", "Line", "Bar", "Box", "Violin", "Histogram", "Pie", "3D Scatter"]
|
343 |
+
)
|
344 |
+
|
345 |
+
# Depending on the plot type, get required axes
|
346 |
+
if plot_type in ["Scatter", "Line", "Bar", "3D Scatter"]:
|
347 |
+
# For scatter/line/bar, we need x and y
|
348 |
+
x_col = st.selectbox("X-axis", dataset.columns.tolist())
|
349 |
+
y_col = st.selectbox("Y-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
|
350 |
+
|
351 |
+
# For 3D scatter, we need a z-axis
|
352 |
+
if plot_type == "3D Scatter":
|
353 |
+
z_col = st.selectbox("Z-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
|
354 |
+
|
355 |
+
# Optional color dimension
|
356 |
+
use_color = st.checkbox("Add color dimension")
|
357 |
+
color_col = None
|
358 |
+
if use_color:
|
359 |
+
color_col = st.selectbox("Color by", dataset.columns.tolist())
|
360 |
+
|
361 |
+
# Create plot
|
362 |
+
if plot_type == "Scatter":
|
363 |
+
fig = px.scatter(
|
364 |
+
dataset,
|
365 |
+
x=x_col,
|
366 |
+
y=y_col,
|
367 |
+
color=color_col,
|
368 |
+
title=f"{y_col} vs {x_col}",
|
369 |
+
template="simple_white"
|
370 |
+
)
|
371 |
+
elif plot_type == "Line":
|
372 |
+
fig = px.line(
|
373 |
+
dataset.sort_values(x_col),
|
374 |
+
x=x_col,
|
375 |
+
y=y_col,
|
376 |
+
color=color_col,
|
377 |
+
title=f"{y_col} vs {x_col}",
|
378 |
+
template="simple_white"
|
379 |
+
)
|
380 |
+
elif plot_type == "Bar":
|
381 |
+
fig = px.bar(
|
382 |
+
dataset,
|
383 |
+
x=x_col,
|
384 |
+
y=y_col,
|
385 |
+
color=color_col,
|
386 |
+
title=f"{y_col} by {x_col}",
|
387 |
+
template="simple_white"
|
388 |
+
)
|
389 |
+
elif plot_type == "3D Scatter":
|
390 |
+
fig = px.scatter_3d(
|
391 |
+
dataset,
|
392 |
+
x=x_col,
|
393 |
+
y=y_col,
|
394 |
+
z=z_col,
|
395 |
+
color=color_col,
|
396 |
+
title=f"3D Scatter: {x_col}, {y_col}, {z_col}",
|
397 |
+
template="simple_white"
|
398 |
+
)
|
399 |
+
|
400 |
+
st.plotly_chart(fig, use_container_width=True)
|
401 |
+
|
402 |
+
elif plot_type in ["Box", "Violin"]:
|
403 |
+
# For box/violin, we need x (categorical) and y (numeric)
|
404 |
+
x_col = st.selectbox("X-axis (categories)", categorical_cols if categorical_cols else dataset.columns.tolist())
|
405 |
+
y_col = st.selectbox("Y-axis (values)", numeric_cols if numeric_cols else dataset.columns.tolist())
|
406 |
+
|
407 |
+
# Optional color dimension
|
408 |
+
use_color = st.checkbox("Add color dimension")
|
409 |
+
color_col = None
|
410 |
+
if use_color:
|
411 |
+
color_col = st.selectbox("Color by", dataset.columns.tolist())
|
412 |
+
|
413 |
+
# Create plot
|
414 |
+
if plot_type == "Box":
|
415 |
+
fig = px.box(
|
416 |
+
dataset,
|
417 |
+
x=x_col,
|
418 |
+
y=y_col,
|
419 |
+
color=color_col,
|
420 |
+
title=f"Box Plot: {y_col} by {x_col}",
|
421 |
+
template="simple_white"
|
422 |
+
)
|
423 |
+
else: # Violin
|
424 |
+
fig = px.violin(
|
425 |
+
dataset,
|
426 |
+
x=x_col,
|
427 |
+
y=y_col,
|
428 |
+
color=color_col,
|
429 |
+
title=f"Violin Plot: {y_col} by {x_col}",
|
430 |
+
template="simple_white"
|
431 |
+
)
|
432 |
+
|
433 |
+
st.plotly_chart(fig, use_container_width=True)
|
434 |
+
|
435 |
+
elif plot_type == "Histogram":
|
436 |
+
# For histogram, we need just one column
|
437 |
+
value_col = st.selectbox("Value column", dataset.columns.tolist())
|
438 |
+
|
439 |
+
# Bins option
|
440 |
+
n_bins = st.slider("Number of bins", 5, 100, 20)
|
441 |
+
|
442 |
+
# Optional color dimension
|
443 |
+
use_color = st.checkbox("Add color dimension")
|
444 |
+
color_col = None
|
445 |
+
if use_color:
|
446 |
+
color_col = st.selectbox("Color by", dataset.columns.tolist())
|
447 |
+
|
448 |
+
# Create plot
|
449 |
+
fig = px.histogram(
|
450 |
+
dataset,
|
451 |
+
x=value_col,
|
452 |
+
color=color_col,
|
453 |
+
nbins=n_bins,
|
454 |
+
title=f"Histogram of {value_col}",
|
455 |
+
template="simple_white"
|
456 |
+
)
|
457 |
+
|
458 |
+
st.plotly_chart(fig, use_container_width=True)
|
459 |
+
|
460 |
+
elif plot_type == "Pie":
|
461 |
+
# For pie, we need a categorical column
|
462 |
+
cat_col = st.selectbox("Category column", categorical_cols if categorical_cols else dataset.columns.tolist())
|
463 |
+
|
464 |
+
# Optional value column
|
465 |
+
use_values = st.checkbox("Use custom values")
|
466 |
+
value_col = None
|
467 |
+
if use_values and numeric_cols:
|
468 |
+
value_col = st.selectbox("Value column", numeric_cols)
|
469 |
+
|
470 |
+
# Limit to top N categories if there are too many
|
471 |
+
top_n = st.slider("Limit to top N categories", 0, 20, 10,
|
472 |
+
help="Set to 0 to show all categories. Recommended to limit to top 10-15 categories for readability.")
|
473 |
+
|
474 |
+
# Process data for pie chart
|
475 |
+
if top_n > 0:
|
476 |
+
if use_values and value_col:
|
477 |
+
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
|
478 |
+
pie_data = pie_data.sort_values(value_col, ascending=False).head(top_n)
|
479 |
+
else:
|
480 |
+
value_counts = dataset[cat_col].value_counts().reset_index()
|
481 |
+
value_counts.columns = [cat_col, 'count']
|
482 |
+
pie_data = value_counts.head(top_n)
|
483 |
+
value_col = 'count'
|
484 |
+
else:
|
485 |
+
if use_values and value_col:
|
486 |
+
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
|
487 |
+
else:
|
488 |
+
value_counts = dataset[cat_col].value_counts().reset_index()
|
489 |
+
value_counts.columns = [cat_col, 'count']
|
490 |
+
pie_data = value_counts
|
491 |
+
value_col = 'count'
|
492 |
+
|
493 |
+
# Create plot
|
494 |
+
fig = px.pie(
|
495 |
+
pie_data,
|
496 |
+
names=cat_col,
|
497 |
+
values=value_col,
|
498 |
+
title=f"Pie Chart of {cat_col}",
|
499 |
+
template="simple_white"
|
500 |
+
)
|
501 |
+
|
502 |
+
st.plotly_chart(fig, use_container_width=True)
|
components/fine_tuning/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Fine-tuning package for code generation models.
|
3 |
+
"""
|
components/fine_tuning/finetune_ui.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Streamlit UI for fine-tuning code generation models.
|
3 |
+
"""
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from datetime import datetime
|
10 |
+
import torch
|
11 |
+
import plotly.express as px
|
12 |
+
import plotly.graph_objects as go
|
13 |
+
from pathlib import Path
|
14 |
+
import json
|
15 |
+
import uuid
|
16 |
+
import threading
|
17 |
+
from transformers import TrainingArguments
|
18 |
+
from datasets import Dataset
|
19 |
+
|
20 |
+
from components.fine_tuning.model_interface import (
|
21 |
+
load_model_and_tokenizer,
|
22 |
+
preprocess_code_dataset,
|
23 |
+
setup_trainer,
|
24 |
+
generate_code_comment,
|
25 |
+
generate_code_from_comment,
|
26 |
+
save_training_config,
|
27 |
+
load_training_config
|
28 |
+
)
|
29 |
+
|
30 |
+
# Initialize training state
|
31 |
+
if 'training_run_id' not in st.session_state:
|
32 |
+
st.session_state.training_run_id = None
|
33 |
+
if 'training_status' not in st.session_state:
|
34 |
+
st.session_state.training_status = "idle" # idle, running, completed, failed
|
35 |
+
if 'training_progress' not in st.session_state:
|
36 |
+
st.session_state.training_progress = 0.0
|
37 |
+
if 'trained_model' not in st.session_state:
|
38 |
+
st.session_state.trained_model = None
|
39 |
+
if 'trained_tokenizer' not in st.session_state:
|
40 |
+
st.session_state.trained_tokenizer = None
|
41 |
+
if 'training_logs' not in st.session_state:
|
42 |
+
st.session_state.training_logs = []
|
43 |
+
if 'fine_tuning_dataset' not in st.session_state:
|
44 |
+
st.session_state.fine_tuning_dataset = None
|
45 |
+
|
46 |
+
# Directory for saving models
|
47 |
+
MODELS_DIR = Path("./fine_tuned_models")
|
48 |
+
MODELS_DIR.mkdir(exist_ok=True)
|
49 |
+
|
50 |
+
# Set for background training thread
|
51 |
+
training_thread = None
|
52 |
+
|
53 |
+
def render_dataset_preparation():
|
54 |
+
"""
|
55 |
+
Render the dataset preparation interface.
|
56 |
+
"""
|
57 |
+
st.markdown("### Dataset Preparation")
|
58 |
+
|
59 |
+
# Dataset input options
|
60 |
+
dataset_source = st.radio(
|
61 |
+
"Choose dataset source",
|
62 |
+
["Upload CSV", "Manual Input", "Use Current Dataset"],
|
63 |
+
help="Select how you want to provide your fine-tuning dataset"
|
64 |
+
)
|
65 |
+
|
66 |
+
if dataset_source == "Upload CSV":
|
67 |
+
uploaded_file = st.file_uploader(
|
68 |
+
"Upload fine-tuning dataset (CSV)",
|
69 |
+
type=["csv"],
|
70 |
+
help="CSV file with 'input' and 'target' columns"
|
71 |
+
)
|
72 |
+
|
73 |
+
if uploaded_file is not None:
|
74 |
+
try:
|
75 |
+
df = pd.read_csv(uploaded_file)
|
76 |
+
|
77 |
+
# Check if required columns exist
|
78 |
+
if "input" not in df.columns or "target" not in df.columns:
|
79 |
+
st.error("CSV must contain 'input' and 'target' columns.")
|
80 |
+
return
|
81 |
+
|
82 |
+
# Preview dataset
|
83 |
+
st.markdown("### Dataset Preview")
|
84 |
+
st.dataframe(df.head(), use_container_width=True)
|
85 |
+
|
86 |
+
# Dataset statistics
|
87 |
+
st.markdown("### Dataset Statistics")
|
88 |
+
col1, col2 = st.columns(2)
|
89 |
+
with col1:
|
90 |
+
st.metric("Number of examples", len(df))
|
91 |
+
with col2:
|
92 |
+
st.metric("Average input length", df["input"].astype(str).str.len().mean().round(1))
|
93 |
+
|
94 |
+
# Save dataset
|
95 |
+
if st.button("Use this dataset"):
|
96 |
+
st.session_state.fine_tuning_dataset = df
|
97 |
+
st.success(f"Dataset with {len(df)} examples loaded successfully!")
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
st.error(f"Error loading CSV: {str(e)}")
|
101 |
+
|
102 |
+
elif dataset_source == "Manual Input":
|
103 |
+
st.markdown("""
|
104 |
+
Enter pairs of inputs and targets for fine-tuning. For code-to-comment tasks, the input is code and
|
105 |
+
the target is a comment. For comment-to-code tasks, the input is a comment and the target is code.
|
106 |
+
""")
|
107 |
+
|
108 |
+
# Container for input fields
|
109 |
+
examples_container = st.container()
|
110 |
+
|
111 |
+
# Default number of example fields
|
112 |
+
if "num_examples" not in st.session_state:
|
113 |
+
st.session_state.num_examples = 3
|
114 |
+
|
115 |
+
# Add more examples button
|
116 |
+
if st.button("Add another example"):
|
117 |
+
st.session_state.num_examples += 1
|
118 |
+
|
119 |
+
# Input fields for examples
|
120 |
+
inputs = []
|
121 |
+
targets = []
|
122 |
+
|
123 |
+
with examples_container:
|
124 |
+
for i in range(st.session_state.num_examples):
|
125 |
+
st.markdown(f"### Example {i+1}")
|
126 |
+
col1, col2 = st.columns(2)
|
127 |
+
with col1:
|
128 |
+
input_text = st.text_area(f"Input {i+1}", key=f"input_{i}", height=150)
|
129 |
+
inputs.append(input_text)
|
130 |
+
with col2:
|
131 |
+
target_text = st.text_area(f"Target {i+1}", key=f"target_{i}", height=150)
|
132 |
+
targets.append(target_text)
|
133 |
+
|
134 |
+
# Create dataset from manual input
|
135 |
+
if st.button("Create Dataset from Examples"):
|
136 |
+
# Filter out empty examples
|
137 |
+
valid_examples = [(inp, tgt) for inp, tgt in zip(inputs, targets) if inp.strip() and tgt.strip()]
|
138 |
+
|
139 |
+
if valid_examples:
|
140 |
+
df = pd.DataFrame(valid_examples, columns=["input", "target"])
|
141 |
+
st.session_state.fine_tuning_dataset = df
|
142 |
+
|
143 |
+
# Preview dataset
|
144 |
+
st.markdown("### Dataset Preview")
|
145 |
+
st.dataframe(df, use_container_width=True)
|
146 |
+
st.success(f"Dataset with {len(df)} examples created successfully!")
|
147 |
+
else:
|
148 |
+
st.warning("No valid examples found. Please enter at least one input-target pair.")
|
149 |
+
|
150 |
+
elif dataset_source == "Use Current Dataset":
|
151 |
+
if st.session_state.dataset is None:
|
152 |
+
st.warning("No dataset is currently loaded. Please upload or select a dataset first.")
|
153 |
+
else:
|
154 |
+
st.markdown("### Current Dataset")
|
155 |
+
st.dataframe(st.session_state.dataset.head(), use_container_width=True)
|
156 |
+
|
157 |
+
# Column selection
|
158 |
+
col1, col2 = st.columns(2)
|
159 |
+
with col1:
|
160 |
+
input_col = st.selectbox("Select column for inputs", st.session_state.dataset.columns)
|
161 |
+
with col2:
|
162 |
+
target_col = st.selectbox("Select column for targets", st.session_state.dataset.columns)
|
163 |
+
|
164 |
+
# Create fine-tuning dataset
|
165 |
+
if st.button("Create Fine-Tuning Dataset"):
|
166 |
+
df = st.session_state.dataset[[input_col, target_col]].copy()
|
167 |
+
df.columns = ["input", "target"]
|
168 |
+
|
169 |
+
# Verify data types and convert to string if necessary
|
170 |
+
df["input"] = df["input"].astype(str)
|
171 |
+
df["target"] = df["target"].astype(str)
|
172 |
+
|
173 |
+
# Preview
|
174 |
+
st.dataframe(df.head(), use_container_width=True)
|
175 |
+
|
176 |
+
# Store dataset
|
177 |
+
st.session_state.fine_tuning_dataset = df
|
178 |
+
st.success(f"Fine-tuning dataset with {len(df)} examples created successfully!")
|
179 |
+
|
180 |
+
def render_model_training():
|
181 |
+
"""
|
182 |
+
Render the model training interface.
|
183 |
+
"""
|
184 |
+
st.markdown("### Model Training")
|
185 |
+
|
186 |
+
# Check if dataset is available
|
187 |
+
if st.session_state.fine_tuning_dataset is None:
|
188 |
+
st.warning("Please prepare a dataset in the 'Dataset Preparation' tab first.")
|
189 |
+
return
|
190 |
+
|
191 |
+
# Model selection
|
192 |
+
model_options = {
|
193 |
+
"Salesforce/codet5-small": "CodeT5 Small (60M params)",
|
194 |
+
"Salesforce/codet5-base": "CodeT5 Base (220M params)",
|
195 |
+
"Salesforce/codet5-large": "CodeT5 Large (770M params)",
|
196 |
+
"microsoft/codebert-base": "CodeBERT Base (125M params)",
|
197 |
+
"facebook/bart-base": "BART Base (140M params)"
|
198 |
+
}
|
199 |
+
|
200 |
+
model_name = st.selectbox(
|
201 |
+
"Select pre-trained model",
|
202 |
+
list(model_options.keys()),
|
203 |
+
format_func=lambda x: model_options[x],
|
204 |
+
help="Select the base model for fine-tuning"
|
205 |
+
)
|
206 |
+
|
207 |
+
# Task type
|
208 |
+
task_type = st.selectbox(
|
209 |
+
"Select task type",
|
210 |
+
["Code to Comment", "Comment to Code"],
|
211 |
+
help="Choose the direction of your task"
|
212 |
+
)
|
213 |
+
|
214 |
+
# Task prefix
|
215 |
+
if task_type == "Code to Comment":
|
216 |
+
task_prefix = "translate code to comment: "
|
217 |
+
else:
|
218 |
+
task_prefix = "translate comment to code: "
|
219 |
+
|
220 |
+
# Hyperparameters
|
221 |
+
st.markdown("### Training Hyperparameters")
|
222 |
+
|
223 |
+
col1, col2 = st.columns(2)
|
224 |
+
with col1:
|
225 |
+
learning_rate = st.select_slider(
|
226 |
+
"Learning Rate",
|
227 |
+
options=[1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5, 1e-4],
|
228 |
+
value=5e-5,
|
229 |
+
help="Step size for optimizer updates"
|
230 |
+
)
|
231 |
+
epochs = st.slider(
|
232 |
+
"Epochs",
|
233 |
+
min_value=1,
|
234 |
+
max_value=20,
|
235 |
+
value=3,
|
236 |
+
help="Number of complete passes through the dataset"
|
237 |
+
)
|
238 |
+
with col2:
|
239 |
+
batch_size = st.select_slider(
|
240 |
+
"Batch Size",
|
241 |
+
options=[1, 2, 4, 8, 16, 32],
|
242 |
+
value=8,
|
243 |
+
help="Number of examples processed in each training step"
|
244 |
+
)
|
245 |
+
max_input_length = st.slider(
|
246 |
+
"Max Input Length (tokens)",
|
247 |
+
min_value=64,
|
248 |
+
max_value=512,
|
249 |
+
value=256,
|
250 |
+
help="Maximum length of input sequences"
|
251 |
+
)
|
252 |
+
|
253 |
+
# Advanced options
|
254 |
+
with st.expander("Advanced Options"):
|
255 |
+
col1, col2 = st.columns(2)
|
256 |
+
with col1:
|
257 |
+
weight_decay = st.select_slider(
|
258 |
+
"Weight Decay",
|
259 |
+
options=[0.0, 0.01, 0.05, 0.1],
|
260 |
+
value=0.01,
|
261 |
+
help="L2 regularization"
|
262 |
+
)
|
263 |
+
warmup_steps = st.slider(
|
264 |
+
"Warmup Steps",
|
265 |
+
min_value=0,
|
266 |
+
max_value=1000,
|
267 |
+
value=100,
|
268 |
+
help="Steps for learning rate warmup"
|
269 |
+
)
|
270 |
+
with col2:
|
271 |
+
max_target_length = st.slider(
|
272 |
+
"Max Target Length (tokens)",
|
273 |
+
min_value=64,
|
274 |
+
max_value=512,
|
275 |
+
value=256,
|
276 |
+
help="Maximum length of target sequences"
|
277 |
+
)
|
278 |
+
gradient_accumulation = st.slider(
|
279 |
+
"Gradient Accumulation Steps",
|
280 |
+
min_value=1,
|
281 |
+
max_value=16,
|
282 |
+
value=1,
|
283 |
+
help="Number of steps to accumulate gradients"
|
284 |
+
)
|
285 |
+
|
286 |
+
# Model output configuration
|
287 |
+
st.markdown("### Model Output Configuration")
|
288 |
+
model_name_custom = st.text_input(
|
289 |
+
"Custom model name",
|
290 |
+
value=f"{model_name.split('/')[-1]}-finetuned-{task_type.lower().replace(' ', '-')}",
|
291 |
+
help="Name for your fine-tuned model"
|
292 |
+
)
|
293 |
+
|
294 |
+
# Training controls
|
295 |
+
st.markdown("### Training Controls")
|
296 |
+
|
297 |
+
# Check if training is in progress
|
298 |
+
if st.session_state.training_status == "running":
|
299 |
+
# Display progress
|
300 |
+
st.progress(st.session_state.training_progress)
|
301 |
+
|
302 |
+
# Show logs
|
303 |
+
if st.session_state.training_logs:
|
304 |
+
st.markdown("### Training Logs")
|
305 |
+
log_text = "\n".join(st.session_state.training_logs[-10:]) # Show last 10 logs
|
306 |
+
st.text_area("Latest logs", log_text, height=200, disabled=True)
|
307 |
+
|
308 |
+
# Stop button
|
309 |
+
if st.button("Stop Training"):
|
310 |
+
# Logic to stop training thread
|
311 |
+
st.session_state.training_status = "stopping"
|
312 |
+
st.warning("Stopping training after current epoch completes...")
|
313 |
+
|
314 |
+
elif st.session_state.training_status == "completed":
|
315 |
+
st.success(f"Training completed! Model saved as: {model_name_custom}")
|
316 |
+
|
317 |
+
# Show metrics if available
|
318 |
+
if "training_metrics" in st.session_state:
|
319 |
+
st.markdown("### Training Metrics")
|
320 |
+
metrics_df = pd.DataFrame(st.session_state.training_metrics)
|
321 |
+
st.line_chart(metrics_df)
|
322 |
+
|
323 |
+
# Reset button
|
324 |
+
if st.button("Start New Training"):
|
325 |
+
st.session_state.training_status = "idle"
|
326 |
+
st.session_state.training_progress = 0.0
|
327 |
+
st.session_state.training_logs = []
|
328 |
+
st.experimental_rerun()
|
329 |
+
|
330 |
+
else: # idle or failed
|
331 |
+
# If previously failed, show error
|
332 |
+
if st.session_state.training_status == "failed":
|
333 |
+
st.error("Previous training failed. See logs for details.")
|
334 |
+
if st.session_state.training_logs:
|
335 |
+
st.text_area("Error logs", "\n".join(st.session_state.training_logs[-5:]), height=100, disabled=True)
|
336 |
+
|
337 |
+
# Start training button
|
338 |
+
if st.button("Start Training"):
|
339 |
+
# Validate dataset
|
340 |
+
if len(st.session_state.fine_tuning_dataset) < 5:
|
341 |
+
st.warning("Dataset is very small. Consider adding more examples for better results.")
|
342 |
+
|
343 |
+
# Set up training configuration
|
344 |
+
training_config = {
|
345 |
+
"model_name": model_name,
|
346 |
+
"task_type": task_type,
|
347 |
+
"task_prefix": task_prefix,
|
348 |
+
"learning_rate": learning_rate,
|
349 |
+
"epochs": epochs,
|
350 |
+
"batch_size": batch_size,
|
351 |
+
"max_input_length": max_input_length,
|
352 |
+
"max_target_length": max_target_length,
|
353 |
+
"weight_decay": weight_decay,
|
354 |
+
"warmup_steps": warmup_steps,
|
355 |
+
"gradient_accumulation": gradient_accumulation,
|
356 |
+
"output_model_name": model_name_custom,
|
357 |
+
"dataset_size": len(st.session_state.fine_tuning_dataset)
|
358 |
+
}
|
359 |
+
|
360 |
+
# Update session state
|
361 |
+
st.session_state.training_status = "running"
|
362 |
+
st.session_state.training_progress = 0.0
|
363 |
+
st.session_state.training_logs = ["Training initialized..."]
|
364 |
+
st.session_state.training_run_id = str(uuid.uuid4())
|
365 |
+
|
366 |
+
# TODO: Start actual training process using transformers
|
367 |
+
st.info("Training would start here with the Hugging Face transformers library")
|
368 |
+
|
369 |
+
# For now, just simulate training progress
|
370 |
+
st.session_state.training_progress = 0.1
|
371 |
+
st.session_state.training_logs.append("Loaded model and tokenizer")
|
372 |
+
st.session_state.training_logs.append("Preprocessing dataset...")
|
373 |
+
|
374 |
+
# Rerun to update UI with progress
|
375 |
+
st.experimental_rerun()
|
376 |
+
|
377 |
+
def render_model_testing():
|
378 |
+
"""
|
379 |
+
Render the model testing interface.
|
380 |
+
"""
|
381 |
+
st.markdown("### Test & Use Model")
|
382 |
+
|
383 |
+
# Check if a model is trained/available
|
384 |
+
if st.session_state.trained_model is None and st.session_state.training_status != "completed":
|
385 |
+
# Look for saved models
|
386 |
+
saved_models = list(MODELS_DIR.glob("*/"))
|
387 |
+
if not saved_models:
|
388 |
+
st.warning("No trained models available. Please train a model first.")
|
389 |
+
return
|
390 |
+
|
391 |
+
# Let user select a saved model
|
392 |
+
model_options = [model.name for model in saved_models]
|
393 |
+
selected_model = st.selectbox("Select a saved model", model_options)
|
394 |
+
|
395 |
+
if st.button("Load Selected Model"):
|
396 |
+
st.info(f"Loading model {selected_model}...")
|
397 |
+
# TODO: Load model logic
|
398 |
+
st.session_state.trained_model = "loaded" # Placeholder
|
399 |
+
st.session_state.trained_tokenizer = "loaded" # Placeholder
|
400 |
+
st.success("Model loaded successfully!")
|
401 |
+
|
402 |
+
else:
|
403 |
+
# Model is available for testing
|
404 |
+
model_type = "Code to Comment" if "code-to-comment" in st.session_state.get("model_name", "") else "Comment to Code"
|
405 |
+
|
406 |
+
st.markdown(f"### Testing {model_type} Generation")
|
407 |
+
|
408 |
+
if model_type == "Code to Comment":
|
409 |
+
input_text = st.text_area(
|
410 |
+
"Enter code snippet",
|
411 |
+
height=200,
|
412 |
+
help="Enter a code snippet to generate a comment"
|
413 |
+
)
|
414 |
+
|
415 |
+
if st.button("Generate Comment"):
|
416 |
+
if input_text:
|
417 |
+
with st.spinner("Generating comment..."):
|
418 |
+
# TODO: Replace with actual model inference
|
419 |
+
result = f"/* This code {input_text.split()[0:3]} ... */"
|
420 |
+
st.markdown("### Generated Comment")
|
421 |
+
st.code(result)
|
422 |
+
else:
|
423 |
+
st.warning("Please enter a code snippet.")
|
424 |
+
|
425 |
+
else: # Comment to Code
|
426 |
+
input_text = st.text_area(
|
427 |
+
"Enter comment/description",
|
428 |
+
height=150,
|
429 |
+
help="Enter a description to generate code"
|
430 |
+
)
|
431 |
+
|
432 |
+
language = st.selectbox(
|
433 |
+
"Programming language",
|
434 |
+
["Python", "JavaScript", "Java", "C++", "Go"]
|
435 |
+
)
|
436 |
+
|
437 |
+
if st.button("Generate Code"):
|
438 |
+
if input_text:
|
439 |
+
with st.spinner("Generating code..."):
|
440 |
+
# TODO: Replace with actual model inference
|
441 |
+
result = f"def example_function():\n # {input_text}\n pass"
|
442 |
+
st.markdown("### Generated Code")
|
443 |
+
st.code(result, language=language.lower())
|
444 |
+
else:
|
445 |
+
st.warning("Please enter a comment or description.")
|
446 |
+
|
447 |
+
# Batch testing
|
448 |
+
with st.expander("Batch Testing"):
|
449 |
+
st.markdown("Upload a CSV file with test cases to evaluate your model.")
|
450 |
+
|
451 |
+
test_file = st.file_uploader(
|
452 |
+
"Upload test cases (CSV)",
|
453 |
+
type=["csv"],
|
454 |
+
help="CSV file with 'input' and 'expected' columns"
|
455 |
+
)
|
456 |
+
|
457 |
+
if test_file is not None:
|
458 |
+
try:
|
459 |
+
test_df = pd.read_csv(test_file)
|
460 |
+
st.dataframe(test_df.head(), use_container_width=True)
|
461 |
+
|
462 |
+
if st.button("Run Batch Test"):
|
463 |
+
with st.spinner("Running tests..."):
|
464 |
+
# TODO: Actual batch inference
|
465 |
+
st.success("Batch testing completed!")
|
466 |
+
|
467 |
+
# Dummy results
|
468 |
+
results = pd.DataFrame({
|
469 |
+
"input": test_df["input"],
|
470 |
+
"expected": test_df.get("expected", [""] * len(test_df)),
|
471 |
+
"generated": ["Sample output " + str(i) for i in range(len(test_df))],
|
472 |
+
"match_score": np.random.uniform(0.5, 1.0, len(test_df))
|
473 |
+
})
|
474 |
+
|
475 |
+
st.dataframe(results, use_container_width=True)
|
476 |
+
|
477 |
+
# Metrics
|
478 |
+
st.markdown("### Evaluation Metrics")
|
479 |
+
col1, col2 = st.columns(2)
|
480 |
+
with col1:
|
481 |
+
st.metric("Average Match Score", f"{results['match_score'].mean():.2f}")
|
482 |
+
with col2:
|
483 |
+
st.metric("Tests Passed", f"{sum(results['match_score'] > 0.8)}/{len(results)}")
|
484 |
+
|
485 |
+
except Exception as e:
|
486 |
+
st.error(f"Error loading test file: {str(e)}")
|
487 |
+
|
488 |
+
def render_finetune_ui():
|
489 |
+
"""
|
490 |
+
Render the fine-tuning UI for code generation models.
|
491 |
+
"""
|
492 |
+
st.markdown("<h2>Fine-Tune Code Generation Model</h2>", unsafe_allow_html=True)
|
493 |
+
|
494 |
+
# Overview and instructions
|
495 |
+
with st.expander("About Fine-Tuning", expanded=False):
|
496 |
+
st.markdown("""
|
497 |
+
## Fine-Tuning a Code Generation Model
|
498 |
+
|
499 |
+
This interface allows you to fine-tune pre-trained code generation models from Hugging Face
|
500 |
+
on your custom dataset to adapt them to your specific coding style or task.
|
501 |
+
|
502 |
+
### How to use:
|
503 |
+
1. **Prepare your dataset** - Upload a CSV file with 'input' and 'target' columns:
|
504 |
+
- For code-to-comment: 'input' = code snippets, 'target' = corresponding comments
|
505 |
+
- For comment-to-code: 'input' = comments, 'target' = corresponding code snippets
|
506 |
+
|
507 |
+
2. **Configure training** - Set hyperparameters like learning rate, batch size, and epochs
|
508 |
+
|
509 |
+
3. **Start fine-tuning** - Launch the training process and monitor progress
|
510 |
+
|
511 |
+
4. **Test your model** - Once training is complete, test your model on new inputs
|
512 |
+
|
513 |
+
### Tips for better results:
|
514 |
+
- Use a consistent format for your code snippets and comments
|
515 |
+
- Start with a small dataset (50-100 examples) to verify the process
|
516 |
+
- Try different hyperparameters to find the best configuration
|
517 |
+
""")
|
518 |
+
|
519 |
+
# Main UI with tabs
|
520 |
+
tab1, tab2, tab3 = st.tabs(["Dataset Preparation", "Model Training", "Test & Use Model"])
|
521 |
+
|
522 |
+
with tab1:
|
523 |
+
render_dataset_preparation()
|
524 |
+
|
525 |
+
with tab2:
|
526 |
+
render_model_training()
|
527 |
+
|
528 |
+
with tab3:
|
529 |
+
render_model_testing()
|
components/fine_tuning/model_interface.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
Hugging Face model interface for code generation fine-tuning.
|
4 |
+
"""
|
5 |
+
import streamlit as st
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoModelForSeq2SeqLM,
|
11 |
+
Trainer,
|
12 |
+
TrainingArguments,
|
13 |
+
DataCollatorForSeq2Seq,
|
14 |
+
)
|
15 |
+
from datasets import Dataset
|
16 |
+
import numpy as np
|
17 |
+
import time
|
18 |
+
import os
|
19 |
+
from pathlib import Path
|
20 |
+
import uuid
|
21 |
+
import json
|
22 |
+
|
23 |
+
@st.cache_resource(show_spinner=False)
|
24 |
+
def load_model_and_tokenizer(model_name):
|
25 |
+
"""
|
26 |
+
Load a pre-trained model and tokenizer from Hugging Face.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model_name: Name of the model on Hugging Face (e.g., 'Salesforce/codet5-base')
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Tuple of (tokenizer, model)
|
33 |
+
"""
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
35 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
36 |
+
return tokenizer, model
|
37 |
+
|
38 |
+
def preprocess_code_dataset(dataset_df, tokenizer, max_input_length=256, max_target_length=256, task_prefix=""):
|
39 |
+
"""
|
40 |
+
Preprocess the code dataset for fine-tuning.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
dataset_df: Pandas DataFrame with 'input' and 'target' columns
|
44 |
+
tokenizer: HuggingFace tokenizer
|
45 |
+
max_input_length: Maximum length for input sequences
|
46 |
+
max_target_length: Maximum length for target sequences
|
47 |
+
task_prefix: Prefix to add to inputs (e.g., "translate code to comment: ")
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
HuggingFace Dataset ready for training
|
51 |
+
"""
|
52 |
+
def preprocess_function(examples):
|
53 |
+
inputs = [task_prefix + text for text in examples["input"]]
|
54 |
+
targets = examples["target"]
|
55 |
+
|
56 |
+
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")
|
57 |
+
|
58 |
+
# Set up the tokenizer for targets
|
59 |
+
with tokenizer.as_target_tokenizer():
|
60 |
+
labels = tokenizer(targets, max_length=max_target_length, truncation=True, padding="max_length")
|
61 |
+
|
62 |
+
model_inputs["labels"] = labels["input_ids"]
|
63 |
+
return model_inputs
|
64 |
+
|
65 |
+
# Convert DataFrame to HuggingFace Dataset
|
66 |
+
hf_dataset = Dataset.from_pandas(dataset_df)
|
67 |
+
|
68 |
+
# Split dataset into train and validation
|
69 |
+
splits = hf_dataset.train_test_split(test_size=0.1)
|
70 |
+
train_dataset = splits["train"]
|
71 |
+
eval_dataset = splits["test"]
|
72 |
+
|
73 |
+
# Apply preprocessing
|
74 |
+
train_dataset = train_dataset.map(
|
75 |
+
preprocess_function,
|
76 |
+
batched=True,
|
77 |
+
remove_columns=["input", "target"]
|
78 |
+
)
|
79 |
+
eval_dataset = eval_dataset.map(
|
80 |
+
preprocess_function,
|
81 |
+
batched=True,
|
82 |
+
remove_columns=["input", "target"]
|
83 |
+
)
|
84 |
+
|
85 |
+
return train_dataset, eval_dataset
|
86 |
+
|
87 |
+
def setup_trainer(model, tokenizer, train_dataset, eval_dataset, output_dir, training_args):
|
88 |
+
"""
|
89 |
+
Set up the Trainer for fine-tuning.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
model: HuggingFace model
|
93 |
+
tokenizer: HuggingFace tokenizer
|
94 |
+
train_dataset: Preprocessed training dataset
|
95 |
+
eval_dataset: Preprocessed evaluation dataset
|
96 |
+
output_dir: Directory to save model and checkpoints
|
97 |
+
training_args: Dictionary of training arguments
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
HuggingFace Trainer
|
101 |
+
"""
|
102 |
+
# Define training arguments
|
103 |
+
args = TrainingArguments(
|
104 |
+
output_dir=output_dir,
|
105 |
+
per_device_train_batch_size=training_args.get("batch_size", 8),
|
106 |
+
per_device_eval_batch_size=training_args.get("batch_size", 8),
|
107 |
+
learning_rate=training_args.get("learning_rate", 5e-5),
|
108 |
+
num_train_epochs=training_args.get("epochs", 3),
|
109 |
+
weight_decay=training_args.get("weight_decay", 0.01),
|
110 |
+
evaluation_strategy="epoch",
|
111 |
+
save_strategy="epoch",
|
112 |
+
load_best_model_at_end=True,
|
113 |
+
push_to_hub=False,
|
114 |
+
gradient_accumulation_steps=training_args.get("gradient_accumulation", 1),
|
115 |
+
warmup_steps=training_args.get("warmup_steps", 100),
|
116 |
+
logging_dir=os.path.join(output_dir, "logs"),
|
117 |
+
logging_steps=10,
|
118 |
+
)
|
119 |
+
|
120 |
+
# Data collator
|
121 |
+
data_collator = DataCollatorForSeq2Seq(
|
122 |
+
tokenizer,
|
123 |
+
model=model,
|
124 |
+
label_pad_token_id=tokenizer.pad_token_id,
|
125 |
+
pad_to_multiple_of=8
|
126 |
+
)
|
127 |
+
|
128 |
+
# Initialize Trainer
|
129 |
+
trainer = Trainer(
|
130 |
+
model=model,
|
131 |
+
args=args,
|
132 |
+
train_dataset=train_dataset,
|
133 |
+
eval_dataset=eval_dataset,
|
134 |
+
tokenizer=tokenizer,
|
135 |
+
data_collator=data_collator,
|
136 |
+
)
|
137 |
+
|
138 |
+
return trainer
|
139 |
+
|
140 |
+
def generate_code_comment(model, tokenizer, code, max_length=100, task_prefix="translate code to comment: "):
|
141 |
+
"""
|
142 |
+
Generate a comment for a given code snippet.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
model: Fine-tuned model
|
146 |
+
tokenizer: Tokenizer
|
147 |
+
code: Input code snippet
|
148 |
+
max_length: Maximum length of the generated comment
|
149 |
+
task_prefix: Prefix to add to the input
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Generated comment as string
|
153 |
+
"""
|
154 |
+
inputs = tokenizer(task_prefix + code, return_tensors="pt", padding=True, truncation=True)
|
155 |
+
|
156 |
+
# Move inputs to the same device as model
|
157 |
+
device = model.device
|
158 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
159 |
+
|
160 |
+
# Generate
|
161 |
+
output_ids = model.generate(
|
162 |
+
inputs["input_ids"],
|
163 |
+
max_length=max_length,
|
164 |
+
num_beams=4,
|
165 |
+
early_stopping=True
|
166 |
+
)
|
167 |
+
|
168 |
+
comment = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
169 |
+
return comment
|
170 |
+
|
171 |
+
def generate_code_from_comment(model, tokenizer, comment, max_length=200, task_prefix="translate comment to code: "):
|
172 |
+
"""
|
173 |
+
Generate code from a given comment/description.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
model: Fine-tuned model
|
177 |
+
tokenizer: Tokenizer
|
178 |
+
comment: Input comment or description
|
179 |
+
max_length: Maximum length of the generated code
|
180 |
+
task_prefix: Prefix to add to the input
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Generated code as string
|
184 |
+
"""
|
185 |
+
inputs = tokenizer(task_prefix + comment, return_tensors="pt", padding=True, truncation=True)
|
186 |
+
|
187 |
+
# Move inputs to the same device as model
|
188 |
+
device = model.device
|
189 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
190 |
+
|
191 |
+
# Generate
|
192 |
+
output_ids = model.generate(
|
193 |
+
inputs["input_ids"],
|
194 |
+
max_length=max_length,
|
195 |
+
num_beams=4,
|
196 |
+
early_stopping=True
|
197 |
+
)
|
198 |
+
|
199 |
+
code = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
200 |
+
return code
|
201 |
+
|
202 |
+
def save_training_config(output_dir, config):
|
203 |
+
"""
|
204 |
+
Save training configuration to a JSON file.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
output_dir: Directory to save the configuration
|
208 |
+
config: Dictionary with training configuration
|
209 |
+
"""
|
210 |
+
config_path = os.path.join(output_dir, "training_config.json")
|
211 |
+
with open(config_path, "w") as f:
|
212 |
+
json.dump(config, f, indent=2)
|
213 |
+
|
214 |
+
def load_training_config(model_dir):
|
215 |
+
"""
|
216 |
+
Load training configuration from a JSON file.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
model_dir: Directory with the saved model
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
Dictionary with training configuration
|
223 |
+
"""
|
224 |
+
config_path = os.path.join(model_dir, "training_config.json")
|
225 |
+
if os.path.exists(config_path):
|
226 |
+
with open(config_path, "r") as f:
|
227 |
+
return json.load(f)
|
228 |
+
return {}
|
generated-icon.png
ADDED
![]() |
Git LFS Details
|
huggingface-spacefile
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: ML Dataset & Code Generation Manager
|
2 |
+
emoji: 🤗
|
3 |
+
colorFrom: indigo
|
4 |
+
colorTo: blue
|
5 |
+
sdk: streamlit
|
6 |
+
sdk_version: 1.42.0
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
main.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import plotly.express as px
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Make sure necessary directories exist
|
10 |
+
os.makedirs('assets', exist_ok=True)
|
11 |
+
os.makedirs('database/data', exist_ok=True)
|
12 |
+
os.makedirs('fine_tuned_models', exist_ok=True)
|
13 |
+
|
14 |
+
# Page configuration
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="ML Dataset & Code Generation Manager",
|
17 |
+
page_icon="🤗",
|
18 |
+
layout="wide",
|
19 |
+
initial_sidebar_state="expanded",
|
20 |
+
)
|
21 |
+
|
22 |
+
def load_css():
|
23 |
+
"""Load custom CSS styles"""
|
24 |
+
css_dir = Path("assets")
|
25 |
+
css_path = css_dir / "custom.css"
|
26 |
+
|
27 |
+
if not css_path.exists():
|
28 |
+
# Create assets directory if it doesn't exist
|
29 |
+
css_dir.mkdir(exist_ok=True)
|
30 |
+
|
31 |
+
# Create a basic CSS file if it doesn't exist
|
32 |
+
with open(css_path, "w") as f:
|
33 |
+
f.write("""
|
34 |
+
/* Custom styles for ML Dataset & Code Generation Manager */
|
35 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Space+Grotesk:wght@500;700&display=swap');
|
36 |
+
|
37 |
+
h1, h2, h3, h4, h5, h6 {
|
38 |
+
font-family: 'Space Grotesk', sans-serif;
|
39 |
+
font-weight: 700;
|
40 |
+
color: #1A1C1F;
|
41 |
+
}
|
42 |
+
|
43 |
+
body {
|
44 |
+
font-family: 'Inter', sans-serif;
|
45 |
+
color: #1A1C1F;
|
46 |
+
background-color: #F8F9FA;
|
47 |
+
}
|
48 |
+
|
49 |
+
.stButton button {
|
50 |
+
background-color: #2563EB;
|
51 |
+
color: white;
|
52 |
+
border-radius: 4px;
|
53 |
+
border: none;
|
54 |
+
padding: 0.5rem 1rem;
|
55 |
+
font-weight: 600;
|
56 |
+
}
|
57 |
+
|
58 |
+
.stButton button:hover {
|
59 |
+
background-color: #1D4ED8;
|
60 |
+
}
|
61 |
+
|
62 |
+
/* Card styling */
|
63 |
+
.card {
|
64 |
+
background-color: white;
|
65 |
+
border-radius: 8px;
|
66 |
+
padding: 1.5rem;
|
67 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
68 |
+
margin-bottom: 1rem;
|
69 |
+
}
|
70 |
+
|
71 |
+
/* Accent colors */
|
72 |
+
.accent-primary {
|
73 |
+
color: #2563EB;
|
74 |
+
}
|
75 |
+
|
76 |
+
.accent-secondary {
|
77 |
+
color: #84919A;
|
78 |
+
}
|
79 |
+
|
80 |
+
.accent-success {
|
81 |
+
color: #10B981;
|
82 |
+
}
|
83 |
+
|
84 |
+
.accent-warning {
|
85 |
+
color: #F59E0B;
|
86 |
+
}
|
87 |
+
|
88 |
+
.accent-danger {
|
89 |
+
color: #EF4444;
|
90 |
+
}
|
91 |
+
""")
|
92 |
+
|
93 |
+
# Load custom CSS
|
94 |
+
with open(css_path, "r") as f:
|
95 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
96 |
+
|
97 |
+
def render_finetune_ui():
|
98 |
+
"""
|
99 |
+
Renders the fine-tuning UI for code generation models.
|
100 |
+
"""
|
101 |
+
try:
|
102 |
+
from components.fine_tuning.finetune_ui import render_finetune_ui as ft_ui
|
103 |
+
ft_ui()
|
104 |
+
except ImportError as e:
|
105 |
+
st.error(f"Could not load fine-tuning UI: {e}")
|
106 |
+
|
107 |
+
# Create default fine-tuning UI component if not exists
|
108 |
+
os.makedirs("components/fine_tuning", exist_ok=True)
|
109 |
+
if not os.path.exists("components/fine_tuning/__init__.py"):
|
110 |
+
with open("components/fine_tuning/__init__.py", "w") as f:
|
111 |
+
f.write('"""\nFine-tuning package for code generation models.\n"""\n')
|
112 |
+
|
113 |
+
if not os.path.exists("components/fine_tuning/finetune_ui.py"):
|
114 |
+
with open("components/fine_tuning/finetune_ui.py", "w") as f:
|
115 |
+
f.write('''"""
|
116 |
+
Streamlit UI for fine-tuning code generation models.
|
117 |
+
"""
|
118 |
+
import streamlit as st
|
119 |
+
import pandas as pd
|
120 |
+
import os
|
121 |
+
|
122 |
+
def render_dataset_preparation():
|
123 |
+
"""
|
124 |
+
Render the dataset preparation interface.
|
125 |
+
"""
|
126 |
+
st.subheader("Dataset Preparation")
|
127 |
+
st.write("Prepare your dataset for fine-tuning code generation models.")
|
128 |
+
|
129 |
+
# Dataset upload
|
130 |
+
uploaded_file = st.file_uploader("Upload your dataset", type=["csv", "json"])
|
131 |
+
if uploaded_file is not None:
|
132 |
+
try:
|
133 |
+
if uploaded_file.name.endswith('.csv'):
|
134 |
+
df = pd.read_csv(uploaded_file)
|
135 |
+
else:
|
136 |
+
df = pd.read_json(uploaded_file)
|
137 |
+
|
138 |
+
st.write("Dataset Preview:")
|
139 |
+
st.dataframe(df.head())
|
140 |
+
|
141 |
+
# Example of data columns mapping
|
142 |
+
st.subheader("Column Mapping")
|
143 |
+
|
144 |
+
input_col = st.selectbox("Select input column (e.g., code)", df.columns)
|
145 |
+
target_col = st.selectbox("Select target column (e.g., comment)", df.columns)
|
146 |
+
|
147 |
+
# Sample transformation
|
148 |
+
if st.button("Apply Transformation"):
|
149 |
+
if input_col and target_col:
|
150 |
+
# Example transformation: simple trim/clean
|
151 |
+
df[input_col] = df[input_col].astype(str).str.strip()
|
152 |
+
df[target_col] = df[target_col].astype(str).str.strip()
|
153 |
+
|
154 |
+
st.write("Transformed Dataset:")
|
155 |
+
st.dataframe(df.head())
|
156 |
+
|
157 |
+
# Option to save processed dataset
|
158 |
+
if st.button("Save Processed Dataset"):
|
159 |
+
processed_path = os.path.join("datasets", "processed_dataset.csv")
|
160 |
+
os.makedirs("datasets", exist_ok=True)
|
161 |
+
df.to_csv(processed_path, index=False)
|
162 |
+
st.success(f"Dataset saved to {processed_path}")
|
163 |
+
except Exception as e:
|
164 |
+
st.error(f"Error processing dataset: {e}")
|
165 |
+
|
166 |
+
def render_model_training():
|
167 |
+
"""
|
168 |
+
Render the model training interface.
|
169 |
+
"""
|
170 |
+
st.subheader("Model Training")
|
171 |
+
st.write("Configure and start training your model.")
|
172 |
+
|
173 |
+
# Model selection
|
174 |
+
model_options = [
|
175 |
+
"Salesforce/codet5-small",
|
176 |
+
"Salesforce/codet5-base",
|
177 |
+
"microsoft/codebert-base",
|
178 |
+
"microsoft/graphcodebert-base"
|
179 |
+
]
|
180 |
+
|
181 |
+
selected_model = st.selectbox("Select base model", model_options)
|
182 |
+
|
183 |
+
# Training parameters
|
184 |
+
col1, col2 = st.columns(2)
|
185 |
+
with col1:
|
186 |
+
batch_size = st.number_input("Batch size", min_value=1, max_value=64, value=8)
|
187 |
+
epochs = st.number_input("Number of epochs", min_value=1, max_value=100, value=3)
|
188 |
+
learning_rate = st.number_input("Learning rate", min_value=0.00001, max_value=0.1, value=0.0001, format="%.5f")
|
189 |
+
|
190 |
+
with col2:
|
191 |
+
max_input_length = st.number_input("Max input length", min_value=32, max_value=512, value=128)
|
192 |
+
max_target_length = st.number_input("Max target length", min_value=32, max_value=512, value=128)
|
193 |
+
task_type = st.selectbox("Task type", ["Code to Comment", "Comment to Code"])
|
194 |
+
|
195 |
+
# Training button (placeholder)
|
196 |
+
if st.button("Start Training"):
|
197 |
+
st.info("Training would start here. This is a placeholder.")
|
198 |
+
# In a real implementation, this would call the training function
|
199 |
+
# and display a progress bar or redirect to a training monitoring page
|
200 |
+
|
201 |
+
def render_model_testing():
|
202 |
+
"""
|
203 |
+
Render the model testing interface.
|
204 |
+
"""
|
205 |
+
st.subheader("Model Testing")
|
206 |
+
st.write("Test your fine-tuned model with custom inputs.")
|
207 |
+
|
208 |
+
# Model selection
|
209 |
+
st.selectbox("Select fine-tuned model", ["No models available yet"])
|
210 |
+
|
211 |
+
# Test input
|
212 |
+
if st.selectbox("Task type", ["Code to Comment", "Comment to Code"]) == "Code to Comment":
|
213 |
+
test_input = st.text_area("Enter code to generate a comment",
|
214 |
+
value="def fibonacci(n):\\n if n <= 1:\\n return n\\n else:\\n return fibonacci(n-1) + fibonacci(n-2)")
|
215 |
+
placeholder = "# This function implements the Fibonacci sequence recursively..."
|
216 |
+
else:
|
217 |
+
test_input = st.text_area("Enter comment to generate code",
|
218 |
+
value="# A function that calculates the factorial of a number recursively")
|
219 |
+
placeholder = "def factorial(n):\\n if n == 0:\\n return 1\\n else:\\n return n * factorial(n-1)"
|
220 |
+
|
221 |
+
# Generate button (placeholder)
|
222 |
+
if st.button("Generate"):
|
223 |
+
st.code(placeholder, language="python")
|
224 |
+
# In a real implementation, this would call the model inference function
|
225 |
+
|
226 |
+
def render_finetune_ui():
|
227 |
+
"""
|
228 |
+
Render the fine-tuning UI for code generation models.
|
229 |
+
"""
|
230 |
+
st.title("Fine-Tune Code Generation Models")
|
231 |
+
|
232 |
+
tabs = st.tabs(["Dataset Preparation", "Model Training", "Model Testing"])
|
233 |
+
|
234 |
+
with tabs[0]:
|
235 |
+
render_dataset_preparation()
|
236 |
+
|
237 |
+
with tabs[1]:
|
238 |
+
render_model_training()
|
239 |
+
|
240 |
+
with tabs[2]:
|
241 |
+
render_model_testing()
|
242 |
+
''')
|
243 |
+
|
244 |
+
# Try again after creating the files
|
245 |
+
try:
|
246 |
+
from components.fine_tuning.finetune_ui import render_finetune_ui as ft_ui
|
247 |
+
ft_ui()
|
248 |
+
except ImportError as e:
|
249 |
+
st.error(f"Still could not load fine-tuning UI after creating files: {e}")
|
250 |
+
st.info("Please restart the app to initialize the components.")
|
251 |
+
|
252 |
+
def render_code_quality_ui():
|
253 |
+
"""
|
254 |
+
Renders the code quality tools UI.
|
255 |
+
"""
|
256 |
+
try:
|
257 |
+
from components.code_quality import render_code_quality_tools
|
258 |
+
render_code_quality_tools()
|
259 |
+
except ImportError:
|
260 |
+
st.error("Code quality tools not found. Implementing basic version.")
|
261 |
+
st.title("Code Quality Tools")
|
262 |
+
st.write("This section will provide tools for code linting, formatting, and testing.")
|
263 |
+
|
264 |
+
# Tabs for different code quality tools
|
265 |
+
tabs = st.tabs(["Linting", "Formatting", "Type Checking", "Testing"])
|
266 |
+
|
267 |
+
with tabs[0]:
|
268 |
+
st.subheader("Code Linting")
|
269 |
+
st.write("Tools for checking code quality and style.")
|
270 |
+
st.code("# Coming soon: PyLint and Flake8 integration")
|
271 |
+
|
272 |
+
with tabs[1]:
|
273 |
+
st.subheader("Code Formatting")
|
274 |
+
st.write("Tools for formatting code according to style guides.")
|
275 |
+
st.code("# Coming soon: Black and isort integration")
|
276 |
+
|
277 |
+
with tabs[2]:
|
278 |
+
st.subheader("Type Checking")
|
279 |
+
st.write("Tools for checking type annotations.")
|
280 |
+
st.code("# Coming soon: MyPy integration")
|
281 |
+
|
282 |
+
with tabs[3]:
|
283 |
+
st.subheader("Testing")
|
284 |
+
st.write("Tools for running tests and checking code coverage.")
|
285 |
+
st.code("# Coming soon: PyTest integration")
|
286 |
+
|
287 |
+
def render_dataset_management_ui():
|
288 |
+
"""
|
289 |
+
Renders the dataset management UI.
|
290 |
+
"""
|
291 |
+
st.title("Dataset Management")
|
292 |
+
|
293 |
+
# Tabs for different dataset operations
|
294 |
+
tabs = st.tabs(["Upload", "Preview", "Statistics", "Visualization", "Validation", "Version Control"])
|
295 |
+
|
296 |
+
with tabs[0]:
|
297 |
+
try:
|
298 |
+
from components.dataset_uploader import render_dataset_uploader
|
299 |
+
render_dataset_uploader()
|
300 |
+
except ImportError:
|
301 |
+
st.subheader("Dataset Upload")
|
302 |
+
st.write("Upload your datasets in CSV or JSON format.")
|
303 |
+
|
304 |
+
uploaded_file = st.file_uploader("Choose a file", type=["csv", "json"])
|
305 |
+
if uploaded_file is not None:
|
306 |
+
try:
|
307 |
+
if uploaded_file.name.endswith('.csv'):
|
308 |
+
df = pd.read_csv(uploaded_file)
|
309 |
+
dataset_type = "csv"
|
310 |
+
else:
|
311 |
+
df = pd.read_json(uploaded_file)
|
312 |
+
dataset_type = "json"
|
313 |
+
|
314 |
+
st.session_state["dataset"] = df
|
315 |
+
st.session_state["dataset_type"] = dataset_type
|
316 |
+
st.success(f"Successfully loaded {dataset_type.upper()} file with {df.shape[0]} rows and {df.shape[1]} columns.")
|
317 |
+
st.dataframe(df.head())
|
318 |
+
except Exception as e:
|
319 |
+
st.error(f"Error: {e}")
|
320 |
+
|
321 |
+
with tabs[1]:
|
322 |
+
if "dataset" in st.session_state:
|
323 |
+
try:
|
324 |
+
from components.dataset_preview import render_dataset_preview
|
325 |
+
render_dataset_preview(st.session_state["dataset"], st.session_state["dataset_type"])
|
326 |
+
except ImportError:
|
327 |
+
st.subheader("Dataset Preview")
|
328 |
+
st.dataframe(st.session_state["dataset"].head(10))
|
329 |
+
else:
|
330 |
+
st.info("Please upload a dataset first.")
|
331 |
+
|
332 |
+
with tabs[2]:
|
333 |
+
if "dataset" in st.session_state:
|
334 |
+
try:
|
335 |
+
from components.dataset_statistics import render_dataset_statistics
|
336 |
+
render_dataset_statistics(st.session_state["dataset"], st.session_state["dataset_type"])
|
337 |
+
except ImportError:
|
338 |
+
st.subheader("Dataset Statistics")
|
339 |
+
st.write("Basic statistics:")
|
340 |
+
st.write(st.session_state["dataset"].describe())
|
341 |
+
|
342 |
+
# Missing values
|
343 |
+
missing_data = st.session_state["dataset"].isnull().sum()
|
344 |
+
st.write("Missing values per column:")
|
345 |
+
st.write(missing_data[missing_data > 0])
|
346 |
+
else:
|
347 |
+
st.info("Please upload a dataset first.")
|
348 |
+
|
349 |
+
with tabs[3]:
|
350 |
+
if "dataset" in st.session_state:
|
351 |
+
try:
|
352 |
+
from components.dataset_visualization import render_dataset_visualization
|
353 |
+
render_dataset_visualization(st.session_state["dataset"], st.session_state["dataset_type"])
|
354 |
+
except ImportError:
|
355 |
+
st.subheader("Dataset Visualization")
|
356 |
+
|
357 |
+
# Only show for numerical columns
|
358 |
+
numeric_cols = st.session_state["dataset"].select_dtypes(include=[np.number]).columns.tolist()
|
359 |
+
|
360 |
+
if len(numeric_cols) > 0:
|
361 |
+
col1, col2 = st.columns(2)
|
362 |
+
|
363 |
+
with col1:
|
364 |
+
x_axis = st.selectbox("X-axis", numeric_cols)
|
365 |
+
|
366 |
+
with col2:
|
367 |
+
y_axis = st.selectbox("Y-axis", numeric_cols, index=min(1, len(numeric_cols)-1))
|
368 |
+
|
369 |
+
fig = px.scatter(st.session_state["dataset"], x=x_axis, y=y_axis)
|
370 |
+
st.plotly_chart(fig, use_container_width=True)
|
371 |
+
else:
|
372 |
+
st.write("No numerical columns available for visualization.")
|
373 |
+
else:
|
374 |
+
st.info("Please upload a dataset first.")
|
375 |
+
|
376 |
+
with tabs[4]:
|
377 |
+
if "dataset" in st.session_state:
|
378 |
+
try:
|
379 |
+
from components.dataset_validation import render_dataset_validation
|
380 |
+
render_dataset_validation(st.session_state["dataset"], st.session_state["dataset_type"])
|
381 |
+
except ImportError:
|
382 |
+
st.subheader("Dataset Validation")
|
383 |
+
|
384 |
+
# Simple validation checks
|
385 |
+
st.write("Dataset Shape:", st.session_state["dataset"].shape)
|
386 |
+
st.write("Duplicate Rows:", st.session_state["dataset"].duplicated().sum())
|
387 |
+
|
388 |
+
# Missing values percentage
|
389 |
+
missing_percent = (st.session_state["dataset"].isnull().sum() / len(st.session_state["dataset"])) * 100
|
390 |
+
st.write("Missing Values Percentage:")
|
391 |
+
st.write(missing_percent[missing_percent > 0])
|
392 |
+
else:
|
393 |
+
st.info("Please upload a dataset first.")
|
394 |
+
|
395 |
+
with tabs[5]:
|
396 |
+
if "dataset" in st.session_state:
|
397 |
+
try:
|
398 |
+
from components.dataset_version_control import render_version_control_ui, render_save_version_ui, render_version_visualization
|
399 |
+
|
400 |
+
# If we have a dataset ID in session state, use it, otherwise prompt to save first
|
401 |
+
if "dataset_id" in st.session_state:
|
402 |
+
dataset_id = st.session_state["dataset_id"]
|
403 |
+
|
404 |
+
# Show dataset version control UI
|
405 |
+
render_version_control_ui(dataset_id, st.session_state.get("dataset"))
|
406 |
+
|
407 |
+
# Show save version UI
|
408 |
+
st.divider()
|
409 |
+
if st.session_state.get("dataset") is not None:
|
410 |
+
new_version = render_save_version_ui(dataset_id, st.session_state["dataset"])
|
411 |
+
if new_version:
|
412 |
+
st.success(f"Created new version: {new_version.version_id}")
|
413 |
+
|
414 |
+
# Show version visualization
|
415 |
+
st.divider()
|
416 |
+
render_version_visualization(dataset_id)
|
417 |
+
else:
|
418 |
+
# No dataset ID yet, so prompt to save the dataset first
|
419 |
+
st.info("To use version control, first save this dataset to the database.")
|
420 |
+
|
421 |
+
dataset_name = st.text_input("Dataset Name", value="My Dataset")
|
422 |
+
dataset_description = st.text_area("Dataset Description", value="Dataset uploaded for analysis")
|
423 |
+
|
424 |
+
if st.button("Save Dataset to Database"):
|
425 |
+
# Import database operations
|
426 |
+
from database.operations import DatasetOperations, DatasetVersionOperations
|
427 |
+
|
428 |
+
# Store dataset in database
|
429 |
+
dataset = DatasetOperations.store_dataframe_info(
|
430 |
+
df=st.session_state["dataset"],
|
431 |
+
name=dataset_name,
|
432 |
+
description=dataset_description,
|
433 |
+
source="local_upload"
|
434 |
+
)
|
435 |
+
|
436 |
+
# Store as initial version
|
437 |
+
initial_version = DatasetVersionOperations.create_version_from_dataframe(
|
438 |
+
dataset_id=dataset.id,
|
439 |
+
df=st.session_state["dataset"],
|
440 |
+
description="Initial version"
|
441 |
+
)
|
442 |
+
|
443 |
+
# Store dataset ID in session state
|
444 |
+
st.session_state["dataset_id"] = dataset.id
|
445 |
+
|
446 |
+
st.success(f"Dataset saved to database with ID: {dataset.id}")
|
447 |
+
st.success(f"Initial version created: {initial_version.version_id}")
|
448 |
+
|
449 |
+
# Rerun to show version control UI
|
450 |
+
st.experimental_rerun()
|
451 |
+
except ImportError as e:
|
452 |
+
st.subheader("Dataset Version Control")
|
453 |
+
st.error(f"Could not load version control components: {e}")
|
454 |
+
st.info("Please make sure all required components are installed.")
|
455 |
+
else:
|
456 |
+
st.info("Please upload a dataset first.")
|
457 |
+
|
458 |
+
def main():
|
459 |
+
"""
|
460 |
+
Main function to run the application.
|
461 |
+
"""
|
462 |
+
# Load custom CSS
|
463 |
+
load_css()
|
464 |
+
|
465 |
+
# Sidebar for navigation
|
466 |
+
st.sidebar.title("ML Dataset & Code Gen Manager")
|
467 |
+
|
468 |
+
# Navigation
|
469 |
+
page = st.sidebar.radio("Navigation", ["Home", "Dataset Management", "Fine-Tuning", "Code Quality Tools"])
|
470 |
+
|
471 |
+
# Display selected page
|
472 |
+
if page == "Home":
|
473 |
+
st.title("ML Dataset & Code Generation Manager")
|
474 |
+
st.write("Welcome to the ML Dataset & Code Generation Manager. This platform helps you manage ML datasets and fine-tune code generation models.")
|
475 |
+
|
476 |
+
# Main features in cards
|
477 |
+
col1, col2 = st.columns(2)
|
478 |
+
|
479 |
+
with col1:
|
480 |
+
st.markdown("""
|
481 |
+
<div class="card">
|
482 |
+
<h3>Dataset Management</h3>
|
483 |
+
<p>Upload, analyze, visualize, and validate your ML datasets.</p>
|
484 |
+
<ul>
|
485 |
+
<li>Support for CSV and JSON formats</li>
|
486 |
+
<li>Statistical analysis and visualization</li>
|
487 |
+
<li>Data validation and quality checks</li>
|
488 |
+
<li>Hugging Face Hub integration</li>
|
489 |
+
</ul>
|
490 |
+
</div>
|
491 |
+
""", unsafe_allow_html=True)
|
492 |
+
|
493 |
+
st.markdown("""
|
494 |
+
<div class="card">
|
495 |
+
<h3>Code Quality Tools</h3>
|
496 |
+
<p>Tools for ensuring high-quality code.</p>
|
497 |
+
<ul>
|
498 |
+
<li>Code linting with PyLint</li>
|
499 |
+
<li>Code formatting with Black and isort</li>
|
500 |
+
<li>Type checking with MyPy</li>
|
501 |
+
<li>Testing with PyTest</li>
|
502 |
+
</ul>
|
503 |
+
</div>
|
504 |
+
""", unsafe_allow_html=True)
|
505 |
+
|
506 |
+
with col2:
|
507 |
+
st.markdown("""
|
508 |
+
<div class="card">
|
509 |
+
<h3>Fine-Tuning</h3>
|
510 |
+
<p>Fine-tune code generation models on your custom datasets.</p>
|
511 |
+
<ul>
|
512 |
+
<li>Support for CodeT5, CodeBERT models</li>
|
513 |
+
<li>Code-to-comment and comment-to-code tasks</li>
|
514 |
+
<li>Custom dataset preparation</li>
|
515 |
+
<li>Model testing and evaluation</li>
|
516 |
+
</ul>
|
517 |
+
</div>
|
518 |
+
""", unsafe_allow_html=True)
|
519 |
+
|
520 |
+
st.markdown("""
|
521 |
+
<div class="card">
|
522 |
+
<h3>Hugging Face Integration</h3>
|
523 |
+
<p>Seamless integration with Hugging Face Hub.</p>
|
524 |
+
<ul>
|
525 |
+
<li>Search and load models and datasets</li>
|
526 |
+
<li>Deploy fine-tuned models to Hugging Face Spaces</li>
|
527 |
+
<li>Share and collaborate on models and datasets</li>
|
528 |
+
</ul>
|
529 |
+
</div>
|
530 |
+
""", unsafe_allow_html=True)
|
531 |
+
|
532 |
+
# Get started section
|
533 |
+
st.subheader("Get Started")
|
534 |
+
st.write("To get started, navigate to the Dataset Management page to upload your data, or explore the Fine-Tuning page to train code generation models.")
|
535 |
+
|
536 |
+
elif page == "Dataset Management":
|
537 |
+
render_dataset_management_ui()
|
538 |
+
|
539 |
+
elif page == "Fine-Tuning":
|
540 |
+
render_finetune_ui()
|
541 |
+
|
542 |
+
elif page == "Code Quality Tools":
|
543 |
+
render_code_quality_ui()
|
544 |
+
|
545 |
+
if __name__ == "__main__":
|
546 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "repl-nix-workspace"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
requires-python = ">=3.11"
|
6 |
+
dependencies = [
|
7 |
+
"black>=25.1.0",
|
8 |
+
"datasets>=3.3.2",
|
9 |
+
"huggingface-hub>=0.29.1",
|
10 |
+
"isort>=6.0.1",
|
11 |
+
"matplotlib>=3.10.1",
|
12 |
+
"mypy>=1.15.0",
|
13 |
+
"numpy>=2.2.3",
|
14 |
+
"pandas>=2.2.3",
|
15 |
+
"plotly>=6.0.0",
|
16 |
+
"pyarrow>=19.0.1",
|
17 |
+
"pylint>=3.3.4",
|
18 |
+
"pytest>=8.3.4",
|
19 |
+
"scikit-learn>=1.6.1",
|
20 |
+
"sqlalchemy>=2.0.38",
|
21 |
+
"streamlit>=1.42.2",
|
22 |
+
"torch>=2.6.0",
|
23 |
+
"transformers>=4.49.0",
|
24 |
+
]
|
25 |
+
|
26 |
+
[[tool.uv.index]]
|
27 |
+
explicit = true
|
28 |
+
name = "pytorch-cpu"
|
29 |
+
url = "https://download.pytorch.org/whl/cpu"
|
30 |
+
|
31 |
+
[tool.uv.sources]
|
32 |
+
torch = [{ index = "pytorch-cpu", marker = "platform_system == 'Linux'" }]
|
33 |
+
torchvision = [{ index = "pytorch-cpu", marker = "platform_system == 'Linux'" }]
|
replit.nix
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{pkgs}: {
|
2 |
+
deps = [
|
3 |
+
pkgs.arrow-cpp
|
4 |
+
pkgs.tk
|
5 |
+
pkgs.tcl
|
6 |
+
pkgs.qhull
|
7 |
+
pkgs.gtk3
|
8 |
+
pkgs.gobject-introspection
|
9 |
+
pkgs.ghostscript
|
10 |
+
pkgs.freetype
|
11 |
+
pkgs.ffmpeg-full
|
12 |
+
pkgs.cairo
|
13 |
+
pkgs.glibcLocales
|
14 |
+
pkgs.xsimd
|
15 |
+
pkgs.pkg-config
|
16 |
+
pkgs.libxcrypt
|
17 |
+
];
|
18 |
+
}
|
test_app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple test file for the ML Dataset & Code Generation Manager application.
|
3 |
+
This script checks basic aspects of the application structure and setup.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
def test_directory_structure():
|
13 |
+
"""Test if the required directories exist"""
|
14 |
+
# Ensure necessary directories exist
|
15 |
+
os.makedirs('database/data', exist_ok=True)
|
16 |
+
os.makedirs('assets', exist_ok=True)
|
17 |
+
os.makedirs('fine_tuned_models', exist_ok=True)
|
18 |
+
|
19 |
+
# Check if directories exist
|
20 |
+
assert Path("database").exists() and Path("database").is_dir(), "Database directory not found"
|
21 |
+
assert Path("assets").exists() and Path("assets").is_dir(), "Assets directory not found"
|
22 |
+
assert Path("fine_tuned_models").exists() and Path("fine_tuned_models").is_dir(), "Fine-tuned models directory not found"
|
23 |
+
|
24 |
+
print("✅ Directory structure test passed")
|
25 |
+
|
26 |
+
def test_css_file():
|
27 |
+
"""Test if the CSS file exists"""
|
28 |
+
css_file = Path("assets/custom.css")
|
29 |
+
assert css_file.exists() and css_file.is_file(), "CSS file not found in assets directory"
|
30 |
+
|
31 |
+
print("✅ CSS file test passed")
|
32 |
+
|
33 |
+
def test_huggingface_config():
|
34 |
+
"""Test if Hugging Face configuration file exists"""
|
35 |
+
config_file = Path("huggingface-spacefile")
|
36 |
+
assert config_file.exists() and config_file.is_file(), "Hugging Face configuration file not found"
|
37 |
+
|
38 |
+
print("✅ Hugging Face configuration test passed")
|
39 |
+
|
40 |
+
def test_streamlit_config():
|
41 |
+
"""Test if Streamlit configuration exists"""
|
42 |
+
config_dir = Path(".streamlit")
|
43 |
+
config_file = config_dir / "config.toml"
|
44 |
+
assert config_dir.exists() and config_dir.is_dir(), ".streamlit directory not found"
|
45 |
+
assert config_file.exists() and config_file.is_file(), "config.toml file not found in .streamlit directory"
|
46 |
+
|
47 |
+
print("✅ Streamlit configuration test passed")
|
48 |
+
|
49 |
+
def test_sample_dataframe():
|
50 |
+
"""Test creation of sample dataframes"""
|
51 |
+
# Create a sample dataframe
|
52 |
+
df = pd.DataFrame({
|
53 |
+
"code": ["def hello():", "import numpy as np", "print('Hello')"],
|
54 |
+
"comment": ["Function greeting", "Import numpy library", "Print hello message"]
|
55 |
+
})
|
56 |
+
|
57 |
+
# Test dataframe properties
|
58 |
+
assert len(df) == 3
|
59 |
+
assert list(df.columns) == ["code", "comment"]
|
60 |
+
|
61 |
+
print("✅ Sample dataframe test passed")
|
62 |
+
|
63 |
+
def test_database_initialization():
|
64 |
+
"""Test if database can be initialized"""
|
65 |
+
try:
|
66 |
+
from database import init_db
|
67 |
+
init_db()
|
68 |
+
assert Path("database/data/mlmanager.db").exists(), "Database file was not created"
|
69 |
+
print("✅ Database initialization test passed")
|
70 |
+
except ImportError:
|
71 |
+
print("⚠️ Could not import database module")
|
72 |
+
assert False, "Database module not found"
|
73 |
+
|
74 |
+
def run_tests():
|
75 |
+
"""Run all tests"""
|
76 |
+
print("Running tests for ML Dataset & Code Generation Manager...")
|
77 |
+
|
78 |
+
test_directory_structure()
|
79 |
+
test_css_file()
|
80 |
+
test_huggingface_config()
|
81 |
+
test_streamlit_config()
|
82 |
+
test_sample_dataframe()
|
83 |
+
test_database_initialization()
|
84 |
+
|
85 |
+
print("\nAll tests passed! ✅")
|
86 |
+
|
87 |
+
def test_components_existence():
|
88 |
+
"""Test if core components directories exist"""
|
89 |
+
# Check for components directory
|
90 |
+
components_dir = Path("components")
|
91 |
+
assert components_dir.exists() and components_dir.is_dir(), "Components directory not found"
|
92 |
+
|
93 |
+
# Check for fine_tuning subdirectory
|
94 |
+
fine_tuning_dir = components_dir / "fine_tuning"
|
95 |
+
assert fine_tuning_dir.exists() and fine_tuning_dir.is_dir(), "Fine-tuning components directory not found"
|
96 |
+
|
97 |
+
# Check for essential component files
|
98 |
+
assert (components_dir / "code_quality.py").exists(), "Code quality component not found"
|
99 |
+
assert (components_dir / "dataset_uploader.py").exists(), "Dataset uploader component not found"
|
100 |
+
|
101 |
+
print("✅ Components existence test passed")
|
102 |
+
|
103 |
+
# Run the tests if executed directly
|
104 |
+
if __name__ == '__main__':
|
105 |
+
run_tests()
|
utils/dataset_utils.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_dataset_info(df):
|
5 |
+
"""
|
6 |
+
Get basic information about a dataset.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
df: Pandas DataFrame
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
Dictionary with dataset information
|
13 |
+
"""
|
14 |
+
info = {
|
15 |
+
'rows': df.shape[0],
|
16 |
+
'columns': df.shape[1],
|
17 |
+
'missing_values': df.isna().sum().sum(),
|
18 |
+
'duplicate_rows': df.duplicated().sum(),
|
19 |
+
'memory_usage': df.memory_usage(deep=True).sum() / (1024 * 1024), # MB
|
20 |
+
'column_types': df.dtypes.astype(str).value_counts().to_dict(),
|
21 |
+
'column_info': []
|
22 |
+
}
|
23 |
+
|
24 |
+
# Get info for each column
|
25 |
+
for col in df.columns:
|
26 |
+
col_info = {
|
27 |
+
'name': col,
|
28 |
+
'type': str(df[col].dtype),
|
29 |
+
'missing': df[col].isna().sum(),
|
30 |
+
'missing_pct': (df[col].isna().sum() / len(df)) * 100,
|
31 |
+
'unique_values': df[col].nunique()
|
32 |
+
}
|
33 |
+
|
34 |
+
# Add additional info for numeric columns
|
35 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
36 |
+
col_info.update({
|
37 |
+
'min': df[col].min(),
|
38 |
+
'max': df[col].max(),
|
39 |
+
'mean': df[col].mean(),
|
40 |
+
'median': df[col].median(),
|
41 |
+
'std': df[col].std()
|
42 |
+
})
|
43 |
+
|
44 |
+
# Add additional info for categorical/text columns
|
45 |
+
elif pd.api.types.is_object_dtype(df[col]):
|
46 |
+
# Get top values
|
47 |
+
value_counts = df[col].value_counts().head(5).to_dict()
|
48 |
+
col_info['top_values'] = value_counts
|
49 |
+
|
50 |
+
# Estimate if it's a categorical column
|
51 |
+
if df[col].nunique() / len(df) < 0.1: # If less than 10% of rows have unique values
|
52 |
+
col_info['likely_categorical'] = True
|
53 |
+
else:
|
54 |
+
col_info['likely_categorical'] = False
|
55 |
+
|
56 |
+
info['column_info'].append(col_info)
|
57 |
+
|
58 |
+
return info
|
59 |
+
|
60 |
+
def detect_dataset_format(df):
|
61 |
+
"""
|
62 |
+
Try to detect the format/type of the dataset based on its structure.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
df: Pandas DataFrame
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
String indicating the likely format
|
69 |
+
"""
|
70 |
+
# Check for text data
|
71 |
+
text_cols = 0
|
72 |
+
for col in df.columns:
|
73 |
+
if pd.api.types.is_string_dtype(df[col]) and df[col].str.len().mean() > 100:
|
74 |
+
text_cols += 1
|
75 |
+
|
76 |
+
if text_cols / len(df.columns) > 0.5:
|
77 |
+
return "text"
|
78 |
+
|
79 |
+
# Check for time series data
|
80 |
+
date_cols = 0
|
81 |
+
for col in df.columns:
|
82 |
+
if pd.api.types.is_datetime64_dtype(df[col]):
|
83 |
+
date_cols += 1
|
84 |
+
|
85 |
+
if date_cols > 0:
|
86 |
+
return "time_series"
|
87 |
+
|
88 |
+
# Check if it looks like tabular data
|
89 |
+
numeric_cols = len(df.select_dtypes(include=[np.number]).columns)
|
90 |
+
categorical_cols = len(df.select_dtypes(include=['object', 'category']).columns)
|
91 |
+
|
92 |
+
if numeric_cols > 0 and categorical_cols > 0:
|
93 |
+
return "mixed"
|
94 |
+
elif numeric_cols > 0:
|
95 |
+
return "numeric"
|
96 |
+
elif categorical_cols > 0:
|
97 |
+
return "categorical"
|
98 |
+
|
99 |
+
# Default
|
100 |
+
return "generic"
|
101 |
+
|
102 |
+
def check_column_completeness(df, threshold=0.8):
|
103 |
+
"""
|
104 |
+
Check if columns have good completeness (less than 20% missing values by default).
|
105 |
+
|
106 |
+
Args:
|
107 |
+
df: Pandas DataFrame
|
108 |
+
threshold: Completeness threshold (0.8 = 80% complete)
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
List of columns with poor completeness
|
112 |
+
"""
|
113 |
+
results = []
|
114 |
+
for col in df.columns:
|
115 |
+
missing_ratio = df[col].isna().sum() / len(df)
|
116 |
+
completeness = 1 - missing_ratio
|
117 |
+
|
118 |
+
if completeness < threshold:
|
119 |
+
results.append({
|
120 |
+
'Column': col,
|
121 |
+
'Completeness': f"{completeness:.2%}",
|
122 |
+
'Missing': f"{missing_ratio:.2%}",
|
123 |
+
'Recommendation': 'Consider imputing or removing this column'
|
124 |
+
})
|
125 |
+
|
126 |
+
return results
|
127 |
+
|
128 |
+
def detect_outliers(series, method='iqr', factor=1.5):
|
129 |
+
"""
|
130 |
+
Detect outliers in a pandas Series using IQR or Z-score method.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
series: Pandas Series with numeric values
|
134 |
+
method: 'iqr' or 'zscore'
|
135 |
+
factor: Multiplier for IQR or Z-score threshold
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Tuple of (outlier_indices, lower_bound, upper_bound)
|
139 |
+
"""
|
140 |
+
if method == 'iqr':
|
141 |
+
# IQR method
|
142 |
+
q1 = series.quantile(0.25)
|
143 |
+
q3 = series.quantile(0.75)
|
144 |
+
iqr = q3 - q1
|
145 |
+
|
146 |
+
lower_bound = q1 - factor * iqr
|
147 |
+
upper_bound = q3 + factor * iqr
|
148 |
+
|
149 |
+
outliers = series[(series < lower_bound) | (series > upper_bound)].index.tolist()
|
150 |
+
|
151 |
+
else: # zscore
|
152 |
+
# Z-score method
|
153 |
+
from scipy import stats
|
154 |
+
z_scores = stats.zscore(series.dropna())
|
155 |
+
abs_z_scores = abs(z_scores)
|
156 |
+
|
157 |
+
# Filter for Z-scores above threshold
|
158 |
+
outlier_indices = np.where(abs_z_scores > factor)[0]
|
159 |
+
outliers = series.dropna().iloc[outlier_indices].index.tolist()
|
160 |
+
|
161 |
+
# Compute equivalent bounds for consistency
|
162 |
+
mean = series.mean()
|
163 |
+
std = series.std()
|
164 |
+
lower_bound = mean - factor * std
|
165 |
+
upper_bound = mean + factor * std
|
166 |
+
|
167 |
+
return outliers, lower_bound, upper_bound
|
utils/huggingface_integration.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from huggingface_hub import HfApi, list_datasets
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
@st.cache_data(ttl=3600)
|
8 |
+
def search_huggingface_datasets(query, limit=20):
|
9 |
+
"""
|
10 |
+
Search for datasets on Hugging Face Hub.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
query: Search query string
|
14 |
+
limit: Maximum number of results to return
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
List of dataset metadata
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
api = HfApi()
|
21 |
+
datasets = list_datasets(
|
22 |
+
filter=query,
|
23 |
+
limit=limit
|
24 |
+
)
|
25 |
+
|
26 |
+
# Convert to list of dicts with relevant info
|
27 |
+
results = []
|
28 |
+
for dataset in datasets:
|
29 |
+
results.append({
|
30 |
+
'id': dataset.id,
|
31 |
+
'name': dataset.id.split('/')[-1],
|
32 |
+
'description': dataset.description or "No description available",
|
33 |
+
'author': dataset.author or "Unknown",
|
34 |
+
'tags': dataset.tags,
|
35 |
+
'downloads': dataset.downloads
|
36 |
+
})
|
37 |
+
|
38 |
+
return results
|
39 |
+
except Exception as e:
|
40 |
+
st.error(f"Error searching Hugging Face Hub: {str(e)}")
|
41 |
+
return []
|
42 |
+
|
43 |
+
@st.cache_data(ttl=3600)
|
44 |
+
def load_huggingface_dataset(dataset_id, split='train'):
|
45 |
+
"""
|
46 |
+
Load a dataset from Hugging Face Hub.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
dataset_id: ID of the dataset on HF Hub (e.g., 'mnist', 'glue', etc.)
|
50 |
+
split: Dataset split to load (e.g., 'train', 'test', 'validation')
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Pandas DataFrame containing the dataset
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
# Load the dataset
|
57 |
+
dataset = load_dataset(dataset_id, split=split)
|
58 |
+
|
59 |
+
# Convert to pandas DataFrame
|
60 |
+
df = dataset.to_pandas()
|
61 |
+
|
62 |
+
return df
|
63 |
+
except Exception as e:
|
64 |
+
st.error(f"Error loading dataset '{dataset_id}': {str(e)}")
|
65 |
+
raise
|
66 |
+
|
67 |
+
def upload_to_huggingface(dataset, dataset_name, token=None):
|
68 |
+
"""
|
69 |
+
Upload a dataset to Hugging Face Hub.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
dataset: Pandas DataFrame to upload
|
73 |
+
dataset_name: Name for the dataset
|
74 |
+
token: Hugging Face API token (optional, will use environment variable if not provided)
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
URL to the uploaded dataset
|
78 |
+
"""
|
79 |
+
# Get token from environment if not provided
|
80 |
+
if token is None:
|
81 |
+
token = os.getenv("HF_TOKEN")
|
82 |
+
if not token:
|
83 |
+
raise ValueError("No Hugging Face token provided. Set the HF_TOKEN environment variable or pass a token.")
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Convert to HF dataset
|
87 |
+
from datasets import Dataset
|
88 |
+
hf_dataset = Dataset.from_pandas(dataset)
|
89 |
+
|
90 |
+
# Upload to HF Hub
|
91 |
+
push_result = hf_dataset.push_to_hub(
|
92 |
+
dataset_name,
|
93 |
+
token=token
|
94 |
+
)
|
95 |
+
|
96 |
+
return f"https://huggingface.co/datasets/{push_result.repo_id}"
|
97 |
+
except Exception as e:
|
98 |
+
st.error(f"Error uploading to Hugging Face Hub: {str(e)}")
|
99 |
+
raise
|
utils/smolagents_integration.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def process_with_smolagents(dataset, operation, custom_code=None):
|
6 |
+
"""
|
7 |
+
Process dataset using SmolaAgents for various operations.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
dataset: Pandas DataFrame to process
|
11 |
+
operation: Type of processing operation
|
12 |
+
custom_code: Custom code to execute (for custom processing)
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
Processed pandas DataFrame
|
16 |
+
"""
|
17 |
+
if dataset is None:
|
18 |
+
raise ValueError("No dataset provided")
|
19 |
+
|
20 |
+
# Create a copy to avoid modifying the original
|
21 |
+
processed_df = dataset.copy()
|
22 |
+
|
23 |
+
try:
|
24 |
+
if operation == "Data Cleaning":
|
25 |
+
processed_df = clean_dataset(processed_df)
|
26 |
+
elif operation == "Feature Engineering":
|
27 |
+
processed_df = engineer_features(processed_df)
|
28 |
+
elif operation == "Data Transformation":
|
29 |
+
processed_df = transform_dataset(processed_df)
|
30 |
+
elif operation == "Custom Processing" and custom_code:
|
31 |
+
# Execute custom code
|
32 |
+
# Note: This is a security risk in a real application
|
33 |
+
# Should be replaced with a safer approach
|
34 |
+
local_vars = {"df": processed_df}
|
35 |
+
exec(custom_code, {"pd": pd, "np": np}, local_vars)
|
36 |
+
processed_df = local_vars["df"]
|
37 |
+
else:
|
38 |
+
raise ValueError(f"Unsupported operation: {operation}")
|
39 |
+
|
40 |
+
return processed_df
|
41 |
+
|
42 |
+
except Exception as e:
|
43 |
+
st.error(f"Error during processing: {str(e)}")
|
44 |
+
raise
|
45 |
+
|
46 |
+
def clean_dataset(df):
|
47 |
+
"""
|
48 |
+
Clean the dataset by handling missing values, duplicates, and outliers.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
df: Pandas DataFrame to clean
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Cleaned pandas DataFrame
|
55 |
+
"""
|
56 |
+
# Create a copy to avoid modifying the original
|
57 |
+
cleaned_df = df.copy()
|
58 |
+
|
59 |
+
# Remove duplicate rows
|
60 |
+
cleaned_df = cleaned_df.drop_duplicates()
|
61 |
+
|
62 |
+
# Handle missing values
|
63 |
+
for col in cleaned_df.columns:
|
64 |
+
# For numeric columns
|
65 |
+
if pd.api.types.is_numeric_dtype(cleaned_df[col]):
|
66 |
+
# If more than 20% missing, leave as is
|
67 |
+
if cleaned_df[col].isna().mean() > 0.2:
|
68 |
+
continue
|
69 |
+
|
70 |
+
# Otherwise impute with median
|
71 |
+
cleaned_df[col] = cleaned_df[col].fillna(cleaned_df[col].median())
|
72 |
+
|
73 |
+
# For categorical columns
|
74 |
+
elif pd.api.types.is_object_dtype(cleaned_df[col]):
|
75 |
+
# If more than 20% missing, leave as is
|
76 |
+
if cleaned_df[col].isna().mean() > 0.2:
|
77 |
+
continue
|
78 |
+
|
79 |
+
# Otherwise impute with mode
|
80 |
+
mode_value = cleaned_df[col].mode()[0] if not cleaned_df[col].mode().empty else "Unknown"
|
81 |
+
cleaned_df[col] = cleaned_df[col].fillna(mode_value)
|
82 |
+
|
83 |
+
# Handle outliers in numeric columns
|
84 |
+
for col in cleaned_df.select_dtypes(include=[np.number]).columns:
|
85 |
+
# Skip if too many missing values
|
86 |
+
if cleaned_df[col].isna().mean() > 0.1:
|
87 |
+
continue
|
88 |
+
|
89 |
+
# Calculate IQR
|
90 |
+
q1 = cleaned_df[col].quantile(0.25)
|
91 |
+
q3 = cleaned_df[col].quantile(0.75)
|
92 |
+
iqr = q3 - q1
|
93 |
+
|
94 |
+
# Define bounds
|
95 |
+
lower_bound = q1 - 1.5 * iqr
|
96 |
+
upper_bound = q3 + 1.5 * iqr
|
97 |
+
|
98 |
+
# Cap outliers instead of removing
|
99 |
+
cleaned_df[col] = cleaned_df[col].clip(lower_bound, upper_bound)
|
100 |
+
|
101 |
+
return cleaned_df
|
102 |
+
|
103 |
+
def engineer_features(df):
|
104 |
+
"""
|
105 |
+
Perform basic feature engineering on the dataset.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
df: Pandas DataFrame to process
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
DataFrame with engineered features
|
112 |
+
"""
|
113 |
+
# Create a copy to avoid modifying the original
|
114 |
+
engineered_df = df.copy()
|
115 |
+
|
116 |
+
# Get numeric columns
|
117 |
+
numeric_cols = engineered_df.select_dtypes(include=[np.number]).columns
|
118 |
+
|
119 |
+
# Skip if less than 2 numeric columns
|
120 |
+
if len(numeric_cols) >= 2:
|
121 |
+
# Create interaction features for pairs of numeric columns
|
122 |
+
# Limit to first 5 columns to avoid feature explosion
|
123 |
+
for i, col1 in enumerate(numeric_cols[:5]):
|
124 |
+
for col2 in numeric_cols[i+1:5]:
|
125 |
+
# Product interaction
|
126 |
+
engineered_df[f"{col1}_{col2}_product"] = engineered_df[col1] * engineered_df[col2]
|
127 |
+
|
128 |
+
# Ratio interaction (avoid division by zero)
|
129 |
+
denominator = engineered_df[col2].replace(0, np.nan)
|
130 |
+
engineered_df[f"{col1}_{col2}_ratio"] = engineered_df[col1] / denominator
|
131 |
+
|
132 |
+
# Create binary features from categorical columns
|
133 |
+
cat_cols = engineered_df.select_dtypes(include=['object', 'category']).columns
|
134 |
+
for col in cat_cols:
|
135 |
+
# Skip if too many unique values (>10)
|
136 |
+
if engineered_df[col].nunique() > 10:
|
137 |
+
continue
|
138 |
+
|
139 |
+
# One-hot encode
|
140 |
+
dummies = pd.get_dummies(engineered_df[col], prefix=col, drop_first=True)
|
141 |
+
engineered_df = pd.concat([engineered_df, dummies], axis=1)
|
142 |
+
|
143 |
+
# Create aggregated features
|
144 |
+
if len(numeric_cols) >= 3:
|
145 |
+
# Sum of all numeric features
|
146 |
+
engineered_df['sum_numeric'] = engineered_df[numeric_cols].sum(axis=1)
|
147 |
+
|
148 |
+
# Mean of all numeric features
|
149 |
+
engineered_df['mean_numeric'] = engineered_df[numeric_cols].mean(axis=1)
|
150 |
+
|
151 |
+
# Standard deviation of numeric features
|
152 |
+
engineered_df['std_numeric'] = engineered_df[numeric_cols].std(axis=1)
|
153 |
+
|
154 |
+
return engineered_df
|
155 |
+
|
156 |
+
def transform_dataset(df):
|
157 |
+
"""
|
158 |
+
Perform data transformations on the dataset.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
df: Pandas DataFrame to transform
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Transformed pandas DataFrame
|
165 |
+
"""
|
166 |
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
167 |
+
|
168 |
+
# Create a copy to avoid modifying the original
|
169 |
+
transformed_df = df.copy()
|
170 |
+
|
171 |
+
# Get numeric columns
|
172 |
+
numeric_cols = transformed_df.select_dtypes(include=[np.number]).columns
|
173 |
+
|
174 |
+
if len(numeric_cols) > 0:
|
175 |
+
# Create scaled versions of numeric columns
|
176 |
+
|
177 |
+
# Standard scaling (z-score)
|
178 |
+
scaler = StandardScaler()
|
179 |
+
scaled_data = scaler.fit_transform(transformed_df[numeric_cols])
|
180 |
+
scaled_df = pd.DataFrame(
|
181 |
+
scaled_data,
|
182 |
+
columns=[f"{col}_scaled" for col in numeric_cols],
|
183 |
+
index=transformed_df.index
|
184 |
+
)
|
185 |
+
|
186 |
+
# Min-max scaling (0-1 range)
|
187 |
+
minmax_scaler = MinMaxScaler()
|
188 |
+
minmax_data = minmax_scaler.fit_transform(transformed_df[numeric_cols])
|
189 |
+
minmax_df = pd.DataFrame(
|
190 |
+
minmax_data,
|
191 |
+
columns=[f"{col}_normalized" for col in numeric_cols],
|
192 |
+
index=transformed_df.index
|
193 |
+
)
|
194 |
+
|
195 |
+
# Log transform (for positive columns only)
|
196 |
+
log_cols = []
|
197 |
+
for col in numeric_cols:
|
198 |
+
if (transformed_df[col] > 0).all():
|
199 |
+
transformed_df[f"{col}_log"] = np.log(transformed_df[col])
|
200 |
+
log_cols.append(f"{col}_log")
|
201 |
+
|
202 |
+
# Combine all transformations
|
203 |
+
transformed_df = pd.concat([transformed_df, scaled_df, minmax_df], axis=1)
|
204 |
+
|
205 |
+
# One-hot encode categorical columns
|
206 |
+
cat_cols = transformed_df.select_dtypes(include=['object', 'category']).columns
|
207 |
+
if len(cat_cols) > 0:
|
208 |
+
# One-hot encode all categorical columns
|
209 |
+
transformed_df = pd.get_dummies(transformed_df, columns=cat_cols, drop_first=False)
|
210 |
+
|
211 |
+
return transformed_df
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|