kmeans聚类anchor框代码分享
189******30 发布于2020-02 浏览:2892 回复:1
0
收藏

#kmeans 聚类 anchors(AI识虫项目work目录下可直接运行,打印的聚类结果已排序)

%cd /home/aistudio/work
import numpy as np
from insects_reader import get_insect_names, get_annotations

records = get_annotations(get_insect_names(), 'insects/train')
boxes = []
for r in records:
boxes.extend(r['gt_bbox'])
#print(r['gt_bbox'])
boxes = np.array(boxes)
boxes = boxes[:, 2:4]
#w = boxes[:, 0]
#h = boxes[:, 1]
print(boxes)
#print(w)
#print(h)

from sklearn.cluster import KMeans
# 正式定义模型
model1 = KMeans(n_clusters=9)
# 跑模型
model1.fit(boxes)
# 需要知道每个类别有哪些参数
C_i = model1.predict(boxes)
# 还需要知道聚类中心的坐标
centers = model1.cluster_centers_
centers = centers.astype('int32')
print(centers)

from matplotlib import pyplot as plt
plt.figure(figsize=(12,8),dpi=80)
plt.scatter(boxes[:, 0], boxes[:, 1], s=1)
plt.scatter(centers[:, 0], centers[:, 1], s=50)
plt.show()

sums = centers[:, 0] + centers[:, 1]
#print(sums)
#print(centers)
centers_sorted = np.append(centers.T, [sums], axis = 0)
centers_sorted = centers_sorted.T
print(centers_sorted)
print('--------')
centers_sorted = centers_sorted[centers_sorted[:,2].argsort()]
print(centers_sorted)

收藏
点赞
0
个赞
共1条回复 最后由用户已被禁言回复于2022-04
#2189******30回复于2020-02
#kmeans 聚类 anchors

%cd  /home/aistudio/work
import numpy as np
from insects_reader import get_insect_names, get_annotations

records = get_annotations(get_insect_names(), 'insects/train')
boxes = []
for r in records:
    boxes.extend(r['gt_bbox'])
    #print(r['gt_bbox'])
boxes = np.array(boxes)
boxes = boxes[:, 2:4]
boxesRotated = boxes.copy()
boxesRotated[:, 0] = boxes[:, 1]
boxesRotated[:, 1] = boxes[:, 0]
data = np.concatenate((boxes, boxesRotated))
#w = boxes[:, 0]
#h = boxes[:, 1]
print(boxes)
print(boxesRotated)
print(boxes.shape, boxesRotated.shape, data.shape)
#print(w)
#print(h)

from sklearn.cluster import KMeans
# 正式定义模型
model1 = KMeans(n_clusters=9)
# 跑模型
model1.fit(data)
# 需要知道每个类别有哪些参数
C_i = model1.predict(data)
# 还需要知道聚类中心的坐标
centers = model1.cluster_centers_
centers = centers.astype('int32')
print(centers)

from matplotlib import pyplot as plt
plt.figure(figsize=(12,8),dpi=80)
plt.scatter(boxes[:, 0], boxes[:, 1], s=1)
plt.scatter(centers[:, 0], centers[:, 1], s=50)
plt.show()

sums = centers[:, 0] + centers[:, 1]
#print(sums)
#print(centers)
centers_sorted = np.append(centers.T, [sums], axis = 0)
centers_sorted = centers_sorted.T
print(centers_sorted)
print('--------')
centers_sorted = centers_sorted[centers_sorted[:,2].argsort()]
print(centers_sorted)
0
TOP
切换版块