github_search_visualizations / task_visualizations.py
lambdaofgod's picture
task counts visualization
732d800
import pandas as pd
import ast
import json
import plotly.express as px
import plotly.graph_objects as go
class TaskVisualizations:
def __init__(
self, task_counts_path, selected_task_counts_path, tasks_with_areas_path
):
self.tasks_with_areas_df = self.load_tasks_with_areas_df(
task_counts_path, tasks_with_areas_path
)
self.selected_tasks_with_areas_df = self.load_tasks_with_areas_df(
selected_task_counts_path, tasks_with_areas_path
)
@classmethod
def load_tasks_with_areas_df(
cls, task_counts_path, tasks_with_areas_path="data/paperswithcode_tasks.csv"
):
task_counts_df = pd.read_csv(task_counts_path)
raw_tasks_with_areas_df = pd.read_csv(tasks_with_areas_path)
return raw_tasks_with_areas_df.merge(task_counts_df, on="task")
@classmethod
def get_topk_merge_others(cls, df, by_col, val_col, k=10, val_threshold=1000):
sorted_df = df.copy().sort_values(val_col, ascending=False)
topk_dict = (
sorted_df[[by_col, val_col]].set_index(by_col).iloc[:k].to_dict()[val_col]
)
print(topk_dict)
sorted_df[by_col] = sorted_df[by_col].apply(
lambda k: k
if k in topk_dict.keys() and topk_dict[k] >= val_threshold
else "other"
)
sorted_df = sorted_df.groupby(by_col).agg({val_col: sum})
return sorted_df
@classmethod
def get_displayed_tasks_with_areas_df(cls, tasks_with_areas_df, min_task_count):
displayed_tasks_with_areas_df = tasks_with_areas_df.dropna().copy()
displayed_tasks_with_areas_df["task"] = displayed_tasks_with_areas_df.apply(
lambda r: r["task"] if r["count"] >= min_task_count else "other", axis=1
)
displayed_tasks_with_areas_df = (
displayed_tasks_with_areas_df.groupby("area")
.apply(
lambda df: cls.get_topk_merge_others(
df, "task", "count", val_threshold=min_task_count
)
)
.reset_index()
)
displayed_tasks_with_areas_df["task"] = (
displayed_tasks_with_areas_df["task"]
+ " "
+ displayed_tasks_with_areas_df["count"].apply(str)
)
return displayed_tasks_with_areas_df
def get_tasks_sunburst(self, min_task_count, which_df="selected"):
if which_df == "selected":
df = self.selected_tasks_with_areas_df
else:
df = self.tasks_with_areas_df
displayed_tasks_with_areas_df = self.get_displayed_tasks_with_areas_df(
df, min_task_count
)
return px.sunburst(
displayed_tasks_with_areas_df, path=["area", "task"], values="count"
)