Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize token_classification/filter.py and find_label_issues_batched.py for performance #1072

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

gogetron
Copy link
Contributor

Summary

This PR partially addresses #862

🎯 Purpose: Improve performance of find_label_issues from token_classification/filter.py and the more general find_label_issues_batched function from experimental/label_issues_batched.py

[ ✏️ Write your summary here. ]
We could skip one list comprehension in find_label_issues. The flatten part for the pred_probs is a lot faster using np.vstack and iterating only on the top level of the list.

In addition, after profiling the function I found that most of the work was done in the label_issues_batched so I made a few changes there also. By batching the label_issues_mask we could reduce the memory usage significantly while slightly improving the runtime at the same time.

For memory I used the memory-profiler library. The code I used for benchmarking is copied below. In addition I sorted the imports in the modified files.

Code Setup

import random

import numpy as np

from cleanlab.token_classification.filter import find_label_issues

np.random.seed(0)
random.seed(0)
%load_ext memory_profiler

TOTAL_EXAMPLES = 150_000
MAX_LENGTH = 100
NUM_CLASSES = 50

def create_dataset():
    labels = []
    pred_probs = []
    for _ in range(TOTAL_EXAMPLES):
        length = random.randint(2, MAX_LENGTH)
        new_labels = np.random.randint(NUM_CLASSES, size=length) 
        labels.append(new_labels) 
        probs = np.random.random((length, NUM_CLASSES))
        probs /= probs.sum(axis=1, keepdims=True)
        pred_probs.append(probs)
    return labels, pred_probs

# Create input data
labels, pred_probs = create_dataset()
# Execute once to avoid the tensorflow import time in the benchmark.
issues = find_label_issues(labels, pred_probs, n_jobs=1, low_memory=True)

Current version

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, low_memory=True)
# peak memory: 10559.92 MiB, increment: 6451.74 MiB
# peak memory: 10618.27 MiB, increment: 6481.02 MiB
# peak memory: 10618.55 MiB, increment: 6480.28 MiB
# peak memory: 10615.83 MiB, increment: 6475.28 MiB
# peak memory: 10618.69 MiB, increment: 6476.10 MiB
# peak memory: 10618.55 MiB, increment: 6474.94 MiB
# peak memory: 10618.63 MiB, increment: 6473.01 MiB
# peak memory: 10613.63 MiB, increment: 6468.02 MiB
# 10.1 s ± 509 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR

%%timeit
%memit find_label_issues(labels, pred_probs, n_jobs=1, low_memory=True)
# peak memory: 7392.07 MiB, increment: 3560.62 MiB
# peak memory: 7384.52 MiB, increment: 3569.43 MiB
# peak memory: 7404.90 MiB, increment: 3589.71 MiB
# peak memory: 7393.79 MiB, increment: 3578.36 MiB
# peak memory: 7404.90 MiB, increment: 3589.45 MiB
# peak memory: 7398.82 MiB, increment: 3583.40 MiB
# peak memory: 7386.59 MiB, increment: 3570.19 MiB
# peak memory: 7387.36 MiB, increment: 3570.96 MiB
# 6.23 s ± 97.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Testing

🔍 Testing Done: Existing tests.

References

Reviewer Notes

💡 Include any specific points for the reviewer to consider during their review.

Copy link

codecov bot commented Mar 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.11%. Comparing base (abd0924) to head (f619bc5).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1072      +/-   ##
==========================================
- Coverage   96.15%   96.11%   -0.04%     
==========================================
  Files          74       74              
  Lines        5850     5849       -1     
  Branches     1044     1043       -1     
==========================================
- Hits         5625     5622       -3     
- Misses        134      135       +1     
- Partials       91       92       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jwmueller jwmueller requested a review from elisno March 29, 2024 23:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant