27 lines
876 B
Python
27 lines
876 B
Python
import torch
|
|
|
|
from encoder.export_text_encoder import TextEncoder
|
|
|
|
# Export ImageEncoder of the CLIP to onnx model
|
|
if __name__ == '__main__':
|
|
import clip
|
|
|
|
device = "cpu"
|
|
model, preprocess = clip.load("ViT-B/32", device=device)
|
|
model.eval()
|
|
|
|
text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408,
|
|
transformer_width=512, transformer_heads=8, transformer_layers=12)
|
|
|
|
missing_keys, unexpected_keys = text_encoder.load_state_dict(model.state_dict(), strict=False)
|
|
|
|
text_encoder.eval()
|
|
|
|
input_tensor = clip.tokenize("a diagram").to(device)
|
|
traced_model = torch.jit.trace(text_encoder, input_tensor)
|
|
|
|
onnx_filename = 'clip-text-encoder.onnx'
|
|
|
|
torch.onnx.export(text_encoder, input_tensor, onnx_filename)
|
|
# python -m onnxsim clip-text-encoder.onnx clip-text-encoder-optimized.onnx
|