File size: 6,191 Bytes
afdc6b5 e3e435d 0dee2e0 afdc6b5 0dee2e0 e3e435d 8405918 afdc6b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
---
license: mit
language:
- en
pipeline_tag: zero-shot-image-classification
tags:
- vision
- simple
- small
---
# tinyvvision 🧠✨
**tinyvvision** is a compact, synthetic curriculum-trained vision-language model designed to demonstrate real zero-shot capability in a minimal setup. Despite its small size (~630k parameters), it aligns images and captions effectively by learning shared visual-language embeddings.
## What tinyvvision can do:
- Match simple geometric shapes (circles, stars, hearts, triangles, etc.) and descriptive captions (e.g., "a red circle", "a yellow star").
- Perform genuine zero-shot generalization, meaning it can correctly match captions to shapes and colors it has never explicitly encountered during training.
## Model Details:
- **Type**: Contrastive embedding (CLIP-style, zero-shot)
- **Parameters**: ~630,000 (tiny!)
- **Training data**: Fully synthetic—randomly generated shapes, letters, numbers, and symbols paired with descriptive text captions.
- **Architecture**:
- **Image Encoder**: Simple CNN
- **Text Encoder**: Small embedding layer + bidirectional GRU
- **Embedding Dim**: 128-dimensional shared embedding space
## Examples of Zero-Shot Matching:
- **Seen during training**: "a red circle" → correctly matches the drawn red circle.
- **Never seen**: "a teal lightning bolt" → correctly matched a hand-drawn lightning bolt shape, despite never having seen one during training.
## Limitations:
- tinyvvision is designed as a demonstration of zero-shot embedding and generalization on synthetic data. It is not trained on real-world data or complex scenarios. While robust within its domain (simple geometric shapes and clear captions), results may vary significantly on more complicated or out-of-domain inputs.
## How to Test tinyvvision:
Check out the provided inference script to easily test your own shapes and captions. Feel free to challenge tinyvvision with new, unseen combinations to explore its generalization capability!
```python
from huggingface_hub import hf_hub_download
import torch, re, numpy as np, math
from PIL import Image, ImageDraw, ImageFont
repo = "ProCreations/tinyvvision"
pth = hf_hub_download(repo, "cortexclip-mini.pth")
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
state = torch.load(pth, map_location=device)
idx2tok = state["vocab"]
tok2idx = {t:i for i,t in enumerate(idx2tok)}
def encode_txt(s, maxlen=16):
toks = re.findall(r"\w+|[^\w\s]", s.lower())
ids = [tok2idx.get(t,0) for t in toks][:maxlen]
return ids + [0]*(maxlen-len(ids))
class TE(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(len(idx2tok), 64)
self.gru = torch.nn.GRU(64, 128, num_layers=2, bidirectional=True, batch_first=True)
self.out_proj = torch.nn.Linear(256, 128)
def forward(self, x):
e, _ = self.gru(self.emb(x))
return self.out_proj(e[:, -1])
class IE(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(3,32,5,1,2), torch.nn.ReLU(),
torch.nn.Conv2d(32,64,3,1,1), torch.nn.ReLU(),
torch.nn.Conv2d(64,128,3,1,1), torch.nn.ReLU(),
torch.nn.AdaptiveAvgPool2d((4,4)), torch.nn.Flatten(),
torch.nn.Linear(128*4*4,128), torch.nn.ReLU()
)
def forward(self, x): return self.conv(x)
te, ie = TE().to(device), IE().to(device)
te.load_state_dict(state["text_encoder"])
ie.load_state_dict(state["image_encoder"])
te.eval(); ie.eval()
# ----- CUSTOMIZE YOUR EXAMPLES HERE -----
# To try your own image:
# 1. Replace the 'custom_image()' function with your image drawing/loading code.
# 2. Replace 'custom_caption' with your own caption for the image.
def custom_image():
# Example: Draw your own "blue hexagon" shape below!
img = Image.new("RGB",(64,64),"white")
dr = ImageDraw.Draw(img)
dr.regular_polygon((32,32,22), n_sides=6, fill="blue")
arr = np.array(img).astype(np.float32)/255.0
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
custom_caption = "a blue hexagon"
# ----- FUN DEMO EXAMPLES -----
def draw_red_heart():
img = Image.new("RGB",(64,64),"white")
dr = ImageDraw.Draw(img)
dr.polygon([(32,18),(50,34),(32,56),(14,34)], fill="red") # simple heart
dr.ellipse((18,12,32,32), fill="red")
dr.ellipse((32,12,46,32), fill="red")
arr = np.array(img).astype(np.float32)/255.0
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
def draw_purple_star():
img = Image.new("RGB",(64,64),"white")
dr = ImageDraw.Draw(img)
points = [ (32+20*math.cos(math.radians(a)),32+20*math.sin(math.radians(a))) for a in range(-90, 270, 72) ]
for i in range(5):
dr.line([points[i], points[(i+2)%5]], fill="purple", width=7)
arr = np.array(img).astype(np.float32)/255.0
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
def draw_orange_pentagon():
img = Image.new("RGB",(64,64),"white")
dr = ImageDraw.Draw(img)
dr.regular_polygon((32,32,22), n_sides=5, fill="orange")
arr = np.array(img).astype(np.float32)/255.0
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
demo_imgs = [
(custom_image(), custom_caption),
(draw_red_heart(), "a red heart"),
(draw_purple_star(), "a purple star"),
(draw_orange_pentagon(), "an orange pentagon"),
]
captions = [c for (_,c) in demo_imgs]
img_tensors = [im for (im,_) in demo_imgs]
cap_ids = torch.tensor([encode_txt(c) for c in captions], device=device)
with torch.no_grad():
txt_emb = te(cap_ids)
for i, (img, caption) in enumerate(zip(img_tensors, captions)):
im_emb = ie(img)
sim = torch.nn.functional.cosine_similarity(im_emb, txt_emb).cpu().numpy()
rank = int(np.argmax(sim))
print(f"Input image {i+1}: '{caption}'")
print(" Similarity scores:")
for j, c in enumerate(captions):
print(f" {c}: {sim[j]:.4f}")
print(" Best match:", captions[rank], "\n")
```
✨ **Enjoy experimenting!** ✨ |