Yolov5封装detect.py面向对象

news/发布时间2024/6/15 18:15:16

主要目标是适应摄像头rtsp流的检测

如果是普通文件夹或者图片,run中的while True去掉即可。

web_client是根据需求创建的客户端,将检测到的数据打包发送给服务器

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.Usage:$ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""import argparse
import json
import os
import sys
import time
import moment
from pathlib import Pathimport cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnnFILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relativefrom models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_syncfrom mytools import read_yaml_all, base64_encode_img
from message_base import MessageBase
from websocket_client import WebClientclass Detect:def __init__(self, config: dict, client: WebClient):self.config = configself.weights = self.config.get("weights")  # weights pathself.source = self.config.get("source")  # source self.imgsz = self.config.get("imgsz")  # imgszself.conf_thres = self.config.get("conf_thres")self.iou_thres = self.config.get("iou_thres")self.max_det = self.config.get("max_det")self.device = self.config.get("device")  # "cpu" or "0,1,2,3"self.view_img = self.config.get("view_img")  # show resultsself.save_txt = self.config.get("save_txt")  # save results to *.txtself.save_conf = self.config.get("save_conf")  # save confidences in --save-txt labelsself.save_crop = self.config.get("save_crop")  # save cropped prediction boxesself.nosave = self.config.get("nosave")  # do not save images/videosself.classes = self.config.get("classes")  # filter by class: --class 0, or --class 0 2 3self.agnostic_nms = self.config.get("agnostic_nms")  # class-agnostic NMSself.augment = self.config.get("augment")  # augmented inferenceself.visualize = self.config.get("visualize")  # visualize featuresself.update = self.config.get("update")  # update all modelsself.save_path = self.config.get("save_path")  # save results to project/nameself.line_thickness = self.config.get("line_thickness")  # bounding box thickness (pixels)self.hide_labels = self.config.get("hide_labels")  # hide labelsself.hide_conf = self.config.get("hide_conf")  # hide confidencesself.half = self.config.get("half")  # use FP16 half-precision inferenceself.dnn = self.config.get("dnn")  # use OpenCV DNN for ONNX inferenceself.func_device = self.config.get("func_device")  # 对应功能的设备名字self.save_img = not self.nosave and not self.source.endswith('.txt')  # save inference imagesself.webcam = self.source.isnumeric() or self.source.endswith('.txt') or self.source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))set_logging()self.device = select_device(self.device)self.half = self.device.type != 'cpu'  # half precision only supported on CUDAself.model = attempt_load(self.weights, map_location=self.device)self.imgsz = check_img_size(self.imgsz, s=int(self.model.stride.max()))self.stride = int(self.model.stride.max())self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names# 获取数据if self.webcam:self.view_img = check_imshow()cudnn.benchmark = True  # set True to speed up constant image size inferenceself.dataset = LoadStreams(self.source, img_size=self.imgsz, stride=self.stride, auto=True)self.bs = len(self.dataset)  # batch_sizeelse:self.dataset = LoadImages(self.source, img_size=self.imgsz, stride=self.stride, auto=True)self.bs = 1  # batch_sizeself.client = client  # 客户端self.last_time = moment.now()self.check_time_step = 5  # 每隔多少时间检测一次os.mkdir(self.save_path) if not os.path.exists(self.save_path) else Nonedef inference(self, img):img = torch.from_numpy(img).to(self.device)img = img.half() if self.half else img.float()  # uint8 to fp16/32img /= 255.0  # 0 - 255 to 0.0 - 1.0if img.ndimension() == 3:img = img.unsqueeze(0)pred = self.model(img, augment=self.augment)[0]# NMSpred = non_max_suppression(pred, self.conf_thres, self.iou_thres,self.classes, self.agnostic_nms, max_det=self.max_det)return preddef process(self, im0s, img, pred, path):for i, det in enumerate(pred):  # per imageif self.webcam:  # batch_size >= 1p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), self.dataset.countelse:p, s, im0, frame = path, '', im0s.copy(), getattr(self.dataset, 'frame', 0)p = Path(p)  # to Pathtxt_path = str(self.save_path + "/" + 'labels' + "/" + p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')  # img.txts += '%gx%g ' % img.shape[2:]  # print stringgn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwhimc = im0.copy() if self.save_crop else im0  # for save_cropannotator = Annotator(im0, line_width=self.line_thickness, example=str(self.names))if len(det):# Rescale boxes from img_size to im0 sizedet[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()# Print resultsfor c in det[:, -1].unique():n = (det[:, -1] == c).sum()  # detections per classs += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string# Write resultsfor *xyxy, conf, cls in reversed(det):c = int(cls)label = self.names[c]# if label == "person":if label:  # 根据对应标签做处理# annotator.box_label(xyxy, label, color=colors(c, True)) # 画框t = int(time.time())img_path = f"{self.save_path}/{self.func_device}_{label}_{t}.jpg"crop = save_one_box(xyxy, imc, img_path, BGR=True)x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])data = {"device": self.func_device,"value": {"label": label,"time": t,"locate": (x1, y1, x2, y2),"crop": base64_encode_img(crop)}}data = json.dumps(data)  # 打包数据try:self.client.send(data)  # 客户端发送数据passexcept Exception as err:print("发送失败:", err)self.client.connect()self.client.send(data)print("重连成功!")print(data)# if self.save_txt:  # Write to file#     xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(#         -1).tolist()  # normalized xywh#     line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh)  # label format#     with open(txt_path + '.txt', 'a') as f:#         f.write(('%g ' * len(line)).rstrip() % line + '\n')# 画框# if self.save_img or self.save_crop or self.view_img:  # Add bbox to image#     c = int(cls)  # integer class#     label = None if self.hide_labels else (self.names[c] if self.hide_conf else#                                            f'{self.names[c]} {conf:.2f}')#     annotator.box_label(xyxy, label, color=colors(c, True))def run(self):self.client.connect()while True:for path, img, im0s, vid_cap in self.dataset:if self.last_time.__lt__(moment.now()):self.last_time = moment.now().add(seconds=self.check_time_step)try:pred = self.inference(img)self.process(im0s, img, pred, path)              except Exception as err:print(err)if self.save_txt or self.save_img:s = f"\n{len(list(self.save_path.glob('labels/*.txt')))} labels saved to {self.save_path / 'labels'}" if self.save_txt else ''print(f"Results saved to {colorstr('bold', self.save_path)}{s}")if self.update:strip_optimizer(self.weights)  # update model (to fix SourceChangeWarning)if __name__ == "__main__":message_base = MessageBase()wc = WebClient("192.168.6.28", 8000)configs = read_yaml_all("yolo_configs.yaml")config = read_yaml_all("configs.yaml")device_name = config.get("DEVICE_LIST")[0]device_source = config.get("RTSP_URLS").get(device_name)configs["source"] = device_sourceconfigs["func_device"] = device_nameprint(configs)detect = Detect(configs, wc)detect.run()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.shwantai.cn/a/06026331.html

如若内容造成侵权/违法违规/事实不符,请联系万泰站长网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

免费SSL证书和付费SSL证书的区别点

背景: 在了解免费SSL证书和付费SSL证书的区别之前,先带大家了解一下SSL证书的概念和作用。 SSL证书的概念: SSL证书就是基于http超文本传输协议的延伸,在http访问的基础上增加了一个文本传输加密的协议,由于http是明…

蓝桥杯(5):python动态规划DF[2:背包问题]

1 0-1背包介绍【每件物品只能拿1件或者不拿】 1.1 简介 贪心是不可以的!!! 1.2 状态 及状态转移 转移解释:要么不选 则上一个直接转移过来【dp[i-1][j]】,要么是选这个之后体积为j 则上一个对应的就是【dp[i-1][j-wi]…

营销中的归因人工智能

Attribution AI in marketing 归因人工智能作为智能服务的一部分,是一种多渠道算法归因服务,根据特定结果计算客户互动的影响和增量影响。有了归因人工智能,营销人员可以通过了解每个客户互动对客户旅程每个阶段的影响来衡量和优化营销和广告…

canvas画图,画矩形、圆形、直线可拖拽移动,可拖拽更改尺寸大小

提示:canvas画图,画矩形,圆形,直线,曲线可拖拽移动 文章目录 前言一、画矩形,圆形,直线,曲线可拖拽移动总结 前言 一、画矩形,圆形,直线,曲线可拖…

Python | Leetcode Python题解之第10题正则表达式匹配

题目: 题解: class Solution:def isMatch(self, s: str, p: str) -> bool:m, n len(s), len(p)dp [False] * (n1)# 初始化dp[0] Truefor j in range(1, n1):if p[j-1] *:dp[j] dp[j-2]# 状态更新for i in range(1, m1):dp2 [False] * (n1) …

Maplesoft Maple 2024(数学科学计算)mac/win

Maplesoft Maple是一款强大的数学计算软件,提供了丰富的功能和工具,用于数学建模、符号计算、数据可视化等领域的数学分析和解决方案。 Mac版软件下载:Maplesoft Maple 2024 for mac激活版 WIn版软件下载:Maplesoft Maple 2024特别…