import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from astir.data import SCDataset
[docs]class AstirModel:
"""Abstract class to perform statistical inference to assign. This module is the super class of
`CellTypeModel` and `CellStateModel` and is not supposed to be instantiated.
"""
def __init__(
self,
dset: Optional[SCDataset],
random_seed: int,
dtype: torch.dtype,
device: torch.device = torch.device("cpu"),
) -> None:
if not isinstance(random_seed, int):
raise NotClassifiableError("Random seed is expected to be an integer.")
torch.manual_seed(random_seed)
np.random.seed(random_seed)
if dtype != torch.float32 and dtype != torch.float64:
raise NotClassifiableError(
"dtype must be one of torch.float32 and torch.float64."
)
elif dset is not None and dtype != dset.get_dtype():
raise NotClassifiableError("dtype must be the same as `dset`.")
self._dtype: torch.dtype = dtype
self._data: Dict[str, torch.Tensor] = {}
self._variables: Dict[str, torch.Tensor] = {}
self._losses: torch.Tensor = torch.tensor([], dtype=self._dtype)
self._assignment: pd.DataFrame = pd.DataFrame()
self._dset = dset
# self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._device = device
self._is_converged = False
[docs] def get_losses(self) -> torch.Tensor:
"""Getter for losses.
:return: self.losses
"""
if len(self._losses) == 0:
raise Exception("The model has not been trained yet")
return self._losses
[docs] def get_scdataset(self) -> SCDataset:
"""Getter for the `SCDataset`.
:return: `self._dset`
"""
if self._dset is None:
raise Exception("the dataset is not provided")
return self._dset
[docs] def get_data(self) -> Dict[str, torch.Tensor]:
"""Get model data
:return: data
"""
if self._data == {}:
raise Exception("The model has not been initialized yet")
return self._data
[docs] def get_variables(self) -> Dict[str, torch.Tensor]:
"""Returns all variables
:return: self._variables
"""
if self._variables == {}:
raise Exception("The model has not been initialized yet")
return self._variables
[docs] def is_converged(self) -> bool:
"""Returns True if the model converged
:return: self._is_converged
"""
return self._is_converged
[docs] def get_assignment(self) -> pd.DataFrame:
"""Get the final assignment of the dataset.
:return: the final assignment of the dataset
"""
if self._assignment.shape == (0, 0):
raise Exception("The model has not been trained yet")
return self._assignment
def _param_init(self) -> None:
"""Initializes parameters and design matrices."""
raise NotImplementedError("AbstractModel is not supposed to be instantiated.")
def _forward(
self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""One forward pass"""
raise NotImplementedError("AbstractModel is not supposed to be instantiated.")
[docs] def fit(
self,
max_epochs: int,
learning_rate: float,
batch_size: int,
delta_loss: float,
delta_loss_batch: int,
msg: str,
) -> None:
"""Runs train loops until the convergence reaches delta_loss for
delta_loss_batch sizes or for max_epochs number of times
"""
raise NotImplementedError("AbstractModel is not supposed to be instantiated.")
class NotClassifiableError(RuntimeError):
"""Raised when the input data is not classifiable."""
pass