calculating commited on
Commit
4f90f1b
1 Parent(s): 896572e

committing...

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. sample.wav +0 -0
  3. utils/dist.py +11 -12
app.py CHANGED
@@ -14,7 +14,7 @@ import os
14
  # Global variables for model and tokenizer
15
  global_generator = None
16
  global_tokenizer = None
17
- default_audio_path = "testingtesting.wav" # Your default audio file
18
 
19
  def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
20
  """Initialize the model and tokenizer"""
 
14
  # Global variables for model and tokenizer
15
  global_generator = None
16
  global_tokenizer = None
17
+ default_audio_path = "sample.wav" # Changed from "testingtesting.wav"
18
 
19
  def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
20
  """Initialize the model and tokenizer"""
sample.wav ADDED
Binary file (786 kB). View file
 
utils/dist.py CHANGED
@@ -8,6 +8,7 @@ import requests
8
  import hashlib
9
 
10
  from io import BytesIO
 
11
 
12
  def rank0():
13
  rank = os.environ.get('RANK')
@@ -75,17 +76,12 @@ def init_dist():
75
  return rank, local_rank, world_size
76
 
77
  def load_ckpt(load_from_location, expected_hash=None):
 
78
  if local0():
79
- os.makedirs('ckpt', exist_ok=True)
80
- url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt"
81
- save_path = f"ckpt/{load_from_location}.pt"
82
- if not os.path.exists(save_path):
83
- response = requests.get(url, stream=True)
84
- total_size = int(response.headers.get('content-length', 0))
85
- with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
86
- for chunk in response.iter_content(chunk_size=8192):
87
- f.write(chunk)
88
- pbar.update(len(chunk))
89
  if expected_hash is not None:
90
  with open(save_path, 'rb') as f:
91
  file_hash = hashlib.md5(f.read()).hexdigest()
@@ -94,6 +90,9 @@ def load_ckpt(load_from_location, expected_hash=None):
94
  os.remove(save_path)
95
  return load_ckpt(load_from_location, expected_hash)
96
  if T.distributed.is_initialized():
97
- T.distributed.barrier() # so that ranks don't try to load checkpoint before it's finished downloading
98
- loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
 
 
 
99
  return loaded
 
8
  import hashlib
9
 
10
  from io import BytesIO
11
+ from huggingface_hub import hf_hub_download
12
 
13
  def rank0():
14
  rank = os.environ.get('RANK')
 
76
  return rank, local_rank, world_size
77
 
78
  def load_ckpt(load_from_location, expected_hash=None):
79
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
80
  if local0():
81
+ repo_id = "si-pbc/hertz-dev"
82
+ print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
83
+ save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
84
+ print0(f'Downloaded checkpoint to {save_path}')
 
 
 
 
 
 
85
  if expected_hash is not None:
86
  with open(save_path, 'rb') as f:
87
  file_hash = hashlib.md5(f.read()).hexdigest()
 
90
  os.remove(save_path)
91
  return load_ckpt(load_from_location, expected_hash)
92
  if T.distributed.is_initialized():
93
+ save_path = [save_path]
94
+ T.distributed.broadcast_object_list(save_path, src=0)
95
+ save_path = save_path[0]
96
+ loaded = T.load(save_path, weights_only=False, map_location='cpu')
97
+ print0(f'Loaded checkpoint from {save_path}')
98
  return loaded