85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import clip
|
|
import numpy
|
|
import onnxruntime as ort
|
|
from typing import Tuple, List
|
|
|
|
from PIL import Image
|
|
from numpy import ndarray
|
|
from torch.ao.ns.fx.utils import compute_cosine_similarity
|
|
from tqdm import tqdm
|
|
from torch import Tensor
|
|
import os
|
|
from torchvision.datasets import CIFAR100
|
|
|
|
vit_model_uint8 = "clip-image-encoder-quant-uint8.onnx"
|
|
vit_model_int8 = "clip-image-encoder-quant-int8.onnx"
|
|
vit_model_fp16 = "clip-image-encoder-fp16.onnx"
|
|
vit_model_fp32 = "clip-image-encoder.onnx"
|
|
text_model_uint8 = "clip-text-encoder-quant-uint8.onnx"
|
|
text_model_int8 = "clip-text-encoder-quant-int8.onnx"
|
|
text_model_fp16 = "clip-text-encoder-fp16.onnx"
|
|
text_model_fp32 = "clip-text-encoder.onnx"
|
|
|
|
image_encoder: ort.InferenceSession
|
|
text_encoder: ort.InferenceSession
|
|
|
|
_, preprocess = clip.load('ViT-B/32')
|
|
|
|
image_embeddings: list[Tuple[str, Tensor]] = [] # (image_file_name, image_feat)
|
|
|
|
LIMIT = 100
|
|
|
|
|
|
def to_numpy(tensor: Tensor, dtype=None):
|
|
r: ndarray = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
|
r.astype(dtype=dtype)
|
|
return r
|
|
|
|
|
|
def build_image_embedding(images: List[str]):
|
|
print("images on disk:")
|
|
for i in tqdm(range(len(images)), unit=" images", postfix="build_image_embedding..."):
|
|
image = Image.open("images/" + images[i])
|
|
image_input = preprocess(image).unsqueeze(0)
|
|
image_features = image_encoder.run(None, {image_encoder.get_inputs()[0].name: to_numpy(image_input)})[0]
|
|
|
|
image_features = Tensor(numpy.array(image_features))
|
|
image_embeddings.append((images[i], image_features))
|
|
|
|
|
|
def search_by_input(input_text: str) -> Tuple[str, Tensor] or None:
|
|
text_input = clip.tokenize(input_text)
|
|
print("Token:", text_input, len(text_input[0]))
|
|
print("InputNames", text_encoder.get_inputs()[0].name)
|
|
arr = text_encoder.run(None, {text_encoder.get_inputs()[0].name: to_numpy(text_input, dtype=int)})[0]
|
|
text_features = Tensor(arr)
|
|
print("text_features", text_features)
|
|
MAX_SIM = -999
|
|
result: Tuple[str, Tensor] or None = None
|
|
for img_ebd in image_embeddings:
|
|
sim = compute_cosine_similarity(img_ebd[1], text_features)
|
|
if sim > MAX_SIM:
|
|
MAX_SIM = sim
|
|
result = img_ebd
|
|
return result[0], MAX_SIM
|
|
|
|
|
|
def image_search_test():
|
|
images = load_images()
|
|
build_image_embedding(images)
|
|
while True:
|
|
prompt = input("Search photo:")
|
|
label, score = search_by_input(prompt)
|
|
print("Result:", label, score)
|
|
Image.open("images/" + label).show()
|
|
|
|
|
|
def load_images() -> List[str]:
|
|
return os.listdir("images")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
image_encoder = ort.InferenceSession(vit_model_int8)
|
|
text_encoder = ort.InferenceSession(text_model_int8)
|
|
image_search_test()
|