Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RealTimeKernel.from_moscot(tp) Out of Memory #1146

Open
AlinaKurjan opened this issue Dec 12, 2023 · 27 comments
Open

RealTimeKernel.from_moscot(tp) Out of Memory #1146

AlinaKurjan opened this issue Dec 12, 2023 · 27 comments
Assignees
Labels
question Further information is requested

Comments

@AlinaKurjan
Copy link

AlinaKurjan commented Dec 12, 2023

Hi, thank you very much for Moscot and Cellrank resources! I was wondering if you could please advise the best way to get around the "XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2160972208 bytes." error when converting to the realtimekernel. My adata object is 73131 obs × 23552 var

I've managed to tp.solve by specifying batch=10000. But not cannot convert it to a kernel either using gpu or cpu. Any advice please?

@AlinaKurjan AlinaKurjan added the question Further information is requested label Dec 12, 2023
@Marius1311
Copy link
Collaborator

Hi @AlinaKurjan, this is most likely a problem with adaptive thresholding when we go from dense (but not materialized) OT coupling in moscot to sparse (and materialized) transition matrices in CellRank. Do you have any advice here @MUCDK?

@AlinaKurjan
Copy link
Author

Hi @Marius1311, thank you for your reply. I've managed to get it to work in the end by subsampling to 95% of my dataset. Would be great to have some sort of a different solution for scaling to larger projects in the future :)

@MUCDK
Copy link

MUCDK commented Dec 13, 2023

Hi @AlinaKurjan ,

Have you tried sparsifying the matrix using https://github.com/theislab/moscot/blob/ba231ea4f7511caaa1915e2b7e28e88d709d45ad/src/moscot/base/output.py#L170 / https://moscot.readthedocs.io/en/latest/genapi/moscot.backends.ott.OTTOutput.sparsify.html#moscot.backends.ott.OTTOutput.sparsify ?

As Marius mentioned, CellRank needs to materialize the transport matrix, hence it becomes quadratic in memory. If you first sparsify the output, i.e. set very small values to 0, you can pass a sparse matrix.

@AlinaKurjan
Copy link
Author

Hi @MUCDK. I've tried sparsifying on the full dataset with different parameters and batch sizes, but it made no difference and still gave the Resource exhausted error, usually not immediately but further down the line.

@MUCDK
Copy link

MUCDK commented Dec 13, 2023

Hi @AlinaKurjan ,

Thanks for this, it's super helpful feedback.

Would you mind copy-pasting the code which you tried? This would be easier for us to track down the error.

@AlinaKurjan
Copy link
Author

@MUCDK @Marius1311 sorry for the delay. Here's what I've tried:

sc.pp.filter_genes(adata, min_cells=5)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, layer='log1p_norm')
sc.pp.pca(adata)
sc.pp.neighbors(adata, random_state=0)
tp = TemporalProblem(adata)
tp = tp.score_genes_for_marginals(
    gene_set_proliferation="human", gene_set_apoptosis="human"
)
sc.pl.embedding(
    adata, basis="draw_graph_fa", color=["C_scANVI_original", "proliferation", "apoptosis"]
)
tp = tp.prepare(time_key="agefloat")
tp = tp.solve(epsilon=1e-3, tau_a=0.99, tau_b=0.999, scale_cost="mean",
              batch_size=10000, device='gpu')
tmk = RealTimeKernel.from_moscot(tp, sparse_mode='threshold',
                                 sparcify_kwargs={'value':0.5, 'batch_size':1000})

with batch sizes of 100, 1000, 5000, 10000
and values of 0.1, 0.5, 0.9 all gave same resource exhausted errors down the line
or
tmk= RealTimeKernel.from_moscot(tp, sparse_mode='min_row', sparsify_kwargs={'batch_size':1000)
same problem.

You might notice that I am not entirely sure here how a threshold value or a good batch size should be picked. To be honest, I did not understand what the thresholding did from the description. It would be great if some guidance for a starting point could be added. Apologies if I'm missing something really obvious.

While I have you, can I please ask for your opinion on whether or not CellRank can be used to work with a combined sc and snRNA-seq dataset? I'm working with human developmental data, with embryonic counts coming from sc and fetal from sn data. Obviously very interested in trying to identify trajectories, but not sure if it is possible given the source of biases. I've integrated and corrected my sc and sn data with scVI-scANVI, but since it doesn't batch-correct the counts, I am not sure how "analyseable" this combined data is for trajectory stuff... Any thoughts would be highly appreciated (and apologies if this is not the right place to ask).

@AlinaKurjan AlinaKurjan changed the title RealTimeKernel.from_moscot(tp) Our of Memory RealTimeKernel.from_moscot(tp) Out of Memory Dec 13, 2023
@MUCDK
Copy link

MUCDK commented Dec 13, 2023

Hi @AlinaKurjan ,

It is literally the code you used? There might be an error as you use sparcify_kwargs instead of sparsify_kwargs. Hence, the sparsify_kwargs are not passed. I will look into the moscot sparsification later.

@AlinaKurjan
Copy link
Author

Hi @AlinaKurjan ,

It is literally the code you used? There might be an error as you use sparcify_kwargs instead of sparsify_kwargs. Hence, the sparsify_kwargs are not passed. I will look into the moscot sparsification later.

Sorry, was typing this part of code from memory as I had already deleted it. I am sure I typed it right in the code, the output it was producing was different compared to the one I was getting when running it without sparsification. I am currently running a few processes so can't rerun, but I'll post the outputs tomorrow if you'd like?

@Marius1311
Copy link
Collaborator

Replying to the second part of your question @AlinaKurjan:

Using the combination of moscot.time + CellRank's RealTimeKernel, your data integration strategy should not be a problem for downstream trajectory analysis, as both of these methods operate in latent space embeddings, and don't need access to corrected counts.

  • moscot.time computes OT couplings based on cost matrices, which I guess you compute in the corrected scANVI space (you did not specify this explicitly in your code snippet).
  • CellRank's RealTimeKernel then combines these cross-time point transitions with within-time point transitions, where the latter are computed using within-time-point kNN graphs, which can also be computed in the same scANVI latent space.

Thus, the final combined Markov chain can be entirely computed without access to corrected counts, so this should work! This will enable you to compute macrostates and fate probabilities. Once you go from there to gene-level analysis, like driver genes and gene trends, things could get a bit tricky, as these analyses do need counts. However, I'm not sure how much of a problem this would cause - I think you will just have to try. Up to fate probabilities, you should be fine as you can work entirely in latent spaces.

@michalk8
Copy link
Collaborator

@AlinaKurjan the issue seems that the sparsification happens on GPU, since jax preallocates ~90% percent of the GPU memory. One way to check this would be converting the output object to CPU as using OTTOutput.to:

from anndata import AnnData

problem = ...
sparsify_kwargs = ...

couplings = {}
for (t1, t2), solution in problem.solutions.items():
    adata_src = problem[t1, t2].adata_src
    adata_tgt = problem[t1, t2].adata_tgt

    solution = solution.to(device='cpu')
    solution = solution.sparsify(mode='threshold', **sparsify_kwargs)

    tmat = solution.transport_matrix  # should be scipy.sparse array
    couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)

This will come at some performance cost, but should be still fast. Lmk if this works - initializing the kernel should be straightforward from this (see here).

@AlinaKurjan
Copy link
Author

AlinaKurjan commented Dec 13, 2023

@Marius1311 Thank you for such a detailed reply, that's super useful!! Essentially, when you say that I need to compute the OT couplings in the scANVI space, do you mean that I should be supplying neighbors calculated using the scANVI embeddings when creating the TemporalProblem (cause I've supplied the PCA-based ones as in the tutorial...)? If so, would that logic also apply for RNA velocity moments as well (which I'm doing separately for sc and sn)?

The reason I was confused about what neighbors to supply was because I've gone through Dana Pe'er's Harmony and Palantir workflows prior to this to get an augmented affinity matrix (X_aug_aff) using scANVI embeddings, then the diffusion space based on this (X_diff), and then the multiscale space (X_msdiff). I was planning to proceed with downstream trajectory analysis using X_diff neighbors, but now realising that's likely not the right call. Any opinion on this would be extremely useful.

I imagine it's not necessary to do the harmony and palantir stuff if the kernel is built on experimental time (I originally built a kernel based on pseudotime, but the results were not clear)?

data_df=pd.DataFrame(adata.obsm["X_scANVI"],index=adata.obs_names)
aug_aff, aff = harmony.core.augmented_affinity_matrix(data_df, tp, timepoint_connections, n_neighbors=30, pc_components=None)
adata.obsm['X_aug_aff'] = aug_aff
dm_res=palantir.utils.run_diffusion_maps(adata.obsm["X_aug_aff"], knn=30)
sc.pp.neighbors(adata,n_neighbors=30,use_rep="X_diff")
sc.tl.paga(adata,"C_scANVI_original")
ms_data = palantir.utils.determine_multiscale_space(dm_res, n_eigs=7)
adata.obsm["X_msdiff"]=ms_data.values
sc.tl.tsne(adata,use_rep="X_msdiff",perplexity=50, learning_rate=adata.shape[0]/12)
sc.pp.neighbors(adata, n_neighbors=30, use_rep="X_msdiff")
sc.tl.draw_graph(adata, init_pos="X_tsne", maxiter=500)

@AlinaKurjan
Copy link
Author

@AlinaKurjan the issue seems that the sparsification happens on GPU, since jax preallocates ~90% percent of the GPU memory. One way to check this would be converting the output object to CPU as using OTTOutput.to:

from anndata import AnnData

problem = ...
sparsify_kwargs = ...

couplings = {}
for (t1, t2), solution in problem.solutions.items():
    adata_src = problem[t1, t2].adata_src
    adata_tgt = problem[t1, t2].adata_tgt

    solution = solution.to(device='cpu')
    solution = solution.sparsify(mode='threshold', **sparsify_kwargs)

    tmat = solution.transport_matrix  # should be scipy.sparse array
    couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)

This will come at some performance cost, but should be still fast. Lmk if this works - initializing the kernel should be straightforward from this (see here).

@michalk8 Awesome, thank you so much! Will give it a go tomorrow and get back to you with the results.

@AlinaKurjan
Copy link
Author

@michalk8

problem = tp

sparsify_kwargs = {'value':0.2, 'batch_size':10000}

​

couplings = {}

for (t1, t2), solution in problem.solutions.items():

    adata_src = problem[t1, t2].adata_src

    adata_tgt = problem[t1, t2].adata_tgt

​

    solution = solution.to(device='cpu')

    solution = solution.sparsify(mode='threshold', **sparsify_kwargs)

​

    tmat = solution.transport_matrix  # should be scipy.sparse array

    couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[11], line 10
      7 adata_tgt = problem[t1, t2].adata_tgt
      9 solution = solution.to(device='cpu')
---> 10 solution = solution.sparsify(mode='threshold', **sparsify_kwargs)
     12 tmat = solution.transport_matrix  # should be scipy.sparse array
     13 couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/moscot/base/output.py:242, in BaseSolverOutput.sparsify(self, mode, value, batch_size, n_samples, seed)
    240 for batch in range(0, k, batch_size):
    241     x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float)
--> 242     res = np.array(func(x, scale_by_marginals=False))
    243     res[res < thr] = 0.0
    244     tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res))

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/moscot/base/output.py:130, in BaseSolverOutput.pull(self, x, scale_by_marginals)
    128 if scale_by_marginals:
    129     x = self._scale_by_marginals(x, forward=False)
--> 130 return self._apply(x, forward=False)

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/moscot/backends/ott/output.py:177, in OTTOutput._apply(self, x, forward)
    175 if x.ndim == 1:
    176     return self._output.apply(x, axis=1 - forward)
--> 177 return self._output.apply(x.T, axis=1 - forward).T

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/ott/solvers/linear/sinkhorn.py:471, in SinkhornOutput.apply(self, inputs, axis)
    469 def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
    470   """Apply the transport to a ndarray; axis=1 for its transpose."""
--> 471   return self.ot_prob.geom.apply_transport_from_potentials(
    472       self.f, self.g, inputs, axis=axis
    473   )

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/ott/geometry/geometry.py:468, in Geometry.apply_transport_from_potentials(self, f, g, vec, axis)
    464 if vec.ndim == 1:
    465   return self._apply_transport_from_potentials(
    466       f, g, vec[jnp.newaxis, :], axis
    467   )[0, :]
--> 468 return self._apply_transport_from_potentials(f, g, vec, axis)

    [... skipping hidden 3 frame]

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/ott/geometry/geometry.py:431, in Geometry._apply_transport_from_potentials(self, f, g, vec, axis)
    426 @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None])
    427 def _apply_transport_from_potentials(
    428     self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int
    429 ) -> jnp.ndarray:
    430   """Apply lse_kernel to arbitrary vector while keeping track of signs."""
--> 431   lse_res, lse_sgn = self.apply_lse_kernel(
    432       f, g, self.epsilon, vec=vec, axis=axis
    433   )
    434   lse_res += f if axis == 1 else g
    435   return lse_sgn * jnp.exp(lse_res / self.epsilon)

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/ott/geometry/pointcloud.py:271, in PointCloud.apply_lse_kernel(self, f, g, eps, vec, axis)
    268 else:
    269   raise ValueError(axis)
--> 271 _, (h_res, h_sign) = jax.lax.scan(
    272     fun, init=(f, g, eps, vec), xs=jnp.arange(n)
    273 )
    274 h_res, h_sign = jnp.concatenate(h_res), jnp.concatenate(h_sign)
    275 h_res_rest, h_sign_rest = finalize(n * self.batch_size)

    [... skipping hidden 13 frame]

File ~/conda/envs/cellrankV2/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py:1349, in ExecuteReplicated.__call__(self, *args)
   1344   self._handle_token_bufs(
   1345       results.disassemble_prefix_into_single_device_arrays(
   1346           len(self.ordered_effects)),
   1347       results.consume_token())
   1348 else:
-> 1349   results = self.xla_executable.execute_sharded(input_bufs)
   1350 if dispatch.needs_check_special():
   1351   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 8442223150208 bytes.

same error with batch_size=1000. Batch_size=100 crashes the kernel. Any other parameters I can try?
I also don't know how to pick the right threshold for this...

@Marius1311
Copy link
Collaborator

@Marius1311 Thank you for such a detailed reply, that's super useful!! Essentially, when you say that I need to compute the OT couplings in the scANVI space, do you mean that I should be supplying neighbors calculated using the scANVI embeddings when creating the TemporalProblem (cause I've supplied the PCA-based ones as in the tutorial...)? If so, would that logic also apply for RNA velocity moments as well (which I'm doing separately for sc and sn)?

No, you don't need to supply a kNN-graph, you can directly supply the scANVI embedding in the temporal problem. You would simply do something like

tp = tp.prepare(time_key="day", joint_attr="X_scanvi")

As to RNA velocity, that is a bit more trick to combine with batch correction, as that does rely on corrected counts, and not only that, but also on corrected spliced/unspliced counts. Importantly, batch effect correction needs to preserve the ration between spliced/unspliced counts, so here it gets tricky. Latent space-based methods, like veloVI, are probably your best guess here. See the VeloVI docs here: https://velovi.readthedocs.io/en/latest/index.html

However, I'm not sure whether VeloVI currently supports data integration for RNA velocity computation - you could open an issue in their repo.

The reason I was confused about what neighbors to supply was because I've gone through Dana Pe'er's Harmony and Palantir workflows prior to this to get an augmented affinity matrix (X_aug_aff) using scANVI embeddings, then the diffusion space based on this (X_diff), and then the multiscale space (X_msdiff). I was planning to proceed with downstream trajectory analysis using X_diff neighbors, but now realising that's likely not the right call. Any opinion on this would be extremely useful.

As a starting point, I would run both moscot and CellRank directly in the scANVI embedding, as I pointed out above, and see how far that gets you.

I imagine it's not necessary to do the harmony and palantir stuff if the kernel is built on experimental time (I originally built a kernel based on pseudotime, but the results were not clear)?

The Pe'er lab's harmony is for integration across time points, which in this context, you might not need, as you're using OT to map across time points.

@michalk8
Copy link
Collaborator

@AlinaKurjan there are 2 batch sizes, one in sparsify and the other one in problem.solve.

The one in the solve method only materialized to point cloud of shape [n, m] into [n, batch_size] (or [batch_size, m]) chunks which are then used when running T @ x, where T is the transport matrix and x is a vector/matrix.

The batch_size passed in sparsify controls how big the x matrix is, since to sparsify the values, you need to evaluate it on a diagonal matrix D as T(D) - it just correctly chunks the D and roughly does [T(d_chunk) for d_chunk in d_chunks].

The batch_size in solve is 10_000, so T(d_chunk) will be evaluated roughly on [n, 10_000] for a specific d_chunk. I'd consider decreasing the batch_size in solve to something lower (let's say 1024 - it should be still fast).
There might be a way to set this without re-running the solve method, but it's not really convenient (can try later to see if it works) - @MUCDK we might want to include some convenience wrapper for this later.

@MUCDK
Copy link

MUCDK commented Dec 14, 2023

@AlinaKurjan there are 2 batch sizes, one in sparsify and the other one in problem.solve.

The one in the solve method only materialized to point cloud of shape [n, m] into [n, batch_size] (or [batch_size, m]) chunks which are then used when running T @ x, where T is the transport matrix and x is a vector/matrix.

The batch_size passed in sparsify controls how big the x matrix is, since to sparsify the values, you need to evaluate it on a diagonal matrix D as T(D) - it just correctly chunks the D and roughly does [T(d_chunk) for d_chunk in d_chunks].

The batch_size in solve is 10_000, so T(d_chunk) will be evaluated roughly on [n, 10_000] for a specific d_chunk. I'd consider decreasing the batch_size in solve to something lower (let's say 1024 - it should be still fast). There might be a way to set this without re-running the solve method, but it's not really convenient (can try later to see if it works) - @MUCDK we might want to include some convenience wrapper for this later.

Agree, tracked in theislab/moscot#639 (comment)

@AlinaKurjan
Copy link
Author

AlinaKurjan commented Dec 18, 2023

@michalk8 @MUCDK Sorry for the delay and thank you so much for all your help so far, your input is invaluable! I've tried reducing the batch size of the tp.solve to 1024. It took a quite a bit longer to run (maybe half an hour more), but nothing major. Unfortunately, still getting errors as below:

problem = tp
sparsify_kwargs = {'batch_size':1024}

couplings = {}
for (t1, t2), solution in problem.solutions.items():
    adata_src = problem[t1, t2].adata_src
    adata_tgt = problem[t1, t2].adata_tgt
    solution = solution.to(device='gpu')
    solution = solution.sparsify(mode='min_row', **sparsify_kwargs)
    tmat = solution.transport_matrix  # should be scipy.sparse array
    couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)
2023-12-18 10:35:05.321394: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 17.77GiB (19076431872 bytes) by rematerialization; only reduced to 37.94GiB (40734793728 bytes)
2023-12-18 10:35:15.380041: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 37.94GiB (rounded to 40734793728)requested by op 
2023-12-18 10:35:15.380329: W external/tsl/tsl/framework/bfc_allocator.cc:497] ________________________________________________*****___________________*********___*****xx_____****
2023-12-18 10:35:15.380392: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 40734793728 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   57.92MiB
              constant allocation:         0B
        maybe_live_out allocation:   37.94GiB
     preallocated temp allocation:         0B
                 total allocation:   37.99GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 37.94GiB
		Operator: op_name="jit(_where)/jit(main)/select_n" source_file="/tmp/ipykernel_1779629/2320153915.py" source_line=1
		XLA Label: fusion
		Shape: f32[1024,486,20463]
		==========================

	Buffer 2:
		Size: 37.94MiB
		Entry Parameter Subshape: f32[486,20463]
		==========================

	Buffer 3:
		Size: 19.98MiB
		Entry Parameter Subshape: pred[1024,20463]
		==========================

	Buffer 4:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================

Same issue with threshold specification rather than 'min_row'.
Attempting to run with the 'cpu' setting kills the kernel...

@AlinaKurjan
Copy link
Author

Also, don't know if this is useful but here's the output of tp.solve:

tp = tp.solve(epsilon=1e-3, tau_a=0.99, tau_b=0.999, scale_cost="mean",
              batch_size=1024, device='gpu')
INFO     Solving `8` problems                                                                                      
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(23014, 20463)].                                
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20463, 26196)].                                
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(421, 1644)].                                   
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(468, 421)].                                    
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(21, 212)].                                     
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(1644, 23014)].                                 
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(212, 468)].                                    
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(736, 21)].                                     

Could it be because my groups are so unbalanced?

@michalk8
Copy link
Collaborator

The output is still on GPU, can you try please?

# solution = solution.to(device='gpu')
solution = solution.to(device='cpu')

Another thing is that even with this batch size, you're trying to materialize array of shape f32[1024,486,20463], which has 10183698432 elements, each of 4 bytes, which is approx. 37.93GiB. Either decreasing the batch_size during solve (and later as well during sparsify) is the way to go (I assume 486 is the dimensionality of your data, which itself is quite big) or putting the object on CPU (if there's enough RAM) should work (or both).

@fafa92
Copy link

fafa92 commented Apr 17, 2024

Hi, I'm also getting the same error.

this is what I'm trying to do:
atac:
n_obs × n_vars = 50637 × 370058

rna:
n_obs × n_vars = 120444 × 2000

data: shape=(50637, 120444)

tp = TranslationProblem(adata_src=atac_seq_data, adata_tgt=rna_seq_data)
tp = tp.prepare(src_attr="ATAC_lsi_l2_norm", tgt_attr="X_pca")

problem = tp

sparsify_kwargs = {'batch_size':1024}
couplings = {}
for (t1, t2), solution in problem.solutions.items():
    adata_src = problem[t1, t2].adata_src
    adata_tgt = problem[t1, t2].adata_tgt
    solution = solution.to(device='cpu')
    solution = solution.sparsify(mode='min_row', **sparsify_kwargs)
    tmat = solution.transport_matrix  # should be scipy.sparse array
    couplings[t1, t2] = AnnData(tmat, obs=adata_src.obs, var=adata_tgt.obs)

tp = tp.solve(alpha=1.0, epsilon=1e-3 ,device='gpu')

output:

INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(50637, 120444)].   
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 24412468528 bytes.

Am I missing something in my solution?
Any help would be appreciated, thanks in advance!

@MUCDK
Copy link

MUCDK commented Apr 18, 2024

Hi @fafa92 ,

Thanks for reporting this!

There are two options:

  1. you move the solution to the cpu (here you seem to get a CUDA OOM). Then, the sparsification will take a while, but you probably have more memory on a cpu.
  2. try a different batch_size in the sparsify kwargs.

Moreover, just to be sure: you get the OOM error from the sparsify, and not the solve right?
Just wondering because in your solve call you don't set the rank. As you are using the TranslationProblem this scales quadratically in the number of data points if you don't set the rank.

@AlinaKurjan
Copy link
Author

AlinaKurjan commented Apr 18, 2024

Hi @MUCDK and @fafa92, I did get it to work in the end with no extra code modifications (i.e. with the tutorial codes provided). There are a few things that could be the reason for it working this time for me:

  1. this is a pure assumption, but perhaps what is in adata.raw and adata.X affects how the data is processed and potentially leads to errors? @MUCDK I have not looked at the source code, so apologies is this is not right, but this is really the only thing that changed for my dataset besides the environment installation and additional tp calculations for plots (see below). In the working version, i am storing normalised full counts in .raw and scGen-corrected counts in X. It also works if I revert .raw to X before initiating the tp.
  2. there could have been some underlying dependency conflicts (though doesn't explain why subsetting worked):
    Here are the versions of packages that work for me:
-----
anndata     0.9.2
scanpy      1.9.6
-----
PIL                 10.1.0
absl                NA
asttokens           NA
backcall            0.2.0
brotli              1.1.0
cellrank            2.0.2
certifi             2023.11.17
charset_normalizer  3.3.2
chex                0.1.8
cloudpickle         3.0.0
colorama            0.4.6
comm                0.1.4
contextlib2         NA
cycler              0.12.1
cython_runtime      NA
dateutil            2.8.2
debugpy             1.8.0
decorator           5.1.1
docrep              0.3.2
executing           2.0.1
flax                0.6.1
fsspec              2023.12.2
get_annotations     NA
gmpy2               2.1.2
h5py                3.10.0
idna                3.6
importlib_resources NA
ipykernel           6.26.0
jax                 0.4.13
jaxlib              0.4.12
jaxopt              NA
jedi                0.19.1
joblib              1.3.2
kiwisolver          1.4.5
lightning_fabric    1.9.4
lightning_utilities 0.10.0
llvmlite            0.41.1
matplotlib          3.7.1
matplotlib_inline   0.1.6
ml_collections      NA
ml_dtypes           0.2.0
moscot              0.3.3
mpl_toolkits        NA
mpmath              1.3.0
msgpack             1.0.7
mudata              0.2.3
multipledispatch    0.6.0
natsort             8.4.0
networkx            3.1
numba               0.58.1
numpy               1.24.4
numpyro             0.12.1
opt_einsum          v3.3.0
optax               0.1.7
ott                 0.4.4
packaging           23.2
pandas              2.0.3
parso               0.8.3
patsy               0.5.4
petsc4py            3.20.2
pexpect             4.8.0
pickleshare         0.7.5
pkg_resources       NA
platformdirs        4.1.0
progressbar         4.2.0
prompt_toolkit      3.0.41
psutil              5.9.5
ptyprocess          0.7.0
pure_eval           0.2.2
pydev_ipython       NA
pydevconsole        NA
pydevd              2.9.5
pydevd_file_utils   NA
pydevd_plugins      NA
pydevd_tracing      NA
pygam               0.8.0
pygments            2.17.2
pygpcca             1.0.4
pyparsing           3.1.1
pyro                1.8.6+4be5c2e
python_utils        NA
pytorch_lightning   1.9.4
pytz                2023.3.post1
requests            2.31.0
rich                NA
scipy               1.10.1
scvelo              0.3.1
scvi                0.20.3
seaborn             0.12.2
session_info        1.0.0
six                 1.16.0
sklearn             1.3.2
slepc4py            3.20.1
socks               1.7.1
stack_data          0.6.2
statsmodels         0.14.0
sympy               1.12
threadpoolctl       3.2.0
toolz               0.12.0
torch               2.0.0
torchmetrics        1.2.1
tornado             6.3.3
tqdm                4.66.1
traitlets           5.14.0
tree                0.1.8
typing_extensions   NA
urllib3             2.1.0
wcwidth             0.2.12
wrapt               1.16.0
yaml                6.0.1
zipp                NA
zmq                 25.1.2
-----
IPython             8.12.2
jupyter_client      8.6.0
jupyter_core        5.5.0
-----
Python 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36) [GCC 12.3.0]
Linux-5.15.0-100-generic-x86_64-with-glibc2.10
  1. Maybe you absolutely have to run the sankey and cell transition calculations? The code I use that now works:
tp = TemporalProblem(adata)
tp = tp.score_genes_for_marginals(
    gene_set_proliferation="human", gene_set_apoptosis="human"
)
tp = tp.prepare(time_key="agefloat", joint_attr='corrected_latent')
tp = tp.solve(epsilon=1e-3, tau_a=0.99, tau_b=0.999, scale_cost="mean",
              batch_size=1200, device='gpu')

tp.sankey(
    source=6.5,
    target=20,
    source_groups="cell_type",
    target_groups="cell_type",
    threshold=0.2,
)
mtp.sankey(tp, dpi=100, figsize=(40, 8), save=os.path.join(FIGURES_FOLDERNAME,'sankey_corrlatent.svg'))

for i in range(len(keys) - 1): 
    t1 = keys[i]
    t2 = keys[i + 1]
    dict_key = f'{t1}_{t2}'
    
    ct_desc = tp.cell_transition(t1, t2, "cell_type", "cell_type", forward=True, key_added=f"transitions_{t1}_{t2}")
    mtp.cell_transition(tp,fontsize=8,figsize=(5, 5),return_fig=True,key=f"transitions_{t1}_{t2}",
                       save=f'{dict_key}_descendants.svg'
                       )

tmk = RealTimeKernel.from_moscot(tp)           

Those are the changes I could think of. Hope this helps :)

@MUCDK
Copy link

MUCDK commented Apr 18, 2024

Hi @AlinaKurjan ,

Happy it worked for you in the end!
The preprocessing, e.g. the PCA computation all happens on CPU, hence, if you had a GPU OOM error, this should not affect whether you get OOM.

Possibly you worked on a different GPU with more memory? Otherwise not sure.
Anyways, thanks a lot for posting this!

@fafa92 please be aware that you are working with the TranslationProblem, which requires the rank parameter to be set for reducing the memory. The TemporalProblem can have an exact solution in linear memory complexity as it is a "linear OT" problem, while the TranslationProblem "maps across incomparable spaces", and hence in its default variant is squared in memory, and the only way to reduce memory here is to set the rank parameter, e.g. to 3000

@fafa92
Copy link

fafa92 commented Apr 18, 2024

Hi @MUCDK and @AlinaKurjan. Thanks for the help. I'm trying to use this instruction on my data.
https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html#

The first problem was solved using rank. Thank you @MUCDK for the hint. However, I'm facing a few other issues.

  1. getting this output for my tp.solve

tp = tp.solve(alpha=1.0, epsilon=1e-3,rank=3000, device='gpu')

INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(50637, 120444)].                                       
WARNING  Solver did not converge

Does not converging mean I need to re-run it? if so, what are the changes you guys recommend?

This came out of the visualization part, I'm not sure if this is looking like this because of tp.solve convergence or not.

image

I had dimensionality reduction for both datasets set to LSI-l2 (used to be PCA for RNA and LSI-l2 for ATAC), but still not converging. I got a more robust figure, but datasets are not overlapping like the example! What might be causing this?

image

@MUCDK
Copy link

MUCDK commented Apr 19, 2024

Yes, whenever the solver hasn't converged, it's better to rerun it.

In general, doing a pure GW as proposed in the SCOT paper is often not good enough to match cells, hence, as described in our tutorial (the lower part), we recommend using the extension to the fused setting. I.e. you can get gene activity, and then integrate the data with scVI.

Regarding the convergence, this is hard to tell. You could try different epsilon values, e.g. 0 (which is allowed when you do low-rank), or 1.0. . @giovp do you have any recommendations regarding the convergence?

@Marius1311
Copy link
Collaborator

I think at this point, it makes sense to move the discussion to the moscot repo.

@fafa92
Copy link

fafa92 commented Apr 23, 2024

@MUCDK scVI integration helped a bit, thanks for the suggestion! More RAM came to rescue the convergence problem :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants