|
| 1 | +import os |
| 2 | +import jsonschema |
| 3 | +import asyncio |
| 4 | +import json |
| 5 | +import numpy as np |
| 6 | +from six import BytesIO |
| 7 | +import tensorflow as tf |
| 8 | +from PIL import Image, ImageDraw, ImageFont |
| 9 | +from inference.base_inference_engine import AbstractInferenceEngine |
| 10 | +from inference.exceptions import InvalidModelConfiguration, InvalidInputData, ApplicationError |
| 11 | +from object_detection.utils import label_map_util |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | +# noinspection PyMethodMayBeStatic |
| 16 | +class InferenceEngine(AbstractInferenceEngine): |
| 17 | + |
| 18 | + def __init__(self, model_path): |
| 19 | + self.label_path = "" |
| 20 | + self.NUM_CLASSES = None |
| 21 | + self.label_map = None |
| 22 | + self.labels = None |
| 23 | + self.label_map_dict = None |
| 24 | + self.categories = None |
| 25 | + self.category_index = None |
| 26 | + self.detect_fn = None |
| 27 | + self.font = ImageFont.truetype("./fonts/DejaVuSans.ttf", 20) |
| 28 | + super().__init__(model_path) |
| 29 | + |
| 30 | + def _load_image_into_numpy_array(self, path): |
| 31 | + """Load an image from file into a numpy array. |
| 32 | +
|
| 33 | + Puts image into numpy array to feed into tensorflow graph. |
| 34 | + Note that by convention we put it into a numpy array with shape |
| 35 | + (height, width, channels), where channels=3 for RGB. |
| 36 | +
|
| 37 | + Args: |
| 38 | + path: the file path to the image |
| 39 | +
|
| 40 | + Returns: |
| 41 | + uint8 numpy array with shape (img_height, img_width, 3) |
| 42 | + """ |
| 43 | + img_data = tf.io.gfile.GFile(path, 'rb').read() |
| 44 | + image = Image.open(BytesIO(img_data)) |
| 45 | + (im_width, im_height) = image.size |
| 46 | + return np.array(image.getdata()).reshape( |
| 47 | + (im_height, im_width, 3)).astype(np.uint8) |
| 48 | + |
| 49 | + def _get_keypoint_tuples(self, eval_config): |
| 50 | + """Return a tuple list of keypoint edges from the eval config. |
| 51 | +
|
| 52 | + Args: |
| 53 | + eval_config: an eval config containing the keypoint edges |
| 54 | +
|
| 55 | + Returns: |
| 56 | + a list of edge tuples, each in the format (start, end) |
| 57 | + """ |
| 58 | + tuple_list = [] |
| 59 | + kp_list = eval_config.keypoint_edge |
| 60 | + for edge in kp_list: |
| 61 | + tuple_list.append((edge.start, edge.end)) |
| 62 | + return tuple_list |
| 63 | + |
| 64 | + def load(self): |
| 65 | + |
| 66 | + with open(os.path.join(self.model_path, 'config.json')) as f: |
| 67 | + data = json.load(f) |
| 68 | + try: |
| 69 | + self.validate_json_configuration(data) |
| 70 | + self.set_model_configuration(data) |
| 71 | + except ApplicationError as e: |
| 72 | + raise e |
| 73 | + |
| 74 | + self.label_path = os.path.join(self.model_path, 'object-detection.pbtxt') |
| 75 | + self.label_map = label_map_util.load_labelmap(self.label_path) |
| 76 | + self.categories = label_map_util.convert_label_map_to_categories(self.label_map, |
| 77 | + max_num_classes=label_map_util.get_max_label_map_index( |
| 78 | + self.label_map), |
| 79 | + use_display_name=True) |
| 80 | + self.category_index = label_map_util.create_category_index(self.categories) |
| 81 | + self.label_map_dict = label_map_util.get_label_map_dict(self.label_map, use_display_name=True) |
| 82 | + self.labels = [label for label in self.label_map_dict] |
| 83 | + # allow memory growth |
| 84 | + [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices('GPU')] |
| 85 | + self.detect_fn = tf.saved_model.load(self.model_path) |
| 86 | + |
| 87 | + |
| 88 | + |
| 89 | + async def infer(self, input_data, draw, predict_batch): |
| 90 | + |
| 91 | + |
| 92 | + await asyncio.sleep(0.00001) |
| 93 | + try: |
| 94 | + pillow_image = Image.open(input_data.file).convert('RGB') |
| 95 | + np_image = np.array(pillow_image) |
| 96 | + except Exception as e: |
| 97 | + raise InvalidInputData('corrupted image') |
| 98 | + try: |
| 99 | + with open(self.model_path + '/config.json') as f: |
| 100 | + data = json.load(f) |
| 101 | + except Exception as e: |
| 102 | + raise InvalidModelConfiguration('config.json not found or corrupted') |
| 103 | + json_confidence = data['confidence'] |
| 104 | + json_predictions = data['predictions'] |
| 105 | + |
| 106 | + input_tensor = tf.convert_to_tensor(np_image) |
| 107 | + input_tensor = input_tensor[tf.newaxis, ...] |
| 108 | + detections = self.detect_fn(input_tensor) |
| 109 | + |
| 110 | + height, width, depth = np_image.shape |
| 111 | + |
| 112 | + names = [] |
| 113 | + confidence = [] |
| 114 | + ids = [] |
| 115 | + bounding_boxes = [] |
| 116 | + names_start = [] |
| 117 | + scores = detections["detection_scores"][0].numpy() |
| 118 | + boxes = detections["detection_boxes"][0].numpy() |
| 119 | + classes = (detections['detection_classes'][0].numpy()).astype(int) |
| 120 | + classes_names = ([self.category_index.get(i) for i in classes]) |
| 121 | + for name in classes_names: |
| 122 | + if name is not None: |
| 123 | + names_start.append(name['name']) |
| 124 | + |
| 125 | + for i in range(json_predictions): |
| 126 | + if scores[i] * 100 >= json_confidence: |
| 127 | + ymin = int(round(boxes[i][0] * height)) if int(round(boxes[i][0] * height)) > 0 else 0 |
| 128 | + xmin = int(round(boxes[i][1] * width)) if int(round(boxes[i][1] * height)) > 0 else 0 |
| 129 | + ymax = int(round(boxes[i][2] * height)) if int(round(boxes[i][2] * height)) > 0 else 0 |
| 130 | + xmax = int(round(boxes[i][3] * width)) if int(round(boxes[i][3] * height)) > 0 else 0 |
| 131 | + tmp = dict([('left', xmin), ('top', ymin), ('right', xmax), ('bottom', ymax)]) |
| 132 | + bounding_boxes.append(tmp) |
| 133 | + confidence.append(float(scores[i] * 100)) |
| 134 | + ids.append(int(classes[i])) |
| 135 | + names.append(names_start[i]) |
| 136 | + |
| 137 | + responses_list = zip(names, confidence, bounding_boxes, ids) |
| 138 | + |
| 139 | + output = [] |
| 140 | + for response in responses_list: |
| 141 | + tmp = dict([('ObjectClassName', response[0]), ('confidence', response[1]), ('coordinates', response[2]), |
| 142 | + ('ObjectClassId', response[3])]) |
| 143 | + output.append(tmp) |
| 144 | + |
| 145 | + if predict_batch: |
| 146 | + response = dict([('bounding-boxes', output), ('ImageName', input_data.filename)]) |
| 147 | + else: |
| 148 | + response = dict([('bounding-boxes', output)]) |
| 149 | + if not draw: |
| 150 | + return response |
| 151 | + else: |
| 152 | + try: |
| 153 | + self.draw_image(pillow_image, response) |
| 154 | + except ApplicationError as e: |
| 155 | + raise e |
| 156 | + except Exception as e: |
| 157 | + raise e |
| 158 | + |
| 159 | + async def run_batch(self, input_data, draw, predict_batch): |
| 160 | + result_list = [] |
| 161 | + for image in input_data: |
| 162 | + post_process = await self.infer(image, draw, predict_batch) |
| 163 | + if post_process is not None: |
| 164 | + result_list.append(post_process) |
| 165 | + return result_list |
| 166 | + |
| 167 | + def draw_image(self, image, response): |
| 168 | + """ |
| 169 | + Draws on image and saves it. |
| 170 | + :param image: image of type pillow image |
| 171 | + :param response: inference response |
| 172 | + :return: |
| 173 | + """ |
| 174 | + draw = ImageDraw.Draw(image) |
| 175 | + for bbox in response['bounding-boxes']: |
| 176 | + draw.rectangle([bbox['coordinates']['left'], bbox['coordinates']['top'], bbox['coordinates']['right'], |
| 177 | + bbox['coordinates']['bottom']], outline="red") |
| 178 | + left = bbox['coordinates']['left'] |
| 179 | + top = bbox['coordinates']['top'] |
| 180 | + conf = "{0:.2f}".format(bbox['confidence']) |
| 181 | + draw.text((int(left), int(top) - 20), str(conf) + "% " + str(bbox['ObjectClassName']), 'red', self.font) |
| 182 | + image.save('./result.jpg', 'PNG') |
| 183 | + |
| 184 | + def free(self): |
| 185 | + pass |
| 186 | + |
| 187 | + def validate_variables(self): |
| 188 | + valid: bool = False |
| 189 | + |
| 190 | + index_file: str = None |
| 191 | + meta_file: str = None |
| 192 | + |
| 193 | + for var in os.listdir(os.path.join(self.model_path, 'variables')): |
| 194 | + |
| 195 | + if var.startswith("variables") and var.endswith(".index"): |
| 196 | + index_file = var |
| 197 | + elif var.startswith("variables") and var.endswith(".data-00000-of-00001"): |
| 198 | + meta_file = var |
| 199 | + |
| 200 | + if meta_file is not None and index_file is not None: |
| 201 | + valid = True |
| 202 | + |
| 203 | + return valid |
| 204 | + |
| 205 | + def validate_configuration(self): |
| 206 | + # check if variables folder exist |
| 207 | + if not os.path.isdir(os.path.join(self.model_path, 'variables')): |
| 208 | + raise InvalidModelConfiguration('variables folder not found') |
| 209 | + |
| 210 | + # check if variables are valid |
| 211 | + if not self.validate_variables(): |
| 212 | + raise InvalidModelConfiguration('variables folder structure not valid') |
| 213 | + |
| 214 | + # check if weights file exists |
| 215 | + if not os.path.exists(os.path.join(self.model_path, 'saved_model.pb')): |
| 216 | + raise InvalidModelConfiguration('saved_model.pb not found') |
| 217 | + # check if labels file exists |
| 218 | + if not os.path.exists(os.path.join(self.model_path, 'object-detection.pbtxt')): |
| 219 | + raise InvalidModelConfiguration('object-detection.pbtxt not found') |
| 220 | + return True |
| 221 | + |
| 222 | + def set_model_configuration(self, data): |
| 223 | + self.configuration['framework'] = data['framework'] |
| 224 | + self.configuration['type'] = data['type'] |
| 225 | + self.configuration['network'] = data['network'] |
| 226 | + self.NUM_CLASSES = data['number_of_classes'] |
| 227 | + |
| 228 | + def validate_json_configuration(self, data): |
| 229 | + with open(os.path.join('inference', 'ConfigurationSchema.json')) as f: |
| 230 | + schema = json.load(f) |
| 231 | + try: |
| 232 | + jsonschema.validate(data, schema) |
| 233 | + except Exception as e: |
| 234 | + raise InvalidModelConfiguration(e) |
0 commit comments