tymbos commited on
Commit
72577d1
·
verified ·
1 Parent(s): cb19003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -111
app.py CHANGED
@@ -2,10 +2,6 @@
2
  import os
3
  import gc
4
  import gradio as gr
5
- import requests
6
- import time
7
- from io import BytesIO
8
- import matplotlib.pyplot as plt
9
  from datasets import load_dataset
10
  from train_tokenizer import train_tokenizer
11
  from tokenizers import Tokenizer
@@ -13,27 +9,25 @@ from langdetect import detect, DetectorFactory
13
  from PIL import Image
14
  from datetime import datetime
15
  from concurrent.futures import ThreadPoolExecutor
 
 
 
16
 
17
  # Για επαναληψιμότητα στο langdetect
18
  DetectorFactory.seed = 0
19
 
20
  # Ρυθμίσεις
21
  CHECKPOINT_FILE = "checkpoint.txt"
22
- TOKENIZER_DIR = os.getcwd() # Χρησιμοποιεί τον τρέχοντα φάκελο
23
- #TOKENIZER_DIR = "tokenizer_model"
24
  TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
25
- MAX_SAMPLES = 5000000 # Αυξημένο όριο δειγμάτων
26
- DEFAULT_CHUNK_SIZE = 200000 # Μεγαλύτερο chunk size
27
- BATCH_SIZE = 1000 # Μέγεθος batch για φόρτωση δεδομένων
28
- NUM_WORKERS = 4 # Αριθμός workers για πολυνηματική επεξεργασία
29
 
30
  # Παγκόσμια μεταβλητή ελέγχου
31
  STOP_COLLECTION = False
32
 
33
- # Καταγραφή εκκίνησης
34
- startup_log = f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n"
35
- print(startup_log)
36
-
37
  def load_checkpoint():
38
  """Φόρτωση δεδομένων από το checkpoint."""
39
  if os.path.exists(CHECKPOINT_FILE):
@@ -50,7 +44,6 @@ def append_to_checkpoint(texts):
50
  def create_iterator(dataset_name, configs, split):
51
  """Βελτιωμένο iterator με batch φόρτωση και caching."""
52
  configs_list = [c.strip() for c in configs.split(",") if c.strip()]
53
-
54
  for config in configs_list:
55
  try:
56
  dataset = load_dataset(
@@ -58,52 +51,39 @@ def create_iterator(dataset_name, configs, split):
58
  name=config,
59
  split=split,
60
  streaming=True,
61
- cache_dir="./dataset_cache" # Ενεργοποίηση cache
62
  )
63
-
64
- # Φόρτωση δεδομένων σε batches
65
  while True:
66
  batch = list(dataset.take(BATCH_SIZE))
67
  if not batch:
68
  break
69
  dataset = dataset.skip(BATCH_SIZE)
70
-
71
- # Πολυνηματική επεξεργασία batch
72
  with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
73
  processed_texts = list(executor.map(process_example, batch))
74
-
75
  yield from filter(None, processed_texts)
76
-
77
  except Exception as e:
78
- print(f"⚠️ Σφάλμα φόρτωσης: {config}: {e}")
79
 
80
  def process_example(example):
81
  """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
82
  try:
83
  text = example.get('text', '').strip()
84
- if text and detect(text) in ['el', 'en']: # Φιλτράρισμα γλώσσας
85
  return text
86
  return None
87
  except:
88
  return None
89
 
90
  def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
91
- """Βελτιωμένη συλλογή δεδομένων με μεγάλα chunks."""
92
  global STOP_COLLECTION
93
  STOP_COLLECTION = False
94
  total_processed = len(load_checkpoint())
95
-
96
- progress_messages = [
97
- f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}",
98
- f"⚙️ Ρυθμίσεις: Chunk Size={chunk_size}, Workers={NUM_WORKERS}"
99
- ]
100
-
101
  dataset_iterator = create_iterator(dataset_name, configs, split)
102
  chunk = []
103
-
104
  while not STOP_COLLECTION and total_processed < max_samples:
105
  try:
106
- # Φόρτωση chunk
107
  while len(chunk) < chunk_size:
108
  text = next(dataset_iterator)
109
  if text:
@@ -111,119 +91,97 @@ def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
111
  total_processed += 1
112
  if total_processed >= max_samples:
113
  break
114
-
115
- # Αποθήκευση chunk
116
  if chunk:
117
  append_to_checkpoint(chunk)
118
- progress_messages.append(
119
- f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})"
120
- )
121
  chunk = []
122
-
123
- # Εκκαθάριση μνήμης
124
  gc.collect()
125
-
126
  except StopIteration:
127
  progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
128
  break
129
  except Exception as e:
130
  progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
131
  break
132
-
133
  return "\n".join(progress_messages)
134
 
135
  def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
136
- """Βελτιωμένη εκπαίδευση tokenizer με χρήση cache."""
137
- print("🚀 Εκκίνηση εκπαίδευσης...")
138
- all_texts = load_checkpoint()
139
-
140
- # Παράλληλη επεξεργασία για εκπαίδευση
141
- tokenizer = train_tokenizer(
142
- all_texts,
143
- vocab_size=vocab_size,
144
- min_frequency=min_freq,
145
- output_dir=TOKENIZER_DIR,
146
- num_threads=NUM_WORKERS # Παράλληλη επεξεργασία
147
- )
148
-
149
- # Φόρτωση και δοκιμή tokenizer
150
- trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
151
- encoded = trained_tokenizer.encode(test_text)
152
- decoded = trained_tokenizer.decode(encoded.ids)
153
-
154
- # Δημιουργία γραφήματος
155
- fig, ax = plt.subplots()
156
- ax.hist([len(t) for t in encoded.tokens], bins=20)
157
- ax.set_xlabel('Μήκος Token')
158
- ax.set_ylabel('Συχνότητα')
159
- img_buffer = BytesIO()
160
- plt.savefig(img_buffer, format='png')
161
- plt.close()
162
-
163
- return ("✅ Εκπαίδευση ολοκληρώθηκε!", decoded, Image.open(img_buffer))
164
- print(f"Ο tokenizer αποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}")
165
 
166
  def analyze_checkpoint():
167
- """Νέα λειτουργία ανάλυσης δεδομένων."""
168
- texts = load_checkpoint()
169
- if not texts:
170
- return "Δεν βρέθηκαν δεδομένα για ανάλυση."
171
-
172
- # Βασική στατιστική
173
- total_chars = sum(len(t) for t in texts)
174
- avg_length = total_chars / len(texts) if texts else 0
175
-
176
- # Ανάλυση γλώσσας
177
- languages = {}
178
- for t in texts[:1000]: # Δειγματοληψία για ταχύτητα
179
- try:
180
- lang = detect(t)
181
- languages[lang] = languages.get(lang, 0) + 1
182
- except:
183
- continue
184
-
185
- report = [
186
- f"📊 Σύνολο δειγμάτων: {len(texts)}",
187
- f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
188
- "🌍 Γλώσσες (δείγμα 1000):",
189
- *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
190
- ]
191
-
192
- return "\n".join(report)
193
 
194
  def restart_collection():
195
- """Διαγράφει το checkpoint και επανεκκινεί τη συλλογή."""
196
  global STOP_COLLECTION
197
  STOP_COLLECTION = False
198
  if os.path.exists(CHECKPOINT_FILE):
199
  os.remove(CHECKPOINT_FILE)
200
- print("🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή.")
201
  return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή."
202
 
203
-
204
  # Gradio Interface
205
  with gr.Blocks() as demo:
206
- gr.Markdown("## Βελτιωμένος Wikipedia Tokenizer Trainer")
207
-
208
  with gr.Row():
209
  with gr.Column(scale=2):
210
  dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
211
  configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
212
  split = gr.Dropdown(["train"], value="train", label="Split")
213
  chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
214
- vocab_size = gr.Slider(20000, 200000, value=50000, step=10000, label="Vocabulary Size")
215
- min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
216
  test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
217
- max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Maximum Samples")
218
-
219
  with gr.Row():
220
  start_btn = gr.Button("Start", variant="primary")
221
  stop_btn = gr.Button("Stop", variant="stop")
222
  restart_btn = gr.Button("Restart")
223
-
224
  analyze_btn = gr.Button("Analyze Data")
225
  train_btn = gr.Button("Train Tokenizer", variant="primary")
226
-
227
  with gr.Column(scale=3):
228
  progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
229
  gr.Markdown("### Αποτελέσματα")
@@ -232,10 +190,10 @@ with gr.Blocks() as demo:
232
 
233
  # Event handlers
234
  start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
235
- stop_btn.click(lambda: "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
236
- restart_btn.click(lambda: "🔄 Επαναφορά...", None, progress).then(restart_collection, None, progress)
237
  analyze_btn.click(analyze_checkpoint, None, progress)
238
  train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
239
- [progress, decoded_text, token_distribution])
240
 
241
  demo.queue().launch()
 
2
  import os
3
  import gc
4
  import gradio as gr
 
 
 
 
5
  from datasets import load_dataset
6
  from train_tokenizer import train_tokenizer
7
  from tokenizers import Tokenizer
 
9
  from PIL import Image
10
  from datetime import datetime
11
  from concurrent.futures import ThreadPoolExecutor
12
+ import matplotlib.pyplot as plt
13
+ from io import BytesIO
14
+ import traceback
15
 
16
  # Για επαναληψιμότητα στο langdetect
17
  DetectorFactory.seed = 0
18
 
19
  # Ρυθμίσεις
20
  CHECKPOINT_FILE = "checkpoint.txt"
21
+ TOKENIZER_DIR = "./tokenizer_model"
 
22
  TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
23
+ MAX_SAMPLES = 5000000
24
+ DEFAULT_CHUNK_SIZE = 200000
25
+ BATCH_SIZE = 1000
26
+ NUM_WORKERS = 4
27
 
28
  # Παγκόσμια μεταβλητή ελέγχου
29
  STOP_COLLECTION = False
30
 
 
 
 
 
31
  def load_checkpoint():
32
  """Φόρτωση δεδομένων από το checkpoint."""
33
  if os.path.exists(CHECKPOINT_FILE):
 
44
  def create_iterator(dataset_name, configs, split):
45
  """Βελτιωμένο iterator με batch φόρτωση και caching."""
46
  configs_list = [c.strip() for c in configs.split(",") if c.strip()]
 
47
  for config in configs_list:
48
  try:
49
  dataset = load_dataset(
 
51
  name=config,
52
  split=split,
53
  streaming=True,
54
+ cache_dir="./dataset_cache"
55
  )
 
 
56
  while True:
57
  batch = list(dataset.take(BATCH_SIZE))
58
  if not batch:
59
  break
60
  dataset = dataset.skip(BATCH_SIZE)
 
 
61
  with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
62
  processed_texts = list(executor.map(process_example, batch))
 
63
  yield from filter(None, processed_texts)
 
64
  except Exception as e:
65
+ print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}")
66
 
67
  def process_example(example):
68
  """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
69
  try:
70
  text = example.get('text', '').strip()
71
+ if text and detect(text) in ['el', 'en']:
72
  return text
73
  return None
74
  except:
75
  return None
76
 
77
  def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
78
+ """Συλλογή δεδομένων με streaming και checkpoints."""
79
  global STOP_COLLECTION
80
  STOP_COLLECTION = False
81
  total_processed = len(load_checkpoint())
82
+ progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"]
 
 
 
 
 
83
  dataset_iterator = create_iterator(dataset_name, configs, split)
84
  chunk = []
 
85
  while not STOP_COLLECTION and total_processed < max_samples:
86
  try:
 
87
  while len(chunk) < chunk_size:
88
  text = next(dataset_iterator)
89
  if text:
 
91
  total_processed += 1
92
  if total_processed >= max_samples:
93
  break
 
 
94
  if chunk:
95
  append_to_checkpoint(chunk)
96
+ progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})")
 
 
97
  chunk = []
 
 
98
  gc.collect()
 
99
  except StopIteration:
100
  progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
101
  break
102
  except Exception as e:
103
  progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
104
  break
 
105
  return "\n".join(progress_messages)
106
 
107
  def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
108
+ """Εκπαίδευση του tokenizer και έλεγχος ποιότητας."""
109
+ messages = ["🚀 Εκκίνηση εκπαίδευσης..."]
110
+ try:
111
+ all_texts = load_checkpoint()
112
+ messages.append("📚 Φόρτωση δεδομένων από checkpoint...")
113
+ tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS)
114
+ messages.append("✅ Εκπαίδευση ολοκληρώθηκε!")
115
+ trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
116
+ encoded = trained_tokenizer.encode(test_text)
117
+ decoded = trained_tokenizer.decode(encoded.ids)
118
+ fig, ax = plt.subplots()
119
+ ax.hist([len(t) for t in encoded.tokens], bins=20)
120
+ ax.set_xlabel('Μήκος Token')
121
+ ax.set_ylabel('Συχνότητα')
122
+ img_buffer = BytesIO()
123
+ plt.savefig(img_buffer, format='png')
124
+ plt.close()
125
+ return ("\n".join(messages), decoded, Image.open(img_buffer))
126
+ except Exception as e:
127
+ messages.append(f"❌ Σφάλμα: {str(e)}")
128
+ return ("\n".join(messages), "", None)
 
 
 
 
 
 
 
 
129
 
130
  def analyze_checkpoint():
131
+ """Ανάλυση δεδομένων από το checkpoint."""
132
+ messages = ["🔍 Έναρξη ανάλυσης..."]
133
+ try:
134
+ texts = load_checkpoint()
135
+ if not texts:
136
+ return "Δεν βρέθηκαν δεδομένα για ανάλυση."
137
+ total_chars = sum(len(t) for t in texts)
138
+ avg_length = total_chars / len(texts) if texts else 0
139
+ languages = {}
140
+ for t in texts[:1000]:
141
+ if len(t) > 20:
142
+ try:
143
+ lang = detect(t)
144
+ languages[lang] = languages.get(lang, 0) + 1
145
+ except Exception as e:
146
+ print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}")
147
+ report = [
148
+ f"📊 Σύνολο δειγμάτων: {len(texts)}",
149
+ f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
150
+ "🌍 Γλώσσες (δείγμα 1000):",
151
+ *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
152
+ ]
153
+ return "\n".join(messages + report)
154
+ except Exception as e:
155
+ messages.append(f"❌ Σφάλμα: {str(e)}")
156
+ return "\n".join(messages)
157
 
158
  def restart_collection():
159
+ """Διαγραφή checkpoint και επανεκκίνηση."""
160
  global STOP_COLLECTION
161
  STOP_COLLECTION = False
162
  if os.path.exists(CHECKPOINT_FILE):
163
  os.remove(CHECKPOINT_FILE)
 
164
  return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή."
165
 
 
166
  # Gradio Interface
167
  with gr.Blocks() as demo:
168
+ gr.Markdown("## Custom Tokenizer Trainer για GPT-2")
 
169
  with gr.Row():
170
  with gr.Column(scale=2):
171
  dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
172
  configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
173
  split = gr.Dropdown(["train"], value="train", label="Split")
174
  chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
175
+ vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου")
176
+ min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα")
177
  test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
178
+ max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα")
 
179
  with gr.Row():
180
  start_btn = gr.Button("Start", variant="primary")
181
  stop_btn = gr.Button("Stop", variant="stop")
182
  restart_btn = gr.Button("Restart")
 
183
  analyze_btn = gr.Button("Analyze Data")
184
  train_btn = gr.Button("Train Tokenizer", variant="primary")
 
185
  with gr.Column(scale=3):
186
  progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
187
  gr.Markdown("### Αποτελέσματα")
 
190
 
191
  # Event handlers
192
  start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
193
+ stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
194
+ restart_btn.click(restart_collection, None, progress)
195
  analyze_btn.click(analyze_checkpoint, None, progress)
196
  train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
197
+ [progress, decoded_text, token_distribution])
198
 
199
  demo.queue().launch()