katsukiai commited on
Commit
8a4e856
·
verified ·
1 Parent(s): 37323c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import login, HfApi, snapshot_download
3
+ from datasets import load_dataset
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
6
+ from transformers import BertTokenizer, BertForSequenceClassification
7
+ import pandas as pd
8
+ import os
9
+
10
+ # Streamlit app configuration
11
+ st.set_page_config(page_title="Katsukiai Dataset Trainer", layout="wide")
12
+
13
+ # Sidebar for navigation
14
+ st.sidebar.title("Navigation")
15
+ tabs = ["Train", "Train with DeepSeek-V3", "Select Dataset and Format", "About", "Settings"]
16
+ selected_tab = st.sidebar.radio("Select Tab", tabs)
17
+
18
+ # Settings state
19
+ if "settings" not in st.session_state:
20
+ st.session_state.settings = {
21
+ "token": "",
22
+ "username": "",
23
+ "use_torch": False,
24
+ "use_bert": False
25
+ }
26
+
27
+ # Functions
28
+ def load_katsukiai_dataset(dataset_name):
29
+ return load_dataset(f"Katsukiai/{dataset_name}", token=st.session_state.settings["token"] if st.session_state.settings["token"] else None)
30
+
31
+ def train_with_bert(dataset, model_name="bert-base-uncased"):
32
+ tokenizer = BertTokenizer.from_pretrained(model_name)
33
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
34
+
35
+ def tokenize_function(examples):
36
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
37
+
38
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
39
+ tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
40
+
41
+ training_args = TrainingArguments(
42
+ output_dir=f"./converted/results_{st.session_state.settings['username']}",
43
+ num_train_epochs=3,
44
+ per_device_train_batch_size=8,
45
+ save_steps=10_000,
46
+ save_total_limit=2,
47
+ )
48
+
49
+ KILL trainer = Trainer(
50
+ model=model,
51
+ args=training_args,
52
+ train_dataset=tokenized_dataset["train"],
53
+ eval_dataset=tokenized_dataset["test"]
54
+ )
55
+ trainer.train()
56
+ return "BERT Training Complete"
57
+
58
+ def train_with_deepseek(dataset, model_name="deepseek-ai/DeepSeek-V3"):
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
60
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
61
+
62
+ def tokenize_function(examples):
63
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
64
+
65
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
66
+ tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
67
+
68
+ training_args = TrainingArguments(
69
+ output_dir=f"./deepseek/results_{st.session_state.settings['username']}",
70
+ num_train_epochs=3,
71
+ per_device_train_batch_size=4,
72
+ gradient_accumulation_steps=2,
73
+ save_steps=10_000,
74
+ save_total_limit=2,
75
+ fp16=True # Mixed precision for efficiency
76
+ )
77
+
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ train_dataset=tokenized_dataset["train"],
82
+ )
83
+ trainer.train()
84
+ return "DeepSeek-V3 Training Complete"
85
+
86
+ # Tab content
87
+ if selected_tab == "Train":
88
+ st.title("Train Katsukiai Dataset")
89
+ api = HfApi()
90
+ datasets_list = [d.id.split("/")[-1] for d in api.list_datasets(author="Katsukiai")]
91
+ dataset_name = st.selectbox("Select Dataset", datasets_list)
92
+ if st.button("Start Training"):
93
+ dataset = load_katsukiai_dataset(dataset_name)
94
+ if st.session_state.settings["use_bert"]:
95
+ result = train_with_bert(dataset)
96
+ st.success(result)
97
+ elif st.session_state.settings["use_torch"]:
98
+ st.write("Training with Torch (custom implementation required)")
99
+ else:
100
+ st.write("Basic training (no specific model selected)")
101
+
102
+ elif selected_tab == "Train with DeepSeek-V3":
103
+ st.title("Train with DeepSeek-V3")
104
+ dataset_name = st.selectbox("Select Dataset", [d.id.split("/")[-1] for d in api.list_datasets(author="Katsukiai")])
105
+ if st.button("Train with DeepSeek"):
106
+ if st.session_state.settings["token"]:
107
+ login(st.session_state.settings["token"])
108
+ dataset = load_katsukiai_dataset(dataset_name)
109
+ result = train_with_deepseek(dataset)
110
+ st.success(result)
111
+ else:
112
+ st.error("Please set Hugging Face token in Settings")
113
+
114
+ elif selected_tab == "Select Dataset and Format":
115
+ st.title("Select Dataset and Format")
116
+ api = HfApi()
117
+ datasets_list = [d.id.split("/")[-1] for d in api.list_datasets(author="katsukiai")]
118
+ dataset_name = st.selectbox("Select Dataset", datasets_list)
119
+ format_option = st.selectbox("Select Format", ["csv", "json", "parquet"])
120
+ if st.button("Load Dataset"):
121
+ dataset = load_katsukiai_dataset(dataset_name)
122
+ df = pd.DataFrame(dataset["train"])
123
+ if format_option == "csv":
124
+ st.download_button("Download CSV", df.to_csv(index=False), "dataset.csv")
125
+ elif format_option == "json":
126
+ st.download_button("Download JSON", df.to_json(), "dataset.json")
127
+ else:
128
+ st.download_button("Download Parquet", df.to_parquet(), "dataset.parquet")
129
+
130
+ elif selected_tab == "About":
131
+ st.title("About")
132
+ st.write("This app trains models on Katsukiai datasets from Hugging Face.")
133
+ st.write("Features:")
134
+ st.write("- Train with BERT or custom Torch models")
135
+ st.write("- Train using DeepSeek-V3 from Hugging Face")
136
+ st.write("- Dataset selection and format conversion")
137
+ st.write("Built with Streamlit, Hugging Face Hub, and PyTorch.")
138
+
139
+ elif selected_tab == "Settings":
140
+ st.title("Settings")
141
+ token = st.text_input("Hugging Face Token", value=st.session_state.settings["token"])
142
+ username = st.text_input("Username (for output folder)", value=st.session_state.settings["username"])
143
+ use_torch = st.checkbox("Use Torch", value=st.session_state.settings["use_torch"])
144
+ use_bert = st.checkbox("Use BERT & Tokenizer", value=st.session_state.settings["use_bert"])
145
+
146
+ if st.button("Save Settings"):
147
+ st.session_state.settings.update({
148
+ "token": token,
149
+ "username": username,
150
+ "use_torch": use_torch,
151
+ "use_bert": use_bert
152
+ })
153
+ if username and not os.path.exists(f"./results_{username}"):
154
+ os.makedirs(f"./results_{username}")
155
+ st.success("Settings saved!")