chs20 commited on
Commit
5e8482a
·
verified ·
1 Parent(s): 90b95b4

Push model using huggingface_hub.

Browse files
Files changed (1) hide show
  1. fuse_clip_hub.py +80 -0
fuse_clip_hub.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fuseclip_hub.py (keep the rest of your code unchanged)
2
+ import inspect
3
+ import json
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+
11
+ from fuse_clip.fuse_clip_arch import FuseCLIP
12
+ from open_clip import get_input_dtype, SimpleTokenizer
13
+
14
+
15
+ class FuseLIP(FuseCLIP, PyTorchModelHubMixin):
16
+ """FuseLIP with save_pretrained / from_pretrained / push_to_hub."""
17
+
18
+ # ---------- save ----------
19
+ def _save_pretrained(self, save_directory: Path, **kwargs):
20
+ save_directory = Path(save_directory)
21
+ save_directory.mkdir(parents=True, exist_ok=True)
22
+
23
+ torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
24
+ (save_directory / "config.json").write_text(
25
+ json.dumps(self.get_config(), indent=2)
26
+ )
27
+ # copy TiTok VQ-VAE weights so offline loading works
28
+ # shutil.copy(
29
+ # self.image_tokenizer.tokenizer_path,
30
+ # save_directory / "titok_image_tokenizer.pt"
31
+ # )
32
+
33
+ # publish fuse_clip_hub.py
34
+ source_path = Path(inspect.getfile(FuseLIP)) # absolute path of this file
35
+ shutil.copy(source_path, save_directory / "fuse_clip_hub.py")
36
+
37
+ # ---------- load ----------
38
+ @classmethod
39
+ def _from_pretrained(cls, save_directory: Path, **kwargs):
40
+
41
+ cfg = json.loads(Path(save_directory, "config.json").read_text())
42
+
43
+ tokenizer = SimpleTokenizer(context_length=cfg["context_length"])
44
+ tokenizer.pad_token_id = 0
45
+
46
+ if cfg["mlm_probability"] > 0:
47
+ MASK_TOKEN = "[MASK]"
48
+ if MASK_TOKEN not in tokenizer.encoder:
49
+ # Assign a new token ID
50
+ mask_token_id = max(tokenizer.encoder.values()) + 1 # Get a new unique ID
51
+
52
+ # Add to tokenizer's vocabulary
53
+ tokenizer.encoder[MASK_TOKEN] = mask_token_id
54
+ tokenizer.decoder[mask_token_id] = MASK_TOKEN
55
+
56
+ tokenizer.all_special_ids.append(mask_token_id)
57
+ tokenizer.mask_token = mask_token_id
58
+ tokenizer.vocab_size += 1
59
+
60
+ print(f"Added `[MASK]` token with ID {mask_token_id}")
61
+ else:
62
+ mask_token_id = tokenizer.encoder[MASK_TOKEN]
63
+ print(f"`[MASK]` token already exists with ID {mask_token_id}")
64
+
65
+
66
+ cfg["image_tokenizer_path"] = cfg["image_tokenizer"]
67
+ cfg["init_logit_scale"] = np.log(10)
68
+ cfg["init_logit_bias"] = -10
69
+ cfg["input_dtype"] = get_input_dtype("fp32")
70
+ del cfg["text_config"]
71
+ del cfg["image_tokenizer"]
72
+ del cfg["context_length"]
73
+
74
+ model = cls(**cfg, **kwargs) # device / dtype can be injected via kwargs
75
+ state = torch.load(
76
+ Path(save_directory, "pytorch_model.bin"),
77
+ map_location="cpu"
78
+ )
79
+ model.load_state_dict(state, strict=True)
80
+ return model