Source code for SpaSRL._srl

# -*- coding: utf-8 -*-

from typing import Callable, Optional, Union
from anndata import AnnData
from ._compat import Literal

import torch
import numpy as np

from scipy.sparse import issparse, spmatrix, csr_matrix

from ._initialize import initialize_Z
from ._solve import solve_P, solve_Z


_Metric = Literal['cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan',
                  'braycurtis', 'canberra', 'chebyshev', 'correlation',
                  'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
                  'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
                  'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule']
_Metric_fn = Callable[[np.ndarray, np.ndarray], float]


[docs]def run_SRL( adata: AnnData, n_neighbors: Optional[int] = None, n_pcs: Optional[int] = None, metric: Union[_Metric, _Metric_fn] = 'euclidean', Lambda: float = 0.1, n_iterations: int = 500, n_discriminant: Optional[int] = None, Z_mask: Optional[Union[np.ndarray, spmatrix]] = None, use_landmarks: Optional[bool] = None, use_highly_variable: Optional[bool] = None, device: Optional[str] = None, random_state: int = 0, key_added: Optional[str] = None, copy: bool = False, ) -> Optional[AnnData]: ''' Run self-representation learning [Wong17]_ [Shi21]_. Parameters ---------- adata Annotated data matrix. n_neighbors Number of neighbors for constructing a :class:`~sklearn.neighbors.NearestNeighbors` graph as hard constraints in self-representation learning. Defaults to the number of samples devided by 10 with a minimum number of 20. n_pcs Number of principal components to use for constructing a :class:`~sklearn.neighbors.NearestNeighbors` graph as hard constraints in self-representation learning. Defaults to the number of features devided by 50 with a minimum number of 20. metric Distance metric to use for constructing a :class:`~sklearn.neighbors.NearestNeighbors` graph as hard constraints in self-representation learning. Lambda Hyperparameter for sparsity regularization. n_iterations Number of iterations for the optimization. n_discriminant Number of discriminant vectors to store for label transfer. By default stores all discriminant vectors. Z_mask Customized sample-by-sample graph of shape (`adata.n_obs`, `adata.n_obs`) as hard constraints in self-representation learning. By default computes a :class:`~sklearn.neighbors.NearestNeighbors` graph. use_landmarks Whether to use landmarks as hard constraints in self-representation learning, stored in `adata.obs['is_landmarks']`. By default uses them if they have been selected beforehand. use_highly_variable Whether to use highly variable genes only, stored in `adata.var['highly_variable']`. By default uses them if they have been determined beforehand. device The desired device for `PyTorch` computation. By default uses cuda if cuda is avaliable cpu otherwise. random_state Change to use different initial states for the optimization. key_added If not specified, the self-representation learning data is stored in `adata.uns['representation']`, representation is stored in `adata.obsp['representation']` and discriminant matrix is stored in `adata.uns['representation']['discriminant']`. If specified, the self-representation learning data is added to `adata.uns[key_added]`, representation is stored in `adata.obsp[key_added+'_representation']` and discriminant matrix is stored in `adata.uns[key_added]['discriminant']`. copy Return a copy instead of writing to ``adata``. Returns ------- Depending on ``copy``, returns or updates ``adata`` with the following fields. See ``key_added`` parameter description for the storage path of representation and discriminant. representation : :class:`~scipy.sparse.csr_matrix` (.obsp) The self-representation of samples. discriminant : :class:`~numpy.ndarray` (.uns[``key_added``]) The discriminant vectors for label transfer. ''' adata = adata.copy() if copy else adata if device is None: if torch.cuda.is_available(): device = "cuda" else: device = "cpu" device = torch.device(device) if use_landmarks is None: use_landmarks = True if 'is_landmark' in adata.obs else False landmarks = np.flatnonzero(adata.obs['is_landmark']) if use_landmarks else np.arange(adata.n_obs) if n_neighbors is None: n_neighbors = min(max(adata.n_obs // 10, 20), adata.n_obs) if use_highly_variable is True and 'highly_variable' not in adata.var.keys(): raise ValueError( 'Did not find adata.var[\'highly_variable\']. ' 'Either your data already only consists of highly-variable genes ' 'or consider running `pp.highly_variable_genes` first.' ) if use_highly_variable is None: use_highly_variable = True if 'highly_variable' in adata.var.keys() else False adata_use = ( adata[:, adata.var['highly_variable']] if use_highly_variable else adata ) if n_discriminant is None: n_discriminant = adata_use.n_vars X = adata_use.X.toarray().T if issparse(adata_use.X) else adata_use.X.T m, n = X.shape if Z_mask is None: Z_mask, n_pcs = initialize_Z(X, landmarks, n_neighbors, n_pcs, random_state) else: if len(Z_mask.shape) != 2 or Z_mask.shape[0] != X.shape[1] or Z_mask.shape[1] != landmarks.shape[0]: raise ValueError( 'The shape of Z_mask needs to be (adata.n_obs, adata.n_obs) ' f'({adata.n_obs}, {adata.n_obs}), but given {Z_mask.shape}.' ) Z_mask = Z_mask.toarray() if issparse(Z_mask) else Z_mask Z_mask[Z_mask != 0] = 1 n_pcs = None X = torch.Tensor(X).type(torch.float32).to(device) Z_mask = torch.Tensor(Z_mask).type(torch.float32).to(device) Z, E = solve_Z(X, landmarks, Lambda, Z_mask, n_iterations, device) P = solve_P(X, landmarks, Z) Z = Z.cpu().numpy() P = P.cpu().numpy() eps = 1e-14 Z[Z < eps] = 0 if key_added is None: key_added = 'representation' conns_key = 'representation' dists_key = 'representation' else: conns_key = key_added + '_representation' dists_key = key_added + '_representation' adata.uns[key_added] = {} representation_dict = adata.uns[key_added] representation_dict['connectivities_key'] = conns_key representation_dict['distances_key'] = dists_key representation_dict['discriminant'] = P[:, :n_discriminant] representation_dict['var_names_use'] = adata_use.var_names.to_numpy() representation_dict['params'] = {} representation_dict['params']['n_neighbors'] = np.count_nonzero(Z) // Z.shape[0] representation_dict['params']['n_pcs'] = n_pcs representation_dict['params']['metric'] = str(metric) representation_dict['params']['Lambda'] = Lambda representation_dict['params']['n_iterations'] = n_iterations representation_dict['params']['n_discriminant'] = n_discriminant representation_dict['params']['use_landmarks'] = use_landmarks representation_dict['params']['use_highly_variable'] = use_highly_variable representation_dict['params']['random_state'] = random_state representation_dict['params']['method'] = 'umap' row_idx = np.repeat(np.arange(n), landmarks.shape[0]) col_idx = np.repeat(landmarks[None,:], n, axis=0).reshape(-1) adata.obsp[conns_key] = csr_matrix((Z.reshape(-1), (row_idx, col_idx)), shape=(n, n)) return adata if copy else None