hugpv's picture
initial commit via hf
8e5930e verified
import zipfile
import os
import plotly.express as px
import plotly.graph_objects as go
from torch.utils.data.dataloader import DataLoader as dl
import yaml
from io import StringIO
import torch as t
import numpy as np
import pandas as pd
from torch.utils.data import Dataset as torch_dset
from PIL import Image
import torchvision.transforms.functional as tvfunc
import json
from matplotlib import pyplot as plt
import matplotlib.patches as patches
from matplotlib.font_manager import FontProperties
import pathlib as pl
import matplotlib as mpl
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
import einops as eo
import copy
# import stqdm
from tqdm.auto import tqdm
import time
import requests
from matplotlib.patches import Rectangle
from matplotlib import font_manager
from models import LitModel, EnsembleModel
from loss_functions import corn_label_from_logits
import classic_correction_algos as calgo
import analysis_funcs as anf
TEMP_FOLDER = pl.Path("results")
AVAILABLE_FONTS = [x.name for x in font_manager.fontManager.ttflist]
PLOTS_FOLDER = pl.Path("plots")
TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER / "temp_matplotlib_plot_stimulus.png"
all_fonts = [x.name for x in font_manager.fontManager.ttflist]
mpl.use("agg")
DIST_MODELS_FOLDER = pl.Path("models")
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
gradio_plots = pl.Path("plots")
event_strs = [
"EFIX",
"EFIX R",
"EFIX L",
"SSACC",
"ESACC",
"SFIX",
"MSG",
"SBLINK",
"EBLINK",
"BUTTON",
"INPUT",
"END",
"START",
"DISPLAY ON",
]
names_dict = {
"SSACC": {"Descr": "Start of Saccade", "Pattern": "SSACC <eye > <stime>"},
"ESACC": {
"Descr": "End of Saccade",
"Pattern": "ESACC <eye > <stime> <etime > <dur> <sxp > <syp> <exp > <eyp> <ampl > <pv >",
},
"SFIX": {"Descr": "Start of Fixation", "Pattern": "SFIX <eye > <stime>"},
"EFIX": {"Descr": "End of Fixation", "Pattern": "EFIX <eye > <stime> <etime > <dur> <axp > <ayp> <aps >"},
"SBLINK": {"Descr": "Start of Blink", "Pattern": "SBLINK <eye > <stime>"},
"EBLINK": {"Descr": "End of Blink", "Pattern": "EBLINK <eye > <stime> <etime > <dur>"},
"DISPLAY ON": {"Descr": "Actual start of Trial", "Pattern": "DISPLAY ON"},
}
metadata_strs = ["DISPLAY COORDS", "GAZE_COORDS", "FRAMERATE"]
ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [
"warp",
"regress",
"compare",
"attach",
"segment",
"split",
"stretch",
"chain",
"slice",
"cluster",
"merge",
"Wisdom_of_Crowds",
"DIST",
"DIST-Ensemble",
"Wisdom_of_Crowds_with_DIST",
"Wisdom_of_Crowds_with_DIST_Ensemble",
]
COLORS = px.colors.qualitative.Alphabet
class NumpyEncoder(json.JSONEncoder):
"From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable"
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, pl.Path) or isinstance(obj, UploadedFile):
return str(obj)
return json.JSONEncoder.default(self, obj)
class DSet(torch_dset):
def __init__(
self,
in_sequence: t.Tensor,
chars_center_coords_padded: t.Tensor,
out_categories: t.Tensor,
trialslist: list,
padding_list: list = None,
padding_at_end: bool = False,
return_images_for_conv: bool = False,
im_partial_string: str = "fixations_chars_channel_sep",
input_im_shape=[224, 224],
) -> None:
super().__init__()
self.in_sequence = in_sequence
self.chars_center_coords_padded = chars_center_coords_padded
self.out_categories = out_categories
self.padding_list = padding_list
self.padding_at_end = padding_at_end
self.trialslist = trialslist
self.return_images_for_conv = return_images_for_conv
self.input_im_shape = input_im_shape
if return_images_for_conv:
self.im_partial_string = im_partial_string
self.plot_files = [
str(x["plot_file"]).replace("fixations_words", im_partial_string) for x in self.trialslist
]
def __getitem__(self, index):
if self.return_images_for_conv:
im = Image.open(self.plot_files[index])
if [im.size[1], im.size[0]] != self.input_im_shape:
im = tvfunc.resize(im, self.input_im_shape)
im = tvfunc.normalize(tvfunc.to_tensor(im), IMAGENET_MEAN, IMAGENET_STD)
if self.chars_center_coords_padded is not None:
if self.padding_list is not None:
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
if self.padding_at_end:
if self.padding_list[index] > 0:
attention_mask[-self.padding_list[index] :] = 0
else:
attention_mask[: self.padding_list[index]] = 0
if self.return_images_for_conv:
return (
self.in_sequence[index],
self.chars_center_coords_padded[index],
im,
attention_mask,
self.out_categories[index],
)
return (
self.in_sequence[index],
self.chars_center_coords_padded[index],
attention_mask,
self.out_categories[index],
)
else:
if self.return_images_for_conv:
return (
self.in_sequence[index],
self.chars_center_coords_padded[index],
im,
self.out_categories[index],
)
else:
return (self.in_sequence[index], self.chars_center_coords_padded[index], self.out_categories[index])
if self.padding_list is not None:
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
if self.padding_at_end:
if self.padding_list[index] > 0:
attention_mask[-self.padding_list[index] :] = 0
else:
attention_mask[: self.padding_list[index]] = 0
if self.return_images_for_conv:
return (self.in_sequence[index], im, attention_mask, self.out_categories[index])
else:
return (self.in_sequence[index], attention_mask, self.out_categories[index])
if self.return_images_for_conv:
return (self.in_sequence[index], im, self.out_categories[index])
else:
return (self.in_sequence[index], self.out_categories[index])
def __len__(self):
if isinstance(self.in_sequence, t.Tensor):
return self.in_sequence.shape[0]
else:
return len(self.in_sequence)
def download_url(url, target_filename):
r = requests.get(url)
open(target_filename, "wb").write(r.content)
return 0
def asc_to_trial_ids(asc_file, close_gap_between_words=True):
if "logger" in st.session_state:
st.session_state["logger"].debug("asc_to_trial_ids entered")
asc_encoding = ["ISO-8859-15", "UTF-8"][0]
trials_dict, lines = file_to_trials_and_lines(
asc_file, asc_encoding, close_gap_between_words=close_gap_between_words
)
trials_by_ids = {trials_dict[idx]["trial_id"]: trials_dict[idx] for idx in trials_dict["paragraph_trials"]}
if hasattr(asc_file, "name"):
if "logger" in st.session_state:
st.session_state["logger"].info(f"Found {len(trials_by_ids)} trials in {asc_file.name}.")
return trials_by_ids, lines
def get_trials_list(asc_file=None, close_gap_between_words=True):
if "logger" in st.session_state:
st.session_state["logger"].debug("get_trials_list entered")
if asc_file == None:
if "single_asc_file" in st.session_state.keys() and st.session_state["single_asc_file"] is not None:
asc_file = st.session_state["single_asc_file"]
else:
if "logger" in st.session_state:
st.session_state["logger"].warning("Asc file is None")
return None
if hasattr(asc_file, "name"):
if "logger" in st.session_state:
st.session_state["logger"].info(f"get_trials_list entered with asc_file {asc_file.name}")
trials_by_ids, lines = asc_to_trial_ids(asc_file, close_gap_between_words=close_gap_between_words)
trial_keys = list(trials_by_ids.keys())
return trial_keys, trials_by_ids, lines, asc_file
def save_trial_to_json(trial, savename):
if "dffix" in trial:
trial.pop("dffix")
with open(savename, "w", encoding="utf-8") as f:
json.dump(trial, f, ensure_ascii=False, indent=4, cls=NumpyEncoder)
def export_csv(dffix, trial):
if isinstance(dffix, dict):
dffix = dffix["value"]
trial_id = trial["trial_id"]
savename = TEMP_FOLDER.joinpath(pl.Path(trial["fname"]).stem)
trial_name = f"{savename}_{trial_id}_trial_info.json"
csv_name = f"{savename}_{trial_id}.csv"
dffix.to_csv(csv_name)
if "logger" in st.session_state:
st.session_state["logger"].info(f"Saved processed data as {csv_name}")
save_trial_to_json(trial, trial_name)
if "logger" in st.session_state:
st.session_state["logger"].info(f"Saved processed trial data as {trial_name}")
return csv_name, trial_name
def get_all_classic_preds(dffix, trial, classic_algos_cfg):
corrections = []
for algo, classic_params in copy.deepcopy(classic_algos_cfg).items():
dffix = calgo.apply_classic_algo(dffix, trial, algo, classic_params)
corrections.append(np.asarray(dffix.loc[:, f"y_{algo}"]))
return dffix, corrections
def apply_woc(dffix, trial, corrections, algo_choice):
corrected_Y = calgo.wisdom_of_the_crowd(corrections)
dffix.loc[:, f"y_{algo_choice}"] = corrected_Y
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_Y]
dffix.loc[:, f"line_num_y_{algo_choice}"] = corrected_line_nums
return dffix
def calc_xdiff_ydiff(line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False):
x_diffs = np.unique(np.diff(line_xcoords_no_pad))
if len(x_diffs) == 1:
x_diff = x_diffs[0]
elif not allow_multiple_values:
x_diff = np.min(x_diffs)
else:
x_diff = x_diffs
if np.unique(line_ycoords_no_pad).shape[0] == 1:
return x_diff, line_heights[0]
y_diffs = np.unique(np.diff(line_ycoords_no_pad))
if len(y_diffs) == 1:
y_diff = y_diffs[0]
elif len(y_diffs) == 0:
y_diff = 0
elif not allow_multiple_values:
y_diff = np.min(y_diffs)
else:
y_diff = y_diffs
return x_diff, y_diff
def add_words(trial, close_gap_between_words=True):
chars_list_reconstructed = []
words_list = []
word_start_idx = 0
chars_df = pd.DataFrame(trial["chars_list"])
chars_df["char_width"] = chars_df.char_xmax - chars_df.char_xmin
space_width = chars_df.loc[chars_df["char"] == " ", "char_width"].mean()
for idx, char_dict in enumerate(trial["chars_list"]):
on_line_num = char_dict["assigned_line"]
chars_list_reconstructed.append(char_dict)
if (
char_dict["char"] in [" ", ",", ";", ".", ":"]
or (
len(chars_list_reconstructed) > 2
and (chars_list_reconstructed[-1]["char_xmin"] < chars_list_reconstructed[-2]["char_xmin"])
)
or len(chars_list_reconstructed) == len(trial["chars_list"])
):
triggered = True
word_xmin = chars_list_reconstructed[word_start_idx]["char_xmin"]
word_xmax = chars_list_reconstructed[-2]["char_xmax"]
word_ymin = chars_list_reconstructed[word_start_idx]["char_ymin"]
word_ymax = chars_list_reconstructed[word_start_idx]["char_ymax"]
word_x_center = (word_xmax - word_xmin) / 2 + word_xmin
word_y_center = (word_ymax - word_ymin) / 2 + word_ymin
word = "".join(
[
chars_list_reconstructed[idx]["char"]
for idx in range(word_start_idx, len(chars_list_reconstructed) - 1)
]
)
assigned_line = chars_list_reconstructed[word_start_idx]["assigned_line"]
word_dict = dict(
word=word,
word_xmin=word_xmin,
word_xmax=word_xmax,
word_ymin=word_ymin,
word_ymax=word_ymax,
word_x_center=word_x_center,
word_y_center=word_y_center,
assigned_line=assigned_line,
)
if char_dict["char"] != " ":
word_start_idx = idx
else:
word_start_idx = idx + 1
words_list.append(word_dict)
else:
triggered = False
last_letter_in_word = word_dict["word"][-1]
last_letter_in_chars_list_reconstructed = char_dict["char"]
if last_letter_in_word != last_letter_in_chars_list_reconstructed:
word_dict = dict(
word=char_dict["char"],
word_xmin=char_dict["char_xmin"],
word_xmax=char_dict["char_xmax"],
word_ymin=char_dict["char_ymin"],
word_ymax=char_dict["char_ymax"],
word_x_center=char_dict["char_x_center"],
word_y_center=char_dict["char_y_center"],
assigned_line=assigned_line,
)
words_list.append(word_dict)
if close_gap_between_words:
for widx in range(1, len(words_list)):
if words_list[widx]["assigned_line"] == words_list[widx - 1]["assigned_line"]:
word_sep_half_width = (words_list[widx]["word_xmin"] - words_list[widx - 1]["word_xmax"]) / 2
words_list[widx - 1]["word_xmax"] = words_list[widx - 1]["word_xmax"] + word_sep_half_width
words_list[widx]["word_xmin"] = words_list[widx]["word_xmin"] - word_sep_half_width
return words_list
def asc_lines_to_trials_by_trail_id(
lines: list, paragraph_trials_only=False, fname: str = "", close_gap_between_words=True
) -> dict:
if hasattr(fname, "name"):
fname = fname.name
fps = -999
display_coords = -999
trials_dict = dict(paragraph_trials=[], paragraph_trial_IDs=[])
trial_idx = -1
removed_trial_ids = []
for idx, l in enumerate(lines):
parts = l.strip().split(" ")
if "TRIALID" in l:
trial_id = parts[-1]
trial_idx += 1
if trial_id[0] == "F":
trial_is = "question"
elif trial_id[0] == "P":
trial_is = "practice"
else:
trial_is = "paragraph"
trials_dict["paragraph_trials"].append(trial_idx)
trials_dict["paragraph_trial_IDs"].append(trial_id)
trials_dict[trial_idx] = dict(trial_id=trial_id, trial_id_idx=idx, trial_is=trial_is, filename=fname)
last_trial_skipped = False
elif "TRIAL_RESULT" in l or "stop_trial" in l:
trials_dict[trial_idx]["trial_result_idx"] = idx
trials_dict[trial_idx]["trial_result_timestamp"] = int(parts[0].split("\t")[1])
if len(parts) > 2:
trials_dict[trial_idx]["trial_result_number"] = int(parts[2])
elif "DISPLAY COORDS" in l and isinstance(display_coords, int):
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
elif "GAZE_COORDS" in l and isinstance(display_coords, int):
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
elif "FRAMERATE" in l:
l_idx = parts.index(metadata_strs[2])
fps = float(parts[l_idx + 1])
elif "TRIAL ABORTED" in l or "TRIAL REPEATED" in l:
if not last_trial_skipped:
if trial_is == "paragraph":
trials_dict["paragraph_trials"].remove(trial_idx)
trial_idx -= 1
removed_trial_ids.append(trial_id)
last_trial_skipped = True
if paragraph_trials_only:
trials_dict_temp = trials_dict.copy()
for k in trials_dict_temp.keys():
if k not in ["paragraph_trials"] + trials_dict_temp["paragraph_trials"]:
trials_dict.pop(k)
if len(trials_dict_temp["paragraph_trials"]):
trial_idx = trials_dict_temp["paragraph_trials"][-1]
else:
return trials_dict
trials_dict["display_coords"] = display_coords
trials_dict["fps"] = fps
trials_dict["max_trial_idx"] = trial_idx
enum = trials_dict["paragraph_trials"] if "paragraph_trials" in trials_dict.keys() else range(len(trials_dict))
for trial_idx in enum:
if trial_idx not in trials_dict.keys():
continue
chars_list = []
if "display_coords" not in trials_dict[trial_idx].keys():
trials_dict[trial_idx]["display_coords"] = trials_dict["display_coords"]
trial_start_idx = trials_dict[trial_idx]["trial_id_idx"]
trial_end_idx = trials_dict[trial_idx]["trial_result_idx"]
trial_lines = lines[trial_start_idx:trial_end_idx]
for idx, l in enumerate(trial_lines):
parts = l.strip().split(" ")
if "START" in l and " MSG" not in l:
trials_dict[trial_idx]["start_idx"] = trial_start_idx + idx + 7
trials_dict[trial_idx]["start_time"] = int(parts[0].split("\t")[1])
elif "END" in l and "ENDBUTTON" not in l and " MSG" not in l:
trials_dict[trial_idx]["end_idx"] = trial_start_idx + idx - 2
trials_dict[trial_idx]["end_time"] = int(parts[0].split("\t")[1])
elif "SYNCTIME" in l:
trials_dict[trial_idx]["synctime"] = trial_start_idx + idx
trials_dict[trial_idx]["synctime_time"] = int(parts[0].split("\t")[1])
elif "GAZE TARGET OFF" in l:
trials_dict[trial_idx]["gaze_targ_off_time"] = int(parts[0].split("\t")[1])
elif "GAZE TARGET ON" in l:
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
elif "DISPLAY_SENTENCE" in l: # some .asc files seem to use this
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
elif "REGION CHAR" in l:
rg_idx = parts.index("CHAR")
if len(parts[rg_idx:]) > 8:
char = " "
idx_correction = 1
elif len(parts[rg_idx:]) == 3:
char = " "
if "REGION CHAR" not in trial_lines[idx + 1]:
parts = trial_lines[idx + 1].strip().split(" ")
idx_correction = -rg_idx - 4
else:
char = parts[rg_idx + 3]
idx_correction = 0
try:
char_dict = {
"char": char,
"char_xmin": float(parts[rg_idx + 4 + idx_correction]),
"char_ymin": float(parts[rg_idx + 5 + idx_correction]),
"char_xmax": float(parts[rg_idx + 6 + idx_correction]),
"char_ymax": float(parts[rg_idx + 7 + idx_correction]),
}
char_dict["char_y_center"] = (char_dict["char_ymax"] - char_dict["char_ymin"]) / 2 + char_dict[
"char_ymin"
]
char_dict["char_x_center"] = (char_dict["char_xmax"] - char_dict["char_xmin"]) / 2 + char_dict[
"char_xmin"
]
chars_list.append(char_dict)
except Exception as e:
if "logger" in st.session_state:
st.session_state["logger"].warning(f"char_dict creation failed for parts {parts}")
if "logger" in st.session_state:
st.session_state["logger"].warning(e)
if "gaze_targ_on_time" in trials_dict[trial_idx]:
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["gaze_targ_on_time"]
else:
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["start_time"]
if len(chars_list) > 0:
line_ycoords = []
for idx in range(len(chars_list)):
chars_list[idx]["char_line_y"] = (
chars_list[idx]["char_ymax"] - chars_list[idx]["char_ymin"]
) / 2 + chars_list[idx]["char_ymin"]
if chars_list[idx]["char_line_y"] not in line_ycoords:
line_ycoords.append(chars_list[idx]["char_line_y"])
for idx in range(len(chars_list)):
chars_list[idx]["assigned_line"] = line_ycoords.index(chars_list[idx]["char_line_y"])
line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list]
line_xcoords_all = [x["char_x_center"] for x in chars_list]
line_xcoords_no_pad = np.unique(line_xcoords_all)
line_ycoords_all = [x["char_y_center"] for x in chars_list]
line_ycoords_no_pad = np.unique(line_ycoords_all)
trials_dict[trial_idx]["x_char_unique"] = list(line_xcoords_no_pad)
trials_dict[trial_idx]["y_char_unique"] = list(line_ycoords_no_pad)
x_diff, y_diff = calc_xdiff_ydiff(
line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False
)
trials_dict[trial_idx]["x_diff"] = float(x_diff)
trials_dict[trial_idx]["y_diff"] = float(y_diff)
trials_dict[trial_idx]["num_char_lines"] = len(line_ycoords_no_pad)
trials_dict[trial_idx]["line_heights"] = line_heights
trials_dict[trial_idx]["chars_list"] = chars_list
words_list = add_words(trials_dict[trial_idx], close_gap_between_words=close_gap_between_words)
trials_dict[trial_idx]["words_list"] = words_list
return trials_dict
def file_to_trials_and_lines(uploaded_file, asc_encoding: str = "ISO-8859-15", close_gap_between_words=True):
if isinstance(uploaded_file, str) or isinstance(uploaded_file, pl.Path):
with open(uploaded_file, "r", encoding=asc_encoding) as f:
lines = f.readlines()
else:
stringio = StringIO(uploaded_file.getvalue().decode(asc_encoding))
loaded_str = stringio.read()
lines = loaded_str.split("\n")
trials_dict = asc_lines_to_trials_by_trail_id(
lines, True, uploaded_file, close_gap_between_words=close_gap_between_words
)
if "paragraph_trials" not in trials_dict.keys() and "trial_is" in trials_dict[0].keys():
paragraph_trials = []
for k in range(trials_dict["max_trial_idx"]):
if trials_dict[k]["trial_is"] == "paragraph":
paragraph_trials.append(k)
trials_dict["paragraph_trials"] = paragraph_trials
enum = (
trials_dict["paragraph_trials"]
if "paragraph_trials" in trials_dict.keys()
else range(trials_dict["max_trial_idx"])
)
for k in enum:
if "chars_list" in trials_dict[k].keys():
max_line = trials_dict[k]["chars_list"][-1]["assigned_line"]
words_on_lines = {x: [] for x in range(max_line + 1)}
[words_on_lines[x["assigned_line"]].append(x["char"]) for x in trials_dict[k]["chars_list"]]
sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()]
text = sentence_list[0] + "\n".join([x for x in sentence_list[1:]])
trials_dict[k]["sentence_list"] = sentence_list
trials_dict[k]["text"] = text
trials_dict[k]["max_line"] = max_line
return trials_dict, lines
def get_plot_props(trial, available_fonts):
if "font" in trial.keys():
font = trial["font"]
font_size = trial["font_size"]
if font not in available_fonts:
font = "DejaVu Sans Mono"
else:
font = "DejaVu Sans Mono"
font_size = 21
dpi = 100
if "display_coords" in trial.keys():
screen_res = (trial["display_coords"][2], trial["display_coords"][3])
else:
screen_res = (1920, 1080)
return font, font_size, dpi, screen_res
def trial_to_dfs(
trial: dict, lines: list, use_synctime: bool = False, save_lines_to_txt=False, cut_out_outer_fixations=False
):
"""trial should be dict of line numbers of trials.
lines should be list of lines from .asc file."""
if use_synctime and "synctime" in trial:
idx0, idxend = trial["synctime"] + 1, trial["trial_result_idx"]
else:
idx0, idxend = trial["start_idx"], trial["end_idx"]
line_dicts = []
fixations_dicts = []
blink_started = False
fixation_started = False
efix_count = 0
sfix_count = 0
sblink_count = 0
if save_lines_to_txt:
with open("Lines_plus500.txt", "w") as f:
f.writelines(lines[idx0 - 500 : idxend + 500])
eye_to_use = "R"
for l in lines[idx0 : idxend + 1]:
if "EFIX R" in l:
eye_to_use = "R"
break
elif "EFIX L" in l:
eye_to_use = "L"
break
for l in lines[idx0 : idxend + 1]:
parts = [x.strip() for x in l.split("\t")]
if f"EFIX {eye_to_use}" in l:
efix_count += 1
if fixation_started:
if parts[1] == "." and parts[2] == ".":
continue
fixations_dicts.append(
{
"start_time": float(parts[0].split()[-1].strip()),
"end_time": float(parts[1].strip()),
"duration": float(parts[2].strip()),
"x": float(parts[3].strip()),
"y": float(parts[4].strip()),
"pupil_size": float(parts[5].strip()),
}
)
if len(fixations_dicts) >= 2:
assert (
fixations_dicts[-1]["start_time"] > fixations_dicts[-2]["start_time"]
), "start times not in order"
fixation_started = False
elif f"SFIX {eye_to_use}" in l:
sfix_count += 1
fixation_started = True
elif f"SBLINK {eye_to_use}" in l:
sblink_count += 1
blink_started = True
if not blink_started and not any([True for x in event_strs if x in l]):
if len(parts) < 3 or (parts[1] == "." and parts[2] == "."):
continue
line_dicts.append(
{
"idx": float(parts[0].strip()),
"x": float(parts[1].strip()),
"y": float(parts[2].strip()),
"p": float(parts[3].strip()),
}
)
elif f"EBLINK {eye_to_use}" in l:
blink_started = False
df = pd.DataFrame(line_dicts)
dffix = pd.DataFrame(fixations_dicts)
if len(fixations_dicts) > 0:
dffix["corrected_start_time"] = dffix.start_time - trial["trial_start_time"]
dffix["corrected_end_time"] = dffix.end_time - trial["trial_start_time"]
dffix["fix_duration"] = dffix.corrected_end_time.values - dffix.corrected_start_time.values
assert all(np.diff(dffix["corrected_start_time"]) > 0), "start times not in order"
else:
df, pd.DataFrame(), trial
if cut_out_outer_fixations:
dffix = dffix[(dffix.x > -10) & (dffix.y > -10) & (dffix.x < 1050) & (dffix.y < 800)]
trial["efix_count"] = efix_count
trial["eye_to_use"] = eye_to_use
trial["sfix_count"] = sfix_count
trial["sblink_count"] = sblink_count
return df, dffix, trial
def get_save_path(fpath, fname_ending):
save_path = gradio_plots.joinpath(f"{fpath.stem}_{fname_ending}.png")
return save_path
def save_im_load_convert(fpath, fig, fname_ending, mode):
save_path = get_save_path(fpath, fname_ending)
fig.savefig(save_path)
im = Image.open(save_path).convert(mode)
im.save(save_path)
return im
def get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, dffix=None, prefix="word"):
fig = plt.figure(figsize=(screen_res[0] / dpi, screen_res[1] / dpi), dpi=dpi)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
if dffix is not None:
ax.set_ylim((dffix.y.min(), dffix.y.max()))
ax.set_xlim((dffix.x.min(), dffix.x.max()))
else:
ax.set_ylim((words_df[f"{prefix}_y_center"].min() - y_margin, words_df[f"{prefix}_y_center"].max() + y_margin))
ax.set_xlim((words_df[f"{prefix}_x_center"].min() - x_margin, words_df[f"{prefix}_x_center"].max() + x_margin))
ax.invert_yaxis()
fig.add_axes(ax)
return fig, ax
def plot_text_boxes_fixations(
fpath,
dpi,
screen_res,
data_dir_sub,
set_font_size: bool,
font_size: int,
use_words: bool,
save_channel_repeats: bool,
save_combo_grey_and_rgb: bool,
dffix=None,
trial=None,
):
if isinstance(fpath, str):
fpath = pl.Path(fpath)
if use_words:
prefix = "word"
else:
prefix = "char"
if dffix is None:
dffix = pd.read_csv(fpath)
if trial is None:
json_fpath = str(fpath).replace("_fixations.csv", "_trial.json")
with open(json_fpath, "r") as f:
trial = json.load(f)
words_df = pd.DataFrame(trial[f"{prefix}s_list"])
x_right = words_df[f"{prefix}_xmin"]
x_left = words_df[f"{prefix}_xmax"]
y_top = words_df[f"{prefix}_ymax"]
y_bottom = words_df[f"{prefix}_ymin"]
if f"{prefix}_x_center" not in words_df.columns:
words_df[f"{prefix}_x_center"] = (words_df[f"{prefix}_xmax"] - words_df[f"{prefix}_xmin"]) / 2 + words_df[
f"{prefix}_xmin"
]
words_df[f"{prefix}_y_center"] = (words_df[f"{prefix}_ymax"] - words_df[f"{prefix}_ymin"]) / 2 + words_df[
f"{prefix}_ymin"
]
x_margin = words_df[f"{prefix}_x_center"].mean() / 8
y_margin = words_df[f"{prefix}_y_center"].mean() / 4
times = dffix.corrected_start_time - dffix.corrected_start_time.min()
times = times / times.max()
times = np.linspace(0.25, 1, len(times))
if set_font_size:
font = "monospace"
else:
font_size = trial["font_size"] * 27 // dpi
font_props = FontProperties(family=font, style="normal", size=font_size)
if save_combo_grey_and_rgb:
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
ax.scatter(dffix.x, dffix.y, alpha=times, facecolor="b")
for idx in range(len(x_left)):
xdiff = x_right[idx] - x_left[idx]
ydiff = y_top[idx] - y_bottom[idx]
rect = patches.Rectangle(
(x_left[idx] - 1, y_bottom[idx] - 1),
xdiff,
ydiff,
alpha=0.9,
linewidth=0.8,
edgecolor="r",
facecolor="none",
) # seems to need one pixel offset
ax.text(
words_df[f"{prefix}_x_center"][idx],
words_df[f"{prefix}_y_center"][idx],
words_df[prefix][idx],
horizontalalignment="center",
verticalalignment="center",
fontproperties=font_props,
color="g",
)
ax.add_patch(rect)
fname_ending = f"{prefix}s_combo_rgb"
words_combo_rgb_im = save_im_load_convert(fpath, fig, fname_ending, "RGB")
plt.close("all")
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
for idx in range(len(x_left)):
xdiff = x_right[idx] - x_left[idx]
ydiff = y_top[idx] - y_bottom[idx]
rect = patches.Rectangle(
(x_left[idx] - 1, y_bottom[idx] - 1),
xdiff,
ydiff,
alpha=0.9,
linewidth=0.8,
edgecolor="k",
facecolor="none",
) # seems to need one pixel offset
ax.text(
words_df[f"{prefix}_x_center"][idx],
words_df[f"{prefix}_y_center"][idx],
words_df[prefix][idx],
horizontalalignment="center",
verticalalignment="center",
fontproperties=font_props,
)
ax.add_patch(rect)
fname_ending = f"{prefix}s_combo_grey"
words_combo_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
plt.close("all")
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.01)
for idx in range(len(x_left)):
ax.text(
words_df[f"{prefix}_x_center"][idx],
words_df[f"{prefix}_y_center"][idx],
words_df[prefix][idx],
horizontalalignment="center",
verticalalignment="center",
fontproperties=font_props,
)
fname_ending = f"{prefix}s_grey"
words_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
plt.close("all")
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.1)
for idx in range(len(x_left)):
xdiff = x_right[idx] - x_left[idx]
ydiff = y_top[idx] - y_bottom[idx]
rect = patches.Rectangle(
(x_left[idx] - 1, y_bottom[idx] - 1), xdiff, ydiff, alpha=0.9, linewidth=1, edgecolor="k", facecolor="grey"
) # seems to need one pixel offset
ax.add_patch(rect)
fname_ending = f"{prefix}_boxes_grey"
word_boxes_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
plt.close("all")
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
fname_ending = "fix_scatter_grey"
fix_scatter_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
plt.close("all")
arr_combo = np.stack(
[
np.asarray(words_grey_im),
np.asarray(word_boxes_grey_im),
np.asarray(fix_scatter_grey_im),
],
axis=2,
)
im_combo = Image.fromarray(arr_combo)
fname_ending = f"{prefix}s_channel_sep"
save_path = get_save_path(fpath, fname_ending)
print(f"save_path for im combo is {save_path}")
im_combo.save(fpath)
if save_channel_repeats:
arr_combo = np.stack([np.asarray(words_grey_im)] * 3, axis=2)
im_combo = Image.fromarray(arr_combo)
fname_ending = f"{prefix}s_channel_repeat"
save_path = get_save_path(fpath, fname_ending)
im_combo.save(save_path)
arr_combo = np.stack([np.asarray(word_boxes_grey_im)] * 3, axis=2)
im_combo = Image.fromarray(arr_combo)
fname_ending = f"{prefix}boxes_channel_repeat"
save_path = get_save_path(fpath, fname_ending)
im_combo.save(save_path)
arr_combo = np.stack([np.asarray(fix_scatter_grey_im)] * 3, axis=2)
im_combo = Image.fromarray(arr_combo)
fname_ending = "fix_channel_repeat"
save_path = get_save_path(fpath, fname_ending)
im_combo.save(save_path)
def add_line_overlaps_to_sample(trial, sample):
char_df = pd.DataFrame(trial["chars_list"])
line_overlaps = []
for arr in sample:
y_val = arr[1]
line_overlap = t.tensor(-1, dtype=t.float32)
for idx, (x1, x2) in enumerate(zip(char_df.char_ymin.unique(), char_df.char_ymax.unique())):
if x1 <= y_val <= x2:
line_overlap = t.tensor(idx, dtype=t.float32)
break
line_overlaps.append(line_overlap)
line_olaps_tensor = t.stack(line_overlaps, dim=0)
sample = t.cat([sample, line_olaps_tensor.unsqueeze(1)], dim=1)
return sample
def norm_coords_by_letter_min_x_y(
sample_idx: int,
trialslist: list,
samplelist: list,
chars_center_coords_list: list = None,
):
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
min_x_chars = chars_df.char_xmin.min()
min_y_chars = chars_df.char_ymin.min()
norm_vector_substract = t.zeros(
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
)
norm_vector_substract[0, 0] = norm_vector_substract[0, 0] + 1 * min_x_chars
norm_vector_substract[0, 1] = norm_vector_substract[0, 1] + 1 * min_y_chars
samplelist[sample_idx] = samplelist[sample_idx] - norm_vector_substract
if chars_center_coords_list is not None:
norm_vector_substract = norm_vector_substract.squeeze(0)[:2]
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_substract.shape[-1] * 2:
chars_center_coords_list[sample_idx][:, :2] -= norm_vector_substract
chars_center_coords_list[sample_idx][:, 2:] -= norm_vector_substract
else:
chars_center_coords_list[sample_idx] -= norm_vector_substract
return trialslist, samplelist, chars_center_coords_list
def norm_coords_by_letter_positions(
sample_idx: int,
trialslist: list,
samplelist: list,
meanlist: list = None,
stdlist: list = None,
return_mean_std_lists=False,
norm_by_char_averages=False,
chars_center_coords_list: list = None,
add_normalised_values_as_features=False,
):
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
min_x_chars = chars_df.char_xmin.min()
max_x_chars = chars_df.char_xmax.max()
norm_vector_multi = t.ones(
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
)
if norm_by_char_averages:
chars_list = trialslist[sample_idx]["chars_list"]
char_widths = np.asarray([x["char_xmax"] - x["char_xmin"] for x in chars_list])
char_heights = np.asarray([x["char_ymax"] - x["char_ymin"] for x in chars_list])
char_widths_average = np.mean(char_widths[char_widths > 0])
char_heights_average = np.mean(char_heights[char_heights > 0])
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * char_widths_average
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * char_heights_average
else:
line_height = min(np.unique(trialslist[sample_idx]["line_heights"]))
line_width = max_x_chars - min_x_chars
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * line_width
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * line_height
assert ~t.any(t.isnan(norm_vector_multi)), "Nan found in char norming vector"
norm_vector_multi = norm_vector_multi.squeeze(0)
if add_normalised_values_as_features:
norm_vector_multi = norm_vector_multi[norm_vector_multi != 1]
normed_features = samplelist[sample_idx][:, : norm_vector_multi.shape[0]] / norm_vector_multi
samplelist[sample_idx] = t.cat([samplelist[sample_idx], normed_features], dim=1)
else:
samplelist[sample_idx] = samplelist[sample_idx] / norm_vector_multi # in case time or pupil size is included
if chars_center_coords_list is not None:
norm_vector_multi = norm_vector_multi[:2]
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_multi.shape[-1] * 2:
chars_center_coords_list[sample_idx][:, :2] /= norm_vector_multi
chars_center_coords_list[sample_idx][:, 2:] /= norm_vector_multi
else:
chars_center_coords_list[sample_idx] /= norm_vector_multi
if return_mean_std_lists:
mean_val = samplelist[sample_idx].mean(axis=0).cpu().numpy()
meanlist.append(mean_val)
std_val = samplelist[sample_idx].std(axis=0).cpu().numpy()
stdlist.append(std_val)
assert ~any(np.isnan(mean_val)), "Nan found in mean_val"
assert ~any(np.isnan(mean_val)), "Nan found in std_val"
return trialslist, samplelist, meanlist, stdlist, chars_center_coords_list
return trialslist, samplelist, chars_center_coords_list
def remove_compile_from_model(model):
if hasattr(model.project, "_orig_mod"):
model.project = model.project._orig_mod
model.chars_conv = model.chars_conv._orig_mod
model.chars_classifier = model.chars_classifier._orig_mod
model.layer_norm_in = model.layer_norm_in._orig_mod
model.bert_model = model.bert_model._orig_mod
model.linear = model.linear._orig_mod
else:
print(f"remove_compile_from_model not done since model.project {model.project} has no orig_mod")
return model
def remove_compile_from_dict(state_dict):
for key in list(state_dict.keys()):
newkey = key.replace("._orig_mod.", ".")
state_dict[newkey] = state_dict.pop(key)
return state_dict
def add_text_to_ax(
chars_list,
ax,
font_to_use="DejaVu Sans Mono",
fontsize=21,
prefix="char",
plot_boxes=True,
plot_text=True,
box_annotations=None,
):
font_props = FontProperties(family=font_to_use, style="normal", size=fontsize)
if not plot_boxes and not plot_text:
return None
if box_annotations is None:
enum = chars_list
else:
enum = zip(chars_list, box_annotations)
for v in enum:
if box_annotations is not None:
v, annot_text = v
x0, y0 = v[f"{prefix}_xmin"], v[f"{prefix}_ymin"]
xdiff, ydiff = v[f"{prefix}_xmax"] - v[f"{prefix}_xmin"], v[f"{prefix}_ymax"] - v[f"{prefix}_ymin"]
if plot_text:
ax.text(
v[f"{prefix}_x_center"],
v[f"{prefix}_y_center"],
v[prefix],
horizontalalignment="center",
verticalalignment="center",
fontproperties=font_props,
)
if plot_boxes:
ax.add_patch(Rectangle((x0, y0), xdiff, ydiff, edgecolor="grey", facecolor="none", lw=0.8, alpha=0.4))
if box_annotations is not None:
ax.annotate(
str(annot_text),
(x0 + xdiff / 2, y0),
horizontalalignment="center",
verticalalignment="center",
fontproperties=FontProperties(family=font_to_use, style="normal", size=fontsize / 1.5),
)
def plot_fixations_and_text(
dffix: pd.DataFrame,
trial: dict,
plot_prefix="chars_",
show=False,
returnfig=False,
save=False,
savelocation="plot.png",
font_to_use="DejaVu Sans Mono",
fontsize=20,
plot_classic=True,
plot_boxes=True,
plot_text=True,
fig_size=(14, 8),
dpi=300,
turn_axis_on=True,
algo_choice="slice",
):
fig, ax = plt.subplots(1, 1, figsize=fig_size, tight_layout=True, dpi=dpi)
if f"{plot_prefix}list" in trial.keys():
add_text_to_ax(
trial[f"{plot_prefix}list"],
ax,
font_to_use,
fontsize=fontsize,
prefix=plot_prefix[:-2],
plot_boxes=plot_boxes,
plot_text=plot_text,
)
ax.plot(dffix.x, dffix.y, "kX", label="Raw Fixations", alpha=0.9)
if plot_classic and f"line_num_{algo_choice}" in dffix.columns:
ax.scatter(
dffix.x,
dffix[f"y_{algo_choice}"],
marker="*",
color="tab:green",
label=f"{algo_choice} Prediction",
alpha=0.9,
)
for x_before, y_before, x_after, y_after in zip(
dffix.x.values, dffix[f"y_{algo_choice}"].values, dffix.x, dffix.y
):
arr_delta_x = x_after - x_before
arr_delta_y = y_after - y_before
ax.arrow(x_before, y_before, arr_delta_x, arr_delta_y, color="tab:green", alpha=0.6)
ax.set_ylabel("y (pixel)")
ax.set_xlabel("x (pixel)")
ax.invert_yaxis()
ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
if not turn_axis_on:
ax.axis("off")
if save:
plt.savefig(savelocation, dpi=dpi)
if show:
plt.show()
if returnfig:
return fig
else:
plt.close()
return None
def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots):
gradio_temp_folder.mkdir(exist_ok=True)
gradio_temp_unzipped_folder.mkdir(exist_ok=True)
gradio_plots.mkdir(exist_ok=True)
return 0
def get_classic_cfg(fname):
with open(fname, "r") as f:
jsonsstring = f.read()
classic_algos_cfg = json.loads(jsonsstring)
classic_algos_cfg["slice"] = classic_algos_cfg["slice"]
classic_algos_cfg = classic_algos_cfg
return classic_algos_cfg
def find_and_load_model(model_date="20240104-223349"):
model_cfg_file = list(DIST_MODELS_FOLDER.glob(f"*{model_date}*.yaml"))
if len(model_cfg_file) == 0:
if "logger" in st.session_state:
st.session_state["logger"].warning(f"No model cfg yaml found for {model_date}")
return None, None
model_cfg_file = model_cfg_file[0]
with open(model_cfg_file) as f:
model_cfg = yaml.safe_load(f)
model_cfg["system_type"] = "linux"
model_file = list(pl.Path("models").glob(f"*{model_date}*.ckpt"))[0]
model = load_model(model_file, model_cfg)
return model, model_cfg
def load_model(model_file, cfg):
try:
model_loaded = t.load(model_file, map_location="cpu")
if "hyper_parameters" in model_loaded.keys():
model_cfg_temp = model_loaded["hyper_parameters"]["cfg"]
else:
model_cfg_temp = cfg
model_state_dict = model_loaded["state_dict"]
except Exception as e:
if "logger" in st.session_state:
st.session_state["logger"].warning(e)
if "logger" in st.session_state:
st.session_state["logger"].warning(f"Failed to load {model_file}")
return None
model = LitModel(
[1, 500, 3],
model_cfg_temp["hidden_dim_bert"],
model_cfg_temp["num_attention_heads"],
model_cfg_temp["n_layers_BERT"],
model_cfg_temp["loss_function"],
1e-4,
model_cfg_temp["weight_decay"],
model_cfg_temp,
model_cfg_temp["use_lr_warmup"],
model_cfg_temp["use_reduce_on_plateau"],
track_gradient_histogram=model_cfg_temp["track_gradient_histogram"],
register_forw_hook=model_cfg_temp["track_activations_via_hook"],
char_dims=model_cfg_temp["char_dims"],
)
model = remove_compile_from_model(model)
model_state_dict = remove_compile_from_dict(model_state_dict)
with t.no_grad():
model.load_state_dict(model_state_dict, strict=False)
model.eval()
model.freeze()
return model
def set_up_models(dist_models_folder):
out_dict = {}
if "logger" in st.session_state:
st.session_state["logger"].info("Loading Ensemble")
dist_models_with_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_True*.ckpt"))
dist_models_without_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_False*.ckpt"))
DIST_MODEL_DATE_WITH_NORM = dist_models_with_norm[0].stem.split("_")[1]
models_without_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm]
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
ensemble_model_avg = EnsembleModel(
models_without_norm_df, models_with_norm_df, learning_rate=0.0058, use_simple_average=True
)
out_dict["ensemble_model_avg"] = ensemble_model_avg
out_dict["model_cfg_without_norm_df"] = model_cfg_without_norm_df
out_dict["model_cfg_with_norm_df"] = model_cfg_with_norm_df
single_DIST_model, single_DIST_model_cfg = find_and_load_model(model_date=DIST_MODEL_DATE_WITH_NORM)
out_dict["DIST_MODEL_DATE_WITH_NORM"] = DIST_MODEL_DATE_WITH_NORM
out_dict["single_DIST_model"] = single_DIST_model
out_dict["single_DIST_model_cfg"] = single_DIST_model_cfg
return out_dict
def prep_data_for_dist(model_cfg, dffix, trial=None):
if "logger" in st.session_state:
st.session_state["logger"].debug("prep_data_for_dist entered")
if trial is None:
trial = st.session_state["trial"]
if isinstance(dffix, dict):
dffix = dffix["value"]
sample_tensor = t.tensor(dffix.loc[:, model_cfg["sample_cols"]].to_numpy(), dtype=t.float32)
if model_cfg["add_line_overlap_feature"]:
sample_tensor = add_line_overlaps_to_sample(trial, sample_tensor)
has_nans = t.any(t.isnan(sample_tensor))
assert not has_nans, "NaNs found in sample tensor"
samplelist_eval = [sample_tensor]
trialslist_eval = [trial]
chars_center_coords_list_eval = None
if model_cfg["norm_coords_by_letter_min_x_y"]:
for sample_idx, _ in enumerate(samplelist_eval):
trialslist_eval, samplelist_eval, chars_center_coords_list_eval = norm_coords_by_letter_min_x_y(
sample_idx,
trialslist_eval,
samplelist_eval,
chars_center_coords_list=chars_center_coords_list_eval,
)
if model_cfg["normalize_by_line_height_and_width"]:
meanlist_eval, stdlist_eval = [], []
for sample_idx, _ in enumerate(samplelist_eval):
(
trialslist_eval,
samplelist_eval,
meanlist_eval,
stdlist_eval,
chars_center_coords_list_eval,
) = norm_coords_by_letter_positions(
sample_idx,
trialslist_eval,
samplelist_eval,
meanlist_eval,
stdlist_eval,
return_mean_std_lists=True,
norm_by_char_averages=model_cfg["norm_by_char_averages"],
chars_center_coords_list=chars_center_coords_list_eval,
add_normalised_values_as_features=model_cfg["add_normalised_values_as_features"],
)
sample_tensor = samplelist_eval[0]
sample_means = t.tensor(model_cfg["sample_means"], dtype=t.float32)
sample_std = t.tensor(model_cfg["sample_std"], dtype=t.float32)
sample_tensor = (sample_tensor - sample_means) / sample_std
sample_tensor = sample_tensor.unsqueeze(0)
if "logger" in st.session_state:
st.session_state["logger"].info(f"Using path {trial['plot_file']} for plotting")
plot_text_boxes_fixations(
fpath=trial["plot_file"],
dpi=250,
screen_res=(1024, 768),
data_dir_sub=None,
set_font_size=True,
font_size=4,
use_words=False,
save_channel_repeats=False,
save_combo_grey_and_rgb=False,
dffix=dffix,
trial=trial,
)
val_set = DSet(
sample_tensor,
None,
t.zeros((1, sample_tensor.shape[1])),
trialslist_eval,
padding_list=[0],
padding_at_end=model_cfg["padding_at_end"],
return_images_for_conv=True,
im_partial_string=model_cfg["im_partial_string"],
input_im_shape=model_cfg["char_plot_shape"],
)
val_loader = dl(val_set, batch_size=1, shuffle=False, num_workers=0)
return val_loader, val_set
def fold_in_seq_dim(out, y=None):
batch_size, seq_len, num_classes = out.shape
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
if y is None:
return out, None
if len(y.shape) > 2:
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
else:
y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
return out, y
def logits_to_pred(out, y=None):
seq_len = out.shape[1]
out, y = fold_in_seq_dim(out, y)
preds = corn_label_from_logits(out)
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
if y is not None:
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
y = y
return preds, y
def get_DIST_preds(dffix, trial, models_dict=None):
algo_choice = "DIST"
if models_dict is None:
if st.session_state["single_DIST_model"] is None or st.session_state["single_DIST_model_cfg"] is None:
st.session_state["single_DIST_model"], st.session_state["single_DIST_model_cfg"] = find_and_load_model(
model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"]
)
if "logger" in st.session_state:
st.session_state["logger"].info("Model is None, reiniting model")
else:
model = st.session_state["single_DIST_model"]
loader, dset = prep_data_for_dist(st.session_state["single_DIST_model_cfg"], dffix, trial)
else:
model = models_dict["single_DIST_model"]
loader, dset = prep_data_for_dist(models_dict["single_DIST_model_cfg"], dffix, trial)
batch = next(iter(loader))
if "cpu" not in str(model.device):
batch = [x.cuda() for x in batch]
try:
out = model(batch)
preds, y = logits_to_pred(out, y=None)
if "logger" in st.session_state:
st.session_state["logger"].debug(
f"y_char_unique are {trial['y_char_unique']} for trial {trial['trial_id']}"
)
if "logger" in st.session_state:
st.session_state["logger"].debug(f"trial keys are {trial.keys()} for trial {trial['trial_id']}")
if "logger" in st.session_state:
st.session_state["logger"].debug(
f"chars_list has len {len(trial['chars_list'])} for trial {trial['trial_id']}"
)
if "logger" in st.session_state:
st.session_state["logger"].debug(f"y_char_unique {trial['y_char_unique']} for trial {trial['trial_id']}")
if len(trial["y_char_unique"]) < 1:
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
else:
y_char_unique = trial["y_char_unique"]
num_lines = trial["num_char_lines"] - 1
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
y_pred_DIST = [y_char_unique[idx] for idx in preds]
dffix[f"line_num_{algo_choice}"] = preds
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
except Exception as e:
if "logger" in st.session_state:
st.session_state["logger"].warning(f"Exception on model(batch) for DIST \n{e}")
return dffix
def get_DIST_ensemble_preds(
dffix,
trial,
model_cfg_without_norm_df,
model_cfg_with_norm_df,
ensemble_model_avg,
):
algo_choice = "DIST-Ensemble"
loader_without_norm, dset_without_norm = prep_data_for_dist(model_cfg_without_norm_df, dffix, trial)
loader_with_norm, dset_with_norm = prep_data_for_dist(model_cfg_with_norm_df, dffix, trial)
batch_without_norm = next(iter(loader_without_norm))
batch_with_norm = next(iter(loader_with_norm))
out = ensemble_model_avg((batch_without_norm, batch_with_norm))
preds, y = logits_to_pred(out[0]["out_avg"], y=None)
if len(trial["y_char_unique"]) < 1:
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
else:
y_char_unique = trial["y_char_unique"]
num_lines = trial["num_char_lines"] - 1
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
if "logger" in st.session_state:
st.session_state["logger"].debug(f"preds are {preds} for trial {trial['trial_id']}")
y_pred_DIST = [y_char_unique[idx] for idx in preds]
dffix[f"line_num_{algo_choice}"] = preds
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
return dffix
def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None, models_dict=None):
if models_dict is None:
if ensemble_model_avg is None and "ensemble_model_avg" not in st.session_state:
if "logger" in st.session_state:
st.session_state["logger"].info("Ensemble Model is None, reiniting model")
dist_models_with_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_True*.ckpt")
dist_models_without_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_False*.ckpt")
models_without_norm_df = [
find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm
]
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
ensemble_model_avg = EnsembleModel(
models_without_norm_df, models_with_norm_df, learning_rate=0.0, use_simple_average=True
)
st.session_state["ensemble_model_avg"] = ensemble_model_avg
st.session_state["model_cfg_without_norm_df"] = model_cfg_without_norm_df
st.session_state["model_cfg_with_norm_df"] = model_cfg_with_norm_df
else:
model_cfg_without_norm_df = st.session_state["model_cfg_without_norm_df"]
model_cfg_with_norm_df = st.session_state["model_cfg_with_norm_df"]
ensemble_model_avg = st.session_state["ensemble_model_avg"]
dffix = get_DIST_ensemble_preds(
dffix,
trial,
st.session_state["model_cfg_without_norm_df"],
st.session_state["model_cfg_with_norm_df"],
st.session_state["ensemble_model_avg"],
)
else:
dffix = get_DIST_ensemble_preds(
dffix,
trial,
models_dict["model_cfg_without_norm_df"],
models_dict["model_cfg_with_norm_df"],
models_dict["ensemble_model_avg"],
)
return dffix
def correct_df(
dffix,
algo_choice,
trial=None,
for_multi=False,
ensemble_model_avg=None,
is_outside_of_streamlit=False,
classic_algos_cfg=None,
models_dict=None,
):
if is_outside_of_streamlit:
stqdm = tqdm
else:
from stqdm import stqdm
if classic_algos_cfg is None:
classic_algos_cfg = st.session_state["classic_algos_cfg"]
if trial is None and not for_multi:
trial = st.session_state["trial"]
if "logger" in st.session_state:
st.session_state["logger"].info(f"Applying {algo_choice} to fixations for trial {trial['trial_id']}")
if isinstance(dffix, dict):
dffix = dffix["value"]
if "x" not in dffix.keys() or "x" not in dffix.keys():
if "logger" in st.session_state:
st.session_state["logger"].warning(f"x or y not in dffix")
if "logger" in st.session_state:
st.session_state["logger"].warning(dffix.columns)
return dffix
if isinstance(algo_choice, list):
algo_choices = algo_choice
repeats = range(len(algo_choice))
else:
algo_choices = [algo_choice]
repeats = range(1)
for algoIdx in stqdm(repeats, desc="Applying correction algorithms"):
algo_choice = algo_choices[algoIdx]
st_proc = time.process_time()
st_wall = time.time()
if algo_choice == "DIST":
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
elif algo_choice == "DIST-Ensemble":
dffix = get_EDIST_preds_with_model_check(dffix, trial, models_dict=models_dict)
elif algo_choice == "Wisdom_of_Crowds_with_DIST":
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
for _ in range(3):
corrections.append(np.asarray(dffix.loc[:, "y_DIST"]))
dffix = apply_woc(dffix, trial, corrections, algo_choice)
elif algo_choice == "Wisdom_of_Crowds_with_DIST_Ensemble":
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
dffix = get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg, models_dict=models_dict)
for _ in range(3):
corrections.append(np.asarray(dffix.loc[:, "y_DIST-Ensemble"]))
dffix = apply_woc(dffix, trial, corrections, algo_choice)
elif algo_choice == "Wisdom_of_Crowds":
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
dffix = apply_woc(dffix, trial, corrections, algo_choice)
else:
algo_cfg = classic_algos_cfg[algo_choice]
dffix = calgo.apply_classic_algo(dffix, trial, algo_choice, algo_cfg)
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
et_proc = time.process_time()
time_proc = et_proc - st_proc
et_wall = time.time()
time_wall = et_wall - st_wall
if "logger" in st.session_state:
st.session_state["logger"].info(f"time_proc {algo_choice} {time_proc}")
if "logger" in st.session_state:
st.session_state["logger"].info(f"time_wall {algo_choice} {time_wall}")
if for_multi:
return dffix
else:
if "start_time" in dffix.columns:
dffix = dffix.drop(axis=1, labels=["start_time", "end_time"])
return dffix, export_csv(dffix, trial)
def set_font_from_chars_list(trial):
if "chars_list" in trial:
chars_df = pd.DataFrame(trial["chars_list"])
line_diffs = np.diff(chars_df.char_y_center.unique())
y_diffs = np.unique(line_diffs)
if len(y_diffs) == 1:
y_diff = y_diffs[0]
else:
y_diff = np.min(y_diffs)
y_diff = round(y_diff * 2) / 2
else:
y_diff = 1 / 0.333 * 18
font_size = y_diff * 0.333 # pixel to point conversion
return round((font_size)*4,ndigits=0)/4
def get_font_and_font_size_from_trial(trial):
font_face, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
if font_size is None and "font_size" in trial:
font_size = trial["font_size"]
elif font_size is None:
font_size = set_font_from_chars_list(trial)
return font_face, font_size
def sigmoid(x):
return 1 / (1 + np.exp(-1 * x))
def matplotlib_plot_df(
dffix,
trial,
algo_choice,
stimulus_prefix="word",
desired_dpi=300,
fix_to_plot=[],
stim_info_to_plot=["Words", "Word boxes"],
box_annotations=None,
):
chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None
if chars_df is not None:
font_face, font_size = get_font_and_font_size_from_trial(trial)
font_size = font_size * 0.65
else:
st.warning("No character or word information available to plot")
if "display_coords" in trial:
desired_width_in_pixels = trial["display_coords"][2] + 1
desired_height_in_pixels = trial["display_coords"][3] + 1
else:
desired_width_in_pixels = 1920
desired_height_in_pixels = 1080
figure_width = desired_width_in_pixels / desired_dpi
figure_height = desired_height_in_pixels / desired_dpi
fig = plt.figure(figsize=(figure_width, figure_height), dpi=desired_dpi)
ax = fig.add_subplot(1, 1, 1)
fig.subplots_adjust(bottom=0)
fig.subplots_adjust(top=1)
fig.subplots_adjust(right=1)
fig.subplots_adjust(left=0)
if "font" in trial and trial["font"] in AVAILABLE_FONTS:
font_to_use = trial["font"]
else:
font_to_use = "DejaVu Sans Mono"
if "font_size" in trial:
font_size = trial["font_size"]
else:
font_size = 20
if f"{stimulus_prefix}s_list" in trial:
add_text_to_ax(
trial[f"{stimulus_prefix}s_list"],
ax,
font_to_use,
prefix=stimulus_prefix,
fontsize=font_size / 3.89,
plot_text=False,
plot_boxes=True if "Word boxes" in stim_info_to_plot else False,
box_annotations=box_annotations,
)
if "chars_list" in trial:
add_text_to_ax(
trial["chars_list"],
ax,
font_to_use,
prefix="char",
fontsize=font_size / 3.89,
plot_text=True if "Words" in stim_info_to_plot else False,
plot_boxes=False,
box_annotations=None,
)
if "Uncorrected Fixations" in fix_to_plot:
ax.plot(dffix.x, dffix.y, label="Raw fixations", color="blue", alpha=0.6, linewidth=0.6)
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
x1 = dffix.x.iloc[range(1, len(dffix.x))].values
y0 = dffix.y.iloc[range(len(dffix.y) - 1)].values
y1 = dffix.y.iloc[range(1, len(dffix.y))].values
xpos = x0
ypos = y0
xdir = x1 - x0
ydir = y1 - y0
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
ax.annotate(
"",
xytext=(X, Y),
xy=(X + 0.001 * dX, Y + 0.001 * dY),
arrowprops=dict(arrowstyle="fancy", color="blue"),
size=8,
alpha=0.3,
)
if "Corrected Fixations" in fix_to_plot:
if isinstance(algo_choice, list):
algo_choices = algo_choice
repeats = range(len(algo_choice))
else:
algo_choices = [algo_choice]
repeats = range(1)
for algoIdx in repeats:
algo_choice = algo_choices[algoIdx]
if f"y_{algo_choice}" in dffix.columns:
ax.plot(
dffix.x,
dffix.loc[:, f"y_{algo_choice}"],
label="Raw fixations",
color=COLORS[algoIdx],
alpha=0.6,
linewidth=0.6,
)
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
x1 = dffix.x.iloc[range(1, len(dffix.x))].values
y0 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(len(dffix.loc[:, f"y_{algo_choice}"]) - 1)].values
y1 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(1, len(dffix.loc[:, f"y_{algo_choice}"]))].values
xpos = x0
ypos = y0
xdir = x1 - x0
ydir = y1 - y0
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
ax.annotate(
"",
xytext=(X, Y),
xy=(X + 0.001 * dX, Y + 0.001 * dY),
arrowprops=dict(arrowstyle="fancy", color=COLORS[algoIdx]),
size=8,
alpha=0.3,
)
ax.set_xlim((0, desired_width_in_pixels))
ax.set_ylim((0, desired_height_in_pixels))
ax.invert_yaxis()
return fig, desired_width_in_pixels, desired_height_in_pixels
def plotly_plot_with_image(
dffix,
trial,
algo_choice,
to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"],
scale_factor=0.5,
):
fig, img_width, img_height = matplotlib_plot_df(
dffix, trial, algo_choice, desired_dpi=300, fix_to_plot=[], stim_info_to_plot=to_plot_list
)
fig.savefig(TEMP_FIGURE_STIMULUS_PATH)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=[0, img_width * scale_factor],
y=[img_height * scale_factor, 0],
mode="markers",
marker_opacity=0,
name="scale_helper",
)
)
fig.update_xaxes(visible=False, range=[0, img_width * scale_factor])
fig.update_yaxes(
visible=False,
range=[img_height * scale_factor, 0],
scaleanchor="x",
)
if "Words" in to_plot_list or "Word boxes" in to_plot_list:
imsource = Image.open(str(TEMP_FIGURE_STIMULUS_PATH))
fig.add_layout_image(
dict(
x=0,
sizex=img_width * scale_factor,
y=0,
sizey=img_height * scale_factor,
xref="x",
yref="y",
opacity=1.0,
layer="below",
sizing="stretch",
source=imsource,
)
)
if "Uncorrected Fixations" in to_plot_list:
duration_scaled = dffix.duration - dffix.duration.min()
duration_scaled = ((duration_scaled / duration_scaled.max()) - 0.5) * 3
duration = sigmoid(duration_scaled) * 50 * scale_factor
fig.add_trace(
go.Scatter(
x=dffix.x * scale_factor,
y=dffix.y * scale_factor,
mode="markers+lines+text",
name="Raw fixations",
marker=dict(
color=COLORS[-1],
symbol="arrow",
size=duration.values,
angleref="previous",
line=dict(color="black", width=duration.values / 10),
),
line_width=2 * scale_factor,
text=np.arange(len(dffix.x)),
textposition="middle right",
textfont=dict(
family="sans serif",
size=18 * scale_factor,
),
hoverinfo="text+x+y",
opacity=0.9,
)
)
if "Corrected Fixations" in to_plot_list:
if isinstance(algo_choice, list):
algo_choices = algo_choice
repeats = range(len(algo_choice))
else:
algo_choices = [algo_choice]
repeats = range(1)
for algoIdx in repeats:
algo_choice = algo_choices[algoIdx]
if f"y_{algo_choice}" in dffix.columns:
fig.add_trace(
go.Scatter(
x=dffix.x * scale_factor,
y=dffix.loc[:, f"y_{algo_choice}"] * scale_factor,
mode="markers",
name=f"{algo_choice} corrected",
marker_color=COLORS[algoIdx],
marker_size=10 * scale_factor,
hoverinfo="text+x+y",
opacity=0.75,
)
)
fig.update_layout(
plot_bgcolor=None,
width=img_width * scale_factor,
height=img_height * scale_factor,
margin={"l": 0, "r": 0, "t": 0, "b": 0},
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8),
)
for trace in fig["data"]:
if trace["name"] == "scale_helper":
trace["showlegend"] = False
return fig
def plot_y_corr(dffix, algo_choice, margin=dict(t=40, l=10, r=10, b=1)):
num_datapoints = len(dffix.x)
layout = dict(
plot_bgcolor="white",
autosize=True,
margin=margin,
xaxis=dict(
title="Fixation Index",
linecolor="black",
range=[-1, num_datapoints + 1],
showgrid=False,
mirror="all",
showline=True,
),
yaxis=dict(
title="y correction",
side="left",
linecolor="black",
showgrid=False,
mirror="all",
showline=True,
),
legend=dict(orientation="v", yanchor="middle", y=0.95, xanchor="left", x=1.05),
)
if isinstance(dffix, dict):
dffix = dffix["value"]
algo_string = algo_choice[0] if isinstance(algo_choice, list) else algo_choice
if f"y_{algo_string}_correction" not in dffix.columns:
st.session_state["logger"].warning("No correction column found in dataframe")
return go.Figure(layout=layout)
if isinstance(dffix, dict):
dffix = dffix["value"]
fig = go.Figure(layout=layout)
if isinstance(algo_choice, list):
algo_choices = algo_choice
repeats = range(len(algo_choice))
else:
algo_choices = [algo_choice]
repeats = range(1)
for algoIdx in repeats:
algo_choice = algo_choices[algoIdx]
fig.add_trace(
go.Scatter(
x=np.arange(num_datapoints),
y=dffix.loc[:, f"y_{algo_choice}_correction"],
mode="markers",
name=f"{algo_choice} y correction",
marker_color=COLORS[algoIdx],
marker_size=3,
showlegend=True,
)
)
fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor="black")
return fig
def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH):
if not os.path.isdir(EXAMPLES_FOLDER):
os.mkdir(EXAMPLES_FOLDER)
if not os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
download_url(OSF_DOWNLAOD_LINK, EXAMPLES_ASC_ZIP_FILENAME)
# os.system(f'''wget -O {EXAMPLES_ASC_ZIP_FILENAME} -c --read-timeout=5 --tries=0 "{OSF_DOWNLAOD_LINK}"''')
if os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
if EXAMPLES_FOLDER_PATH.exists():
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
if len(EXAMPLE_ASC_FILES) != 4:
try:
with zipfile.ZipFile(EXAMPLES_ASC_ZIP_FILENAME, "r") as zip_ref:
zip_ref.extractall(EXAMPLES_FOLDER)
except Exception as e:
st.session_state["logger"].warning(e)
st.session_state["logger"].warning(f"Extracting {EXAMPLES_ASC_ZIP_FILENAME} failed")
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
return EXAMPLE_ASC_FILES
def process_trial_choice_single_csv(trial, algo_choice, file=None):
trial_id = trial["trial_id"]
if "dffix" in trial:
dffix = trial["dffix"]
else:
if file is None:
file = st.session_state["single_csv_file"]
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{file.name}_{trial_id}_2ndInput_chars_channel_sep.png"))
trial["fname"] = str(file.name)
dffix = trial["dffix"] = st.session_state["trials_by_ids_single_csv"][trial_id]["dffix"]
font, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
chars_df = pd.DataFrame(trial["chars_list"])
trial["chars_df"] = chars_df.to_dict()
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
if algo_choice is not None:
dffix, _ = correct_df(dffix, algo_choice, trial)
return dffix, trial, dpi, screen_res, font, font_size
def add_default_font_and_character_props_to_state(trial):
chars_list = trial["chars_list"]
chars_df = pd.DataFrame(trial["chars_list"])
line_diffs = np.diff(chars_df.char_y_center.unique())
y_diffs = np.unique(line_diffs)
if len(y_diffs) == 1:
y_diff = y_diffs[0]
else:
y_diff = np.min(y_diffs)
y_diff = round(y_diff * 2) / 2
x_txt_start = chars_list[0]["char_xmin"]
y_txt_start = chars_list[0]["char_y_center"]
font_face, font_size = get_font_and_font_size_from_trial(trial)
line_height = y_diff
return y_diff, x_txt_start, y_txt_start, font_face, font_size, line_height
def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"):
if use_corrected_fixations:
dffix_copy = copy.deepcopy(dffix)
dffix_copy["y"] = dffix_copy[f"y_{correction_algo}"]
else:
dffix_copy = dffix
initial_landing_position_own_vals = anf.initial_landing_position_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
second_pass_duration_own_vals = anf.second_pass_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
number_of_fixations_own_vals = anf.number_of_fixations_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
initial_fixation_duration_own_vals = anf.initial_fixation_duration_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
first_of_many_duration_own_vals = anf.first_of_many_duration_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
total_fixation_duration_own_vals = anf.total_fixation_duration_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
gaze_duration_own_vals = anf.gaze_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
go_past_duration_own_vals = anf.go_past_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
initial_landing_distance_own_vals = anf.initial_landing_distance_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
landing_distances_own_vals = anf.landing_distances_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
number_of_regressions_in_own_vals = anf.number_of_regressions_in_own(trial, dffix_copy, prefix).set_index(
f"{prefix}_index"
)
own_measure_df = pd.concat(
[
df.drop(prefix, axis=1)
for df in [
number_of_fixations_own_vals,
initial_fixation_duration_own_vals,
first_of_many_duration_own_vals,
total_fixation_duration_own_vals,
gaze_duration_own_vals,
go_past_duration_own_vals,
second_pass_duration_own_vals,
initial_landing_position_own_vals,
initial_landing_distance_own_vals,
landing_distances_own_vals,
number_of_regressions_in_own_vals,
]
],
axis=1,
)
own_measure_df[prefix] = number_of_fixations_own_vals[prefix]
first_column = own_measure_df.pop(prefix)
own_measure_df.insert(0, prefix, first_column)
own_measure_df.insert(0, f"{prefix}_num", np.arange((own_measure_df.shape[0])))
return own_measure_df