시작하기 전에 https://dbfoot.tistory.com/118를 참고하여 TensorFlow Object Detection API을 먼저 설치해야 합니다.
이 게시글에서 설명할 코드는 가상환경에서 실행합니다.
먼저 필요한 라이브러리 전부 import 해줍니다.
import tensorflow as tf
import os
import pathlib
import numpy as np
import zipfile
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from plot_object_detection_saved_model import IMAGE_PATHS
자신의 로컬에 설치된 레이블 파일을 인덱스에 연결시킵니다
PATH_TO_LABELS = '텐서플로우가 있는 폴더\\models\\research\\object_detection\\data\\mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS,use_display_name=True)
이제 모델을 가져옵니다.
모델은 여기서 가져오고 원하는 모델을 링크 주소 복사합니다.
제가 예시로 한 모델은
입니다.
모델을 불러오는 함수입니다.
def download_model(model_name, model_date):
base_url = 'http://download.tensorflow.org/models/object_detection/tf2/' #경로는 변하지 않음
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(fname=model_name,
origin=base_url + model_date + '/' + model_file,
untar=True)
return str(model_dir)
#모델 날짜와 모델 이름만 바뀜
MODEL_DATE = '20200711'
MODEL_NAME = 'ssd_mobilenet_v2_320x320_coco17_tpu-8'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)
#모델 받고 저장하는 함수
def load_model(model_dir):
model_full_dir = model_dir + "/saved_model"
# Load saved model and build the detection function
detection_model = tf.saved_model.load(model_full_dir)
return detection_model
detection_model = load_model(PATH_TO_MODEL_DIR)
base_url 에는 http://download.tensorflow.org/models/object_detection/tf2/ 이 고정으로 들어갑니다.
MODEL_DATE 엔 링크에 있는 날짜
MODEL_NAME 엔 링크에 있는 모델 이름을 넣어줍니다.
텐서플로우로 디텍션한 영상을 출력하는 함수
def show_inference(detection_model,image_np) :
input_tensor = tf.convert_to_tensor(image_np)
input_tensor = input_tensor[tf.newaxis, ...]
detections = detection_model(input_tensor)
num_detections = int(detections.pop('num_detections'))
detections = {key: value[0, :num_detections].numpy()
for key, value in detections.items()}
detections['num_detections'] = num_detections
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
print(detections)
image_np_with_detections = image_np.copy()
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'],
detections['detection_classes'],
detections['detection_scores'],
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False)
cv2.imshow('result',image_np_with_detections)
텐서플로우로 디텍션한 영상을 저장하는 함수입니다
def save_inference(detection_model,image_np,video_writer):
input_tensor = tf.convert_to_tensor(image_np)
input_tensor = input_tensor[tf.newaxis, ...]
detections = detection_model(input_tensor)
num_detections = int(detections.pop('num_detections'))
detections = {key: value[0, :num_detections].numpy()
for key, value in detections.items()}
detections['num_detections'] = num_detections
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
print(detections)
image_np_with_detections = image_np.copy()
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'],
detections['detection_classes'],
detections['detection_scores'],
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False)
video_writer.write(image_np_with_detections)
동영상을 실행하는 코드
cap = cv2.VideoCapture('영상이 있는 경로와 파일명')
캠으로 동영상을 촬영하는 코드
cap = cv2.VideoCapture(0)
해당 영상을 저장하거나 불러올때 필요한 코드입니다.
if cap.isOpened() == False:
print('비디오 실행 애러')
else:
#동영상 저장하는 변수 코드
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
out = cv2.VideoWriter('저장을 원하는 폴더 위치와 파일명.확장자명',cv2.VideoWriter_fourcc('M','J','P','G'),
20,#20은 초당 FPS수
(frame_width, frame_height))
#비디오 캡쳐에서 이미지를 1장씩 가져옵니다
#이 1장의 이미지를 오브젝트 디텍션합니다.
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
#frame이 이미지에 대한 넘파이 어레이 이므로 이 frame을 오브젝트 디텍션한다
#불러오는 시간 체크
start_time = time.time()
#이 코딩 save_inference를 하면 영상이 저장 show_inference를 하면 화면 출력
#save_inference(detection_model,frame,out)
show_inference(detection_model,frame)
end_time = time.time()
print('연산에 걸린 시간',str(end_time-start_time))
if cv2.waitKey(27) & 0xFF == 27:
break
else:
break
cap.release()
out.release()
cv2.destroyWindow()