Source code for SpaSRL._select_landmarks
# -*- coding: utf-8 -*-
from typing import Optional
from anndata import AnnData
import numpy as np
import pandas as pd
from scipy.sparse import issparse
from sklearn.preprocessing import normalize
from tqdm import trange
from ._lasso_landmarks import lasso_landmarks
[docs]def select_landmarks(
adata: AnnData,
n_landmarks: int,
Lambda: float = 0.5,
reltol: float = 1e-3,
use_highly_variable: Optional[bool] = None,
random_state: int = 0,
copy: bool = False,
) -> Optional[AnnData]:
'''
Select landmark samples [Matsushima19]_.
Parameters
----------
adata
Annotated data matrix.
n_landmarks
Number of landmarks to be selected.
Lambda
Hyperparameter for sparsity regularization.
reltol
Relative tolerance in optimization.
use_highly_variable
Whether to use highly variable genes only, stored in `adata.var['highly_variable']`.
By default uses them if they have been determined beforehand.
random_state
Change to use different initial states for the optimization.
copy
Return a copy instead of writing to ``adata``.
Returns
-------
Depending on ``copy``, returns or updates ``adata`` with the following fields.
.obs['is_landmark']
Boolean indicator of landmark samples.
'''
adata = adata.copy() if copy else adata
if adata.n_obs < n_landmarks:
n_landmarks = adata.n_obs
if use_highly_variable is True and 'highly_variable' not in adata.var.keys():
raise ValueError(
'Did not find adata.var[\'highly_variable\']. '
'Either your data already only consists of highly-variable genes '
'or consider running `pp.highly_variable_genes` first.'
)
if use_highly_variable is None:
use_highly_variable = True if 'highly_variable' in adata.var.keys() else False
adata_use = (
adata[:, adata.var['highly_variable']] if use_highly_variable else adata
)
X = adata_use.X.toarray().T if issparse(adata_use.X) else adata_use.X.T
X = normalize(X.astype(np.float64), axis=0)
rng = np.random.RandomState(seed=random_state)
n = adata.n_obs
max_t = n_landmarks * 2
num_I_t = 1
num_rand_idx = n_landmarks * 2 * num_I_t
if n >= num_rand_idx:
rand_idx = rng.choice(np.arange(n), num_rand_idx)
else:
rand_idx = []
for it in range(num_rand_idx // n + 1):
rand_idx.append(rng.permutation(n))
rand_idx = np.concatenate(rand_idx)
S_t = np.array([], dtype=int)
norms = np.sum(X**2, axis=0)
stats = {
'reltol': reltol,
'Lambda': Lambda,
'normsSt': np.array([]),
'XSt': np.ndarray((X.shape[0], 0)),
'W': np.zeros((n_landmarks, n)),
'R': X,
}
# select landmark samples
pbar = trange(max_t)
for t in pbar:
I_t = rand_idx[(t*num_I_t):((t+1)*num_I_t)]
if S_t.shape[0] != 0:
for i in I_t:
stats = lasso_landmarks(X, S_t, stats, i, rng)
candidate = np.setdiff1d(np.arange(n), S_t)
grad_L = np.matmul(stats['R'][:, I_t].T, X[:, candidate])
grad = np.zeros((num_I_t, n))
grad[:, candidate] = np.minimum(grad_L+Lambda, np.maximum(0, grad_L-Lambda))
for j in range(num_I_t):
grad[j, I_t[j]] = 0
grad2sum = np.sum(grad**2, axis=0)
idx = np.argsort(grad2sum)[::-1]
max_vals = grad2sum[idx]
if max_vals[0] != 0:
dSt = idx[0]
S_t = np.append(S_t, dSt)
pbar.set_postfix_str(f'seleted landmarks: {S_t.shape[0]}')
stats['normsSt'] = np.append(stats['normsSt'], norms[dSt])
stats['XSt'] = np.concatenate((stats['XSt'], X[:, dSt][:, None]), axis=1)
if S_t.shape[0] == n_landmarks:
break
landmarks = pd.Series(False, index=adata.obs_names)
landmarks.iloc[S_t] = True
adata.uns['select_landmarks'] = {}
landmarks_dict = adata.uns['select_landmarks']
landmarks_dict['params'] = {}
landmarks_dict['params']['n_landmarks'] = np.count_nonzero(landmarks)
landmarks_dict['params']['Lambda'] = Lambda
landmarks_dict['params']['reltol'] = reltol
landmarks_dict['params']['use_highly_variable'] = use_highly_variable
landmarks_dict['params']['random_state'] = random_state
adata.obs['is_landmark'] = landmarks
return adata if copy else None