"""Inference utilities."""
import logging
import torch
import numpy as np
from paccmann_predictor.models.paccmann import MCA
from pytoda.transforms import Compose
from pytoda.smiles.transforms import ToTensor
from configuration import (
    MODEL_WEIGHTS_URI,
    MODEL_PARAMS,
    SMILES_LANGUAGE,
    SMILES_TRANSFORMS,
)

logger = logging.getLogger("openapi_server:inference")
# NOTE: to avoid segfaults
torch.set_num_threads(1)


def predict(
    smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False
) -> dict:
    """
    Run PaccMann prediction.

    Args:
        smiles (str): SMILES representing a compound.
        gene_expression (np.ndarray): gene expression data.
        estimate_confidence (bool, optional): estimate confidence of the
            prediction. Defaults to False.
    Returns:
        dict: the prediction dictionaty from the model.
    """
    logger.debug("running predict.")
    logger.debug("gene expression shape: {}.".format(gene_expression.shape))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.debug("device selected: {}.".format(device))
    logger.debug("loading model for prediction.")
    model = MCA(MODEL_PARAMS)
    model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device))
    model.eval()
    if estimate_confidence:
        logger.debug("associating SMILES language for confidence estimates.")
        model._associate_language(SMILES_LANGUAGE)
    logger.debug("model loaded.")
    logger.debug("set up the transformation.")
    smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)])
    logger.debug("starting the prediction.")
    with torch.no_grad():
        _, prediction_dict = model(
            smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1),
            torch.tensor(gene_expression).float(),
            confidence=estimate_confidence,
        )
    logger.debug("successful prediction.")
    return prediction_dict