Source code for thor.analysis.ccc

from typing import Dict, Optional
import logging

import numpy as np
import pandas as pd
import scanpy as sc
import anndata
from sklearn.neighbors import NearestNeighbors

from thor.utils import get_adata_layer_array, get_library_id, update_kwargs_exclusive, get_scalefactors, on_patch_rect

logger = logging.getLogger(__name__)

def precompute_nearest_pairs_distances(
    adata: anndata.AnnData,
    cutoff: float,
    spatial_key: Optional[str] = "spatial"
) -> None:
    """
    Precompute pairwise distances between cells within a given cutoff distance using the NearestNeighbors class from scikit-learn.

    Parameters
    ----------
    adata
        Annotated data matrix.
    cutoff
        The cutoff distance for the nearest neighbor search.
    spatial_key
        The key of the spatial coordinates in adata.obsm.

    Returns
    -------
    None

    Examples
    --------
    >>> import scanpy as sc
    >>> import thor
    >>> adata = sc.datasets.visium_sge()
    >>> distance = 500 # distance cutoff in microns
    >>> microns_per_pixel = 0.25
    >>> distance_pixel = distance / microns_per_pixel
    >>> thor.analy.precompute_nearest_pairs_distances(adata, cutoff=distance_pixel)
    """
    X = adata.obsm[spatial_key]
    neigh = NearestNeighbors(radius=cutoff)
    neigh.fit(X)
    A = neigh.radius_neighbors_graph(X, mode="distance")
    adata.obsp["spatial_distance"] = A


def get_pathway(pathway_name: str, pathways_df: pd.DataFrame, name_col_index: int = 2) -> pd.DataFrame:
    """
    Get a subset of the pathways database containing the specified pathway.

    Parameters
    ----------
    pathway_name : str
        The name of the pathway to retrieve.
    pathways_df : pandas.DataFrame
        The database of pathways to search.
    name_col_index : int, optional
        The index of the column containing the pathway names in the database.
        Default is 2.

    Returns
    -------
    pandas.DataFrame
        A subset of the pathways database containing the specified pathway.
    """
    pathway_name = pathway_name.upper()
    sub_ligrec = pathways_df[pathways_df.iloc[:, name_col_index].str.upper().isin([pathway_name])]
    return sub_ligrec


def split_pathways(pathways_df: pd.DataFrame, name_col_index: int = 2) -> Dict[str, pd.DataFrame]:
    """
    Split the pathways database into a dictionary of pathways.

    Parameters
    ----------
    pathways_df : pandas.DataFrame
        The DataFrame containing the pathways column.
    name_col_index : int, optional
        The index of the column containing the pathway names. Default is 2.

    Returns
    -------
    dict
        A dictionary of pathways, where the keys are the pathway names and the values are the corresponding DataFrames.
    """
    pathways = pathways_df.iloc[:, name_col_index].unique()

    pathways_dict = {
        p: get_pathway(p, pathways_df, name_col_index=name_col_index)
        for p in pathways
    }

    return pathways_dict


def add_image_row_col(adata, spatial_key="spatial"):
    adata.obs.loc[:, "imagerow"] = adata.obsm[spatial_key][:, 0]
    adata.obs.loc[:, "imagecol"] = adata.obsm[spatial_key][:, 1]


def prepare_adata(adata, layer=None, img_key="hires", gene_symbols_key="feature_name"):

    ad_sym = adata.copy()
    try:
        del ad_sym.raw
    except:
        pass

    if gene_symbols_key in ad_sym.var.columns:
        ad_sym.var_names = ad_sym.var.loc[:, gene_symbols_key]

    # use specified layer
    ad_sym.X = get_adata_layer_array(adata, layer_key=layer)
    try:
        del ad_sym.layers
    except:
        pass

    # set image to use
    library_id = get_library_id(adata)
    ad_sym.uns["spatial"][library_id]["use_quality"] = img_key
    add_image_row_col(ad_sym)

    # filter cells with no expression
    ad_sym = ad_sym[ad_sym.X.sum(axis=1) > 0]

    # Remove incompatible genes
    def remove_incompatible_genes(adata, incompatible_gene_characters_list=["_"]):
        return adata[:, ~adata.var.index.str.contains("|".join(incompatible_gene_characters_list))]

    ad_sym = remove_incompatible_genes(ad_sym)
    return ad_sym


[docs] def run_commot(adata, region=None, gene_symbols_key="feature_name", **kwargs): """ Run the cell-cell communication analysis using the modified `COMMOT <https://doi.org/10.1038/s41592-022-01728-4>`_ method. This is a wrapper function the :py:func:`commot.tl.spatial_communication`. The function first prepares the data by filtering and processing the input data, and precompute the cell-cell distances matrix more efficiently, supporting sparse matrix. Parameters ---------- adata : :class:`~anndata.AnnData` Annotated data matrix. region : :py:class:`list`, optional The region to analyze in the format [left, right, lower, upper]. left < right and lower < upper. gene_symbols_key : :py:class:`str`, optional The key for gene symbols in adata.var. Default is "feature_name". kwargs: :py:class:`dict`, optional Additional keyword arguments for :py:func:`commot.tl.spatial_communication`. Returns ------- adata: :class:`~anndata.AnnData` Annotated data matrix. See Also -------- :py:func:`commot.tl.spatial_communication` """ import commot as ct if region is not None: region = np.array(region) xy = adata.obsm["spatial"] ad_selected = adata[on_patch_rect(xy, region)].copy() else: ad_selected = adata.copy() ad_selected = prepare_adata(ad_selected, gene_symbols_key=gene_symbols_key) commot_kwargs = update_kwargs_exclusive(ct.tl.spatial_communication, kwargs) precompute_nearest_pairs_distances(ad_selected, commot_kwargs["dis_thr"]) ct.tl.spatial_communication( ad_selected, **commot_kwargs ) return ad_selected
def plot_commot(adata, region=None, **kwargs): """ Plot the cell-cell communication analysis results on a spatial plot. Parameters ---------- adata : :class:`~anndata.AnnData` Annotated data matrix. region : list, optional The region to plot in the format [left, right, lower, upper]. left < right and lower < upper. kwargs Additional keyword arguments for the `cc.tl.communication_direction` and `cc.pl.plot_cell_communication` functions. """ import commot as ct from matplotlib import pyplot as plt communication_direction_kwargs = update_kwargs_exclusive(ct.tl.communication_direction, kwargs) ct.tl.communication_direction(adata, **communication_direction_kwargs) plot_kwargs = update_kwargs_exclusive(ct.pl.plot_cell_communication, kwargs) ax = ct.pl.plot_cell_communication(adata, **plot_kwargs) if region is not None: scalef = get_scalefactors(adata)['tissue_hires_scalef'] lower = region[2] * scalef higher = region[3] * scalef left = region[0] * scalef right = region[1] * scalef ax.set_xlim(left, right) ax.set_ylim(lower, higher) ax.invert_yaxis() plt.show() def commot_to_dynamo(ad_commot, pathway, database, lr="sender", basis="spatial"): """Convert the COMMOT output to dynamo format. Parameters ---------- ad_commot : :class:`~anndata.AnnData` Annotated data matrix. The COMMOT results are stored in `obsm`. pathway : str The pathway name. database : str The database name. lr : str, optional The ligand-receptor direction. Default is "sender". basis : str, optional The low-dimensional embedding used for COMMOT analysis. Default is "spatial" for spatial transcriptomics data. Returns ------- ad_sig : :class:`~anndata.AnnData` Annotated data matrix. The ligand-receptor directions are stored in `obsm`. Notes ----- Headsup! The function is not tested on multiple datasets yet. """ ligrec_obsm_key = f"commot-{database}-sum-{lr}" # skip summary molecules = ad_commot.obsm[ligrec_obsm_key].columns.map(lambda x: len(x.split("-"))>2) ad_sig = sc.AnnData(ad_commot.obsm[ligrec_obsm_key].loc[:, molecules].iloc[:, :-1]) # chage key names in dynamo convention ad_sig.obsm[f"X_{basis}"] = ad_commot.obsm[basis] ad_sig.obsm[f"velocity_{basis}"] = ad_commot.obsm[f"commot_{lr}_vf-{database}-{pathway}"] ad_sig.obs = ad_commot.obs return ad_sig