✏️ Editor's note
Every summer, the Milvus community will work with the Software Institute of the Chinese Academy of Sciences to prepare a wealth of engineering projects for college students in the "Summer of Open Source" activity, and arrange instructors to answer questions. Student Zhang Yumin performed well in the "Summer of Open Source" activities. He believes in the joy of progressing every inch and trying to surpass himself in the process of contributing to open source.
His project provides precision control for the vector query operations of the Milvus database, allowing developers to customize the return precision, reducing memory consumption and improving the readability of the returned results.
Want to learn more about high-quality open source projects and project experience sharing? Please poke: are there any open source projects worth participating in?
Project Description
Project name: Supports specifying the accuracy of the distance returned when searching
Student profile: Zhang Yumin, currently studying for a master's degree in Electronic Information Software Engineering, University of Chinese Academy of Sciences
Project Mentor: Zilliz Software Engineer Zhang Cai
Instructor’s comment: Zhang Yumin optimized the query function of the Milvus database so that it can be searched with a specified accuracy when searching, which makes the search process more flexible. Users can query with different accuracy according to their own needs, which brings users convenient.
Supports specifying the accuracy of the distance returned when searching
Task introduction
When performing a vector query, the search request returns the id and distance fields, where the distance field type is a floating point number. The distance calculated by the Milvus database is a 32-bit floating point number, but the Python SDK returns and displays it as a 64-bit floating point, resulting in invalid precision. The contribution of this project is to support the specified distance accuracy returned during search, solve the problem that some accuracy is invalid when displayed on the Python side, and reduce some memory overhead.
Project Objectives
- Solve the problem of mismatch between calculation results and display accuracy
- Support to return the specified distance accuracy when searching
- Supplement related documents
Project steps
- Preliminary research to understand the overall framework of Milvus
- Clarify the calling relationship between modules
- Design solutions and confirm results
Project summary
What is the Milvus database?
Milvus is an open source vector database that enables AI applications and vector similarity search. In system design, the front end of the Milvus database has a Python SDK (Client) that is convenient for users; at the back end of the Milvus database, the entire system is divided into an access layer (Access Layer), a coordination service (Coordinator Server), and an execution node (Worker). Node) and storage service (Storge) four levels:
(1) Access Layer: The facade of the system, including a set of peer Proxy nodes. The access layer is a unified endpoint exposed to users, responsible for forwarding requests and collecting execution results.
(2) Coordinator Service: The brain of the system is responsible for allocating tasks to execution nodes. There are four types of coordinator roles: root coordinator, data coordinator, query coordinator and index coordinator.
(3) Worker Node: The limbs of the system. The execution node is only responsible for passively executing read and write requests initiated by the coordination service. There are currently three types of execution nodes: data nodes, query nodes, and index nodes.
(4) Storage: The skeleton of the system is the basis for the realization of all other functions. The Milvus database relies on three types of storage: metadata storage, message storage (log broker), and object storage. From a language perspective, it can be seen as three language layers, namely the SDK layer composed of Python, the middle layer composed of Go, and the core computing layer composed of C++.
Architecture diagram of Milvus database
vector was queried for Search?
On the Python SDK side, when a user initiates a Search API call, the call will be encapsulated into a gRPC request and sent to the Milvus backend, and the SDK will start to wait. On the backend, the Proxy node first accepts the request sent from the Python SDK, then processes the accepted request, and finally encapsulates it into a message, which is sent to the consumption queue via the Producer. When the message is sent to the consumer queue, the Coordinator will coordinate it and send the message to the appropriate query node for consumption. When the query node receives the message, it will further process the message, and finally pass the information to the computing layer composed of C++. In the calculation layer, different calculation functions are called to calculate the distance between vectors according to different situations. When the calculation is completed, the result will be passed upwards in turn until it reaches the SDK end.
Solution design
Through a brief introduction in the previous article, we have a general concept of the vector query process. At the same time, we can also clearly realize that in order to complete the query goal, we need to modify the SDK layer composed of Python, the middle layer composed of Go, and the computing layer composed of C++. The modification plan is as follows:
1. Modification steps in the Python layer:
Add a round_decimal parameter to the vector query Search request to determine the accuracy information returned. At the same time, some legality checks and exception handling of the parameters are required to construct a gRPC request:
round_decimal = param_copy("round_decimal", 3)
if not isinstance(round_decimal, (int, str))
raise ParamError("round_decimal must be int or str")
try:
round_decimal = int(round_decimal)
except Exception:
raise ParamError("round_decimal is not illegal")
if round_decimal < 0 or round_decimal > 6:
raise ParamError("round_decimal must be greater than zero and less than seven")
if not instance(params, dict):
raise ParamError("Search params must be a dict")
search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params, "round_decimal": round_decimal}
2. Modification steps in the Go layer:
Add the constant RoundDecimalKey to the task.go file to maintain a uniform style and facilitate subsequent retrieval:
const (
InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask"
SearchTaskName = "SearchTask"
RetrieveTaskName = "RetrieveTask"
QueryTaskName = "QueryTask"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
MetricTypeKey = "metric_type"
SearchParamsKey = "params"
RoundDecimalKey = "round_decimal"
HasCollectionTaskName = "HasCollectionTask"
DescribeCollectionTaskName = "DescribeCollectionTask"
Next, modify the PreExecute function to obtain the value of round_decimal, construct the queryInfo variable, and add exception handling:
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
if err != nil {
return errors.New(SearchParamsKey + " not found in search_params")
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)
if err != nil {
return errors.New(RoundDecimalKey + "not found in search_params")
}
roundDeciaml, err := strconv.Atoi(roundDecimalStr)
if err != nil {
return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
}
queryInfo := &planpb.QueryInfo{
Topk: int64(topK),
MetricType: metricType,
SearchParams: searchParams,
RoundDecimal: int64(roundDeciaml),
}
At the same time, modify the query proto file and add round_decimal variable to QueryInfo:
message QueryInfo {
int64 topk = 1;
string metric_type = 3;
string search_params = 4;
int64 round_decimal = 5;
}
3. Modification steps in the C++ layer:
Add a new variable round\_decimal\_ in the SearchInfo structure to accept the round_decimal value from the Go layer:
struct SearchInfo {
int64_t topk_;
int64_t round_decimal_;
FieldOffset field_offset_;
MetricType metric_type_;
nlohmann::json search_params_;
};
In the ParseVecNode and PlanNodeFromProto functions, the SearchInfo structure needs to accept the round_decimal value in the Go layer:
std::unique_ptr<VectorPlanNode>
Parser::ParseVecNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto iter = out_body.begin();
auto field_name = FieldName(iter.key());
auto& vec_info = iter.value();
Assert(vec_info.is_object());
auto topk = vec_info["topk"];
AssertInfo(topk > 0, "topk must greater than 0");
AssertInfo(topk < 16384, "topk is too large");
auto field_offset = schema.get_offset(field_name);
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
auto& field_meta = schema.operator[](field_name);
auto data_type = field_meta.get_data_type();
if (data_type == DataType::VECTOR_FLOAT) {
return std::make_unique<FloatVectorANNS>();
} else {
return std::make_unique<BinaryVectorANNS>();
}
}();
vec_node->search_info_.topk_ = topk;
vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type"));
vec_node->search_info_.search_params_ = vec_info.at("params");
vec_node->search_info_.field_offset_ = field_offset;
vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
vec_node->placeholder_tag_ = vec_info.at("query");
auto tag = vec_node->placeholder_tag_;
AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
tag2field_.emplace(tag, field_offset);
return vec_node;
}
std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
// TODO: add more buffs
Assert(plan_node_proto.has_vector_anns());
auto& anns_proto = plan_node_proto.vector_anns();
auto expr_opt = [&]() -> std::optional<ExprPtr> {
if (!anns_proto.has_predicates()) {
return std::nullopt;
} else {
return ParseExpr(anns_proto.predicates());
}
}();
auto& query_info_proto = anns_proto.query_info();
SearchInfo search_info;
auto field_id = FieldId(anns_proto.field_id());
auto field_offset = schema.get_offset(field_id);
search_info.field_offset_ = field_offset;
search_info.metric_type_ = GetMetricType(query_info_proto.metric_type());
search_info.topk_ = query_info_proto.topk();
search_info.round_decimal_ = query_info_proto.round_decimal();
search_info.search_params_ = json::parse(query_info_proto.search_params());
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
if (anns_proto.is_binary()) {
return std::make_unique<BinaryVectorANNS>();
} else {
return std::make_unique<FloatVectorANNS>();
}
}();
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
plan_node->predicate_ = std::move(expr_opt);
plan_node->search_info_ = std::move(search_info);
return plan_node;
}
Add a new member variable round_decimal to the SubSearchResult class, and modify the declaration of each SubSearchResult variable at the same time:
class SubSearchResult {
public:
SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type)
: metric_type_(metric_type),
num_queries_(num_queries),
topk_(topk),
labels_(num_queries * topk, -1),
values_(num_queries * topk, init_value(metric_type)) {
}
Add a new member function to the SubSearchResult class to control the rounding precision of each result at the end:
void
SubSearchResult::round_values() {
if (round_decimal_ == -1)
return;
const float multiplier = pow(10.0, round_decimal_);
for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
*it = round(*it * multiplier) / multiplier;
}
}
Add a new variable round_decimal to the SearchDataset structure, and modify every SearchDataset variable declaration:
struct SearchDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t round_decimal;
int64_t dim;
const void* query_data;
};
Modify the various distance calculation functions (FloatSearch, BinarySearchBruteForceFast, etc.) in the C++ layer to accept the round_decomal value:
Status
FloatSearch(const segcore::SegmentGrowingImpl& segment,
const query::SearchInfo& info,
const float* query_data,
int64_t num_queries,
int64_t ins_barrier,
const BitsetView& bitset,
SearchResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
// step 1: binary search to find the barrier of the snapshot
// auto del_barrier = get_barrier(deleted_record_, timestamp);
#if 0
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
Assert(bitmap_holder);
auto bitmap = bitmap_holder->bitmap_ptr;
#endif
// step 2.1: get meta
// step 2.2: get which vector field to search
auto vecfield_offset = info.field_offset_;
auto& field = schema[vecfield_offset];
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
auto dim = field.get_dim();
auto topk = info.topk_;
auto total_count = topk * num_queries;
auto metric_type = info.metric_type_;
auto round_decimal = info.round_decimal_;
// step 3: small indexing search
// std::vector<int64_t> final_uids(total_count, -1);
// std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
auto vec_ptr = record.get_field_data<FloatVector>(vecfield_offset);
int current_chunk_id = 0;
SubSearchResult
BinarySearchBruteForceFast(MetricType metric_type,
int64_t dim,
const uint8_t* binary_chunk,
int64_t size_per_chunk,
int64_t topk,
int64_t num_queries,
int64_t round_decimal,
const uint8_t* query_data,
const faiss::BitsetView& bitset) {
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
int64_t code_size = dim / 8;
const idx_t block_size = size_per_chunk;
raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
result_labels, bitset);
sub_result.round_values();
return sub_result;
}
Result confirmation
1. Recompile the Milvus database:
2. Start the environment container:
3. Start the Milvus database:
4. Construct vector query request:
5. Confirm the result, 3 decimal places are reserved by default, 0 is rounded down:
Summary and thoughts
Participating in this summer open source activity is a very precious experience for me. In this event, I tried to read the code of an open source project for the first time, I tried to get in touch with a multi-language project, and I got in touch with Make, gRPc, pytest, etc. for the first time. In the phase of writing code and testing code, I also encountered many unexpected problems, such as "strange" dependency problems, compilation failures caused by the Conda environment, test failures, and so on. Faced with these problems, I gradually learned to check the error log patiently and carefully, actively think about it, check the code and test it, step by step to narrow the scope of the error, locate the error code and try various solutions.
Through this activity, I have learned a lot of experience and lessons. At the same time, I am very grateful to my mentor Zhang Cai and thank him for patiently answering questions and guiding directions during my development process! At the same time, I hope you can pay more attention to the Milvus community, I believe you will be able to gain something!
Finally, I welcome everyone to communicate with me (📮 deepmin@mail.deepexplore.top ). My main research direction is natural language processing. I usually like to read science fiction, animation and toss server personal websites, and hang out on Stack Overflow and GitHub every day. I believe in the joy of progressing an inch, and hope to make progress together with you.
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。