-
Notifications
You must be signed in to change notification settings - Fork 512
/
image_pil.py
2205 lines (1880 loc) · 78.7 KB
/
image_pil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
import copy
import math
import random
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms as T
from torchvision.transforms import functional as F
from corenet.data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation
from corenet.data.transforms.utils import jaccard_numpy, setup_size
from corenet.options.parse_args import JsonValidator
from corenet.utils import logger
INTERPOLATION_MODE_MAP = {
"nearest": T.InterpolationMode.NEAREST,
"bilinear": T.InterpolationMode.BILINEAR,
"bicubic": T.InterpolationMode.BICUBIC,
"cubic": T.InterpolationMode.BICUBIC,
"box": T.InterpolationMode.BOX,
"hamming": T.InterpolationMode.HAMMING,
"lanczos": T.InterpolationMode.LANCZOS,
}
def _interpolation_modes_from_str(name: str) -> T.InterpolationMode:
return INTERPOLATION_MODE_MAP[name]
def _crop_fn(data: Dict, top: int, left: int, height: int, width: int) -> Dict:
"""Helper function for cropping"""
img = data["image"]
data["image"] = F.crop(img, top=top, left=left, height=height, width=width)
if "mask" in data:
mask = data.pop("mask")
data["mask"] = F.crop(mask, top=top, left=left, height=height, width=width)
if "box_coordinates" in data:
boxes = data.pop("box_coordinates")
area_before_cropping = (boxes[..., 2] - boxes[..., 0]) * (
boxes[..., 3] - boxes[..., 1]
)
boxes[..., 0::2] = np.clip(boxes[..., 0::2] - left, a_min=0, a_max=left + width)
boxes[..., 1::2] = np.clip(boxes[..., 1::2] - top, a_min=0, a_max=top + height)
area_after_cropping = (boxes[..., 2] - boxes[..., 0]) * (
boxes[..., 3] - boxes[..., 1]
)
area_ratio = area_after_cropping / (area_before_cropping + 1)
# keep the boxes whose area is atleast 20% of the area before cropping
keep = area_ratio >= 0.2
box_labels = data.pop("box_labels")
data["box_coordinates"] = boxes[keep]
data["box_labels"] = box_labels[keep]
if "instance_mask" in data:
assert "instance_coords" in data
instance_masks = data.pop("instance_mask")
data["instance_mask"] = F.crop(
instance_masks, top=top, left=left, height=height, width=width
)
instance_coords = data.pop("instance_coords")
instance_coords[..., 0::2] = np.clip(
instance_coords[..., 0::2] - left, a_min=0, a_max=left + width
)
instance_coords[..., 1::2] = np.clip(
instance_coords[..., 1::2] - top, a_min=0, a_max=top + height
)
data["instance_coords"] = instance_coords
return data
def _resize_fn(
data: Dict,
size: Union[Sequence, int],
interpolation: Optional[T.InterpolationMode or str] = T.InterpolationMode.BILINEAR,
) -> Dict:
"""Helper function for resizing"""
img = data["image"]
w, h = F.get_image_size(img)
if isinstance(size, Sequence) and len(size) == 2:
size_h, size_w = size[0], size[1]
elif isinstance(size, int):
if (w <= h and w == size) or (h <= w and h == size):
return data
if w < h:
size_h = int(size * h / w)
size_w = size
else:
size_w = int(size * w / h)
size_h = size
else:
raise TypeError(
"Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format(
size
)
)
if isinstance(interpolation, str):
interpolation = _interpolation_modes_from_str(name=interpolation)
data["image"] = F.resize(
img=img, size=[size_h, size_w], interpolation=interpolation
)
if "mask" in data:
mask = data.pop("mask")
# mask can be a PIL or Tensor.
# Especially for Mask-RCNN, we may have tensors with first dimension as 0.
# In that case, resize, won't work.
# A workaround is that we check for the instance of a Tensor and then check its dimension.
if isinstance(mask, torch.Tensor) and mask.shape[0] == 0:
# It's empty tensor.
resized_mask = torch.zeros(
[0, size_h, size_w], dtype=mask.dtype, device=mask.device
)
else:
resized_mask = F.resize(
img=mask,
size=[size_h, size_w],
interpolation=T.InterpolationMode.NEAREST,
)
data["mask"] = resized_mask
if "box_coordinates" in data:
boxes = data.pop("box_coordinates")
boxes[:, 0::2] *= 1.0 * size_w / w
boxes[:, 1::2] *= 1.0 * size_h / h
data["box_coordinates"] = boxes
if "instance_mask" in data:
assert "instance_coords" in data
instance_masks = data.pop("instance_mask")
resized_instance_masks = F.resize(
img=instance_masks,
size=[size_h, size_w],
interpolation=T.InterpolationMode.NEAREST,
)
data["instance_mask"] = resized_instance_masks
instance_coords = data.pop("instance_coords")
instance_coords = instance_coords.astype(np.float32)
instance_coords[..., 0::2] *= 1.0 * size_w / w
instance_coords[..., 1::2] *= 1.0 * size_h / h
data["instance_coords"] = instance_coords
return data
def _pad_fn(
data: Dict,
padding: Union[int, Sequence],
fill: Optional[int] = 0,
padding_mode: Optional[str] = "constant",
) -> Dict:
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
padding = [pad_left, pad_top, pad_right, pad_bottom]
data["image"] = F.pad(data.pop("image"), padding, fill, padding_mode)
if "mask" in data:
data["mask"] = F.pad(data.pop("mask"), padding, 0, "constant")
if "box_coordinates" in data:
# labels remain unchanged
boxes = data.pop("box_coordinates")
boxes[:, 0::2] += pad_left
boxes[:, 1::2] += pad_top
data["box_coordinates"] = boxes
return data
@TRANSFORMATIONS_REGISTRY.register(name="fixed_size_crop", type="image_pil")
class FixedSizeCrop(BaseTransformation):
def __init__(
self, opts, size: Optional[Union[int, Tuple[int, int]]] = None, *args, **kwargs
):
super().__init__(opts, *args, **kwargs)
# size can be passed as an argument or using config.
# The argument is useful when implementing variable samplers
if size is None:
size = getattr(opts, "image_augmentation.fixed_size_crop.size", None)
fill = getattr(opts, "image_augmentation.fixed_size_crop.fill", 0)
padding_mode = getattr(
opts, "image_augmentation.fixed_size_crop.padding_mode", "constant"
)
size = setup_size(
size,
error_msg="Please provide either int or (int, int) for size in {}.".format(
self.__class__.__name__
),
)
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill
self.padding_mode = padding_mode
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.fixed-size-crop.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.fixed-size-crop.size",
type=int,
nargs="+",
default=None,
help="Image size either as an int or (int, int).",
)
group.add_argument(
"--image-augmentation.fixed-size-crop.fill",
type=int,
default=0,
help="Fill value to be used during padding operation. Defaults to 0.",
)
group.add_argument(
"--image-augmentation.fixed-size-crop.padding-mode",
type=str,
default="constant",
help="Padding modes. Defaults to constant",
)
return parser
def __call__(self, data: Dict, *args, **kwargs) -> Dict:
img = data["image"]
width, height = F.get_image_size(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = random.random()
top = int(offset_height * r)
left = int(offset_width * r)
data = _crop_fn(
data, top=top, left=left, height=new_height, width=new_width
)
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
data = _pad_fn(
data,
padding=[0, 0, pad_right, pad_bottom],
fill=self.fill,
padding_mode=self.padding_mode,
)
return data
def __repr__(self):
return "{}(crop_size=({}, {}), fill={}, padding_mode={})".format(
self.__class__.__name__,
self.crop_height,
self.crop_width,
self.fill,
self.padding_mode,
)
@TRANSFORMATIONS_REGISTRY.register(name="scale_jitter", type="image_pil")
class ScaleJitter(BaseTransformation):
"""Randomly resizes the input within the scale range"""
def __init__(self, opts, *args, **kwargs) -> None:
target_size = getattr(opts, "image_augmentation.scale_jitter.target_size", None)
if target_size is None:
logger.error(
"Target size can't be None in {}.".format(self.__class__.__name__)
)
target_size = setup_size(
target_size,
error_msg="Need either an int or (int, int) for target size in {}".format(
self.__class__.__name__
),
)
scale_range = getattr(opts, "image_augmentation.scale_jitter.scale_range", None)
if scale_range is None:
logger.error(
"Scale range can't be None in {}".format(self.__class__.__name__)
)
if isinstance(scale_range, Sequence) and len(scale_range) == 2:
scale_range = scale_range
else:
logger.error(
"Need (float, float) for target size in {}".format(
self.__class__.__name__
)
)
if scale_range[0] > scale_range[1]:
logger.error(
"scale_range[1] >= scale_range[0] in {}. Got: {}".format(
self.__class__.__name__, scale_range[1], scale_range[0]
)
)
interpolation = getattr(
opts, "image_augmentation.scale_jitter.interpolation", "bilinear"
)
if isinstance(interpolation, str):
interpolation = _interpolation_modes_from_str(name=interpolation)
super().__init__(opts, *args, **kwargs)
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.scale-jitter.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.scale-jitter.interpolation",
type=str,
default="bilinear",
help="Interpolation method. Defaults to bilinear interpolation",
)
group.add_argument(
"--image-augmentation.scale-jitter.target-size",
type=int,
nargs="+",
default=None,
help="Target image size either as an int or (int, int).",
)
group.add_argument(
"--image-augmentation.scale-jitter.scale-range",
type=float,
nargs="+",
default=None,
help="Scale range as (float, float).",
)
return parser
def __call__(self, data: Dict, *args, **kwargs) -> Dict:
img = data["image"]
orig_width, orig_height = F.get_image_size(img)
scale = self.scale_range[0] + random.random() * (
self.scale_range[1] - self.scale_range[0]
)
r = (
min(
self.target_size[1] / orig_height,
self.target_size[0] / orig_width,
)
* scale
)
new_width = int(orig_width * r)
new_height = int(orig_height * r)
data = _resize_fn(
data, size=(new_height, new_width), interpolation=self.interpolation
)
return data
def __repr__(self):
return "{}(scale_range={}, target_size={}, interpolation={})".format(
self.__class__.__name__,
self.scale_range,
self.target_size,
self.interpolation,
)
@TRANSFORMATIONS_REGISTRY.register(name="random_resized_crop", type="image_pil")
class RandomResizedCrop(BaseTransformation, T.RandomResizedCrop):
"""
This class crops a random portion of an image and resize it to a given size.
"""
def __init__(
self, opts: argparse.Namespace, size: Union[Sequence, int], *args, **kwargs
) -> None:
interpolation = getattr(
opts, "image_augmentation.random_resized_crop.interpolation"
)
scale = getattr(opts, "image_augmentation.random_resized_crop.scale")
ratio = getattr(opts, "image_augmentation.random_resized_crop.aspect_ratio")
BaseTransformation.__init__(self, opts=opts)
T.RandomResizedCrop.__init__(
self,
size=size,
scale=scale,
ratio=ratio,
interpolation=interpolation,
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.random-resized-crop.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.random-resized-crop.interpolation",
type=str,
default="bilinear",
choices=list(INTERPOLATION_MODE_MAP.keys()),
help="Interpolation method for resizing. Defaults to bilinear.",
)
group.add_argument(
"--image-augmentation.random-resized-crop.scale",
type=JsonValidator(Tuple[float, float]),
default=(0.08, 1.0),
help="Specifies the lower and upper bounds for the random area of the crop, before resizing."
" The scale is defined with respect to the area of the original image. Defaults to "
"(0.08, 1.0)",
)
group.add_argument(
"--image-augmentation.random-resized-crop.aspect-ratio",
type=float or tuple,
default=(3.0 / 4.0, 4.0 / 3.0),
help="lower and upper bounds for the random aspect ratio of the crop, before resizing. "
"Defaults to (3./4., 4./3.)",
)
return parser
def get_rrc_params(self, image: Image.Image) -> Tuple[int, int, int, int]:
return T.RandomResizedCrop.get_params(
img=image, scale=self.scale, ratio=self.ratio
)
def __call__(self, data: Dict) -> Dict:
"""
Input data format:
data: mapping of: {
"image": [Height, Width, Channels],
"mask": [Height, Width],
"box_coordinates": [Num_boxes, x, y, w, h],
"box_labels: : [Num_boxes],
}
Output data format: Same as the input
"""
img = data["image"]
i, j, h, w = self.get_rrc_params(image=img)
data = _crop_fn(data=data, top=i, left=j, height=h, width=w)
return _resize_fn(data=data, size=self.size, interpolation=self.interpolation)
def __repr__(self) -> str:
return "{}(scale={}, ratio={}, size={}, interpolation={})".format(
self.__class__.__name__,
self.scale,
self.ratio,
self.size,
self.interpolation,
)
@TRANSFORMATIONS_REGISTRY.register(name="auto_augment", type="image_pil")
class AutoAugment(BaseTransformation, T.AutoAugment):
"""
This class implements the `AutoAugment data augmentation <https://arxiv.org/pdf/1805.09501.pdf>`_ method.
"""
def __init__(self, opts, *args, **kwargs) -> None:
policy_name = getattr(
opts, "image_augmentation.auto_augment.policy", "imagenet"
)
interpolation = getattr(
opts, "image_augmentation.auto_augment.interpolation", "bilinear"
)
if policy_name == "imagenet":
policy = T.AutoAugmentPolicy.IMAGENET
else:
raise NotImplemented
if isinstance(interpolation, str):
interpolation = _interpolation_modes_from_str(name=interpolation)
BaseTransformation.__init__(self, opts=opts)
T.AutoAugment.__init__(self, policy=policy, interpolation=interpolation)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.auto-augment.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.auto-augment.policy",
type=str,
default="imagenet",
help="Auto-augment policy name. Defaults to imagenet.",
)
group.add_argument(
"--image-augmentation.auto-augment.interpolation",
type=str,
default="bilinear",
help="Auto-augment interpolation method. Defaults to bilinear interpolation",
)
return parser
def __call__(self, data: Dict) -> Dict:
if "box_coordinates" in data or "mask" in data or "instance_masks" in data:
logger.error(
"{} is only supported for classification tasks".format(
self.__class__.__name__
)
)
img = data["image"]
img = super().forward(img)
data["image"] = img
return data
def __repr__(self) -> str:
return "{}(policy={}, interpolation={})".format(
self.__class__.__name__, self.policy, self.interpolation
)
@TRANSFORMATIONS_REGISTRY.register(name="rand_augment", type="image_pil")
class RandAugment(BaseTransformation, T.RandAugment):
"""
This class implements the `RandAugment data augmentation <https://arxiv.org/abs/1909.13719>`_ method.
"""
def __init__(self, opts, *args, **kwargs) -> None:
num_ops = getattr(opts, "image_augmentation.rand_augment.num_ops", 2)
magnitude = getattr(opts, "image_augmentation.rand_augment.magnitude", 9)
num_magnitude_bins = getattr(
opts, "image_augmentation.rand_augment.num_magnitude_bins", 31
)
interpolation = getattr(
opts, "image_augmentation.rand_augment.interpolation", "bilinear"
)
BaseTransformation.__init__(self, opts=opts)
if isinstance(interpolation, str):
interpolation = _interpolation_modes_from_str(name=interpolation)
T.RandAugment.__init__(
self,
num_ops=num_ops,
magnitude=magnitude,
num_magnitude_bins=num_magnitude_bins,
interpolation=interpolation,
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.rand-augment.enable",
action="store_true",
help="Use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.rand-augment.num-ops",
type=int,
default=2,
help="Number of augmentation transformations to apply sequentially. Defaults to 2.",
)
group.add_argument(
"--image-augmentation.rand-augment.magnitude",
type=int,
default=9,
help="Magnitude for all the transformations. Defaults to 9",
)
group.add_argument(
"--image-augmentation.rand-augment.num-magnitude-bins",
type=int,
default=31,
help="The number of different magnitude values. Defaults to 31.",
)
group.add_argument(
"--image-augmentation.rand-augment.interpolation",
type=str,
default="bilinear",
choices=list(INTERPOLATION_MODE_MAP.keys()),
help="Desired interpolation method. Defaults to bilinear",
)
return parser
def __call__(self, data: Dict) -> Dict:
if "box_coordinates" in data or "mask" in data or "instance_masks" in data:
logger.error(
"{} is only supported for classification tasks".format(
self.__class__.__name__
)
)
img = data["image"]
img = super().forward(img)
data["image"] = img
return data
def __repr__(self) -> str:
return "{}(num_ops={}, magnitude={}, num_magnitude_bins={}, interpolation={})".format(
self.__class__.__name__,
self.num_ops,
self.magnitude,
self.num_magnitude_bins,
self.interpolation,
)
@TRANSFORMATIONS_REGISTRY.register(name="trivial_augment_wide", type="image_pil")
class TrivialAugmentWide(BaseTransformation, T.TrivialAugmentWide):
"""
This class implements the `TrivialAugment (Wide) data augmentation <https://arxiv.org/abs/2103.10158>`_ method.
"""
def __init__(self, opts, *args, **kwargs) -> None:
num_magnitude_bins = getattr(
opts,
"image_augmentation.trivial_augment_wide.num_magnitude_bins",
31,
)
interpolation = getattr(
opts,
"image_augmentation.trivial_augment_wide.interpolation",
"bilinear",
)
BaseTransformation.__init__(self, opts=opts)
if isinstance(interpolation, str):
interpolation = _interpolation_modes_from_str(name=interpolation)
T.TrivialAugmentWide.__init__(
self,
num_magnitude_bins=num_magnitude_bins,
interpolation=interpolation,
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.trivial-augment-wide.enable",
action="store_true",
help="Use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.trivial-augment-wide.num-magnitude-bins",
type=int,
default=31,
help="The number of different magnitude values. Defaults to 31.",
)
group.add_argument(
"--image-augmentation.trivial-augment-wide.interpolation",
type=str,
default="bilinear",
choices=list(INTERPOLATION_MODE_MAP.keys()),
help="Desired interpolation method. Defaults to bilinear",
)
return parser
def __call__(self, data: Dict) -> Dict:
if "box_coordinates" in data or "mask" in data or "instance_masks" in data:
logger.error(
"{} is only supported for classification tasks".format(
self.__class__.__name__
)
)
img = data["image"]
img = super().forward(img)
data["image"] = img
return data
def __repr__(self) -> str:
return "{}(num_magnitude_bins={}, interpolation={})".format(
self.__class__.__name__,
self.num_magnitude_bins,
self.interpolation,
)
@TRANSFORMATIONS_REGISTRY.register(name="random_horizontal_flip", type="image_pil")
class RandomHorizontalFlip(BaseTransformation):
"""
This class implements random horizontal flipping method
"""
def __init__(self, opts, *args, **kwargs) -> None:
p = getattr(opts, "image_augmentation.random_horizontal_flip.p")
super().__init__(opts=opts)
self.p = p
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.random-horizontal-flip.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.random-horizontal-flip.p",
type=float,
default=0.5,
help="Probability for applying random horizontal flip",
)
return parser
def __call__(self, data: Dict) -> Dict:
if random.random() <= self.p:
img = data["image"]
width, height = F.get_image_size(img)
data["image"] = F.hflip(img)
if "mask" in data:
mask = data.pop("mask")
data["mask"] = F.hflip(mask)
if "box_coordinates" in data:
boxes = data.pop("box_coordinates")
boxes[..., 0::2] = width - boxes[..., 2::-2]
data["box_coordinates"] = boxes
if "instance_mask" in data:
assert "instance_coords" in data
instance_coords = data.pop("instance_coords")
instance_coords[..., 0::2] = width - instance_coords[..., 2::-2]
data["instance_coords"] = instance_coords
instance_masks = data.pop("instance_mask")
data["instance_mask"] = F.hflip(instance_masks)
return data
def __repr__(self) -> str:
return "{}(p={})".format(self.__class__.__name__, self.p)
@TRANSFORMATIONS_REGISTRY.register(name="random_rotate", type="image_pil")
class RandomRotate(BaseTransformation):
"""
This class implements random rotation method
"""
def __init__(self, opts, *args, **kwargs) -> None:
super().__init__(opts=opts)
self.angle = getattr(opts, "image_augmentation.random_rotate.angle", 10)
self.mask_fill = getattr(opts, "image_augmentation.random_rotate.mask_fill", 0)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.random-rotate.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.random-rotate.angle",
type=float,
default=10,
help="Angle for rotation. Defaults to 10. The angle is sampled "
"uniformly from [-angle, angle]",
)
group.add_argument(
"--image-augmentation.random-rotate.mask-fill",
default=0,
help="Fill value for the segmentation mask. Defaults to 0.",
)
return parser
def __call__(self, data: Dict) -> Dict:
data_keys = list(data.keys())
if "box_coordinates" in data_keys or "instance_mask" in data_keys:
logger.error("{} supports only images and masks")
rand_angle = random.uniform(-self.angle, self.angle)
img = data.pop("image")
data["image"] = F.rotate(
img,
angle=rand_angle,
interpolation=F.InterpolationMode.BILINEAR,
fill=0,
)
if "mask" in data:
mask = data.pop("mask")
data["mask"] = F.rotate(
mask,
angle=rand_angle,
interpolation=F.InterpolationMode.NEAREST,
fill=self.mask_fill,
)
return data
def __repr__(self) -> str:
return "{}(angle={}, mask_fill={})".format(
self.__class__.__name__, self.angle, self.mask_fill
)
@TRANSFORMATIONS_REGISTRY.register(name="resize", type="image_pil")
class Resize(BaseTransformation):
"""
This class implements resizing operation.
.. note::
Two possible modes for resizing.
1. Resize while maintaining aspect ratio. To enable this option, pass int as a size
2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size
.. note::
If img_size is passed as a positional argument, then it will override size from args
"""
def __init__(
self,
opts,
img_size: Optional[Union[Tuple[int, int], int]] = None,
*args,
**kwargs
) -> None:
interpolation = getattr(
opts, "image_augmentation.resize.interpolation", "bilinear"
)
super().__init__(opts=opts)
# img_size argument is useful for implementing multi-scale sampler
size = (
getattr(opts, "image_augmentation.resize.size", None)
if img_size is None
else img_size
)
if size is None:
logger.error("Size can not be None in {}".format(self.__class__.__name__))
# Possible modes.
# 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size
# 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size
if isinstance(size, Sequence) and len(size) == 1:
# List with single integer
size = size[0]
elif isinstance(size, Sequence) and len(size) > 2:
logger.error(
"The length of size should be either 1 or 2 in {}. Got: {}".format(
self.__class__.__name__, size
)
)
if not (isinstance(size, Sequence) or isinstance(size, int)):
logger.error(
"Size needs to be either Tuple of length 2 or an integer in {}. Got: {}".format(
self.__class__.__name__, size
)
)
self.size = size
self.interpolation = interpolation
self.maintain_aspect_ratio = True if isinstance(size, int) else False
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.resize.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.resize.interpolation",
type=str,
default="bilinear",
choices=list(INTERPOLATION_MODE_MAP.keys()),
help="Desired interpolation method for resizing. Defaults to bilinear",
)
group.add_argument(
"--image-augmentation.resize.size",
type=int,
nargs="+",
default=256,
help="Resize image to the specified size. If int is passed, then shorter side is resized"
"to the specified size and longest side is resized while maintaining aspect ratio."
"Defaults to None.",
)
return parser
def __call__(self, data: Dict) -> Dict:
return _resize_fn(data, size=self.size, interpolation=self.interpolation)
def __repr__(self) -> str:
return "{}(size={}, interpolation={}, maintain_aspect_ratio={})".format(
self.__class__.__name__,
self.size,
self.interpolation,
self.maintain_aspect_ratio,
)
@TRANSFORMATIONS_REGISTRY.register(name="center_crop", type="image_pil")
class CenterCrop(BaseTransformation):
"""
This class implements center cropping method.
.. note::
This class assumes that the input size is greater than or equal to the desired size.
"""
def __init__(self, opts, *args, **kwargs) -> None:
super().__init__(opts=opts)
size = getattr(opts, "image_augmentation.center_crop.size", None)
if size is None:
logger.error("Size cannot be None in {}".format(self.__class__.__name__))
if isinstance(size, Sequence) and len(size) == 2:
self.height, self.width = size[0], size[1]
elif isinstance(size, Sequence) and len(size) == 1:
self.height = self.width = size[0]
elif isinstance(size, int):
self.height = self.width = size
else:
logger.error("Scale should be either an int or tuple of ints")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--image-augmentation.center-crop.enable",
action="store_true",
help="use {}. This flag is useful when you want to study the effect of different "
"transforms.".format(cls.__name__),
)
group.add_argument(
"--image-augmentation.center-crop.size",
type=int,
nargs="+",
default=224,
help="Center crop size. Defaults to None.",
)