Spaces:
Running
Running
More robust with tied weight keys
Browse files- convert.py +8 -4
convert.py
CHANGED
@@ -241,7 +241,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
|
|
241 |
|
242 |
|
243 |
def convert_generic(
|
244 |
-
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
|
245 |
) -> ConversionResult:
|
246 |
operations = []
|
247 |
errors = []
|
@@ -262,7 +262,7 @@ def convert_generic(
|
|
262 |
sf_in_repo = f"{prefix}.safetensors"
|
263 |
sf_filename = os.path.join(folder, sf_in_repo)
|
264 |
try:
|
265 |
-
convert_file(pt_filename, sf_filename, discard_names=
|
266 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
267 |
except Exception as e:
|
268 |
errors.append((pt_filename, e))
|
@@ -275,6 +275,10 @@ def convert(
|
|
275 |
pr_title = "Adding `safetensors` variant of this model"
|
276 |
info = api.model_info(model_id, revision=revision)
|
277 |
filenames = set(s.rfilename for s in info.siblings)
|
|
|
|
|
|
|
|
|
278 |
|
279 |
with TemporaryDirectory() as d:
|
280 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
@@ -283,6 +287,7 @@ def convert(
|
|
283 |
try:
|
284 |
operations = None
|
285 |
pr = previous_pr(api, model_id, pr_title, revision=revision)
|
|
|
286 |
|
287 |
library_name = getattr(info, "library_name", None)
|
288 |
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
@@ -293,7 +298,6 @@ def convert(
|
|
293 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
294 |
elif library_name == "transformers":
|
295 |
|
296 |
-
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
297 |
if "pytorch_model.bin" in filenames:
|
298 |
operations, errors = convert_single(
|
299 |
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
@@ -306,7 +310,7 @@ def convert(
|
|
306 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
307 |
else:
|
308 |
operations, errors = convert_generic(
|
309 |
-
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
|
310 |
)
|
311 |
|
312 |
if operations:
|
|
|
241 |
|
242 |
|
243 |
def convert_generic(
|
244 |
+
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str], discard_names: List[str],
|
245 |
) -> ConversionResult:
|
246 |
operations = []
|
247 |
errors = []
|
|
|
262 |
sf_in_repo = f"{prefix}.safetensors"
|
263 |
sf_filename = os.path.join(folder, sf_in_repo)
|
264 |
try:
|
265 |
+
convert_file(pt_filename, sf_filename, discard_names=discard_names)
|
266 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
267 |
except Exception as e:
|
268 |
errors.append((pt_filename, e))
|
|
|
275 |
pr_title = "Adding `safetensors` variant of this model"
|
276 |
info = api.model_info(model_id, revision=revision)
|
277 |
filenames = set(s.rfilename for s in info.siblings)
|
278 |
+
try:
|
279 |
+
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
280 |
+
except:
|
281 |
+
discard_names = []
|
282 |
|
283 |
with TemporaryDirectory() as d:
|
284 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
|
|
287 |
try:
|
288 |
operations = None
|
289 |
pr = previous_pr(api, model_id, pr_title, revision=revision)
|
290 |
+
|
291 |
|
292 |
library_name = getattr(info, "library_name", None)
|
293 |
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
|
|
298 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
299 |
elif library_name == "transformers":
|
300 |
|
|
|
301 |
if "pytorch_model.bin" in filenames:
|
302 |
operations, errors = convert_single(
|
303 |
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
|
|
310 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
311 |
else:
|
312 |
operations, errors = convert_generic(
|
313 |
+
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, discard_names=discard_names
|
314 |
)
|
315 |
|
316 |
if operations:
|