File size: 6,394 Bytes
8a4e856 73e2dba 8a4e856 7051fd2 8a4e856 7051fd2 8a4e856 73e2dba 8a4e856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import streamlit as st
from huggingface_hub import login, HfApi, snapshot_download
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import os
# Streamlit app configuration
st.set_page_config(page_title="Katsukiai Dataset Trainer", layout="wide")
# Sidebar for navigation
st.sidebar.title("Navigation")
tabs = ["Train", "Train with DeepSeek-V3", "Select Dataset and Format", "About", "Settings"]
selected_tab = st.sidebar.radio("Select Tab", tabs)
# Settings state
if "settings" not in st.session_state:
st.session_state.settings = {
"token": "",
"username": "",
"use_torch": False,
"use_bert": False
}
# Functions
def load_katsukiai_dataset(dataset_name):
return load_dataset(f"Katsukiai/{dataset_name}", token=st.session_state.settings["token"] if st.session_state.settings["token"] else None)
def train_with_bert(dataset, model_name="bert-base-uncased"):
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir=f"./converted/results_{st.session_state.settings['username']}",
num_train_epochs=3,
per_device_train_batch_size=8,
save_steps=10_000,
save_total_limit=2,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"]
)
trainer.train()
return "BERT Training Complete"
def train_with_deepseek(dataset, model_name="deepseek-ai/DeepSeek-V3"):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
training_args = TrainingArguments(
output_dir=f"./deepseek/results_{st.session_state.settings['username']}",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
save_steps=10_000,
save_total_limit=2,
fp16=True # Mixed precision for efficiency
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
)
trainer.train()
return "DeepSeek-V3 Training Complete"
# Tab content
if selected_tab == "Train":
st.title("Train Katsukiai Dataset")
api = HfApi()
datasets_list = [d.id.split("/")[-1] for d in api.list_datasets(author="katsukiai")]
dataset_name = st.selectbox("Select Dataset", datasets_list)
if st.button("Start Training"):
dataset = load_katsukiai_dataset(dataset_name)
if st.session_state.settings["use_bert"]:
result = train_with_bert(dataset)
st.success(result)
elif st.session_state.settings["use_torch"]:
st.write("Training with Torch (custom implementation required)")
else:
st.write("Basic training (no specific model selected)")
elif selected_tab == "Train with DeepSeek-V3":
st.title("Train with DeepSeek-V3")
dataset_name = st.selectbox("Select Dataset", [d.id.split("/")[-1] for d in api.list_datasets(author="katsukiai")])
if st.button("Train with DeepSeek"):
if st.session_state.settings["token"]:
login(st.session_state.settings["token"])
dataset = load_katsukiai_dataset(dataset_name)
result = train_with_deepseek(dataset)
st.success(result)
else:
st.error("Please set Hugging Face token in Settings")
elif selected_tab == "Select Dataset and Format":
st.title("Select Dataset and Format")
api = HfApi()
datasets_list = [d.id.split("/")[-1] for d in api.list_datasets(author="katsukiai")]
dataset_name = st.selectbox("Select Dataset", datasets_list)
format_option = st.selectbox("Select Format", ["csv", "json", "parquet"])
if st.button("Load Dataset"):
dataset = load_katsukiai_dataset(dataset_name)
df = pd.DataFrame(dataset["train"])
if format_option == "csv":
st.download_button("Download CSV", df.to_csv(index=False), "dataset.csv")
elif format_option == "json":
st.download_button("Download JSON", df.to_json(), "dataset.json")
else:
st.download_button("Download Parquet", df.to_parquet(), "dataset.parquet")
elif selected_tab == "About":
st.title("About")
st.write("This app trains models on Katsukiai datasets from Hugging Face.")
st.write("Features:")
st.write("- Train with BERT or custom Torch models")
st.write("- Train using DeepSeek-V3 from Hugging Face")
st.write("- Dataset selection and format conversion")
st.write("Built with Streamlit, Hugging Face Hub, and PyTorch.")
elif selected_tab == "Settings":
st.title("Settings")
token = st.text_input("Hugging Face Token", value=st.session_state.settings["token"])
username = st.text_input("Username (for output folder)", value=st.session_state.settings["username"])
use_torch = st.checkbox("Use Torch", value=st.session_state.settings["use_torch"])
use_bert = st.checkbox("Use BERT & Tokenizer", value=st.session_state.settings["use_bert"])
if st.button("Save Settings"):
st.session_state.settings.update({
"token": token,
"username": username,
"use_torch": use_torch,
"use_bert": use_bert
})
if username and not os.path.exists(f"./converted/results_{username}"):
os.makedirs(f"./converted/results_{username}")
st.success("Settings saved!") |