import pandas as pd
import numpy as np
import anndata as ad
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import to_rgb, to_rgba, rgb_to_hsv, hsv_to_rgb, to_rgba_array, PowerNorm, LinearSegmentedColormap, to_hex, ListedColormap, LogNorm, Normalize
from matplotlib.cm import ScalarMappable
import seaborn as sns
import scipy.stats
import scipy.cluster
from ..tools._enrichments import get_compositions, get_contributions, enrichments
from .. import utils
from .. import tools
from .. import get
from .. import preprocessing
import joblib
import gc
from numba import njit
def get_min_max(vector, log=True):
if log:
maximum = np.log(np.nanmax(vector))
minimum = np.log(np.nanmin(vector[vector > 0]))
delta = (maximum - minimum) * 0.1
maximum += delta
minimum -= delta
maximum = np.exp(maximum)
minimum = np.exp(minimum)
else:
maximum = np.nanmax(vector)
minimum = np.nanmin(vector)
#delta = (maximum - minimum) * 0.1
#maximum += delta
#minimum -= delta
return minimum, maximum
def _correlations(x, y, log=False):
if log:
x, y = np.log(x), np.log(y)
x_y_finite = np.isfinite(x) & np.isfinite(y)
x, y = x[x_y_finite], y[x_y_finite]
nDOF = len(x) - 1
x2 = np.sum((x-y-x.mean()+y.mean())**2) / (np.var(y) * nDOF)
pc = scipy.stats.pearsonr(x,y)[0]
return pc, x2
def _scatter_plot(ax,x,y,sizes=None,colors=None,alpha=1.0,marker='o',log=False):
ax.scatter(x=x,y=y,s=sizes,c=colors,alpha=alpha,marker=marker)
pc, x2 = _correlations(x, y, log=log)
ax.annotate('r=%.2f'%(pc), (0.05,0.95), xycoords='axes fraction', horizontalalignment='left', verticalalignment='top')
ax.annotate('$\\chi^2_m$=%.2f'%(x2), (0.95,0.00), xycoords='axes fraction', horizontalalignment='right', verticalalignment='bottom')
if log:
ax.set_xscale('log')
ax.set_yscale('log')
def _composition_bar(compositions, colors, horizontal=True, ax=None, legend=True):
n_freqs = len(compositions.index)
x = np.arange(compositions.shape[0])
fig = None
if horizontal:
if ax is None:
fig, ax = plt.subplots(figsize=(11,1*(n_freqs+2)))
bottom = np.full_like(x,0)
for t in colors.index[::1]:
if t in compositions:
ax.barh(x, compositions[t], height=0.55, label=t, left=bottom, color=colors[t])
bottom = bottom + np.array(compositions[t])
else:
ax.barh(x, np.zeros(compositions.shape[0]), height=0.55, label=t, left=bottom, color=colors[t])
ax.set_xlim([0, 1])
ax.set_xlabel('composition')
ax.set_ylim(x.min()-0.75,x.max()+0.75)
ax.set_yticks(x)
ax.set_yticklabels(compositions.index, ha='right')
if legend:
ax.legend(bbox_to_anchor=(-0.0, 1.0), loc='lower left', ncol=6)
else:
if ax is None:
fig, ax = plt.subplots(figsize=(1*(n_freqs+2),8))
bottom = np.array(x)*0+1
for t in colors.index[::-1]:
if t in compositions:
bottom = bottom - np.array(compositions[t])
ax.bar(x, compositions[t], width=0.55, label=t, bottom=bottom, color=colors[t])
else:
ax.bar(x, np.zeros(compositions.shape[0]), width=0.55, label=t, bottom=bottom, color=colors[t])
ax.set_ylim([0, 1])
ax.set_ylabel('composition')
ax.set_xlim(x.min()-0.75,x.max()+0.75)
ax.set_xticks(x)
ax.set_xticklabels(compositions.index, rotation=30,va='top',ha='right')
if legend:
ax.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
if fig is not None:
fig.tight_layout()
return fig
def _complement_color(r,g,b):
h, s, v = rgb_to_hsv((r,g,b))
_r, _g, _b = hsv_to_rgb(((h+0.5)%1,s,v))
return _r, _g, _b
def _scatter_frame(df, colors, ax_, id_=None, cmap=None, cmap_vmin_vmax=None, out_min_max=None, joint=True, point_size=3, grid=False, margin=0.25, scale=None, title='', rasterized=True):
pie = False
if not isinstance(rasterized, bool):
if rasterized != 'pie':
raise ValueError(f'`rasterized` has to be boolean or "pie"!')
if cmap is not None:
raise ValueError(f'`rasterized=="pie"` cannot use a `cmap`!')
pie = True
rasterized = False
color_array = np.array([colors[c] for c in df.columns.difference(['x','y'])])
dpi = ax_[0].get_figure().get_dpi()
point_radius = np.sqrt(point_size / np.pi) / 2 * dpi
# convert from pixel to axes units
point_radius = (ax_[0].transAxes.inverted().transform((point_radius, 0)) - ax_[0].transAxes.inverted().transform((0, 0)))[0]
if id_ is None:
n_cols = len(colors)
id_ = np.full(fill_value=-1,shape=(n_cols,2),dtype=object)
n_types = len(colors)
for ax in ax_:
ax.set_aspect(1)
ax.grid(grid)
if joint is None:
_ax_ = ax_[1:]
_id_ = id_[1:]
elif joint:
_ax_ = [None] * n_types
_id_ = [None] * n_types
else:
_ax_ = ax_
_id_ = id_
axsize = ax_[0].get_window_extent().transformed(ax_[0].get_figure().dpi_scale_trans.inverted()).size
canvas_size = axsize
coords = df[['x','y']].to_numpy()
# use the largest possible amount of space of the axes
coords_min = coords.min(axis=0) if len(coords) > 0 else np.nan
_coords = coords - coords_min
coords_range = _coords.max(axis=0) if len(coords) > 0 else np.nan
scale_0 = (canvas_size[0]-2*margin) / coords_range[0]
scale_1 = (canvas_size[1]-2*margin) / coords_range[1]
if scale is None:
scale = min(scale_0, scale_1)
# center the coords
offset_0 = (canvas_size[0] - coords_range[0] * scale) / 2
offset_1 = (canvas_size[1] - coords_range[1] * scale) / 2
# coords_min is mapped to (offset_0,offset_1)
# => coords_min-(offset_0,offset_1)/scale is mapped to (0,0)
# canvas_size / scale is the range
# => extent_min = coords_min - (offset_0,offset_1) / scale
# => extent_max = extent_min + canvas_size / scale
extent_min = coords_min - np.array((offset_0,offset_1)) / scale
extent_max = extent_min + canvas_size / scale
extent = [extent_min[0], extent_max[0], extent_min[1], extent_max[1]]
x, y = df['x'], df['y']
for (t,c),ax,id in zip(colors.items(),_ax_,_id_):
if pie:
only_color_array = color_array.copy()
only_color_array[only_color_array != c] = '#fff0' # transparent white
weights = df[df.columns.difference(['x','y'])].to_numpy()
def plotit(ax, id, color_array):
for _weights,_x,_y in zip(weights, x, y):
ax.pie(_weights,colors=color_array,radius=point_radius,center=(_x,_y), frame=True)
ax.set_xlim(extent[0:2])
ax.set_ylim(extent[2:4])
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id[0],id[1]] = vals.min(), vals.max()
if joint is None or not joint:
plotit(ax, id, only_color_array)
ax.set_title(f'{title}: {t}' if title is not None and title != '' else f'{t}')
if joint is None or joint:
if (t,c) == list(iter(colors.items()))[0]: # plot full pies only once
plotit(ax_[0], id_[0],color_array)
elif cmap is None:
#scale = 200.0 * np.random.rand(n)
r, g, b = to_rgb(c)
_r, _g, _b = _complement_color(r,g,b) # complement colors for negative values
if t in df.columns:
vals = np.maximum(np.minimum(df[t].to_numpy(),1),-1)
# dont plot invisible points
visible = vals != 0
vals, _x, _y = vals[visible], x[visible], y[visible]
# separately plot efficient cases
which1p = vals >= +1
vals1p, _x1p, _y1p = vals[which1p], _x[which1p], _y[which1p]
which1m = vals <= -1
vals1m, _x1m, _y1m = vals[which1m], _x[which1m], _y[which1m]
# and the rest
remaining = (vals > -1) & (vals < 1)
vals, _x, _y = vals[remaining], _x[remaining], _y[remaining]
color1p = [r,g,b]
color1m = [_r,_g,_b]
color = np.hstack([np.array([[r,_r],[g,_g],[b,_b]]).T[(1-np.sign(vals).astype(np.int8))//2],np.abs(vals)[:,None]])
def plotit(ax, id):
if len(_x) > 0:
ax.scatter(_x, _y, color=color, s=point_size, edgecolors='none', rasterized=rasterized)
if len(_x1p) > 0:
ax.scatter(_x1p, _y1p, color=color1p, s=point_size, edgecolors='none', rasterized=rasterized)
if len(_x1m) > 0:
ax.scatter(_x1m, _y1m, color=color1m, s=point_size, edgecolors='none', rasterized=rasterized)
ax.set_xlim(extent[0:2])
ax.set_ylim(extent[2:4])
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id[0],id[1]] = vals.min(), vals.max()
if joint is None or not joint:
plotit(ax, id)
ax.set_title(f'{title}: {t}' if title is not None and title != '' else f'{t}')
if joint is None or joint:
plotit(ax_[0], id_[0])
else:
if t in df.columns:
vals = df[t].to_numpy()
def plotit(ax, id):
if cmap_vmin_vmax is not None:
vmin, vmax = cmap_vmin_vmax
norm = Normalize(vmin=vmin, vmax=vmax)
else:
vmin, vmax = vals.min(), vals.max()
norm = None
ax.scatter(x, y, c=vals, s=point_size, edgecolors='none', cmap=cmap, norm=norm, rasterized=rasterized)
ax.set_xlim(extent[0:2])
ax.set_ylim(extent[2:4])
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id[0],id[1]] = vmin, vmax
if joint is None or not joint:
plotit(ax, id)
ax.set_title(f'{title}: {t}' if title is not None and title != '' else f'{t}')
if joint is None or joint:
plotit(ax_[0], id_[0])
@njit(cache=True)
def _make_stencil(point_radius):
extra_size = int(np.ceil(point_radius - 0.5))
stencil_size = 1 + 2 * extra_size
stencil = np.zeros((stencil_size,stencil_size))
# monte carlo evaluate density profile of a disc... There have to be better options, but maybe not more straightforward ones...
# for one quadrant
np.random.seed(42)
point_radius2 = point_radius**2
n_samples = int(point_radius2*10000)
samples = np.random.uniform(0,1,size=(n_samples,2)) * point_radius
quadrant = np.zeros((1+extra_size,1+extra_size))
for sample in samples:
d2 = (sample**2).sum()
if d2 < point_radius2:
_sample = sample+0.5
quadrant[int(_sample[0]),int(_sample[1])] += 1
# symmetrize octant
quadrant = quadrant + quadrant.T # using += with numba and matrices gives wrong results: https://github.com/numba/numba/issues/6949
# replicate to other quadrants
stencil[:-extra_size,:-extra_size] += quadrant[::-1,::-1]
stencil[extra_size:,:-extra_size] += quadrant[::,::-1]
stencil[:-extra_size,extra_size:] += quadrant[::-1,::]
stencil[extra_size:,extra_size:] += quadrant[::,::]
# normalize
stencil /= stencil.sum()
return stencil
@njit(cache=True)
def _draw_weights(canvas_size, weights, data, offset_x, offset_y, stencil):
extra_size = (stencil.shape[0] - 1) // 2
canvas = np.zeros(canvas_size)
for w,(x,y) in zip(weights,data):
xo,yo = int(x+offset_x),int(y+offset_y)
canvas[(xo-extra_size):(xo+extra_size+1),(yo-extra_size):(yo+extra_size+1)] += w * stencil
return canvas
def _render_frame(df, colors, ax_, id_=None, cmap=None, cmap_vmin_vmax=None, out_min_max=None, joint=True, point_size=3, grid=False, margin=0.25, scale=None, title='', color_mix_mode='xyv'):
typing = df[df.columns.intersection(colors.index)]
coords = df[['x','y']].to_numpy()
coords_min = coords.min(axis=0) if len(coords) > 0 else np.nan
coords = coords - coords_min
coords_range = coords.max(axis=0) if len(coords) > 0 else np.nan
if cmap is None: # empty dots dont get rendered
nonzero = np.abs(typing.to_numpy()).sum(axis=1) > 0
typing = typing.loc[nonzero]
coords = coords[nonzero]
n_types = len(typing.columns)
weights = typing
dpi = ax_[0].get_figure().get_dpi()
if id_ is None:
n_cols = len(colors)
id_ = np.full(fill_value=-1,shape=(n_cols,2),dtype=object)
for ax in ax_:
ax.grid(grid)
if joint is None:
_ax_ = ax_[1:]
_id_ = id_[1:]
elif joint:
_ax_ = [None] * n_types
_id_ = [None] * n_types
else:
_ax_ = ax_
_id_ = id_
point_radius = np.sqrt(point_size / np.pi) / 72 * dpi
margin = int(np.ceil(margin * dpi + point_radius))
axsize = ax_[0].get_window_extent().transformed(ax_[0].get_figure().dpi_scale_trans.inverted()).size
canvas_size = ((axsize)*dpi).astype(int)
# use the largest possible amount of space of the axes
scale_0 = (canvas_size[0]-1-2*margin) / coords_range[0]
scale_1 = (canvas_size[1]-1-2*margin) / coords_range[1]
if scale is None:
scale = min(scale_0, scale_1)
else:
scale = scale * dpi
# center the coords
offset_0 = (canvas_size[0]-1 - coords_range[0] * scale) / 2
offset_1 = (canvas_size[1]-1 - coords_range[1] * scale) / 2
# coords_min is mapped to (offset_0,offset_1)
# => coords_min-(offset_0,offset_1)/scale is mapped to (0,0)
# canvas_size / scale is the range
# => extent_min = coords_min - (offset_0,offset_1) / scale
# => extent_max = extent_min + canvas_size / scale
extent_min = coords_min - np.array((offset_0,offset_1)) / scale
extent_max = extent_min + canvas_size / scale
extent = [extent_min[0], extent_max[0], extent_min[1], extent_max[1]]
coords = coords * scale
stencil = _make_stencil(point_radius)
canvases = { c: _draw_weights((canvas_size[0],canvas_size[1]), weights[c].to_numpy(), coords, offset_0, offset_1, stencil) for c in weights.columns}
sum_canvas = sum(canvases.values())
finite_sum = sum_canvas!=0
colors = {i:np.array(to_rgb(colors[i])) for i in colors.index}
# alpha just tells us whether there is data on the pixel or not
def get_alpha(norm_canvas, stencil):
alpha = np.log1p(np.abs(norm_canvas)) / np.log1p(stencil.max())
alpha[alpha>1] = 1 # cut off values with too high alpha
return alpha
def add_alpha(canvas, norm_canvas, stencil):
canvasA = np.zeros_like(canvas, shape=(*canvas.shape[:-1],4))
canvasA[...,:-1] = canvas
canvasA[...,-1] = get_alpha(norm_canvas, stencil)
return canvasA
if cmap is None:
#norm_canvas = _draw_weights((canvas_size[0],canvas_size[1]), np.ones(shape=len(weights)), coords, offset_0, offset_1, stencil)
norm_canvas = sum([np.abs(canvas) for canvas in canvases.values()])
#finite_norm = norm_canvas!=0
if joint is None or joint:
canvas = mix_base_colors(np.stack([canvases[t] for t in canvases],axis=-1), np.array([colors[t] for t in canvases]), mode=color_mix_mode)
for i in range(3): # remove zero weight colors
canvas[...,i][~finite_sum] = 1
canvas[canvas>1] = 1 # numerical accuracy issues
canvasA = add_alpha(canvas, norm_canvas, stencil)
ax_[0].imshow(canvasA.swapaxes(0,1), origin='lower', extent=extent)
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id_[0][0],id_[0][1]] = sum(canvases.values()).min()/stencil.max(), sum(canvases.values()).max()/stencil.max()
if joint is None or not joint:
for (t,c),ax,id in zip(colors.items(),_ax_,_id_):
if t in canvases:
canvas = canvases[t][...,None] * colors[t]
finite_t = canvases[t]!=0
for i in range(3): # normalize the colors by the weights
canvas[...,i][finite_t] = canvas[...,i][finite_t] / np.abs(canvases[t][finite_t])
canvas[canvas>1] = 1 # numerical accuracy issues
canvasA = add_alpha(canvas, np.abs(canvases[t]), stencil)
ax.imshow(canvasA.swapaxes(0,1), origin='lower', extent=extent)
ax.set_title(f'{title}: {t}' if title is not None and title != '' else f'{t}')
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id[0],id[1]] = canvases[t].min()/stencil.max(), canvases[t].max()/stencil.max()
else:
norm_canvas = _draw_weights((canvas_size[0],canvas_size[1]), np.ones(shape=len(weights)), coords, offset_0, offset_1, stencil)
finite_norm = norm_canvas!=0
if joint is None or joint:
canvas = sum_canvas.copy()
canvas[finite_norm] = canvas[finite_norm] / norm_canvas[finite_norm] # normalize the values by the weights
alpha = get_alpha(norm_canvas, stencil)
if cmap_vmin_vmax is not None:
vmin, vmax = cmap_vmin_vmax
else:
vmin, vmax = canvas[finite_norm].min(), canvas[finite_norm].max()
norm = Normalize(vmin=vmin, vmax=vmax)
ax_[0].imshow(canvas.swapaxes(0,1), alpha=alpha.swapaxes(0,1), origin='lower', extent=extent, cmap=cmap, norm=norm)
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id_[0][0],id_[0][1]] = vmin, vmax
if joint is None or not joint:
for (t,c),ax,id in zip(colors.items(),_ax_,_id_):
if t in canvases:
canvas = canvases[t].copy()
canvas[finite_norm] = canvas[finite_norm] / norm_canvas[finite_norm] # normalize the values by the weights
alpha = get_alpha(norm_canvas, stencil)
if cmap_vmin_vmax is not None:
vmin, vmax = cmap_vmin_vmax
norm = Normalize(vmin=vmin, vmax=vmax)
else:
vmin, vmax = canvas[finite_norm].min(), canvas[finite_norm].max()
norm = None
ax.imshow(canvas.swapaxes(0,1), alpha=alpha.swapaxes(0,1), origin='lower', extent=extent, cmap=cmap, norm=norm)
ax.set_title(f'{title}: {t}' if title is not None and title != '' else f'{t}')
if out_min_max is not None: # get values for backpropagation of the value range
out_min_max[id[0],id[1]] = vmin, vmax
def _set_axes_labels(ax, axes_labels):
if axes_labels is not None:
if not pd.api.types.is_list_like(axes_labels) or len(axes_labels) != 2:
raise ValueError(f'`axes_labels` {axes_labels!r} is not a list-like of 2 elements!')
ax.set_xlabel(axes_labels[0])
ax.set_ylabel(axes_labels[1])
def spatial_distribution_plot(typing_data, coords, colors, n_cols=1, axs=None, joint=None, normalize=True, point_size=3, cmap=None, cmap_vmin_vmax=None, out_min_max=None, scale=None, grid=False, margin=0.25, render=False, rasterized=True, noticks=False, axes_labels=None, on_data_legend=None):
n_types = len(colors.index)
if joint is None:
n_y = (n_types + n_cols) // n_cols
elif joint:
n_cols = 1
n_y = 1
else:
n_y = (n_types + n_cols - 1) // n_cols
n_solutions = len(typing_data)
if axs is None:
fig, axs = plt.subplots(n_y,n_solutions*n_cols,figsize=(7*n_solutions*n_cols,7*n_y), squeeze=False)
else:
fig = None
if axs.shape != (n_y,n_solutions*n_cols):
raise Exception('spatial_distribution_plot got axs of wrong dimensions: need %s, got %s' % ((n_y,n_solutions*n_cols),axs.shape))
_axs = np.full(fill_value=None,shape=(n_solutions,n_cols*n_y),dtype=object)
_idx = np.full(fill_value=-1,shape=(n_solutions,n_cols*n_y,2),dtype=object)
counter = 0
for i in range(n_cols):
for j in range(n_y):
for c in range(n_solutions):
_axs[c,counter] = axs[j,i*n_solutions+c]
_idx[c,counter,:] = np.array([j,i*n_solutions+c])
counter += 1
for id_,ax_,df,title in zip(_idx,_axs,typing_data,typing_data.index):
df = df.copy() # dont change the original df
df = df[colors.index.intersection(df.columns)]
if normalize: # normalize values per column to give alphas between 0 and 1
df = df - df.min(axis=0).to_numpy()
df = df / df.max(axis=0).to_numpy()
df[['x','y']] = coords
df = df.loc[:,~df.isna().all(axis=0)] # remove all-nan columns
df = df.loc[~df.isna().any(axis=1)] # remove any-nan rows
if isinstance(render, bool) and not render:
_scatter_frame(df, colors, ax_, id_=id_, cmap=cmap, cmap_vmin_vmax=cmap_vmin_vmax, out_min_max=out_min_max, joint=joint, point_size=point_size, grid=grid, margin=margin, scale=scale, title=title, rasterized=rasterized)
else:
if not rasterized:
raise ValueError(f'`render!=False` only works when `rasterized==True`')
if isinstance(render, bool):
color_mix_mode = 'xyv'
else:
color_mix_mode = render
_render_frame (df, colors, ax_, id_=id_, cmap=cmap, cmap_vmin_vmax=cmap_vmin_vmax, out_min_max=out_min_max, joint=joint, point_size=point_size, grid=grid, margin=margin, scale=scale, title=title, color_mix_mode=color_mix_mode)
for ax in ax_:
_set_axes_labels(ax, axes_labels)
if noticks:
ax.set_xticks([])
ax.set_yticks([])
if joint is None or joint:
ax_[0].set_title(title)
if on_data_legend is not None:
def weighted_median(df, value_col, weights_col):
df = df[[value_col, weights_col]].sort_values(value_col)
return df[value_col][df[weights_col].cumsum() >= 0.5 * df[weights_col].sum()].iloc[0]
def find_closest_point(center, coords, weights):
coords = coords.loc[weights >= weights.max() * 0.9].to_numpy() # only consider points with weights above 90% of the max weight
dists = utils.cdist(np.array(center).reshape((1,2)),coords)
return coords[np.argmin(dists.flatten()).flatten()[0]]
for annotation in [c for c in df.columns if c not in ['x','y']]:
medians = []
for direction in ['x','y']:
medians.append(weighted_median(df, direction, annotation))
closest = find_closest_point(medians, df[['x','y']], df[annotation])
for ax in ax_:
ax.text(*closest, on_data_legend[annotation] if annotation in on_data_legend else annotation, ha='center', va='center')
return fig
def write_results_to_adata(adata, typing_data, pfn_factors=None, pfn_key='platform_normalization_factors'):
if pfn_key in adata.varm:
pfns = adata.varm[pfn_key]
else:
pfns = pd.DataFrame(index=adata.var.index)
for typing, data in typing_data.items():
adata.obsm[typing] = data.reindex(adata.obs.index)
if pfn_factors is not None and typing in pfn_factors and pfn_factors[typing] is not None:
pfns[typing] = pfn_factors[typing]
if len(pfns.columns) > 0:
adata.varm[pfn_key] = pfns
[docs]
def get_default_colors(n, offset=0):
"""\
Chooses default colors.
This is a convenience wrapper around :func:`seaborn.color_palette` which
provides a quasi-periodicity of 10, i.e. every 10 colors, the colors are
related.
Parameters
----------
n
Number of colors to choose OR a list of keys to choose colors for.
offset
The number of chosen colors to skip before starting to use them. This
is useful for generating different sets of colors.
Returns
-------
If `n` is a number, returns a list of colors. If `n` is a list-like,\
returns a mapping of the elements of `n` to colors.
"""
if pd.api.types.is_list_like(n):
default_colors = get_default_colors(len(n), offset)
return {name:color for name,color in zip(n,default_colors)}
else:
default_colors = [ to_hex(c) for c in [*sns.color_palette("bright"),*sns.color_palette("deep"),*sns.color_palette("dark"),*sns.color_palette("pastel")] ]
default_colors *= ((n+offset) // len(default_colors) + 1)
return default_colors[offset:(n+offset)]
[docs]
def mix_base_colors(weights, base_colors_rgb, mode='xyv'):
"""\
Mix colors "additively". In contrast to weighted averages over "rgb" values
(which results in quite dark colors), the average can be done in "xyv"
space, which is "hsv" with the "hs" part converted from polar to cartesian
coordinates.
Parameters
----------
weights
A weight tensor with last dimension `n_base` describing mixtures of
`n_base` colors; this can be a :class:`~numpy.ndarray` or a
:class:`~pandas.DataFrame`.
base_colors_rgb
An `n_base X 3` matrix defining the base colors in rgb space; this must
be a :class:`~numpy.ndarray`.
mode
The mixing mode; available are:
- 'rgb': average the rgb values in the rgb cube
- 'xyv': average the xyz values in the hsv cylider
Returns
-------
Returns the color mixtures depending on the type of `weights` either as\
:class:`~numpy.ndarray` or :class:`~pandas.DataFrame`.
"""
weights_index = None
if hasattr(weights, 'to_numpy'):
weights_index = weights.index
weights = weights.to_numpy()
weights = weights / weights.sum(axis=-1)[...,None]
if mode == 'xyv':
base_colors_hsv = np.array([matplotlib.colors.rgb_to_hsv(base_color_rgb[:3]) for base_color_rgb in base_colors_rgb])
base_colors_xyv = np.array([
np.cos(base_colors_hsv[:,0] * 2 * np.pi) * base_colors_hsv[:,1],
np.sin(base_colors_hsv[:,0] * 2 * np.pi) * base_colors_hsv[:,1],
base_colors_hsv[:,2]
]).T
mixed_colors_xyv = weights @ base_colors_xyv
mixed_colors_hsv = np.stack([
np.arctan2(mixed_colors_xyv[...,1], mixed_colors_xyv[...,0])/(2 * np.pi),
np.sqrt(mixed_colors_xyv[...,0]**2 + mixed_colors_xyv[...,1]**2),
mixed_colors_xyv[...,2]
],axis=-1)
mixed_colors_hsv[mixed_colors_hsv[...,0]<0,0] += 1
mixed_colors_rgb = matplotlib.colors.hsv_to_rgb(mixed_colors_hsv)
if weights_index is not None:
mixed_colors_rgb = pd.DataFrame(mixed_colors_rgb, index=weights_index)
elif mode == 'rgb':
mixed_colors_rgb = weights @ base_colors_rgb
else:
raise ValueError(f'The mode "{mode}" is not implemented!')
return mixed_colors_rgb
def _filter_types(typing_data, types, colors, show_only):
if show_only is not None:
show_only = pd.Index(show_only)
if show_only.isin(types).all():
colors = colors[show_only]
types = colors.index
else:
raise Exception('Not all selected types %s are available in the data %s!' % (show_only, types))
typing_data = typing_data.map(lambda data: data.reindex(columns=types))
return typing_data, types, colors
def _get_colors(colors, types):
if colors is not None:
colors = pd.Series(colors)
if types.isin(colors.index).all():
#colors = colors[colors.index.intersection(types)]
types = colors.index
else:
raise Exception('Not all types %s are given colors with %s!' % (types, colors))
else:
colors = pd.Series(get_default_colors(len(types)), index=types)
return colors, types
def _get_adatas(adata):
if isinstance(adata, ad.AnnData):
adatas = pd.Series(index=[''],dtype=object)
adatas[''] = adata
elif isinstance(adata, pd.DataFrame):
adatas = pd.Series(index=[''],dtype=object)
adatas[''] = utils.dataframe2anndata(adata, None, None)
elif isinstance(adata, dict):
#adatas = pd.Series(adata) # it could be so simple - but it does not work for adatas...
adatas = pd.Series(index=adata.keys(),dtype=object)
for k,v in adata.items():
if isinstance(v, pd.DataFrame):
v = utils.dataframe2anndata(v, None, None)
adatas[k] = v
else:
adatas = adata
if (adatas.index.value_counts() != 1).any():
raise ValueError('The series of adatas has non-unique indices: %s!', adatas.index)
return adatas
def _validate_args(adata, keys, colors, show_only, reads=False, method_labels=None, counts_location=None, compositional=True):
adatas = _get_adatas(adata)
if isinstance(keys, str):
methods = [keys]
else:
methods = keys
try:
#methods = pd.Series({k:v for k,v in methods.items()}) # check basically whether methods is a dict or pd.Series (having an items method), mapping sample names to methods
# check basically whether methods is a dict or pd.Series (having an items method), mapping sample names to methods
_methods = {}
for k,v in methods.items():
if isinstance(v,str):
v = [v]
_methods[k] = v
methods = pd.Series(_methods)
except:
methods = pd.Series([methods]*len(adatas), index=adatas.index) # assume same methods available for all samples
typing_data = []
types = pd.Index([],dtype=object)
for sample, adata in adatas.items():
for method in methods[sample]:
if pd.api.types.is_list_like(method):
data = pd.DataFrame(index=adata.obs.index)
for element in method:
if element in adata.obs:
data[element] = adata.obs[element]
if not pd.api.types.is_numeric_dtype(data[element]):
raise Exception(f'The adata obs key {element} did contain non-numeric data%s!' % ('' if sample == '' else (' for sample "%s"' % sample)))
elif element in adata.var.index:
_data = adata[:,element].X
if scipy.sparse.issparse(_data):
_data = _data.A
data[element] = _data.flatten()
else:
raise Exception(f'The key {element} is neither in obs.columns nor in var.index%s!' % ('' if sample == '' else (' for sample "%s"' % sample)))
elif method in adata.obsm:
data = adata.obsm[method].copy()
elif method in adata.obs:
iscat = hasattr(adata.obs[method], 'cat')
number = pd.api.types.is_numeric_dtype(adata.obs[method])
if iscat or not number:
data = pd.get_dummies(adata.obs[method])
if data.shape[1] > 100:
print(f'More than 30 categories were discovered in column `.obs[{method}]`, which probably leads to slow plotting! (This message can be removed, if the column is made a categorical column).')
else:
data = pd.DataFrame({method:adata.obs[method]})
else:
raise Exception(f'The method key {method} is neither obs(m) key nor a list-like%s!' % ('' if sample == '' else (' for sample "%s"' % sample)))
if compositional:
data *= (data>=0).to_numpy() # some methods may export negative values... (RCTD did that once)
data = (data / data.fillna(0).sum(axis=1).to_numpy()[:,None]).fillna(0)
if reads: # scale by read count
counts = get.counts(adata, counts_location=counts_location)[adata.obs.index]
if counts.shape[1] != 0: # ...but only if there is data for that
data *= np.array(counts.X.sum(axis=1)).flatten()[:,None]
if compositional == 'catsize':
sizes = get_cellsize(adata, data)
data /= sizes[data.columns].to_numpy()
data = (data / data.fillna(0).sum(axis=1).to_numpy()[:,None]).fillna(0)
types = types.union(data.columns,sort=False)
method_label = method if ((method_labels is None) or (method not in method_labels)) else method_labels[method]
if sample == '' and method_label == '':
name = ''
elif sample == '':
name = method_label
elif isinstance(method_label, str) and method_label == '':
name = sample
else:
name = f'{sample}; {method_label}'
typing_data.append((name, sample, method, data))
typing_data = pd.DataFrame(typing_data, columns=['name','sample','method','data']).set_index('name')
colors, types = _get_colors(colors, types)
if compositional and len(types) < 1:
print(f'`compositional==True`, but there were less than 2 categories: {types!r}')
typing_data['data'], types, colors = _filter_types(typing_data['data'], types, colors, show_only)
return typing_data, adatas, methods, types, colors
def _validate_scatter_args(adata, position_key, keys, colors, show_only, method_labels=None, counts_location=None, compositional=True):
typing_data, adatas, methods, types, colors = _validate_args(adata, keys, colors, show_only, method_labels=method_labels, counts_location=counts_location, compositional=compositional)
coords = {}
for sample, adata in adatas.items():
_coords = get.positions(adata, position_key)
_coords.columns = ['x','y']
coords[sample] = _coords#.rename(columns={_coords.columns[0]:'x',_coords.columns[1]:'y'})
return typing_data, adatas, methods, types, colors, coords
[docs]
def subplots(
n_x=1,
n_y=1,
axsize=(5,5),
hspace=0.15,
wspace=0.15,
x_padding=None,
y_padding=None,
title=None,
sharex='none',
sharey='none',
width_ratios=None,
height_ratios=None,
x_shifts=None,
y_shifts=None,
dpi=None,
**kwargs,
):
"""\
Creates a new figure with a grid of subplots.
This is a convenience wrapper around :func:`matplotlib.pyplot.subplots`
with parameters for axis instead of figure and absolute instead of relative
units.
Parameters
----------
n_x
Number of plots in horizontal/x direction
n_y
Number of plots in vertical/y direction
axsize
Size of a single axis in the plot
hspace
Relative vertical spacing between plots
wspace
Relative horizontal spacing between plots
x_padding
Absolute horizontal spacing between plots; this setting overrides
`wspace`; if `None`, use the value from `wspace`
y_padding
Absolute vertical spacing between plots; this setting overrides
`hspace`; if `None`, use the value from `hspace`
title
Sets the figure suptitle
sharex
Parameter for sharing the x-axes between the subplots; see the
documentation of :func:`matplotlib.pyplot.subplots`
sharey
Parameter for sharing the y-axes between the subplots; see the
documentation of :func:`matplotlib.pyplot.subplots`
width_ratios
Sets the ratios of widths for the columns of the subplot grid keeping
the width of the widest column; if `None`, all columns have the same
width
height_ratios
Sets the ratios of heights for the rows of the subplot grid keeping the
height of the highest row; if `None`, all rows have the same height
x_shifts
The absolute shifts in position in horizontal/x direction per column of
subplots; if `None`, the columns are not shifted
y_shifts
The absolute shifts in position in vertical/y direction per row of
subplots; if `None`, the rows are not shifted
dpi
The dpi setting to use for this figure
**kwargs
Extra keyword arguments are forwarded to
:func:`matplotlib.pyplot.subplots`
Returns
-------
A pair of the created :class:`~matplotlib.figure.Figure` and an 2d\
:class:`~numpy.ndarray` of :class:`~matplotlib.axes.Axes`.
"""
if x_padding is not None:
wspace = x_padding / axsize[0]
if y_padding is not None:
hspace = y_padding / axsize[1]
if width_ratios is None:
effective_n_x = n_x
effective_wspace = wspace
else:
effective_n_x = sum(width_ratios) / max(width_ratios)
effective_wspace = wspace * n_x / effective_n_x
if height_ratios is None:
effective_n_y = n_y
effective_hspace = hspace
else:
effective_n_y = sum(height_ratios) / max(height_ratios)
effective_hspace = hspace * n_y / effective_n_y
# find the values of the figure size which will keep the axis size identical for all numbers of columns and rows...
fig_height = axsize[1] * (effective_n_y + hspace * (n_y - 1))
fig_width = axsize[0] * (effective_n_x + wspace * (n_x - 1))
top = 1.0
if title != None:
title_space = 0.75
fig_height += title_space
top = 1 - title_space / fig_height
if dpi is not None:
kwargs = {**kwargs}
kwargs['dpi'] = dpi
fig, axs = plt.subplots(n_y,n_x,figsize=(fig_width,fig_height), squeeze=False, sharex=sharex, sharey=sharey, gridspec_kw={'wspace':effective_wspace,'hspace':effective_hspace,'left':0,'right':1,'top':top,'bottom':0,'width_ratios': width_ratios,'height_ratios': height_ratios}, **kwargs)
if title is not None:
fig.suptitle(title, fontsize=16, y=1)
if x_shifts is not None or y_shifts is not None:
if x_shifts is None:
x_shifts = [0.0] * n_x
else:
y_shifts = [0.0] * n_y
for i_x in range(n_x):
for i_y in range(n_y):
[left,bottom,width,height] = axs[i_y,i_x].get_position().bounds
axs[i_y,i_x].set_position([left+x_shifts[i_x],bottom+y_shifts[i_y],width,height])
return fig, axs
def _add_legend_or_colorbars(fig, axs, colors, cmap=None, min_max=None, scale_legend=1.0):
if cmap is None:
axs[0,-1].legend(handles=[mpatches.Patch(color=color, label=ind) for (ind, color) in colors.items() ],
bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
elif cmap is not None:
rel_dpi_factor = fig.get_dpi() / 72
height_pxl = 200 * rel_dpi_factor * scale_legend
width_pxl = 15 * rel_dpi_factor * scale_legend
offset_top_pxl = 0 * rel_dpi_factor * scale_legend
offset_left_pxl = 10 * rel_dpi_factor * scale_legend
for irow in range(axs.shape[0]):
for jcol in range(axs.shape[1]):
ax = axs[irow,jcol]
left,bottom = fig.transFigure.inverted().transform(ax.transAxes.transform((1,1))+np.array([offset_left_pxl,-offset_top_pxl-height_pxl]))
width,height = fig.transFigure.inverted().transform(fig.transFigure.transform((0,0))+np.array([width_pxl,height_pxl]))
cax = fig.add_axes((left, bottom, width, height))
norm = Normalize(vmin=min_max[irow,jcol,0], vmax=min_max[irow,jcol,1])
fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax)
[docs]
def scatter(
adata,
keys,
position_key=('x','y'),
group_key=None,
colors=None,
show_only=None,
axsize=(5,5),
padding=0.5,
margin=0.0,
sharex=False,
sharey=False,
share_scaling=True,
n_cols=1,
joint=True,
method_labels=None,
counts_location=None,
compositional=False,
normalize=False,
point_size=3,
cmap=None,
cmap_vmin_vmax=None,
legend=True,
on_data_legend=None,
title=None,
render=True,
rasterized=True,
background_color=None,
grid=False,
noticks=False,
axes_labels=None,
ax=None,
):
"""\
Scatter plots of annotation.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs` and/or
`.obsm`. Can also be a mapping of labels to :class:`~anndata.AnnData`
to specify multiple datasets. The :class:`~anndata.AnnData` instances
can be replaced also by :class:`~pandas.DataFrame`, which are then
treated like the `.obs` of an :class:`~anndata.AnnData`.
keys
The `.obs`/`.obsm` annotation keys to compare. Can be a single string,
or a list of strings, or a mapping of the labels of `adata` to strings
or lists of strings. In the list or mapping variant, categorical `.obs`
keys can be replaced by list-likes of numerical `.obs` keys or gene
names can be used.
position_key
The `.obsm` key or array-like of `.obs` keys with the position space
coordinates.
group_key
An `.obs` key with categorical group information to split `adata` prior
to plotting. This works only if `adata` is a single
:class:`~anndata.AnnData` instance.
colors
A mapping of annotation values to colors. If `None`, default colors are
used.
show_only
A subset of annotation values to restrict the plotting to.
axsize
Tuple of width and height of a single axis. If one of them is `None`,
it is determined from the aspect ratio of the data. If it is a single
scalar value, then this is interpreted as a conversion factor from
data units to axsize units and `share_scaling` is ignored.
padding
The absolute padding between the plots.
margin
The absolute margin between the outermost data points and the boundary
of the plot
sharex
Whether to use common x axis for all axes.
sharey
Whether to use common y axis for all axes.
share_scaling
Whether to have the units in all plots be of the same size in pixels
n_cols
Number of "columns" to plot: If larger than 1 splits columns of plots
into `n_cols` columns.
joint
Whether to plot only one scatter plot with all annotation categories or
only the scatter plots with one annotation category per plot. If
`None`, plot both.
method_labels
A mapping from the strings in `keys` and `basis_keys` to (shorter)
names to show in the plot.
counts_location
A string or tuple specifying where the count matrix is stored, e.g.
`'X'`, `('raw','X')`, `('raw','obsm','my_counts_key')`,
`('layer','my_counts_key')`, ... For details see
:func:`~tacco.get.counts`.
compositional
Whether the annotation is to be interpreted as compositional data or as
arbitrary numbers. Compositional data is normalized to sum to 1 per
observation. Can also be 'catsize', which rescales the compositions by
the average observed size of an annotation category in terms of the
contents of `.X`, e.g. reads.
normalize
Whether to shift the data to non-negative values and normalize them by
their maximum.
point_size
The size of the points in the plot. Like in matplotlib, this is a
measure of point area and provided in units of "squared points"
corresponding to (1/72)^2 inch^2 = (25.4/72)^2 mm^2.
cmap
A string/colormap to override the `colors` with. Makes sense mostly for
numeric data.
cmap_vmin_vmax
A tuple giving the range of values for the colormap.
legend
Whether to plot a legend
on_data_legend
A mapping from annotation values to (shortened) versions of the labels
to use for labels on the plot at the center of the annotation;
annotations not occurring in the mapping are used as is; if `None`, no
annotation is plotted on the data.
title
The title of the figure
render
Whether the scatterplot should be custom rendered using the dpi setting
from matplotlib or plotted using a matplotlib's scatterplot. If `True`,
the different annotations from the same (and overlapping) positions are
added up symmetrically, if `False`, they are plottet on top of each
other using an alpha channel proportional to the weight. `True` also
has the advantage that only the scatter part of the figure will be
exported as pixelated version if the plot is exported as vectorgraphic,
with the rest like labels and axes being exported as a vectorgraphic.
This parameter provides control over the type of color averaging in the
process of the rendering by specifying one of the modes available in
:func:`~tacco.plots.mix_base_colors`, e.g. "xyv" or "rgb", with "xyv"
being equivalent to setting `True`.
rasterized
Whether to rasterize the interior of the plot, even when exported later
as vectorgraphic. This leads to much smaller plots for many (data)
points. `rasterized==False` is incompatible with `render==True` or
string.
This parameter provides experimental support for plotting pie charts
per dot via ´rasterized=="pie"´ and ´render==False´. This is much
slower, so only usable for very few points.
background_color
The background color to draw the points on.
grid
Whether to draw a grid
noticks
Whether to switch off ticks on the axes.
axes_labels
Labels to write on the axes as an list-like of the two labels.
ax
The 2d array of :class:`~matplotlib.axes.Axes` instances to plot on.
The array dimensions have to agree with the number of axes which would
be created automatically if `ax` was not supplied. If it is a single
instance it is treated as a 1x1 array.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
if group_key is not None:
if not isinstance(adata, ad.AnnData):
raise ValueError(f'The `group_key` {group_key!r} is not `None`, but `adata` is not a single `AnnData` instance!')
if group_key not in adata.obs:
raise ValueError(f'The `group_key` {group_key!r} is not available in `adata.obs`!')
adata = { c: adata[df.index] for c, df in adata.obs.groupby(group_key, observed=False) if len(df) > 0 }
typing_data, adatas, methods, types, colors, coords = _validate_scatter_args(adata, position_key, keys, colors, show_only, method_labels=method_labels, counts_location=counts_location, compositional=compositional)
n_solutions, n_samples, n_types = len(typing_data), len(adatas), len(types)
if joint is None:
n_types = n_types + 1
elif joint:
n_types = 1
n_cols = 1
n_x = n_solutions*n_cols
n_y = (n_types + n_cols - 1) // n_cols
axsize = np.array(axsize)
scale = None
# special treatment of various semi-automatic axsizes
if len(axsize.shape) == 0 or axsize[0] is None or axsize[1] is None or ax is not None or share_scaling:
minxs, maxxs = [], []
minys, maxys = [], []
for i_sample, sample in enumerate(typing_data['sample'].unique()):
minx,maxx = get_min_max(coords[sample].iloc[:,0], log=False)
miny,maxy = get_min_max(coords[sample].iloc[:,1], log=False)
minxs.append(minx), maxxs.append(maxx)
minys.append(miny), maxys.append(maxy)
minxs, maxxs = np.array(minxs), np.array(maxxs)
minys, maxys = np.array(minys), np.array(maxys)
sizexs = maxxs-minxs
sizeys = maxys-minys
maxsizex = sizexs.max()
maxsizey = sizeys.max()
minx, maxx = minxs.min(), maxxs.max()
miny, maxy = minys.min(), maxys.max()
sizex = maxx-minx
sizey = maxy-miny
if ax is not None: # If `ax` is given, use it. This also ensures that axs is always defined below
if isinstance(ax, matplotlib.axes.Axes):
axs = np.array([[ax]])
else:
axs = ax
if axs.shape != (n_y,n_x):
raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(n_y,n_x)!r} supplied was {axs.shape!r}!')
axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size
elif len(axsize.shape) == 0: # if it is a single scalar value, use that as a global scale
# this also implies that all axes have to have a common scaling
if not (sharex or sharey):
share_scaling = True
if sharex:
axsizex = sizex * axsize
else:
axsizex = maxsizex * axsize
if sharey:
axsizey = sizey * axsize
else:
axsizey = maxsizey * axsize
axsize = np.array([axsizex, axsizey])
elif axsize[0] is None or axsize[1] is None:
# determine the missing axsize
if axsize[0] is None and axsize[1] is None:
raise ValueError(f'The parameter `axsize` got `(None, None)`, while only one of the entries of the tuple can be `None`!')
aspect_ratio = (sizey if sharey else sizeys) / (sizex if sharex else sizexs)
if not (sharex and sharey): # if aspect_ratio is not yet fixed and still a vector
#aspect_ratio = max(aspect_ratio)
#aspect_ratio = min(aspect_ratio)
#np.argmax(np.abs(aspect_ratio - 1))
if sharex: # x is fixed: make y as large as needed
aspect_ratio = max(aspect_ratio)
elif sharey: # y is fixed: make x as large as needed
aspect_ratio = min(aspect_ratio)
else:
aspect_ratio = aspect_ratio.mean() # we are free to take the one which fits best - in some arbitrary sense...
if axsize[0] is None:
axsize[0] = axsize[1] / aspect_ratio
else: # axsize[1] is None
axsize[1] = axsize[0] * aspect_ratio
if share_scaling:
scale_0 = (axsize[0]-2*margin) / maxsizex
scale_1 = (axsize[1]-2*margin) / maxsizey
scale = min(scale_0, scale_1)
aspect_ratio = axsize[1] / axsize[0]
wspace, hspace = padding / axsize
if ax is None:
fig, axs = subplots(n_x,n_y,axsize=axsize,sharex=sharex,sharey=sharey,wspace=wspace,hspace=hspace,title=title)
else:
fig = axs[0,0].get_figure()
min_max = np.zeros((*axs.shape,2))
for i_sample, sample in enumerate(typing_data['sample'].unique()):
sub = typing_data[typing_data['sample']==sample]
n_methods = len(sub)
ax = axs[:,(i_sample*n_methods*n_cols):((i_sample+1)*n_methods*n_cols)]
_min_max = None if cmap is None else min_max[:,(i_sample*n_methods*n_cols):((i_sample+1)*n_methods*n_cols)]
spatial_distribution_plot(sub['data'], coords[sample], colors, axs=ax, n_cols=n_cols, joint=joint, normalize=normalize, point_size=point_size, cmap=cmap, cmap_vmin_vmax=cmap_vmin_vmax, out_min_max=_min_max, scale=scale, grid=grid, margin=margin, render=render, rasterized=rasterized, noticks=noticks, axes_labels=axes_labels, on_data_legend=on_data_legend)
if background_color is not None:
for _ax in ax.flatten():
_ax.set_facecolor(background_color)
if share_scaling:
fig.canvas.draw() # make the transformations have well defined values
ax_heights = []
ax_widths = []
for _ax in axs.flatten():
ax_x_low,ax_y_low = _ax.transData.inverted().transform(_ax.transAxes.transform((0,0)))
ax_x_high,ax_y_high = _ax.transData.inverted().transform(_ax.transAxes.transform((1,1)))
ax_height,ax_width = (ax_y_high-ax_y_low),(ax_x_high-ax_x_low)
ax_heights.append(ax_height)
ax_widths.append(ax_width)
max_height = max(ax_heights)
max_width = max(ax_widths)
for _ax in axs.flatten():
ax_x_low,ax_y_low = _ax.transData.inverted().transform(_ax.transAxes.transform((0,0)))
ax_x_high,ax_y_high = _ax.transData.inverted().transform(_ax.transAxes.transform((1,1)))
axes_ratio = (ax_y_high-ax_y_low)/(ax_x_high-ax_x_low)
delta = max_height - (ax_y_high - ax_y_low)
ax_y_high += delta / 2
ax_y_low -= delta / 2
_ax.set_ylim((ax_y_low,ax_y_high))
delta = max_width - (ax_x_high - ax_x_low)
ax_x_high += delta / 2
ax_x_low -= delta / 2
_ax.set_xlim((ax_x_low,ax_x_high))
fig.canvas.draw() # make the transformations have well defined values
# make axes use the full axsize
for _ax in axs.flatten():
ax_x_low,ax_y_low = _ax.transData.inverted().transform(_ax.transAxes.transform((0,0)))
ax_x_high,ax_y_high = _ax.transData.inverted().transform(_ax.transAxes.transform((1,1)))
axes_ratio = (ax_y_high-ax_y_low)/(ax_x_high-ax_x_low)
if aspect_ratio > axes_ratio:
# increase height
delta = (aspect_ratio / axes_ratio - 1) * (ax_y_high - ax_y_low)
ax_y_high += delta / 2
ax_y_low -= delta / 2
_ax.set_ylim((ax_y_low,ax_y_high))
else:
# increase width
delta = (axes_ratio / aspect_ratio - 1) * (ax_x_high - ax_x_low)
ax_x_high += delta / 2
ax_x_low -= delta / 2
_ax.set_xlim((ax_x_low,ax_x_high))
if legend:
_add_legend_or_colorbars(fig, axs, colors, cmap=cmap, min_max=min_max)
return fig
def get_cellsize(adata, key='OTT', reference_adata=None, reference_key='OTT', pfn_key=None, counts_location=None):
if isinstance(key, pd.DataFrame):
solution = key
elif key in adata.obsm:
solution = adata.obsm[key]
else:
solution = pd.get_dummies(adata.obs[key])
if reference_adata is not None:
counts = get.counts(adata, counts_location=counts_location)[adata.obs.index,adata.var.index]
pfn = adata.var.index.isin(reference_adata.var.index).astype(float) # basic level of rescaling: using common genes only
bare_read_count = pd.Series(np.array(counts.X @ pfn).flatten(), index=adata.obs.index)
#print('get_cell_size0:',pfn.sum())
if pfn_key is not None:
if pfn_key in adata.varm and key in adata.varm[pfn_key]:
pfn = pfn * adata.varm[pfn_key][key].to_numpy()
#print('multiplying', adata.varm[pfn_key][key].to_numpy())
if pfn_key in reference_adata.varm and reference_key in reference_adata.varm[pfn_key]:
pfn = pfn / reference_adata.varm[pfn_key][reference_key].reindex(adata.var.index).to_numpy()
#print('dividing', reference_adata.varm[pfn_key][reference_key].reindex(adata.var.index).to_numpy())
#print('get_cell_size1:',pfn.sum())
pfn = np.nan_to_num(pfn).astype(float)
#print('get_cell_size2:',pfn.sum())
read_count = pd.Series(np.array(counts.X @ pfn).flatten(), index=adata.obs.index)
read_count *= bare_read_count.sum() / read_count.sum()
else:
counts = get.counts(adata, counts_location=counts_location)[adata.obs.index]
read_count = pd.Series(np.array(counts.X.sum(axis=1)).flatten(), index=adata.obs.index)
by_cells = solution / solution.sum(axis=1).to_numpy()[:,None]
by_reads = by_cells * read_count[solution.index].to_numpy()[:,None]
cellsize = by_reads.sum(axis=0) / by_cells.sum(axis=0).to_numpy()
return cellsize
def _validate_cross_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads, method_labels, counts_location=None):
typing_data, adatas, methods, types, colors = _validate_args(adata, keys, colors, show_only, reads=reads, method_labels=method_labels, counts_location=counts_location)
if basis_adata is None:
basis_adata = adata
if basis_keys is None:
basis_keys = keys
basis_typing_data, basis_adatas, basis_methods, basis_types, basis_colors = _validate_args(basis_adata, basis_keys, colors, show_only, reads=reads, method_labels=method_labels, counts_location=counts_location)
if len(types.intersection(basis_types)) != len(types):
raise Exception('The available types provided in adata %s and basis adata %s dont agree!' % (types, basis_types))
return typing_data, adatas, methods, types, colors, basis_typing_data, basis_adatas, basis_methods
def _cross_scatter(data, axs, colors, marker='o'):
if data.shape != axs.shape:
raise Exception('The shapes of the data and axs arrays %s and %s dont agree!' % (data.shape, axs.shape))
n_y, n_x = data.shape
for x in range(n_x):
for y in range(n_y):
x_data,y_data = data.iloc[y,x]
ax = axs[y,x]
ax.grid(True)
common = colors.index
common = common.intersection(x_data[x_data > 0].index)
common = common.intersection(y_data[y_data > 0].index)
_scatter_plot(ax,x=x_data[common],y=y_data[common],
colors=colors[common],
marker=marker, log=True)
if y == n_y-1:
ax.set_xlabel(data.columns[x],rotation=15,va='top',ha='right')
ax.tick_params(axis='x', rotation=45, which='both')
if x == 0:
ax.set_ylabel(data.index[y],rotation='horizontal',ha='right')
axs[0,n_x-1].legend(handles=[mpatches.Patch(color=color, label=ind) for (ind, color) in colors.items() ],
bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
[docs]
def cellsize(
adata,
keys,
colors=None,
show_only=None,
axsize=(1.5,1.5),
basis_adata=None,
basis_keys=None,
pfn_key=None,
use_reference=False,
method_labels=None,
counts_location=None
):
"""\
Scatter plots of average cell sizes in an annotation category in the whole
dataset against some reference dataset. These cell sizes are given as the
average number of counts per unit cell.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs` and/or `.obsm`.
Can also be a mapping of labels to :class:`~anndata.AnnData` to specify
multiple datasets.
keys
The `.obs`/`.obsm` annotation keys to compare. Can be a single string, or a
list of strings, or a mapping of the labels of `adata` to strings or lists
of strings.
colors
A mapping of annotation values to colors. If `None`, default colors are used.
show_only
A subset of annotation values to restrict the plotting to.
axsize
Tuple of width and size of a single axis.
basis_adata
like `adata`, but for reference data.
basis_keys
like `adata`, but for reference data.
pfn_key
A `.varm` key containing platform normalization factors. Ignored if platform
normalization is not requested via `use_reference`.
use_reference
Whether and how to use reference data for the determination of cell sizes in
non-reference data.
Can be a single choice or an array of choices. Possible choices are:
- `False`: Dont use the reference
- `True`: Determine the cell sizes on the set of common genes with the
reference data
- 'pfn': Like `True`, but use additionally platform normalization factors.
- `None`: Use all settings which are available, i.e. includes `True`, `False,
and if `pfn_key!=None` also 'pfn'.
method_labels
A mapping from the strings in `keys` and `basis_keys` to (shorter) names
to show in the plot.
counts_location
A string or tuple specifying where the count matrix is stored, e.g. `'X'`,
`('raw','X')`, `('raw','obsm','my_counts_key')`, `('layer','my_counts_key')`,
... For details see :func:`~tacco.get.counts`.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
typing_data, adatas, methods, types, colors, basis_typing_data, basis_adatas, basis_methods = _validate_cross_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads=False, method_labels=method_labels, counts_location=counts_location)
n_solutions, n_samples, n_types = len(typing_data), len(adatas), len(types)
n_solutions_basis, n_samples_basis = len(basis_typing_data), len(basis_adatas)
possible_use_refs = [False,True,'pfn']
if isinstance(use_reference, (str,bool)):
use_reference = [use_reference]
elif use_reference is None:
if pfn_key is None:
use_reference = [False,True]
else:
use_reference = possible_use_refs
if 'pfn' in use_reference and pfn_key is None:
raise Exception('If platform normalization is desired (as indicated by specifying "pfn" in the "use_reference" argument), the argument "pfn_key" has to be specified!')
use_ref_labels = []
for use_ref in use_reference:
if use_ref not in possible_use_refs:
raise Exception('The argument use_reference got the unknown value "%s"! Only %s are possible.' % (use_ref, possible_use_refs))
if isinstance(use_ref,str):
if use_ref == 'pfn':
use_ref_label = ' - platform corrected'
elif use_ref: # ignore pfn for cellsize estimation
use_ref_label = ' - common genes'
else: # ignore reference for cellsize estimation
use_ref_label = ''
use_ref_labels.append(use_ref_label)
n_x = n_solutions
n_y = n_solutions_basis * len(use_reference)
fig, axs = subplots(n_x,n_y,axsize=axsize,sharex='all',sharey='all',wspace=0,hspace=0)
data = pd.DataFrame(dtype=object,columns=typing_data.index,index=[basis_typing_data.index[y % n_solutions_basis] + use_ref_labels[y // n_solutions_basis] for y in range(n_y)])
for x in range(n_x):
for y in range(n_y):
x_key = typing_data['method'].iloc[x]
x_adata = adatas[typing_data['sample'].iloc[x]]
y_key = basis_typing_data['method'].iloc[y % n_solutions_basis]
y_adata = basis_adatas[basis_typing_data['sample'].iloc[y % n_solutions_basis]]
use_ref = use_reference[y // n_solutions_basis]
if isinstance(use_ref,str):
if use_ref == 'pfn':
cell_sizes_x = get_cellsize(x_adata, x_key, reference_adata=y_adata, reference_key=y_key, counts_location=counts_location)#, pfn_key=None) # no platform normalization effects, only on common genes
cell_sizes_y = get_cellsize(y_adata, y_key, reference_adata=x_adata, reference_key=x_key, counts_location=counts_location, pfn_key=pfn_key) # include platform normalization effects for both adatas in these size estimations
elif use_ref: # ignore pfn for cellsize estimation
cell_sizes_x = get_cellsize(x_adata, x_key, reference_adata=y_adata, counts_location=counts_location) # no platform normalization effects, only on common genes
cell_sizes_y = get_cellsize(y_adata, y_key, reference_adata=x_adata, counts_location=counts_location) # include platform normalization effects for both adatas in these size estimations
else: # ignore reference for cellsize estimation
cell_sizes_x = get_cellsize(x_adata, x_key, counts_location=counts_location)
cell_sizes_y = get_cellsize(y_adata, y_key, counts_location=counts_location)
data.iloc[y,x] = (cell_sizes_x,cell_sizes_y)
_cross_scatter(data, axs, colors, marker='x')
return fig
def _validate_frequency_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads, method_labels=None, counts_location=None):
typing_data, adatas, methods, types, _colors = _validate_args(adata, keys, colors, show_only, reads, method_labels=method_labels, counts_location=counts_location)
if basis_adata is not None:
basis_typing_data, basis_adatas, basis_methods, basis_types, _colors = _validate_args(basis_adata, basis_keys, colors, show_only, reads, method_labels=method_labels, counts_location=counts_location)
if len(types.intersection(basis_types)) != len(types):
raise Exception('The available types provided in adata %s and basis adata %s dont agree!' % (types, basis_types))
typing_data = pd.concat([basis_typing_data, typing_data])
return typing_data, adatas, methods, types, _colors
[docs]
def frequency_bar(
adata,
keys,
colors=None,
show_only=None,
axsize=None,
basis_adata=None,
basis_keys=None,
horizontal=True,
reads=False,
method_labels=None,
counts_location=None,
ax=None,
):
"""\
Bar plots of the total frequency of annotation in the whole dataset against
some reference dataset.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs` and/or
`.obsm`. Can also be a mapping of labels to :class:`~anndata.AnnData`
to specify multiple datasets. The :class:`~anndata.AnnData` instances
can be replaced also by :class:`~pandas.DataFrame`, which are then
treated like the `.obs` of an :class:`~anndata.AnnData`.
keys
The `.obs`/`.obsm` annotation keys to compare. Can be a single string,
or a list of strings, or a mapping of the labels of `adata` to strings
or lists of strings.
colors
A mapping of annotation values to colors. If `None`, default colors are
used.
show_only
A subset of annotation values to restrict the plotting to.
axsize
Tuple of width and size of a single axis. If `None`, a default size is
chosen automatically.
basis_adata
like `adata`, but for reference data.
basis_keys
like `adata`, but for reference data.
horizontal
Whether to draw the bars horizontally.
reads
Whether to work with read or cell count fractions.
method_labels
A mapping from the strings in `keys` and `basis_keys` to (shorter)
names to show in the plot.
counts_location
A string or tuple specifying where the count matrix is stored, e.g.
`'X'`, `('raw','X')`, `('raw','obsm','my_counts_key')`,
`('layer','my_counts_key')`, ... For details see
:func:`~tacco.get.counts`.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting.
Returns
-------
The :class:`~matplotlib.figure.Figure` containing the plot.
"""
typing_data, adatas, methods, types, colors = _validate_frequency_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads, method_labels=method_labels, counts_location=counts_location)
n_solutions = len(typing_data)
if ax is None:
if axsize is None:
if horizontal:
axsize=(9.1,0.7*n_solutions)
else:
axsize=(0.7*n_solutions,5)
fig, axs = subplots(axsize=axsize)
ax = axs[0,0]
else:
fig = ax.figure
type_freqs = pd.DataFrame({ method: data.sum(axis=0) for method, data in typing_data['data'].items() })
norm_tf = (type_freqs/type_freqs.sum(axis=0)).T
norm_tf = norm_tf.fillna(0)
_composition_bar(norm_tf, colors, horizontal=horizontal, ax=ax)
return fig
[docs]
def frequency(
adata,
keys,
colors=None,
show_only=None,
axsize=(1.5,1.5),
basis_adata=None,
basis_keys=None,
reads=False,
method_labels=None,
counts_location=None
):
"""\
Scatter plots of the total frequency of annotation in the whole dataset
against some reference dataset.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs` and/or
`.obsm`. Can also be a mapping of labels to :class:`~anndata.AnnData`
to specify multiple datasets. The :class:`~anndata.AnnData` instances
can be replaced also by :class:`~pandas.DataFrame`, which are then
treated like the `.obs` of an :class:`~anndata.AnnData`.
keys
The `.obs`/`.obsm` annotation keys to compare. Can be a single string,
or a list of strings, or a mapping of the labels of `adata` to strings
or lists of strings.
colors
A mapping of annotation values to colors. If `None`, default colors are
used.
show_only
A subset of annotation values to restrict the plotting to.
axsize
Tuple of width and size of a single axis.
basis_adata
like `adata`, but for reference data.
basis_keys
like `adata`, but for reference data.
reads
Whether to work with read or cell count fractions.
method_labels
A mapping from the strings in `keys` and `basis_keys` to (shorter)
names to show in the plot.
counts_location
A string or tuple specifying where the count matrix is stored, e.g.
`'X'`, `('raw','X')`, `('raw','obsm','my_counts_key')`,
`('layer','my_counts_key')`, ... For details see
:func:`~tacco.get.counts`.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
typing_data, adatas, methods, types, colors, basis_typing_data, basis_adatas, basis_methods = _validate_cross_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads=reads, method_labels=method_labels, counts_location=counts_location)
n_solutions, n_samples, n_types = len(typing_data), len(adatas), len(types)
n_solutions_basis, n_samples_basis = len(basis_typing_data), len(basis_adatas)
n_x = n_solutions
n_y = n_solutions_basis
fig, axs = subplots(n_x,n_y,axsize=axsize,sharex='all',sharey='all',wspace=0,hspace=0)
type_freqs = pd.DataFrame({ method: data.sum(axis=0) for method, data in typing_data['data'].items() })
type_freqs = type_freqs / type_freqs.sum(axis=0) # normalize freqs
basis_type_freqs = pd.DataFrame({ method: data.sum(axis=0) for method, data in basis_typing_data['data'].items() })
basis_type_freqs = basis_type_freqs / basis_type_freqs.sum(axis=0) # normalize freqs
data = pd.DataFrame(dtype=object,columns=typing_data.index,index=basis_typing_data.index)
for x in range(n_x):
for y in range(n_y):
x_data = type_freqs.iloc[:,x]
y_data = basis_type_freqs.iloc[:,y]
data.iloc[y,x] = (x_data,y_data)
_cross_scatter(data, axs, colors, marker='o')
return fig
[docs]
def comparison(
adata,
keys,
colors=None,
show_only=None,
axsize=(2.0,2.0),
basis_adata=None,
basis_keys=None,
method_labels=None,
counts_location=None,
point_size=2,
joint=None
):
"""\
Scatterplots of the annotation fractions of different annotations, e.g. of
different annotation methods or of methods and a ground truth.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs` and/or `.obsm`.
Can also be a mapping of labels to :class:`~anndata.AnnData` to specify
multiple datasets.
keys
The `.obs`/`.obsm` annotation keys to compare. Can be a single string, or a
list of strings, or a mapping of the labels of `adata` to strings or lists
of strings.
colors
A mapping of annotation values to colors. If `None`, default colors are used.
show_only
A subset of annotation values to restrict the plotting to.
axsize
Tuple of width and size of a single axis.
basis_adata
like `adata`, but for reference data.
basis_keys
like `adata`, but for reference data.
method_labels
A mapping from the strings in `keys` and `basis_keys` to (shorter) names
to show in the plot.
counts_location
A string or tuple specifying where the count matrix is stored, e.g. `'X'`,
`('raw','X')`, `('raw','obsm','my_counts_key')`, `('layer','my_counts_key')`,
... For details see :func:`~tacco.get.counts`.
point_size
The size of the points in the plot.
joint
Whether to plot only one scatter plot with all annotation categories or only
the scatter plots with one annotation category per plot. If `None`, plot both.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
typing_data, adatas, methods, types, colors, basis_typing_data, basis_adatas, basis_methods = _validate_cross_args(adata, keys, colors, show_only, basis_adata, basis_keys, reads=False, method_labels=method_labels, counts_location=counts_location)
n_solutions, n_samples, n_types = len(typing_data), len(adatas), len(types)
n_solutions_basis, n_samples_basis = len(basis_typing_data), len(basis_adatas)
n_types = len(colors.index)
n_x = n_solutions
if joint is None:
n_y = n_types + 1
elif joint:
n_y = 1
else:
n_y = n_types
n_y = n_solutions_basis * n_y
fig, axs = subplots(n_x,n_y,axsize=axsize,sharex='all',sharey='all',wspace=0.25,hspace=0.25)
for ax in axs.flatten():
#ax.grid(False)
#ax.axes.xaxis.set_ticks([])
#ax.axes.yaxis.set_ticks([])
pass
for i,basis_name in enumerate(basis_typing_data.index):
basis_sample, basis_type_freqs = basis_typing_data.loc[basis_name,['sample', 'data']]
basis_type_freqs /= basis_type_freqs.sum(axis=1).to_numpy()[:,None]
for j,name in enumerate(typing_data.index):
sample, type_freqs = typing_data.loc[name,['sample', 'data']]
type_freqs /= type_freqs.sum(axis=1).to_numpy()[:,None]
common = type_freqs.index.intersection(basis_type_freqs.index)
type_freqs = type_freqs.loc[common]
_basis_type_freqs = basis_type_freqs.loc[common]
def plotit(t, ax):
x = type_freqs[t]
y = _basis_type_freqs[t]
ax.scatter(x, y, color=colors[t], s=point_size, label=t, edgecolors='none')
if joint is None:
joint_ax = axs[i * (n_types + 1),j]
for it,t in enumerate(colors.index):
plotit(t, joint_ax)
ax = axs[i * (n_types + 1) + it + 1, j]
plotit(t, ax)
pc, x2 = _correlations(type_freqs[t], _basis_type_freqs[t], log=False)
#ax.set_title('%s\nr=%.2f $\\chi^2_m$=%.2f'%(t, pc, x2))
ax.set_title('r=%.2f $\\chi^2_m$=%.2f'%(pc, x2))
pc, x2 = _correlations(type_freqs.to_numpy().flatten(), _basis_type_freqs.to_numpy().flatten(), log=False)
joint_ax.set_title('r=%.2f $\\chi^2_m$=%.2f'%(pc, x2))
elif joint:
joint_ax = axs[i,j]
for t in colors.index:
plotit(t, joint_ax)
pc, x2 = _correlations(type_freqs.to_numpy().flatten(), _basis_type_freqs.to_numpy().flatten(), log=False)
joint_ax.set_title('r=%.2f $\\chi^2_m$=%.2f'%(pc, x2))
else:
for it,t in enumerate(colors.index):
ax = axs[i * n_types + it, j]
plotit(t, ax)
pc, x2 = _correlations(type_freqs[t], _basis_type_freqs[t], log=False)
#ax.set_title('%s\nr=%.2f $\\chi^2_m$=%.2f'%(t, pc, x2))
ax.set_title('r=%.2f $\\chi^2_m$=%.2f'%(pc, x2))
for ax,name in zip(axs[n_y-1,:],typing_data.index):
ax.set_xlabel(name,rotation=15,va='top',ha='right')
#ax.tick_params(axis='x', rotation=45, which='both')
for ax,basis_name in zip(axs[:,0],list(basis_typing_data.index) * (n_y//n_solutions_basis)):
ax.set_ylabel(basis_name,rotation='horizontal',ha='right')
axs[0,n_x-1].legend(handles=[mpatches.Patch(color=color, label=ind) for (ind, color) in colors.items() ],
bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
#typing = pd.Series({'%s (%s)' % (name, basis_name): pd.get_dummies(mdata.obs['split_type'])})
return fig
[docs]
def compositions(
adata,
value_key,
group_key,
basis_adata=None,
basis_value_key=None,
basis_group_key=None,
fillna=None,
restrict_groups=None,
restrict_values=None,
basis_restrict_groups=None,
basis_restrict_values=None,
reads=False,
colors=None,
horizontal=False,
axsize=None,
ax=None,
legend=True,
):
"""\
Plot compositions of groups. In contrast to :func:`~tacco.plots.contribution`, compositions
have to add up to one.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs`.
value_key
The `.obs` or `.obsm` key with the values to determine the
enrichment for.
group_key
The `.obs` key with categorical group information.
basis_adata
Another :class:`~anndata.AnnData` with annotation in `.obs`
to compare. If `None`, only the `adata` composition is shown.
basis_value_key
The `.obs` or `.obsm` key for `basis_adata` with the values
to determine the enrichment for. If `None`, `value_key` is used.
basis_group_key
The `.obs` key with categorical group information for
`basis_adata`. If `None`, `value_key` is used.
fillna
If `None`, observation containing NA in the values are filtered.
Else, NA values are replaced with this value.
restrict_groups
A list-like containing the groups within which the enrichment analysis is
to be done. If `None`, all groups are included.
restrict_values
A list-like containing the values within which the enrichment analysis is
to be done. If `None`, all values are included. Works only for categorical
values.
basis_restrict_groups
Like `restrict_groups` but for `basis_adata`.
basis_restrict_values
Like `restrict_values` but for `basis_adata`.
reads
Whether to weight the values by the total count per observation
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
horizontal
Whether to plot the bar plot horizontally.
axsize
Tuple of width and size of a single axis. If `None`, use
automatic values.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting.
legend
Whether to include the legend
Returns
-------
A :class:`~matplotlib.figure.Figure` if `ax` is `None`, else `None`.
"""
compositions = get_compositions(adata=adata, value_key=value_key, group_key=group_key, fillna=fillna, restrict_groups=restrict_groups, restrict_values=restrict_values, reads=reads)
colors, types = _get_colors(colors, compositions.columns)
if basis_adata is not None:
if basis_value_key is None:
basis_value_key = value_key
if basis_group_key is None:
basis_group_key = group_key
basis_compositions = get_compositions(adata=basis_adata, value_key=basis_value_key, group_key=basis_group_key, fillna=fillna, restrict_groups=basis_restrict_groups, restrict_values=basis_restrict_values, reads=reads)
basis_compositions.index = basis_compositions.index.astype(str) + ' (reference)'
compositions = pd.concat([basis_compositions,compositions])
n_solutions = len(compositions.index)
if ax is not None:
fig = None
else:
if axsize is None:
if horizontal:
axsize=(9.1,0.7*n_solutions)
else:
axsize=(0.7*n_solutions,5)
fig, axs = subplots(axsize=axsize)
ax = axs[0,0]
_composition_bar(compositions, colors, horizontal=horizontal, ax=ax, legend=legend)
return fig
def _prep_contributions(
adata,
value_key,
group_key,
sample_key=None,
basis_adata=None,
basis_value_key=None,
basis_group_key=None,
basis_sample_key=None,
fillna=None,
restrict_groups=None,
restrict_values=None,
basis_restrict_groups=None,
basis_restrict_values=None,
reduction='sum',
normalization=None,
assume_counts=None,
reads=False,
colors=None,
):
"""\
Prepares contribution data.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs`.
value_key
The `.obs` or `.obsm` key with the values to determine the
enrichment for.
group_key
The `.obs` key with categorical group information.
sample_key
The `.obs` key with categorical sample information. If `None`,
only the aggregated data is plotted. Otherwise the data is aggregated
per sample and total and per-sample values are plotted.
basis_adata
Another :class:`~anndata.AnnData` with annotation in `.obs`
to compare. If `None`, only the `adata` composition is shown.
basis_value_key
The `.obs` or `.obsm` key for `basis_adata` with the values
to determine the enrichment for. If `None`, `value_key` is used.
basis_group_key
The `.obs` key with categorical group information for
`basis_adata`. If `None`, `value_key` is used.
basis_sample_key
The `.obs` key with categorical sample information for
`basis_adata`. If `None`, `sample_key` is used.
fillna
If `None`, observation containing NA in the values are filtered.
Else, NA values are replaced with this value.
restrict_groups
A list-like containing the groups within which the enrichment analysis
is to be done. If `None`, all groups are included.
restrict_values
A list-like containing the values within which the enrichment analysis
is to be done. If `None`, all values are included. Works only for
categorical values.
basis_restrict_groups
Like `restrict_groups` but for `basis_adata`.
basis_restrict_values
Like `restrict_values` but for `basis_adata`.
reduction
The reduction to apply on each (group,sample) subset of the data.
Possible values are:
- 'sum': sum of the values over observations
- 'mean': mean of the values over observations
- 'median': median of the values over observations
- a callable mapping a :class:`~pandas.DataFrame` to its reduced
counterpart
normalization
The normalization to apply on each reduced (group,sample) subset of the
data. Possible values are:
- 'sum': normalize values by their sum (yields fractions)
- 'percent': like 'sum' scaled by 100 (yields percentages)
- 'gmean': normalize values by their geometric mean (yields
contributions which make more sense for enrichments than fractions,
due to zero-sum issue; see :func:`~tacco.tools.enrichments`)
- 'clr': "Center logratio transform"; like 'gmean' with additional log
transform; makes the distribution more normal and better suited for t
tests
- `None`: no normalization
- a value name from `value_key`: all values are normalized to this
contribution
- a callable mapping a :class:`~pandas.DataFrame` to its normalized
counterpart
assume_counts
Ony relevant for `normalization=='gmean'` and `normalization=='clr'`;
whether to regularize zeros by adding a pseudo count of 1 or by
replacing them by 1e-3 of the minimum value. If `None`, check whether
the data are consistent with count data and assume counts accordingly,
except if `reads==True`, then also `assume_counts==True`.
reads
Whether to weight the values by the total count per observation
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
Returns
-------
A tuple containing contributions, detailed_contributions, colors, types.
"""
contributions = get_contributions(adata=adata, value_key=value_key, group_key=group_key, sample_key=None, fillna=fillna, restrict_groups=restrict_groups, restrict_values=restrict_values, reads=reads, reduction=reduction, normalization=normalization, assume_counts=assume_counts)
detailed_contributions = None
if sample_key is not None:
detailed_contributions = get_contributions(adata=adata, value_key=value_key, group_key=group_key, sample_key=sample_key, fillna=fillna, restrict_groups=restrict_groups, restrict_values=restrict_values, reads=reads, reduction=reduction, normalization=normalization, assume_counts=assume_counts)
colors, types = _get_colors(colors, contributions.columns)
if basis_adata is not None:
if basis_value_key is None:
basis_value_key = value_key
if basis_group_key is None:
basis_group_key = group_key
if basis_sample_key is None:
basis_sample_key = sample_key
basis_contributions = get_contributions(adata=basis_adata, value_key=basis_value_key, group_key=basis_group_key, sample_key=None, fillna=fillna, restrict_groups=basis_restrict_groups, restrict_values=basis_restrict_values, reads=reads, reduction=reduction, normalization=normalization, assume_counts=assume_counts)
basis_contributions.index = basis_contributions.index.astype(str) + ' (reference)'
contributions = pd.concat([basis_contributions,contributions])
if sample_key is not None:
if basis_sample_key is None:
basis_sample_key = sample_key
detailed_basis_contributions = get_contributions(adata=basis_adata, value_key=basis_value_key, group_key=basis_group_key, sample_key=basis_sample_key, fillna=fillna, restrict_groups=basis_restrict_groups, restrict_values=basis_restrict_values, reads=reads, reduction=reduction, normalization=normalization, assume_counts=assume_counts)
detailed_basis_contributions.index = detailed_basis_contributions.index.set_levels([detailed_basis_contributions.index.levels[0].astype(str) + ' (reference)',detailed_basis_contributions.index.levels[1].astype(str) ])
detailed_basis_contributions.index.names = detailed_contributions.index.names
detailed_contributions = pd.concat([detailed_basis_contributions,detailed_contributions])
return contributions, detailed_contributions, colors, types
[docs]
def contribution(
adata,
value_key,
group_key,
sample_key=None,
basis_adata=None,
basis_value_key=None,
basis_group_key=None,
basis_sample_key=None,
fillna=None,
restrict_groups=None,
restrict_values=None,
basis_restrict_groups=None,
basis_restrict_values=None,
reduction='sum',
normalization='gmean',
assume_counts=None,
reads=False,
colors=None,
axsize=None,
log=True,
ax=None,
):
"""\
Plot contribution to groups. In contrast to :func:`~tacco.plots.composition`,
contributions dont have to add up to one.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs`.
value_key
The `.obs` or `.obsm` key with the values to determine the
enrichment for.
group_key
The `.obs` key with categorical group information.
sample_key
The `.obs` key with categorical sample information. If `None`,
only the aggregated data is plotted. Otherwise the data is aggregated
per sample and total and per-sample values are plotted.
basis_adata
Another :class:`~anndata.AnnData` with annotation in `.obs`
to compare. If `None`, only the `adata` composition is shown.
basis_value_key
The `.obs` or `.obsm` key for `basis_adata` with the values
to determine the enrichment for. If `None`, `value_key` is used.
basis_group_key
The `.obs` key with categorical group information for
`basis_adata`. If `None`, `value_key` is used.
basis_sample_key
The `.obs` key with categorical sample information for
`basis_adata`. If `None`, `sample_key` is used.
fillna
If `None`, observation containing NA in the values are filtered.
Else, NA values are replaced with this value.
restrict_groups
A list-like containing the groups within which the enrichment analysis
is to be done. If `None`, all groups are included.
restrict_values
A list-like containing the values within which the enrichment analysis
is to be done. If `None`, all values are included. Works only for
categorical values.
basis_restrict_groups
Like `restrict_groups` but for `basis_adata`.
basis_restrict_values
Like `restrict_values` but for `basis_adata`.
reduction
The reduction to apply on each (group,sample) subset of the data.
Possible values are:
- 'sum': sum of the values over observations
- 'mean': mean of the values over observations
- 'median': median of the values over observations
- a callable mapping a :class:`~pandas.DataFrame` to its reduced
counterpart
normalization
The normalization to apply on each reduced (group,sample) subset of the
data. Possible values are:
- 'sum': normalize values by their sum (yields fractions)
- 'percent': like 'sum' scaled by 100 (yields percentages)
- 'gmean': normalize values by their geometric mean (yields
contributions which make more sense for enrichments than fractions,
due to zero-sum issue; see :func:`~tacco.tools.enrichments`)
- 'clr': "Center logratio transform"; like 'gmean' with additional log
transform; makes the distribution more normal and better suited for t
tests
- `None`: no normalization
- a value name from `value_key`: all values are normalized to this
contribution
- a callable mapping a :class:`~pandas.DataFrame` to its normalized
counterpart
assume_counts
Ony relevant for `normalization=='gmean'` and `normalization=='clr'`;
whether to regularize zeros by adding a pseudo count of 1 or by
replacing them by 1e-3 of the minimum value. If `None`, check whether
the data are consistent with count data and assume counts accordingly,
except if `reads==True`, then also `assume_counts==True`.
reads
Whether to weight the values by the total count per observation
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
axsize
Tuple of width and size of a single axis. If `None`, use
automatic values.
log
Whether to plot on the log scale.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting.
Returns
-------
A :class:`~matplotlib.figure.Figure` if `ax` is `None`, else `None`.
"""
contributions, detailed_contributions, colors, types = _prep_contributions(
adata=adata,
value_key=value_key,
group_key=group_key,
sample_key=sample_key,
basis_adata=basis_adata,
basis_value_key=basis_value_key,
basis_group_key=basis_group_key,
basis_sample_key=basis_sample_key,
fillna=fillna,
restrict_groups=restrict_groups,
restrict_values=restrict_values,
basis_restrict_groups=basis_restrict_groups,
basis_restrict_values=basis_restrict_values,
reduction=reduction,
normalization=normalization,
assume_counts=assume_counts,
reads=reads,
colors=colors,
)
labels = contributions.columns.astype(str)
total_bars_width = 0.8
n_states = len(contributions.index)
bars_separation = total_bars_width / (n_states * 3)
bar_separation = total_bars_width / (n_states * 5)
n_states = len(contributions.index)
if ax is not None:
fig = None
else:
if axsize is None:
axsize = (len(labels) * (0.3 * n_states + .2), 4)
fig, ax = subplots(axsize=axsize)
ax = ax[0,0]
ax.set_axisbelow(True)
ax.yaxis.grid()#color='gray', linestyle='dashed')
alpha = 1 if detailed_contributions is None else 0.5
minor_x_labeling = {'position':[],'label':[]}
for i_state, state in enumerate(contributions.index):
for i_column, column in enumerate(contributions.columns):
bar_width = (total_bars_width-n_states*bar_separation)/n_states
bar_start = i_column - 0.5*total_bars_width + i_state * total_bars_width/n_states + 0.5*bar_separation
ax.bar(bar_start, contributions.loc[state,column], bar_width, color=colors[column], align='edge', alpha=alpha)
minor_x_labeling['position'].append(i_column - 0.5*total_bars_width + (i_state+0.5)/(n_states) * total_bars_width)
minor_x_labeling['label'].append(state)
if detailed_contributions is not None:
for i_state, state in enumerate(contributions.index):
df = detailed_contributions[detailed_contributions.index.get_level_values(0) == state]
n_samples = df.shape[0]
for i_column, column in enumerate(contributions.columns):
heights = df[column].sort_values(ascending=False)
bar_start = i_column - 0.5*total_bars_width + 0.5*bars_separation + i_state/n_states * total_bars_width
bar_width = (total_bars_width - (n_states - 0) * bars_separation) / (n_states * n_samples)
ax.bar(bar_start + np.arange(0,(n_samples-0.5)*bar_width,bar_width), heights, bar_width, color=colors[column], align='edge')
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('contribution')
if log:
ax.set_yscale('log')
x = np.arange(len(labels)) # the label locations
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlim(x[0]-0.5,x[-1]+0.5)
minor_x_labeling = pd.DataFrame(minor_x_labeling).sort_values('position')
minor_x_labeling['position'] += 1e-2 # tiny shift to make the central minor annotation not be superseded by the major annotation...
ax.set_xticks(minor_x_labeling['position'],minor=True)
ax.set_xticklabels(minor_x_labeling['label'], rotation=50, ha='right',minor=True)
for t in ax.get_xticklabels(minor=True):
t.set_y(-0.05)
ax.tick_params( axis='x', which='major', bottom=False, top=False )
ax.tick_params( axis='x', which='minor', bottom=False, top=False )
return fig
[docs]
def heatmap(
adata,
value_key,
group_key,
basis_adata=None,
basis_value_key=None,
basis_group_key=None,
fillna=None,
restrict_groups=None,
restrict_values=None,
basis_restrict_groups=None,
basis_restrict_values=None,
reduction='sum',
normalization=None,
assume_counts=None,
reads=False,
colors=None,
alpha=None,
axsize=None,
axes_labels=None,
annotation=None,
value_cluster=False,
group_cluster=False,
value_dendrogram=False,
group_dendrogram=False,
value_order=None,
group_order=None,
group_labels_rotation=None,
ax=None,
cmap=None,
cmap_center=0,
cmap_vmin_vmax=None,
complement_colors=True,
colorbar=True,
colorbar_label=None,
):
"""\
Plot heatmap of contribution to groups.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs`. Can also be a
:class:`~pandas.DataFrame` with the data to show in the heatmap. In the
latter case all adata processing arguments are ignored.
value_key
The `.obs` or `.obsm` key with the values to determine the
enrichment for.
group_key
The `.obs` key with categorical group information.
basis_adata
Another :class:`~anndata.AnnData` with annotation in `.obs`
to compare. If `None`, only the `adata` composition is shown.
basis_value_key
The `.obs` or `.obsm` key for `basis_adata` with the values
to determine the enrichment for. If `None`, `value_key` is used.
basis_group_key
The `.obs` key with categorical group information for
`basis_adata`. If `None`, `value_key` is used.
fillna
If `None`, observation containing NA in the values are filtered.
Else, NA values are replaced with this value.
restrict_groups
A list-like containing the groups within which the enrichment analysis
is to be done. If `None`, all groups are included.
restrict_values
A list-like containing the values within which the enrichment analysis
is to be done. If `None`, all values are included. Works only for
categorical values.
basis_restrict_groups
Like `restrict_groups` but for `basis_adata`.
basis_restrict_values
Like `restrict_values` but for `basis_adata`.
reduction
The reduction to apply on each (group,sample) subset of the data.
Possible values are:
- 'sum': sum of the values over observations
- 'mean': mean of the values over observations
- 'median': median of the values over observations
- a callable mapping a :class:`~pandas.DataFrame` to its reduced
counterpart
normalization
The normalization to apply on each reduced (group,sample) subset of the
data. Possible values are:
- 'sum': normalize values by their sum (yields fractions)
- 'percent': like 'sum' scaled by 100 (yields percentages)
- 'gmean': normalize values by their geometric mean (yields
contributions which make more sense for enrichments than fractions,
due to zero-sum issue; see :func:`~tacco.tools.enrichments`)
- 'clr': "Center logratio transform"; like 'gmean' with additional log
transform; makes the distribution more normal and better suited for t
tests
- `None`: no normalization
- a value name from `value_key`: all values are normalized to this
contribution
- a callable mapping a :class:`~pandas.DataFrame` to its normalized
counterpart
assume_counts
Ony relevant for `normalization=='gmean'` and `normalization=='clr'`;
whether to regularize zeros by adding a pseudo count of 1 or by
replacing them by 1e-3 of the minimum value. If `None`, check whether
the data are consistent with count data and assume counts accordingly,
except if `reads==True`, then also `assume_counts==True`.
reads
Whether to weight the values by the total count per observation
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
alpha
A value-group-dataframe specifying a separate alpha value for all cells
in the histogram. If `None`, no transparency is used.
axsize
Tuple of width and size of a single axis. If `None`, use
automatic values.
axes_labels
Labels to write on the axes as an list-like of the two labels.
annotation
A :class:`~pandas.DataFrame` containing annotation for each heatmap
cell. If "value", annotate by the values. If `None`, don't annotate.
If a tuple of "value" and a :class:`~pandas.DataFrame`, append the
annotation from the dataframe to the values.
value_cluster
Whether to cluster and reorder the values.
group_cluster
Whether to cluster and reorder the groups.
value_dendrogram
Whether to draw a dendrogram for the values. If `True`, this implies
`value_cluster=True`.
group_dendrogram
Whether to draw a dendrogram for the groups. If `True`, this implies
`group_cluster=True`.
value_order
Set the order of the values explicitly with a list or to be close to
diagonal by specifying "diag"; this option is incompatible with
`value_cluster` and `value_dendrogram`.
group_order
Set the order of the groups explicitly with a list or to be close to
diagonal by specifying "diag"; this option is incompatible with
`group_cluster` and `group_dendrogram`.
group_labels_rotation
Adjusts the rotation of the group labels in degree.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting. Incompatible with dendrogram plotting.
cmap
A string/colormap to override the `colors` with.
cmap_center
A value to use as center of the colormap. E.g. choosing `0` sets the
central color to `0` for every colormap in the plot (i.e. `0` will get
white, positive and negative colors the color and complement color
given by `colors` for `cmap` is `None`, and whatever the central color
of the supplied colormap is if `cmap` is not `None`). If `None`, the
colormap spans the entire value range.
cmap_vmin_vmax
A tuple giving the range of values for the colormap. This can be
modfied by `cmap_center`.
complement_colors
Whether to use complement colors for values below `cmap_center` if
`cmap==None`.
colorbar
Whether to draw a colorbar; only available if `cmap` is not `None`.
colorbar_label
The label to use for the colorbar; only available if `colorbar` is
`True`.
Returns
-------
A :class:`~matplotlib.figure.Figure` if `ax` is `None`, else `None`.
"""
if isinstance(adata, pd.DataFrame):
contributions = adata.fillna(0)
types = contributions.index
else:
contributions, detailed_contributions, colors, types = _prep_contributions(
adata=adata,
value_key=value_key,
group_key=group_key,
sample_key=None,
basis_adata=basis_adata,
basis_value_key=basis_value_key,
basis_group_key=basis_group_key,
basis_sample_key=None,
fillna=fillna,
restrict_groups=restrict_groups,
restrict_values=restrict_values,
basis_restrict_groups=basis_restrict_groups,
basis_restrict_values=basis_restrict_values,
reduction=reduction,
normalization=normalization,
assume_counts=assume_counts,
reads=reads,
colors=colors,
)
value_cluster = value_cluster or value_dendrogram
group_cluster = group_cluster or group_dendrogram
if value_order is not None and value_cluster:
raise ValueError('The options `value_cluster` and `value_dendrogram` are incompatible with the option `value_order`.')
if group_order is not None and group_cluster:
raise ValueError('The options `group_cluster` and `group_dendrogram` are incompatible with the option `group_order`.')
group_order_diag = isinstance(group_order, str) and group_order == 'diag'
value_order_diag = isinstance(value_order, str) and value_order == 'diag'
if group_order_diag and value_order is not None and not value_order_diag:
raise ValueError('The option `group_order=="diag"` is incompatible with non-´None´ or "diag" values for `value_order`.')
if value_order_diag and group_order is not None and not group_order:
raise ValueError('The option `value_order=="diag"` is incompatible with non-´None´ or "diag" values for `group_order`.')
if group_order_diag and group_cluster:
raise ValueError('The option `group_order=="diag"` is incompatible with the options `value_cluster` and `value_dendrogram`.')
if value_order_diag and value_cluster:
raise ValueError('The option `value_order=="diag"` is incompatible with the options `group_cluster` and `group_dendrogram`.')
if ax is not None:
if value_dendrogram or group_dendrogram:
raise ValueError('The options `value_dendrogram` and `group_dendrogram` are incompatible with the option `ax`.')
fig = ax.get_figure()
else:
n_ax_x, n_ax_y = 1+value_dendrogram, 1+group_dendrogram
if axsize is None:
y_per_value = 0.25
x_per_group = y_per_value if annotation is None else 0.7
plot_size_y = contributions.shape[1] * y_per_value
plot_size_x = contributions.shape[0] * x_per_group
axsize = (plot_size_x,plot_size_y)
else:
plot_size_x,plot_size_y = axsize
width_ratios = [plot_size_x,0.5*np.log(contributions.shape[1])] if value_dendrogram else None
height_ratios = [0.5*np.log(contributions.shape[0]),plot_size_y] if group_dendrogram else None
fig,axs = subplots(n_ax_x, n_ax_y, axsize=axsize, wspace=0, hspace=0, width_ratios=width_ratios, height_ratios=height_ratios)
ax = axs[-1,0]
if group_labels_rotation is None:
if annotation is None:
group_labels_rotation = 45
else:
group_labels_rotation = 30
if value_cluster:
Z = scipy.cluster.hierarchy.linkage(contributions.T, method='average', metric='cosine')
dn = scipy.cluster.hierarchy.dendrogram(Z, ax=(axs[-1,-1] if value_dendrogram else None), orientation='right', color_threshold=0, above_threshold_color='tab:gray', no_plot=(not value_dendrogram))
if value_dendrogram:
axs[-1,-1].set_axis_off()
reordering = pd.Series(dn['ivl']).astype(np.int).to_numpy()
contributions = contributions.iloc[:,reordering]
if group_cluster:
Z = scipy.cluster.hierarchy.linkage(contributions, method='average', metric='cosine')
dn = scipy.cluster.hierarchy.dendrogram(Z, ax=(axs[0,0] if group_dendrogram else None), orientation='top', color_threshold=0, above_threshold_color='tab:gray', no_plot=(not group_dendrogram))
if group_dendrogram:
axs[0,0].set_axis_off()
reordering = pd.Series(dn['ivl']).astype(np.int).to_numpy()
contributions = contributions.iloc[reordering]
if value_dendrogram and group_dendrogram:
axs[0,-1].set_axis_off()
if isinstance(value_order,str) and value_order == 'diag':
# permute towards diagonal
for i in range(10):
contributions = contributions.iloc[:,np.argsort(np.argmax(contributions.to_numpy(),axis=0))]
contributions = contributions.iloc[np.argsort(np.argmax(contributions.to_numpy().T,axis=0))]
else:
if value_order is not None:
contributions = contributions.loc[:,value_order]
if group_order is not None:
contributions = contributions.loc[group_order]
_x, _y = [np.arange(-0.5,s+0.5,1) for s in contributions.shape]
x, y = (_x[:-1] + _x[1:]) / 2, (_y[:-1] + _y[1:]) / 2
def _vmin_vmax(vmin, vmax):
if cmap_vmin_vmax is not None:
vmin, vmax = cmap_vmin_vmax
if cmap_center is not None:
delta_max = vmax - cmap_center
delta_min = -(vmin - cmap_center)
delta = max(delta_max, delta_min)
vmax = cmap_center + delta
vmin = cmap_center - delta
return vmin, vmax
if cmap is None:
rgba = np.zeros((*contributions.shape,4))
for j in range(contributions.shape[1]):
r, g, b = to_rgb(colors[contributions.columns[j]])
_r, _g, _b = _complement_color(r,g,b) # complement colors for negative values
vmin, vmax = _vmin_vmax(contributions.iloc[:,j].min(), contributions.iloc[:,j].max())
norm = Normalize(vmin=vmin, vmax=vmax)
if complement_colors:
cmap_j=LinearSegmentedColormap.from_list(contributions.columns[j], [(0,(_r, _g, _b)),(0.5,(1, 1, 1)),(1,(r, g, b))])
else:
cmap_j=LinearSegmentedColormap.from_list(contributions.columns[j], [(0,(1, 1, 1)),(1,(r, g, b))])
mapper = ScalarMappable(norm=norm, cmap=cmap_j)
rgba[:,j,:] = mapper.to_rgba(contributions.iloc[:,j].to_numpy())
else:
vmin, vmax = _vmin_vmax(contributions.to_numpy().min(), contributions.to_numpy().max())
norm = Normalize(vmin=vmin, vmax=vmax)
mapper = ScalarMappable(norm=norm, cmap=cmap)
rgba = mapper.to_rgba(contributions.to_numpy())
if colorbar:
height_pxl = 200
width_pxl = 15
offset_top_pxl = 0
offset_left_pxl = 20
left,bottom = fig.transFigure.inverted().transform(ax.transAxes.transform((1,1))+np.array([offset_left_pxl,-offset_top_pxl-height_pxl]))
width,height = fig.transFigure.inverted().transform(fig.transFigure.transform((0,0))+np.array([width_pxl,height_pxl]))
cax = fig.add_axes((left, bottom, width, height))
cb = fig.colorbar(mapper, cax=cax, label=colorbar_label)
if alpha is not None:
alpha = alpha.reindex(index=contributions.index, columns=contributions.columns)
rgba[...,-1] = alpha
ax.imshow(np.swapaxes(rgba, 0, 1), origin='lower', aspect='auto')
if annotation is not None:
for i,ind in enumerate(contributions.index):
for j,col in enumerate(contributions.columns):
if isinstance(annotation, str):
if annotation == 'value':
ann = f'{contributions.loc[ind,col]:.2}'
else:
raise ValueError(f'`annotation` got unknown string argument "{annotation}"')
elif (hasattr(annotation, 'shape') and annotation.shape == (2,)) or (not hasattr(annotation, 'shape') and len(annotation) == 2):
if isinstance(annotation[0], str) and annotation[0] == 'value':
if ind in annotation[1].index and col in annotation[1].columns:
val = annotation[1].loc[ind,col]
else:
val = 0#np.nan
ann = f'{contributions.loc[ind,col]:.2}{val}'
else:
raise ValueError('`annotation` got a tuple argument where the first entry is not "value"')
else:
ann = annotation.loc[ind,col]
ax.annotate(ann, xy=(x[i],y[j]), ha='center', va='center')
ax.set_xticks(x)
ax.set_xticklabels(contributions.index, rotation=group_labels_rotation, ha=('right' if group_labels_rotation not in [0,90] else 'center'))
ax.set_yticks(y)
ax.set_yticklabels(contributions.columns)
_set_axes_labels(ax, axes_labels)
return fig
def _asterisks_from_pvals(pvals):
pvals = pvals.astype(object)
anstr = pvals.astype(str)
anstr.loc[:,:] = ''
anstr[(pvals > 0) & (np.abs(pvals) <= 0.05)] = '$^{\\ast}$'
anstr[(pvals > 0) & (np.abs(pvals) <= 0.01)] = '${^{\\ast}}{^{\\ast}}$'
anstr[(pvals > 0) & (np.abs(pvals) <= 0.001)] = '${^{\\ast}}{^{\\ast}}{^{\\ast}}$'
anstr[(pvals < 0) & (np.abs(pvals) <= 0.05)] = '$_{\\ast}$'
anstr[(pvals < 0) & (np.abs(pvals) <= 0.01)] = '${_{\\ast}}{_{\\ast}}$'
anstr[(pvals < 0) & (np.abs(pvals) <= 0.001)] = '${_{\\ast}}{_{\\ast}}{_{\\ast}}$'
return anstr
[docs]
def sigmap(
adata,
value_key,
group_key,
sample_key=None,
position_key=None,
position_split=2,
min_obs=0,
basis_adata=None,
basis_value_key=None,
basis_group_key=None,
basis_sample_key=None,
basis_position_key=None,
basis_position_split=None,
basis_min_obs=None,
fillna=None,
restrict_groups=None,
restrict_values=None,
basis_restrict_groups=None,
basis_restrict_values=None,
p_corr='fdr_bh',
method='mwu',
reduction=None,
normalization=None,
assume_counts=None,
reads=False,
colors=None,
axsize=None,
value_dendrogram=False,
group_dendrogram=False,
value_order=None,
group_order=None,
ax=None,
):
"""\
Plot heatmap of contribution to groups and mark significant differences
with asterisks.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs`.
value_key
The `.obs` or `.obsm` key with the values to determine the
enrichment for.
group_key
The `.obs` key with categorical group information.
sample_key
The `.obs` key with categorical sample information for p-value
determination. If `None`, use only the aggregated data is plotted.
position_key
The `.obsm` key or array-like of `.obs` keys with the position space
coordinates. If `None`, no position splits are performed.
position_split
The number of splits per spatial dimension before enrichment. Can be a
tuple with the spatial dimension as length to assign a different split
per dimension. If `None`, no position splits are performed. See also
`min_obs`.
min_obs
The minimum number of observations per sample: if less observations are
available, the sample is not used. This also limits the number of
`position_split` to stop splitting if the split would decrease the
number of observations below this threshold.
basis_adata
Another :class:`~anndata.AnnData` with annotation in `.obs`
to compare. If `None`, only the `adata` composition is shown.
basis_value_key
The `.obs` or `.obsm` key for `basis_adata` with the values
to determine the enrichment for. If `None`, `value_key` is used.
basis_group_key
The `.obs` key with categorical group information for
`basis_adata`. If `None`, `value_key` is used.
basis_sample_key
The `.obs` key with categorical sample information for
`basis_adata`. If `None`, `sample_key` is used.
basis_position_key
Like `position_key` but for `basis_adata`. If `None`, no position
splits are performed.
basis_position_split
Like `position_split` but for `basis_adata`. If `None`,
`position_split` is used.
basis_min_obs
Like `min_obs` but for `basis_adata`. If `None`, `min_obs` is used.
fillna
If `None`, observation containing NA in the values are filtered.
Else, NA values are replaced with this value.
restrict_groups
A list-like containing the groups within which the enrichment analysis
is to be done. If `None`, all groups are included.
restrict_values
A list-like containing the values within which the enrichment analysis
is to be done. If `None`, all values are included. Works only for
categorical values.
basis_restrict_groups
Like `restrict_groups` but for `basis_adata`.
basis_restrict_values
Like `restrict_values` but for `basis_adata`.
p_corr
The name of the p-value correction method to use. Possible values are
the ones available in
:func:`~statsmodels.stats.multitest.multipletests`. If `None`, no
p-value correction is performed.
method
Specification of methods to use for enrichment. Available are:
- 'fisher': Fishers exact test; only for categorical values. Ignores
the `reduction` and `normalization` arguments.
- 'mwu': MannWhitneyU test
reduction
The reduction to apply on each (group,sample) subset of the data.
Possible values are:
- 'sum': sum of the values over observations
- 'mean': mean of the values over observations
- 'median': median of the values over observations
- `None`: use observations directly
- a callable mapping a :class:`~pandas.DataFrame` to its reduced
counterpart
normalization
The normalization to apply on each reduced (group,sample) subset of the
data. Possible values are:
- 'sum': normalize values by their sum (yields fractions)
- 'percent': like 'sum' scaled by 100 (yields percentages)
- 'gmean': normalize values by their geometric mean (yields
contributions which make more sense for enrichments than fractions,
due to zero-sum issue; see :func:`~tacco.tools.enrichments`)
- 'clr': "Center logratio transform"; like 'gmean' with additional log
transform; makes the distribution more normal and better suited for t
tests
- `None`: no normalization
- a value name from `value_key`: all values are normalized to this
contribution
- a callable mapping a :class:`~pandas.DataFrame` to its normalized
counterpart
assume_counts
Ony relevant for `normalization=='gmean'` and `normalization=='clr'`;
whether to regularize zeros by adding a pseudo count of 1 or by
replacing them by 1e-3 of the minimum value. If `None`, check whether
the data are consistent with count data and assume counts accordingly,
except if `reads==True`, then also `assume_counts==True`.
reads
Whether to weight the values by the total count per observation
colors
The mapping of value names to colors. If `None`, a set of standard
colors is used.
axsize
Tuple of width and size of a single axis. If `None`, use automatic
values.
value_dendrogram
Whether to draw a dendrogram for the values
group_dendrogram
Whether to draw a dendrogram for the groups
value_order
Set the order of the values explicitly; this option is incompatible
with `value_dendrogram`.
group_order
Set the order of the groups explicitly; this option is incompatible
with `group_dendrogram`.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting. Incompatible with dendrogram plotting.
Returns
-------
A :class:`~matplotlib.figure.Figure` if `ax` is `None`, else `None`.
"""
pvals = enrichments(
adata=adata,value_key=value_key,group_key=group_key,sample_key=sample_key, method=method,fillna=fillna,
position_key=position_key, position_split=position_split,min_obs=min_obs,p_corr=p_corr,
restrict_groups=restrict_groups,restrict_values=restrict_values, reads=reads,
reduction=reduction,normalization=normalization,assume_counts=assume_counts,
)
if basis_adata is not None:
if basis_position_split is None:
basis_position_split = position_split
if basis_min_obs is None:
basis_min_obs = min_obs
if basis_sample_key is None:
basis_sample_key = sample_key
if basis_group_key is None:
basis_group_key = group_key
if basis_value_key is None:
basis_value_key = value_key
basis_pvals = enrichments(
adata=basis_adata,value_key=basis_value_key,group_key=basis_group_key,sample_key=basis_sample_key, method=method,fillna=fillna,
position_key=basis_position_key, position_split=basis_position_split,min_obs=basis_min_obs,p_corr=p_corr,
restrict_groups=basis_restrict_groups,restrict_values=basis_restrict_values, reads=reads,
reduction=reduction,normalization=normalization,assume_counts=assume_counts,
)
basis_pvals.rename(columns={basis_value_key:value_key}, inplace=True)
basis_pvals[basis_group_key] = basis_pvals[basis_group_key].cat.rename_categories(lambda c: c + ' (reference)')
basis_pvals.rename(columns={basis_group_key:group_key}, inplace=True)
joint_categories = [*list(basis_pvals[group_key].cat.categories), *list(pvals[group_key].cat.categories)]
pvals = pd.concat([basis_pvals,pvals])
pvals[group_key] = pvals[group_key].astype(pd.CategoricalDtype(joint_categories,ordered=True))
ann = pd.pivot_table(pvals[pvals['enrichment']=='enriched'],values=f'p_{method}_{p_corr}',index=group_key,columns=value_key)
annp = pd.pivot_table(pvals[pvals['enrichment']!='enriched'],values=f'p_{method}_{p_corr}',index=group_key,columns=value_key)
ann[annp<ann] = -annp
anstr = _asterisks_from_pvals(ann)
fig = heatmap(
adata=adata,value_key=value_key,group_key=group_key,
basis_adata=basis_adata,basis_value_key=basis_value_key,basis_group_key=basis_group_key,
fillna=fillna,restrict_groups=restrict_groups,restrict_values=restrict_values,
basis_restrict_groups=basis_restrict_groups,basis_restrict_values=basis_restrict_values,
reduction=reduction,normalization=normalization,reads=reads,colors=colors,axsize=axsize,annotation=('value',anstr),
value_dendrogram=value_dendrogram, group_dendrogram=group_dendrogram,assume_counts=assume_counts,
ax=ax, colorbar=False, complement_colors=False, cmap_center=None,
value_order=value_order,group_order=group_order,
);
return fig
[docs]
def significances(
significances,
p_key,
value_key,
group_key,
enrichment_key='enrichment',
enriched_label='enriched',
pmax=0.05,
pmin=1e-5,
annotate_pvalues=True,
value_cluster=False,
group_cluster=False,
value_order=None,
group_order=None,
axsize=None,
ax = None,
scale_legend=1.0
):
"""\
Plot enrichment significances.
Parameters
----------
significances
A :class:`~pandas.DataFrame` with p-values and their annotation. If it
contains significances for enrichment and depletion, this direction has
to be specified with values "enriched" and something else (e.g.
"depleted" or "purified") in a column "enrichment" of the DataFrame.
See also the parameters `enrichment_key` and `enrichment_label`.
p_key
The key with the p-values.
value_key
The key with the values for which the enrichment was determined.
group_key
The key with the groups in which the enrichment was determined.
enrichment_key
The key with the direction of enrichment, i.e something like "enriched"
and "purified". See also parameter `enriched_label`. Default:
"enrichment".
enriched_label
The value under the key `enrichment_key` which indicates something like
enrichment. Default: "enriched".
pmax
The maximum p-value to show.
pmin
The minimum p-value on the color scale.
annotate_pvalues
Whether to annotate p-values
value_cluster
Whether to cluster and reorder the values.
group_cluster
Whether to cluster and reorder the groups.
value_order
Set the order of the values explicitly with a list or to be close to
diagonal by specifying "diag"; this option is incompatible with
`value_cluster`.
group_order
Set the order of the groups explicitly with a list or to be close to
diagonal by specifying "diag"; this option is incompatible with
`group_cluster`.
axsize
Tuple of width and size of a single axis. If `None`, use
automatic values.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting. Incompatible with dendrogram plotting.
scale_legend
Set to scale height and width of legend.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
small_value = 1e-300
max_log = -np.log(pmin)
min_log = -np.log(pmax)
depleted_label = None
if enrichment_key is not None:
if enrichment_key not in significances:
raise ValueError(f'The column "{enrichment_key}" does not exist in the supplied dataframe! If this is intentional and you want to supress this error, supply "enrichment_key=None" as argument.')
unique_significance_labels = significances[enrichment_key].unique()
if len(unique_significance_labels) == 1:
enriched_label = unique_significance_labels[0]
else:
if len(unique_significance_labels) > 2 or enriched_label not in unique_significance_labels:
raise ValueError(f'The column "{enrichment_key}" is expected to have exactly 2 different values: "{enriched_label}" and something else, e.g. "depleted" or "purified"! The supplied column has the values {significances[enrichment_key].unique()}!')
depleted_label = unique_significance_labels[unique_significance_labels!=enriched_label][0]
if depleted_label is not None:
enr_e = pd.pivot(significances[significances[enrichment_key]==enriched_label], index=value_key, columns=group_key, values=p_key)
enr_p = pd.pivot(significances[significances[enrichment_key]==depleted_label], index=value_key, columns=group_key, values=p_key)
enr_e = np.maximum(enr_e,small_value)
enr_p = np.maximum(enr_p,small_value)
enr_p = enr_p.reindex_like(enr_e)
enr = pd.DataFrame(np.where(enr_e < enr_p, -np.log(enr_e), np.log(enr_p)),index=enr_e.index,columns=enr_e.columns)
ann = pd.DataFrame(np.where(enr_e < enr_p, enr_e, enr_p),index=enr_e.index,columns=enr_e.columns)
else:
ann = pd.pivot(significances, index=value_key, columns=group_key, values=p_key)
enr = -np.log(ann)
# avoid discrepancies between color and annotation by basing both color and annotation on cuts on the same values
enr = pd.DataFrame(np.where(ann > pmax, 0, enr),index=enr.index,columns=enr.columns)
ann = pd.DataFrame(np.where(ann > pmax, '', ann.applymap(lambda x: f'{x:.2}')),index=enr.index,columns=enr.columns)
enr = enr.T
ann = ann.T
if not annotate_pvalues:
ann = None
# setup the plotting
enriched_color = (1.0, 0.07058823529411765, 0.09019607843137255)
depleted_color = (0.30196078431372547, 0.5215686274509804, 0.7098039215686275)
null_color = (0.9,0.9,0.9)
slightly_weight = 0.2
slightly_enriched_color, slightly_depleted_color = mix_base_colors(
np.array([[slightly_weight,1-slightly_weight,0.0],[0.0,1-slightly_weight,slightly_weight],]),
np.array([list(enriched_color),list(null_color),list(depleted_color)])
)
if depleted_label is None:
ct2 = min_log/max_log
cdict = {'red': [[0.0, null_color[0], null_color[0]],
[ct2, null_color[0], slightly_enriched_color[0]],
[1.0, enriched_color[0], enriched_color[0]]],
'green': [[0.0, null_color[1], null_color[1]],
[ct2, null_color[1], slightly_enriched_color[1]],
[1.0, enriched_color[1], enriched_color[1]]],
'blue': [[0.0, null_color[2], null_color[2]],
[ct2, null_color[2], slightly_enriched_color[2]],
[1.0, enriched_color[2], enriched_color[2]]]}
cmap = LinearSegmentedColormap('sigmap', segmentdata=cdict, N=256)
fig = heatmap(enr, None, None, cmap=cmap, cmap_vmin_vmax=(0.0,max_log), cmap_center=max_log/2, annotation=ann, colorbar=False, value_cluster=value_cluster, group_cluster=group_cluster, value_order=value_order, group_order=group_order, axsize=axsize, ax=ax);
else:
ct1 = 0.5 * (1 - min_log/max_log)
ct2 = 0.5 * (1 + min_log/max_log)
cdict = {'red': [[0.0, depleted_color[0], depleted_color[0]],
[ct1, slightly_depleted_color[0], null_color[0]],
[ct2, null_color[0], slightly_enriched_color[0]],
[1.0, enriched_color[0], enriched_color[0]]],
'green': [[0.0, depleted_color[1], depleted_color[1]],
[ct1, slightly_depleted_color[1], null_color[1]],
[ct2, null_color[1], slightly_enriched_color[1]],
[1.0, enriched_color[1], enriched_color[1]]],
'blue': [[0.0, depleted_color[2], depleted_color[2]],
[ct1, slightly_depleted_color[2], null_color[2]],
[ct2, null_color[2], slightly_enriched_color[2]],
[1.0, enriched_color[2], enriched_color[2]]]}
cmap = LinearSegmentedColormap('sigmap', segmentdata=cdict, N=256)
fig = heatmap(enr, None, None, cmap=cmap, cmap_vmin_vmax=(-max_log,max_log), cmap_center=0.0, annotation=ann, colorbar=False, value_cluster=value_cluster, group_cluster=group_cluster, value_order=value_order, group_order=group_order, axsize=axsize, ax=ax);
rel_dpi_factor = fig.get_dpi() / 72
height_pxl = 200 * rel_dpi_factor * scale_legend
width_pxl = 15 * rel_dpi_factor * scale_legend
offset_top_pxl = 0 * rel_dpi_factor * scale_legend
offset_left_pxl = 30 * rel_dpi_factor * scale_legend
if ax is None:
ax = fig.axes[0]
left,bottom = fig.transFigure.inverted().transform(ax.transAxes.transform((1,1))+np.array([offset_left_pxl,-offset_top_pxl-height_pxl]))
width,height = fig.transFigure.inverted().transform(fig.transFigure.transform((0,0))+np.array([width_pxl,height_pxl]))
cax = fig.add_axes((left, bottom, width, height))
norm = Normalize(vmin=(0.0 if depleted_label is None else -max_log), vmax=max_log)
cb = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax)
if depleted_label is None:
cb.set_ticks([min_log,max_log])
cb.set_ticklabels([pmax,pmin])
else:
cb.set_ticks([-max_log,-min_log,min_log,max_log])
cb.set_ticklabels([pmin,pmax,pmax,pmin])
cb.ax.annotate('enriched', xy=(0, 1), xycoords='axes fraction', xytext=(-3, -5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='top', rotation=90, fontsize=10*scale_legend)
cb.ax.annotate('insignificant', xy=(0, (0.0 if depleted_label is None else 0.5)), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment=('bottom' if depleted_label is None else 'center'), rotation=90, fontsize=10*scale_legend)
if depleted_label is not None:
cb.ax.annotate('depleted', xy=(0, 0), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='bottom', rotation=90, fontsize=10*scale_legend)
return fig
def _escape_math_special_characters(string):
string = string.replace('_','\_')
return string
def _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, score_key, log_base):
if analysis_key not in adata.uns:
raise ValueError(f'`analysis_key` "{analysis_key}" is not found in `adata.uns`! Make sure to run tc.tl.co_occurrence first.')
if score_key in adata.uns[analysis_key]:
mean_scores = adata.uns[analysis_key][score_key]
if score_key.startswith('log'):
if log_base is not None:
mean_scores = mean_scores / np.log(log_base)
else:
raise ValueError(f'The score_key {score_key!r} was not found!')
if mean_scores is None:
raise ValueError(f'The score {score_key!r} in `adata.uns[{analysis_key}]` is None!')
intervals = adata.uns[analysis_key]['interval']
annotation = adata.uns[analysis_key]['annotation']
center = adata.uns[analysis_key]['center']
if show_only is not None:
if isinstance(show_only, str):
show_only = [show_only]
select = annotation.isin(show_only)
if select.sum() < len(show_only):
raise ValueError(f'The `show_only` categories {[s for s in show_only if s not in annotation]!r} are not available in the data!')
annotation = annotation[select]
mean_scores = mean_scores[select,:,:]
# check if the order is the same as in the show_only selection
permutation = np.argsort(annotation)[np.argsort(np.argsort(show_only))]
if not np.all(permutation == np.arange(len(permutation))):
annotation = annotation[permutation]
mean_scores = mean_scores[permutation,:,:]
if show_only_center is not None:
if isinstance(show_only_center, str):
show_only_center = [show_only_center]
select = center.isin(show_only_center)
if select.sum() < len(show_only_center):
raise ValueError(f'The `show_only_center` categories {[s for s in show_only_center if s not in center]!r} are not available in the data!')
center = center[select]
mean_scores = mean_scores[:,select,:]
# check if the order is the same as in the show_only_center selection
permutation = np.argsort(center)[np.argsort(np.argsort(show_only_center))]
if not np.all(permutation == np.arange(len(permutation))):
center = center[permutation]
mean_scores = mean_scores[:,permutation,:]
colors, types = _get_colors(colors, pd.Series(annotation))
return mean_scores, intervals, annotation, center, colors, types
def _get_cooc_expression_label(score_key,log_base):
if score_key == 'occ':
expression = '$\\frac{p(anno|center)}{p(anno)}$'
elif score_key == 'log_occ':
base_str = '' if log_base is None else f'_{log_base}'
expression = '$log' + base_str + '\\left(\\frac{p(anno|center)}{p(anno)}\\right)$'
elif score_key == 'z':
expression = '$\\frac{log(N(anno,center))-random expectation}{standard deviation}$'
elif score_key == 'composition':
expression = '$p(anno|center)$'
elif score_key == 'log_composition':
base_str = '' if log_base is None else f'_{log_base}'
expression = '$log' + base_str + '(p(anno|center))$'
elif score_key == 'distance_distribution':
expression = '$p(dist|anno,center)$'
elif score_key == 'log_distance_distribution':
base_str = '' if log_base is None else f'_{log_base}'
expression = '$log' + base_str + '(p(dist|anno,center))$'
elif score_key == 'relative_distance_distribution':
expression = '$\\frac{p(dist|anno,center)}{p(dist|*,center)}$'
elif score_key == 'log_relative_distance_distribution':
base_str = '' if log_base is None else f'_{log_base}'
expression = '$log' + base_str + '\\left(\\frac{p(dist|anno,center)}{p(dist|*,center)}\\right)$'
else:
base_str = '' if log_base is None else f'_{log_base}'
expression = '$log' + base_str + '\\left(\\frac{p(anno|center)/p(anno)}{gmean\\left(p(anno|center)/p(anno)\\right)}\\right)$'
return expression
[docs]
def co_occurrence(
adata,
analysis_key,
score_key='log_occ',
log_base=None,
colors=None,
show_only=None,
show_only_center=None,
axsize=(4,3),
sharex=True,
sharey='col',
wspace=0.15,
hspace=0.3,
legend=True,
grid=True,
merged=False,
ax=None
):
"""\
Plot co-occurrence as determined by :func:`~tacco.tools.co_occurrence`.
Parameters
----------
adata
An :class:`~anndata.AnnData` with the co-occurence analysis in `.uns`.
Can also be a mapping of labels to :class:`~anndata.AnnData` to specify
multiple datasets.
analysis_key
The `.uns` key with the co-occurence analysis result.
score_key
The `.uns[analysis_key]` key of the score to use or the
`.uns[analysis_key]['comparisons']` sub-key specifying the comparison
to plot. Available keys `.uns[analysis_key]` include:
- 'occ': co-occurrence
- 'log_occ': logarithm of the co-occurrence
- 'log2_occ': base-2-logarithm of the co-occurrence; this is a not a
real key but a convenience function to rescale the 'log_occ' values
- 'z': z-score of the log of the neighbourship counts with respect to
random neighbourships
- 'composition': distance dependent composition, `p(anno|center, dist)`
- 'log_composition': log of 'composition'
- 'distance_distribution': distribution of distances between anno and
center ´p(dist|anno,center)´
- 'log_distance_distribution': log of 'distance_distribution'
- 'relative_distance_distribution': 'distance_distribution' normalized
to `p(dist|*,center)`, the distance distribution of any annotation to
the center
- 'log_relative_distance_distribution': log of
'relative_distance_distribution'
log_base
The base of the logarithm to use for plotting if `score_key` is
'log_occ' or a comparison key. If `None`, use the natural logarithm.
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
show_only
A subset of annotation values to restrict the plotting to.
show_only_center
A subset of the center annotation values to restrict the plotting
to.
axsize
Tuple of width and size of a single axis.
sharex, sharey
Whether and how to use common x/y axis. Options include `True`,
`False`, "col", "row", "none", and "all".
wspace, hspace
Control the spacing between the plots.
legend
Whether to include the legend
grid
Whether to plot a grid
merged
Whether to merge the plots for all :class:`~anndata.AnnData` instances
into a single row of plots. This makes only sense if more instances are
provided in `adata`.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
adatas = _get_adatas(adata)
fig = None
linestyles = ['solid','dashed','dotted','dashdot']
for adata_i, (adata_name, adata) in enumerate(adatas.items()):
mean_scores, intervals, annotation, center, colors, types = _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, score_key, log_base)
if merged:
if len(adatas) > len(linestyles):
raise ValueError(f'`merged==True` is ony possible with up to {len(linestyles)} andatas!')
if fig is None:
if ax is not None:
if isinstance(ax, matplotlib.axes.Axes):
axs = np.array([[ax]])
else:
axs = ax
if axs.shape != (len(center), 1):
raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), 1)!r} supplied was {axs.shape!r}!')
axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size
fig = axs[0,0].get_figure()
else:
fig, axs = subplots(len(center), 1, axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, )
for ir, nr in enumerate(center):
x = (intervals[1:] + intervals[:-1]) / 2
for ja, na in enumerate(annotation):
y = mean_scores[ja,ir,:]
linestyle = linestyles[adata_i]
axs[0,ir].plot(x, y, color=colors[na], linestyle=linestyle)
if adata_i == len(adatas) - 1:
axs[0,ir].set_xlabel('distance')
adata_title = f'center {center.name}={nr}: '
anno_title = 'annotation: ' + annotation.name
axs[0,ir].set_xlabel('distance')
axs[0,ir].set_ylabel(anno_title)
expression = _get_cooc_expression_label(score_key,log_base)
axs[0,ir].set_title(adata_title + expression)
axs[0,ir].grid(grid)
else:
if fig is None:
if ax is not None:
if isinstance(ax, matplotlib.axes.Axes):
axs = np.array([[ax]])
else:
axs = ax
if axs.shape != (len(center), len(adatas)):
raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), len(adatas))!r} supplied was {axs.shape!r}!')
axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size
fig = axs[0,0].get_figure()
else:
fig, axs = subplots(len(center), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, )
for ir, nr in enumerate(center):
x = (intervals[1:] + intervals[:-1]) / 2
for ja, na in enumerate(annotation):
y = mean_scores[ja,ir,:]
linestyle = linestyles[nr==na]
axs[adata_i,ir].plot(x, y, color=colors[na], linestyle=linestyle)
axs[adata_i,ir].set_xlabel('distance')
adata_title = f'{adata_name}, center {center.name}={nr}: ' if adata_name != '' else f'center {center.name}={nr}: '
anno_title = 'annotation: ' + annotation.name
axs[adata_i,ir].set_xlabel('distance')
axs[adata_i,ir].set_ylabel(anno_title)
expression = _get_cooc_expression_label(score_key,log_base)
axs[adata_i,ir].set_title(adata_title + expression)
axs[adata_i,ir].grid(grid)
if legend:
handles = []
if merged:
handles.extend([mlines.Line2D([], [], color='gray', label=adata_name, linestyle=linestyle) for ((adata_name, adata), linestyle) in zip(adatas.items(), linestyles) ])
handles.extend([mpatches.Patch(color=color, label=ind) for (ind, color) in zip(annotation, colors[annotation]) ])
axs[0,len(center)-1].legend(handles=handles, bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
return fig
[docs]
def co_occurrence_matrix(
adata,
analysis_key,
score_key='log_occ',
log_base=None,
colors=None,
show_only=None,
show_only_center=None,
axsize=None,
hspace=None,
wspace=None,
x_padding=2.0,
y_padding=2.0,
value_cluster=False,
group_cluster=False,
restrict_intervals=None,
p_corr='fdr_bh',
cmap='bwr',
cmap_vmin_vmax=None,
legend=True,
ax = None,
scale_legend=1.0
):
"""\
Plot co-occurrence as determined by :func:`~tacco.tools.co_occurrence` or
:func:`~tacco.tools.co_occurrence_matrix` as a matrix.
Parameters
----------
adata
An :class:`~anndata.AnnData` with the co-occurence analysis in `.uns`.
Can also be a mapping of labels to :class:`~anndata.AnnData` to specify
multiple datasets.
analysis_key
The `.uns` key with the co-occurence analysis result.
score_key
The `.uns[analysis_key]` key of the score to use. Available keys
`.uns[analysis_key]` include:
- 'occ': co-occurrence
- 'log_occ': logarithm of the co-occurrence
- 'log2_occ': base-2-logarithm of the co-occurrence; this is a not a
real key but a convenience function to rescale the 'log_occ' values
- 'z': z-score of the log of the neighbourship counts with respect to
random neighbourships
- 'composition': distance dependent composition, `p(anno|center, dist)`
- 'log_composition': log of 'composition'
- 'distance_distribution': distribution of distances between anno and
center ´p(dist|anno,center)´
- 'log_distance_distribution': log of 'distance_distribution'
- 'relative_distance_distribution': 'distance_distribution' normalized
to `p(dist|*,center)`, the distance distribution of any annotation to
the center
- 'log_relative_distance_distribution': log of
'relative_distance_distribution'
log_base
The base of the logarithm to use for plotting if `score_key` is a log
quantity. If `None`, use the natural logarithm.
colors
The mapping of value names to colors. If `None`, a set of
standard colors is used.
show_only
A subset of annotation values to restrict the plotting to.
show_only_center
A subset of the center annotation values to restrict the plotting
to.
axsize
Tuple of width and size of a single axis. If `None`, some heuristic
value is used.
hspace, vspace
Relative horizontal and vertical spacing between plots
x_padding, y_padding
Absolute horizontal and vertical spacing between plots; this setting
overrides `hspace` and `vspace`; if `None`, use the value from `hspace`
`vspace`; if `None`, use the value from `vspace`
value_cluster
Whether to cluster and reorder the values.
group_cluster
Whether to cluster and reorder the groups.
restrict_intervals
A list-like containing the indices of the intervals to plot. If `None`,
all intervals are included.
cmap
A string/colormap to override the `colors` with globally.
cmap_vmin_vmax
A tuple giving the range of values for the colormap.
legend
Whether to include the legend
scale_legend
Set to scale height and width of legend.
ax
The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a
fresh figure for plotting. Incompatible with dendrogram plotting.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
adatas = _get_adatas(adata)
min_max = None
# first pass through the data to get global min/max of the values for colormap
for adata_i, (adata_name, adata) in enumerate(adatas.items()):
mean_scores, intervals, annotation, center, colors, types = _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, score_key=score_key, log_base=log_base)
if min_max is None:
if restrict_intervals is None:
restrict_intervals = np.arange(len(intervals)-1)
min_max = np.zeros((len(adatas),len(restrict_intervals),2))
data = mean_scores[:,:,restrict_intervals]
min_max[adata_i,:,:] = data.min(),data.max()
min_max[:,:,0] = min_max[:,:,0].min()
min_max[:,:,1] = min_max[:,:,1].max()
if cmap_vmin_vmax is not None:
min_max[:,:,:] = np.array(cmap_vmin_vmax)
if axsize is None:
axsize = (0.2*len(center),0.2*len(annotation))
if ax is not None:
if isinstance(ax, matplotlib.axes.Axes):
axs = np.array([[ax]])
else:
axs = ax
if axs.shape != (len(adatas), len(restrict_intervals)):
raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(adatas), len(restrict_intervals))!r} supplied was {axs.shape!r}!')
axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size
fig = axs[0,0].get_figure()
else:
fig, axs = subplots(len(restrict_intervals), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, x_padding=x_padding, y_padding=y_padding, )
# second pass for actual plotting
for adata_i, (adata_name, adata) in enumerate(adatas.items()):
mean_scores, intervals, annotation, center, colors, types = _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, score_key=score_key, log_base=log_base)
data = mean_scores[:,:,restrict_intervals]
for _ii, rii in enumerate(restrict_intervals):
interval = f'({intervals[rii]},{intervals[rii+1]})'
ax = axs[adata_i,_ii]
data = mean_scores[:,:,_ii]
data = pd.DataFrame(data, index=annotation, columns=center).T
heatmap(
data,
value_key=None, group_key=None,
colors=colors,
value_cluster=value_cluster, group_cluster=group_cluster,
ax=ax,
cmap=cmap,
cmap_center=None,#(0 if log else None),
cmap_vmin_vmax=min_max[adata_i,_ii],
group_labels_rotation=90,
colorbar=False,
);
adata_title = f'{adata_name}, interval {interval}: ' if adata_name != '' else f'interval {interval}: '
expression = _get_cooc_expression_label(score_key,log_base)
ax.set_title(adata_title + expression)
anno_title = 'annotation: ' + _escape_math_special_characters(annotation.name)
anno_center_title = 'center: ' + _escape_math_special_characters(center.name)
ax.set_ylabel(anno_title)
ax.set_xlabel(anno_center_title)
if legend:
_add_legend_or_colorbars(fig, axs, colors, cmap=cmap, min_max=min_max, scale_legend=scale_legend)
return fig
[docs]
def annotated_heatmap(
adata,
obs_key=None,
var_key=None,
n_genes=None,
var_highlight=None,
obs_colors=None,
var_colors=None,
cmap='bwr',
cmap_center=0,
cmap_vmin_vmax=(-2,2),
trafo=None,
axsize=(4,4),
):
"""\
Plots a heatmap of cells and genes grouped by categorical annotations.
Parameters
----------
adata
An :class:`~anndata.AnnData` with annotation in `.obs` and/or `.var`.
obs_key
The `.obs` key with the categorical `obs` annotation to use. If `None`,
the observations are not grouped and plotted in the order in which they
appear in `adata`.
var_key
The `.var` key with the categorical `var` annotation to use. Can also
be a mapping of annotations to list-likes of var names. If `None`, the
genes are not grouped and plotted in the order in which they appear in
`adata`. If `n_genes` is set, the meaning of this key is modified.
n_genes
The number of differentially expressed genes to find for the groups of
observations. The differentially exressed genes will be used in place
of a categorical `var` annotation. Setting `n_genes` changes the
behaviour of `var_key`: It is interpreted as a categorical `.obs` key
defining the groups of observations for which to derive the
differentially expressed genes. If `var_key` is `None`, `obs_key` is
used for that.
var_highlight
A list-like of var names to annotate.
obs_colors
A dict-like with the colors to use for the `obs_key` annotation. If
`None`, default colors are used.
var_colors
A dict-like with the colors to use for the `var_key` annotation. If
`None`, default colors are used, except if `n_genes` is set which
triggers the usage of `obs_colors` for `var_colors`.
cmap
A string/colormap to use in the heatmap.
cmap_center
A value to use as center of the colormap. E.g. choosing `0` sets the
central color to `0`. If `None`, the colormap spans the entire value
range.
cmap_vmin_vmax
A tuple giving the range of values for the colormap. This can be
modfied by `cmap_center`.
trafo
Whether to normalize, logarithmize, and scale the data prior to
plotting. This makes sense if bare count data is supplied. If the data
is already preprocessed, this can be set to `False`. If `None`, a
heuristic tries to figure out whether count data was supplied, and if
so performs the preprocessing.
axsize
Tuple of width and size of the main axis.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
if trafo is None:
got_counts = True
try:
preprocessing.check_counts_validity(adata.X)
except:
got_counts = False
if got_counts:
print('`adata` looks like it contains bare counts and will be normalized, logarithmized, and scaled for plotting. If that is not desired (and to get rid of this message) explicitly set `trafo` to something else than `None`.')
trafo = True
else:
trafo = False
if trafo or n_genes:
adata = adata.copy()
if trafo:
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.log1p(adata)
if var_key is not None or n_genes is not None:
if n_genes is not None:
if var_key is not None:
if hasattr(var_key, 'items'):
raise ValueError(f'`var_key` cannot be a dict-like if `n_genes` is set!')
if var_key not in adata.obs:
raise ValueError(f'`var_key` {var_key!r} is not a column of `adata.obs` even though `n_genes` is set!')
if not hasattr(adata.obs[var_key], 'cat'):
print(f'WARNING: `var_key` {var_key!r} is not a categorical column of `adata.obs` even though `n_genes` is set! Treating it as a categorical column...')
group_key = var_key
elif obs_key is not None:
if obs_key not in adata.obs:
raise ValueError(f'`obs_key` {obs_key!r} is not a column of `adata.obs`!')
if not hasattr(adata.obs[obs_key], 'cat'):
print(f'WARNING: `obs_key` {obs_key!r} is not a categorical column of `adata.obs`! Treating it as a categorical column...')
group_key = obs_key
if var_colors is None:
var_colors = obs_colors
else:
raise ValueError(f'`n_genes` can only be used if at least one of [`obs_key`, `var_key`] is not `None`!')
ukey = utils.find_unused_key(adata.uns)
sc.tl.rank_genes_groups(adata, group_key, key_added=ukey, n_genes=n_genes,)
marker = pd.DataFrame(adata.uns[ukey]['names'])
del adata.uns[ukey]
else: # if var_key is not None:
if hasattr(var_key, 'items'):
marker = var_key
else:
if var_key not in adata.var:
raise ValueError(f'`var_key` {var_key!r} is not a column of `adata.var`!')
if not hasattr(adata.var[var_key], 'cat'):
print(f'WARNING: `var_key` {var_key!r} is not a categorical column of `adata.var`! Treating it as a categorical column...')
marker = {c: df.index for c,df in adata.var.groupby(var_key, observed=False)}
all_marker = [c for l,m in marker.items() for c in m]
# reorder genes to represent the annotation
adata = adata[:,all_marker]
if var_colors is None:
var_colors = to_rgba_array(get_default_colors(len(marker)))
var_colors = {cat:col for cat,col in zip(marker.keys(),var_colors)}
else:
var_colors = {cat:to_rgba(col) for cat,col in var_colors.items()}
all_marker_colors = [var_colors[l] for l,m in marker.items() for g in m]
all_marker_labels = [l for l,m in marker.items() for c in m]
marker_centers = { l: np.median(np.arange(len(all_marker_labels))[np.array(all_marker_labels) == l]) for l in marker.keys() }
if obs_key is not None:
if obs_key not in adata.obs:
raise ValueError(f'`obs_key` {obs_key!r} is not a column of `adata.obs`!')
if not hasattr(adata.obs[obs_key], 'cat'):
print(f'WARNING: `obs_key` {obs_key!r} is not a categorical column of `adata.obs`! Treating it as a categorical column...')
cells = {c: df.index for c,df in adata.obs.groupby(obs_key, observed=False)}
all_cells = [c for l,m in cells.items() for c in m]
# reorder cells to represent the annotation
adata = adata[all_cells]
if obs_colors is None:
obs_colors = to_rgba_array(get_default_colors(len(cells)))
obs_colors = {cat:col for cat,col in zip(cells.keys(),obs_colors)}
else:
obs_colors = {cat:to_rgba(col) for cat,col in obs_colors.items()}
all_cell_colors = [obs_colors[l] for l,m in cells.items() for c in m]
all_cell_labels = [l for l,m in cells.items() for c in m]
cell_centers = { l: np.median(np.arange(len(all_cell_labels))[np.array(all_cell_labels) == l]) for l in cells.keys() }
if var_highlight is not None:
# reorder highlighted genes to simplify labelling
var_highlight = pd.Series(adata.var.index)[adata.var.index.isin(var_highlight)]
highlight_centers = pd.Series(var_highlight.index, index=var_highlight )
if trafo:
if adata.is_view:
adata = adata.copy()
sc.pp.scale(adata)
data = adata.X
if scipy.sparse.issparse(data):
data = data.A
if cmap_vmin_vmax is None:
cmap_vmin_vmax = [data.min(),data.max()]
cmap_vmin_vmax = np.array(cmap_vmin_vmax)
if cmap_center is not None:
shifted = cmap_vmin_vmax - cmap_center
abs_max = np.max(np.abs(shifted))
shifted[:] = [-abs_max,abs_max]
cmap_vmin_vmax = shifted + cmap_center
# plotting
fig,axs=subplots(2,2,axsize=axsize,width_ratios=[0.03,axsize[0]/4],height_ratios=[0.03,axsize[1]/4],x_padding=0.05,y_padding=0.05)
axs[0,0].axis('off')
im = axs[1,1].imshow(data.T,aspect='auto',cmap=cmap, vmin=cmap_vmin_vmax[0], vmax=cmap_vmin_vmax[1])
axs[1,1].set_xticks([])
axs[1,1].set_yticks([])
rel_dpi_factor = fig.get_dpi() / 72
cax_width = 100 * rel_dpi_factor # color bar width in pixel
cax_height = 10 * rel_dpi_factor # color bar height in pixel
cax_offset = 10 * rel_dpi_factor # color bar y offset in pixel
cax_l = 1 - fig.transFigure.inverted().transform([cax_width,0])[0] + fig.transFigure.inverted().transform([0,0])[0]
cax_b = 0 - fig.transFigure.inverted().transform([0,cax_height+cax_offset])[1] + fig.transFigure.inverted().transform([0,0])[1]
cax_w = fig.transFigure.inverted().transform([cax_width,0])[0] - fig.transFigure.inverted().transform([0,0])[0]
cax_h = fig.transFigure.inverted().transform([0,cax_height])[1] - fig.transFigure.inverted().transform([0,0])[1]
cax = fig.add_axes((cax_l,cax_b,cax_w,cax_h))
fig.colorbar(im, cax=cax, orientation='horizontal')
# labelling
def _collides(ann1, ann2, offset1, offset2, fig, direction):
if offset1 != offset2:
return False
extent1 = ann1.get_window_extent(fig.canvas.get_renderer())
extent2 = ann2.get_window_extent(fig.canvas.get_renderer())
if direction == 'x':
return extent2.x1 > extent1.x0
else:
return extent1.y1 > extent2.y0
def _collides_any(ann, off, anns, offs, fig, direction):
for ann2, off2 in zip(anns, offs):
if _collides(ann, ann2, off, off2, fig, direction):
return True
return False
def _find_offset(ann, anns, offs, fig, direction):
off = 0
while _collides_any(ann, off, anns, offs, fig, direction):
off += 1
return off
def _find_shift(ann, anns, fig, direction):
if len(anns) == 0:
return 0
extent1 = ann.get_window_extent(fig.canvas.get_renderer())
extent2 = anns[-1].get_window_extent(fig.canvas.get_renderer())
if direction == 'x':
delta = extent2.x1 - extent1.x0
else:
delta = extent1.y1 - extent2.y0
if delta > 0:
return delta
else:
return 0
if obs_key is not None:
axs[0,1].imshow(np.array([all_cell_colors]),aspect='auto')
axs[0,1].set_xticks([])
axs[0,1].set_yticks([])
anns = []
offs = []
bar = (axs[0,1].transData.inverted().transform([0,-15*rel_dpi_factor])[1] - axs[0,1].transData.inverted().transform([0,0])[1])
for l,c in cell_centers.items():
ann = axs[0,1].annotate(l, (c, -0.5), (c, -0.5-bar), ha="center", va="center", rotation=0, size=10, arrowprops={'arrowstyle':'-'},)
offset = _find_offset(ann, anns, offs, fig, direction='x')
ann.xyann = (ann.xyann[0],ann.xyann[1]-offset*1.8)
anns.append(ann)
offs.append(offset)
else:
axs[0,1].axis('off')
if var_key is not None or n_genes is not None:
axs[1,0].imshow(np.array([all_marker_colors]).swapaxes(0,1),aspect='auto')
axs[1,0].set_xticks([])
axs[1,0].set_yticks([])
anns=[]
bar = (axs[1,0].transData.inverted().transform([15*rel_dpi_factor,0])[0] - axs[1,0].transData.inverted().transform([0,0])[0])
for l,c in marker_centers.items():
ann = axs[1,0].annotate(l, (-0.5, c), (-0.5-bar, c), ha="right", va="center", rotation=0, size=10, arrowprops={'arrowstyle':'-'},)
shift = _find_shift(ann, anns, fig, direction='y')
shift = (axs[1,0].transData.inverted().transform([0,shift])[1] - axs[1,0].transData.inverted().transform([0,0])[1])
ann.xyann = (ann.xyann[0],ann.xyann[1]-shift)
anns.append(ann)
else:
axs[1,0].axis('off')
if var_highlight is not None:
anns=[]
bar = (axs[1,1].transData.inverted().transform([15*rel_dpi_factor,0])[0] - axs[1,1].transData.inverted().transform([0,0])[0])
offset = (axs[1,1].transData.inverted().transform([0*rel_dpi_factor,0])[0] - axs[1,1].transData.inverted().transform([0,0])[0]) - 0.5
for l,c in highlight_centers.items():
ann = axs[1,1].annotate(l, (len(adata.obs.index)+offset, c), (len(adata.obs.index)+offset+bar, c), ha="left", va="center", rotation=0, size=10, arrowprops={'arrowstyle':'-'}, annotation_clip=False,)
shift = _find_shift(ann, anns, fig, direction='y')
shift = (axs[1,1].transData.inverted().transform([0,shift])[1] - axs[1,1].transData.inverted().transform([0,0])[1])
ann.xyann = (ann.xyann[0],ann.xyann[1]-shift)
anns.append(ann)
return fig
@njit(parallel=False,fastmath=True,cache=True)
def _anno_hist(mist, intervals, weights, anno):
Nobs,Nanno = anno.shape
assert(Nobs==len(mist))
assert(Nobs==len(weights))
Nd = len(intervals)-1
hist = np.zeros((Nd,))
sums = np.zeros((Nd,Nanno))
for i in range(Nobs):
di = mist[i]
_di = np.argmax(di <= intervals) - 1
if di > 0:
hist[_di] += weights[i]
sums[_di] += weights[i] * anno[i]
for d in range(Nd):
sums[d] /= hist[d]
return sums
[docs]
def annotation_coordinate(
adata,
annotation_key,
coordinate_key,
group_key=None,
reference_key=None,
max_coordinate=None,
delta_coordinate=None,
axsize=None,
colors=None,
stacked=True,
verbose=1,
):
"""\
Plots an annotation density with respect to a scalar coordinate.
Parameters
----------
adata
A :class:`~anndata.AnnData`.
annotation_key
The `.obs` or `.obsm` key to plot.
coordinate_key
The `.obs` key or (`.obsm` key, column name) pair with the scalar
coordinate(s).
group_key
A categorical group annotation. The plot is done separately per group.
If `None`, plots for only one group are generated.
reference_key
The `.obs` key to use as weights (i.e. the weight to use for
calculating the annotation density). If `None`, use `1` per
observation, which makes sense if the annotation is categorical or
fractional annotations which should sum to `1`.
max_coordinate
The maximum coordinate to use. If `None` or `np.inf`, uses the maximum
coordinate in the data.
delta_coordinate
The width in coordinate for coordinate discretization. If `None`, takes
`max_coordinate/100`.
axsize
Size of a single axis in the plot
colors
A mapping of annotation values to colors. If `None`, default colors are
used.
stacked
Whether to plot the different annotations on different scales on
separate stacked plots or on the same scale in a single plot.
verbose
Level of verbosity, with `0` (no output), `1` (some output), ...
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
if group_key is None:
group_adatas = {'':adata}
else:
group_adatas = {group:adata[df.index] for group,df in adata.obs.groupby(group_key, observed=False) if len(df)>0 }
if annotation_key in adata.obs:
annotation = adata.obs[annotation_key]
if hasattr(annotation, 'cat'):
annotation = pd.get_dummies(annotation)
else:
annotation = pd.DataFrame(annotation)
elif annotation_key in adata.obsm:
annotation = adata.obsm[annotation_key].copy()
else:
raise ValueError(f'The `annotation_key` {annotation_key!r} is neither in `adata.obs` nor `adata.obsm`!')
if pd.api.types.is_list_like(coordinate_key):
if len(coordinate_key) == 2 and coordinate_key[0] in adata.obsm and coordinate_key[1] in adata.obsm[coordinate_key[0]]:
coordinates = adata.obsm[coordinate_key[0]][coordinate_key[1]]
else:
raise ValueError(f'The `coordinate_key` {coordinate_key!r} is list/like, but not something of length 2 containing a `adata.obsm` key and a column name therein!')
#coordinate_key = f'{coordinate_key[0]}:{coordinate_key[1]}'
coordinate_key = f'{coordinate_key[1]}'
elif coordinate_key in adata.obs:
coordinates = adata.obs[coordinate_key]
else:
raise ValueError(f'The `coordinate_key` {coordinate_key!r} is not in `adata.obs` and not a valid specification for something in `adata.obsm`!')
if max_coordinate is None or max_coordinate == np.inf:
max_coordinate = coordinates.to_numpy().max()
max_coordinate = float(max_coordinate)
if delta_coordinate is None:
delta_coordinate = max_coordinate / 100
n_intervals = int(max_coordinate / delta_coordinate)
max_coordinate = n_intervals * delta_coordinate
intervals = np.arange(0,max_coordinate+delta_coordinate*0.5,delta_coordinate)
midpoints = (intervals[1:] + intervals[:-1]) * 0.5
if reference_key is None:
reference_weights = pd.Series(np.ones(len(adata.obs.index),dtype=float),index=adata.obs.index)
elif reference_key in adata.obs:
reference_weights = adata.obs[reference_key]
else:
raise ValueError(f'The `reference_key` {reference_key!r} is not in `adata.obs`!')
annotation_categories = annotation.columns
colors,annotation_categories = _get_colors(colors, annotation_categories)
if stacked:
if axsize is None:
axsize = (6,0.7)
fig,axs = subplots(len(group_adatas),len(annotation.columns),axsize=axsize,sharex=True,sharey='row',y_padding=0,x_padding=0)
else:
if axsize is None:
axsize = (6,4)
fig,axs = subplots(len(group_adatas),1,axsize=axsize,sharex=True,sharey='row',y_padding=0,x_padding=0)
for gi,(group,group_adata) in enumerate(group_adatas.items()):
group_annotation = annotation.loc[group_adata.obs.index].to_numpy()
group_coordinates = coordinates.loc[group_adata.obs.index].to_numpy()
group_reference_weights = reference_weights.loc[group_adata.obs.index].to_numpy()
no_nan = ~np.isnan(group_annotation).any(axis=1)
anno = _anno_hist(group_coordinates[no_nan], intervals, group_reference_weights[no_nan], group_annotation[no_nan])
anno = pd.DataFrame(anno,index=midpoints,columns=annotation.columns)
if stacked:
for i,c in enumerate(colors.index):
axs[i,gi].plot(anno[c],c=colors[c],label=c)
axs[i,gi].xaxis.grid(True)
else:
for i,c in enumerate(colors.index):
axs[0,gi].plot(anno[c],c=colors[c],label=c)
axs[0,gi].xaxis.grid(True)
group_string = f' in {group}' if group != '' else ''
axs[0,gi].set_title(f'{annotation_key} VS distance from {coordinate_key}{group_string}')
axs[0,-1].legend(handles=[mpatches.Patch(color=color, label=ind) for (ind, color) in colors.items() ],
bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
return fig
[docs]
def dotplot(
adata,
genes,
group_key,
log1p=True,
marks=None,
marks_colors=None,
swap_axes=True,
):
"""\
Dot plot of expression values.
This is similar to :func:`scanpy.pl.dotplot` with customizations, e.g. the
option to mark selected dots.
Parameters
----------
adata
An :class:`~anndata.AnnData` including annotation in `.obs`.
genes
The `.var` index values to compare use as a list-like.
group_key
An `.obs` key with categorical group information.
log1p
Whether to log1p-transform the data prior to plotting.
marks
A :class:`pandas.DataFrame` containing categorical markers for the dots
with genes in the rows and groups in the columns.
marks_colors
A mapping from the categories in `marks` to colors; if `None`, default
colors are used.
swap_axes
If `False`, the x axis contains the genes and the y axis the groups.
Otherwise the axes are swapped.
Returns
-------
A :class:`~matplotlib.figure.Figure`.
"""
if not pd.Index(genes).isin(adata.var.index).all():
raise ValueError(f'The genes {pd.Index(genes).difference(adata.var.index)!r} are not available in `adata.var`!')
markers = genes[::-1]
if group_key not in adata.obs.columns:
raise ValueError(f'The `group_key` {group_key!r} is not available in `adata.obs`!')
if hasattr(adata.obs[group_key], 'cat'):
cluster = adata.obs[group_key].cat.categories
else:
cluster = adata.obs[group_key].unique()
if swap_axes:
xticklabels = cluster
yticklabels = markers
else:
xticklabels = markers
yticklabels = cluster
fig,axs = subplots(axsize=0.25*np.array([len(xticklabels),len(yticklabels)]))
x = np.arange(len(xticklabels)) # the label locations
y = np.arange(len(yticklabels)) # the label locations
if not swap_axes:
x = x[::-1]
y = y[::-1]
axs[0,0].set_xticks(x)
axs[0,0].set_xticklabels(xticklabels, rotation=45, ha='right',)
axs[0,0].set_yticks(y)
axs[0,0].set_yticklabels(yticklabels)
axs[0,0].set_xlim((x.min()-0.5,x.max()+0.5))
axs[0,0].set_ylim((y.min()-0.5,y.max()+0.5))
axs[0,0].set_axisbelow(True)
axs[0,0].grid(True)
marker_counts = adata[:,markers].to_df()
if log1p:
marker_counts = np.log1p(marker_counts)
mean_exp = pd.DataFrame({c: marker_counts.loc[df.index].mean(axis=0) for c,df in adata.obs.groupby(group_key, observed=False) })
mean_pos = pd.DataFrame({c: (marker_counts.loc[df.index] != 0).mean(axis=0) for c,df in adata.obs.groupby(group_key, observed=False) })
if marks is not None:
marks = marks.reindex_like(mean_pos)
mean_exp_index_name = 'index' if mean_exp.index.name is None else mean_exp.index.name
mean_pos_index_name = 'index' if mean_pos.index.name is None else mean_pos.index.name
mean_exp = pd.melt(mean_exp, ignore_index=False).reset_index().rename(columns={mean_exp_index_name:'value','variable':'cluster','value':'mean_exp'})
mean_pos = pd.melt(mean_pos, ignore_index=False).reset_index().rename(columns={mean_pos_index_name:'value','variable':'cluster','value':'mean_pos'})
if marks is not None:
marks.index.name = None
marks.columns.name = None
marks_index_name = 'index' if marks.index.name is None else marks.index.name
marks = pd.melt(marks, ignore_index=False).reset_index().rename(columns={marks_index_name:'value','variable':'cluster','value':'marks'})
if marks_colors is None:
marks_colors = get_default_colors(marks['marks'].unique())
all_df = pd.merge(mean_exp, mean_pos, on=['value', 'cluster'])
if marks is not None:
all_df = pd.merge(all_df, marks, on=['value', 'cluster'], how='outer')
if all_df['marks'].isna().any():
raise ValueError(f'There were gene-group combinations without a match in "marks"!')
if swap_axes:
all_df['x'] = all_df['cluster'].map(pd.Series(x,index=xticklabels))
all_df['y'] = all_df['value'].map(pd.Series(y,index=yticklabels))
else:
all_df['x'] = all_df['value'].map(pd.Series(x,index=xticklabels))
all_df['y'] = all_df['cluster'].map(pd.Series(y,index=yticklabels))
legend_items = []
mean_exp_min, mean_exp_max = all_df['mean_exp'].min(), all_df['mean_exp'].max()
norm = Normalize(vmin=mean_exp_min, vmax=mean_exp_max)
cmap='Reds'#LinearSegmentedColormap.from_list('mean_exp', [(0,(1, 1, 1)),(1,(1, g, b))])
mapper = ScalarMappable(norm=norm, cmap=cmap)
color = [ tuple(x) for x in mapper.to_rgba(all_df['mean_exp'].to_numpy()) ]
legend_items.append(mpatches.Patch(color='#0000', label='mean expression'))
mean_exp_for_legend = np.linspace(mean_exp_min, mean_exp_max, 4)
legend_items.extend([mpatches.Patch(color=color, label=f'{ind:.2f}') for color,ind in zip(mapper.to_rgba(mean_exp_for_legend),mean_exp_for_legend)])
mean_pos_min, mean_pos_max = all_df['mean_pos'].min(), all_df['mean_pos'].max()
def size_map(x):
return (x/mean_pos_max * 14)**2
size = size_map(all_df['mean_pos'])
legend_items.append(mpatches.Patch(color='#0000', label='fraction of expressing cells'))
mean_pos_for_legend = np.linspace(mean_pos_min, mean_pos_max, 5)[1:]
legend_items.extend([mlines.Line2D([], [], color='#aaa', linestyle='none', marker='o', markersize=np.sqrt(size_map(ind)), label=f'{ind:.2f}') for ind in mean_pos_for_legend])
edgecolors = '#aaa' if marks is None else all_df['marks'].map(marks_colors)
if marks is not None:
marks_name = marks_colors.name if hasattr(marks_colors, 'name') else ''
legend_items.append(mpatches.Patch(color='#0000', label=marks_name))
legend_items.extend([mlines.Line2D([], [], color='#aaa', linestyle='none', fillstyle='none', markeredgecolor=color, marker='o', markersize=np.sqrt(size_map(mean_pos_for_legend[-2])), label=f'{ind}') for ind,color in marks_colors.items()])
axs[0,0].scatter(all_df['x'], all_df['y'], c=color, s=size, edgecolors=edgecolors)
axs[0,0].legend(handles=legend_items, bbox_to_anchor=(1, 1), loc='upper left', ncol=1)
return fig