Spaces:
Running
Running
| # query를 자동으로 읽고 쓰는 container를 정의 | |
| from __future__ import annotations | |
| import re | |
| from typing import Callable, TypeVar | |
| import streamlit as st | |
| __all__ = ["QueryWrapper", "get_base_url"] | |
| T = TypeVar("T") | |
| import hashlib | |
| import urllib.parse | |
| def SHA1(msg: str) -> str: | |
| return hashlib.sha1(msg.encode()).hexdigest()[:8] | |
| def get_base_url(): | |
| session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0] | |
| return urllib.parse.urlunparse( | |
| [session.client.request.protocol, session.client.request.host, "", "", "", ""] | |
| ) | |
| class QueryWrapper: | |
| queries: dict[str, _QueryWrapper] = {} # 기록용 | |
| def __init__(self, query: str, label: str | None = None, use_hash: bool = True): | |
| self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper( | |
| query, label, use_hash | |
| ) | |
| def __call__(self, *args, **kwargs): | |
| return self.__wrapper(*args, **kwargs) | |
| def get_sharable_link(cls): | |
| # for k, v in cls.queries.items(): | |
| # print(f"{k}: {v}") | |
| return re.sub( | |
| "&+", "&", "&".join([str(v) for k, v in cls.queries.items()]) | |
| ).strip("&") | |
| class _QueryWrapper: | |
| ILLEGAL_CHARS = "&/=?" | |
| def __init__(self, query: str, label: str | None = None, use_hash: bool = True): | |
| self.query = query | |
| self.label = label or query | |
| self.use_hash = use_hash | |
| self.hash_table = {} | |
| self.key = None | |
| def __call__( | |
| self, | |
| base_container: Callable, | |
| legal_list: list[T], | |
| default: T | list[T] | None = None, | |
| *, | |
| key: str | None = None, | |
| **kwargs, | |
| ) -> T | list[T] | None: | |
| val_from_query = st.query_params.get_all(self.query.lower()) | |
| # print(val_from_query) | |
| legal = len(val_from_query) > 0 | |
| self.key = key or self.label | |
| self.hash_table = {SHA1(str(v)): v for v in legal_list} | |
| # filter out illegal values | |
| if legal and legal_list: | |
| val_from_query = [v for v in val_from_query if v in self.hash_table] | |
| # print(self.label, val_from_query, legal) | |
| if legal: | |
| selected = [self.hash_table[v] for v in val_from_query] | |
| elif default: | |
| selected = default | |
| elif self.label in st.session_state: | |
| selected = st.session_state[self.label] | |
| if legal_list: | |
| if isinstance(selected, list): | |
| selected = [v for v in selected if v in legal_list] | |
| elif selected not in legal_list: | |
| selected = [] | |
| else: | |
| selected = [] | |
| if selected is None: | |
| pass | |
| elif len(selected) == 1 and base_container in [st.selectbox, st.radio]: | |
| selected = selected[0] | |
| # print(self.label, selected) | |
| if base_container == st.checkbox: | |
| selected = base_container( | |
| self.label, | |
| legal_list, | |
| index=legal_list.index(selected) if selected in legal_list else None, | |
| key=self.key, | |
| **kwargs, | |
| ) | |
| elif base_container == st.multiselect: | |
| selected = base_container( | |
| self.label, legal_list, default=selected, key=self.key, **kwargs | |
| ) | |
| elif base_container == st.radio: | |
| selected = base_container( | |
| self.label, | |
| legal_list, | |
| index=legal_list.index(selected) if selected in legal_list else None, | |
| key=self.key, | |
| **kwargs, | |
| ) | |
| elif base_container == st.selectbox: | |
| selected = base_container( | |
| self.label, | |
| legal_list, | |
| index=legal_list.index(selected) if selected in legal_list else None, | |
| key=self.key, | |
| **kwargs, | |
| ) | |
| else: | |
| selected = base_container(self.label, legal_list, key=self.key, **kwargs) | |
| return st.session_state[self.key] | |
| def __str__(self): | |
| selected = st.session_state.get(self.key, None) | |
| if isinstance(selected, str): | |
| return f"{self.query.lower()}={SHA1(selected)}" | |
| elif isinstance(selected, list): | |
| return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected]) | |
| else: | |
| return "" | |