同步链接: https://yangshun.win/blogs/684610df
github code: https://github.com/busyboxs/BaiDuAICPP
图像主体检测能检测出图片主体的坐标位置,可使用该接口裁剪出图像主体区域,配合图像识别接口提升识别精度。广泛适用于美图类 app、辅助智能识图等业务场景中。
应用场景
- 智能美图:根据用户上传照片进行主体检测,实现图像裁剪或背景虚化等功能,可应用于含美图功能 app 等业务场景中
- 图像识别辅助:可使用图像主体检测裁剪出图像主体区域,配合图像识别接口提升识别精度
接口描述
用户向服务请求检测图像中的主体位置。
请求说明
- HTTP 方法: POST
- 请求 URL: https://aip.baidubce.com/rest/2.0/image-classify/v1/object_detect
- URL参数: access_token
- Header 参数: Content-Type = application/x-www-form-urlencoded
- Body 参数:见下表
返回说明
返回参数如下表:
返回示例如下:
{
"log_id": 895582300,
"result": {
"width": 486,
"top": 76,
"left": 134,
"height": 394
}
}
C++ 代码实现调用
这里假设已经将环境配置好了,环境配置的文章可以参考 Windows 下使用 Vcpkg 配置百度 AI 图像识别 C++开发环境(VS2017)[https://yangshun.win/blogs/3b103680/]。
为了方便,首先根据返回参数定义了一个结构体,该结构体包括了返回参数中的参数,如下:
struct ObjDetInfo {
uint32_t left;
uint32_t top;
uint32_t width;
uint32_t height;
void print() {
std::cout << std::setw(20) << std::setfill('-') << '\n';
std::cout << "left: " << left << "\n";
std::cout << "top: " << top << "\n";
std::cout << "width: " << width << "\n";
std::cout << "height: " << height << "\n";
}
void draw(cv::Mat &img) {
cv::Rect rect(left, top, width, height);
cv::rectangle(img, rect, cv::Scalar(255, 0, 255), 3);
}
};
在 ObjInfo 结构体中,定义了一个 print 方法以打印获取的结果,draw 方法以在图像上画出边框。
然后定义了一个类来调用接口并获取结果
class ObjectDetection
{
public:
ObjectDetection();
~ObjectDetection();
Json::Value request(std::string imgBase64, std::map& options);
// only get first result
void getResult(ObjDetInfo& result);
private:
Json::Value obj_;
std::string url_;
// file to save token key
std::string filename_;
};
类中的私有成员 obj_ 表示返回结果对应的 json 对象。url_ 表示请求的 url,filename_ 表示用于存储 access token 的文件的文件名。
request 函数输入请求图像的 base64 编码以及请求参数,返回一个 json 对象,json 对象中包含请求的结果。
getResult 获取请求的结果。
完整代码如下
util.h 和 util.cpp 代码参见 (简单调用篇 01) 通用物体和场景识别高级版 - C++ 简单调用[https://yangshun.win/blogs/cd08a730/]
ObjectDetection.h 代码如下:
#pragma once
#include
#include
#include
#include "util.h"
struct ObjDetInfo {
uint32_t left;
uint32_t top;
uint32_t width;
uint32_t height;
void print() {
std::cout << std::setw(20) << std::setfill('-') << '\n';
std::cout << "left: " << left << "\n";
std::cout << "top: " << top << "\n";
std::cout << "width: " << width << "\n";
std::cout << "height: " << height << "\n";
}
void draw(cv::Mat &img) {
cv::Rect rect(left, top, width, height);
cv::rectangle(img, rect, cv::Scalar(255, 0, 255), 3);
}
};
class ObjectDetection
{
public:
ObjectDetection();
~ObjectDetection();
Json::Value request(std::string imgBase64, std::map& options);
// only get first result
void getResult(ObjDetInfo& result);
private:
Json::Value obj_;
std::string url_;
// file to save token key
std::string filename_;
};
void objectDetectionTest();
ObjectDetection.cpp 代码如下:
#include "ObjectDetection.h"
ObjectDetection::ObjectDetection()
{
filename_ = "tokenKey";
url_ = "https://aip.baidubce.com/rest/2.0/image-classify/v1/object_detect";
}
ObjectDetection::~ObjectDetection()
{
}
Json::Value ObjectDetection::request(std::string imgBase64, std::map& options)
{
std::string response;
Json::Value obj;
std::string token;
// 1. get HTTP post body
std::string body;
mergeHttpPostBody(body, imgBase64, options);
// 2. get HTTP url with access token
std::string url = url_;
getHttpPostUrl(url, filename_, token);
// 3. post request, response store the result
int status_code = httpPostRequest(url, body, response);
if (status_code != CURLcode::CURLE_OK) {
obj["curl_error_code"] = status_code;
obj_ = obj;
return obj; // TODO: maybe should exit
}
// 4. make string to json object
generateJson(response, obj);
// if access token is invalid or expired, we will get a new one
if (obj["error_code"].asInt() == 110 || obj["error_code"].asInt() == 111) {
token = getTokenKey();
writeFile(filename_, token);
return request(imgBase64, options);
}
obj_ = obj;
checkErrorWithExit(obj);
return obj;
}
void ObjectDetection::getResult(ObjDetInfo & result)
{
result.left = obj_["result"].get("left", "0").asInt();
result.top = obj_["result"].get("top", "0").asInt();
result.width = obj_["result"].get("width", "0").asInt();
result.height = obj_["result"].get("height", "0").asInt();
}
void objectDetectionTest()
{
std::cout << "size: " << sizeof(ObjDetInfo) << "\n";
// read image and encode to base64
std::string img_file = "./images/cat.jpg";
std::string out;
readImageFile(img_file.c_str(), out);
std::string img_base64 = base64_encode(out.c_str(), (int)out.size());
// set options
std::map options;
options["with_face"] = "0";
Json::Value obj;
ObjectDetection objDetObj;
obj = objDetObj.request(img_base64, options);
//std::cout << (obj.get("result", "null")) << std::endl;
ObjDetInfo result;
objDetObj.getResult(result);
result.print();
cv::Mat img = cv::imread(img_file);
result.draw(img);
cv::namedWindow("Object Detection", cv::WINDOW_NORMAL);
cv::imshow("Object Detection", img);
cv::waitKey();
}
main.cpp 代码如下:
#include "util.h"
#include "ObjectDetection.h"
#include
int main() {
objectDetectionTest();
system("pause");
return EXIT_SUCCESS;
}
运行结果