157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
import clip
|
|
import numpy
|
|
import onnxruntime as ort
|
|
from typing import Tuple, List
|
|
|
|
from PIL import Image
|
|
from numpy import ndarray
|
|
from torch.ao.ns.fx.utils import compute_cosine_similarity
|
|
from tqdm import tqdm
|
|
from torch import Tensor
|
|
import os
|
|
from torchvision.datasets import CIFAR100
|
|
|
|
vit_model_uint8 = "clip-image-encoder-quant-uint8.onnx"
|
|
vit_model_int8 = "clip-image-encoder-quant-int8.onnx"
|
|
vit_model_fp16 = "clip-image-encoder-fp16.onnx"
|
|
vit_model_fp32 = "clip-image-encoder.onnx"
|
|
text_model_uint8 = "clip-text-encoder-quant-uint8.onnx"
|
|
text_model_int8 = "clip-text-encoder-quant-int8.onnx"
|
|
text_model_fp16 = "clip-text-encoder-fp16.onnx"
|
|
text_model_fp32 = "clip-text-encoder.onnx"
|
|
|
|
image_encoder: ort.InferenceSession
|
|
text_encoder: ort.InferenceSession
|
|
|
|
_, preprocess = clip.load('ViT-B/32')
|
|
|
|
cifar100 = CIFAR100(root=os.path.expanduser("../../dataset"), download=True, train=False)
|
|
|
|
image_embeddings: list[Tuple[str, Tensor]] = []
|
|
|
|
LIMIT = 100
|
|
|
|
|
|
def to_numpy(tensor: Tensor, dtype=None):
|
|
r: ndarray = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
|
r.astype(dtype=dtype)
|
|
return r
|
|
|
|
|
|
def build_image_embedding():
|
|
for i in tqdm(range(LIMIT), unit=" images", postfix="build_image_embedding..."):
|
|
image, class_id = cifar100[i]
|
|
label = cifar100.classes[class_id]
|
|
image_input = preprocess(image).unsqueeze(0)
|
|
image_features = image_encoder.run(None, {image_encoder.get_inputs()[0].name: to_numpy(image_input)})[0]
|
|
|
|
image_features = Tensor(numpy.array(image_features))
|
|
image_embeddings.append((label, image_features))
|
|
|
|
|
|
def search_by_input(input_text: str) -> Tuple[str, Tensor] or None:
|
|
text_input = clip.tokenize(input_text)
|
|
arr = text_encoder.run(None, {text_encoder.get_inputs()[0].name: to_numpy(text_input, dtype=int)})[0]
|
|
text_features = Tensor(arr)
|
|
print("text_features", text_features)
|
|
MAX_SIM = -999
|
|
result: Tuple[str, Tensor] or None = None
|
|
for img_ebd in image_embeddings:
|
|
sim = compute_cosine_similarity(img_ebd[1], text_features)
|
|
if sim > MAX_SIM:
|
|
MAX_SIM = sim
|
|
result = img_ebd
|
|
print(MAX_SIM)
|
|
return result
|
|
|
|
|
|
def benchmark():
|
|
build_image_embedding()
|
|
correct = 0
|
|
count = 0
|
|
for i in tqdm(range(LIMIT), unit=" images", postfix="validating..."):
|
|
image, class_id = cifar100[i]
|
|
label = cifar100.classes[class_id]
|
|
res = search_by_input(f"a photo of a {label}")
|
|
if res[0] == label:
|
|
correct += 1
|
|
count += 1
|
|
|
|
print("Accuracy:", correct / LIMIT)
|
|
|
|
|
|
def image_search_test():
|
|
images = load_images()
|
|
build_image_embedding_from_disk(images)
|
|
while True:
|
|
prompt = input("Search photo:")
|
|
index, score = search_by_input(prompt)
|
|
Image.open("images/" + images[index]).show()
|
|
|
|
# for i in tqdm(range(len(images)), unit=" images", postfix="validating..."):
|
|
# image, class_id = cifar100[i]
|
|
# label = cifar100.classes[class_id]
|
|
# res = search_by_input(f"a photo of a {label}")
|
|
# if res[0] == label:
|
|
# correct += 1
|
|
# count += 1
|
|
|
|
# print("Accuracy:", correct / LIMIT)
|
|
|
|
|
|
# Testing fp32 version...
|
|
# 100%|██████████| 5000/5000 [04:03<00:00, 20.57 images/s, build_image_embedding...]
|
|
# 100%|██████████| 5000/5000 [28:51<00:00, 2.89 images/s, validating...]
|
|
# Accuracy: 0.9404
|
|
# ==========================
|
|
# Testing fp32 version...
|
|
# 100%|██████████| 2000/2000 [01:29<00:00, 35.84 images/s, build_image_embedding...]
|
|
# 100%|██████████| 2000/2000 [03:21<00:00, 11.91 images/s, validating...]
|
|
# Accuracy: 0.871
|
|
# ==========================
|
|
|
|
# Testing uint8 version...
|
|
# 100%|██████████| 5000/5000 [02:29<00:00, 33.35 images/s, build_image_embedding...]
|
|
# 100%|██████████| 5000/5000 [15:10<00:00, 5.49 images/s, validating...]
|
|
# Accuracy: 0.7814
|
|
# ==========================
|
|
# Testing int8 version...
|
|
# 100%|██████████| 2000/2000 [00:55<00:00, 35.84 images/s, build_image_embedding...]
|
|
# 100%|██████████| 2000/2000 [02:47<00:00, 11.91 images/s, validating...]
|
|
# Accuracy: 0.825
|
|
# ==========================
|
|
# Testing int8 version...
|
|
# 100%|██████████| 5000/5000 [02:17<00:00, 36.47 images/s, build_image_embedding...]
|
|
# 100%|██████████| 5000/5000 [16:27<00:00, 5.06 images/s, validating...]
|
|
# Accuracy: 0.8304
|
|
|
|
def load_images() -> List[str]:
|
|
return os.listdir("images")
|
|
|
|
|
|
def build_image_embedding_from_disk(images: List[str]):
|
|
print("images on disk:")
|
|
for i in tqdm(range(len(images)), unit=" images", postfix="build_image_embedding..."):
|
|
image = Image.open("images/" + images[i])
|
|
image_input = preprocess(image).unsqueeze(0)
|
|
image_features = image_encoder.run(None, {image_encoder.get_inputs()[0].name: to_numpy(image_input)})[0]
|
|
|
|
image_features = Tensor(numpy.array(image_features))
|
|
image_embeddings.append((i, image_features))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
image_encoder = ort.InferenceSession(vit_model_int8)
|
|
text_encoder = ort.InferenceSession(text_model_int8)
|
|
image_search_test()
|
|
# LIMIT = 100
|
|
# print("Testing int8 version...")
|
|
# benchmark()
|
|
# print("==========================")
|
|
# print("Testing fp32 version...")
|
|
# # fp32(original), CIFAR ~2000, acc: 0.87
|
|
# # fp16, CIFAR ~500, acc: 0.814
|
|
# image_encoder = ort.InferenceSession(vit_model_fp32)
|
|
# text_encoder = ort.InferenceSession(text_model_fp32)
|
|
# benchmark()
|