시작하기 전에 https://dbfoot.tistory.com/118를 참고하여 TensorFlow Object Detection API을 먼저 설치해야 합니다.
이 게시글에서 설명할 코드는 가상환경에서 실행합니다.
먼저 필요한 라이브러리 전부 import 해줍니다.
import tensorflow as tf
import os
import pathlib
import numpy as np
import zipfile
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from plot_object_detection_saved_model import IMAGE_PATHS
라이브러리를 실행시켰다면 제 로컬에 설치된 레이블 파일을 인덱스와 연결 시켜줍니다.
PATH_TO_LABELS = 'C:\\Users\\5-1\\Documents\\TensorFlow\\models\\research\\object_detection\\data\\mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS,use_display_name=True)
모델을 다운로드하는 함수
위 링크에서 모델들을 가져올 수 있습니다.
제가 보여드릴 예제는 3번째 CenterNet HourGlass104 1024x1024 입니다.
해당 링크를 우클릭 후 링크 주소 복사합니다.
아래 링크가 해당 복사한 주소 입니다.
def download_model(model_name, model_date):
#base_url은 변하지 않습니다.
base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(fname=model_name,
origin=base_url + model_date + '/' + model_file,
untar=True)
return str(model_dir)
#모델 날짜와 모델 이름만 바뀜
MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)
받아온 텐서플로우 모델을 불러옵니다.
def load_model(model_dir):
model_full_dir = model_dir + "/saved_model"
# Load saved model and build the detection function
detection_model = tf.saved_model.load(model_full_dir)
return detection_model
detection_model = load_model(PATH_TO_MODEL_DIR)
그리고 테스트할 이미지 경로를 입력하는 코드 입니다.
PATH_TO_IMAGE_DIR = pathlib.Path('data\\images') #이미지가 있는 경로를 적어주시면 됩니다.
IMAGE_PATHS = list(PATH_TO_IMAGE_DIR.glob('*.jpg'))
#이미지 이름하고 확장자명을 적어주시면 됩니다. *은 해당 폴더안에 .jpg 확장자명을 가지고 있는 이미지는 모두 사용한다는 뜻입니다.
print(IMAGE_PATHS)
이미지를 가져왔으면 모든 이미지를 컴퓨터가 읽을 수 있게 넘파이 행렬로 변경해줍니다.
def load_image_into_numpy_array(path):
print(str(path))
return cv2.imread(str(path))
for image_path in IMAGE_PATHS:
print('Running inference for {}... '.format(image_path), end='')
image_np = load_image_into_numpy_array(image_path)
input_tensor = tf.convert_to_tensor(image_np)
input_tensor = input_tensor[tf.newaxis, ...]
detections = detection_model(input_tensor)
num_detections = int(detections.pop('num_detections'))
detections = {key: value[0, :num_detections].numpy()
for key, value in detections.items()}
detections['num_detections'] = num_detections
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
print(detections)
image_np_with_detections = image_np.copy()
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'],
detections['detection_classes'],
detections['detection_scores'],
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False)
cv2.imshow(str(image_path),image_np_with_detections)
그리고 코드가 종료된 후 사진이 바로 꺼지는걸 막기 위해 waitkey() 함수를 사용해 줍니다.
cv2.waitKey(0)
cv2.destroyWindow()
그리고 실행시켜보시면 해당 사진으로 이미지 오브젝트 디텍션이 잘 되는걸 보실 수 있습니다.