"""
Cell Type Model
"""
import re
import warnings
from collections import OrderedDict
from typing import Dict, Generator, List, Optional, Tuple, Union
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Dirichlet
from scipy import stats
from sklearn.preprocessing import StandardScaler
from torch.autograd import Variable
from torch.distributions import (
LowRankMultivariateNormal,
MultivariateNormal,
Normal,
StudentT,
)
from torch.utils.data import DataLoader, Dataset
from tqdm import trange
from astir.data import SCDataset
from .abstract import AstirModel
from .celltype_recognet import TypeRecognitionNet
[docs]class CellTypeModel(AstirModel):
"""Class to perform statistical inference to assign cells to cell types.
:param dset: the input gene expression dataframe
:param random_seed: the random seed for parameter initialization, defaults to 1234
:param dtype: the data type of parameters, should be the same as `dset`, defaults to
torch.float64
"""
def __init__(
self,
dset: Optional[SCDataset] = None,
random_seed: int = 1234,
dtype: torch.dtype = torch.float64,
device: torch.device = torch.device("cpu"),
) -> None:
super().__init__(dset, random_seed, dtype, device)
if dset is not None:
self._param_init()
def _param_init(self) -> None:
"""Initializes parameters and design matrices."""
if self._dset is None:
raise Exception("the dataset is not provided")
G = self._dset.get_n_features()
C = self._dset.get_n_classes()
self._recog = TypeRecognitionNet(
self._dset.get_n_classes(), self._dset.get_n_features()
).to(self._device, dtype=self._dtype)
# Establish data
self._data: Dict[str, torch.Tensor] = {
# "log_alpha": torch.log(torch.ones(C + 1, dtype=self._dtype) / (C + 1)).to(
# self._device
# ),
"rho": self._dset.get_marker_mat().to(self._device),
}
self._alpha_prior = Dirichlet(
torch.ones(C + 1, dtype=self._dtype).to(self._device) * (C + 1)
)
# Initialize mu, log_delta
delta_init_mean = torch.log(
torch.log(torch.tensor(3.0, dtype=self._dtype))
) # the log of the log of this is the multiplier
t = torch.distributions.Normal(
# delta_init_mean.clone().detach().to(self._dtype),
torch.tensor(0, dtype=self._dtype),
torch.tensor(0.1, dtype=self._dtype),
)
log_delta_init = t.sample((G, C + 1))
mu_init = torch.log(
torch.tensor(self._dset.get_mu_init(), dtype=self._dtype)
).to(self._device)
# mu_init = torch.log(self._dset.get_mu()).to(self._device)
# mu_init = mu_init - (
# self._data["rho"] * torch.exp(log_delta_init).to(self._device)
# ).mean(1)
mu_init = mu_init.reshape(-1, 1)
# Create initialization dictionary
initializations = {
"mu": mu_init,
"log_sigma": torch.log(self._dset.get_sigma()).to(self._device),
"log_delta": log_delta_init,
"p": torch.zeros((G, C + 1), dtype=self._dtype, device=self._device),
"alpha_logits": torch.ones(C + 1, dtype=self._dtype, device=self._device),
}
P = self._dset.get_design().shape[1]
# Add additional columns of mu for anything in the design matrix
initializations["mu"] = torch.cat(
[
initializations["mu"],
torch.zeros((G, P - 1), dtype=self._dtype, device=self._device),
],
1,
)
# Create trainable variables
self._variables: Dict[str, torch.Tensor] = {}
for (n, v) in initializations.items():
self._variables[n] = Variable(v.clone()).to(self._device)
self._variables[n].requires_grad = True
[docs] def load_hdf5(self, hdf5_name: str) -> None:
"""Initializes Cell Type Model from a hdf5 file type
:param hdf5_name: file path
"""
self._assignment = pd.read_hdf(hdf5_name, "celltype_model/celltype_assignments")
with h5py.File(hdf5_name, "r") as f:
grp = f["celltype_model"]
param = grp["parameters"]
self._variables = {
"mu": torch.tensor(np.array(param["mu"])),
"log_sigma": torch.tensor(np.array(param["log_sigma"])),
"log_delta": torch.tensor(np.array(param["log_delta"])),
"p": torch.tensor(np.array(param["p"])),
"alpha_logits": torch.tensor(np.array(param["alpha_logits"])),
}
self._data = {
"rho": torch.tensor(np.array(param["rho"])),
}
self._losses = torch.tensor(np.array(grp["losses"]["losses"]))
rec = grp["recog_net"]
hidden1_W = torch.tensor(np.array(rec["hidden_1.weight"]))
hidden2_W = torch.tensor(np.array(rec["hidden_2.weight"]))
state_dict = {
"hidden_1.weight": hidden1_W,
"hidden_1.bias": torch.tensor(np.array(rec["hidden_1.bias"])),
"hidden_2.weight": hidden2_W,
"hidden_2.bias": torch.tensor(np.array(rec["hidden_2.bias"])),
}
state_dict = OrderedDict(state_dict)
self._recog = TypeRecognitionNet(
hidden2_W.shape[0] - 1, hidden1_W.shape[1], hidden1_W.shape[0]
).to(device=self._device, dtype=self._dtype)
self._recog.load_state_dict(state_dict)
self._recog.eval()
def _forward(
self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor
) -> torch.Tensor:
"""One forward pass.
:param Y: a sample from the dataset
:param X: normalized sample data
:param design: the corresponding row of design matrix
:return: the cost (elbo) of the current pass
"""
if self._dset is None:
raise Exception("the dataset is not provided")
G = self._dset.get_n_features()
C = self._dset.get_n_classes()
N = Y.shape[0]
Y_spread = Y.view(N, 1, G).repeat(1, C + 1, 1)
delta_tilde = torch.exp(self._variables["log_delta"])
mean = delta_tilde * self._data["rho"]
mean2 = torch.mm(design, self._variables["mu"].T) ## N x P * P x G
mean2 = mean2.view(-1, G, 1).repeat(1, 1, C + 1)
mean = mean + mean2
# now do the variance modelling
p = torch.sigmoid(self._variables["p"])
sigma = torch.exp(self._variables["log_sigma"])
v1 = (self._data["rho"] * p).T * sigma
v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T
v1 = v1.view(1, C + 1, G, 1).repeat(N, 1, 1, 1) # extra 1 is the "rank"
v2 = v2.view(1, C + 1, G).repeat(N, 1, 1) + 1e-6
dist = LowRankMultivariateNormal(
loc=torch.exp(mean).permute(0, 2, 1), cov_factor=v1, cov_diag=v2
)
log_p_y_on_c = dist.log_prob(Y_spread)
gamma, log_gamma = self._recog.forward(X)
log_alpha = F.log_softmax(self._variables["alpha_logits"], dim=0)
alpha = F.softmax(self._variables["alpha_logits"], dim=0)
mix_prior = self._alpha_prior.log_prob(alpha)
elbo = (gamma * (log_p_y_on_c + log_alpha - log_gamma)).sum() + mix_prior
return -elbo
[docs] def fit(
self,
max_epochs: int = 50,
learning_rate: float = 1e-3,
batch_size: int = 128,
delta_loss: float = 1e-3,
delta_loss_batch: int = 10,
msg: str = "",
) -> None:
"""Runs train loops until the convergence reaches delta_loss for
delta_loss_batch sizes or for max_epochs number of times
:param max_epochs: number of train loop iterations, defaults to 50
:param learning_rate: the learning rate, defaults to 0.01
:param batch_size: the batch size, defaults to 128
:param delta_loss: stops iteration once the loss rate reaches
delta_loss, defaults to 0.001
:param delta_loss_batch: the batch size to consider delta loss,
defaults to 10
:param msg: iterator bar message, defaults to empty string
"""
if self._dset is None:
raise Exception("the dataset is not provided")
# Make dataloader
dataloader = DataLoader(
self._dset, batch_size=min(batch_size, len(self._dset)), shuffle=True
)
# Run training loop
losses: List[torch.Tensor] = []
per = torch.tensor(1)
# Construct optimizer
opt_params = list(self._variables.values()) + list(self._recog.parameters())
optimizer = torch.optim.Adam(opt_params, lr=learning_rate)
_, exprs_X, _ = self._dset[:] # calls dset.get_item
iterator = trange(
max_epochs,
desc="training restart" + msg,
unit="epochs",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]",
)
for ep in iterator:
# for ep in range(max_epochs):
L = None
loss = torch.tensor(0.0, dtype=self._dtype)
for batch in dataloader:
Y, X, design = batch
optimizer.zero_grad()
L = self._forward(Y, X, design)
L.backward()
optimizer.step()
with torch.no_grad():
loss = loss + L
if len(losses) > 0:
per = abs((loss - losses[-1]) / losses[-1])
losses.append(loss)
iterator.set_postfix_str("current loss: " + str(round(float(loss), 1)))
if per <= delta_loss:
self._is_converged = True
iterator.close()
break
# Save output
self._assignment = pd.DataFrame(
self._recog.forward(exprs_X)[0].detach().cpu().numpy()
)
self._assignment.columns = self._dset.get_classes() + ["Other"]
self._assignment.index = self._dset.get_cell_names()
if self._losses.shape[0] == 0:
self._losses = torch.tensor(losses)
else:
self._losses = torch.cat(
(self._losses.view(self._losses.shape[0]), torch.tensor(losses)), dim=0
)
[docs] def predict(self, new_dset: pd.DataFrame) -> np.array:
"""Feed `new_dset` to the recognition net to get a prediction.
:param new_dset: the dataset to be predicted
:return: the resulting cell type assignment
"""
_, exprs_X, _ = new_dset[:]
g = pd.DataFrame(self._recog.forward(exprs_X)[0].detach().cpu().numpy())
return g
[docs] def get_recognet(self) -> TypeRecognitionNet:
"""Getter for the recognition net.
:return: the trained recognition net
"""
return self._recog
def _most_likely_celltype(
self,
row: pd.DataFrame,
threshold: float,
cell_types: List[str],
assignment_type: str,
) -> str:
"""Given a row of the assignment matrix, return the most likely cell type
:param row: the row of cell assignment matrix to be evaluated
:param threshold: the higher bound of the maximun probability to classify a cell as `Unknown`
:param cell_types: the names of cell types, in the same order as the features of the row
:param assignment_type: See
:meth:`astir.CellTypeModel.get_celltypes` for full documentation
:return: the most likely cell type of this cell
"""
row = row.values
max_prob = np.max(row)
if assignment_type == "threshold":
if max_prob < threshold:
return "Unknown"
elif assignment_type == "max":
if sum(row == max_prob) > 1:
return "Unknown"
return cell_types[np.argmax(row)]
[docs] def get_celltypes(
self,
threshold: float = 0.7,
assignment_type: str = "threshold",
prob_assign: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Get the most likely cell types. A cell is assigned to a cell type
if the probability is greater than threshold.
If no cell types have a probability higher than threshold,
then "Unknown" is returned.
:param assignment_type: either 'threshold' or 'max'. If threshold,
type assignment is based on whether the probability threshold is
above prob_assignment. If 'max', type assignment is based on the max
probability value or "unknown" if there are multiple max
probabilities. Defaults to 'threshold'.
:param threshold: the probability threshold above which a cell is
assigned to a cell type, defaults to 0.7
:return: a data frame with most likely cell types for each
"""
if prob_assign is None:
type_probability = self.get_assignment()
else:
type_probability = prob_assign
if assignment_type != "threshold" and assignment_type != "max":
warnings.warn(
"Wrong assignment type. Defaults the assignment " "type to threshold."
)
assignment_type = "threshold"
if assignment_type == "max" and prob_assign is not None:
warnings.warn(
"Assignment type is 'max' but probability "
"threshold value was passed in. Probability "
"threshold value will be ignored."
)
cell_types = list(type_probability.columns)
cell_type_assignments = type_probability.apply(
self._most_likely_celltype,
axis=1,
assignment_type=assignment_type,
threshold=threshold,
cell_types=cell_types,
)
cell_type_assignments = pd.DataFrame(cell_type_assignments)
cell_type_assignments.columns = ["cell_type"]
return cell_type_assignments
def _compare_marker_between_types(
self,
curr_type: str,
celltype_to_compare: str,
marker: str,
cell_types: List[str],
alpha: float = 0.05,
) -> Optional[dict]:
"""For two cell types and a protein, ensure marker
is expressed at higher level for curr_type than celltype_to_compare
:param curr_type: the cell type to assess
:param celltype_to_compare: all the cell types that shouldn't highly express this marker
:param marker: the marker protein for curr_type
:param cell_types: list of cell types assigned for cells
:param alpha:
:return:
"""
if self._dset is None:
raise Exception("the dataset is not provided")
current_marker_ind = np.array(self._dset.get_features()) == marker
cells_x = np.array(cell_types) == curr_type
cells_y = np.array(cell_types) == celltype_to_compare
# x - cells whose cell types' marker protein is marker
# y - cells whose cell types' marker protein is not marker
x = self._dset.get_exprs().detach().cpu().numpy()[cells_x, current_marker_ind]
y = self._dset.get_exprs().detach().cpu().numpy()[cells_y, current_marker_ind]
stat = np.NaN
pval = np.Inf
note: Optional[str] = "Only 1 cell in a type: comparison not possible"
if len(x) > 1 and len(y) > 1:
tt = stats.ttest_ind(x, y)
stat = tt.statistic
pval = tt.pvalue
note = None
if not (stat > 0 and pval < alpha):
rdict = {
"current_marker": marker,
"curr_type": curr_type,
"celltype_to_compare": celltype_to_compare,
"mean_A": x.mean(),
"mean_Y": y.mean(),
"p-val": pval,
"note": note,
}
return rdict
return None
[docs] def plot_clustermap(
self,
plot_name: str = "celltype_protein_cluster.png",
threshold: float = 0.7,
figsize: Tuple[float, float] = (7.0, 5.0),
prob_assign: Optional[pd.DataFrame] = None,
) -> None:
"""Save the heatmap of protein content in cells with cell types labeled.
:param plot_name: name of the plot, extension(e.g. .png or .jpg) is needed, defaults to "celltype_protein_cluster.png"
:param threshold: the probability threshold above which a cell is assigned to a cell type, defaults to 0.7
:param figsize: the size of the figure, defaults to (7.0, 5.0)
"""
if self._dset is None:
raise Exception("the dataset is not provided")
expr_df = self._dset.get_exprs_df()
scaler = StandardScaler()
for feature in expr_df.columns:
expr_df[feature] = scaler.fit_transform(
expr_df[feature].values.reshape((expr_df[feature].shape[0], 1))
)
expr_df["cell_type"] = self.get_celltypes(
threshold=threshold, prob_assign=prob_assign
)
expr_df = expr_df.sort_values(by=["cell_type"])
types = expr_df.pop("cell_type")
types_uni = types.unique()
lut = dict(zip(types_uni, sns.color_palette("BrBG", len(types_uni))))
col_colors = pd.DataFrame(types.map(lut))
cm = sns.clustermap(
expr_df.T,
xticklabels=False,
cmap="vlag",
col_cluster=False,
col_colors=col_colors,
figsize=figsize,
)
for t in types_uni:
cm.ax_col_dendrogram.bar(0, 0, color=lut[t], label=t, linewidth=0)
cm.ax_col_dendrogram.legend(
title="Cell Types", loc="center", ncol=3, bbox_to_anchor=(0.8, 0.8)
)
cm.savefig(plot_name, dpi=150)
[docs] def diagnostics(self, cell_type_assignments: list, alpha: float) -> pd.DataFrame:
"""Run diagnostics on cell type assignments
See :meth:`astir.Astir.diagnostics_celltype` for full documentation
"""
if self._dset is None:
raise Exception("the dataset is not provided")
problems = []
# Want to construct a data frame that models rho with
# cell type names on the columns and feature names on the rows
g_df = pd.DataFrame(self._data["rho"].detach().cpu().numpy())
g_df.columns = self._dset.get_classes() + ["Other"]
g_df.index = self._dset.get_features()
for curr_type in self._dset.get_classes():
if not curr_type in cell_type_assignments:
continue
current_markers = g_df.index[g_df[curr_type] == 1]
for current_marker in current_markers:
# find all the cell types that shouldn't highly express this marker
celltypes_to_compare = g_df.columns[g_df.loc[current_marker] == 0]
for celltype_to_compare in celltypes_to_compare:
if not celltype_to_compare in cell_type_assignments:
continue
is_problem = self._compare_marker_between_types(
curr_type,
celltype_to_compare,
current_marker,
cell_type_assignments,
alpha,
)
if is_problem is not None:
problems.append(is_problem)
col_names = [
"feature",
"should be expressed higher in",
"than",
"mean cell type 1",
"mean cell type 2",
"p-value",
"note",
]
df_issues = None
if len(problems) > 0:
df_issues = pd.DataFrame(problems)
df_issues.columns = col_names
else:
df_issues = pd.DataFrame(columns=col_names)
return df_issues