简介
Grounding DINO 是一种先进的零样本目标检测模型,由 IDEA Research 开发。它通过将基于 Transformer 的检测器 DINO 与Grounded Pre-Training相结合,实现了通过人类输入(如类别名称或指代表达)对任意物体进行检测。
例如在不需要任何训练的情况下,告诉Grounding DINO找出图像中人所在的位置,Grounding DINO就能标注出人的坐标。如下:
演示流程:
基本原理
在Grounding DINO中,作者想要完成这样一项任务:根据人类文字输入去检测任意类别的目标,称作开放世界目标检测问题(open-set object detection)。
完成open-set object detection的关键是将language信息引入到目标的通用特征表示中。例如,GLIP利用对比学习的方式在目标检测和文字短语之间建立起了联系,它在close-set和open-set数据集上都有很好的表现。尽管如此,GLIP是基于传统的one-stage detector结构,因此还有一定的局限性。
受很多前期工作的启发(GLIP、DINO等),作者提出了Grounding DINO,它相对于GLIP有以下几点优势:
- Grounding DINO 的transformer结构更接近于NLP模型,因此它更容易同时处理图片和文字;
- Transformer-based detector在处理大型数据集时被证明有优势;
- 作为DETR的变种,DINO能够完成end-to-end的训练,而且不需要NMS等额外的后处理。
网络结构:
Grounding DINO的整体结构如上图所示。Grounding DINO是一个双encoder单decoder结构,它包含了
- 一个image backbone用于提取image feature
- 一个text backbone用于提取text feature
- 一个feature enhancer用于融合image和text feature
- 一个language-guide query selection模块用于query初始化
- 一个cross-modality decoder用于bbox预测
特点与优势
- 零样本检测能力:Grounding DINO 能够在没有目标数据集标注的情况下,通过文本提示检测未见过的类别。例如,在 COCO 数据集的零样本检测基准测试中,它达到了 52.5 AP。
- 强大的跨模态融合能力:通过深度的视觉与语言模态融合,模型在开放集目标检测和指代表达理解任务中表现出色。
- 端到端优化:基于 Transformer 的架构使得 Grounding DINO 可以端到端地进行优化,无需手工设计模块。
应用场景
- Grounding DINO 可以广泛应用于需要灵活目标检测的场景
- 自动驾驶:通过自然语言描述检测特定的交通标志或障碍物。
- 机器人视觉:根据指令识别和操作物体。
- 图像标注与内容理解:自动识别图像中的对象并生成描述。
安装
参考:https://blog.csdn.net/weixin_44151034/article/details/139362032
1.安装虚拟环境
conda create -n dino python=3.10 -y
2.安装pytorch
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
如果有显卡在安装pytorch之前需要了解显卡驱动支持的最高cuda版本,安装pytorch的同时会安装cuda和cudnn等模块,保证cuda版本在显卡支持的范围之内。
nvidia-smi
4.下载项目并安装
git clone https://github.com/IDEA-Research/GroundingDINO.git
cd GroundingDINO
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
5.下载权重
预训练模型:groundingdino_swint_ogc.pth 这里可能需要使用魔法
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
到这里为止就安装好了Grounding DINO,下面就来使用它。
运行
Grounding DINO 能做的事情包括理解文字信息,找出图像中文字信息描述的对象。比如告诉Grounding DINO 找出图像中人所在的位置。
在项目根目录下创建test.py文件,位置一定不能错,关系到寻找配置文件的路径。需要输入给模型包括一个图像和一段文字。准备一张图像,再准备一段文字,文字为想要检测的物体,用空格或句号隔开,如: "chair . person . cell . flower"
from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
#加载模型
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weights/groundingdino_swint_ogc.pth")
#要预测的图片路径
IMAGE_PATH = "1.jpeg"
#要预测的类别提示,可以输入多个类中间用英文句号隔开
TEXT_PROMPT = "chair . person . cell . flower"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
image_source, image = load_image(IMAGE_PATH)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
#保存预测的图片,保存到outputs文件夹中,名称为annotated_image.jpg
cv2.imwrite("annotated_image.jpg", annotated_frame)
执行代码:
得到标注结果的图像:
自带页面
Grounding DINO 使用代码推理还有一个更方便的网页端推理页面。在demo目录下面的gradio_demo.py 是一个python实现的前端推理页面。
由于我在写这篇文章时发现直接跑有点问题,所以需要稍作修改。主要修改在load_model_hf函数中,修改了模型文件的加载方式。原本从huggingface获取但是模型文件已经不在了,修改成使用上面下载的模型。
import argparse
from functools import partial
import cv2
import requests
import os
from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import warnings
import torch
# prepare the environment
os.system("python setup.py build develop --user")
os.system("pip install packaging==21.3")
os.system("pip install gradio==3.50.2")
warnings.filterwarnings("ignore")
import gradio as gr
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
# Use this command for evaluate the Grounding DINO model
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "weights/groundingdino_swint_ogc.pth"
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
# cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
# checkpoint = torch.load(cache_file, map_location='cpu')
checkpoint = torch.load(filename, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
# print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def image_transform_grounding(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(init_image, None) # 3, h, w
return init_image, image
def image_transform_grounding_for_vis(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
])
image, _ = transform(init_image, None) # 3, h, w
return image
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
init_image = input_image.convert("RGB")
original_size = init_image.size
_, image_tensor = image_transform_grounding(init_image)
image_pil: Image = image_transform_grounding_for_vis(init_image)
# run grounidng
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
return image_with_box
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args()
block = gr.Blocks().queue()
with block:
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
gr.Markdown("### Open-World Detection with Grounding DINO")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="pil")
grounding_caption = gr.Textbox(label="Detection Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
with gr.Column():
gallery = gr.outputs.Image(
type="pil",
# label="grounding results"
).style(full_width=True, full_height=True)
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
run_button.click(fn=run_grounding, inputs=[
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
执行 python demo/gradio_demo.py ,注意一定要在这个路径下。首先会安装一些库,然后启动
hon3.10/site-packages (from anyio->httpx->gradio==3.50.2) (1.3.1)
final text_encoder_type: bert-base-uncased
Running on local URL: http://0.0.0.0:7579
To create a public link, set `share=True` in `launch()`.
IMPORTANT: You are using gradio version 3.50.2, however version 4.44.1 is available, please upgrade.
--------
打开 http://IP:7579/ 就能看到页面。上传图片完成推理
各位亦菲、胡歌们,觉得不错请支持点赞,谢谢~