calculating
commited on
Commit
•
4f90f1b
1
Parent(s):
896572e
committing...
Browse files- app.py +1 -1
- sample.wav +0 -0
- utils/dist.py +11 -12
app.py
CHANGED
@@ -14,7 +14,7 @@ import os
|
|
14 |
# Global variables for model and tokenizer
|
15 |
global_generator = None
|
16 |
global_tokenizer = None
|
17 |
-
default_audio_path = "
|
18 |
|
19 |
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
|
20 |
"""Initialize the model and tokenizer"""
|
|
|
14 |
# Global variables for model and tokenizer
|
15 |
global_generator = None
|
16 |
global_tokenizer = None
|
17 |
+
default_audio_path = "sample.wav" # Changed from "testingtesting.wav"
|
18 |
|
19 |
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
|
20 |
"""Initialize the model and tokenizer"""
|
sample.wav
ADDED
Binary file (786 kB). View file
|
|
utils/dist.py
CHANGED
@@ -8,6 +8,7 @@ import requests
|
|
8 |
import hashlib
|
9 |
|
10 |
from io import BytesIO
|
|
|
11 |
|
12 |
def rank0():
|
13 |
rank = os.environ.get('RANK')
|
@@ -75,17 +76,12 @@ def init_dist():
|
|
75 |
return rank, local_rank, world_size
|
76 |
|
77 |
def load_ckpt(load_from_location, expected_hash=None):
|
|
|
78 |
if local0():
|
79 |
-
|
80 |
-
|
81 |
-
save_path = f"
|
82 |
-
|
83 |
-
response = requests.get(url, stream=True)
|
84 |
-
total_size = int(response.headers.get('content-length', 0))
|
85 |
-
with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
|
86 |
-
for chunk in response.iter_content(chunk_size=8192):
|
87 |
-
f.write(chunk)
|
88 |
-
pbar.update(len(chunk))
|
89 |
if expected_hash is not None:
|
90 |
with open(save_path, 'rb') as f:
|
91 |
file_hash = hashlib.md5(f.read()).hexdigest()
|
@@ -94,6 +90,9 @@ def load_ckpt(load_from_location, expected_hash=None):
|
|
94 |
os.remove(save_path)
|
95 |
return load_ckpt(load_from_location, expected_hash)
|
96 |
if T.distributed.is_initialized():
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
99 |
return loaded
|
|
|
8 |
import hashlib
|
9 |
|
10 |
from io import BytesIO
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
def rank0():
|
14 |
rank = os.environ.get('RANK')
|
|
|
76 |
return rank, local_rank, world_size
|
77 |
|
78 |
def load_ckpt(load_from_location, expected_hash=None):
|
79 |
+
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
|
80 |
if local0():
|
81 |
+
repo_id = "si-pbc/hertz-dev"
|
82 |
+
print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
|
83 |
+
save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
|
84 |
+
print0(f'Downloaded checkpoint to {save_path}')
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if expected_hash is not None:
|
86 |
with open(save_path, 'rb') as f:
|
87 |
file_hash = hashlib.md5(f.read()).hexdigest()
|
|
|
90 |
os.remove(save_path)
|
91 |
return load_ckpt(load_from_location, expected_hash)
|
92 |
if T.distributed.is_initialized():
|
93 |
+
save_path = [save_path]
|
94 |
+
T.distributed.broadcast_object_list(save_path, src=0)
|
95 |
+
save_path = save_path[0]
|
96 |
+
loaded = T.load(save_path, weights_only=False, map_location='cpu')
|
97 |
+
print0(f'Loaded checkpoint from {save_path}')
|
98 |
return loaded
|