有一个需求:计算图片的相似度
需要解决两个问题:
- 生成 ahash
- 存储和计算 ahash 之间的距离
生成 ahash
『生成 ahash』 选用 python 下面的一个 imagehash 库。(github:https://github.com/JohannesBu...)
from io import BytesIO
import numpy
import imagehash
from PIL import Image
def create_vector(file: BytesIO) -> bytes:
image = Image.open(file)
hash = imagehash.average_hash(image)
_vector = []
for h in hash.hash:
_vector.extend(h)
vector = bytes(
numpy.packbits(
[
int(v)
for v in _vector
],
axis=-1
).tolist()
)
return vector
create_vector 函数输出的类型是 bytes,就是二进制序列
imagehash.average_hash(image) 输出的 hash 对象,hash 对象有一个 hash 属性,这个属性的类型是
list[list[bool]]
打印出来就是长下面这样子,其实就是一个 8x8=64 bit 的序列[[False False False False False False False False] [ True False False False True False False False] [False False True True True True False False] [False False False True True False True True] [False False True True True False False False] [False True True True True False False False] [False True True True True False True True] [False False False True True False True True]]
向量数据库
『存储和计算 ahash 之间的距离』选用 milvus
创建集合
定义集合:
import settings
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
)
from loggers import logger
connections.connect(
host=settings.MILVUS_CONFIG.host,
port=settings.MILVUS_CONFIG.port,
)
schema = CollectionSchema([
FieldSchema("id", DataType.INT64, is_primary=True, auto_id=True),
FieldSchema("meta_id", DataType.INT64),
FieldSchema("company_id", DataType.INT64),
FieldSchema("image_vector", dtype=DataType.BINARY_VECTOR, dim=64)
])
# 集合不存在,则会自动创建集合;已存在,不会重复创建
collection = Collection(settings.MILVUS_CONFIG.collection.name, schema)
使用的向量类型是 dtype=DataType.BINARY_VECTOR
,
为什么不选 float 是因为我不知道怎么把 ahash 转成 float
关于向量索引的问题,因为我选用的 『向量类型』 是 BINARY_VECTOR。所以,索引类型只有 BIN_FLAT 和 BIN_IVF_FLAT 可选了
具体可看 https://milvus.io/docs/v2.2.x...
插入 ahash 到 milvus
class TestVector(unittest.TestCase):
def test_insert_vector(self):
"""
插入 ahash 到 milvus
python -m unittest testing.test_milvus.TestVector.test_insert_vector
"""
oss_file_path = 'image_hash/testing/WechatIMG193.jpeg'
file = BytesIO(bucket.get_object(oss_file_path).read())
vector = create_vector(file)
m_pk = insert_vector(vector, meta_id=2, company_id=1)
logger.debug(f'milvus pk: {m_pk}')
查询 ahash from milvus
def test_search(self):
"""
批量调用后端接口入库
python -m unittest testing.test_milvus.TestVector.test_search
"""
oss_file_path = 'image_hash/testing/WechatIMG193.jpeg'
file = BytesIO(open(BASE_DIR/'testing'/'resource'/'WechatIMG193.jpeg','rb').read())
vector = create_vector(file)
logger.debug(vector)
rows: list[dict[str, Any]] = collection.search(
data=[vector],
param={"metric_type": 'HAMMING', "params": {"nprobe": 32}},
anns_field='image_vector',
output_fields=['id', 'meta_id', 'company_id'],
limit=10,
)
logger.debug(rows)
logger.debug(type(rows))
注意 metric_type ,因为我选用的 『向量类型』 是 BINARY_VECTOR,所以,metric_type 要选择支持 BINARY_VECTOR 的才行
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。