Mehrdad-S commited on
Commit
5ba152f
·
verified ·
1 Parent(s): d2975ab

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +13 -26
evaluate.py CHANGED
@@ -1,32 +1,19 @@
1
- from transformers import AutoTokenizer, AutoModel
2
- import torch
3
  from datasets import load_dataset
4
- from sklearn.metrics.pairwise import cosine_similarity
5
- import numpy as np
6
 
7
  def evaluate_model(model_name):
8
  try:
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModel.from_pretrained(model_name)
11
- except:
12
- return None
13
-
14
- dataset = load_dataset("persiannlp/STS-pairs", split="test[:100]")
15
- embeddings1, embeddings2 = [], []
16
-
17
- for item in dataset:
18
- inputs1 = tokenizer(item["sentence1"], return_tensors="pt", truncation=True, padding=True)
19
- inputs2 = tokenizer(item["sentence2"], return_tensors="pt", truncation=True, padding=True)
20
 
21
- with torch.no_grad():
22
- embed1 = model(**inputs1).last_hidden_state[:, 0, :]
23
- embed2 = model(**inputs2).last_hidden_state[:, 0, :]
 
 
 
24
 
25
- embeddings1.append(embed1.squeeze().numpy())
26
- embeddings2.append(embed2.squeeze().numpy())
27
-
28
- sims = [cosine_similarity([e1], [e2])[0][0] for e1, e2 in zip(embeddings1, embeddings2)]
29
- labels = [item["similarity_score"] for item in dataset]
30
-
31
- corr = np.corrcoef(sims, labels)[0, 1]
32
- return float(corr)
 
 
 
1
  from datasets import load_dataset
2
+ from sentence_transformers import SentenceTransformer, util
 
3
 
4
  def evaluate_model(model_name):
5
  try:
6
+ model = SentenceTransformer(model_name)
7
+ dataset = load_dataset("arshiaafshani/persian-natural-fluently", split="train[:200]")
 
 
 
 
 
 
 
 
 
8
 
9
+ scores = []
10
+ for row in dataset:
11
+ emb1 = model.encode(row["instruction"], convert_to_tensor=True)
12
+ emb2 = model.encode(row["output"], convert_to_tensor=True)
13
+ sim_score = float(util.cos_sim(emb1, emb2)[0])
14
+ scores.append(sim_score)
15
 
16
+ return sum(scores) / len(scores)
17
+ except Exception as e:
18
+ print(f"Evaluation failed: {e}")
19
+ return None