Push model using huggingface_hub.
Browse files- fuse_clip_hub.py +80 -0
fuse_clip_hub.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# fuseclip_hub.py (keep the rest of your code unchanged)
|
2 |
+
import inspect
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import PyTorchModelHubMixin
|
10 |
+
|
11 |
+
from fuse_clip.fuse_clip_arch import FuseCLIP
|
12 |
+
from open_clip import get_input_dtype, SimpleTokenizer
|
13 |
+
|
14 |
+
|
15 |
+
class FuseLIP(FuseCLIP, PyTorchModelHubMixin):
|
16 |
+
"""FuseLIP with save_pretrained / from_pretrained / push_to_hub."""
|
17 |
+
|
18 |
+
# ---------- save ----------
|
19 |
+
def _save_pretrained(self, save_directory: Path, **kwargs):
|
20 |
+
save_directory = Path(save_directory)
|
21 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
22 |
+
|
23 |
+
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
|
24 |
+
(save_directory / "config.json").write_text(
|
25 |
+
json.dumps(self.get_config(), indent=2)
|
26 |
+
)
|
27 |
+
# copy TiTok VQ-VAE weights so offline loading works
|
28 |
+
# shutil.copy(
|
29 |
+
# self.image_tokenizer.tokenizer_path,
|
30 |
+
# save_directory / "titok_image_tokenizer.pt"
|
31 |
+
# )
|
32 |
+
|
33 |
+
# publish fuse_clip_hub.py
|
34 |
+
source_path = Path(inspect.getfile(FuseLIP)) # absolute path of this file
|
35 |
+
shutil.copy(source_path, save_directory / "fuse_clip_hub.py")
|
36 |
+
|
37 |
+
# ---------- load ----------
|
38 |
+
@classmethod
|
39 |
+
def _from_pretrained(cls, save_directory: Path, **kwargs):
|
40 |
+
|
41 |
+
cfg = json.loads(Path(save_directory, "config.json").read_text())
|
42 |
+
|
43 |
+
tokenizer = SimpleTokenizer(context_length=cfg["context_length"])
|
44 |
+
tokenizer.pad_token_id = 0
|
45 |
+
|
46 |
+
if cfg["mlm_probability"] > 0:
|
47 |
+
MASK_TOKEN = "[MASK]"
|
48 |
+
if MASK_TOKEN not in tokenizer.encoder:
|
49 |
+
# Assign a new token ID
|
50 |
+
mask_token_id = max(tokenizer.encoder.values()) + 1 # Get a new unique ID
|
51 |
+
|
52 |
+
# Add to tokenizer's vocabulary
|
53 |
+
tokenizer.encoder[MASK_TOKEN] = mask_token_id
|
54 |
+
tokenizer.decoder[mask_token_id] = MASK_TOKEN
|
55 |
+
|
56 |
+
tokenizer.all_special_ids.append(mask_token_id)
|
57 |
+
tokenizer.mask_token = mask_token_id
|
58 |
+
tokenizer.vocab_size += 1
|
59 |
+
|
60 |
+
print(f"Added `[MASK]` token with ID {mask_token_id}")
|
61 |
+
else:
|
62 |
+
mask_token_id = tokenizer.encoder[MASK_TOKEN]
|
63 |
+
print(f"`[MASK]` token already exists with ID {mask_token_id}")
|
64 |
+
|
65 |
+
|
66 |
+
cfg["image_tokenizer_path"] = cfg["image_tokenizer"]
|
67 |
+
cfg["init_logit_scale"] = np.log(10)
|
68 |
+
cfg["init_logit_bias"] = -10
|
69 |
+
cfg["input_dtype"] = get_input_dtype("fp32")
|
70 |
+
del cfg["text_config"]
|
71 |
+
del cfg["image_tokenizer"]
|
72 |
+
del cfg["context_length"]
|
73 |
+
|
74 |
+
model = cls(**cfg, **kwargs) # device / dtype can be injected via kwargs
|
75 |
+
state = torch.load(
|
76 |
+
Path(save_directory, "pytorch_model.bin"),
|
77 |
+
map_location="cpu"
|
78 |
+
)
|
79 |
+
model.load_state_dict(state, strict=True)
|
80 |
+
return model
|