Skip to content

Commit c78be99

Browse files
Mountchickengaotongxiao
authored andcommitted
[Refactor] Refactor TextRecogVisualizer
1 parent 7e7a526 commit c78be99

File tree

4 files changed

+213
-1
lines changed

4 files changed

+213
-1
lines changed

.codespellrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
skip = *.ipynb
33
count =
44
quiet-level = 3
5-
ignore-words-list = convertor,convertors,formating,nin,wan,datas,hist
5+
ignore-words-list = convertor,convertors,formating,nin,wan,datas,hist,ned

mmocr/core/visualization/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .textrecog_visualizer import TextRecogLocalVisualizer
3+
4+
__all__ = ['TextRecogLocalVisualizer']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, Optional, Tuple, Union
3+
4+
import cv2
5+
import mmcv
6+
import numpy as np
7+
from mmengine import Visualizer
8+
9+
from mmocr.core import TextRecogDataSample
10+
from mmocr.registry import VISUALIZERS
11+
12+
13+
@VISUALIZERS.register_module()
14+
class TextRecogLocalVisualizer(Visualizer):
15+
"""MMOCR Text Detection Local Visualizer.
16+
17+
Args:
18+
name (str): Name of the instance. Defaults to 'visualizer'.
19+
image (np.ndarray, optional): The origin image to draw. The format
20+
should be RGB. Defaults to None.
21+
vis_backends (list, optional): Visual backend config list.
22+
Defaults to None.
23+
save_dir (str, optional): Save file dir for all storage backends.
24+
If it is None, the backend storage will not save any data.
25+
gt_color (str or tuple[int, int, int]): Colors of GT text. The tuple of
26+
color should be in RGB order. Or using an abbreviation of color,
27+
such as `'g'` for `'green'`. Defaults to 'g'.
28+
pred_color (str or tuple[int, int, int]): Colors of Predicted text.
29+
The tuple of color should be in RGB order. Or using an abbreviation
30+
of color, such as `'r'` for `'red'`. Defaults to 'r'.
31+
"""
32+
33+
def __init__(self,
34+
name: str = 'visualizer',
35+
image: Optional[np.ndarray] = None,
36+
vis_backends: Optional[Dict] = None,
37+
save_dir: Optional[str] = None,
38+
gt_color: Optional[Union[str, Tuple[int, int, int]]] = 'g',
39+
pred_color: Optional[Union[str, Tuple[int, int,
40+
int]]] = 'r') -> None:
41+
super().__init__(
42+
name=name,
43+
image=image,
44+
vis_backends=vis_backends,
45+
save_dir=save_dir)
46+
self.gt_color = gt_color
47+
self.pred_color = pred_color
48+
49+
def add_datasample(self,
50+
name: str,
51+
image: np.ndarray,
52+
gt_sample: Optional['TextRecogDataSample'] = None,
53+
pred_sample: Optional['TextRecogDataSample'] = None,
54+
draw_gt: bool = True,
55+
draw_pred: bool = True,
56+
show: bool = False,
57+
wait_time: int = 0,
58+
out_file: Optional[str] = None,
59+
step=0) -> None:
60+
"""Visualize datasample and save to all backends.
61+
62+
- If GT and prediction are plotted at the same time, they are
63+
displayed in a stitched image where the left image is the
64+
ground truth and the right image is the prediction.
65+
- If ``show`` is True, all storage backends are ignored, and
66+
the images will be displayed in a local window.
67+
- If ``out_file`` is specified, the drawn image will be
68+
saved to ``out_file``. This is usually used when the display
69+
is not available.
70+
71+
Args:
72+
name (str): The image title. Defaults to 'image'.
73+
image (np.ndarray): The image to draw.
74+
gt_sample (:obj:`TextRecogDataSample`, optional): GT
75+
TextRecogDataSample. Defaults to None.
76+
pred_sample (:obj:`TextRecogDataSample`, optional): Predicted
77+
TextRecogDataSample. Defaults to None.
78+
draw_gt (bool): Whether to draw GT TextRecogDataSample.
79+
Defaults to True.
80+
draw_pred (bool): Whether to draw Predicted TextRecogDataSample.
81+
Defaults to True.
82+
show (bool): Whether to display the drawn image. Defaults to False.
83+
wait_time (float): The interval of show (s). Defaults to 0.
84+
out_file (str): Path to output file. Defaults to None.
85+
step (int): Global step value to record. Defaults to 0.
86+
"""
87+
gt_img_data = None
88+
pred_img_data = None
89+
height, width = image.shape[:2]
90+
resize_height = 64
91+
resize_width = int(1.0 * width / height * resize_height)
92+
image = cv2.resize(image, (resize_width, resize_height))
93+
if image.ndim == 2:
94+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95+
96+
if draw_gt and gt_sample is not None and 'gt_text' in gt_sample:
97+
gt_text = gt_sample.gt_text.item
98+
empty_img = np.full_like(image, 255)
99+
self.set_image(empty_img)
100+
font_size = 0.5 * resize_width / len(gt_text)
101+
self.draw_texts(
102+
gt_text,
103+
np.array([resize_width / 2, resize_height / 2]),
104+
colors=self.gt_color,
105+
font_sizes=font_size,
106+
vertical_alignments='center',
107+
horizontal_alignments='center')
108+
gt_text_image = self.get_image()
109+
gt_img_data = np.concatenate((image, gt_text_image), axis=0)
110+
111+
if (draw_pred and pred_sample is not None
112+
and 'pred_text' in pred_sample):
113+
pred_text = pred_sample.pred_text.item
114+
empty_img = np.full_like(image, 255)
115+
self.set_image(empty_img)
116+
font_size = 0.5 * resize_width / len(pred_text)
117+
self.draw_texts(
118+
pred_text,
119+
np.array([resize_width / 2, resize_height / 2]),
120+
colors=self.pred_color,
121+
font_sizes=font_size,
122+
vertical_alignments='center',
123+
horizontal_alignments='center')
124+
pred_text_image = self.get_image()
125+
pred_img_data = np.concatenate((image, pred_text_image), axis=0)
126+
127+
if gt_img_data is not None and pred_img_data is not None:
128+
drawn_img = np.concatenate((gt_img_data, pred_text_image), axis=0)
129+
elif gt_img_data is not None:
130+
drawn_img = gt_img_data
131+
else:
132+
drawn_img = pred_img_data
133+
134+
if show:
135+
self.show(drawn_img, win_name=name, wait_time=wait_time)
136+
else:
137+
self.add_image(name, drawn_img, step)
138+
139+
if out_file is not None:
140+
mmcv.imwrite(drawn_img[..., ::-1], out_file)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
import tempfile
4+
import unittest
5+
6+
import cv2
7+
import numpy as np
8+
from mmengine.data import LabelData
9+
10+
from mmocr.core import TextRecogDataSample
11+
from mmocr.core.visualization import TextRecogLocalVisualizer
12+
13+
14+
class TestTextDetLocalVisualizer(unittest.TestCase):
15+
16+
def test_add_datasample(self):
17+
h, w = 64, 128
18+
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
19+
20+
# test gt_text
21+
gt_recog_data_sample = TextRecogDataSample()
22+
img_meta = dict(img_shape=(12, 10, 3))
23+
gt_text = LabelData(metainfo=img_meta)
24+
gt_text.item = 'mmocr'
25+
gt_recog_data_sample.gt_text = gt_text
26+
27+
recog_local_visualizer = TextRecogLocalVisualizer()
28+
recog_local_visualizer.add_datasample('image', image,
29+
gt_recog_data_sample)
30+
31+
# test gt_text and pred_text
32+
pred_recog_data_sample = TextRecogDataSample()
33+
pred_text = LabelData(metainfo=img_meta)
34+
pred_text.item = 'MMOCR'
35+
pred_recog_data_sample.pred_text = pred_text
36+
37+
with tempfile.TemporaryDirectory() as tmp_dir:
38+
# test out
39+
out_file = osp.join(tmp_dir, 'out_file.jpg')
40+
41+
# draw_gt = True + gt_sample
42+
recog_local_visualizer.add_datasample(
43+
'image', image, gt_recog_data_sample, out_file=out_file)
44+
self._assert_image_and_shape(out_file, (h * 2, w, 3))
45+
46+
# draw_gt = True + gt_sample + pred_sample
47+
recog_local_visualizer.add_datasample(
48+
'image',
49+
image,
50+
gt_recog_data_sample,
51+
pred_recog_data_sample,
52+
out_file=out_file)
53+
self._assert_image_and_shape(out_file, (h * 3, w, 3))
54+
55+
# draw_gt = False + gt_sample + pred_sample
56+
recog_local_visualizer.add_datasample(
57+
'image',
58+
image,
59+
gt_recog_data_sample,
60+
pred_recog_data_sample,
61+
draw_gt=False,
62+
out_file=out_file)
63+
self._assert_image_and_shape(out_file, (h * 2, w, 3))
64+
65+
def _assert_image_and_shape(self, out_file, out_shape):
66+
self.assertTrue(osp.exists(out_file))
67+
drawn_img = cv2.imread(out_file)
68+
self.assertTrue(drawn_img.shape == out_shape)

0 commit comments

Comments
 (0)