Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LGCD solver issue when using new data larger than a certain size #121

Open
rsuhendra opened this issue Apr 12, 2024 · 1 comment
Open

LGCD solver issue when using new data larger than a certain size #121

rsuhendra opened this issue Apr 12, 2024 · 1 comment

Comments

@rsuhendra
Copy link

rsuhendra commented Apr 12, 2024

the LGCD solver (which is the default for BatchCDL) seems to run into some issues when trying on new data larger than a certain size. Specifically, when I use the transform function, it freezes above a certain size threshold but works perfectly below it. None of the other solvers seem to have the issue, but they also dont seem to work as well.

@rsuhendra
Copy link
Author

I think I've figured out the issue. In the _coordinate_descent_idx function, in the greedy case the following is defined:

n_seg = max(n_times_valid // (2 * n_times_atom), 1)
n_times_seg = n_times_valid // n_seg + 1

and to look at each segment.
seg_bounds = [0, n_times_seg]
seg_bounds[0] += n_times_seg
seg_bounds[1] += n_times_seg

Most of the times here, n_times_seg will be fixed to n_times_seg = 2 * n_times_atom + 1. When n_times_valid is large enough however, n_seg becomes larger than n_times_seg.

This is an issue because here, there are according to each segment, we should have n_seg blocks of size n_times_seg.
However, the n_seg and n_seg_times are defined as if we are solving for n_times_seg blocks of size n_seg.

Specifically, we have n_seg*(n_times_seg-1) < n_times_valid < n_seg*n_times_seg instead of having (n_seg-1)n_times_seg < n_times_valid < n_segn_times_seg.

As a result, for larger datasets, the program will often freeze since too many blocks are defined that are never iterated over and therefore never "solved".

The two fixes for this are:

  1. Make it so that the functions solves for n_times_seg blocks of size n_seg. For example, having instead
    seg_bounds = [0, n_seg]
    seg_bounds[0] += n_seg
    seg_bounds[1] += n_seg

  2. Make it so that the number of blocks is properly defined.
    n_times_seg = (2 * n_times_atom) + 1
    n_seg = n_times_valid // n_times_seg + 1

I personally did option 2, though I'm not sure which one is more consistent with your paper. Let me know if there was a mistake here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant