boltuix commited on
Commit
4bbeb21
·
verified ·
1 Parent(s): 1654e52

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -211
README.md CHANGED
@@ -351,217 +351,8 @@ bert-local is trained using **bert-mini** for multi-class text classification. H
351
  ```
352
 
353
  ### Training Code
354
- ```python
355
- import pandas as pd
356
- from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, TrainerCallback
357
- from sklearn.model_selection import train_test_split
358
- from sklearn.metrics import accuracy_score, f1_score
359
- import torch
360
- from torch.utils.data import Dataset
361
- import shutil
362
- from tqdm import tqdm
363
- import numpy as np
364
-
365
- # === 0. Define model and output paths ===
366
- MODEL_NAME = "bert-mini"
367
- OUTPUT_DIR = "./bert-local"
368
-
369
- # === 1. Custom callback for tqdm progress bar ===
370
- class TQDMProgressBarCallback(TrainerCallback):
371
- def __init__(self):
372
- super().__init__()
373
- self.progress_bar = None
374
-
375
- def on_train_begin(self, args, state, control, **kwargs):
376
- self.total_steps = state.max_steps
377
- self.progress_bar = tqdm(total=self.total_steps, desc="Training", unit="step")
378
-
379
- def on_step_end(self, args, state, control, **kwargs):
380
- self.progress_bar.update(1)
381
- self.progress_bar.set_postfix({
382
- "epoch": f"{state.epoch:.2f}",
383
- "step": state.global_step
384
- })
385
-
386
- def on_train_end(self, args, state, control, **kwargs):
387
- if self.progress_bar is not None:
388
- self.progress_bar.close()
389
- self.progress_bar = None
390
-
391
- # === 2. Load and preprocess data ===
392
- dataset_path = 'dataset.csv'
393
- df = pd.read_csv(dataset_path)
394
- df = df.dropna(subset=['category'])
395
- df.columns = ['label', 'text'] # Rename columns
396
-
397
- # === 3. Encode labels ===
398
- labels = sorted(df["label"].unique())
399
- label_to_id = {label: idx for idx, label in enumerate(labels)}
400
- id_to_label = {idx: label for label, idx in label_to_id.items()}
401
- df['label'] = df['label'].map(label_to_id)
402
-
403
- # === 4. Train-val split ===
404
- train_texts, val_texts, train_labels, val_labels = train_test_split(
405
- df['text'].tolist(), df['label'].tolist(), test_size=0.2, random_state=42, stratify=df['label']
406
- )
407
-
408
- # === 5. Tokenizer ===
409
- tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
410
-
411
- # === 6. Dataset class ===
412
- class CategoryDataset(Dataset):
413
- def __init__(self, texts, labels, tokenizer, max_length=128):
414
- self.texts = texts
415
- self.labels = labels
416
- self.tokenizer = tokenizer
417
- self.max_length = max_length
418
-
419
- def __len__(self):
420
- return len(self.texts)
421
-
422
- def __getitem__(self, idx):
423
- encoding = self.tokenizer(
424
- self.texts[idx],
425
- padding='max_length',
426
- truncation=True,
427
- max_length=self.max_length,
428
- return_tensors='pt'
429
- )
430
- return {
431
- 'input_ids': encoding['input_ids'].squeeze(0),
432
- 'attention_mask': encoding['attention_mask'].squeeze(0),
433
- 'labels': torch.tensor(self.labels[idx], dtype=torch.long)
434
- }
435
-
436
- # === 7. Load datasets ===
437
- train_dataset = CategoryDataset(train_texts, train_labels, tokenizer)
438
- val_dataset = CategoryDataset(val_texts, val_labels, tokenizer)
439
-
440
- # === 8. Load model with num_labels ===
441
- model = BertForSequenceClassification.from_pretrained(
442
- MODEL_NAME,
443
- num_labels=len(label_to_id)
444
- )
445
-
446
- # === 9. Define metrics for evaluation ===
447
- def compute_metrics(eval_pred):
448
- logits, labels = eval_pred
449
- predictions = np.argmax(logits, axis=-1)
450
- acc = accuracy_score(labels, predictions)
451
- f1 = f1_score(labels, predictions, average='weighted')
452
- return {
453
- 'accuracy': acc,
454
- 'f1_weighted': f1,
455
- }
456
-
457
- # === 10. Training arguments ===
458
- training_args = TrainingArguments(
459
- output_dir='./results',
460
- run_name="bert-local",
461
- num_train_epochs=5,
462
- per_device_train_batch_size=16,
463
- per_device_eval_batch_size=16,
464
- warmup_steps=500,
465
- weight_decay=0.01,
466
- logging_dir='./logs',
467
- logging_steps=10,
468
- eval_strategy="epoch",
469
- report_to="none"
470
- )
471
-
472
- # === 11. Trainer setup ===
473
- trainer = Trainer(
474
- model=model,
475
- args=training_args,
476
- train_dataset=train_dataset,
477
- eval_dataset=val_dataset,
478
- compute_metrics=compute_metrics,
479
- callbacks=[TQDMProgressBarCallback()]
480
- )
481
-
482
- # === 12. Train and evaluate ===
483
- trainer.train()
484
- trainer.evaluate()
485
-
486
- # === 13. Save model and tokenizer ===
487
- model.config.label2id = label_to_id
488
- model.config.id2label = id_to_label
489
- model.config.num_labels = len(label_to_id)
490
-
491
- model.save_pretrained(OUTPUT_DIR)
492
- tokenizer.save_pretrained(OUTPUT_DIR)
493
-
494
- # === 14. Zip model directory ===
495
- shutil.make_archive("bert-local", 'zip', OUTPUT_DIR)
496
- print("✅ Training complete. Model and tokenizer saved to ./bert-local")
497
- print("✅ Model directory zipped to bert-local.zip")
498
-
499
- # === 15. Test function with confidence threshold ===
500
- def run_test_cases(model, tokenizer, test_sentences, label_to_id, id_to_label, confidence_threshold=0.5):
501
- model.eval()
502
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
503
- model.to(device)
504
-
505
- correct = 0
506
- total = len(test_sentences)
507
- results = []
508
-
509
- for text, expected_label in test_sentences:
510
- encoding = tokenizer(
511
- text,
512
- padding='max_length',
513
- truncation=True,
514
- max_length=128,
515
- return_tensors='pt'
516
- )
517
- input_ids = encoding['input_ids'].to(device)
518
- attention_mask = encoding['attention_mask'].to(device)
519
-
520
- with torch.no_grad():
521
- outputs = model(input_ids, attention_mask=attention_mask)
522
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
523
- max_prob, predicted_id = torch.max(probs, dim=1)
524
- predicted_label = id_to_label[predicted_id.item()]
525
- if max_prob.item() < confidence_threshold:
526
- predicted_label = "unknown"
527
-
528
- is_correct = (predicted_label == expected_label)
529
- if is_correct:
530
- correct += 1
531
- results.append({
532
- "sentence": text,
533
- "expected": expected_label,
534
- "predicted": predicted_label,
535
- "confidence": max_prob.item(),
536
- "correct": is_correct
537
- })
538
-
539
- accuracy = correct / total * 100
540
- print(f"\nTest Cases Accuracy: {accuracy:.2f}% ({correct}/{total} correct)")
541
-
542
- for r in results:
543
- status = "✓" if r["correct"] else "✗"
544
- print(f"{status} '{r['sentence']}'")
545
- print(f" Expected: {r['expected']}, Predicted: {r['predicted']}, Confidence: {r['confidence']:.3f}")
546
-
547
- assert accuracy >= 70, f"Test failed: Accuracy {accuracy:.2f}% < 70%"
548
- return results
549
-
550
- # === 16. Sample test sentences for testing ===
551
- test_sentences = [
552
- ("Where is the nearest airport to this location?", "airport"),
553
- ("Can I bring a laptop through airport security?", "airport"),
554
- ("How do I get to the closest airport terminal?", "airport"),
555
- ("Need help finding an accounting firm for tax planning.", "accounting firm"),
556
- ("Can an accounting firm help with financial audits?", "accounting firm"),
557
- ("Looking for an accounting firm to manage payroll.", "accounting firm"),
558
- ]
559
-
560
- print("\nRunning test cases...")
561
- test_results = run_test_cases(model, tokenizer, test_sentences, label_to_id, id_to_label)
562
- print("✅ Test cases completed.")
563
- ```
564
-
565
  ---
566
 
567
  ## Evaluation 📈
 
351
  ```
352
 
353
  ### Training Code
354
+ - 📍 Get training [Source Code](https://huggingface.co/boltuix/bert-local/blob/main/colab_training_code.ipynb) 🌟
355
+ - 📍 Dataset (comming soon..)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  ---
357
 
358
  ## Evaluation 📈