wi-lab commited on
Commit
309f04a
·
verified ·
1 Parent(s): f9384cb

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +7 -3
utils.py CHANGED
@@ -207,6 +207,7 @@ def evaluate_model(model, test_loader, device):
207
 
208
  # Visualization
209
  import matplotlib.cm as cm
 
210
  def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
211
  """
212
  Plots the F1-score over epochs or number of training samples.
@@ -222,11 +223,14 @@ def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
222
  markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X', 'h'] # Different markers for curves
223
 
224
  for r in range(test_f1_scores.shape[0]):
225
- color = colors(r / (test_f1_scores.shape[0] - 1)) # Normalize color index
226
  marker = markers[r % len(markers)] # Cycle through markers
227
  if flag == 0:
228
- plt.plot(test_f1_scores[r], linewidth=2, marker=marker, markersize=5, markeredgewidth=1.5,
229
- markeredgecolor=color, color=color, label=f"{input_types[r]}")
 
 
 
230
  else:
231
  plt.plot(n_train, test_f1_scores[r], linewidth=2, marker=marker, markersize=6, markeredgewidth=1.5,
232
  markeredgecolor=color, markerfacecolor='none', color=color, label=f"{input_types[r]}")
 
207
 
208
  # Visualization
209
  import matplotlib.cm as cm
210
+
211
  def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
212
  """
213
  Plots the F1-score over epochs or number of training samples.
 
223
  markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X', 'h'] # Different markers for curves
224
 
225
  for r in range(test_f1_scores.shape[0]):
226
+ color = colors(0.5 if test_f1_scores.shape[0] == 1 else r / (test_f1_scores.shape[0] - 1)) # Normalize color index
227
  marker = markers[r % len(markers)] # Cycle through markers
228
  if flag == 0:
229
+ if test_f1_scores.shape[0] == 1:
230
+ plt.plot(test_f1_scores[r], linewidth=2, color=color, label=f"{input_types[r]}")
231
+ else:
232
+ plt.plot(test_f1_scores[r], linewidth=2, marker=marker, markersize=5, markeredgewidth=1.5,
233
+ markeredgecolor=color, color=color, label=f"{input_types[r]}")
234
  else:
235
  plt.plot(n_train, test_f1_scores[r], linewidth=2, marker=marker, markersize=6, markeredgewidth=1.5,
236
  markeredgecolor=color, markerfacecolor='none', color=color, label=f"{input_types[r]}")