Source code for tacco.preprocessing._reference

import time

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from scipy.sparse import issparse

from .. import get
from .. import utils
from ..utils._utils import _run_OT
from . import _qc as qc
from ..utils._utils import _infer_annotation_key

def _normalize_reference(adata, annotation_key=None):
    ''' Normalize profiles per annotation and annotation per cell - if available. '''
    
    annotation_key = _infer_annotation_key(adata, annotation_key)
    
    # normalize annotation weights to 1 per cell
    if annotation_key in adata.obsm:
        adata.obsm[annotation_key] /= np.array(adata.obsm[annotation_key].sum(axis=1))[:,None]
    # normalize profiles to 1 per annotation
    if annotation_key in adata.varm:
        adata.varm[annotation_key] /= np.array(adata.varm[annotation_key].sum(axis=0))

[docs] def construct_reference_profiles( adata, annotation_key=None, counts_location=None, inplace=True, normalize=True, target_sum=None, trafo=None, ): """\ Constructs reference profiles from categorical annotations. Parameters ---------- adata An :class:`~anndata.AnnData` including expression data in `.X` and a categorical annotation in `.obs`. annotation_key The `.obs` key where the annotation is stored; this will also be used as a `.varm` key to store the resulting annotation profiles. counts_location A string or tuple specifying where the count matrix is stored, e.g. `'X'`, `('raw','X')`, `('raw','obsm','my_counts_key')`, `('layer','my_counts_key')`, ... For details see :func:`~tacco.get.counts`. inplace Whether to update the input :class:`~anndata.AnnData` or return a copy. normalize Whether to normalize the reference annotation and profiles. target_sum The number of counts to normalize every observation to before computing profiles. If `None`, no normalization is performed. trafo What transformation to apply to the expression before computing the profiles. After the computation, the transformation is inverted to get profiles in the same units as the expression. Possible values are: - "log1p": log(x+1) - "sqrt": sqrt(x) - `None`: no transformation Returns ------- Returns an :class:`~anndata.AnnData` containing the constructed profiles\ depending on `inplace` either as copy or as a reference to the original\ `adata`. """ if adata is None: raise ValueError('"adata" cannot be None!') annotation_key = _infer_annotation_key(adata, annotation_key) reference = get.counts(adata, counts_location=counts_location, annotation=annotation_key, copy=False) #qc.check_counts_validity(reference.X) if annotation_key not in reference.obs: raise ValueError(f'The key {annotation_key!r} is not available in .obs!') if target_sum is not None or trafo is not None: reference = reference.copy() # we do not want to change the original if target_sum is not None: sc.pp.normalize_total(reference, target_sum=target_sum) if trafo == 'log1p': utils.log1p(reference) elif trafo == 'sqrt': utils.sqrt(reference) dums = pd.get_dummies(reference.obs[annotation_key],dtype=reference.X.dtype) ncats = dums.sum(axis=0) dums /= ncats.to_numpy() profiles = reference.X.T @ dums.to_numpy() profiles = pd.DataFrame(profiles, index=reference.var.index, columns=dums.columns) # invert the transformation for the profiles if trafo == 'log1p': profiles = np.expm1(profiles) elif trafo == 'sqrt': profiles *= profiles if not inplace: adata = adata.copy() adata.varm[annotation_key] = profiles.reindex(index=adata.var.index) if normalize: _normalize_reference(adata, annotation_key) return adata
[docs] def refine_reference( adata, annotation_key=None, counts_location=None, inplace=False, normalize=True, regularization=1e-3, ): """\ Refines a reference data set by scaling profiles and annotation to match the expression data. Specifically, determines the normalization factors n(cg) in the read model p(cga) = n(cg) p(g|a) p(a|c) for the joint probability distribution of cells c, genes g, and annotation a per read to p(cg) from the expression data, and updates the profiles p(g|a) and the annotation p(a|c) from the marginals of p(cga). Parameters ---------- adata An :class:`~anndata.AnnData` including expression data in `.X` and profiles in `.varm` and/or annotation in `.obs` or `.obsm`. annotation_key The `.obs`, `.obsm`, and/or `.varm` key where the annotation and profiles are and/or will be stored. counts_location A string or tuple specifying where the count matrix is stored, e.g. `'X'`, `('raw','X')`, `('raw','obsm','my_counts_key')`, `('layer','my_counts_key')`, ... For details see :func:`~tacco.get.counts`. inplace Whether to modify the input :class:`~anndata.AnnData` or return a copy. normalize Whether to normalize the reference annotation and profiles. regularization Relative factor to determine a regularization addition to the profiles to avoid unsolvable count distributions (e.g. for some (g,c): sum_a p(g|a) * p(a|c) = 0, but p(cg) != 0). If set to 0, no regularization is done. Returns ------- Returns an :class:`~anndata.AnnData` containing the refined reference,\ depending on `copy` either as copy or as a reference to the original\ `adata`. """ if adata is None: raise ValueError('"adata" cannot be None!') annotation_key = _infer_annotation_key(adata, annotation_key) counts = get.counts(adata, counts_location=counts_location, annotation=False, copy=False) qc.check_counts_validity(counts.X) if annotation_key not in adata.obsm and annotation_key not in adata.obs: raise ValueError('The key "%s" is neither available in .obsm nor .obs!' % annotation_key) if annotation_key not in adata.varm: raise ValueError('The key "%s" is not available in .varm!' % annotation_key) if not inplace: adata = adata.copy() _normalize_reference(adata, annotation_key) if annotation_key in adata.obs: adata.obsm[annotation_key] = pd.get_dummies(adata.obs[annotation_key]) del adata.obs[annotation_key] if regularization != 0: adata.varm[annotation_key] += 1 / len(adata.var.index) * regularization adata.obsm[annotation_key] += 1 / len(adata.obs.index) * regularization _normalize_reference(adata, annotation_key) p_g_a = adata.varm[annotation_key] p_a_c = adata.obsm[annotation_key].T p_g_a = p_g_a.to_numpy() p_a_c = p_a_c.to_numpy() #p(cga) = n(cg) p(g|a) p(a|c) #p(cg) = sum_a p(cga) = n(cg) sum_a p(g|a) p(a|c) #n(cg) = p(cg) / sum_a p(g|a) p(a|c) if issparse(counts.X): p_cg = counts.X.tocoo() if p_cg is counts.X: p_cg = p_cg.copy() p_cg.data *= 1/p_cg.data.sum() # normalize as joint probability temp_data = np.empty_like(p_cg.data) utils.sparse_result_gemmT(p_g_a, p_a_c.T, p_cg.col, p_cg.row, temp_data) utils.divide(p_cg.data,temp_data,out=p_cg.data) # p_cg now contains n_cg else: p_cg = counts.X.copy() p_cg *= 1/p_cg.sum(axis=None) # normalize as joint probability temp = p_g_a@p_a_c p_cg /= temp # p_cg now contains n_cg #p'(ga) = sum_c p(cga) = p(g|a) sum_c n(cg) p(a|c) adata.varm[annotation_key] = pd.DataFrame(p_g_a * (p_a_c@p_cg).T, index=adata.var.index, columns=adata.varm[annotation_key].columns) #p'(ac) = sum_g p(cga) = p(a|c) sum_g p(g|a) n(cg) adata.obsm[annotation_key] = pd.DataFrame(p_a_c.T * (p_cg@p_g_a), index=adata.obs.index, columns=adata.varm[annotation_key].columns) if normalize: _normalize_reference(adata, annotation_key) return adata