Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling timestamps changes text/reduces accuracy #30815

Closed
2 of 4 tasks
jaggzh opened this issue May 14, 2024 · 4 comments
Closed
2 of 4 tasks

Enabling timestamps changes text/reduces accuracy #30815

jaggzh opened this issue May 14, 2024 · 4 comments
Labels

Comments

@jaggzh
Copy link

jaggzh commented May 14, 2024

System Info

  • transformers version: 4.40.2
  • Platform: Linux-6.1.0-20-amd64-x86_64-with-glibc2.36
  • Python version: 3.11.2
  • Huggingface_hub version: 0.21.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.30.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: no
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): 2.16.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: (I believe so? torch.cuda.is_available() => True)
  • Using distributed or parallel set-up in script?: (Not that I'm aware of)

Who can help?

@sanchit-gandhi @ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. With a fine-tuned whisper model
  2. Enable return_timestamps=True in the generate() call
  3. Compare predicted results against a generate() without return_timestamps=True

Expected behavior

The text with and without timestamps "should" match, no? But with timestamps it somehow interferes, changing the text and, in this case, decreasing its accuracy.

This is a fine-tuned model, with a complex voice (patient whispers, breathing on a ventilator), and so far with insufficient data for better training. My point here is that I believe the model will therefore be more susceptible to influences that can deteriorate its recognition. However, my main questions are:

  1. How does the inclusion of timestamps, at generation(..., return_timestamps=True) end up affecting the whole process?
  2. Is there anything that can be done to keep the 'original' (non-timestamp-based) accuracy?

My code (it's a bit of a mess as I experiment):

    for predwav in predwavs:
        aa,sr=librosa.load(predwav, sr=16000)
        sample=aa.astype(np.float64)

        input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features 
        # generate token ids

        print("Generating...")
        # Raises error. don't use:
        # pids = model.generate(input_features, return_timestamps=True, return_token_timestamps=True, language='en')
        #    ".../transformers/tokenization_utils.py", line 976, in convert_ids_to_tokens
        #    index = int(index)
        #            ^^^^^^^^^^
        #    ValueError: invalid literal for int() with base 10: 's'

        # Allows timestamps but reduces transcription accuracy:
        pids = model.generate(input_features, return_timestamps=True, language='en')

        # Highest accuracy is without timestamps:
        # pids = model.generate(input_features, language='en')
        print("/Generating.")
        # decode token ids to text
        # print("batch_decode()")
        # transcription = processor.batch_decode(pids, skip_special_tokens=False)
        # print("/batch_decode")
        # print(f"Transcription: {transcription}")
        print("Timestamp info:")
        for pidi, pid in enumerate(pids):
            # timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
            timestamps = processor.tokenizer.decode(pid, output_offset=True)
            pdict = processor.tokenizer.decode(pid, output_offsets=True)
            print(f"Predicted id [{pidi}]: {pdict['text']}")
            print(f"Predicted id [{pidi}]: {pdict['offsets']}")
        import ipdb; ipdb.set_trace(context=16); pass

With generate()'s return_timestamps=True:

Predicted id [0] text: <|startoftranscript|><|en|><|transcribe|> There is a time... ...of a subconscious development. It don't work. Bureau work. The branch, the branch.<|endoftext|>

Predicted id [0] offsets: [{'text': ' There is a time...', 'timestamp': (0.0, 2.6)}, {'text': ' ...of a subconscious development.', 'timestamp': (14.6, 17.6)}, {'text': " It don't work.", 'timestamp': (20.6, 22.6)}, {'text': ' Bureau work.', 'timestamp': (23.400000000000002, 24.400000000000002)}, {'text': ' The branch, the branch.', 'timestamp': (25.6, 27.6)}]

Without generate()'s return_timestamps=True:

Predicted id [0] text: <|startoftranscript|><|en|><|transcribe|><|notimestamps|> there is it time... what is that chin? round one? you know what? the brown strap is... the brown strap is...<|endoftext|>

Predicted id [0] offsets: []

Full code below. (Please don't look at it unless you have to!)

#!/usr/bin/env python3
import os
# This is the directory created by run_speech_recognition_seq2seq.py
whdir_def="..../voice-training-dataset-create/whisper-custom-en"
# whdir_def="whisper-custom-en/checkpoint-1100"

def get_last_checkpoint_dir(dstr):
    # looks in dstr for checkpoint-* directories, picking latest mtime and returning its full rel path
    latest_checkpoint_dir = None
    latest_mtime = -1
    for item in os.listdir(dstr):
        item_path = os.path.join(dstr, item)
        if os.path.isdir(item_path) and item.startswith('checkpoint-'):
            mtime = os.path.getmtime(item_path)
            if mtime > latest_mtime:
                latest_mtime = mtime
                latest_checkpoint_dir = item_path
    return latest_checkpoint_dir


# whdir_def=get_last_checkpoint_dir("whisper-custom-en")
print(f"Whisper dir! {whdir_def}")

predwavs=[
    '..../patient--2024-04-27_06-32-part1.flac',
   ]

use_test_dataset_hf = False
use_test_dataset_moz = False

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from transformers import HfArgumentParser
import librosa
import numpy as np
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import argparse
import torch
    
# NOTE!  For loading large datasets, like common-voice, seq2seq's hf mozilla commonvoice loader (at least for v11)
# will try to preprocess the ENTIRE SET, even if you set your splits to small %'s.
# I modified mine to:
     #if training_args.do_train:
     #    raw_datasets["train"] = load_dataset(
     #        data_args.dataset_name,
     #        data_args.dataset_config_name,
     #        #split=data_args.train_split_name,
     #  # THIS LINE HERE AND IN THE .do_eval right below this one
     #        split=f'{data_args.train_split_name}[:1%]',  # Load only the first 1%
     #        cache_dir=model_args.cache_dir,
     #        token=model_args.token,
     #        #verification_mode='all_checks',
     #    )
# AND LOWER, BEFORE prepare_dataset(), slice the dataset (or it'll still preproc everything):
# These 4 lines:
    # if data_args.max_train_samples is not None:
    #     raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
    # if data_args.max_eval_samples is not None:
    #     raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))

    # def prepare_dataset(batch):

@dataclass
class AdditionalArguments:
    output_probabilities: bool = field(default=False, metadata={"help": "Output word probabilities"})

def main():
    global predwavs

    parser = argparse.ArgumentParser(description="Speech-to-Text Prediction")
    parser.add_argument("--whdir", type=str, help="Directory of the Whisper model")
    parser.add_argument("--predwav", type=str, help="Path to the audio file for prediction")
    parser.add_argument("-p", "--output_probabilities", action="store_true", help="Output word probabilities")
    parser.add_argument("-cp", "--cnt_probs", type=int, default=10, help="Number of top candidate tokens to output with their probabilities")

    # Parse arguments
    args = parser.parse_args()

    # Rest of your code
    # load model and processor
    # processor = WhisperProcessor.from_pretrained("openai/whisper-large")
    # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
    # Load model and processor
    whdir = args.whdir if args.whdir is not None else whdir_def
    if args.predwav:
        predwavs = [args.predwav]

    print(f"Predicting on wave(s):\n{predwavs}")
    processor = WhisperProcessor.from_pretrained(whdir)
    model = WhisperForConditionalGeneration.from_pretrained(whdir)
    model.config.forced_decoder_ids = None
    
    # load dummy dataset and read audio files
    #if use_test_dataset_hf:
    #    ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    #elif use_test_dataset_moz:
    #    ds = load_dataset(
    #        "mozilla-foundation/common_voice_11_0",
    #        "en",
    #        #split=data_args.train_split_name,
    #        split=f'train[:15%]',  # Load only the first %
    #        #cache_dir=model_args.cache_dir,
    #        token=True
    #        #verification_mode='all_checks',
    #    )
    #    example = ds[0]["audio"]
    #    sample = example['array']
    #    sr = example['array']
    #else:
    for predwav in predwavs:
        aa,sr=librosa.load(predwav, sr=16000)
        sample=aa.astype(np.float64)

        input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features 
        # generate token ids

        print("Generating...")
        # Raises error. don't use:
        # pids = model.generate(input_features, return_timestamps=True, return_token_timestamps=True, language='en')
        #    ".../transformers/tokenization_utils.py", line 976, in convert_ids_to_tokens
        #    index = int(index)
        #            ^^^^^^^^^^
        #    ValueError: invalid literal for int() with base 10: 's'

        # Allows timestamps but reduces transcription accuracy:
        #pids = model.generate(input_features, return_timestamps=True, language='en')

        # Highest accuracy is without timestamps:
        pids = model.generate(input_features, language='en')
        print("/Generating.")
        # decode token ids to text
        # print("batch_decode()")
        # transcription = processor.batch_decode(pids, skip_special_tokens=False)
        # print("/batch_decode")
        # print(f"Transcription: {transcription}")
        print("Timestamp info:")
        import ipdb; ipdb.set_trace(context=16); pass
        for pidi, pid in enumerate(pids):
            # timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
            timestamps = processor.tokenizer.decode(pid, output_offset=True)
            pdict = processor.tokenizer.decode(pid, output_offsets=True)
            print(f"Predicted id [{pidi}] text: {pdict['text']}")
            print(f"Predicted id [{pidi}] offsets: {pdict['offsets']}")
        import sys; sys.exit()
        import ipdb; ipdb.set_trace(context=16); pass
        
        transcription = processor.batch_decode(pids, skip_special_tokens=True)
        print(f"Transcription: {transcription}")

        if args.output_probabilities:
            # Use generate to handle decoder inputs automatically
            generated_outputs = model.generate(input_features, output_scores=True, return_dict_in_generate=True)
            scores = generated_outputs.scores  # List of tensors of scores for each step
        
            for stepi, step_scores in enumerate(scores):
                probabilities = F.softmax(step_scores, dim=-1)
                top_probs, top_indices = torch.topk(probabilities, args.cnt_probs, dim=-1)
                print(f"[{stepi}] Step")
                for i in range(args.cnt_probs):
                    token_id = top_indices[0][i].item()
                    word = processor.tokenizer.decode([token_id])
                    prob = top_probs[0][i].item()
                    print(f"  Token {i + 1}: {word} - {prob:.4f}")

if __name__ == '__main__':
    main()
@ArthurZucker
Copy link
Collaborator

cc @kamilakesbi as well!
@jaggzh we are going to need an audio we can work with together, and if you can reduce the reproducer to a minimal amount of custom code would be great!

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented May 16, 2024

Hey @jaggzh - thanks for reporting. This is actually the intended behaviour with Whisper. To understand why, recall that Whisper predicts the distribution over the next token $y_{i}$ conditionally over all previous tokens $\boldsymbol{y}_{0:i-1}$:

$$ y_{i} \sim P\left(y | \boldsymbol{y}_{0:i-1}\right) $$

When we decode without timestamps, we generate sequences with the following format:

<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> The cat sat on the mat.<|endoftranscript|>

Note the task token at index 4: the <|notimestamps|> indicates to the model that it should not predict timestamps.

To decode with timestamps, we ensure that the <|notimestamps|> is not generated at position 4, which triggers the model to predict with timestamp tokens:

<|startoftranscript|> <|en|> <|transcribe|> <|0.00|> The cat sat on the mat.<|4.22|><|endoftranscript|>

=> we can see here that the sequence of token ids changes in two ways:

  1. The task tokens are changed (we drop the <|notimestamps|> token from position 4)
  2. We predict timestamp tokens as part of the generated sequence (in this example, <|0.00|> and <|4.22|>). The key here is understanding that these timestamp tokens are predicted in the same way as the text tokens: auto-regressively based on the conditional probability distribution over previous tokens.

Since the sequence of token ids $\boldsymbol{y}_{0:i-1}$ changes, the predictions for token $y_{i}$ also change (by nature of the conditional probability distribution that we predict). Therefore, it's possible that the generations with timestamps differ from those without timestamps.

Generally, what we observe is that enabling timestamps gives less accurate transcriptions for short-form audio, and more accurate for long-form audio (whether you're using the chunked or sequential decoding algorithms).

@sanchit-gandhi
Copy link
Contributor

Closing the issue since it is in-fact the intended behaviour from Whisper, but happy to answer any follow-up questions you have! Feel free to post on this comment thread 🤗

@jaggzh
Copy link
Author

jaggzh commented May 25, 2024

2. We predict timestamp tokens as part of the generated sequence (in this example, `<|0.00|>` and `<|4.22|>`). The key here is understanding that these timestamp tokens are predicted in the same way as the text tokens: auto-regressively based on the conditional probability distribution over previous tokens.

Thank you so much for the extremely helpful and detailed explanation!
Are our timestamp tokens then initially generated during training, and are they actually {.02f}? (I'm dealing with short-form audio, from maybe .3 to 6s max (and very few samples reach above 1s -- it's for someone with speech issues). If I were to give up accuracy in the timestamp, like .1f, it might help the model have less variation in the timestamp tokens, and an easier time learning higher accuracy [I'm thinking]. The actual main goal of mine is not to get the accuracy of the timestamps -- they can be rough -- but not to damage the transcription accuracy [much] in the process.

Nevertheless, since it's short-form disjoint speech I began working on a project that does some nice automatic breaking up of audio with auto-calibrated silence detection -- and that's a module that operates as a generator function, returning the clip and the time offset, so I can use it in different projects (including my data prep OR prediction code). Thus, with such short utterances, I'm able to then get the timestamp of each clip and that'll be sufficient for my needs.

(It's not on topic, but if anyone's interested (not that they'll see this closed issue))...
You can find it here:
https://gist.github.com/jaggzh/e9a5b31afc218b8d44fd5ddb976c8c96
(If run directly it'll accept an audio file to test ones settings), but I didn't incorporate arg parsing so one has to modify the code to evaluate them).

It handles evaluating a provided audio file (file only right now.. can't yet use it on a live audio stream). It examines requested seconds of audio (chunk) and, within that small examination windows for each of their max amplitudes. (It considers the lowest of those as the noise floor). It then evaluates the max it heard (discards some (maxamp_discard_frac)), to take a fraction between the floor and that max as the acceptable signal (voice) level.

SS_20240525_023824

The purpose was to automatically adjust, instead of using fixed dB of many solutions I found.

If plotting, it ends up using my non-breaking key module (kbnb) -- that import can just be left out if not using it. Otherwise that's included in the gist, along with bansi.py for some perdy colors also used in the plotting.

In any case, it's also a good example of matplotlib running and updating its window in the bg, non-blocking. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants