cyrilvallez HF Staff commited on
Commit
843addb
·
verified ·
1 Parent(s): d90cde9

More robust with tied weight keys

Browse files
Files changed (1) hide show
  1. convert.py +8 -4
convert.py CHANGED
@@ -241,7 +241,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
241
 
242
 
243
  def convert_generic(
244
- model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
245
  ) -> ConversionResult:
246
  operations = []
247
  errors = []
@@ -262,7 +262,7 @@ def convert_generic(
262
  sf_in_repo = f"{prefix}.safetensors"
263
  sf_filename = os.path.join(folder, sf_in_repo)
264
  try:
265
- convert_file(pt_filename, sf_filename, discard_names=[])
266
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
267
  except Exception as e:
268
  errors.append((pt_filename, e))
@@ -275,6 +275,10 @@ def convert(
275
  pr_title = "Adding `safetensors` variant of this model"
276
  info = api.model_info(model_id, revision=revision)
277
  filenames = set(s.rfilename for s in info.siblings)
 
 
 
 
278
 
279
  with TemporaryDirectory() as d:
280
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
@@ -283,6 +287,7 @@ def convert(
283
  try:
284
  operations = None
285
  pr = previous_pr(api, model_id, pr_title, revision=revision)
 
286
 
287
  library_name = getattr(info, "library_name", None)
288
  if any(filename.endswith(".safetensors") for filename in filenames) and not force:
@@ -293,7 +298,6 @@ def convert(
293
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
294
  elif library_name == "transformers":
295
 
296
- discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
297
  if "pytorch_model.bin" in filenames:
298
  operations, errors = convert_single(
299
  model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
@@ -306,7 +310,7 @@ def convert(
306
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
307
  else:
308
  operations, errors = convert_generic(
309
- model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
310
  )
311
 
312
  if operations:
 
241
 
242
 
243
  def convert_generic(
244
+ model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str], discard_names: List[str],
245
  ) -> ConversionResult:
246
  operations = []
247
  errors = []
 
262
  sf_in_repo = f"{prefix}.safetensors"
263
  sf_filename = os.path.join(folder, sf_in_repo)
264
  try:
265
+ convert_file(pt_filename, sf_filename, discard_names=discard_names)
266
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
267
  except Exception as e:
268
  errors.append((pt_filename, e))
 
275
  pr_title = "Adding `safetensors` variant of this model"
276
  info = api.model_info(model_id, revision=revision)
277
  filenames = set(s.rfilename for s in info.siblings)
278
+ try:
279
+ discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
280
+ except:
281
+ discard_names = []
282
 
283
  with TemporaryDirectory() as d:
284
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
 
287
  try:
288
  operations = None
289
  pr = previous_pr(api, model_id, pr_title, revision=revision)
290
+
291
 
292
  library_name = getattr(info, "library_name", None)
293
  if any(filename.endswith(".safetensors") for filename in filenames) and not force:
 
298
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
299
  elif library_name == "transformers":
300
 
 
301
  if "pytorch_model.bin" in filenames:
302
  operations, errors = convert_single(
303
  model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
 
310
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
311
  else:
312
  operations, errors = convert_generic(
313
+ model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, discard_names=discard_names
314
  )
315
 
316
  if operations: