|
import zipfile |
|
import os |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from torch.utils.data.dataloader import DataLoader as dl |
|
import yaml |
|
from io import StringIO |
|
import torch as t |
|
import numpy as np |
|
import pandas as pd |
|
from torch.utils.data import Dataset as torch_dset |
|
from PIL import Image |
|
import torchvision.transforms.functional as tvfunc |
|
import json |
|
from matplotlib import pyplot as plt |
|
import matplotlib.patches as patches |
|
from matplotlib.font_manager import FontProperties |
|
import pathlib as pl |
|
import matplotlib as mpl |
|
import streamlit as st |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
import einops as eo |
|
import copy |
|
|
|
|
|
from tqdm.auto import tqdm |
|
import time |
|
import requests |
|
|
|
from matplotlib.patches import Rectangle |
|
from matplotlib import font_manager |
|
from models import LitModel, EnsembleModel |
|
from loss_functions import corn_label_from_logits |
|
import classic_correction_algos as calgo |
|
import analysis_funcs as anf |
|
|
|
TEMP_FOLDER = pl.Path("results") |
|
AVAILABLE_FONTS = [x.name for x in font_manager.fontManager.ttflist] |
|
PLOTS_FOLDER = pl.Path("plots") |
|
TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER / "temp_matplotlib_plot_stimulus.png" |
|
all_fonts = [x.name for x in font_manager.fontManager.ttflist] |
|
mpl.use("agg") |
|
|
|
DIST_MODELS_FOLDER = pl.Path("models") |
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
gradio_plots = pl.Path("plots") |
|
|
|
event_strs = [ |
|
"EFIX", |
|
"EFIX R", |
|
"EFIX L", |
|
"SSACC", |
|
"ESACC", |
|
"SFIX", |
|
"MSG", |
|
"SBLINK", |
|
"EBLINK", |
|
"BUTTON", |
|
"INPUT", |
|
"END", |
|
"START", |
|
"DISPLAY ON", |
|
] |
|
names_dict = { |
|
"SSACC": {"Descr": "Start of Saccade", "Pattern": "SSACC <eye > <stime>"}, |
|
"ESACC": { |
|
"Descr": "End of Saccade", |
|
"Pattern": "ESACC <eye > <stime> <etime > <dur> <sxp > <syp> <exp > <eyp> <ampl > <pv >", |
|
}, |
|
"SFIX": {"Descr": "Start of Fixation", "Pattern": "SFIX <eye > <stime>"}, |
|
"EFIX": {"Descr": "End of Fixation", "Pattern": "EFIX <eye > <stime> <etime > <dur> <axp > <ayp> <aps >"}, |
|
"SBLINK": {"Descr": "Start of Blink", "Pattern": "SBLINK <eye > <stime>"}, |
|
"EBLINK": {"Descr": "End of Blink", "Pattern": "EBLINK <eye > <stime> <etime > <dur>"}, |
|
"DISPLAY ON": {"Descr": "Actual start of Trial", "Pattern": "DISPLAY ON"}, |
|
} |
|
metadata_strs = ["DISPLAY COORDS", "GAZE_COORDS", "FRAMERATE"] |
|
|
|
ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [ |
|
"warp", |
|
"regress", |
|
"compare", |
|
"attach", |
|
"segment", |
|
"split", |
|
"stretch", |
|
"chain", |
|
"slice", |
|
"cluster", |
|
"merge", |
|
"Wisdom_of_Crowds", |
|
"DIST", |
|
"DIST-Ensemble", |
|
"Wisdom_of_Crowds_with_DIST", |
|
"Wisdom_of_Crowds_with_DIST_Ensemble", |
|
] |
|
COLORS = px.colors.qualitative.Alphabet |
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
"From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable" |
|
|
|
def default(self, obj): |
|
if isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
elif isinstance(obj, pl.Path) or isinstance(obj, UploadedFile): |
|
return str(obj) |
|
return json.JSONEncoder.default(self, obj) |
|
|
|
|
|
class DSet(torch_dset): |
|
def __init__( |
|
self, |
|
in_sequence: t.Tensor, |
|
chars_center_coords_padded: t.Tensor, |
|
out_categories: t.Tensor, |
|
trialslist: list, |
|
padding_list: list = None, |
|
padding_at_end: bool = False, |
|
return_images_for_conv: bool = False, |
|
im_partial_string: str = "fixations_chars_channel_sep", |
|
input_im_shape=[224, 224], |
|
) -> None: |
|
super().__init__() |
|
|
|
self.in_sequence = in_sequence |
|
self.chars_center_coords_padded = chars_center_coords_padded |
|
self.out_categories = out_categories |
|
self.padding_list = padding_list |
|
self.padding_at_end = padding_at_end |
|
self.trialslist = trialslist |
|
self.return_images_for_conv = return_images_for_conv |
|
self.input_im_shape = input_im_shape |
|
if return_images_for_conv: |
|
self.im_partial_string = im_partial_string |
|
self.plot_files = [ |
|
str(x["plot_file"]).replace("fixations_words", im_partial_string) for x in self.trialslist |
|
] |
|
|
|
def __getitem__(self, index): |
|
|
|
if self.return_images_for_conv: |
|
im = Image.open(self.plot_files[index]) |
|
if [im.size[1], im.size[0]] != self.input_im_shape: |
|
im = tvfunc.resize(im, self.input_im_shape) |
|
im = tvfunc.normalize(tvfunc.to_tensor(im), IMAGENET_MEAN, IMAGENET_STD) |
|
if self.chars_center_coords_padded is not None: |
|
if self.padding_list is not None: |
|
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long) |
|
if self.padding_at_end: |
|
if self.padding_list[index] > 0: |
|
attention_mask[-self.padding_list[index] :] = 0 |
|
else: |
|
attention_mask[: self.padding_list[index]] = 0 |
|
if self.return_images_for_conv: |
|
return ( |
|
self.in_sequence[index], |
|
self.chars_center_coords_padded[index], |
|
im, |
|
attention_mask, |
|
self.out_categories[index], |
|
) |
|
return ( |
|
self.in_sequence[index], |
|
self.chars_center_coords_padded[index], |
|
attention_mask, |
|
self.out_categories[index], |
|
) |
|
else: |
|
if self.return_images_for_conv: |
|
return ( |
|
self.in_sequence[index], |
|
self.chars_center_coords_padded[index], |
|
im, |
|
self.out_categories[index], |
|
) |
|
else: |
|
return (self.in_sequence[index], self.chars_center_coords_padded[index], self.out_categories[index]) |
|
|
|
if self.padding_list is not None: |
|
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long) |
|
if self.padding_at_end: |
|
if self.padding_list[index] > 0: |
|
attention_mask[-self.padding_list[index] :] = 0 |
|
else: |
|
attention_mask[: self.padding_list[index]] = 0 |
|
if self.return_images_for_conv: |
|
return (self.in_sequence[index], im, attention_mask, self.out_categories[index]) |
|
else: |
|
return (self.in_sequence[index], attention_mask, self.out_categories[index]) |
|
if self.return_images_for_conv: |
|
return (self.in_sequence[index], im, self.out_categories[index]) |
|
else: |
|
return (self.in_sequence[index], self.out_categories[index]) |
|
|
|
def __len__(self): |
|
if isinstance(self.in_sequence, t.Tensor): |
|
return self.in_sequence.shape[0] |
|
else: |
|
return len(self.in_sequence) |
|
|
|
|
|
def download_url(url, target_filename): |
|
r = requests.get(url) |
|
open(target_filename, "wb").write(r.content) |
|
return 0 |
|
|
|
|
|
def asc_to_trial_ids(asc_file, close_gap_between_words=True): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug("asc_to_trial_ids entered") |
|
asc_encoding = ["ISO-8859-15", "UTF-8"][0] |
|
trials_dict, lines = file_to_trials_and_lines( |
|
asc_file, asc_encoding, close_gap_between_words=close_gap_between_words |
|
) |
|
|
|
trials_by_ids = {trials_dict[idx]["trial_id"]: trials_dict[idx] for idx in trials_dict["paragraph_trials"]} |
|
if hasattr(asc_file, "name"): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"Found {len(trials_by_ids)} trials in {asc_file.name}.") |
|
return trials_by_ids, lines |
|
|
|
|
|
def get_trials_list(asc_file=None, close_gap_between_words=True): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug("get_trials_list entered") |
|
|
|
if asc_file == None: |
|
if "single_asc_file" in st.session_state.keys() and st.session_state["single_asc_file"] is not None: |
|
asc_file = st.session_state["single_asc_file"] |
|
else: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning("Asc file is None") |
|
return None |
|
|
|
if hasattr(asc_file, "name"): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"get_trials_list entered with asc_file {asc_file.name}") |
|
|
|
trials_by_ids, lines = asc_to_trial_ids(asc_file, close_gap_between_words=close_gap_between_words) |
|
trial_keys = list(trials_by_ids.keys()) |
|
|
|
return trial_keys, trials_by_ids, lines, asc_file |
|
|
|
|
|
def save_trial_to_json(trial, savename): |
|
if "dffix" in trial: |
|
trial.pop("dffix") |
|
with open(savename, "w", encoding="utf-8") as f: |
|
json.dump(trial, f, ensure_ascii=False, indent=4, cls=NumpyEncoder) |
|
|
|
|
|
def export_csv(dffix, trial): |
|
if isinstance(dffix, dict): |
|
dffix = dffix["value"] |
|
trial_id = trial["trial_id"] |
|
savename = TEMP_FOLDER.joinpath(pl.Path(trial["fname"]).stem) |
|
trial_name = f"{savename}_{trial_id}_trial_info.json" |
|
csv_name = f"{savename}_{trial_id}.csv" |
|
dffix.to_csv(csv_name) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"Saved processed data as {csv_name}") |
|
save_trial_to_json(trial, trial_name) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"Saved processed trial data as {trial_name}") |
|
|
|
return csv_name, trial_name |
|
|
|
|
|
def get_all_classic_preds(dffix, trial, classic_algos_cfg): |
|
corrections = [] |
|
for algo, classic_params in copy.deepcopy(classic_algos_cfg).items(): |
|
dffix = calgo.apply_classic_algo(dffix, trial, algo, classic_params) |
|
corrections.append(np.asarray(dffix.loc[:, f"y_{algo}"])) |
|
return dffix, corrections |
|
|
|
|
|
def apply_woc(dffix, trial, corrections, algo_choice): |
|
|
|
corrected_Y = calgo.wisdom_of_the_crowd(corrections) |
|
dffix.loc[:, f"y_{algo_choice}"] = corrected_Y |
|
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
|
corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_Y] |
|
dffix.loc[:, f"line_num_y_{algo_choice}"] = corrected_line_nums |
|
return dffix |
|
|
|
|
|
def calc_xdiff_ydiff(line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False): |
|
x_diffs = np.unique(np.diff(line_xcoords_no_pad)) |
|
if len(x_diffs) == 1: |
|
x_diff = x_diffs[0] |
|
elif not allow_multiple_values: |
|
x_diff = np.min(x_diffs) |
|
else: |
|
x_diff = x_diffs |
|
|
|
if np.unique(line_ycoords_no_pad).shape[0] == 1: |
|
return x_diff, line_heights[0] |
|
y_diffs = np.unique(np.diff(line_ycoords_no_pad)) |
|
if len(y_diffs) == 1: |
|
y_diff = y_diffs[0] |
|
elif len(y_diffs) == 0: |
|
y_diff = 0 |
|
elif not allow_multiple_values: |
|
y_diff = np.min(y_diffs) |
|
else: |
|
y_diff = y_diffs |
|
return x_diff, y_diff |
|
|
|
|
|
def add_words(trial, close_gap_between_words=True): |
|
chars_list_reconstructed = [] |
|
words_list = [] |
|
word_start_idx = 0 |
|
chars_df = pd.DataFrame(trial["chars_list"]) |
|
chars_df["char_width"] = chars_df.char_xmax - chars_df.char_xmin |
|
space_width = chars_df.loc[chars_df["char"] == " ", "char_width"].mean() |
|
|
|
for idx, char_dict in enumerate(trial["chars_list"]): |
|
on_line_num = char_dict["assigned_line"] |
|
chars_list_reconstructed.append(char_dict) |
|
if ( |
|
char_dict["char"] in [" ", ",", ";", ".", ":"] |
|
or ( |
|
len(chars_list_reconstructed) > 2 |
|
and (chars_list_reconstructed[-1]["char_xmin"] < chars_list_reconstructed[-2]["char_xmin"]) |
|
) |
|
or len(chars_list_reconstructed) == len(trial["chars_list"]) |
|
): |
|
triggered = True |
|
word_xmin = chars_list_reconstructed[word_start_idx]["char_xmin"] |
|
word_xmax = chars_list_reconstructed[-2]["char_xmax"] |
|
word_ymin = chars_list_reconstructed[word_start_idx]["char_ymin"] |
|
word_ymax = chars_list_reconstructed[word_start_idx]["char_ymax"] |
|
word_x_center = (word_xmax - word_xmin) / 2 + word_xmin |
|
word_y_center = (word_ymax - word_ymin) / 2 + word_ymin |
|
word = "".join( |
|
[ |
|
chars_list_reconstructed[idx]["char"] |
|
for idx in range(word_start_idx, len(chars_list_reconstructed) - 1) |
|
] |
|
) |
|
assigned_line = chars_list_reconstructed[word_start_idx]["assigned_line"] |
|
|
|
word_dict = dict( |
|
word=word, |
|
word_xmin=word_xmin, |
|
word_xmax=word_xmax, |
|
word_ymin=word_ymin, |
|
word_ymax=word_ymax, |
|
word_x_center=word_x_center, |
|
word_y_center=word_y_center, |
|
assigned_line=assigned_line, |
|
) |
|
if char_dict["char"] != " ": |
|
word_start_idx = idx |
|
else: |
|
word_start_idx = idx + 1 |
|
words_list.append(word_dict) |
|
else: |
|
triggered = False |
|
last_letter_in_word = word_dict["word"][-1] |
|
last_letter_in_chars_list_reconstructed = char_dict["char"] |
|
if last_letter_in_word != last_letter_in_chars_list_reconstructed: |
|
word_dict = dict( |
|
word=char_dict["char"], |
|
word_xmin=char_dict["char_xmin"], |
|
word_xmax=char_dict["char_xmax"], |
|
word_ymin=char_dict["char_ymin"], |
|
word_ymax=char_dict["char_ymax"], |
|
word_x_center=char_dict["char_x_center"], |
|
word_y_center=char_dict["char_y_center"], |
|
assigned_line=assigned_line, |
|
) |
|
words_list.append(word_dict) |
|
|
|
if close_gap_between_words: |
|
for widx in range(1, len(words_list)): |
|
if words_list[widx]["assigned_line"] == words_list[widx - 1]["assigned_line"]: |
|
word_sep_half_width = (words_list[widx]["word_xmin"] - words_list[widx - 1]["word_xmax"]) / 2 |
|
words_list[widx - 1]["word_xmax"] = words_list[widx - 1]["word_xmax"] + word_sep_half_width |
|
words_list[widx]["word_xmin"] = words_list[widx]["word_xmin"] - word_sep_half_width |
|
|
|
return words_list |
|
|
|
|
|
def asc_lines_to_trials_by_trail_id( |
|
lines: list, paragraph_trials_only=False, fname: str = "", close_gap_between_words=True |
|
) -> dict: |
|
if hasattr(fname, "name"): |
|
fname = fname.name |
|
fps = -999 |
|
display_coords = -999 |
|
trials_dict = dict(paragraph_trials=[], paragraph_trial_IDs=[]) |
|
trial_idx = -1 |
|
removed_trial_ids = [] |
|
for idx, l in enumerate(lines): |
|
parts = l.strip().split(" ") |
|
if "TRIALID" in l: |
|
trial_id = parts[-1] |
|
trial_idx += 1 |
|
if trial_id[0] == "F": |
|
trial_is = "question" |
|
elif trial_id[0] == "P": |
|
trial_is = "practice" |
|
else: |
|
trial_is = "paragraph" |
|
trials_dict["paragraph_trials"].append(trial_idx) |
|
trials_dict["paragraph_trial_IDs"].append(trial_id) |
|
trials_dict[trial_idx] = dict(trial_id=trial_id, trial_id_idx=idx, trial_is=trial_is, filename=fname) |
|
last_trial_skipped = False |
|
|
|
elif "TRIAL_RESULT" in l or "stop_trial" in l: |
|
trials_dict[trial_idx]["trial_result_idx"] = idx |
|
trials_dict[trial_idx]["trial_result_timestamp"] = int(parts[0].split("\t")[1]) |
|
if len(parts) > 2: |
|
trials_dict[trial_idx]["trial_result_number"] = int(parts[2]) |
|
elif "DISPLAY COORDS" in l and isinstance(display_coords, int): |
|
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1])) |
|
elif "GAZE_COORDS" in l and isinstance(display_coords, int): |
|
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1])) |
|
elif "FRAMERATE" in l: |
|
l_idx = parts.index(metadata_strs[2]) |
|
fps = float(parts[l_idx + 1]) |
|
elif "TRIAL ABORTED" in l or "TRIAL REPEATED" in l: |
|
if not last_trial_skipped: |
|
if trial_is == "paragraph": |
|
trials_dict["paragraph_trials"].remove(trial_idx) |
|
trial_idx -= 1 |
|
removed_trial_ids.append(trial_id) |
|
last_trial_skipped = True |
|
|
|
if paragraph_trials_only: |
|
trials_dict_temp = trials_dict.copy() |
|
for k in trials_dict_temp.keys(): |
|
if k not in ["paragraph_trials"] + trials_dict_temp["paragraph_trials"]: |
|
trials_dict.pop(k) |
|
if len(trials_dict_temp["paragraph_trials"]): |
|
trial_idx = trials_dict_temp["paragraph_trials"][-1] |
|
else: |
|
return trials_dict |
|
trials_dict["display_coords"] = display_coords |
|
trials_dict["fps"] = fps |
|
trials_dict["max_trial_idx"] = trial_idx |
|
enum = trials_dict["paragraph_trials"] if "paragraph_trials" in trials_dict.keys() else range(len(trials_dict)) |
|
for trial_idx in enum: |
|
if trial_idx not in trials_dict.keys(): |
|
continue |
|
chars_list = [] |
|
if "display_coords" not in trials_dict[trial_idx].keys(): |
|
trials_dict[trial_idx]["display_coords"] = trials_dict["display_coords"] |
|
trial_start_idx = trials_dict[trial_idx]["trial_id_idx"] |
|
trial_end_idx = trials_dict[trial_idx]["trial_result_idx"] |
|
trial_lines = lines[trial_start_idx:trial_end_idx] |
|
for idx, l in enumerate(trial_lines): |
|
parts = l.strip().split(" ") |
|
if "START" in l and " MSG" not in l: |
|
trials_dict[trial_idx]["start_idx"] = trial_start_idx + idx + 7 |
|
trials_dict[trial_idx]["start_time"] = int(parts[0].split("\t")[1]) |
|
elif "END" in l and "ENDBUTTON" not in l and " MSG" not in l: |
|
trials_dict[trial_idx]["end_idx"] = trial_start_idx + idx - 2 |
|
trials_dict[trial_idx]["end_time"] = int(parts[0].split("\t")[1]) |
|
elif "SYNCTIME" in l: |
|
trials_dict[trial_idx]["synctime"] = trial_start_idx + idx |
|
trials_dict[trial_idx]["synctime_time"] = int(parts[0].split("\t")[1]) |
|
elif "GAZE TARGET OFF" in l: |
|
trials_dict[trial_idx]["gaze_targ_off_time"] = int(parts[0].split("\t")[1]) |
|
elif "GAZE TARGET ON" in l: |
|
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1]) |
|
elif "DISPLAY_SENTENCE" in l: |
|
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1]) |
|
elif "REGION CHAR" in l: |
|
rg_idx = parts.index("CHAR") |
|
if len(parts[rg_idx:]) > 8: |
|
char = " " |
|
idx_correction = 1 |
|
elif len(parts[rg_idx:]) == 3: |
|
char = " " |
|
if "REGION CHAR" not in trial_lines[idx + 1]: |
|
parts = trial_lines[idx + 1].strip().split(" ") |
|
idx_correction = -rg_idx - 4 |
|
else: |
|
char = parts[rg_idx + 3] |
|
idx_correction = 0 |
|
try: |
|
char_dict = { |
|
"char": char, |
|
"char_xmin": float(parts[rg_idx + 4 + idx_correction]), |
|
"char_ymin": float(parts[rg_idx + 5 + idx_correction]), |
|
"char_xmax": float(parts[rg_idx + 6 + idx_correction]), |
|
"char_ymax": float(parts[rg_idx + 7 + idx_correction]), |
|
} |
|
char_dict["char_y_center"] = (char_dict["char_ymax"] - char_dict["char_ymin"]) / 2 + char_dict[ |
|
"char_ymin" |
|
] |
|
char_dict["char_x_center"] = (char_dict["char_xmax"] - char_dict["char_xmin"]) / 2 + char_dict[ |
|
"char_xmin" |
|
] |
|
chars_list.append(char_dict) |
|
except Exception as e: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(f"char_dict creation failed for parts {parts}") |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(e) |
|
|
|
if "gaze_targ_on_time" in trials_dict[trial_idx]: |
|
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["gaze_targ_on_time"] |
|
else: |
|
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["start_time"] |
|
|
|
if len(chars_list) > 0: |
|
line_ycoords = [] |
|
for idx in range(len(chars_list)): |
|
chars_list[idx]["char_line_y"] = ( |
|
chars_list[idx]["char_ymax"] - chars_list[idx]["char_ymin"] |
|
) / 2 + chars_list[idx]["char_ymin"] |
|
if chars_list[idx]["char_line_y"] not in line_ycoords: |
|
line_ycoords.append(chars_list[idx]["char_line_y"]) |
|
for idx in range(len(chars_list)): |
|
chars_list[idx]["assigned_line"] = line_ycoords.index(chars_list[idx]["char_line_y"]) |
|
|
|
line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list] |
|
line_xcoords_all = [x["char_x_center"] for x in chars_list] |
|
line_xcoords_no_pad = np.unique(line_xcoords_all) |
|
|
|
line_ycoords_all = [x["char_y_center"] for x in chars_list] |
|
line_ycoords_no_pad = np.unique(line_ycoords_all) |
|
|
|
trials_dict[trial_idx]["x_char_unique"] = list(line_xcoords_no_pad) |
|
trials_dict[trial_idx]["y_char_unique"] = list(line_ycoords_no_pad) |
|
x_diff, y_diff = calc_xdiff_ydiff( |
|
line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False |
|
) |
|
trials_dict[trial_idx]["x_diff"] = float(x_diff) |
|
trials_dict[trial_idx]["y_diff"] = float(y_diff) |
|
trials_dict[trial_idx]["num_char_lines"] = len(line_ycoords_no_pad) |
|
trials_dict[trial_idx]["line_heights"] = line_heights |
|
trials_dict[trial_idx]["chars_list"] = chars_list |
|
|
|
words_list = add_words(trials_dict[trial_idx], close_gap_between_words=close_gap_between_words) |
|
trials_dict[trial_idx]["words_list"] = words_list |
|
|
|
return trials_dict |
|
|
|
|
|
def file_to_trials_and_lines(uploaded_file, asc_encoding: str = "ISO-8859-15", close_gap_between_words=True): |
|
if isinstance(uploaded_file, str) or isinstance(uploaded_file, pl.Path): |
|
with open(uploaded_file, "r", encoding=asc_encoding) as f: |
|
lines = f.readlines() |
|
else: |
|
stringio = StringIO(uploaded_file.getvalue().decode(asc_encoding)) |
|
loaded_str = stringio.read() |
|
lines = loaded_str.split("\n") |
|
trials_dict = asc_lines_to_trials_by_trail_id( |
|
lines, True, uploaded_file, close_gap_between_words=close_gap_between_words |
|
) |
|
|
|
if "paragraph_trials" not in trials_dict.keys() and "trial_is" in trials_dict[0].keys(): |
|
paragraph_trials = [] |
|
for k in range(trials_dict["max_trial_idx"]): |
|
if trials_dict[k]["trial_is"] == "paragraph": |
|
paragraph_trials.append(k) |
|
trials_dict["paragraph_trials"] = paragraph_trials |
|
|
|
enum = ( |
|
trials_dict["paragraph_trials"] |
|
if "paragraph_trials" in trials_dict.keys() |
|
else range(trials_dict["max_trial_idx"]) |
|
) |
|
for k in enum: |
|
if "chars_list" in trials_dict[k].keys(): |
|
max_line = trials_dict[k]["chars_list"][-1]["assigned_line"] |
|
words_on_lines = {x: [] for x in range(max_line + 1)} |
|
[words_on_lines[x["assigned_line"]].append(x["char"]) for x in trials_dict[k]["chars_list"]] |
|
sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()] |
|
text = sentence_list[0] + "\n".join([x for x in sentence_list[1:]]) |
|
trials_dict[k]["sentence_list"] = sentence_list |
|
trials_dict[k]["text"] = text |
|
trials_dict[k]["max_line"] = max_line |
|
|
|
return trials_dict, lines |
|
|
|
|
|
def get_plot_props(trial, available_fonts): |
|
if "font" in trial.keys(): |
|
font = trial["font"] |
|
font_size = trial["font_size"] |
|
if font not in available_fonts: |
|
font = "DejaVu Sans Mono" |
|
else: |
|
font = "DejaVu Sans Mono" |
|
font_size = 21 |
|
dpi = 100 |
|
if "display_coords" in trial.keys(): |
|
screen_res = (trial["display_coords"][2], trial["display_coords"][3]) |
|
else: |
|
screen_res = (1920, 1080) |
|
return font, font_size, dpi, screen_res |
|
|
|
|
|
def trial_to_dfs( |
|
trial: dict, lines: list, use_synctime: bool = False, save_lines_to_txt=False, cut_out_outer_fixations=False |
|
): |
|
"""trial should be dict of line numbers of trials. |
|
lines should be list of lines from .asc file.""" |
|
|
|
if use_synctime and "synctime" in trial: |
|
idx0, idxend = trial["synctime"] + 1, trial["trial_result_idx"] |
|
else: |
|
idx0, idxend = trial["start_idx"], trial["end_idx"] |
|
|
|
line_dicts = [] |
|
fixations_dicts = [] |
|
blink_started = False |
|
|
|
fixation_started = False |
|
efix_count = 0 |
|
sfix_count = 0 |
|
sblink_count = 0 |
|
|
|
if save_lines_to_txt: |
|
with open("Lines_plus500.txt", "w") as f: |
|
f.writelines(lines[idx0 - 500 : idxend + 500]) |
|
eye_to_use = "R" |
|
for l in lines[idx0 : idxend + 1]: |
|
if "EFIX R" in l: |
|
eye_to_use = "R" |
|
break |
|
elif "EFIX L" in l: |
|
eye_to_use = "L" |
|
break |
|
for l in lines[idx0 : idxend + 1]: |
|
parts = [x.strip() for x in l.split("\t")] |
|
if f"EFIX {eye_to_use}" in l: |
|
efix_count += 1 |
|
if fixation_started: |
|
if parts[1] == "." and parts[2] == ".": |
|
continue |
|
fixations_dicts.append( |
|
{ |
|
"start_time": float(parts[0].split()[-1].strip()), |
|
"end_time": float(parts[1].strip()), |
|
"duration": float(parts[2].strip()), |
|
"x": float(parts[3].strip()), |
|
"y": float(parts[4].strip()), |
|
"pupil_size": float(parts[5].strip()), |
|
} |
|
) |
|
if len(fixations_dicts) >= 2: |
|
assert ( |
|
fixations_dicts[-1]["start_time"] > fixations_dicts[-2]["start_time"] |
|
), "start times not in order" |
|
fixation_started = False |
|
|
|
elif f"SFIX {eye_to_use}" in l: |
|
sfix_count += 1 |
|
fixation_started = True |
|
elif f"SBLINK {eye_to_use}" in l: |
|
sblink_count += 1 |
|
blink_started = True |
|
if not blink_started and not any([True for x in event_strs if x in l]): |
|
if len(parts) < 3 or (parts[1] == "." and parts[2] == "."): |
|
continue |
|
line_dicts.append( |
|
{ |
|
"idx": float(parts[0].strip()), |
|
"x": float(parts[1].strip()), |
|
"y": float(parts[2].strip()), |
|
"p": float(parts[3].strip()), |
|
} |
|
) |
|
|
|
elif f"EBLINK {eye_to_use}" in l: |
|
blink_started = False |
|
|
|
df = pd.DataFrame(line_dicts) |
|
dffix = pd.DataFrame(fixations_dicts) |
|
if len(fixations_dicts) > 0: |
|
dffix["corrected_start_time"] = dffix.start_time - trial["trial_start_time"] |
|
dffix["corrected_end_time"] = dffix.end_time - trial["trial_start_time"] |
|
dffix["fix_duration"] = dffix.corrected_end_time.values - dffix.corrected_start_time.values |
|
assert all(np.diff(dffix["corrected_start_time"]) > 0), "start times not in order" |
|
else: |
|
df, pd.DataFrame(), trial |
|
|
|
if cut_out_outer_fixations: |
|
dffix = dffix[(dffix.x > -10) & (dffix.y > -10) & (dffix.x < 1050) & (dffix.y < 800)] |
|
trial["efix_count"] = efix_count |
|
trial["eye_to_use"] = eye_to_use |
|
trial["sfix_count"] = sfix_count |
|
trial["sblink_count"] = sblink_count |
|
return df, dffix, trial |
|
|
|
|
|
def get_save_path(fpath, fname_ending): |
|
save_path = gradio_plots.joinpath(f"{fpath.stem}_{fname_ending}.png") |
|
return save_path |
|
|
|
|
|
def save_im_load_convert(fpath, fig, fname_ending, mode): |
|
save_path = get_save_path(fpath, fname_ending) |
|
fig.savefig(save_path) |
|
im = Image.open(save_path).convert(mode) |
|
im.save(save_path) |
|
return im |
|
|
|
|
|
def get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, dffix=None, prefix="word"): |
|
fig = plt.figure(figsize=(screen_res[0] / dpi, screen_res[1] / dpi), dpi=dpi) |
|
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) |
|
ax.set_axis_off() |
|
if dffix is not None: |
|
ax.set_ylim((dffix.y.min(), dffix.y.max())) |
|
ax.set_xlim((dffix.x.min(), dffix.x.max())) |
|
else: |
|
ax.set_ylim((words_df[f"{prefix}_y_center"].min() - y_margin, words_df[f"{prefix}_y_center"].max() + y_margin)) |
|
ax.set_xlim((words_df[f"{prefix}_x_center"].min() - x_margin, words_df[f"{prefix}_x_center"].max() + x_margin)) |
|
ax.invert_yaxis() |
|
fig.add_axes(ax) |
|
return fig, ax |
|
|
|
|
|
def plot_text_boxes_fixations( |
|
fpath, |
|
dpi, |
|
screen_res, |
|
data_dir_sub, |
|
set_font_size: bool, |
|
font_size: int, |
|
use_words: bool, |
|
save_channel_repeats: bool, |
|
save_combo_grey_and_rgb: bool, |
|
dffix=None, |
|
trial=None, |
|
): |
|
if isinstance(fpath, str): |
|
fpath = pl.Path(fpath) |
|
if use_words: |
|
prefix = "word" |
|
else: |
|
prefix = "char" |
|
if dffix is None: |
|
dffix = pd.read_csv(fpath) |
|
if trial is None: |
|
json_fpath = str(fpath).replace("_fixations.csv", "_trial.json") |
|
with open(json_fpath, "r") as f: |
|
trial = json.load(f) |
|
words_df = pd.DataFrame(trial[f"{prefix}s_list"]) |
|
x_right = words_df[f"{prefix}_xmin"] |
|
x_left = words_df[f"{prefix}_xmax"] |
|
y_top = words_df[f"{prefix}_ymax"] |
|
y_bottom = words_df[f"{prefix}_ymin"] |
|
|
|
if f"{prefix}_x_center" not in words_df.columns: |
|
words_df[f"{prefix}_x_center"] = (words_df[f"{prefix}_xmax"] - words_df[f"{prefix}_xmin"]) / 2 + words_df[ |
|
f"{prefix}_xmin" |
|
] |
|
words_df[f"{prefix}_y_center"] = (words_df[f"{prefix}_ymax"] - words_df[f"{prefix}_ymin"]) / 2 + words_df[ |
|
f"{prefix}_ymin" |
|
] |
|
|
|
x_margin = words_df[f"{prefix}_x_center"].mean() / 8 |
|
y_margin = words_df[f"{prefix}_y_center"].mean() / 4 |
|
times = dffix.corrected_start_time - dffix.corrected_start_time.min() |
|
times = times / times.max() |
|
times = np.linspace(0.25, 1, len(times)) |
|
|
|
if set_font_size: |
|
font = "monospace" |
|
else: |
|
font_size = trial["font_size"] * 27 // dpi |
|
|
|
font_props = FontProperties(family=font, style="normal", size=font_size) |
|
if save_combo_grey_and_rgb: |
|
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
ax.scatter(dffix.x, dffix.y, alpha=times, facecolor="b") |
|
for idx in range(len(x_left)): |
|
xdiff = x_right[idx] - x_left[idx] |
|
ydiff = y_top[idx] - y_bottom[idx] |
|
rect = patches.Rectangle( |
|
(x_left[idx] - 1, y_bottom[idx] - 1), |
|
xdiff, |
|
ydiff, |
|
alpha=0.9, |
|
linewidth=0.8, |
|
edgecolor="r", |
|
facecolor="none", |
|
) |
|
ax.text( |
|
words_df[f"{prefix}_x_center"][idx], |
|
words_df[f"{prefix}_y_center"][idx], |
|
words_df[prefix][idx], |
|
horizontalalignment="center", |
|
verticalalignment="center", |
|
fontproperties=font_props, |
|
color="g", |
|
) |
|
ax.add_patch(rect) |
|
fname_ending = f"{prefix}s_combo_rgb" |
|
words_combo_rgb_im = save_im_load_convert(fpath, fig, fname_ending, "RGB") |
|
plt.close("all") |
|
|
|
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
|
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times) |
|
for idx in range(len(x_left)): |
|
xdiff = x_right[idx] - x_left[idx] |
|
ydiff = y_top[idx] - y_bottom[idx] |
|
rect = patches.Rectangle( |
|
(x_left[idx] - 1, y_bottom[idx] - 1), |
|
xdiff, |
|
ydiff, |
|
alpha=0.9, |
|
linewidth=0.8, |
|
edgecolor="k", |
|
facecolor="none", |
|
) |
|
ax.text( |
|
words_df[f"{prefix}_x_center"][idx], |
|
words_df[f"{prefix}_y_center"][idx], |
|
words_df[prefix][idx], |
|
horizontalalignment="center", |
|
verticalalignment="center", |
|
fontproperties=font_props, |
|
) |
|
ax.add_patch(rect) |
|
fname_ending = f"{prefix}s_combo_grey" |
|
words_combo_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
plt.close("all") |
|
|
|
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
|
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.01) |
|
for idx in range(len(x_left)): |
|
ax.text( |
|
words_df[f"{prefix}_x_center"][idx], |
|
words_df[f"{prefix}_y_center"][idx], |
|
words_df[prefix][idx], |
|
horizontalalignment="center", |
|
verticalalignment="center", |
|
fontproperties=font_props, |
|
) |
|
fname_ending = f"{prefix}s_grey" |
|
words_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
|
plt.close("all") |
|
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
|
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.1) |
|
for idx in range(len(x_left)): |
|
xdiff = x_right[idx] - x_left[idx] |
|
ydiff = y_top[idx] - y_bottom[idx] |
|
rect = patches.Rectangle( |
|
(x_left[idx] - 1, y_bottom[idx] - 1), xdiff, ydiff, alpha=0.9, linewidth=1, edgecolor="k", facecolor="grey" |
|
) |
|
ax.add_patch(rect) |
|
fname_ending = f"{prefix}_boxes_grey" |
|
word_boxes_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
|
plt.close("all") |
|
|
|
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
|
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times) |
|
fname_ending = "fix_scatter_grey" |
|
fix_scatter_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
|
plt.close("all") |
|
|
|
arr_combo = np.stack( |
|
[ |
|
np.asarray(words_grey_im), |
|
np.asarray(word_boxes_grey_im), |
|
np.asarray(fix_scatter_grey_im), |
|
], |
|
axis=2, |
|
) |
|
|
|
im_combo = Image.fromarray(arr_combo) |
|
fname_ending = f"{prefix}s_channel_sep" |
|
|
|
save_path = get_save_path(fpath, fname_ending) |
|
print(f"save_path for im combo is {save_path}") |
|
im_combo.save(fpath) |
|
|
|
if save_channel_repeats: |
|
arr_combo = np.stack([np.asarray(words_grey_im)] * 3, axis=2) |
|
im_combo = Image.fromarray(arr_combo) |
|
fname_ending = f"{prefix}s_channel_repeat" |
|
|
|
save_path = get_save_path(fpath, fname_ending) |
|
im_combo.save(save_path) |
|
|
|
arr_combo = np.stack([np.asarray(word_boxes_grey_im)] * 3, axis=2) |
|
|
|
im_combo = Image.fromarray(arr_combo) |
|
fname_ending = f"{prefix}boxes_channel_repeat" |
|
|
|
save_path = get_save_path(fpath, fname_ending) |
|
im_combo.save(save_path) |
|
|
|
arr_combo = np.stack([np.asarray(fix_scatter_grey_im)] * 3, axis=2) |
|
|
|
im_combo = Image.fromarray(arr_combo) |
|
fname_ending = "fix_channel_repeat" |
|
|
|
save_path = get_save_path(fpath, fname_ending) |
|
im_combo.save(save_path) |
|
|
|
|
|
def add_line_overlaps_to_sample(trial, sample): |
|
char_df = pd.DataFrame(trial["chars_list"]) |
|
line_overlaps = [] |
|
for arr in sample: |
|
y_val = arr[1] |
|
line_overlap = t.tensor(-1, dtype=t.float32) |
|
for idx, (x1, x2) in enumerate(zip(char_df.char_ymin.unique(), char_df.char_ymax.unique())): |
|
if x1 <= y_val <= x2: |
|
line_overlap = t.tensor(idx, dtype=t.float32) |
|
break |
|
line_overlaps.append(line_overlap) |
|
line_olaps_tensor = t.stack(line_overlaps, dim=0) |
|
sample = t.cat([sample, line_olaps_tensor.unsqueeze(1)], dim=1) |
|
return sample |
|
|
|
|
|
def norm_coords_by_letter_min_x_y( |
|
sample_idx: int, |
|
trialslist: list, |
|
samplelist: list, |
|
chars_center_coords_list: list = None, |
|
): |
|
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"]) |
|
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique() |
|
|
|
min_x_chars = chars_df.char_xmin.min() |
|
min_y_chars = chars_df.char_ymin.min() |
|
|
|
norm_vector_substract = t.zeros( |
|
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device |
|
) |
|
norm_vector_substract[0, 0] = norm_vector_substract[0, 0] + 1 * min_x_chars |
|
norm_vector_substract[0, 1] = norm_vector_substract[0, 1] + 1 * min_y_chars |
|
|
|
samplelist[sample_idx] = samplelist[sample_idx] - norm_vector_substract |
|
|
|
if chars_center_coords_list is not None: |
|
norm_vector_substract = norm_vector_substract.squeeze(0)[:2] |
|
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_substract.shape[-1] * 2: |
|
chars_center_coords_list[sample_idx][:, :2] -= norm_vector_substract |
|
chars_center_coords_list[sample_idx][:, 2:] -= norm_vector_substract |
|
else: |
|
chars_center_coords_list[sample_idx] -= norm_vector_substract |
|
return trialslist, samplelist, chars_center_coords_list |
|
|
|
|
|
def norm_coords_by_letter_positions( |
|
sample_idx: int, |
|
trialslist: list, |
|
samplelist: list, |
|
meanlist: list = None, |
|
stdlist: list = None, |
|
return_mean_std_lists=False, |
|
norm_by_char_averages=False, |
|
chars_center_coords_list: list = None, |
|
add_normalised_values_as_features=False, |
|
): |
|
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"]) |
|
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique() |
|
|
|
min_x_chars = chars_df.char_xmin.min() |
|
max_x_chars = chars_df.char_xmax.max() |
|
|
|
norm_vector_multi = t.ones( |
|
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device |
|
) |
|
if norm_by_char_averages: |
|
chars_list = trialslist[sample_idx]["chars_list"] |
|
char_widths = np.asarray([x["char_xmax"] - x["char_xmin"] for x in chars_list]) |
|
char_heights = np.asarray([x["char_ymax"] - x["char_ymin"] for x in chars_list]) |
|
char_widths_average = np.mean(char_widths[char_widths > 0]) |
|
char_heights_average = np.mean(char_heights[char_heights > 0]) |
|
|
|
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * char_widths_average |
|
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * char_heights_average |
|
|
|
else: |
|
line_height = min(np.unique(trialslist[sample_idx]["line_heights"])) |
|
line_width = max_x_chars - min_x_chars |
|
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * line_width |
|
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * line_height |
|
assert ~t.any(t.isnan(norm_vector_multi)), "Nan found in char norming vector" |
|
|
|
norm_vector_multi = norm_vector_multi.squeeze(0) |
|
if add_normalised_values_as_features: |
|
norm_vector_multi = norm_vector_multi[norm_vector_multi != 1] |
|
normed_features = samplelist[sample_idx][:, : norm_vector_multi.shape[0]] / norm_vector_multi |
|
samplelist[sample_idx] = t.cat([samplelist[sample_idx], normed_features], dim=1) |
|
else: |
|
samplelist[sample_idx] = samplelist[sample_idx] / norm_vector_multi |
|
if chars_center_coords_list is not None: |
|
norm_vector_multi = norm_vector_multi[:2] |
|
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_multi.shape[-1] * 2: |
|
chars_center_coords_list[sample_idx][:, :2] /= norm_vector_multi |
|
chars_center_coords_list[sample_idx][:, 2:] /= norm_vector_multi |
|
else: |
|
chars_center_coords_list[sample_idx] /= norm_vector_multi |
|
if return_mean_std_lists: |
|
mean_val = samplelist[sample_idx].mean(axis=0).cpu().numpy() |
|
meanlist.append(mean_val) |
|
std_val = samplelist[sample_idx].std(axis=0).cpu().numpy() |
|
stdlist.append(std_val) |
|
assert ~any(np.isnan(mean_val)), "Nan found in mean_val" |
|
assert ~any(np.isnan(mean_val)), "Nan found in std_val" |
|
|
|
return trialslist, samplelist, meanlist, stdlist, chars_center_coords_list |
|
return trialslist, samplelist, chars_center_coords_list |
|
|
|
|
|
def remove_compile_from_model(model): |
|
if hasattr(model.project, "_orig_mod"): |
|
model.project = model.project._orig_mod |
|
model.chars_conv = model.chars_conv._orig_mod |
|
model.chars_classifier = model.chars_classifier._orig_mod |
|
model.layer_norm_in = model.layer_norm_in._orig_mod |
|
model.bert_model = model.bert_model._orig_mod |
|
model.linear = model.linear._orig_mod |
|
else: |
|
print(f"remove_compile_from_model not done since model.project {model.project} has no orig_mod") |
|
return model |
|
|
|
|
|
def remove_compile_from_dict(state_dict): |
|
for key in list(state_dict.keys()): |
|
newkey = key.replace("._orig_mod.", ".") |
|
state_dict[newkey] = state_dict.pop(key) |
|
return state_dict |
|
|
|
|
|
def add_text_to_ax( |
|
chars_list, |
|
ax, |
|
font_to_use="DejaVu Sans Mono", |
|
fontsize=21, |
|
prefix="char", |
|
plot_boxes=True, |
|
plot_text=True, |
|
box_annotations=None, |
|
): |
|
font_props = FontProperties(family=font_to_use, style="normal", size=fontsize) |
|
if not plot_boxes and not plot_text: |
|
return None |
|
if box_annotations is None: |
|
enum = chars_list |
|
else: |
|
enum = zip(chars_list, box_annotations) |
|
for v in enum: |
|
if box_annotations is not None: |
|
v, annot_text = v |
|
x0, y0 = v[f"{prefix}_xmin"], v[f"{prefix}_ymin"] |
|
xdiff, ydiff = v[f"{prefix}_xmax"] - v[f"{prefix}_xmin"], v[f"{prefix}_ymax"] - v[f"{prefix}_ymin"] |
|
if plot_text: |
|
ax.text( |
|
v[f"{prefix}_x_center"], |
|
v[f"{prefix}_y_center"], |
|
v[prefix], |
|
horizontalalignment="center", |
|
verticalalignment="center", |
|
fontproperties=font_props, |
|
) |
|
if plot_boxes: |
|
ax.add_patch(Rectangle((x0, y0), xdiff, ydiff, edgecolor="grey", facecolor="none", lw=0.8, alpha=0.4)) |
|
if box_annotations is not None: |
|
ax.annotate( |
|
str(annot_text), |
|
(x0 + xdiff / 2, y0), |
|
horizontalalignment="center", |
|
verticalalignment="center", |
|
fontproperties=FontProperties(family=font_to_use, style="normal", size=fontsize / 1.5), |
|
) |
|
|
|
|
|
def plot_fixations_and_text( |
|
dffix: pd.DataFrame, |
|
trial: dict, |
|
plot_prefix="chars_", |
|
show=False, |
|
returnfig=False, |
|
save=False, |
|
savelocation="plot.png", |
|
font_to_use="DejaVu Sans Mono", |
|
fontsize=20, |
|
plot_classic=True, |
|
plot_boxes=True, |
|
plot_text=True, |
|
fig_size=(14, 8), |
|
dpi=300, |
|
turn_axis_on=True, |
|
algo_choice="slice", |
|
): |
|
fig, ax = plt.subplots(1, 1, figsize=fig_size, tight_layout=True, dpi=dpi) |
|
if f"{plot_prefix}list" in trial.keys(): |
|
add_text_to_ax( |
|
trial[f"{plot_prefix}list"], |
|
ax, |
|
font_to_use, |
|
fontsize=fontsize, |
|
prefix=plot_prefix[:-2], |
|
plot_boxes=plot_boxes, |
|
plot_text=plot_text, |
|
) |
|
ax.plot(dffix.x, dffix.y, "kX", label="Raw Fixations", alpha=0.9) |
|
|
|
if plot_classic and f"line_num_{algo_choice}" in dffix.columns: |
|
ax.scatter( |
|
dffix.x, |
|
dffix[f"y_{algo_choice}"], |
|
marker="*", |
|
color="tab:green", |
|
label=f"{algo_choice} Prediction", |
|
alpha=0.9, |
|
) |
|
for x_before, y_before, x_after, y_after in zip( |
|
dffix.x.values, dffix[f"y_{algo_choice}"].values, dffix.x, dffix.y |
|
): |
|
arr_delta_x = x_after - x_before |
|
arr_delta_y = y_after - y_before |
|
ax.arrow(x_before, y_before, arr_delta_x, arr_delta_y, color="tab:green", alpha=0.6) |
|
ax.set_ylabel("y (pixel)") |
|
ax.set_xlabel("x (pixel)") |
|
|
|
ax.invert_yaxis() |
|
ax.legend(bbox_to_anchor=(1, 1), loc="upper left") |
|
if not turn_axis_on: |
|
ax.axis("off") |
|
if save: |
|
plt.savefig(savelocation, dpi=dpi) |
|
if show: |
|
plt.show() |
|
if returnfig: |
|
return fig |
|
else: |
|
plt.close() |
|
return None |
|
|
|
|
|
def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots): |
|
gradio_temp_folder.mkdir(exist_ok=True) |
|
gradio_temp_unzipped_folder.mkdir(exist_ok=True) |
|
gradio_plots.mkdir(exist_ok=True) |
|
return 0 |
|
|
|
|
|
def get_classic_cfg(fname): |
|
with open(fname, "r") as f: |
|
jsonsstring = f.read() |
|
classic_algos_cfg = json.loads(jsonsstring) |
|
classic_algos_cfg["slice"] = classic_algos_cfg["slice"] |
|
classic_algos_cfg = classic_algos_cfg |
|
return classic_algos_cfg |
|
|
|
|
|
def find_and_load_model(model_date="20240104-223349"): |
|
model_cfg_file = list(DIST_MODELS_FOLDER.glob(f"*{model_date}*.yaml")) |
|
if len(model_cfg_file) == 0: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(f"No model cfg yaml found for {model_date}") |
|
return None, None |
|
model_cfg_file = model_cfg_file[0] |
|
with open(model_cfg_file) as f: |
|
model_cfg = yaml.safe_load(f) |
|
|
|
model_cfg["system_type"] = "linux" |
|
model_file = list(pl.Path("models").glob(f"*{model_date}*.ckpt"))[0] |
|
model = load_model(model_file, model_cfg) |
|
|
|
return model, model_cfg |
|
|
|
|
|
def load_model(model_file, cfg): |
|
try: |
|
model_loaded = t.load(model_file, map_location="cpu") |
|
if "hyper_parameters" in model_loaded.keys(): |
|
model_cfg_temp = model_loaded["hyper_parameters"]["cfg"] |
|
else: |
|
model_cfg_temp = cfg |
|
model_state_dict = model_loaded["state_dict"] |
|
except Exception as e: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(e) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(f"Failed to load {model_file}") |
|
return None |
|
model = LitModel( |
|
[1, 500, 3], |
|
model_cfg_temp["hidden_dim_bert"], |
|
model_cfg_temp["num_attention_heads"], |
|
model_cfg_temp["n_layers_BERT"], |
|
model_cfg_temp["loss_function"], |
|
1e-4, |
|
model_cfg_temp["weight_decay"], |
|
model_cfg_temp, |
|
model_cfg_temp["use_lr_warmup"], |
|
model_cfg_temp["use_reduce_on_plateau"], |
|
track_gradient_histogram=model_cfg_temp["track_gradient_histogram"], |
|
register_forw_hook=model_cfg_temp["track_activations_via_hook"], |
|
char_dims=model_cfg_temp["char_dims"], |
|
) |
|
model = remove_compile_from_model(model) |
|
model_state_dict = remove_compile_from_dict(model_state_dict) |
|
with t.no_grad(): |
|
model.load_state_dict(model_state_dict, strict=False) |
|
model.eval() |
|
model.freeze() |
|
return model |
|
|
|
|
|
def set_up_models(dist_models_folder): |
|
out_dict = {} |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info("Loading Ensemble") |
|
dist_models_with_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_True*.ckpt")) |
|
dist_models_without_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_False*.ckpt")) |
|
DIST_MODEL_DATE_WITH_NORM = dist_models_with_norm[0].stem.split("_")[1] |
|
|
|
models_without_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm] |
|
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm] |
|
|
|
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0] |
|
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0] |
|
|
|
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None] |
|
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None] |
|
|
|
ensemble_model_avg = EnsembleModel( |
|
models_without_norm_df, models_with_norm_df, learning_rate=0.0058, use_simple_average=True |
|
) |
|
out_dict["ensemble_model_avg"] = ensemble_model_avg |
|
|
|
out_dict["model_cfg_without_norm_df"] = model_cfg_without_norm_df |
|
out_dict["model_cfg_with_norm_df"] = model_cfg_with_norm_df |
|
|
|
single_DIST_model, single_DIST_model_cfg = find_and_load_model(model_date=DIST_MODEL_DATE_WITH_NORM) |
|
out_dict["DIST_MODEL_DATE_WITH_NORM"] = DIST_MODEL_DATE_WITH_NORM |
|
out_dict["single_DIST_model"] = single_DIST_model |
|
out_dict["single_DIST_model_cfg"] = single_DIST_model_cfg |
|
return out_dict |
|
|
|
|
|
def prep_data_for_dist(model_cfg, dffix, trial=None): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug("prep_data_for_dist entered") |
|
if trial is None: |
|
trial = st.session_state["trial"] |
|
if isinstance(dffix, dict): |
|
dffix = dffix["value"] |
|
sample_tensor = t.tensor(dffix.loc[:, model_cfg["sample_cols"]].to_numpy(), dtype=t.float32) |
|
|
|
if model_cfg["add_line_overlap_feature"]: |
|
sample_tensor = add_line_overlaps_to_sample(trial, sample_tensor) |
|
|
|
has_nans = t.any(t.isnan(sample_tensor)) |
|
assert not has_nans, "NaNs found in sample tensor" |
|
samplelist_eval = [sample_tensor] |
|
trialslist_eval = [trial] |
|
chars_center_coords_list_eval = None |
|
if model_cfg["norm_coords_by_letter_min_x_y"]: |
|
for sample_idx, _ in enumerate(samplelist_eval): |
|
trialslist_eval, samplelist_eval, chars_center_coords_list_eval = norm_coords_by_letter_min_x_y( |
|
sample_idx, |
|
trialslist_eval, |
|
samplelist_eval, |
|
chars_center_coords_list=chars_center_coords_list_eval, |
|
) |
|
|
|
if model_cfg["normalize_by_line_height_and_width"]: |
|
meanlist_eval, stdlist_eval = [], [] |
|
for sample_idx, _ in enumerate(samplelist_eval): |
|
( |
|
trialslist_eval, |
|
samplelist_eval, |
|
meanlist_eval, |
|
stdlist_eval, |
|
chars_center_coords_list_eval, |
|
) = norm_coords_by_letter_positions( |
|
sample_idx, |
|
trialslist_eval, |
|
samplelist_eval, |
|
meanlist_eval, |
|
stdlist_eval, |
|
return_mean_std_lists=True, |
|
norm_by_char_averages=model_cfg["norm_by_char_averages"], |
|
chars_center_coords_list=chars_center_coords_list_eval, |
|
add_normalised_values_as_features=model_cfg["add_normalised_values_as_features"], |
|
) |
|
sample_tensor = samplelist_eval[0] |
|
sample_means = t.tensor(model_cfg["sample_means"], dtype=t.float32) |
|
sample_std = t.tensor(model_cfg["sample_std"], dtype=t.float32) |
|
sample_tensor = (sample_tensor - sample_means) / sample_std |
|
sample_tensor = sample_tensor.unsqueeze(0) |
|
|
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"Using path {trial['plot_file']} for plotting") |
|
plot_text_boxes_fixations( |
|
fpath=trial["plot_file"], |
|
dpi=250, |
|
screen_res=(1024, 768), |
|
data_dir_sub=None, |
|
set_font_size=True, |
|
font_size=4, |
|
use_words=False, |
|
save_channel_repeats=False, |
|
save_combo_grey_and_rgb=False, |
|
dffix=dffix, |
|
trial=trial, |
|
) |
|
|
|
val_set = DSet( |
|
sample_tensor, |
|
None, |
|
t.zeros((1, sample_tensor.shape[1])), |
|
trialslist_eval, |
|
padding_list=[0], |
|
padding_at_end=model_cfg["padding_at_end"], |
|
return_images_for_conv=True, |
|
im_partial_string=model_cfg["im_partial_string"], |
|
input_im_shape=model_cfg["char_plot_shape"], |
|
) |
|
val_loader = dl(val_set, batch_size=1, shuffle=False, num_workers=0) |
|
return val_loader, val_set |
|
|
|
|
|
def fold_in_seq_dim(out, y=None): |
|
batch_size, seq_len, num_classes = out.shape |
|
|
|
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len) |
|
if y is None: |
|
return out, None |
|
if len(y.shape) > 2: |
|
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len) |
|
else: |
|
y = eo.rearrange(y, "b s -> (b s)", s=seq_len) |
|
return out, y |
|
|
|
|
|
def logits_to_pred(out, y=None): |
|
seq_len = out.shape[1] |
|
out, y = fold_in_seq_dim(out, y) |
|
preds = corn_label_from_logits(out) |
|
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len) |
|
if y is not None: |
|
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len) |
|
y = y |
|
return preds, y |
|
|
|
|
|
def get_DIST_preds(dffix, trial, models_dict=None): |
|
algo_choice = "DIST" |
|
|
|
if models_dict is None: |
|
if st.session_state["single_DIST_model"] is None or st.session_state["single_DIST_model_cfg"] is None: |
|
st.session_state["single_DIST_model"], st.session_state["single_DIST_model_cfg"] = find_and_load_model( |
|
model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"] |
|
) |
|
|
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info("Model is None, reiniting model") |
|
else: |
|
model = st.session_state["single_DIST_model"] |
|
loader, dset = prep_data_for_dist(st.session_state["single_DIST_model_cfg"], dffix, trial) |
|
else: |
|
model = models_dict["single_DIST_model"] |
|
loader, dset = prep_data_for_dist(models_dict["single_DIST_model_cfg"], dffix, trial) |
|
batch = next(iter(loader)) |
|
|
|
if "cpu" not in str(model.device): |
|
batch = [x.cuda() for x in batch] |
|
try: |
|
out = model(batch) |
|
preds, y = logits_to_pred(out, y=None) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug( |
|
f"y_char_unique are {trial['y_char_unique']} for trial {trial['trial_id']}" |
|
) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug(f"trial keys are {trial.keys()} for trial {trial['trial_id']}") |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug( |
|
f"chars_list has len {len(trial['chars_list'])} for trial {trial['trial_id']}" |
|
) |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug(f"y_char_unique {trial['y_char_unique']} for trial {trial['trial_id']}") |
|
if len(trial["y_char_unique"]) < 1: |
|
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique() |
|
else: |
|
y_char_unique = trial["y_char_unique"] |
|
num_lines = trial["num_char_lines"] - 1 |
|
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy() |
|
y_pred_DIST = [y_char_unique[idx] for idx in preds] |
|
|
|
dffix[f"line_num_{algo_choice}"] = preds |
|
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1) |
|
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
|
except Exception as e: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(f"Exception on model(batch) for DIST \n{e}") |
|
return dffix |
|
|
|
|
|
def get_DIST_ensemble_preds( |
|
dffix, |
|
trial, |
|
model_cfg_without_norm_df, |
|
model_cfg_with_norm_df, |
|
ensemble_model_avg, |
|
): |
|
algo_choice = "DIST-Ensemble" |
|
loader_without_norm, dset_without_norm = prep_data_for_dist(model_cfg_without_norm_df, dffix, trial) |
|
loader_with_norm, dset_with_norm = prep_data_for_dist(model_cfg_with_norm_df, dffix, trial) |
|
batch_without_norm = next(iter(loader_without_norm)) |
|
batch_with_norm = next(iter(loader_with_norm)) |
|
out = ensemble_model_avg((batch_without_norm, batch_with_norm)) |
|
preds, y = logits_to_pred(out[0]["out_avg"], y=None) |
|
if len(trial["y_char_unique"]) < 1: |
|
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique() |
|
else: |
|
y_char_unique = trial["y_char_unique"] |
|
num_lines = trial["num_char_lines"] - 1 |
|
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy() |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].debug(f"preds are {preds} for trial {trial['trial_id']}") |
|
y_pred_DIST = [y_char_unique[idx] for idx in preds] |
|
|
|
dffix[f"line_num_{algo_choice}"] = preds |
|
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1) |
|
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
|
return dffix |
|
|
|
|
|
def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None, models_dict=None): |
|
|
|
if models_dict is None: |
|
if ensemble_model_avg is None and "ensemble_model_avg" not in st.session_state: |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info("Ensemble Model is None, reiniting model") |
|
dist_models_with_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_True*.ckpt") |
|
dist_models_without_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_False*.ckpt") |
|
|
|
models_without_norm_df = [ |
|
find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm |
|
] |
|
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm] |
|
|
|
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0] |
|
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0] |
|
|
|
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None] |
|
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None] |
|
|
|
ensemble_model_avg = EnsembleModel( |
|
models_without_norm_df, models_with_norm_df, learning_rate=0.0, use_simple_average=True |
|
) |
|
st.session_state["ensemble_model_avg"] = ensemble_model_avg |
|
st.session_state["model_cfg_without_norm_df"] = model_cfg_without_norm_df |
|
st.session_state["model_cfg_with_norm_df"] = model_cfg_with_norm_df |
|
else: |
|
model_cfg_without_norm_df = st.session_state["model_cfg_without_norm_df"] |
|
model_cfg_with_norm_df = st.session_state["model_cfg_with_norm_df"] |
|
ensemble_model_avg = st.session_state["ensemble_model_avg"] |
|
dffix = get_DIST_ensemble_preds( |
|
dffix, |
|
trial, |
|
st.session_state["model_cfg_without_norm_df"], |
|
st.session_state["model_cfg_with_norm_df"], |
|
st.session_state["ensemble_model_avg"], |
|
) |
|
else: |
|
dffix = get_DIST_ensemble_preds( |
|
dffix, |
|
trial, |
|
models_dict["model_cfg_without_norm_df"], |
|
models_dict["model_cfg_with_norm_df"], |
|
models_dict["ensemble_model_avg"], |
|
) |
|
return dffix |
|
|
|
|
|
def correct_df( |
|
dffix, |
|
algo_choice, |
|
trial=None, |
|
for_multi=False, |
|
ensemble_model_avg=None, |
|
is_outside_of_streamlit=False, |
|
classic_algos_cfg=None, |
|
models_dict=None, |
|
): |
|
if is_outside_of_streamlit: |
|
stqdm = tqdm |
|
else: |
|
from stqdm import stqdm |
|
if classic_algos_cfg is None: |
|
classic_algos_cfg = st.session_state["classic_algos_cfg"] |
|
if trial is None and not for_multi: |
|
trial = st.session_state["trial"] |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"Applying {algo_choice} to fixations for trial {trial['trial_id']}") |
|
|
|
if isinstance(dffix, dict): |
|
dffix = dffix["value"] |
|
if "x" not in dffix.keys() or "x" not in dffix.keys(): |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(f"x or y not in dffix") |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].warning(dffix.columns) |
|
return dffix |
|
if isinstance(algo_choice, list): |
|
algo_choices = algo_choice |
|
repeats = range(len(algo_choice)) |
|
else: |
|
algo_choices = [algo_choice] |
|
repeats = range(1) |
|
for algoIdx in stqdm(repeats, desc="Applying correction algorithms"): |
|
algo_choice = algo_choices[algoIdx] |
|
st_proc = time.process_time() |
|
st_wall = time.time() |
|
|
|
if algo_choice == "DIST": |
|
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict) |
|
|
|
elif algo_choice == "DIST-Ensemble": |
|
dffix = get_EDIST_preds_with_model_check(dffix, trial, models_dict=models_dict) |
|
elif algo_choice == "Wisdom_of_Crowds_with_DIST": |
|
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
|
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict) |
|
for _ in range(3): |
|
corrections.append(np.asarray(dffix.loc[:, "y_DIST"])) |
|
dffix = apply_woc(dffix, trial, corrections, algo_choice) |
|
elif algo_choice == "Wisdom_of_Crowds_with_DIST_Ensemble": |
|
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
|
dffix = get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg, models_dict=models_dict) |
|
for _ in range(3): |
|
corrections.append(np.asarray(dffix.loc[:, "y_DIST-Ensemble"])) |
|
dffix = apply_woc(dffix, trial, corrections, algo_choice) |
|
elif algo_choice == "Wisdom_of_Crowds": |
|
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
|
dffix = apply_woc(dffix, trial, corrections, algo_choice) |
|
|
|
else: |
|
algo_cfg = classic_algos_cfg[algo_choice] |
|
dffix = calgo.apply_classic_algo(dffix, trial, algo_choice, algo_cfg) |
|
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
|
|
|
et_proc = time.process_time() |
|
time_proc = et_proc - st_proc |
|
et_wall = time.time() |
|
time_wall = et_wall - st_wall |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"time_proc {algo_choice} {time_proc}") |
|
if "logger" in st.session_state: |
|
st.session_state["logger"].info(f"time_wall {algo_choice} {time_wall}") |
|
if for_multi: |
|
return dffix |
|
else: |
|
if "start_time" in dffix.columns: |
|
dffix = dffix.drop(axis=1, labels=["start_time", "end_time"]) |
|
return dffix, export_csv(dffix, trial) |
|
|
|
def set_font_from_chars_list(trial): |
|
|
|
if "chars_list" in trial: |
|
chars_df = pd.DataFrame(trial["chars_list"]) |
|
line_diffs = np.diff(chars_df.char_y_center.unique()) |
|
y_diffs = np.unique(line_diffs) |
|
if len(y_diffs) == 1: |
|
y_diff = y_diffs[0] |
|
else: |
|
y_diff = np.min(y_diffs) |
|
y_diff = round(y_diff * 2) / 2 |
|
|
|
else: |
|
y_diff = 1 / 0.333 * 18 |
|
font_size = y_diff * 0.333 |
|
return round((font_size)*4,ndigits=0)/4 |
|
|
|
def get_font_and_font_size_from_trial(trial): |
|
font_face, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS) |
|
|
|
if font_size is None and "font_size" in trial: |
|
font_size = trial["font_size"] |
|
elif font_size is None: |
|
font_size = set_font_from_chars_list(trial) |
|
return font_face, font_size |
|
|
|
|
|
def sigmoid(x): |
|
return 1 / (1 + np.exp(-1 * x)) |
|
|
|
|
|
def matplotlib_plot_df( |
|
dffix, |
|
trial, |
|
algo_choice, |
|
stimulus_prefix="word", |
|
desired_dpi=300, |
|
fix_to_plot=[], |
|
stim_info_to_plot=["Words", "Word boxes"], |
|
box_annotations=None, |
|
): |
|
chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None |
|
|
|
if chars_df is not None: |
|
font_face, font_size = get_font_and_font_size_from_trial(trial) |
|
font_size = font_size * 0.65 |
|
else: |
|
st.warning("No character or word information available to plot") |
|
|
|
if "display_coords" in trial: |
|
desired_width_in_pixels = trial["display_coords"][2] + 1 |
|
desired_height_in_pixels = trial["display_coords"][3] + 1 |
|
else: |
|
desired_width_in_pixels = 1920 |
|
desired_height_in_pixels = 1080 |
|
|
|
figure_width = desired_width_in_pixels / desired_dpi |
|
figure_height = desired_height_in_pixels / desired_dpi |
|
|
|
fig = plt.figure(figsize=(figure_width, figure_height), dpi=desired_dpi) |
|
ax = fig.add_subplot(1, 1, 1) |
|
fig.subplots_adjust(bottom=0) |
|
fig.subplots_adjust(top=1) |
|
fig.subplots_adjust(right=1) |
|
fig.subplots_adjust(left=0) |
|
if "font" in trial and trial["font"] in AVAILABLE_FONTS: |
|
font_to_use = trial["font"] |
|
else: |
|
font_to_use = "DejaVu Sans Mono" |
|
if "font_size" in trial: |
|
font_size = trial["font_size"] |
|
else: |
|
font_size = 20 |
|
|
|
if f"{stimulus_prefix}s_list" in trial: |
|
add_text_to_ax( |
|
trial[f"{stimulus_prefix}s_list"], |
|
ax, |
|
font_to_use, |
|
prefix=stimulus_prefix, |
|
fontsize=font_size / 3.89, |
|
plot_text=False, |
|
plot_boxes=True if "Word boxes" in stim_info_to_plot else False, |
|
box_annotations=box_annotations, |
|
) |
|
|
|
if "chars_list" in trial: |
|
add_text_to_ax( |
|
trial["chars_list"], |
|
ax, |
|
font_to_use, |
|
prefix="char", |
|
fontsize=font_size / 3.89, |
|
plot_text=True if "Words" in stim_info_to_plot else False, |
|
plot_boxes=False, |
|
box_annotations=None, |
|
) |
|
|
|
if "Uncorrected Fixations" in fix_to_plot: |
|
ax.plot(dffix.x, dffix.y, label="Raw fixations", color="blue", alpha=0.6, linewidth=0.6) |
|
|
|
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values |
|
x1 = dffix.x.iloc[range(1, len(dffix.x))].values |
|
y0 = dffix.y.iloc[range(len(dffix.y) - 1)].values |
|
y1 = dffix.y.iloc[range(1, len(dffix.y))].values |
|
xpos = x0 |
|
ypos = y0 |
|
xdir = x1 - x0 |
|
ydir = y1 - y0 |
|
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir): |
|
ax.annotate( |
|
"", |
|
xytext=(X, Y), |
|
xy=(X + 0.001 * dX, Y + 0.001 * dY), |
|
arrowprops=dict(arrowstyle="fancy", color="blue"), |
|
size=8, |
|
alpha=0.3, |
|
) |
|
if "Corrected Fixations" in fix_to_plot: |
|
if isinstance(algo_choice, list): |
|
algo_choices = algo_choice |
|
repeats = range(len(algo_choice)) |
|
else: |
|
algo_choices = [algo_choice] |
|
repeats = range(1) |
|
for algoIdx in repeats: |
|
algo_choice = algo_choices[algoIdx] |
|
if f"y_{algo_choice}" in dffix.columns: |
|
ax.plot( |
|
dffix.x, |
|
dffix.loc[:, f"y_{algo_choice}"], |
|
label="Raw fixations", |
|
color=COLORS[algoIdx], |
|
alpha=0.6, |
|
linewidth=0.6, |
|
) |
|
|
|
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values |
|
x1 = dffix.x.iloc[range(1, len(dffix.x))].values |
|
y0 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(len(dffix.loc[:, f"y_{algo_choice}"]) - 1)].values |
|
y1 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(1, len(dffix.loc[:, f"y_{algo_choice}"]))].values |
|
xpos = x0 |
|
ypos = y0 |
|
xdir = x1 - x0 |
|
ydir = y1 - y0 |
|
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir): |
|
ax.annotate( |
|
"", |
|
xytext=(X, Y), |
|
xy=(X + 0.001 * dX, Y + 0.001 * dY), |
|
arrowprops=dict(arrowstyle="fancy", color=COLORS[algoIdx]), |
|
size=8, |
|
alpha=0.3, |
|
) |
|
|
|
ax.set_xlim((0, desired_width_in_pixels)) |
|
ax.set_ylim((0, desired_height_in_pixels)) |
|
ax.invert_yaxis() |
|
|
|
return fig, desired_width_in_pixels, desired_height_in_pixels |
|
|
|
|
|
def plotly_plot_with_image( |
|
dffix, |
|
trial, |
|
algo_choice, |
|
to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"], |
|
scale_factor=0.5, |
|
): |
|
fig, img_width, img_height = matplotlib_plot_df( |
|
dffix, trial, algo_choice, desired_dpi=300, fix_to_plot=[], stim_info_to_plot=to_plot_list |
|
) |
|
fig.savefig(TEMP_FIGURE_STIMULUS_PATH) |
|
fig = go.Figure() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[0, img_width * scale_factor], |
|
y=[img_height * scale_factor, 0], |
|
mode="markers", |
|
marker_opacity=0, |
|
name="scale_helper", |
|
) |
|
) |
|
|
|
fig.update_xaxes(visible=False, range=[0, img_width * scale_factor]) |
|
|
|
fig.update_yaxes( |
|
visible=False, |
|
range=[img_height * scale_factor, 0], |
|
scaleanchor="x", |
|
) |
|
if "Words" in to_plot_list or "Word boxes" in to_plot_list: |
|
imsource = Image.open(str(TEMP_FIGURE_STIMULUS_PATH)) |
|
fig.add_layout_image( |
|
dict( |
|
x=0, |
|
sizex=img_width * scale_factor, |
|
y=0, |
|
sizey=img_height * scale_factor, |
|
xref="x", |
|
yref="y", |
|
opacity=1.0, |
|
layer="below", |
|
sizing="stretch", |
|
source=imsource, |
|
) |
|
) |
|
|
|
if "Uncorrected Fixations" in to_plot_list: |
|
duration_scaled = dffix.duration - dffix.duration.min() |
|
duration_scaled = ((duration_scaled / duration_scaled.max()) - 0.5) * 3 |
|
duration = sigmoid(duration_scaled) * 50 * scale_factor |
|
fig.add_trace( |
|
go.Scatter( |
|
x=dffix.x * scale_factor, |
|
y=dffix.y * scale_factor, |
|
mode="markers+lines+text", |
|
name="Raw fixations", |
|
marker=dict( |
|
color=COLORS[-1], |
|
symbol="arrow", |
|
size=duration.values, |
|
angleref="previous", |
|
line=dict(color="black", width=duration.values / 10), |
|
), |
|
line_width=2 * scale_factor, |
|
text=np.arange(len(dffix.x)), |
|
textposition="middle right", |
|
textfont=dict( |
|
family="sans serif", |
|
size=18 * scale_factor, |
|
), |
|
hoverinfo="text+x+y", |
|
opacity=0.9, |
|
) |
|
) |
|
|
|
if "Corrected Fixations" in to_plot_list: |
|
if isinstance(algo_choice, list): |
|
algo_choices = algo_choice |
|
repeats = range(len(algo_choice)) |
|
else: |
|
algo_choices = [algo_choice] |
|
repeats = range(1) |
|
for algoIdx in repeats: |
|
algo_choice = algo_choices[algoIdx] |
|
if f"y_{algo_choice}" in dffix.columns: |
|
fig.add_trace( |
|
go.Scatter( |
|
x=dffix.x * scale_factor, |
|
y=dffix.loc[:, f"y_{algo_choice}"] * scale_factor, |
|
mode="markers", |
|
name=f"{algo_choice} corrected", |
|
marker_color=COLORS[algoIdx], |
|
marker_size=10 * scale_factor, |
|
hoverinfo="text+x+y", |
|
opacity=0.75, |
|
) |
|
) |
|
|
|
fig.update_layout( |
|
plot_bgcolor=None, |
|
width=img_width * scale_factor, |
|
height=img_height * scale_factor, |
|
margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
|
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8), |
|
) |
|
|
|
for trace in fig["data"]: |
|
if trace["name"] == "scale_helper": |
|
trace["showlegend"] = False |
|
return fig |
|
|
|
|
|
def plot_y_corr(dffix, algo_choice, margin=dict(t=40, l=10, r=10, b=1)): |
|
num_datapoints = len(dffix.x) |
|
|
|
layout = dict( |
|
plot_bgcolor="white", |
|
autosize=True, |
|
margin=margin, |
|
xaxis=dict( |
|
title="Fixation Index", |
|
linecolor="black", |
|
range=[-1, num_datapoints + 1], |
|
showgrid=False, |
|
mirror="all", |
|
showline=True, |
|
), |
|
yaxis=dict( |
|
title="y correction", |
|
side="left", |
|
linecolor="black", |
|
showgrid=False, |
|
mirror="all", |
|
showline=True, |
|
), |
|
legend=dict(orientation="v", yanchor="middle", y=0.95, xanchor="left", x=1.05), |
|
) |
|
if isinstance(dffix, dict): |
|
dffix = dffix["value"] |
|
algo_string = algo_choice[0] if isinstance(algo_choice, list) else algo_choice |
|
if f"y_{algo_string}_correction" not in dffix.columns: |
|
st.session_state["logger"].warning("No correction column found in dataframe") |
|
return go.Figure(layout=layout) |
|
if isinstance(dffix, dict): |
|
dffix = dffix["value"] |
|
|
|
fig = go.Figure(layout=layout) |
|
|
|
if isinstance(algo_choice, list): |
|
algo_choices = algo_choice |
|
repeats = range(len(algo_choice)) |
|
else: |
|
algo_choices = [algo_choice] |
|
repeats = range(1) |
|
for algoIdx in repeats: |
|
algo_choice = algo_choices[algoIdx] |
|
fig.add_trace( |
|
go.Scatter( |
|
x=np.arange(num_datapoints), |
|
y=dffix.loc[:, f"y_{algo_choice}_correction"], |
|
mode="markers", |
|
name=f"{algo_choice} y correction", |
|
marker_color=COLORS[algoIdx], |
|
marker_size=3, |
|
showlegend=True, |
|
) |
|
) |
|
fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor="black") |
|
|
|
return fig |
|
|
|
|
|
def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH): |
|
if not os.path.isdir(EXAMPLES_FOLDER): |
|
os.mkdir(EXAMPLES_FOLDER) |
|
|
|
if not os.path.exists(EXAMPLES_ASC_ZIP_FILENAME): |
|
download_url(OSF_DOWNLAOD_LINK, EXAMPLES_ASC_ZIP_FILENAME) |
|
|
|
|
|
if os.path.exists(EXAMPLES_ASC_ZIP_FILENAME): |
|
if EXAMPLES_FOLDER_PATH.exists(): |
|
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")] |
|
if len(EXAMPLE_ASC_FILES) != 4: |
|
try: |
|
with zipfile.ZipFile(EXAMPLES_ASC_ZIP_FILENAME, "r") as zip_ref: |
|
zip_ref.extractall(EXAMPLES_FOLDER) |
|
except Exception as e: |
|
st.session_state["logger"].warning(e) |
|
st.session_state["logger"].warning(f"Extracting {EXAMPLES_ASC_ZIP_FILENAME} failed") |
|
|
|
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")] |
|
return EXAMPLE_ASC_FILES |
|
|
|
|
|
def process_trial_choice_single_csv(trial, algo_choice, file=None): |
|
trial_id = trial["trial_id"] |
|
if "dffix" in trial: |
|
dffix = trial["dffix"] |
|
else: |
|
if file is None: |
|
file = st.session_state["single_csv_file"] |
|
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{file.name}_{trial_id}_2ndInput_chars_channel_sep.png")) |
|
trial["fname"] = str(file.name) |
|
dffix = trial["dffix"] = st.session_state["trials_by_ids_single_csv"][trial_id]["dffix"] |
|
|
|
font, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS) |
|
chars_df = pd.DataFrame(trial["chars_list"]) |
|
trial["chars_df"] = chars_df.to_dict() |
|
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique()) |
|
if algo_choice is not None: |
|
dffix, _ = correct_df(dffix, algo_choice, trial) |
|
return dffix, trial, dpi, screen_res, font, font_size |
|
|
|
|
|
def add_default_font_and_character_props_to_state(trial): |
|
chars_list = trial["chars_list"] |
|
chars_df = pd.DataFrame(trial["chars_list"]) |
|
line_diffs = np.diff(chars_df.char_y_center.unique()) |
|
y_diffs = np.unique(line_diffs) |
|
if len(y_diffs) == 1: |
|
y_diff = y_diffs[0] |
|
else: |
|
y_diff = np.min(y_diffs) |
|
y_diff = round(y_diff * 2) / 2 |
|
x_txt_start = chars_list[0]["char_xmin"] |
|
y_txt_start = chars_list[0]["char_y_center"] |
|
|
|
font_face, font_size = get_font_and_font_size_from_trial(trial) |
|
|
|
line_height = y_diff |
|
return y_diff, x_txt_start, y_txt_start, font_face, font_size, line_height |
|
|
|
def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"): |
|
if use_corrected_fixations: |
|
dffix_copy = copy.deepcopy(dffix) |
|
dffix_copy["y"] = dffix_copy[f"y_{correction_algo}"] |
|
else: |
|
dffix_copy = dffix |
|
initial_landing_position_own_vals = anf.initial_landing_position_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
second_pass_duration_own_vals = anf.second_pass_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
|
number_of_fixations_own_vals = anf.number_of_fixations_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
|
initial_fixation_duration_own_vals = anf.initial_fixation_duration_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
first_of_many_duration_own_vals = anf.first_of_many_duration_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
total_fixation_duration_own_vals = anf.total_fixation_duration_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
gaze_duration_own_vals = anf.gaze_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
|
go_past_duration_own_vals = anf.go_past_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
|
initial_landing_distance_own_vals = anf.initial_landing_distance_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
landing_distances_own_vals = anf.landing_distances_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
|
number_of_regressions_in_own_vals = anf.number_of_regressions_in_own(trial, dffix_copy, prefix).set_index( |
|
f"{prefix}_index" |
|
) |
|
own_measure_df = pd.concat( |
|
[ |
|
df.drop(prefix, axis=1) |
|
for df in [ |
|
number_of_fixations_own_vals, |
|
initial_fixation_duration_own_vals, |
|
first_of_many_duration_own_vals, |
|
total_fixation_duration_own_vals, |
|
gaze_duration_own_vals, |
|
go_past_duration_own_vals, |
|
second_pass_duration_own_vals, |
|
initial_landing_position_own_vals, |
|
initial_landing_distance_own_vals, |
|
landing_distances_own_vals, |
|
number_of_regressions_in_own_vals, |
|
] |
|
], |
|
axis=1, |
|
) |
|
own_measure_df[prefix] = number_of_fixations_own_vals[prefix] |
|
first_column = own_measure_df.pop(prefix) |
|
own_measure_df.insert(0, prefix, first_column) |
|
own_measure_df.insert(0, f"{prefix}_num", np.arange((own_measure_df.shape[0]))) |
|
return own_measure_df |