Files
coco 7846a45f2c a
2026-07-03 15:47:27 +08:00

157 lines
5.5 KiB
Python

import clip
import numpy
import onnxruntime as ort
from typing import Tuple, List
from PIL import Image
from numpy import ndarray
from torch.ao.ns.fx.utils import compute_cosine_similarity
from tqdm import tqdm
from torch import Tensor
import os
from torchvision.datasets import CIFAR100
vit_model_uint8 = "clip-image-encoder-quant-uint8.onnx"
vit_model_int8 = "clip-image-encoder-quant-int8.onnx"
vit_model_fp16 = "clip-image-encoder-fp16.onnx"
vit_model_fp32 = "clip-image-encoder.onnx"
text_model_uint8 = "clip-text-encoder-quant-uint8.onnx"
text_model_int8 = "clip-text-encoder-quant-int8.onnx"
text_model_fp16 = "clip-text-encoder-fp16.onnx"
text_model_fp32 = "clip-text-encoder.onnx"
image_encoder: ort.InferenceSession
text_encoder: ort.InferenceSession
_, preprocess = clip.load('ViT-B/32')
cifar100 = CIFAR100(root=os.path.expanduser("../../dataset"), download=True, train=False)
image_embeddings: list[Tuple[str, Tensor]] = []
LIMIT = 100
def to_numpy(tensor: Tensor, dtype=None):
r: ndarray = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
r.astype(dtype=dtype)
return r
def build_image_embedding():
for i in tqdm(range(LIMIT), unit=" images", postfix="build_image_embedding..."):
image, class_id = cifar100[i]
label = cifar100.classes[class_id]
image_input = preprocess(image).unsqueeze(0)
image_features = image_encoder.run(None, {image_encoder.get_inputs()[0].name: to_numpy(image_input)})[0]
image_features = Tensor(numpy.array(image_features))
image_embeddings.append((label, image_features))
def search_by_input(input_text: str) -> Tuple[str, Tensor] or None:
text_input = clip.tokenize(input_text)
arr = text_encoder.run(None, {text_encoder.get_inputs()[0].name: to_numpy(text_input, dtype=int)})[0]
text_features = Tensor(arr)
print("text_features", text_features)
MAX_SIM = -999
result: Tuple[str, Tensor] or None = None
for img_ebd in image_embeddings:
sim = compute_cosine_similarity(img_ebd[1], text_features)
if sim > MAX_SIM:
MAX_SIM = sim
result = img_ebd
print(MAX_SIM)
return result
def benchmark():
build_image_embedding()
correct = 0
count = 0
for i in tqdm(range(LIMIT), unit=" images", postfix="validating..."):
image, class_id = cifar100[i]
label = cifar100.classes[class_id]
res = search_by_input(f"a photo of a {label}")
if res[0] == label:
correct += 1
count += 1
print("Accuracy:", correct / LIMIT)
def image_search_test():
images = load_images()
build_image_embedding_from_disk(images)
while True:
prompt = input("Search photo:")
index, score = search_by_input(prompt)
Image.open("images/" + images[index]).show()
# for i in tqdm(range(len(images)), unit=" images", postfix="validating..."):
# image, class_id = cifar100[i]
# label = cifar100.classes[class_id]
# res = search_by_input(f"a photo of a {label}")
# if res[0] == label:
# correct += 1
# count += 1
# print("Accuracy:", correct / LIMIT)
# Testing fp32 version...
# 100%|██████████| 5000/5000 [04:03<00:00, 20.57 images/s, build_image_embedding...]
# 100%|██████████| 5000/5000 [28:51<00:00, 2.89 images/s, validating...]
# Accuracy: 0.9404
# ==========================
# Testing fp32 version...
# 100%|██████████| 2000/2000 [01:29<00:00, 35.84 images/s, build_image_embedding...]
# 100%|██████████| 2000/2000 [03:21<00:00, 11.91 images/s, validating...]
# Accuracy: 0.871
# ==========================
# Testing uint8 version...
# 100%|██████████| 5000/5000 [02:29<00:00, 33.35 images/s, build_image_embedding...]
# 100%|██████████| 5000/5000 [15:10<00:00, 5.49 images/s, validating...]
# Accuracy: 0.7814
# ==========================
# Testing int8 version...
# 100%|██████████| 2000/2000 [00:55<00:00, 35.84 images/s, build_image_embedding...]
# 100%|██████████| 2000/2000 [02:47<00:00, 11.91 images/s, validating...]
# Accuracy: 0.825
# ==========================
# Testing int8 version...
# 100%|██████████| 5000/5000 [02:17<00:00, 36.47 images/s, build_image_embedding...]
# 100%|██████████| 5000/5000 [16:27<00:00, 5.06 images/s, validating...]
# Accuracy: 0.8304
def load_images() -> List[str]:
return os.listdir("images")
def build_image_embedding_from_disk(images: List[str]):
print("images on disk:")
for i in tqdm(range(len(images)), unit=" images", postfix="build_image_embedding..."):
image = Image.open("images/" + images[i])
image_input = preprocess(image).unsqueeze(0)
image_features = image_encoder.run(None, {image_encoder.get_inputs()[0].name: to_numpy(image_input)})[0]
image_features = Tensor(numpy.array(image_features))
image_embeddings.append((i, image_features))
if __name__ == '__main__':
image_encoder = ort.InferenceSession(vit_model_int8)
text_encoder = ort.InferenceSession(text_model_int8)
image_search_test()
# LIMIT = 100
# print("Testing int8 version...")
# benchmark()
# print("==========================")
# print("Testing fp32 version...")
# # fp32(original), CIFAR ~2000, acc: 0.87
# # fp16, CIFAR ~500, acc: 0.814
# image_encoder = ort.InferenceSession(vit_model_fp32)
# text_encoder = ort.InferenceSession(text_model_fp32)
# benchmark()