Created
September 2, 2020 18:25
-
-
Save 26medias/505408fd58505cb1c1d0083c7a2cd69a to your computer and use it in GitHub Desktop.
TensorRT YoloV3 running on a video via OpenCV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import cv2 as cv | |
| import os | |
| import sys | |
| import tensorrt as trt | |
| import pycuda.driver as cuda | |
| import pycuda.autoinit | |
| from PIL import ImageDraw | |
| from PIL import Image | |
| from data_processing_cv import PreprocessYOLO, PostprocessYOLO, ALL_CATEGORIES | |
| import sys, os | |
| sys.path.insert(1, os.path.join(sys.path[0], "..")) | |
| import common | |
| TRT_LOGGER = trt.Logger() | |
| print(cv.__version__) | |
| onnx_file_path = 'yolov3.onnx' | |
| engine_file_path = "yolov3.trt" | |
| input_resolution_yolov3_HW = (608, 608) | |
| # Create a pre-processor object by specifying the required input resolution for YOLOv3 | |
| preprocessor = PreprocessYOLO(input_resolution_yolov3_HW) | |
| # Imported from onnx_to_tensorrt.py, modified for OpenCV | |
| def draw_bboxes(image_raw, bboxes, confidences, categories, all_categories, bbox_color='blue'): | |
| print(bboxes, confidences, categories) | |
| for box, score, category in zip(bboxes, confidences, categories): | |
| cv.rectangle(image_raw,(round(x_coord),round(y_coord)),(round(x_coord+width),round(y_coord+height)),(0,255,255),1) | |
| return image_raw | |
| # Imported from onnx_to_tensorrt.py | |
| def get_engine(onnx_file_path, engine_file_path=""): | |
| """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it.""" | |
| def build_engine(): | |
| """Takes an ONNX file and creates a TensorRT engine to run inference with""" | |
| with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: | |
| builder.max_workspace_size = 1 << 28 # 256MiB | |
| builder.max_batch_size = 1 | |
| # Parse model file | |
| if not os.path.exists(onnx_file_path): | |
| print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path)) | |
| exit(0) | |
| print('Loading ONNX file from path {}...'.format(onnx_file_path)) | |
| with open(onnx_file_path, 'rb') as model: | |
| print('Beginning ONNX file parsing') | |
| if not parser.parse(model.read()): | |
| print ('ERROR: Failed to parse the ONNX file.') | |
| for error in range(parser.num_errors): | |
| print (parser.get_error(error)) | |
| return None | |
| # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1 | |
| network.get_input(0).shape = [1, 3, 608, 608] | |
| print('Completed parsing of ONNX file') | |
| print('Building an engine from file {}; this may take a while...'.format(onnx_file_path)) | |
| engine = builder.build_cuda_engine(network) | |
| print("Completed creating Engine") | |
| with open(engine_file_path, "wb") as f: | |
| f.write(engine.serialize()) | |
| return engine | |
| if os.path.exists(engine_file_path): | |
| # If a serialized engine exists, use it instead of building an engine. | |
| print("Reading engine from file {}".format(engine_file_path)) | |
| with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: | |
| return runtime.deserialize_cuda_engine(f.read()) | |
| else: | |
| return build_engine() | |
| # Create the engine | |
| engine = get_engine(onnx_file_path, engine_file_path) | |
| context = engine.create_execution_context() | |
| inputs, outputs, bindings, stream = common.allocate_buffers(engine) | |
| # Apply YoloV3 on an OpenCV frame, return boxes, classes, scores | |
| def getYolo(img): | |
| # Convert the CV2 image to PIL | |
| img = cv.cvtColor(img, cv.COLOR_BGR2RGB) | |
| im_pil = Image.fromarray(img) | |
| image_raw, image = preprocessor.process(im_pil) | |
| # Store the shape of the original input image in WH format, we will need it for later | |
| shape_orig_WH = image_raw.size | |
| # Output shapes expected by the post-processor | |
| output_shapes = [(1, 255, 19, 19), (1, 255, 38, 38), (1, 255, 76, 76)] | |
| # Do inference with TensorRT | |
| trt_outputs = [] | |
| # Set host input to the image. The common.do_inference function will copy the input to the GPU before executing. | |
| inputs[0].host = image | |
| trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) | |
| # Before doing post-processing, we need to reshape the outputs as the common.do_inference will give us flat arrays. | |
| trt_outputs = [output.reshape(shape) for output, shape in zip(trt_outputs, output_shapes)] | |
| postprocessor_args = {"yolo_masks": [(6, 7, 8), (3, 4, 5), (0, 1, 2)], # A list of 3 three-dimensional tuples for the YOLO masks | |
| "yolo_anchors": [(10, 13), (16, 30), (33, 23), (30, 61), (62, 45), # A list of 9 two-dimensional tuples for the YOLO anchors | |
| (59, 119), (116, 90), (156, 198), (373, 326)], | |
| "obj_threshold": 0.6, # Threshold for object coverage, float value between 0 and 1 | |
| "nms_threshold": 0.5, # Threshold for non-max suppression algorithm, float value between 0 and 1 | |
| "yolo_input_resolution": input_resolution_yolov3_HW} | |
| postprocessor = PostprocessYOLO(**postprocessor_args) | |
| # Run the post-processing algorithms on the TensorRT outputs and get the bounding box details of detected objects | |
| boxes, classes, scores = postprocessor.process(trt_outputs, (shape_orig_WH)) | |
| return boxes, classes, scores | |
| # Read a video, run YoloV3 at every frame | |
| cap = cv.VideoCapture('traffic.mp4') | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| print("Can't receive frame (stream end?). Exiting ...") | |
| break | |
| # Apply YoloV3 | |
| boxes, classes, scores = getYolo(frame) | |
| # Draw the rects | |
| draw_bboxes(frame, boxes, scores, classes, ALL_CATEGORIES) | |
| cv.namedWindow("result", cv.WINDOW_AUTOSIZE) | |
| cv.imshow("result", frame) | |
| if cv.waitKey(1) == ord('q'): | |
| break | |
| cap.release() | |
| cv.destroyAllWindows() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment