抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

文章主要是记录 HNSW 原理以及 Search 过程代码的学习。

论文浅读

目前已有较详细的解析,附上参考链接:

大白话理解 HNSW

【论文笔记】HNSW(Hierarchical Navigable Small World graphs)

近似最近邻算法 HNSW 学习笔记(二) 主要算法伪代码分析

Search 过程代码

hnswlib 库中截取部分代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// hnswlib/hnswalg.h

// KNN 搜索函数
/**
* 目前来看,这里与伪代码中不太一致
* 伪代码中将每次搜索 layer 单独写成一个函数
* 实现中没有单独写成一个函数,仅写了一个 search base layer 函数
* 实现中通过取消 check deletions 和忽略 stop conditions 提高性能
*/
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

// currObj: 进入 base layer 的 enterpoint
tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_),dist_func_param_);

// 与 pseudo code 中从最高层遍历至倒数第二层的逻辑一致
for (int level = maxlevel_; level > 0; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

// 获取了当前 level 的数据
data = (unsigned int *) get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;

tableint *datal = (tableint *) (data + 1);
// 比较当前 level 的点
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

// 如果当前层(非 base layer)存在更近的点
// 改变 currObj,继续遍历当前层
/**
没搞懂:changed = true 后,while 继续循环,这一层重新遍历一遍?
*/
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t,tableint>>, CompareByFirst> top_candidates;
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
if (bare_bone_search) {
// 搜索底层函数
top_candidates = searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

while (top_candidates.size() > k) {
top_candidates.pop();
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}
return result;
}


// 搜索底层函数
// bare_bone_search means there is no check for deletions and stop condition is ignoredin return of extra performance
template <bool bare_bone_search = true, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t,tableint>>, CompareByFirst>
searchBaseLayerST(
tableint ep_id,
const void *data_point,
size_t ef,
BaseFilterFunctor* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {

VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t,tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t,tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
if (bare_bone_search ||
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
char* ep_data = getDataByInternalId(ep_id);
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
if (!bare_bone_search && stop_condition) {
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
}
candidate_set.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
}

visited_array[ep_id] = visited_array_tag;

while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
dist_t candidate_dist = -current_node_pair.first;

bool flag_stop_search;
if (bare_bone_search) {
flag_stop_search = candidate_dist > lowerBound;
} else {
if (stop_condition) {
flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
} else {
flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
}
}
if (flag_stop_search) {
break;
}
candidate_set.pop();

tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
if (collect_metrics) {
metric_hops++;
metric_distance_computations+=size;
}

// 未完...
}
}