-
Notifications
You must be signed in to change notification settings - Fork 512
/
base_tokenizer.py
129 lines (105 loc) · 4.09 KB
/
base_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Any
from torch import Tensor, nn
from corenet.utils import logger
class BaseTextTokenizer(nn.Module):
"""Base class for text tokenizers.
Args:
opts: Command-line arguments.
"""
def __init__(self, opts: argparse.Namespace) -> None:
super().__init__()
self.opts = opts
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
if cls == BaseTextTokenizer:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--text-tokenizer.name",
type=str,
default=None,
help="Name of the text tokenizer (e.g., clip). Defaults to None.",
)
group.add_argument(
"--text-tokenizer.sot-token",
type=str,
default=None,
help=f"Start of the text token. Defaults to None (i.e., users must specify the value if it needs to be used.).",
)
group.add_argument(
"--text-tokenizer.eot-token",
type=str,
default=None,
help=f"End of the text token. Defaults to None (i.e., users must specify the value if it needs to be used.).",
)
group.add_argument(
"--text-tokenizer.pad-token",
type=str,
default=None,
help=f"Pad token. Defaults to None (i.e., users must specify the value if it needs to be used.).",
)
return parser
@property
def vocab_size(self) -> int:
"""Text vocabulary size."""
raise NotImplementedError("Child classes must implement this method.")
@property
def eot_token(self) -> str:
"""End of text token."""
eot = getattr(self.opts, "text_tokenizer.eot_token")
if eot is None:
logger.error(
"EOT token can't be None. Please specify using 'text_tokenizer.eot_token' in config file."
)
return eot
@property
def eot_token_id(self) -> int:
"""Token index for EOT token."""
raise NotImplementedError("Child classes must implement this method.")
@property
def sot_token(self) -> str:
"""Start of text token."""
sot = getattr(self.opts, "text_tokenizer.sot_token")
if sot is None:
logger.error(
"SOT token can't be None. Please specify using 'text_tokenizer.sot_token' in config file."
)
return sot
@property
def sot_token_id(self) -> int:
"""Start of token index."""
raise NotImplementedError("Child classes must implement this method.")
@property
def pad_token(self) -> str:
"""Padding token."""
pad = getattr(self.opts, "text_tokenizer.pad_token")
if pad is None:
logger.error(
"Padding token can't be None. Please specify using 'text_tokenizer.pad_token' in config file."
)
return pad
@property
def pad_token_id(self) -> int:
"""Padding index."""
raise NotImplementedError("Child classes must implement this method.")
def tok_encode(self, input_sentence: str) -> Tensor:
"""Encodes a sentence into a tensor of token ids."""
raise NotImplementedError("Child classes must implement this method.")
def tok_decode(self, token_ids: Any) -> str:
"""Decodes token ids into a sentence."""
raise NotImplementedError("Child classes must implement this method.")
def forward(self, input_sentence: str) -> Tensor:
"""Tokenize the input sentence.
Args:
input_sentence: Pre-processed input sentence.
Returns:
Tensor containing tokenized sequence.
...note:
Input sentence should be pre-processed (e.g., lower case).
"""
tokenized_sentence = self.tok_encode(input_sentence)
return tokenized_sentence