"""
Cell State Model
"""
import warnings
from collections import OrderedDict
from typing import Dict, Generator, List, Optional, Tuple, Union
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import trange
from astir.data import SCDataset
from .abstract import AstirModel
from .cellstate_recognet import StateRecognitionNet
[docs]class CellStateModel(AstirModel):
"""Class to perform statistical inference to on the activation
of states (pathways) across cells
:param dset: the input gene expression dataset, defaults to None
:param const: See parameter ``const`` in
:func:`astir.models.StateRecognitionNet`, defaults to 2
:param dropout_rate: See parameter ``dropout_rate`` in
:func:`astir.models.StateRecognitionNet`, defaults to 0
:param batch_norm: See parameter ``batch_norm`` in
:func:`astir.models.StateRecognitionNet`, defaults to False
:param random_seed: the random seed number to reproduce results, defaults to 42
:param dtype: torch datatype to use in the model, defaults to torch.float64
:param device: torch.device's cpu or gpu, defaults to torch.device("cpu")
"""
def __init__(
self,
dset: SCDataset = None,
const: int = 2,
dropout_rate: float = 0,
batch_norm: bool = False,
random_seed: int = 42,
dtype: torch.dtype = torch.float64,
device: torch.device = torch.device("cpu"),
) -> None:
super().__init__(dset, random_seed, dtype, device)
# Setting random seeds
self.random_seed = random_seed
torch.manual_seed(self.random_seed)
np.random.seed(self.random_seed)
self._optimizer: Optional[torch.optim.Adam] = None
self.const, self.dropout_rate, self.batch_norm = const, dropout_rate, batch_norm
if self._dset is not None:
self._param_init()
# Convergence flag
self._is_converged = False
def _param_init(self) -> None:
"""Initializes sets of parameters"""
if self._dset is None:
raise Exception("the dataset is not provided")
N = len(self._dset)
C = self._dset.get_n_classes()
G = self._dset.get_n_features()
initializations = {
"log_sigma": torch.log(self._dset.get_sigma().mean()),
"mu": torch.reshape(self._dset.get_mu(), (1, -1)),
}
# Include beta or not
d = torch.distributions.Uniform(
torch.tensor(0.0, dtype=self._dtype), torch.tensor(1.5, dtype=self._dtype)
)
initializations["log_w"] = torch.log(d.sample((C, G)))
self._variables = {
n: i.to(self._device).detach().clone().requires_grad_()
for (n, i) in initializations.items()
}
self._data = {
"rho": self._dset.get_marker_mat().T.to(self._device),
}
self._recog = StateRecognitionNet(
C,
G,
const=self.const,
dropout_rate=self.dropout_rate,
batch_norm=self.batch_norm,
).to(device=self._device, dtype=self._dtype)
[docs] def load_hdf5(self, hdf5_name: str) -> None:
"""Initializes Cell State Model from a hdf5 file type
:param hdf5_name: file path
"""
self._assignment = pd.read_hdf(
hdf5_name, "cellstate_model/cellstate_assignments"
)
with h5py.File(hdf5_name, "r") as f:
grp = f["cellstate_model"]
param = grp["parameters"]
self._variables = {
"mu": torch.tensor(np.array(param["mu"])),
"log_sigma": torch.tensor(np.array(param["log_sigma"])),
"log_w": torch.tensor(np.array(param["log_w"])),
}
self._data = {"rho": torch.tensor(np.array(param["rho"]))}
self._losses = torch.tensor(np.array(grp["losses"]["losses"]))
rec = grp["recog_net"]
hidden1_W = torch.tensor(np.array(rec["linear1.weight"]))
hidden2_W = torch.tensor(np.array(rec["linear2.weight"]))
hidden3_mu_W = torch.tensor(np.array(rec["linear3_mu.weight"]))
hidden3_std_W = torch.tensor(np.array(rec["linear3_std.weight"]))
state_dict = {
"linear1.weight": hidden1_W,
"linear1.bias": torch.tensor(np.array(rec["linear1.bias"])),
"linear2.weight": hidden2_W,
"linear2.bias": torch.tensor(np.array(rec["linear2.bias"])),
"linear3_mu.weight": hidden3_mu_W,
"linear3_mu.bias": torch.tensor(np.array(rec["linear3_mu.bias"])),
"linear3_std.weight": hidden3_std_W,
"linear3_std.bias": torch.tensor(np.array(rec["linear3_std.bias"])),
}
state_dict = OrderedDict(state_dict)
self._recog = StateRecognitionNet(
hidden3_mu_W.shape[0],
hidden1_W.shape[1],
const=self.const,
dropout_rate=self.dropout_rate,
batch_norm=self.batch_norm,
).to(device=self._device, dtype=self._dtype)
self._recog.load_state_dict(state_dict)
self._recog.eval()
def _loss_fn(
self,
mu_z: torch.Tensor,
std_z: torch.Tensor,
z_sample: torch.Tensor,
y_in: torch.Tensor,
) -> torch.Tensor:
"""Returns the calculated loss
:param mu_z: the predicted mean of z
:param std_z: the predicted standard deviation of z
:param z_sample: the sampled z values
:param y_in: the input data
:return: the loss
"""
S = y_in.shape[0]
# log posterior q(z) approx p(z|y)
q_z_dist = torch.distributions.Normal(loc=mu_z, scale=torch.exp(std_z))
log_q_z = q_z_dist.log_prob(z_sample)
# log likelihood p(y|z)
rho_w = torch.mul(self._data["rho"], torch.exp(self._variables["log_w"]))
mean = self._variables["mu"] + torch.matmul(z_sample, rho_w)
std = torch.exp(self._variables["log_sigma"]).reshape(1, -1)
p_y_given_z_dist = torch.distributions.Normal(loc=mean, scale=std)
log_p_y_given_z = p_y_given_z_dist.log_prob(y_in)
# log prior p(z)
p_z_dist = torch.distributions.Normal(0, 1)
log_p_z = p_z_dist.log_prob(z_sample)
loss = (1 / S) * (
torch.sum(log_q_z) - torch.sum(log_p_y_given_z) - torch.sum(log_p_z)
)
return loss
def _forward(
self,
Y: Optional[torch.Tensor],
X: Optional[torch.Tensor] = None,
design: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""One forward pass
:param Y: dataset to do forward pass on
:return: mu_z, std_z, z_sample
"""
mu_z, std_z = self._recog(Y)
std = torch.exp(std_z)
eps = torch.randn_like(std)
z_sample = eps * std + mu_z
return mu_z, std_z, z_sample
[docs] def fit(
self,
max_epochs: int = 50,
learning_rate: float = 1e-3,
batch_size: int = 128,
delta_loss: float = 1e-3,
delta_loss_batch: int = 10,
msg: str = "",
) -> None:
""" Runs train loops until the convergence reaches delta_loss for\
delta_loss_batch sizes or for max_epochs number of times
:param max_epochs: number of train loop iterations, defaults to 50
:param learning_rate: the learning rate, defaults to 0.01
:param batch_size: the batch size, defaults to 128
:param delta_loss: stops iteration once the loss rate reaches\
delta_loss, defaults to 0.001
:param delta_loss_batch: the batch size to consider delta loss,\
defaults to 10
:param msg: iterator bar message, defaults to empty string
"""
if self._dset is None:
raise Exception("the dataset is not provided")
# Returns early if the model has already converged
if self._is_converged:
return
# Create an optimizer if there is no optimizer
if self._optimizer is None:
opt_params = list(self._recog.parameters())
opt_params += list(self._variables.values()) # type: ignore
self._optimizer = torch.optim.Adam(opt_params, lr=learning_rate)
iterator = trange(
max_epochs,
desc="training restart" + msg,
unit="epochs",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]",
)
train_iterator = DataLoader(
self._dset, batch_size=min(batch_size, len(self._dset))
)
for ep in iterator:
for i, (y_in, x_in, _) in enumerate(train_iterator):
self._optimizer.zero_grad()
mu_z, std_z, z_samples = self._forward(
x_in.type(self._dtype).to(self._device)
)
loss = self._loss_fn(
mu_z, std_z, z_samples, x_in.type(self._dtype).to(self._device)
)
loss.backward()
self._optimizer.step()
loss_detached = loss.cpu().detach().item()
self._losses = torch.cat(
(self._losses, torch.tensor([loss_detached], dtype=self._dtype))
)
if len(self._losses) > delta_loss_batch:
curr_mean = torch.mean(self._losses[-delta_loss_batch:])
prev_mean = torch.mean(self._losses[-delta_loss_batch - 1 : -1])
curr_delta_loss = (prev_mean - curr_mean) / prev_mean
delta_cond_met = 0 <= curr_delta_loss.item() < delta_loss
else:
delta_cond_met = False
iterator.set_postfix_str("current loss: " + str(round(loss_detached, 1)))
if delta_cond_met:
self._is_converged = True
iterator.close()
break
g = self.get_final_mu_z().detach().cpu().numpy()
self._assignment = pd.DataFrame(g)
self._assignment.columns = self._dset.get_classes()
self._assignment.index = self._dset.get_cell_names()
[docs] def get_recognet(self) -> StateRecognitionNet:
"""Getter for the recognition net
:return: the recognition net
"""
return self._recog
[docs] def get_final_mu_z(self, new_dset: Optional[SCDataset] = None) -> torch.Tensor:
"""Returns the mean of the predicted z values for each core
:param new_dset: returns the predicted z values of this dataset on
the existing model. If None, it predicts using the existing
dataset, defaults to None
:return: the mean of the predicted z values for each core
"""
if self._dset is None:
raise Exception("the dataset is not provided")
if new_dset is None:
_, x_in, _ = self._dset[:] # should be the scaled
# one
else:
_, x_in, _ = new_dset[:]
final_mu_z, _, _ = self._forward(x_in.type(self._dtype).to(self._device))
return final_mu_z
[docs] def get_correlations(self) -> np.array:
"""Returns a C (# of pathways) X G (# of proteins) matrix
where each element represents the correlation value of the pathway
and the protein
:return: matrix of correlation between all pathway and protein pairs.
"""
if self._dset is None:
raise Exception("No dataset input to the model")
state_assignment = self.get_final_mu_z().detach().cpu().numpy()
y_in = self._dset.get_exprs()
feature_names = self._dset.get_features()
state_names = self._dset.get_classes()
G = self._dset.get_n_features()
C = self._dset.get_n_classes()
corr_mat = np.zeros((C, G))
# Make a matrix of correlations between all states and proteins
for c, state in enumerate(state_names):
for g, feature in enumerate(feature_names):
states = state_assignment[:, c]
protein = y_in[:, g].cpu()
corr_mat[c, g] = np.corrcoef(protein, states)[0, 1]
return corr_mat
[docs] def diagnostics(self) -> pd.DataFrame:
"""Run diagnostics on cell type assignments
See :meth:`astir.Astir.diagnostics_cellstate` for full documentation
"""
if self._dset is None:
raise Exception("the dataset is not provided")
feature_names = self._dset.get_features()
state_names = self._dset.get_classes()
corr_mat = self.get_correlations()
# Correlation values of all marker proteins
marker_mat = self._dset.get_marker_mat().T.cpu().numpy()
marker_corr = marker_mat * corr_mat
marker_corr[marker_mat == 0] = np.inf
# Smallest correlation values for each pathway
min_marker_corr = np.min(marker_corr, axis=1).reshape(-1, 1)
min_marker_proteins = np.take(feature_names, np.argmin(marker_corr, axis=1))
# Correlation values of all non marker proteins
non_marker_mat = 1 - self._dset.get_marker_mat().T.cpu().numpy()
non_marker_corr = non_marker_mat * corr_mat
non_marker_corr[non_marker_mat == 0] = -np.inf
# Any correlation values where non marker proteins is greater than
# the smallest correlation values of marker proteins
bad_corr_marker = np.array(non_marker_corr > min_marker_corr, dtype=np.int32)
# Problem summary
indices = np.argwhere(bad_corr_marker > 0)
col_names = [
"pathway",
"protein A",
"correlation of protein A",
"protein B",
"correlation of protein B",
"note",
]
problems = []
for index in indices:
state_index = index[0]
protein_index = index[1]
state = state_names[index[0]]
marker_protein = min_marker_proteins[state_index]
non_marker_protein = feature_names[protein_index]
problem = {
"pathway": state,
"marker_protein": marker_protein,
"corr_of_marker_protein": min_marker_corr[state_index][0],
"non_marker_protein": non_marker_protein,
"corr_of_non_marker_protein": non_marker_corr[
state_index, protein_index
],
"msg": "{} is marker for {} but {} isn't".format(
marker_protein, state, non_marker_protein
),
}
problems.append(problem)
if len(problems) > 0:
df_issues = pd.DataFrame(problems)
df_issues.columns = col_names
else:
df_issues = pd.DataFrame(columns=col_names)
return df_issues
class NotClassifiableError(RuntimeError):
"""Raised when the input data is not classifiable."""
pass