import json import logging from typing import Any, Dict, List, Optional from uuid import uuid4 from pydantic import PrivateAttr from sqlalchemy import Column, DateTime, String, create_engine, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from obsei.misc.utils import obj_to_json from obsei.workflow.base_store import BaseStore from obsei.workflow.workflow import WorkflowState, WorkflowConfig, Workflow logger = logging.getLogger(__name__) Base = declarative_base() # type: Any class ORMBase(Base): # type: ignore __abstract__ = True id = Column(String(100), default=lambda: str(uuid4()), primary_key=True) created = Column(DateTime, server_default=func.now()) updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now()) class WorkflowTable(ORMBase): __tablename__ = "workflow" config = Column(String(2000), nullable=False) source_state = Column(String(500), nullable=True) sink_state = Column(String(500), nullable=True) analyzer_state = Column(String(500), nullable=True) class WorkflowStore(BaseStore): _session: sessionmaker = PrivateAttr() def __init__(self, url: str = "sqlite:///obsei.db", **data: Any): super().__init__(**data) engine = create_engine(url) ORMBase.metadata.create_all(engine) local_session = sessionmaker(bind=engine) self._session = local_session() def get(self, identifier: str) -> Optional[Workflow]: row = self._session.query(WorkflowTable).filter_by(id=identifier).all() return ( None if row is None or len(row) == 0 else self._convert_sql_row_to_workflow_data(row[0]) ) def get_all(self) -> List[Workflow]: rows = self._session.query(WorkflowTable).all() return [self._convert_sql_row_to_workflow_data(row) for row in rows] def get_workflow_state(self, identifier: str) -> Optional[WorkflowState]: row = ( self._session.query( WorkflowTable.source_state, WorkflowTable.analyzer_state, WorkflowTable.sink_state, ) .filter(id=identifier) .all() ) return ( None if row is None or len(row) == 0 else self._convert_sql_row_to_workflow_state(row[0]) ) def get_source_state(self, identifier: str) -> Optional[Dict[str, Any]]: row = ( self._session.query(WorkflowTable.source_state) .filter(WorkflowTable.id == identifier) .all() ) return None if row[0].source_state is None else json.loads(row[0].source_state) def get_sink_state(self, identifier: str) -> Optional[Dict[str, Any]]: row = self._session.query(WorkflowTable.sink_state).filter(id=identifier).all() return None if row[0].sink_state is None else json.loads(row[0].sink_state) def get_analyzer_state(self, identifier: str) -> Optional[Dict[str, Any]]: row = self._session.query(WorkflowTable.analyzer_state).filter(id=identifier).all() return ( None if row[0].analyzer_state is None else json.loads(row[0].analyzer_state) ) def add_workflow(self, workflow: Workflow) -> None: self._session.add( WorkflowTable( id=workflow.id, config=obj_to_json(workflow.config), source_state=obj_to_json(workflow.states.source_state), sink_state=obj_to_json(workflow.states.sink_state), analyzer_state=obj_to_json(workflow.states.analyzer_state), ) ) self._commit_transaction() def update_workflow(self, workflow: Workflow) -> None: self._session.query(WorkflowTable).filter_by(id=workflow.id).update( { WorkflowTable.config: obj_to_json(workflow.config), WorkflowTable.source_state: obj_to_json(workflow.states.source_state), WorkflowTable.sink_state: obj_to_json(workflow.states.sink_state), WorkflowTable.analyzer_state: obj_to_json( workflow.states.analyzer_state ), }, synchronize_session=False, ) self._commit_transaction() def update_workflow_state(self, workflow_id: str, workflow_state: WorkflowState) -> None: self._session.query(WorkflowTable).filter_by(id=workflow_id).update( { WorkflowTable.source_state: obj_to_json(workflow_state.source_state), WorkflowTable.sink_state: obj_to_json(workflow_state.sink_state), WorkflowTable.analyzer_state: obj_to_json( workflow_state.analyzer_state ), }, synchronize_session=False, ) self._commit_transaction() def update_source_state(self, workflow_id: str, state: Dict[str, Any]) -> None: self._session.query(WorkflowTable).filter_by(id=workflow_id).update( {WorkflowTable.source_state: obj_to_json(state)}, synchronize_session=False ) self._commit_transaction() def update_sink_state(self, workflow_id: str, state: Dict[str, Any]) -> None: self._session.query(WorkflowTable).filter_by(id=workflow_id).update( {WorkflowTable.sink_state: obj_to_json(state)}, synchronize_session=False ) self._commit_transaction() def update_analyzer_state(self, workflow_id: str, state: Dict[str, Any]) -> None: self._session.query(WorkflowTable).filter_by(id=workflow_id).update( {WorkflowTable.analyzer_state: obj_to_json(state)}, synchronize_session=False, ) self._commit_transaction() def delete_workflow(self, id: str) -> None: self._session.query(WorkflowTable).filter_by(id=id).delete() self._commit_transaction() def _commit_transaction(self) -> Any: try: self._session.commit() except Exception as ex: logger.error(f"Transaction rollback: {ex.__cause__}") # Rollback is important here otherwise self.session will be in inconsistent state and next call will fail self._session.rollback() raise ex @staticmethod def _convert_sql_row_to_workflow_state(row: Any) -> Optional[WorkflowState]: if row is None: return None source_state_dict = ( None if row.source_state is None else json.loads(row.source_state) ) sink_state_dict = None if row.sink_state is None else json.loads(row.sink_state) analyzer_state_dict = ( None if row.analyzer_state is None else json.loads(row.analyzer_state) ) workflow_states: Optional[WorkflowState] = None if source_state_dict or sink_state_dict or analyzer_state_dict: workflow_states = WorkflowState( source_state=source_state_dict, sink_state=sink_state_dict, analyzer_state=analyzer_state_dict, ) return workflow_states @staticmethod def _convert_sql_row_to_workflow_data(row: Any) -> Workflow: config_dict = json.loads(row.config) workflow = Workflow( id=row.id, config=WorkflowConfig(**config_dict), states=WorkflowStore._convert_sql_row_to_workflow_state(row), ) return workflow