Spaces:
Build error
Build error
ZhifengKong
commited on
Commit
·
92740f3
1
Parent(s):
15f9587
upload
Browse files- LICENSE +21 -0
- LICENSE_OPT_IML.md +65 -0
- app.py +223 -0
- audio/wav1.wav +0 -0
- audio/wav2.wav +0 -0
- audio/wav3.wav +0 -0
- audio/wav4.wav +0 -0
- audio/wav5.wav +0 -0
- audio/wav6.wav +0 -0
- chat.yaml +23 -0
- data.py +243 -0
- inference_utils.py +81 -0
- ms_clap/.DS_Store +0 -0
- ms_clap/.gitignore +350 -0
- ms_clap/CODE_OF_CONDUCT.md +9 -0
- ms_clap/LICENSE +21 -0
- ms_clap/README.md +120 -0
- ms_clap/SECURITY.md +41 -0
- ms_clap/SUPPORT.md +25 -0
- ms_clap/requirements.txt +50 -0
- ms_clap/src/.DS_Store +0 -0
- ms_clap/src/CLAPWrapper.py +458 -0
- ms_clap/src/__init__.py +0 -0
- ms_clap/src/audio_captioning.py +25 -0
- ms_clap/src/configs/config_2022.yml +26 -0
- ms_clap/src/configs/config_2023.yml +26 -0
- ms_clap/src/configs/config_clapcap.yml +34 -0
- ms_clap/src/esc50_dataset.py +82 -0
- ms_clap/src/models/__init__.py +6 -0
- ms_clap/src/models/audio.py +186 -0
- ms_clap/src/models/clap.py +109 -0
- ms_clap/src/models/config.py +128 -0
- ms_clap/src/models/htsat.py +956 -0
- ms_clap/src/models/mapper.py +200 -0
- ms_clap/src/models/pytorch_utils.py +184 -0
- ms_clap/src/models/utils.py +26 -0
- ms_clap/src/zero_shot_classification.py +46 -0
- ms_clap/src/zero_shot_predictions.py +51 -0
- requirements.txt +14 -0
- src/.DS_Store +0 -0
- src/__init__.py +2 -0
- src/factory.py +198 -0
- src/flamingo.py +260 -0
- src/flamingo_lm.py +177 -0
- src/helpers.py +379 -0
- src/utils.py +54 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 NVIDIA CORPORATION.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE_OPT_IML.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h2 align="center"> OPT-IML 175B LICENSE AGREEMENT </h2>
|
2 |
+
|
3 |
+
This License Agreement (as may be amended in accordance with this License Agreement, **“License”**), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (**“Licensee”** or **“you”**) and Meta Platforms, Inc. (**“Meta”** or **“we”**) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Meta under this License (**“Software”**) and any specifications, manuals, documentation, and other written information provided by Meta related to the Software (**“Documentation”**).
|
4 |
+
|
5 |
+
**By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Meta that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.**
|
6 |
+
<br><br>
|
7 |
+
1. **LICENSE GRANT**
|
8 |
+
<br><br>
|
9 |
+
a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Meta grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Meta’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Meta’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
|
10 |
+
<br><br>
|
11 |
+
b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
|
12 |
+
<br><br>
|
13 |
+
c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Meta and its licensors reserve all rights not expressly granted by this License.
|
14 |
+
<br><br>
|
15 |
+
2. **RESTRICTIONS**
|
16 |
+
<br><br>
|
17 |
+
You will not, and will not permit, assist or cause any third party to:
|
18 |
+
<br><br>
|
19 |
+
a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law, including accessing the Software Products from an embargoed country as prohibited by the U.S. government, and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
|
20 |
+
<br><br>
|
21 |
+
b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
|
22 |
+
<br><br>
|
23 |
+
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Meta in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Meta; or
|
24 |
+
<br><br>
|
25 |
+
d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
|
26 |
+
<br><br>
|
27 |
+
3. **ATTRIBUTION**
|
28 |
+
<br><br>
|
29 |
+
Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “OPT-IML 175B is licensed under the OPT-175B license, Copyright (c) Meta Platforms, Inc. All Rights Reserved.”
|
30 |
+
<br><br>
|
31 |
+
4. **DISCLAIMERS**
|
32 |
+
<br><br>
|
33 |
+
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. META EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. META MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
|
34 |
+
<br><br>
|
35 |
+
5. **LIMITATION OF LIABILITY**
|
36 |
+
<br><br>
|
37 |
+
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL META BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF META HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, **“SOFTWARE MATERIALS”**) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A **“HIGH-RISK USE”**). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
|
38 |
+
<br><br>
|
39 |
+
6. **INDEMNIFICATION**
|
40 |
+
<br><br>
|
41 |
+
You will indemnify, defend and hold harmless Meta and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the **“Meta Parties”**) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Meta Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, **“Claims”**) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Meta Parties of any such Claims, and cooperate with Meta Parties in defending such Claims. You will also grant the Meta Parties sole control of the defense or settlement, at Meta’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Meta or the other Meta Parties.
|
42 |
+
<br><br>
|
43 |
+
7. **TERMINATION; SURVIVAL**
|
44 |
+
<br><br>
|
45 |
+
a. This License will automatically terminate upon any breach by you of the terms of this License.
|
46 |
+
<br><br>
|
47 |
+
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
|
48 |
+
<br><br>
|
49 |
+
c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
|
50 |
+
<br><br>
|
51 |
+
8. **THIRD PARTY MATERIALS**
|
52 |
+
<br><br>
|
53 |
+
The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, **“Third Party Materials”**), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Meta does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
|
54 |
+
<br><br>
|
55 |
+
9. **TRADEMARKS**
|
56 |
+
<br><br>
|
57 |
+
Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Meta without the prior written permission of Meta, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
|
58 |
+
<br><br>
|
59 |
+
10. **APPLICABLE LAW; DISPUTE RESOLUTION**
|
60 |
+
<br><br>
|
61 |
+
This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
|
62 |
+
<br><br>
|
63 |
+
11. **MISCELLANEOUS**
|
64 |
+
<br><br>
|
65 |
+
If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Meta to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Meta regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Meta regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
app.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
from pydub import AudioSegment
|
11 |
+
import soundfile as sf
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import laion_clap
|
16 |
+
|
17 |
+
from inference_utils import prepare_tokenizer, prepare_model, inference
|
18 |
+
from data import AudioTextDataProcessor
|
19 |
+
|
20 |
+
|
21 |
+
def load_laionclap():
|
22 |
+
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').cuda()
|
23 |
+
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
|
24 |
+
model.eval()
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
def int16_to_float32(x):
|
29 |
+
return (x / 32767.0).astype(np.float32)
|
30 |
+
|
31 |
+
|
32 |
+
def float32_to_int16(x):
|
33 |
+
x = np.clip(x, a_min=-1., a_max=1.)
|
34 |
+
return (x * 32767.).astype(np.int16)
|
35 |
+
|
36 |
+
|
37 |
+
def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0):
|
38 |
+
if file_path.endswith('.mp3'):
|
39 |
+
audio = AudioSegment.from_file(file_path)
|
40 |
+
if len(audio) > (start + duration) * 1000:
|
41 |
+
audio = audio[start * 1000:(start + duration) * 1000]
|
42 |
+
|
43 |
+
if audio.frame_rate != target_sr:
|
44 |
+
audio = audio.set_frame_rate(target_sr)
|
45 |
+
|
46 |
+
if audio.channels > 1:
|
47 |
+
audio = audio.set_channels(1)
|
48 |
+
|
49 |
+
data = np.array(audio.get_array_of_samples())
|
50 |
+
if audio.sample_width == 2:
|
51 |
+
data = data.astype(np.float32) / np.iinfo(np.int16).max
|
52 |
+
elif audio.sample_width == 4:
|
53 |
+
data = data.astype(np.float32) / np.iinfo(np.int32).max
|
54 |
+
else:
|
55 |
+
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
|
56 |
+
|
57 |
+
else:
|
58 |
+
with sf.SoundFile(file_path) as audio:
|
59 |
+
original_sr = audio.samplerate
|
60 |
+
channels = audio.channels
|
61 |
+
|
62 |
+
max_frames = int((start + duration) * original_sr)
|
63 |
+
|
64 |
+
audio.seek(int(start * original_sr))
|
65 |
+
frames_to_read = min(max_frames, len(audio))
|
66 |
+
data = audio.read(frames_to_read)
|
67 |
+
|
68 |
+
if data.max() > 1 or data.min() < -1:
|
69 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
70 |
+
|
71 |
+
if original_sr != target_sr:
|
72 |
+
if channels == 1:
|
73 |
+
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
|
74 |
+
else:
|
75 |
+
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
|
76 |
+
else:
|
77 |
+
if channels != 1:
|
78 |
+
data = data.T[0]
|
79 |
+
|
80 |
+
if data.min() >= 0:
|
81 |
+
data = 2 * data / abs(data.max()) - 1.0
|
82 |
+
else:
|
83 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
84 |
+
return data
|
85 |
+
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs):
|
89 |
+
try:
|
90 |
+
data = load_audio(audio_file, target_sr=48000)
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
print(audio_file, 'unsuccessful due to', e)
|
94 |
+
return [0.0] * len(outputs)
|
95 |
+
|
96 |
+
audio_data = data.reshape(1, -1)
|
97 |
+
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda()
|
98 |
+
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
|
99 |
+
|
100 |
+
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
|
101 |
+
|
102 |
+
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
103 |
+
cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed)
|
104 |
+
return cos_similarity.squeeze().cpu().numpy()
|
105 |
+
|
106 |
+
|
107 |
+
inference_kwargs = {
|
108 |
+
"do_sample": True,
|
109 |
+
"top_k": 50,
|
110 |
+
"top_p": 0.95,
|
111 |
+
"num_return_sequences": 10
|
112 |
+
}
|
113 |
+
|
114 |
+
config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader)
|
115 |
+
clap_config = config['clap_config']
|
116 |
+
model_config = config['model_config']
|
117 |
+
|
118 |
+
text_tokenizer = prepare_tokenizer(model_config)
|
119 |
+
DataProcessor = AudioTextDataProcessor(
|
120 |
+
data_root='./',
|
121 |
+
clap_config=clap_config,
|
122 |
+
tokenizer=text_tokenizer,
|
123 |
+
max_tokens=512,
|
124 |
+
)
|
125 |
+
|
126 |
+
laionclap_model = load_laionclap()
|
127 |
+
|
128 |
+
model = prepare_model(
|
129 |
+
model_config=model_config,
|
130 |
+
clap_config=clap_config,
|
131 |
+
checkpoint_path='chat.pt'
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
def inference_item(name, prompt):
|
136 |
+
item = {
|
137 |
+
'name': str(name),
|
138 |
+
'prefix': 'The task is dialog.',
|
139 |
+
'prompt': str(prompt)
|
140 |
+
}
|
141 |
+
processed_item = DataProcessor.process(item)
|
142 |
+
|
143 |
+
outputs = inference(
|
144 |
+
model, text_tokenizer, item, processed_item,
|
145 |
+
inference_kwargs,
|
146 |
+
)
|
147 |
+
|
148 |
+
laionclap_scores = compute_laionclap_text_audio_sim(
|
149 |
+
item["name"],
|
150 |
+
laionclap_model,
|
151 |
+
outputs
|
152 |
+
)
|
153 |
+
|
154 |
+
outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)]
|
155 |
+
outputs_joint.sort(key=lambda x: -x[1])
|
156 |
+
|
157 |
+
return outputs_joint[0][0]
|
158 |
+
|
159 |
+
|
160 |
+
with gr.Blocks(title="Audio Flamingo - Demo") as ui:
|
161 |
+
|
162 |
+
gr.HTML(
|
163 |
+
"""
|
164 |
+
<div style="text-align: center; max-width: 900px; margin: 0 auto;">
|
165 |
+
<div
|
166 |
+
style="
|
167 |
+
display: inline-flex;
|
168 |
+
align-items: center;
|
169 |
+
gap: 0.8rem;
|
170 |
+
font-size: 1.5rem;
|
171 |
+
"
|
172 |
+
>
|
173 |
+
<h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;">
|
174 |
+
Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities
|
175 |
+
</h1>
|
176 |
+
</div>
|
177 |
+
<p style="margin-bottom: 10px; font-size: 125%">
|
178 |
+
<a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo]</a>
|
179 |
+
</p>
|
180 |
+
</div>
|
181 |
+
"""
|
182 |
+
)
|
183 |
+
gr.HTML(
|
184 |
+
"""
|
185 |
+
<div>
|
186 |
+
<h3>Model Overview</h3>
|
187 |
+
Audio Flamingo is an audio language model that can understand sounds beyond speech.
|
188 |
+
It can also answer questions about the sound in natural language.
|
189 |
+
Examples of questions include:
|
190 |
+
"Can you briefly describe what you hear in this audio?",
|
191 |
+
"What is the emotion conveyed in this music?",
|
192 |
+
"Where is this audio usually heard?",
|
193 |
+
or "What place is this music usually played at?".
|
194 |
+
</div>
|
195 |
+
"""
|
196 |
+
)
|
197 |
+
|
198 |
+
name = gr.Textbox(
|
199 |
+
label="Audio file path (choose one from: audio/wav{1--6}.wav)",
|
200 |
+
value="audio/wav5.wav"
|
201 |
+
)
|
202 |
+
prompt = gr.Textbox(
|
203 |
+
label="Instruction",
|
204 |
+
value='Can you briefly describe what you hear in this audio?'
|
205 |
+
)
|
206 |
+
|
207 |
+
with gr.Row():
|
208 |
+
play_audio_button = gr.Button("Play Audio")
|
209 |
+
audio_output = gr.Audio(label="Playback")
|
210 |
+
play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output)
|
211 |
+
|
212 |
+
inference_button = gr.Button("Inference")
|
213 |
+
|
214 |
+
output_text = gr.Textbox(label="Audio Flamingo output")
|
215 |
+
|
216 |
+
inference_button.click(
|
217 |
+
fn=inference_item,
|
218 |
+
inputs=[name, prompt],
|
219 |
+
outputs=output_text
|
220 |
+
)
|
221 |
+
|
222 |
+
ui.queue()
|
223 |
+
ui.launch()
|
audio/wav1.wav
ADDED
Binary file (960 kB). View file
|
|
audio/wav2.wav
ADDED
Binary file (960 kB). View file
|
|
audio/wav3.wav
ADDED
Binary file (960 kB). View file
|
|
audio/wav4.wav
ADDED
Binary file (441 kB). View file
|
|
audio/wav5.wav
ADDED
Binary file (441 kB). View file
|
|
audio/wav6.wav
ADDED
Binary file (441 kB). View file
|
|
chat.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clap_config:
|
2 |
+
method: microsoft-clap
|
3 |
+
audio_embed_dim: 1024
|
4 |
+
config_root: ./ms_clap/src/configs
|
5 |
+
model_name: 'clapcap'
|
6 |
+
checkpoint: ./clapcap_weights_2023.pth
|
7 |
+
window_length: 7.0
|
8 |
+
window_overlap: 5.25
|
9 |
+
max_num_window: 16
|
10 |
+
max_num_fewshot: 4
|
11 |
+
|
12 |
+
model_config:
|
13 |
+
cache_dir: None
|
14 |
+
lang_encoder_path: facebook/opt-iml-max-1.3b
|
15 |
+
tokenizer_path: facebook/opt-iml-max-1.3b
|
16 |
+
cross_attn_every_n_layers: 1
|
17 |
+
audio_transformer_kwargs: {
|
18 |
+
n_head: 8,
|
19 |
+
n_layers: 3,
|
20 |
+
d_inner: 2048,
|
21 |
+
max_num_media: 128,
|
22 |
+
max_window_per_audio: 16,
|
23 |
+
}
|
data.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import functools
|
5 |
+
import io
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
import string
|
13 |
+
import subprocess
|
14 |
+
import sys
|
15 |
+
import yaml
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
from collections import defaultdict
|
20 |
+
from copy import deepcopy
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from functools import partial
|
23 |
+
from pydub import AudioSegment
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torchvision
|
28 |
+
from torch.utils.data import DataLoader, Dataset, get_worker_info
|
29 |
+
from torch.utils.data.distributed import DistributedSampler
|
30 |
+
|
31 |
+
|
32 |
+
from transformers import AutoTokenizer
|
33 |
+
|
34 |
+
import librosa
|
35 |
+
import soundfile as sf
|
36 |
+
|
37 |
+
|
38 |
+
def int16_to_float32(x):
|
39 |
+
return (x / 32767.0).astype(np.float32)
|
40 |
+
|
41 |
+
|
42 |
+
def float32_to_int16(x):
|
43 |
+
x = np.clip(x, a_min=-1., a_max=1.)
|
44 |
+
return (x * 32767.).astype(np.int16)
|
45 |
+
|
46 |
+
|
47 |
+
class AudioTextDataProcessor:
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
data_root: str,
|
51 |
+
clap_config: dict,
|
52 |
+
tokenizer,
|
53 |
+
max_tokens: int,
|
54 |
+
**kwargs
|
55 |
+
):
|
56 |
+
self.data_root = data_root
|
57 |
+
self.clap_config = clap_config
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
self.tokenizer.padding_side = "right"
|
60 |
+
self.max_tokens = max_tokens
|
61 |
+
|
62 |
+
def get_num_windows(self, T, sr):
|
63 |
+
clap_config = self.clap_config
|
64 |
+
window_length = int(float(clap_config["window_length"]) * sr)
|
65 |
+
window_overlap = int(float(clap_config["window_overlap"]) * sr)
|
66 |
+
max_num_window = int(clap_config["max_num_window"])
|
67 |
+
|
68 |
+
num_windows = 1
|
69 |
+
if T <= window_length:
|
70 |
+
num_windows = 1
|
71 |
+
full_length = window_length
|
72 |
+
elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
|
73 |
+
num_windows = max_num_window
|
74 |
+
full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
|
75 |
+
else:
|
76 |
+
num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
|
77 |
+
full_length = num_windows * window_length - (num_windows - 1) * window_overlap
|
78 |
+
|
79 |
+
return num_windows, full_length
|
80 |
+
|
81 |
+
def load_audio(self, file_path, target_sr=44100, duration=30.0, start=0.0):
|
82 |
+
if file_path.endswith('.mp3'):
|
83 |
+
audio = AudioSegment.from_file(file_path)
|
84 |
+
if len(audio) > (start + duration) * 1000:
|
85 |
+
audio = audio[start * 1000:(start + duration) * 1000]
|
86 |
+
|
87 |
+
if audio.frame_rate != target_sr:
|
88 |
+
audio = audio.set_frame_rate(target_sr)
|
89 |
+
|
90 |
+
if audio.channels > 1:
|
91 |
+
audio = audio.set_channels(1)
|
92 |
+
|
93 |
+
data = np.array(audio.get_array_of_samples())
|
94 |
+
if audio.sample_width == 2:
|
95 |
+
data = data.astype(np.float32) / np.iinfo(np.int16).max
|
96 |
+
elif audio.sample_width == 4:
|
97 |
+
data = data.astype(np.float32) / np.iinfo(np.int32).max
|
98 |
+
else:
|
99 |
+
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
|
100 |
+
|
101 |
+
else:
|
102 |
+
with sf.SoundFile(file_path) as audio:
|
103 |
+
original_sr = audio.samplerate
|
104 |
+
channels = audio.channels
|
105 |
+
|
106 |
+
max_frames = int((start + duration) * original_sr)
|
107 |
+
|
108 |
+
audio.seek(int(start * original_sr))
|
109 |
+
frames_to_read = min(max_frames, len(audio))
|
110 |
+
data = audio.read(frames_to_read)
|
111 |
+
|
112 |
+
if data.max() > 1 or data.min() < -1:
|
113 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
114 |
+
|
115 |
+
if original_sr != target_sr:
|
116 |
+
if channels == 1:
|
117 |
+
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
|
118 |
+
else:
|
119 |
+
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
|
120 |
+
else:
|
121 |
+
if channels != 1:
|
122 |
+
data = data.T[0]
|
123 |
+
|
124 |
+
if data.min() >= 0:
|
125 |
+
data = 2 * data / abs(data.max()) - 1.0
|
126 |
+
else:
|
127 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
128 |
+
|
129 |
+
assert len(data.shape) == 1, data.shape
|
130 |
+
return data
|
131 |
+
|
132 |
+
def compute_sliding_window(self, audio_file, audio_start=0.0):
|
133 |
+
if type(audio_start) == str:
|
134 |
+
audio_start = float(audio_start)
|
135 |
+
|
136 |
+
clap_config = self.clap_config
|
137 |
+
|
138 |
+
if clap_config["method"] == 'laion-clap':
|
139 |
+
sr = 48000
|
140 |
+
elif clap_config["method"] == 'microsoft-clap':
|
141 |
+
sr = 44100
|
142 |
+
else:
|
143 |
+
raise NotImplementedError
|
144 |
+
|
145 |
+
window_length = int(float(clap_config["window_length"]) * sr)
|
146 |
+
window_overlap = int(float(clap_config["window_overlap"]) * sr)
|
147 |
+
max_num_window = int(clap_config["max_num_window"])
|
148 |
+
duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"]
|
149 |
+
|
150 |
+
audio_data = self.load_audio(audio_file, sr, duration, audio_start)
|
151 |
+
T = len(audio_data)
|
152 |
+
num_windows, full_length = self.get_num_windows(T, sr)
|
153 |
+
|
154 |
+
if full_length > T:
|
155 |
+
audio_data = np.append(audio_data, np.zeros(full_length - T))
|
156 |
+
audio_data = audio_data.reshape(1, -1)
|
157 |
+
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
|
158 |
+
|
159 |
+
audio_clips = []
|
160 |
+
audio_embed_mask = torch.zeros(max_num_window)
|
161 |
+
for i in range(num_windows):
|
162 |
+
start = i * (window_length - window_overlap)
|
163 |
+
audio_clips.append(audio_data_tensor[:, start:start+window_length])
|
164 |
+
audio_embed_mask[i] = 1
|
165 |
+
|
166 |
+
assert sum(audio_embed_mask) == num_windows
|
167 |
+
|
168 |
+
if num_windows < max_num_window:
|
169 |
+
for _ in range(max_num_window - num_windows):
|
170 |
+
audio_clips.append(torch.zeros_like(audio_clips[-1]))
|
171 |
+
|
172 |
+
audio_clips = torch.cat(audio_clips) # (max_num_window, window_length * sr) cuda tensor
|
173 |
+
|
174 |
+
return audio_clips, audio_embed_mask
|
175 |
+
|
176 |
+
def preprocess_string_for_eval(self, x):
|
177 |
+
x = x.rstrip().lstrip()
|
178 |
+
x = x.lower()
|
179 |
+
return x
|
180 |
+
|
181 |
+
def process(self, item):
|
182 |
+
if type(item['name']) is str:
|
183 |
+
audio_files = [os.path.join(self.data_root, item['name'])]
|
184 |
+
audio_starts = [0 if 'audio_start' not in item else float(item['audio_start'])]
|
185 |
+
else:
|
186 |
+
audio_files = [os.path.join(self.data_root, name) for name in item['name']]
|
187 |
+
audio_starts = [0] * len(audio_files) if 'audio_start' not in item else item['audio_start']
|
188 |
+
|
189 |
+
audio_clips, audio_embed_mask = [], []
|
190 |
+
for audio_file, audio_start in zip(audio_files, audio_starts):
|
191 |
+
this_audio_clips, this_audio_embed_mask = self.compute_sliding_window(audio_file, audio_start)
|
192 |
+
audio_clips.append(this_audio_clips)
|
193 |
+
audio_embed_mask.append(this_audio_embed_mask)
|
194 |
+
|
195 |
+
audio_clips = torch.cat(audio_clips)
|
196 |
+
audio_embed_mask = torch.cat(audio_embed_mask)
|
197 |
+
|
198 |
+
correct_num_windows = int(self.clap_config["max_num_window"]) * int(self.clap_config["max_num_fewshot"])
|
199 |
+
if len(audio_clips) < correct_num_windows:
|
200 |
+
audio_clips = torch.cat([
|
201 |
+
audio_clips,
|
202 |
+
torch.zeros(correct_num_windows - len(audio_clips), audio_clips.shape[1])
|
203 |
+
])
|
204 |
+
audio_embed_mask = torch.cat([
|
205 |
+
audio_embed_mask,
|
206 |
+
torch.zeros(correct_num_windows - len(audio_embed_mask))
|
207 |
+
])
|
208 |
+
|
209 |
+
audio_clips.requires_grad = False
|
210 |
+
audio_embed_mask.requires_grad = False
|
211 |
+
|
212 |
+
assert type(item['name']) is str
|
213 |
+
|
214 |
+
# simple data - 1 audio, 1 text
|
215 |
+
if 'prompt' in item:
|
216 |
+
text_prompt = item['prompt'].lower()
|
217 |
+
prefix = item['prefix'].lower() # the task is xxx.
|
218 |
+
sample = "{}{} <audio>{}\nanswer:{}".format(
|
219 |
+
self.tokenizer.bos_token,
|
220 |
+
self.preprocess_string_for_eval(prefix),
|
221 |
+
self.preprocess_string_for_eval(text_prompt),
|
222 |
+
self.tokenizer.sep_token
|
223 |
+
)
|
224 |
+
|
225 |
+
# dialog data - 1 audio, multiple text
|
226 |
+
elif 'dialogue' in item:
|
227 |
+
dialogue = item['dialogue']
|
228 |
+
prefix = item['prefix'].lower() # the task is dialog.
|
229 |
+
sample = f"{self.tokenizer.bos_token}{prefix}<audio>"
|
230 |
+
for each_round in dialogue:
|
231 |
+
sample = sample + f"user: {each_round['user']} \nassistant: {self.tokenizer.sep_token}"
|
232 |
+
if 'assistant' in each_round:
|
233 |
+
sample = sample + f"{each_round['assistant']}<|endofchunk|>{self.tokenizer.eos_token}\n"
|
234 |
+
|
235 |
+
text = self.tokenizer(
|
236 |
+
sample,
|
237 |
+
max_length=self.max_tokens*5,
|
238 |
+
padding="longest",
|
239 |
+
truncation="only_first",
|
240 |
+
return_tensors="pt"
|
241 |
+
)
|
242 |
+
|
243 |
+
return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"])
|
inference_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import string
|
6 |
+
import yaml
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import AutoTokenizer, set_seed
|
11 |
+
set_seed(0)
|
12 |
+
|
13 |
+
from data import AudioTextDataProcessor
|
14 |
+
from src.factory import create_model_and_transforms
|
15 |
+
|
16 |
+
|
17 |
+
def prepare_tokenizer(model_config):
|
18 |
+
tokenizer_path = model_config['tokenizer_path']
|
19 |
+
cache_dir = model_config['cache_dir']
|
20 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
21 |
+
tokenizer_path,
|
22 |
+
local_files_only=False,
|
23 |
+
trust_remote_code=True,
|
24 |
+
cache_dir=cache_dir,
|
25 |
+
)
|
26 |
+
text_tokenizer.add_special_tokens(
|
27 |
+
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
|
28 |
+
)
|
29 |
+
if text_tokenizer.pad_token is None:
|
30 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
31 |
+
if text_tokenizer.sep_token is None:
|
32 |
+
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
|
33 |
+
return text_tokenizer
|
34 |
+
|
35 |
+
|
36 |
+
def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
|
37 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
38 |
+
model, tokenizer = create_model_and_transforms(
|
39 |
+
**model_config,
|
40 |
+
clap_config=clap_config,
|
41 |
+
use_local_files=False,
|
42 |
+
gradient_checkpointing=False,
|
43 |
+
freeze_lm_embeddings=False,
|
44 |
+
)
|
45 |
+
model.eval()
|
46 |
+
model = model.to(device_id)
|
47 |
+
|
48 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
49 |
+
model_state_dict = checkpoint["model_state_dict"]
|
50 |
+
model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
|
51 |
+
model.load_state_dict(model_state_dict, False)
|
52 |
+
|
53 |
+
return model
|
54 |
+
|
55 |
+
|
56 |
+
def inference(model, tokenizer, item, processed_item, inference_kwargs, device_id=0):
|
57 |
+
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
|
58 |
+
audio_clips = audio_clips.to(device_id, dtype=None, non_blocking=True)
|
59 |
+
audio_embed_mask = audio_embed_mask.to(device_id, dtype=None, non_blocking=True)
|
60 |
+
input_ids = input_ids.to(device_id, dtype=None, non_blocking=True).squeeze()
|
61 |
+
|
62 |
+
media_token_id = tokenizer.encode("<audio>")[-1]
|
63 |
+
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
|
64 |
+
sep_token_id = tokenizer.sep_token_id
|
65 |
+
eos_token_id = tokenizer.eos_token_id
|
66 |
+
|
67 |
+
outputs = model.generate(
|
68 |
+
audio_x=audio_clips.unsqueeze(0),
|
69 |
+
audio_x_mask=audio_embed_mask.unsqueeze(0),
|
70 |
+
lang_x=input_ids.unsqueeze(0),
|
71 |
+
eos_token_id=eos_token_id,
|
72 |
+
max_new_tokens=128,
|
73 |
+
**inference_kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
outputs_decoded = [
|
77 |
+
tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
|
78 |
+
]
|
79 |
+
|
80 |
+
return outputs_decoded
|
81 |
+
|
ms_clap/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ms_clap/.gitignore
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Ignore Visual Studio temporary files, build results, and
|
2 |
+
## files generated by popular Visual Studio add-ons.
|
3 |
+
##
|
4 |
+
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
5 |
+
|
6 |
+
# User-specific files
|
7 |
+
*.rsuser
|
8 |
+
*.suo
|
9 |
+
*.user
|
10 |
+
*.userosscache
|
11 |
+
*.sln.docstates
|
12 |
+
|
13 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
14 |
+
*.userprefs
|
15 |
+
|
16 |
+
# Mono auto generated files
|
17 |
+
mono_crash.*
|
18 |
+
|
19 |
+
# Build results
|
20 |
+
[Dd]ebug/
|
21 |
+
[Dd]ebugPublic/
|
22 |
+
[Rr]elease/
|
23 |
+
[Rr]eleases/
|
24 |
+
x64/
|
25 |
+
x86/
|
26 |
+
[Aa][Rr][Mm]/
|
27 |
+
[Aa][Rr][Mm]64/
|
28 |
+
bld/
|
29 |
+
[Bb]in/
|
30 |
+
[Oo]bj/
|
31 |
+
[Ll]og/
|
32 |
+
[Ll]ogs/
|
33 |
+
|
34 |
+
# Visual Studio 2015/2017 cache/options directory
|
35 |
+
.vs/
|
36 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
37 |
+
#wwwroot/
|
38 |
+
|
39 |
+
# Visual Studio 2017 auto generated files
|
40 |
+
Generated\ Files/
|
41 |
+
|
42 |
+
# MSTest test Results
|
43 |
+
[Tt]est[Rr]esult*/
|
44 |
+
[Bb]uild[Ll]og.*
|
45 |
+
|
46 |
+
# NUnit
|
47 |
+
*.VisualState.xml
|
48 |
+
TestResult.xml
|
49 |
+
nunit-*.xml
|
50 |
+
|
51 |
+
# Build Results of an ATL Project
|
52 |
+
[Dd]ebugPS/
|
53 |
+
[Rr]eleasePS/
|
54 |
+
dlldata.c
|
55 |
+
|
56 |
+
# Benchmark Results
|
57 |
+
BenchmarkDotNet.Artifacts/
|
58 |
+
|
59 |
+
# .NET Core
|
60 |
+
project.lock.json
|
61 |
+
project.fragment.lock.json
|
62 |
+
artifacts/
|
63 |
+
|
64 |
+
# StyleCop
|
65 |
+
StyleCopReport.xml
|
66 |
+
|
67 |
+
# Files built by Visual Studio
|
68 |
+
*_i.c
|
69 |
+
*_p.c
|
70 |
+
*_h.h
|
71 |
+
*.ilk
|
72 |
+
*.meta
|
73 |
+
*.obj
|
74 |
+
*.iobj
|
75 |
+
*.pch
|
76 |
+
*.pdb
|
77 |
+
*.ipdb
|
78 |
+
*.pgc
|
79 |
+
*.pgd
|
80 |
+
*.rsp
|
81 |
+
*.sbr
|
82 |
+
*.tlb
|
83 |
+
*.tli
|
84 |
+
*.tlh
|
85 |
+
*.tmp
|
86 |
+
*.tmp_proj
|
87 |
+
*_wpftmp.csproj
|
88 |
+
*.log
|
89 |
+
*.vspscc
|
90 |
+
*.vssscc
|
91 |
+
.builds
|
92 |
+
*.pidb
|
93 |
+
*.svclog
|
94 |
+
*.scc
|
95 |
+
|
96 |
+
# Chutzpah Test files
|
97 |
+
_Chutzpah*
|
98 |
+
|
99 |
+
# Visual C++ cache files
|
100 |
+
ipch/
|
101 |
+
*.aps
|
102 |
+
*.ncb
|
103 |
+
*.opendb
|
104 |
+
*.opensdf
|
105 |
+
*.sdf
|
106 |
+
*.cachefile
|
107 |
+
*.VC.db
|
108 |
+
*.VC.VC.opendb
|
109 |
+
|
110 |
+
# Visual Studio profiler
|
111 |
+
*.psess
|
112 |
+
*.vsp
|
113 |
+
*.vspx
|
114 |
+
*.sap
|
115 |
+
|
116 |
+
# Visual Studio Trace Files
|
117 |
+
*.e2e
|
118 |
+
|
119 |
+
# TFS 2012 Local Workspace
|
120 |
+
$tf/
|
121 |
+
|
122 |
+
# Guidance Automation Toolkit
|
123 |
+
*.gpState
|
124 |
+
|
125 |
+
# ReSharper is a .NET coding add-in
|
126 |
+
_ReSharper*/
|
127 |
+
*.[Rr]e[Ss]harper
|
128 |
+
*.DotSettings.user
|
129 |
+
|
130 |
+
# TeamCity is a build add-in
|
131 |
+
_TeamCity*
|
132 |
+
|
133 |
+
# DotCover is a Code Coverage Tool
|
134 |
+
*.dotCover
|
135 |
+
|
136 |
+
# AxoCover is a Code Coverage Tool
|
137 |
+
.axoCover/*
|
138 |
+
!.axoCover/settings.json
|
139 |
+
|
140 |
+
# Visual Studio code coverage results
|
141 |
+
*.coverage
|
142 |
+
*.coveragexml
|
143 |
+
|
144 |
+
# NCrunch
|
145 |
+
_NCrunch_*
|
146 |
+
.*crunch*.local.xml
|
147 |
+
nCrunchTemp_*
|
148 |
+
|
149 |
+
# MightyMoose
|
150 |
+
*.mm.*
|
151 |
+
AutoTest.Net/
|
152 |
+
|
153 |
+
# Web workbench (sass)
|
154 |
+
.sass-cache/
|
155 |
+
|
156 |
+
# Installshield output folder
|
157 |
+
[Ee]xpress/
|
158 |
+
|
159 |
+
# DocProject is a documentation generator add-in
|
160 |
+
DocProject/buildhelp/
|
161 |
+
DocProject/Help/*.HxT
|
162 |
+
DocProject/Help/*.HxC
|
163 |
+
DocProject/Help/*.hhc
|
164 |
+
DocProject/Help/*.hhk
|
165 |
+
DocProject/Help/*.hhp
|
166 |
+
DocProject/Help/Html2
|
167 |
+
DocProject/Help/html
|
168 |
+
|
169 |
+
# Click-Once directory
|
170 |
+
publish/
|
171 |
+
|
172 |
+
# Publish Web Output
|
173 |
+
*.[Pp]ublish.xml
|
174 |
+
*.azurePubxml
|
175 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
176 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
177 |
+
*.pubxml
|
178 |
+
*.publishproj
|
179 |
+
|
180 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
181 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
182 |
+
# in these scripts will be unencrypted
|
183 |
+
PublishScripts/
|
184 |
+
|
185 |
+
# NuGet Packages
|
186 |
+
*.nupkg
|
187 |
+
# NuGet Symbol Packages
|
188 |
+
*.snupkg
|
189 |
+
# The packages folder can be ignored because of Package Restore
|
190 |
+
**/[Pp]ackages/*
|
191 |
+
# except build/, which is used as an MSBuild target.
|
192 |
+
!**/[Pp]ackages/build/
|
193 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
194 |
+
#!**/[Pp]ackages/repositories.config
|
195 |
+
# NuGet v3's project.json files produces more ignorable files
|
196 |
+
*.nuget.props
|
197 |
+
*.nuget.targets
|
198 |
+
|
199 |
+
# Microsoft Azure Build Output
|
200 |
+
csx/
|
201 |
+
*.build.csdef
|
202 |
+
|
203 |
+
# Microsoft Azure Emulator
|
204 |
+
ecf/
|
205 |
+
rcf/
|
206 |
+
|
207 |
+
# Windows Store app package directories and files
|
208 |
+
AppPackages/
|
209 |
+
BundleArtifacts/
|
210 |
+
Package.StoreAssociation.xml
|
211 |
+
_pkginfo.txt
|
212 |
+
*.appx
|
213 |
+
*.appxbundle
|
214 |
+
*.appxupload
|
215 |
+
|
216 |
+
# Visual Studio cache files
|
217 |
+
# files ending in .cache can be ignored
|
218 |
+
*.[Cc]ache
|
219 |
+
# but keep track of directories ending in .cache
|
220 |
+
!?*.[Cc]ache/
|
221 |
+
|
222 |
+
# Others
|
223 |
+
ClientBin/
|
224 |
+
~$*
|
225 |
+
*~
|
226 |
+
*.dbmdl
|
227 |
+
*.dbproj.schemaview
|
228 |
+
*.jfm
|
229 |
+
*.pfx
|
230 |
+
*.publishsettings
|
231 |
+
orleans.codegen.cs
|
232 |
+
|
233 |
+
# Including strong name files can present a security risk
|
234 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
235 |
+
#*.snk
|
236 |
+
|
237 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
238 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
239 |
+
#bower_components/
|
240 |
+
|
241 |
+
# RIA/Silverlight projects
|
242 |
+
Generated_Code/
|
243 |
+
|
244 |
+
# Backup & report files from converting an old project file
|
245 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
246 |
+
# because we have git ;-)
|
247 |
+
_UpgradeReport_Files/
|
248 |
+
Backup*/
|
249 |
+
UpgradeLog*.XML
|
250 |
+
UpgradeLog*.htm
|
251 |
+
ServiceFabricBackup/
|
252 |
+
*.rptproj.bak
|
253 |
+
|
254 |
+
# SQL Server files
|
255 |
+
*.mdf
|
256 |
+
*.ldf
|
257 |
+
*.ndf
|
258 |
+
|
259 |
+
# Business Intelligence projects
|
260 |
+
*.rdl.data
|
261 |
+
*.bim.layout
|
262 |
+
*.bim_*.settings
|
263 |
+
*.rptproj.rsuser
|
264 |
+
*- [Bb]ackup.rdl
|
265 |
+
*- [Bb]ackup ([0-9]).rdl
|
266 |
+
*- [Bb]ackup ([0-9][0-9]).rdl
|
267 |
+
|
268 |
+
# Microsoft Fakes
|
269 |
+
FakesAssemblies/
|
270 |
+
|
271 |
+
# GhostDoc plugin setting file
|
272 |
+
*.GhostDoc.xml
|
273 |
+
|
274 |
+
# Node.js Tools for Visual Studio
|
275 |
+
.ntvs_analysis.dat
|
276 |
+
node_modules/
|
277 |
+
|
278 |
+
# Visual Studio 6 build log
|
279 |
+
*.plg
|
280 |
+
|
281 |
+
# Visual Studio 6 workspace options file
|
282 |
+
*.opt
|
283 |
+
|
284 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
285 |
+
*.vbw
|
286 |
+
|
287 |
+
# Visual Studio LightSwitch build output
|
288 |
+
**/*.HTMLClient/GeneratedArtifacts
|
289 |
+
**/*.DesktopClient/GeneratedArtifacts
|
290 |
+
**/*.DesktopClient/ModelManifest.xml
|
291 |
+
**/*.Server/GeneratedArtifacts
|
292 |
+
**/*.Server/ModelManifest.xml
|
293 |
+
_Pvt_Extensions
|
294 |
+
|
295 |
+
# Paket dependency manager
|
296 |
+
.paket/paket.exe
|
297 |
+
paket-files/
|
298 |
+
|
299 |
+
# FAKE - F# Make
|
300 |
+
.fake/
|
301 |
+
|
302 |
+
# CodeRush personal settings
|
303 |
+
.cr/personal
|
304 |
+
|
305 |
+
# Python Tools for Visual Studio (PTVS)
|
306 |
+
__pycache__/
|
307 |
+
*.pyc
|
308 |
+
|
309 |
+
# Cake - Uncomment if you are using it
|
310 |
+
# tools/**
|
311 |
+
# !tools/packages.config
|
312 |
+
|
313 |
+
# Tabs Studio
|
314 |
+
*.tss
|
315 |
+
|
316 |
+
# Telerik's JustMock configuration file
|
317 |
+
*.jmconfig
|
318 |
+
|
319 |
+
# BizTalk build output
|
320 |
+
*.btp.cs
|
321 |
+
*.btm.cs
|
322 |
+
*.odx.cs
|
323 |
+
*.xsd.cs
|
324 |
+
|
325 |
+
# OpenCover UI analysis results
|
326 |
+
OpenCover/
|
327 |
+
|
328 |
+
# Azure Stream Analytics local run output
|
329 |
+
ASALocalRun/
|
330 |
+
|
331 |
+
# MSBuild Binary and Structured Log
|
332 |
+
*.binlog
|
333 |
+
|
334 |
+
# NVidia Nsight GPU debugger configuration file
|
335 |
+
*.nvuser
|
336 |
+
|
337 |
+
# MFractors (Xamarin productivity tool) working folder
|
338 |
+
.mfractor/
|
339 |
+
|
340 |
+
# Local History for Visual Studio
|
341 |
+
.localhistory/
|
342 |
+
|
343 |
+
# BeatPulse healthcheck temp database
|
344 |
+
healthchecksdb
|
345 |
+
|
346 |
+
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
347 |
+
MigrationBackup/
|
348 |
+
|
349 |
+
# Ionide (cross platform F# VS Code tools) working folder
|
350 |
+
.ionide/
|
ms_clap/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [[email protected]](mailto:[email protected]) with questions or concerns
|
ms_clap/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
ms_clap/README.md
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###### [Overview](#CLAP) | [Setup](#Setup) | [CLAP weights](#CLAP-weights) | [Usage](#Usage) | [Examples](#Examples) | [Citation](#Citation)
|
2 |
+
|
3 |
+
# CLAP
|
4 |
+
|
5 |
+
CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning.
|
6 |
+
|
7 |
+
<img width="832" alt="clap_diagrams" src="https://github.com/bmartin1/CLAP/assets/26778834/c5340a09-cc0c-4e41-ad5a-61546eaa824c">
|
8 |
+
|
9 |
+
## Setup
|
10 |
+
|
11 |
+
Install the dependencies: `pip install -r requirements.txt` using Python 3 to get started.
|
12 |
+
|
13 |
+
If you have [conda](https://www.anaconda.com) installed, you can run the following:
|
14 |
+
|
15 |
+
```shell
|
16 |
+
git clone https://github.com/microsoft/CLAP.git && \
|
17 |
+
cd CLAP && \
|
18 |
+
conda create -n clap python=3.10 && \
|
19 |
+
conda activate clap && \
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
## NEW CLAP weights
|
24 |
+
Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Pretrained Model \[Zenodo\]](https://zenodo.org/record/8378278)
|
25 |
+
|
26 |
+
_clapcap_ is the audio captioning model that uses the 2023 encoders.
|
27 |
+
|
28 |
+
## Usage
|
29 |
+
|
30 |
+
- Zero-Shot Classification and Retrieval
|
31 |
+
```python
|
32 |
+
# Load model (Choose between versions '2022' or '2023')
|
33 |
+
from src import CLAP
|
34 |
+
|
35 |
+
clap_model = CLAP("<PATH TO WEIGHTS>", version = '2023', use_cuda=False)
|
36 |
+
|
37 |
+
# Extract text embeddings
|
38 |
+
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
|
39 |
+
|
40 |
+
# Extract audio embeddings
|
41 |
+
audio_embeddings = clap_model.get_audio_embeddings(file_paths: List[str])
|
42 |
+
|
43 |
+
# Compute similarity between audio and text embeddings
|
44 |
+
similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
45 |
+
```
|
46 |
+
|
47 |
+
- Audio Captioning
|
48 |
+
```python
|
49 |
+
# Load model (Choose version 'clapcap')
|
50 |
+
from src import CLAP
|
51 |
+
|
52 |
+
clap_model = CLAP("<PATH TO WEIGHTS>", version = 'clapcap', use_cuda=False)
|
53 |
+
|
54 |
+
# Generate audio captions
|
55 |
+
captions = clap_model.generate_caption(file_paths: List[str])
|
56 |
+
```
|
57 |
+
|
58 |
+
## Examples
|
59 |
+
Take a look at `CLAP\src\` for usage examples.
|
60 |
+
|
61 |
+
To run Zero-Shot Classification on the ESC50 dataset try the following:
|
62 |
+
|
63 |
+
```bash
|
64 |
+
> cd src && python zero_shot_classification.py
|
65 |
+
```
|
66 |
+
Output (version 2023)
|
67 |
+
```bash
|
68 |
+
ESC50 Accuracy: 93.9%
|
69 |
+
```
|
70 |
+
|
71 |
+
## Citation
|
72 |
+
|
73 |
+
Kindly cite our work if you find it useful.
|
74 |
+
|
75 |
+
[CLAP: Learning Audio Concepts from Natural Language Supervision](https://ieeexplore.ieee.org/abstract/document/10095889)
|
76 |
+
```
|
77 |
+
@inproceedings{CLAP2022,
|
78 |
+
title={Clap learning audio concepts from natural language supervision},
|
79 |
+
author={Elizalde, Benjamin and Deshmukh, Soham and Al Ismail, Mahmoud and Wang, Huaming},
|
80 |
+
booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
81 |
+
pages={1--5},
|
82 |
+
year={2023},
|
83 |
+
organization={IEEE}
|
84 |
+
}
|
85 |
+
```
|
86 |
+
|
87 |
+
[Natural Language Supervision for General-Purpose Audio Representations](https://arxiv.org/abs/2309.05767)
|
88 |
+
```
|
89 |
+
@misc{CLAP2023,
|
90 |
+
title={Natural Language Supervision for General-Purpose Audio Representations},
|
91 |
+
author={Benjamin Elizalde and Soham Deshmukh and Huaming Wang},
|
92 |
+
year={2023},
|
93 |
+
eprint={2309.05767},
|
94 |
+
archivePrefix={arXiv},
|
95 |
+
primaryClass={cs.SD},
|
96 |
+
url={https://arxiv.org/abs/2309.05767}
|
97 |
+
}
|
98 |
+
```
|
99 |
+
|
100 |
+
## Contributing
|
101 |
+
|
102 |
+
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
103 |
+
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
104 |
+
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
105 |
+
|
106 |
+
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
107 |
+
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
108 |
+
provided by the bot. You will only need to do this once across all repos using our CLA.
|
109 |
+
|
110 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
111 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
112 |
+
contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
|
113 |
+
|
114 |
+
## Trademarks
|
115 |
+
|
116 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
117 |
+
trademarks or logos is subject to and must follow
|
118 |
+
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
119 |
+
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
120 |
+
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
ms_clap/SECURITY.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
|
2 |
+
|
3 |
+
## Security
|
4 |
+
|
5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
6 |
+
|
7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
8 |
+
|
9 |
+
## Reporting Security Issues
|
10 |
+
|
11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
+
|
13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
14 |
+
|
15 |
+
If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
16 |
+
|
17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
18 |
+
|
19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
+
|
21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
+
* Any special configuration required to reproduce the issue
|
25 |
+
* Step-by-step instructions to reproduce the issue
|
26 |
+
* Proof-of-concept or exploit code (if possible)
|
27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
+
|
29 |
+
This information will help us triage your report more quickly.
|
30 |
+
|
31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
32 |
+
|
33 |
+
## Preferred Languages
|
34 |
+
|
35 |
+
We prefer all communications to be in English.
|
36 |
+
|
37 |
+
## Policy
|
38 |
+
|
39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
40 |
+
|
41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
ms_clap/SUPPORT.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: The maintainer of this repo has not yet edited this file
|
2 |
+
|
3 |
+
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
4 |
+
|
5 |
+
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
6 |
+
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
|
7 |
+
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
|
8 |
+
|
9 |
+
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
10 |
+
|
11 |
+
# Support
|
12 |
+
|
13 |
+
## How to file issues and get help
|
14 |
+
|
15 |
+
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
16 |
+
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
17 |
+
feature request as a new Issue.
|
18 |
+
|
19 |
+
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
20 |
+
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
21 |
+
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
22 |
+
|
23 |
+
## Microsoft Support Policy
|
24 |
+
|
25 |
+
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
ms_clap/requirements.txt
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
appdirs==1.4.4
|
2 |
+
audioread==3.0.0
|
3 |
+
certifi==2022.12.7
|
4 |
+
cffi==1.15.1
|
5 |
+
charset-normalizer==3.0.1
|
6 |
+
colorama==0.4.6
|
7 |
+
decorator==5.1.1
|
8 |
+
filelock==3.9.0
|
9 |
+
flit_core==3.6.0
|
10 |
+
huggingface-hub==0.12.1
|
11 |
+
idna==3.4
|
12 |
+
importlib-metadata==6.0.0
|
13 |
+
importlib-resources==5.12.0
|
14 |
+
jaraco.classes==3.2.3
|
15 |
+
joblib==1.2.0
|
16 |
+
lazy_loader==0.1
|
17 |
+
librosa==0.10.0
|
18 |
+
llvmlite==0.39.1
|
19 |
+
mkl-service==2.4.0
|
20 |
+
more-itertools==9.0.0
|
21 |
+
msgpack==1.0.4
|
22 |
+
numba==0.56.4
|
23 |
+
numpy==1.23.5
|
24 |
+
packaging==23.0
|
25 |
+
pandas==1.4.2
|
26 |
+
pooch==1.6.0
|
27 |
+
pycparser==2.21
|
28 |
+
pywin32-ctypes==0.2.0
|
29 |
+
PyYAML==6.0
|
30 |
+
regex==2022.10.31
|
31 |
+
requests==2.28.2
|
32 |
+
scikit-learn==1.2.1
|
33 |
+
scipy==1.10.1
|
34 |
+
setuptools==65.6.3
|
35 |
+
six==1.16.0
|
36 |
+
soundfile==0.12.1
|
37 |
+
soxr==0.3.3
|
38 |
+
threadpoolctl==3.1.0
|
39 |
+
tokenizers==0.13.2
|
40 |
+
torch==1.13.1
|
41 |
+
torchaudio==0.13.1
|
42 |
+
torchlibrosa==0.1.0
|
43 |
+
torchvision==0.14.1
|
44 |
+
tqdm==4.64.1
|
45 |
+
transformers==4.26.1
|
46 |
+
typing_extensions==4.4.0
|
47 |
+
urllib3==1.26.14
|
48 |
+
wheel==0.38.4
|
49 |
+
wincertstore==0.2
|
50 |
+
zipp==3.14.0
|
ms_clap/src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ms_clap/src/CLAPWrapper.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings("ignore")
|
3 |
+
import random
|
4 |
+
import torchaudio
|
5 |
+
# from torch._six import string_classes
|
6 |
+
import collections
|
7 |
+
import re
|
8 |
+
import numpy as np
|
9 |
+
from transformers import AutoTokenizer, logging
|
10 |
+
try:
|
11 |
+
from models.clap import CLAP
|
12 |
+
from models.mapper import get_clapcap
|
13 |
+
except:
|
14 |
+
from .models.clap import CLAP
|
15 |
+
from .models.mapper import get_clapcap
|
16 |
+
import math
|
17 |
+
import torchaudio.transforms as T
|
18 |
+
import os
|
19 |
+
import torch
|
20 |
+
from importlib_resources import files
|
21 |
+
import argparse
|
22 |
+
import yaml
|
23 |
+
import sys
|
24 |
+
logging.set_verbosity_error()
|
25 |
+
|
26 |
+
|
27 |
+
class CLAPWrapper():
|
28 |
+
"""
|
29 |
+
A class for interfacing CLAP model.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, model_fp, config_root, version, use_cuda=False):
|
33 |
+
self.supported_versions = ['2022', '2023', 'clapcap']
|
34 |
+
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
35 |
+
self.file_path = os.path.realpath(__file__)
|
36 |
+
self.default_collate_err_msg_format = (
|
37 |
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
38 |
+
"dicts or lists; found {}")
|
39 |
+
self.config_root = config_root
|
40 |
+
self.config_as_str = self.get_config_path(version)
|
41 |
+
self.model_fp = model_fp
|
42 |
+
self.use_cuda = use_cuda
|
43 |
+
self.version = version
|
44 |
+
if 'clapcap' in self.version:
|
45 |
+
self.clapcap, self.tokenizer, self.args = self.load_clapcap()
|
46 |
+
else:
|
47 |
+
self.clap, self.tokenizer, self.args = self.load_clap()
|
48 |
+
|
49 |
+
def get_config_path(self, version):
|
50 |
+
if version in self.supported_versions:
|
51 |
+
# config_root = /home/zkong/audio_flamingo/audio_flamingo_v1/microsoft_clap/src/configs
|
52 |
+
return f"{self.config_root}/config_{version}.yml"
|
53 |
+
else:
|
54 |
+
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
|
55 |
+
|
56 |
+
def read_config_as_args(self,config_path,args=None,is_config_str=False):
|
57 |
+
return_dict = {}
|
58 |
+
|
59 |
+
if config_path is not None:
|
60 |
+
if is_config_str:
|
61 |
+
yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
|
62 |
+
else:
|
63 |
+
with open(config_path, "r") as f:
|
64 |
+
yml_config = yaml.load(f, Loader=yaml.FullLoader)
|
65 |
+
|
66 |
+
if args != None:
|
67 |
+
for k, v in yml_config.items():
|
68 |
+
if k in args.__dict__:
|
69 |
+
args.__dict__[k] = v
|
70 |
+
else:
|
71 |
+
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
|
72 |
+
else:
|
73 |
+
for k, v in yml_config.items():
|
74 |
+
return_dict[k] = v
|
75 |
+
|
76 |
+
args = args if args != None else return_dict
|
77 |
+
return argparse.Namespace(**args)
|
78 |
+
|
79 |
+
def load_clap(self):
|
80 |
+
r"""Load CLAP model with args from config file"""
|
81 |
+
|
82 |
+
args = self.read_config_as_args(self.config_as_str, is_config_str=False)
|
83 |
+
|
84 |
+
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
|
85 |
+
self.token_keys = ['input_ids', 'attention_mask']
|
86 |
+
elif 'bert' in args.text_model:
|
87 |
+
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
88 |
+
|
89 |
+
clap = CLAP(
|
90 |
+
audioenc_name=args.audioenc_name,
|
91 |
+
sample_rate=args.sampling_rate,
|
92 |
+
window_size=args.window_size,
|
93 |
+
hop_size=args.hop_size,
|
94 |
+
mel_bins=args.mel_bins,
|
95 |
+
fmin=args.fmin,
|
96 |
+
fmax=args.fmax,
|
97 |
+
classes_num=args.num_classes,
|
98 |
+
out_emb=args.out_emb,
|
99 |
+
text_model=args.text_model,
|
100 |
+
transformer_embed_dim=args.transformer_embed_dim,
|
101 |
+
d_proj=args.d_proj
|
102 |
+
)
|
103 |
+
|
104 |
+
# Load pretrained weights for model
|
105 |
+
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
106 |
+
|
107 |
+
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
|
108 |
+
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
|
109 |
+
clap.load_state_dict(model_state_dict)
|
110 |
+
|
111 |
+
clap.eval() # set clap in eval mode
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
113 |
+
if 'gpt' in args.text_model:
|
114 |
+
tokenizer.add_special_tokens({'pad_token': '!'})
|
115 |
+
|
116 |
+
if self.use_cuda and torch.cuda.is_available():
|
117 |
+
clap = clap.cuda()
|
118 |
+
|
119 |
+
return clap, tokenizer, args
|
120 |
+
|
121 |
+
def load_clapcap(self):
|
122 |
+
r"""Load CLAP model with args from config file"""
|
123 |
+
|
124 |
+
args = self.read_config_as_args(self.config_as_str, is_config_str=False)
|
125 |
+
args.prefix_dim = args.d_proj
|
126 |
+
text_model = args.text_model
|
127 |
+
args.text_model = args.text_decoder
|
128 |
+
args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
|
129 |
+
|
130 |
+
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
|
131 |
+
self.token_keys = ['input_ids', 'attention_mask']
|
132 |
+
elif 'bert' in args.text_model:
|
133 |
+
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
134 |
+
|
135 |
+
clap = CLAP(
|
136 |
+
audioenc_name=args.audioenc_name,
|
137 |
+
sample_rate=args.sampling_rate,
|
138 |
+
window_size=args.window_size,
|
139 |
+
hop_size=args.hop_size,
|
140 |
+
mel_bins=args.mel_bins,
|
141 |
+
fmin=args.fmin,
|
142 |
+
fmax=args.fmax,
|
143 |
+
classes_num=args.num_classes,
|
144 |
+
out_emb=args.out_emb,
|
145 |
+
text_model=text_model,
|
146 |
+
transformer_embed_dim=args.transformer_embed_dim,
|
147 |
+
d_proj=args.d_proj
|
148 |
+
)
|
149 |
+
|
150 |
+
clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
|
151 |
+
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
|
152 |
+
|
153 |
+
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
154 |
+
clapcap.load_state_dict(model_state_dict)
|
155 |
+
|
156 |
+
clapcap.eval() # set clap in eval mode
|
157 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
158 |
+
if 'gpt' in args.text_model:
|
159 |
+
tokenizer.add_special_tokens({'pad_token': '!'})
|
160 |
+
|
161 |
+
if self.use_cuda and torch.cuda.is_available():
|
162 |
+
clapcap = clapcap.cuda()
|
163 |
+
|
164 |
+
return clapcap, tokenizer, args
|
165 |
+
|
166 |
+
def default_collate(self, batch):
|
167 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
168 |
+
elem = batch[0]
|
169 |
+
elem_type = type(elem)
|
170 |
+
if isinstance(elem, torch.Tensor):
|
171 |
+
out = None
|
172 |
+
if torch.utils.data.get_worker_info() is not None:
|
173 |
+
# If we're in a background process, concatenate directly into a
|
174 |
+
# shared memory tensor to avoid an extra copy
|
175 |
+
numel = sum([x.numel() for x in batch])
|
176 |
+
storage = elem.storage()._new_shared(numel)
|
177 |
+
out = elem.new(storage)
|
178 |
+
return torch.stack(batch, 0, out=out)
|
179 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
180 |
+
and elem_type.__name__ != 'string_':
|
181 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
182 |
+
# array of string classes and object
|
183 |
+
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
184 |
+
raise TypeError(
|
185 |
+
self.default_collate_err_msg_format.format(elem.dtype))
|
186 |
+
|
187 |
+
return self.default_collate([torch.as_tensor(b) for b in batch])
|
188 |
+
elif elem.shape == (): # scalars
|
189 |
+
return torch.as_tensor(batch)
|
190 |
+
elif isinstance(elem, float):
|
191 |
+
return torch.tensor(batch, dtype=torch.float64)
|
192 |
+
elif isinstance(elem, int):
|
193 |
+
return torch.tensor(batch)
|
194 |
+
# elif isinstance(elem, string_classes):
|
195 |
+
# return batch
|
196 |
+
elif isinstance(elem, collections.abc.Mapping):
|
197 |
+
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
|
198 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
199 |
+
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
|
200 |
+
elif isinstance(elem, collections.abc.Sequence):
|
201 |
+
# check to make sure that the elements in batch have consistent size
|
202 |
+
it = iter(batch)
|
203 |
+
elem_size = len(next(it))
|
204 |
+
if not all(len(elem) == elem_size for elem in it):
|
205 |
+
raise RuntimeError(
|
206 |
+
'each element in list of batch should be of equal size')
|
207 |
+
transposed = zip(*batch)
|
208 |
+
return [self.default_collate(samples) for samples in transposed]
|
209 |
+
|
210 |
+
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
|
211 |
+
|
212 |
+
def read_audio(self, audio_path, resample=False):
|
213 |
+
r"""Loads audio file or array and returns a torch tensor"""
|
214 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
215 |
+
audio_time_series, sample_rate = torchaudio.load(audio_path)
|
216 |
+
|
217 |
+
resample_rate = self.args.sampling_rate
|
218 |
+
if resample:
|
219 |
+
resampler = T.Resample(sample_rate, resample_rate)
|
220 |
+
audio_time_series = resampler(audio_time_series)
|
221 |
+
return audio_time_series, sample_rate
|
222 |
+
|
223 |
+
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
|
224 |
+
r"""Loads audio file and returns raw audio."""
|
225 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
226 |
+
audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
|
227 |
+
audio_time_series = audio_time_series.reshape(-1)
|
228 |
+
|
229 |
+
# audio_time_series is shorter than predefined audio duration,
|
230 |
+
# so audio_time_series is extended
|
231 |
+
if audio_duration*sample_rate >= audio_time_series.shape[0]:
|
232 |
+
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
|
233 |
+
audio_time_series.shape[0]))
|
234 |
+
# Repeat audio_time_series by repeat_factor to match audio_duration
|
235 |
+
audio_time_series = audio_time_series.repeat(repeat_factor)
|
236 |
+
# remove excess part of audio_time_series
|
237 |
+
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
|
238 |
+
else:
|
239 |
+
# audio_time_series is longer than predefined audio duration,
|
240 |
+
# so audio_time_series is trimmed
|
241 |
+
start_index = random.randrange(
|
242 |
+
audio_time_series.shape[0] - audio_duration*sample_rate)
|
243 |
+
audio_time_series = audio_time_series[start_index:start_index +
|
244 |
+
audio_duration*sample_rate]
|
245 |
+
return torch.FloatTensor(audio_time_series)
|
246 |
+
|
247 |
+
# modified by Kong
|
248 |
+
def load_audio_clip_into_tensor(self, audio_clip, audio_duration, resample=False):
|
249 |
+
r"""Loads audio clip and returns raw audio."""
|
250 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
251 |
+
sample_rate = 44100
|
252 |
+
audio_time_series = audio_clip.reshape(-1)
|
253 |
+
|
254 |
+
# audio_time_series is shorter than predefined audio duration,
|
255 |
+
# so audio_time_series is extended
|
256 |
+
assert audio_duration * sample_rate >= audio_time_series.shape[0], \
|
257 |
+
'dur * sr = {} should be larger than len = {}'.format(audio_duration * sample_rate, audio_time_series.shape[0])
|
258 |
+
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
|
259 |
+
audio_time_series.shape[0]))
|
260 |
+
# Repeat audio_time_series by repeat_factor to match audio_duration
|
261 |
+
audio_time_series = audio_time_series.repeat(repeat_factor)
|
262 |
+
# remove excess part of audio_time_series
|
263 |
+
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
|
264 |
+
|
265 |
+
# return torch.FloatTensor(audio_time_series)
|
266 |
+
return audio_time_series # already on cuda device
|
267 |
+
|
268 |
+
def preprocess_audio(self, audio_files, resample):
|
269 |
+
r"""Load list of audio files and return raw audio"""
|
270 |
+
audio_tensors = []
|
271 |
+
for audio_file in audio_files:
|
272 |
+
audio_tensor = self.load_audio_into_tensor(
|
273 |
+
audio_file, self.args.duration, resample)
|
274 |
+
audio_tensor = audio_tensor.reshape(
|
275 |
+
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
|
276 |
+
audio_tensors.append(audio_tensor)
|
277 |
+
return self.default_collate(audio_tensors)
|
278 |
+
|
279 |
+
# modified by Kong
|
280 |
+
def preprocess_audio_clips(self, audio_clips, resample=False):
|
281 |
+
r"""Load list of audio clips and return raw audio"""
|
282 |
+
audio_tensors = []
|
283 |
+
for audio_clip in audio_clips:
|
284 |
+
audio_tensor = self.load_audio_clip_into_tensor(
|
285 |
+
audio_clip, self.args.duration, resample=False)
|
286 |
+
audio_tensor = audio_tensor.reshape(
|
287 |
+
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
|
288 |
+
audio_tensors.append(audio_tensor)
|
289 |
+
return self.default_collate(audio_tensors)
|
290 |
+
|
291 |
+
def preprocess_text(self, text_queries):
|
292 |
+
r"""Load list of class labels and return tokenized text"""
|
293 |
+
tokenized_texts = []
|
294 |
+
for ttext in text_queries:
|
295 |
+
if 'gpt' in self.args.text_model:
|
296 |
+
ttext = ttext + ' <|endoftext|>'
|
297 |
+
tok = self.tokenizer.encode_plus(
|
298 |
+
text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
|
299 |
+
for key in self.token_keys:
|
300 |
+
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
|
301 |
+
tokenized_texts.append(tok)
|
302 |
+
return self.default_collate(tokenized_texts)
|
303 |
+
|
304 |
+
def get_text_embeddings(self, class_labels):
|
305 |
+
r"""Load list of class labels and return text embeddings"""
|
306 |
+
preprocessed_text = self.preprocess_text(class_labels)
|
307 |
+
return self._get_text_embeddings(preprocessed_text)
|
308 |
+
|
309 |
+
def get_audio_embeddings(self, audio_files, resample):
|
310 |
+
r"""Load list of audio files and return a audio embeddings"""
|
311 |
+
preprocessed_audio = self.preprocess_audio(audio_files, resample)
|
312 |
+
return self._get_audio_embeddings(preprocessed_audio)
|
313 |
+
|
314 |
+
# modified by Kong
|
315 |
+
def get_audio_embeddings_from_clips(self, audio_clips, resample=False):
|
316 |
+
r"""Load list of audio files and return a audio embeddings"""
|
317 |
+
preprocessed_audio = self.preprocess_audio_clips(audio_clips, resample)
|
318 |
+
return self._get_audio_embeddings(preprocessed_audio)
|
319 |
+
|
320 |
+
def _get_text_embeddings(self, preprocessed_text):
|
321 |
+
r"""Load preprocessed text and return text embeddings"""
|
322 |
+
with torch.no_grad():
|
323 |
+
return self.clap.caption_encoder(preprocessed_text)
|
324 |
+
|
325 |
+
# modified by Kong
|
326 |
+
def _get_audio_embeddings(self, preprocessed_audio):
|
327 |
+
r"""Load preprocessed audio and return a audio embeddings"""
|
328 |
+
with torch.no_grad():
|
329 |
+
preprocessed_audio = preprocessed_audio.reshape(
|
330 |
+
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
|
331 |
+
#Append [0] the audio emebdding, [1] has output class probabilities
|
332 |
+
if 'clapcap' in self.version:
|
333 |
+
return self.clapcap.clap(preprocessed_audio)[0]
|
334 |
+
else:
|
335 |
+
return self.clap.audio_encoder(preprocessed_audio)[0]
|
336 |
+
|
337 |
+
def _generic_batch_inference(self, func, *args):
|
338 |
+
r"""Process audio and/or text per batch"""
|
339 |
+
input_tmp = args[0]
|
340 |
+
batch_size = args[-1]
|
341 |
+
# args[0] has audio_files, args[1] has class_labels
|
342 |
+
inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
|
343 |
+
args0_len = len(args[0])
|
344 |
+
# compute text_embeddings once for all the audio_files batches
|
345 |
+
if len(inputs) == 2:
|
346 |
+
text_embeddings = self.get_text_embeddings(args[1])
|
347 |
+
inputs = [args[0], args[1], text_embeddings]
|
348 |
+
dataset_idx = 0
|
349 |
+
for _ in range(math.ceil(args0_len/batch_size)):
|
350 |
+
next_batch_idx = dataset_idx + batch_size
|
351 |
+
# batch size is bigger than available audio/text items
|
352 |
+
if next_batch_idx >= args0_len:
|
353 |
+
inputs[0] = input_tmp[dataset_idx:]
|
354 |
+
return func(*tuple(inputs))
|
355 |
+
else:
|
356 |
+
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
|
357 |
+
yield func(*tuple(inputs))
|
358 |
+
dataset_idx = next_batch_idx
|
359 |
+
|
360 |
+
def get_audio_embeddings_per_batch(self, audio_files, batch_size):
|
361 |
+
r"""Load preprocessed audio and return a audio embeddings per batch"""
|
362 |
+
return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
|
363 |
+
|
364 |
+
def get_text_embeddings_per_batch(self, class_labels, batch_size):
|
365 |
+
r"""Load preprocessed text and return text embeddings per batch"""
|
366 |
+
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
|
367 |
+
|
368 |
+
def compute_similarity(self, audio_embeddings, text_embeddings):
|
369 |
+
r"""Compute similarity between text and audio embeddings"""
|
370 |
+
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
371 |
+
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
372 |
+
|
373 |
+
logit_scale = self.clap.logit_scale.exp()
|
374 |
+
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
375 |
+
return similarity.T
|
376 |
+
|
377 |
+
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
|
378 |
+
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
|
379 |
+
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
|
380 |
+
|
381 |
+
def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
|
382 |
+
r"""Generate audio captions for each audio recording in a batch"""
|
383 |
+
captions = []
|
384 |
+
audio_tensors = self.preprocess_audio(audio_files, resample)
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
|
388 |
+
if self.args.normalize_prefix:
|
389 |
+
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
|
390 |
+
prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
|
391 |
+
|
392 |
+
for i in range(len(audio_tensors)):
|
393 |
+
gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
|
394 |
+
beam_size=beam_size,\
|
395 |
+
entry_length=entry_length,\
|
396 |
+
temperature=temperature)[0]
|
397 |
+
captions.append(gen_caption.capitalize())
|
398 |
+
return captions
|
399 |
+
|
400 |
+
def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
|
401 |
+
entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
|
402 |
+
r"""Generate captions by beam search decoding"""
|
403 |
+
self.clapcap.eval()
|
404 |
+
stop_token_index = self.tokenizer.encode(stop_token)[0]
|
405 |
+
tokens = None
|
406 |
+
scores = None
|
407 |
+
device = next(self.clapcap.parameters()).device
|
408 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
409 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
410 |
+
with torch.no_grad():
|
411 |
+
if embed is not None:
|
412 |
+
generated = embed
|
413 |
+
else:
|
414 |
+
if tokens is None:
|
415 |
+
tokens = torch.tensor(self.tokenizer.encode(prompt))
|
416 |
+
tokens = tokens.unsqueeze(0).to(device)
|
417 |
+
generated = self.clapcap.gpt.transformer.wte(tokens)
|
418 |
+
for i in range(entry_length):
|
419 |
+
outputs = self.clapcap.gpt(inputs_embeds=generated)
|
420 |
+
logits = outputs.logits
|
421 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
422 |
+
logits = logits.softmax(-1).log()
|
423 |
+
if scores is None:
|
424 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
425 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
426 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
427 |
+
if tokens is None:
|
428 |
+
tokens = next_tokens
|
429 |
+
else:
|
430 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
431 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
432 |
+
else:
|
433 |
+
logits[is_stopped] = -float(np.inf)
|
434 |
+
logits[is_stopped, 0] = 0
|
435 |
+
scores_sum = scores[:, None] + logits
|
436 |
+
seq_lengths[~is_stopped] += 1
|
437 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
438 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
|
439 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
440 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
441 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
442 |
+
next_tokens = next_tokens.unsqueeze(1)
|
443 |
+
tokens = tokens[next_tokens_source]
|
444 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
445 |
+
generated = generated[next_tokens_source]
|
446 |
+
scores = scores_sum_average * seq_lengths
|
447 |
+
is_stopped = is_stopped[next_tokens_source]
|
448 |
+
next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
|
449 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
450 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
451 |
+
if is_stopped.all():
|
452 |
+
break
|
453 |
+
scores = scores / seq_lengths
|
454 |
+
output_list = tokens.cpu().numpy()
|
455 |
+
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
|
456 |
+
order = scores.argsort(descending=True)
|
457 |
+
output_texts = [output_texts[i] for i in order]
|
458 |
+
return output_texts
|
ms_clap/src/__init__.py
ADDED
File without changes
|
ms_clap/src/audio_captioning.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an example using CLAPCAP for audio captioning.
|
3 |
+
"""
|
4 |
+
from CLAPWrapper import CLAPWrapper
|
5 |
+
|
6 |
+
# Load and initialize CLAP
|
7 |
+
weights_path = "weights_path"
|
8 |
+
clap_model = CLAPWrapper(weights_path, version = 'clapcap', use_cuda=False)
|
9 |
+
|
10 |
+
#Load audio files
|
11 |
+
audio_files = ['audio_file']
|
12 |
+
|
13 |
+
# Generate captions for the recording
|
14 |
+
captions = clap_model.generate_caption(audio_files, resample=True, beam_size=5, entry_length=67, temperature=0.01)
|
15 |
+
|
16 |
+
# Print the result
|
17 |
+
for i in range(len(audio_files)):
|
18 |
+
print(f"Audio file: {audio_files[i]} \n")
|
19 |
+
print(f"Generated caption: {captions[i]} \n")
|
20 |
+
|
21 |
+
"""
|
22 |
+
The output (the exact caption may vary):
|
23 |
+
|
24 |
+
The birds are singing in the trees.
|
25 |
+
"""
|
ms_clap/src/configs/config_2022.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TEXT ENCODER CONFIG
|
2 |
+
text_model: 'bert-base-uncased'
|
3 |
+
text_len: 100
|
4 |
+
transformer_embed_dim: 768
|
5 |
+
freeze_text_encoder_weights: True
|
6 |
+
|
7 |
+
# AUDIO ENCODER CONFIG
|
8 |
+
audioenc_name: 'Cnn14'
|
9 |
+
out_emb: 2048
|
10 |
+
sampling_rate: 44100
|
11 |
+
duration: 5
|
12 |
+
fmin: 50
|
13 |
+
fmax: 14000
|
14 |
+
n_fft: 1028
|
15 |
+
hop_size: 320
|
16 |
+
mel_bins: 64
|
17 |
+
window_size: 1024
|
18 |
+
|
19 |
+
# PROJECTION SPACE CONFIG
|
20 |
+
d_proj: 1024
|
21 |
+
temperature: 0.003
|
22 |
+
|
23 |
+
# TRAINING AND EVALUATION CONFIG
|
24 |
+
num_classes: 527
|
25 |
+
batch_size: 1024
|
26 |
+
demo: False
|
ms_clap/src/configs/config_2023.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TEXT ENCODER CONFIG
|
2 |
+
text_model: 'gpt2'
|
3 |
+
text_len: 77
|
4 |
+
transformer_embed_dim: 768
|
5 |
+
freeze_text_encoder_weights: True
|
6 |
+
|
7 |
+
# AUDIO ENCODER CONFIG
|
8 |
+
audioenc_name: 'HTSAT'
|
9 |
+
out_emb: 768
|
10 |
+
sampling_rate: 44100
|
11 |
+
duration: 7
|
12 |
+
fmin: 50
|
13 |
+
fmax: 8000 #14000
|
14 |
+
n_fft: 1024 # 1028
|
15 |
+
hop_size: 320
|
16 |
+
mel_bins: 64
|
17 |
+
window_size: 1024
|
18 |
+
|
19 |
+
# PROJECTION SPACE CONFIG
|
20 |
+
d_proj: 1024
|
21 |
+
temperature: 0.003
|
22 |
+
|
23 |
+
# TRAINING AND EVALUATION CONFIG
|
24 |
+
num_classes: 527
|
25 |
+
batch_size: 1024
|
26 |
+
demo: False
|
ms_clap/src/configs/config_clapcap.yml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TEXT ENCODER CONFIG
|
2 |
+
text_model: 'gpt2'
|
3 |
+
transformer_embed_dim: 768
|
4 |
+
freeze_text_encoder_weights: True
|
5 |
+
|
6 |
+
# AUDIO ENCODER CONFIG
|
7 |
+
audioenc_name: 'HTSAT'
|
8 |
+
out_emb: 768
|
9 |
+
sampling_rate: 44100
|
10 |
+
duration: 7
|
11 |
+
fmin: 50
|
12 |
+
fmax: 8000
|
13 |
+
n_fft: 1024
|
14 |
+
hop_size: 320
|
15 |
+
mel_bins: 64
|
16 |
+
window_size: 1024
|
17 |
+
|
18 |
+
# PROJECTION SPACE CONFIG
|
19 |
+
d_proj: 1024
|
20 |
+
temperature: 0.003
|
21 |
+
|
22 |
+
# TRAINING AND EVALUATION CONFIG
|
23 |
+
batch_size: 128
|
24 |
+
num_classes: 527
|
25 |
+
|
26 |
+
# CLAPCAP CONFIG
|
27 |
+
clapcap_model: 'ClapCaption'
|
28 |
+
text_decoder: 'gpt2'
|
29 |
+
prefix_length: 40
|
30 |
+
prefix_length_clip: 40
|
31 |
+
mapping_type: 'transformer'
|
32 |
+
num_layers: 8
|
33 |
+
normalize_prefix: True
|
34 |
+
freeze_gpt_weights: True
|
ms_clap/src/esc50_dataset.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from torchvision.datasets.utils import download_url
|
3 |
+
from tqdm import tqdm
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
|
9 |
+
class AudioDataset(Dataset):
|
10 |
+
def __init__(self, root: str, download: bool = True):
|
11 |
+
self.root = os.path.expanduser(root)
|
12 |
+
if download:
|
13 |
+
self.download()
|
14 |
+
|
15 |
+
def __getitem__(self, index):
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
def download(self):
|
19 |
+
raise NotImplementedError
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
|
25 |
+
class ESC50(AudioDataset):
|
26 |
+
base_folder = 'ESC-50-master'
|
27 |
+
url = "https://github.com/karoldvl/ESC-50/archive/master.zip"
|
28 |
+
filename = "ESC-50-master.zip"
|
29 |
+
num_files_in_dir = 2000
|
30 |
+
audio_dir = 'audio'
|
31 |
+
label_col = 'category'
|
32 |
+
file_col = 'filename'
|
33 |
+
meta = {
|
34 |
+
'filename': os.path.join('meta','esc50.csv'),
|
35 |
+
}
|
36 |
+
|
37 |
+
def __init__(self, root, reading_transformations: nn.Module = None, download: bool = True):
|
38 |
+
super().__init__(root)
|
39 |
+
self._load_meta()
|
40 |
+
|
41 |
+
self.targets, self.audio_paths = [], []
|
42 |
+
self.pre_transformations = reading_transformations
|
43 |
+
print("Loading audio files")
|
44 |
+
# self.df['filename'] = os.path.join(self.root, self.base_folder, self.audio_dir) + os.sep + self.df['filename']
|
45 |
+
self.df['category'] = self.df['category'].str.replace('_',' ')
|
46 |
+
|
47 |
+
for _, row in tqdm(self.df.iterrows()):
|
48 |
+
file_path = os.path.join(self.root, self.base_folder, self.audio_dir, row[self.file_col])
|
49 |
+
self.targets.append(row[self.label_col])
|
50 |
+
self.audio_paths.append(file_path)
|
51 |
+
|
52 |
+
def _load_meta(self):
|
53 |
+
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
54 |
+
|
55 |
+
self.df = pd.read_csv(path)
|
56 |
+
self.class_to_idx = {}
|
57 |
+
self.classes = [x.replace('_',' ') for x in sorted(self.df[self.label_col].unique())]
|
58 |
+
for i, category in enumerate(self.classes):
|
59 |
+
self.class_to_idx[category] = i
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
index (int): Index
|
65 |
+
Returns:
|
66 |
+
tuple: (image, target) where target is index of the target class.
|
67 |
+
"""
|
68 |
+
file_path, target = self.audio_paths[index], self.targets[index]
|
69 |
+
idx = torch.tensor(self.class_to_idx[target])
|
70 |
+
one_hot_target = torch.zeros(len(self.classes)).scatter_(0, idx, 1).reshape(1,-1)
|
71 |
+
return file_path, target, one_hot_target
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.audio_paths)
|
75 |
+
|
76 |
+
def download(self):
|
77 |
+
download_url(self.url, self.root, self.filename)
|
78 |
+
|
79 |
+
# extract file
|
80 |
+
from zipfile import ZipFile
|
81 |
+
with ZipFile(os.path.join(self.root, self.filename), 'r') as zip:
|
82 |
+
zip.extractall(path=self.root)
|
ms_clap/src/models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import clap
|
2 |
+
from . import audio
|
3 |
+
from . import htsat
|
4 |
+
from . import config
|
5 |
+
from . import pytorch_utils
|
6 |
+
from . import htsat
|
ms_clap/src/models/audio.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
5 |
+
|
6 |
+
try:
|
7 |
+
from models.htsat import HTSATWrapper
|
8 |
+
except:
|
9 |
+
from .htsat import HTSATWrapper
|
10 |
+
|
11 |
+
def get_audio_encoder(name: str):
|
12 |
+
if name == "Cnn14":
|
13 |
+
return Cnn14
|
14 |
+
elif name == "HTSAT":
|
15 |
+
return HTSATWrapper
|
16 |
+
else:
|
17 |
+
raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
|
18 |
+
|
19 |
+
|
20 |
+
class ConvBlock(nn.Module):
|
21 |
+
def __init__(self, in_channels, out_channels):
|
22 |
+
|
23 |
+
super(ConvBlock, self).__init__()
|
24 |
+
|
25 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
26 |
+
out_channels=out_channels,
|
27 |
+
kernel_size=(3, 3), stride=(1, 1),
|
28 |
+
padding=(1, 1), bias=False)
|
29 |
+
|
30 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
31 |
+
out_channels=out_channels,
|
32 |
+
kernel_size=(3, 3), stride=(1, 1),
|
33 |
+
padding=(1, 1), bias=False)
|
34 |
+
|
35 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
36 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
37 |
+
|
38 |
+
|
39 |
+
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
40 |
+
|
41 |
+
x = input
|
42 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
43 |
+
x = F.relu_(self.bn2(self.conv2(x)))
|
44 |
+
if pool_type == 'max':
|
45 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
46 |
+
elif pool_type == 'avg':
|
47 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
48 |
+
elif pool_type == 'avg+max':
|
49 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
50 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
51 |
+
x = x1 + x2
|
52 |
+
else:
|
53 |
+
raise Exception('Incorrect argument!')
|
54 |
+
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class ConvBlock5x5(nn.Module):
|
59 |
+
def __init__(self, in_channels, out_channels):
|
60 |
+
|
61 |
+
super(ConvBlock5x5, self).__init__()
|
62 |
+
|
63 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
64 |
+
out_channels=out_channels,
|
65 |
+
kernel_size=(5, 5), stride=(1, 1),
|
66 |
+
padding=(2, 2), bias=False)
|
67 |
+
|
68 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
69 |
+
|
70 |
+
|
71 |
+
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
72 |
+
|
73 |
+
x = input
|
74 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
75 |
+
if pool_type == 'max':
|
76 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
77 |
+
elif pool_type == 'avg':
|
78 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
79 |
+
elif pool_type == 'avg+max':
|
80 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
81 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
82 |
+
x = x1 + x2
|
83 |
+
else:
|
84 |
+
raise Exception('Incorrect argument!')
|
85 |
+
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class AttBlock(nn.Module):
|
90 |
+
def __init__(self, n_in, n_out, activation='linear', temperature=1.):
|
91 |
+
super(AttBlock, self).__init__()
|
92 |
+
|
93 |
+
self.activation = activation
|
94 |
+
self.temperature = temperature
|
95 |
+
self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
|
96 |
+
self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
|
97 |
+
|
98 |
+
self.bn_att = nn.BatchNorm1d(n_out)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
# x: (n_samples, n_in, n_time)
|
102 |
+
norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
|
103 |
+
cla = self.nonlinear_transform(self.cla(x))
|
104 |
+
x = torch.sum(norm_att * cla, dim=2)
|
105 |
+
return x, norm_att, cla
|
106 |
+
|
107 |
+
def nonlinear_transform(self, x):
|
108 |
+
if self.activation == 'linear':
|
109 |
+
return x
|
110 |
+
elif self.activation == 'sigmoid':
|
111 |
+
return torch.sigmoid(x)
|
112 |
+
|
113 |
+
|
114 |
+
class Cnn14(nn.Module):
|
115 |
+
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
|
116 |
+
fmax, classes_num, out_emb):
|
117 |
+
|
118 |
+
super(Cnn14, self).__init__()
|
119 |
+
|
120 |
+
window = 'hann'
|
121 |
+
center = True
|
122 |
+
pad_mode = 'reflect'
|
123 |
+
ref = 1.0
|
124 |
+
amin = 1e-10
|
125 |
+
top_db = None
|
126 |
+
|
127 |
+
# Spectrogram extractor
|
128 |
+
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
|
129 |
+
win_length=window_size, window=window, center=center, pad_mode=pad_mode,
|
130 |
+
freeze_parameters=True)
|
131 |
+
|
132 |
+
# Logmel feature extractor
|
133 |
+
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
|
134 |
+
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
|
135 |
+
freeze_parameters=True)
|
136 |
+
|
137 |
+
self.bn0 = nn.BatchNorm2d(64)
|
138 |
+
|
139 |
+
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
140 |
+
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
141 |
+
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
142 |
+
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
143 |
+
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
144 |
+
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
145 |
+
|
146 |
+
# out_emb is 2048 for best Cnn14
|
147 |
+
self.fc1 = nn.Linear(2048, out_emb, bias=True)
|
148 |
+
self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True)
|
149 |
+
|
150 |
+
def forward(self, input, mixup_lambda=None):
|
151 |
+
"""
|
152 |
+
Input: (batch_size, data_length)
|
153 |
+
"""
|
154 |
+
|
155 |
+
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
|
156 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
157 |
+
|
158 |
+
x = x.transpose(1, 3)
|
159 |
+
x = self.bn0(x)
|
160 |
+
x = x.transpose(1, 3)
|
161 |
+
|
162 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
163 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
164 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
165 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
166 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
167 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
168 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
169 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
170 |
+
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
|
171 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
172 |
+
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
|
173 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
174 |
+
x = torch.mean(x, dim=3)
|
175 |
+
|
176 |
+
(x1, _) = torch.max(x, dim=2)
|
177 |
+
x2 = torch.mean(x, dim=2)
|
178 |
+
x = x1 + x2
|
179 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
180 |
+
x = F.relu_(self.fc1(x))
|
181 |
+
embedding = F.dropout(x, p=0.5, training=self.training)
|
182 |
+
clipwise_output = torch.sigmoid(self.fc_audioset(x))
|
183 |
+
|
184 |
+
output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}
|
185 |
+
|
186 |
+
return output_dict
|
ms_clap/src/models/clap.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from transformers import AutoModel
|
6 |
+
from .audio import get_audio_encoder
|
7 |
+
|
8 |
+
class Projection(nn.Module):
|
9 |
+
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
|
10 |
+
super().__init__()
|
11 |
+
self.linear1 = nn.Linear(d_in, d_out, bias=False)
|
12 |
+
self.linear2 = nn.Linear(d_out, d_out, bias=False)
|
13 |
+
self.layer_norm = nn.LayerNorm(d_out)
|
14 |
+
self.drop = nn.Dropout(p)
|
15 |
+
|
16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
17 |
+
embed1 = self.linear1(x)
|
18 |
+
embed2 = self.drop(self.linear2(F.gelu(embed1)))
|
19 |
+
embeds = self.layer_norm(embed1 + embed2)
|
20 |
+
return embeds
|
21 |
+
|
22 |
+
class AudioEncoder(nn.Module):
|
23 |
+
def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
|
24 |
+
hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
audio_encoder = get_audio_encoder(audioenc_name)
|
28 |
+
|
29 |
+
self.base = audio_encoder(
|
30 |
+
sample_rate, window_size,
|
31 |
+
hop_size, mel_bins, fmin, fmax,
|
32 |
+
classes_num, d_in)
|
33 |
+
|
34 |
+
self.projection = Projection(d_in, d_out)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
out_dict = self.base(x)
|
38 |
+
audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
|
39 |
+
projected_vec = self.projection(audio_features)
|
40 |
+
return projected_vec, audio_classification_output
|
41 |
+
|
42 |
+
class TextEncoder(nn.Module):
|
43 |
+
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
|
44 |
+
super().__init__()
|
45 |
+
self.text_model = text_model
|
46 |
+
self.base = AutoModel.from_pretrained(text_model)
|
47 |
+
|
48 |
+
if 'clip' in text_model:
|
49 |
+
self.clip_text_projection = self.base.text_projection
|
50 |
+
self.base = self.base.text_model
|
51 |
+
if 'base' in text_model:
|
52 |
+
transformer_embed_dim = 512
|
53 |
+
|
54 |
+
self.projection = Projection(transformer_embed_dim, d_out)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if 'clip' in self.text_model:
|
58 |
+
pooled_output = self.base(**x)[1] # get pooled output
|
59 |
+
out = self.clip_text_projection(pooled_output) # get CLS token output
|
60 |
+
elif 'gpt' in self.text_model:
|
61 |
+
batch_size = x['input_ids'].shape[0]
|
62 |
+
hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
|
63 |
+
|
64 |
+
sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
|
65 |
+
out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
|
66 |
+
else:
|
67 |
+
out = self.base(**x)[0]
|
68 |
+
out = out[:, 0, :] # get CLS token output
|
69 |
+
|
70 |
+
projected_vec = self.projection(out)
|
71 |
+
|
72 |
+
return projected_vec
|
73 |
+
|
74 |
+
class CLAP(nn.Module):
|
75 |
+
def __init__(self,
|
76 |
+
# audio
|
77 |
+
audioenc_name: str,
|
78 |
+
sample_rate: int,
|
79 |
+
window_size: int,
|
80 |
+
hop_size: int,
|
81 |
+
mel_bins: int,
|
82 |
+
fmin: int,
|
83 |
+
fmax: int,
|
84 |
+
classes_num: int,
|
85 |
+
out_emb: int,
|
86 |
+
# text
|
87 |
+
text_model: str,
|
88 |
+
transformer_embed_dim: int,
|
89 |
+
# common
|
90 |
+
d_proj: int,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
|
95 |
+
self.audio_encoder = AudioEncoder(
|
96 |
+
audioenc_name, out_emb, d_proj,
|
97 |
+
sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
|
98 |
+
|
99 |
+
self.caption_encoder = TextEncoder(
|
100 |
+
d_proj, text_model, transformer_embed_dim
|
101 |
+
)
|
102 |
+
|
103 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
104 |
+
|
105 |
+
def forward(self, audio, text):
|
106 |
+
audio_embed, _ = self.audio_encoder(audio)
|
107 |
+
caption_embed = self.caption_encoder(text)
|
108 |
+
|
109 |
+
return caption_embed, audio_embed, self.logit_scale.exp()
|
ms_clap/src/models/config.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# The configuration for training the model
|
5 |
+
|
6 |
+
exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
|
7 |
+
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
|
8 |
+
dataset_path = "/home/Research/audioset" # the dataset path
|
9 |
+
desed_folder = "/home/Research/DESED" # the desed file
|
10 |
+
|
11 |
+
dataset_type = "audioset" # "audioset" "esc-50" "scv2"
|
12 |
+
index_type = "full_train" # only works for audioset
|
13 |
+
balanced_data = True # only works for audioset
|
14 |
+
|
15 |
+
loss_type = "clip_bce" #
|
16 |
+
# AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
|
17 |
+
|
18 |
+
# trained from a checkpoint, or evaluate a single model
|
19 |
+
resume_checkpoint = None
|
20 |
+
# "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
|
21 |
+
|
22 |
+
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
|
23 |
+
|
24 |
+
|
25 |
+
debug = False
|
26 |
+
|
27 |
+
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
|
28 |
+
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
|
29 |
+
learning_rate = 1e-3 # 1e-4 also workable
|
30 |
+
max_epoch = 100
|
31 |
+
num_workers = 3
|
32 |
+
|
33 |
+
lr_scheduler_epoch = [10,20,30]
|
34 |
+
lr_rate = [0.02, 0.05, 0.1]
|
35 |
+
|
36 |
+
# these data preparation optimizations do not bring many improvements, so deprecated
|
37 |
+
enable_token_label = False # token label
|
38 |
+
class_map_path = "class_hier_map.npy"
|
39 |
+
class_filter = None
|
40 |
+
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
|
41 |
+
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
|
42 |
+
token_label_range = [0.2,0.6]
|
43 |
+
enable_time_shift = False # shift time
|
44 |
+
enable_label_enhance = False # enhance hierarchical label
|
45 |
+
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# for model's design
|
50 |
+
enable_tscam = True # enbale the token-semantic layer
|
51 |
+
|
52 |
+
# for signal processing
|
53 |
+
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
|
54 |
+
clip_samples = sample_rate * 10 # audio_set 10-sec clip
|
55 |
+
window_size = 1024
|
56 |
+
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
|
57 |
+
mel_bins = 64
|
58 |
+
fmin = 50
|
59 |
+
fmax = 14000
|
60 |
+
shift_max = int(clip_samples * 0.5)
|
61 |
+
|
62 |
+
# for data collection
|
63 |
+
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
|
64 |
+
patch_size = (25, 4) # deprecated
|
65 |
+
crop_size = None # int(clip_samples * 0.5) deprecated
|
66 |
+
|
67 |
+
# for htsat hyperparamater
|
68 |
+
htsat_window_size = 8
|
69 |
+
htsat_spec_size = 256
|
70 |
+
htsat_patch_size = 4
|
71 |
+
htsat_stride = (4, 4)
|
72 |
+
htsat_num_head = [4,8,16,32]
|
73 |
+
htsat_dim = 96
|
74 |
+
htsat_depth = [2,2,6,2]
|
75 |
+
|
76 |
+
swin_pretrain_path = None
|
77 |
+
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
|
78 |
+
|
79 |
+
# Some Deprecated Optimization in the model design, check the model code for details
|
80 |
+
htsat_attn_heatmap = False
|
81 |
+
htsat_hier_output = False
|
82 |
+
htsat_use_max = False
|
83 |
+
|
84 |
+
|
85 |
+
# for ensemble test
|
86 |
+
|
87 |
+
ensemble_checkpoints = []
|
88 |
+
ensemble_strides = []
|
89 |
+
|
90 |
+
|
91 |
+
# weight average folder
|
92 |
+
wa_folder = "/home/version_0/checkpoints/"
|
93 |
+
# weight average output filename
|
94 |
+
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
|
95 |
+
|
96 |
+
esm_model_pathes = [
|
97 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
|
98 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
|
99 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
|
100 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
|
101 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
|
102 |
+
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
|
103 |
+
]
|
104 |
+
|
105 |
+
# for framewise localization
|
106 |
+
heatmap_dir = "/home/Research/heatmap_output"
|
107 |
+
test_file = "htsat-test-ensemble"
|
108 |
+
fl_local = False # indicate if we need to use this dataset for the framewise detection
|
109 |
+
fl_dataset = "/home/Research/desed/desed_eval.npy"
|
110 |
+
fl_class_num = [
|
111 |
+
"Speech", "Frying", "Dishes", "Running_water",
|
112 |
+
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
|
113 |
+
"Cat", "Dog", "Vacuum_cleaner"
|
114 |
+
]
|
115 |
+
|
116 |
+
# map 527 classes into 10 classes
|
117 |
+
fl_audioset_mapping = [
|
118 |
+
[0,1,2,3,4,5,6,7],
|
119 |
+
[366, 367, 368],
|
120 |
+
[364],
|
121 |
+
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
|
122 |
+
[369],
|
123 |
+
[382],
|
124 |
+
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
|
125 |
+
[81, 82, 83, 84, 85],
|
126 |
+
[74, 75, 76, 77, 78, 79],
|
127 |
+
[377]
|
128 |
+
]
|
ms_clap/src/models/htsat.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# Model Core
|
5 |
+
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
|
6 |
+
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
|
7 |
+
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import pdb
|
11 |
+
import math
|
12 |
+
import random
|
13 |
+
from numpy.core.fromnumeric import clip, reshape
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
|
18 |
+
# import os
|
19 |
+
import sys
|
20 |
+
sys.path.append('/home/zkong/audio_flamingo/audio_flamingo_v1/v0.2/open_flamingo/my_ms_clap/models')
|
21 |
+
|
22 |
+
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
23 |
+
from torchlibrosa.augmentation import SpecAugmentation
|
24 |
+
|
25 |
+
from itertools import repeat
|
26 |
+
from typing import List
|
27 |
+
try:
|
28 |
+
from models.pytorch_utils import do_mixup, interpolate
|
29 |
+
import models.config as config
|
30 |
+
except:
|
31 |
+
from .pytorch_utils import do_mixup, interpolate
|
32 |
+
from . import config
|
33 |
+
# from CLAP_API.models.pytorch_utils import do_mixup, interpolate
|
34 |
+
# from CLAP_API.models import config
|
35 |
+
|
36 |
+
import torch.nn.functional as F
|
37 |
+
import collections.abc
|
38 |
+
import warnings
|
39 |
+
|
40 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
41 |
+
|
42 |
+
def _ntuple(n):
|
43 |
+
def parse(x):
|
44 |
+
if isinstance(x, collections.abc.Iterable):
|
45 |
+
return x
|
46 |
+
return tuple(repeat(x, n))
|
47 |
+
return parse
|
48 |
+
|
49 |
+
to_1tuple = _ntuple(1)
|
50 |
+
to_2tuple = _ntuple(2)
|
51 |
+
to_3tuple = _ntuple(3)
|
52 |
+
to_4tuple = _ntuple(4)
|
53 |
+
to_ntuple = _ntuple
|
54 |
+
|
55 |
+
|
56 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
57 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
58 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
59 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
60 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
61 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
62 |
+
'survival rate' as the argument.
|
63 |
+
"""
|
64 |
+
if drop_prob == 0. or not training:
|
65 |
+
return x
|
66 |
+
keep_prob = 1 - drop_prob
|
67 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
68 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
69 |
+
random_tensor.floor_() # binarize
|
70 |
+
output = x.div(keep_prob) * random_tensor
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
class DropPath(nn.Module):
|
75 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
76 |
+
"""
|
77 |
+
def __init__(self, drop_prob=None):
|
78 |
+
super(DropPath, self).__init__()
|
79 |
+
self.drop_prob = drop_prob
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return drop_path(x, self.drop_prob, self.training)
|
83 |
+
|
84 |
+
class PatchEmbed(nn.Module):
|
85 |
+
""" 2D Image to Patch Embedding
|
86 |
+
"""
|
87 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
|
88 |
+
super().__init__()
|
89 |
+
img_size = to_2tuple(img_size)
|
90 |
+
patch_size = to_2tuple(patch_size)
|
91 |
+
patch_stride = to_2tuple(patch_stride)
|
92 |
+
self.img_size = img_size
|
93 |
+
self.patch_size = patch_size
|
94 |
+
self.patch_stride = patch_stride
|
95 |
+
self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
|
96 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
97 |
+
self.flatten = flatten
|
98 |
+
self.in_chans = in_chans
|
99 |
+
self.embed_dim = embed_dim
|
100 |
+
|
101 |
+
padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
|
102 |
+
|
103 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
|
104 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
B, C, H, W = x.shape
|
108 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
109 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
110 |
+
x = self.proj(x)
|
111 |
+
if self.flatten:
|
112 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
113 |
+
x = self.norm(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
class Mlp(nn.Module):
|
117 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
118 |
+
"""
|
119 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
120 |
+
super().__init__()
|
121 |
+
out_features = out_features or in_features
|
122 |
+
hidden_features = hidden_features or in_features
|
123 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
124 |
+
self.act = act_layer()
|
125 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
126 |
+
self.drop = nn.Dropout(drop)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.fc1(x)
|
130 |
+
x = self.act(x)
|
131 |
+
x = self.drop(x)
|
132 |
+
x = self.fc2(x)
|
133 |
+
x = self.drop(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
137 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
138 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
139 |
+
def norm_cdf(x):
|
140 |
+
# Computes standard normal cumulative distribution function
|
141 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
142 |
+
|
143 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
144 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
145 |
+
"The distribution of values may be incorrect.",
|
146 |
+
stacklevel=2)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
# Values are generated by using a truncated uniform distribution and
|
150 |
+
# then using the inverse CDF for the normal distribution.
|
151 |
+
# Get upper and lower cdf values
|
152 |
+
l = norm_cdf((a - mean) / std)
|
153 |
+
u = norm_cdf((b - mean) / std)
|
154 |
+
|
155 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
156 |
+
# [2l-1, 2u-1].
|
157 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
158 |
+
|
159 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
160 |
+
# standard normal
|
161 |
+
tensor.erfinv_()
|
162 |
+
|
163 |
+
# Transform to proper mean, std
|
164 |
+
tensor.mul_(std * math.sqrt(2.))
|
165 |
+
tensor.add_(mean)
|
166 |
+
|
167 |
+
# Clamp to ensure it's in the proper range
|
168 |
+
tensor.clamp_(min=a, max=b)
|
169 |
+
return tensor
|
170 |
+
|
171 |
+
|
172 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
173 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
174 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
175 |
+
normal distribution. The values are effectively drawn from the
|
176 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
177 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
178 |
+
the bounds. The method used for generating the random values works
|
179 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
180 |
+
Args:
|
181 |
+
tensor: an n-dimensional `torch.Tensor`
|
182 |
+
mean: the mean of the normal distribution
|
183 |
+
std: the standard deviation of the normal distribution
|
184 |
+
a: the minimum cutoff value
|
185 |
+
b: the maximum cutoff value
|
186 |
+
Examples:
|
187 |
+
>>> w = torch.empty(3, 5)
|
188 |
+
>>> nn.init.trunc_normal_(w)
|
189 |
+
"""
|
190 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
191 |
+
|
192 |
+
|
193 |
+
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
194 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
195 |
+
if mode == 'fan_in':
|
196 |
+
denom = fan_in
|
197 |
+
elif mode == 'fan_out':
|
198 |
+
denom = fan_out
|
199 |
+
elif mode == 'fan_avg':
|
200 |
+
denom = (fan_in + fan_out) / 2
|
201 |
+
|
202 |
+
variance = scale / denom
|
203 |
+
|
204 |
+
if distribution == "truncated_normal":
|
205 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
206 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
207 |
+
elif distribution == "normal":
|
208 |
+
tensor.normal_(std=math.sqrt(variance))
|
209 |
+
elif distribution == "uniform":
|
210 |
+
bound = math.sqrt(3 * variance)
|
211 |
+
tensor.uniform_(-bound, bound)
|
212 |
+
else:
|
213 |
+
raise ValueError(f"invalid distribution {distribution}")
|
214 |
+
|
215 |
+
|
216 |
+
def lecun_normal_(tensor):
|
217 |
+
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
218 |
+
|
219 |
+
|
220 |
+
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
|
221 |
+
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
|
222 |
+
|
223 |
+
def window_partition(x, window_size):
|
224 |
+
"""
|
225 |
+
Args:
|
226 |
+
x: (B, H, W, C)
|
227 |
+
window_size (int): window size
|
228 |
+
Returns:
|
229 |
+
windows: (num_windows*B, window_size, window_size, C)
|
230 |
+
"""
|
231 |
+
B, H, W, C = x.shape
|
232 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
233 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
234 |
+
return windows
|
235 |
+
|
236 |
+
|
237 |
+
def window_reverse(windows, window_size, H, W):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
windows: (num_windows*B, window_size, window_size, C)
|
241 |
+
window_size (int): Window size
|
242 |
+
H (int): Height of image
|
243 |
+
W (int): Width of image
|
244 |
+
Returns:
|
245 |
+
x: (B, H, W, C)
|
246 |
+
"""
|
247 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
248 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
249 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class WindowAttention(nn.Module):
|
254 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
255 |
+
It supports both of shifted and non-shifted window.
|
256 |
+
Args:
|
257 |
+
dim (int): Number of input channels.
|
258 |
+
window_size (tuple[int]): The height and width of the window.
|
259 |
+
num_heads (int): Number of attention heads.
|
260 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
261 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
262 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
263 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
267 |
+
|
268 |
+
super().__init__()
|
269 |
+
self.dim = dim
|
270 |
+
self.window_size = window_size # Wh, Ww
|
271 |
+
self.num_heads = num_heads
|
272 |
+
head_dim = dim // num_heads
|
273 |
+
self.scale = qk_scale or head_dim ** -0.5
|
274 |
+
|
275 |
+
# define a parameter table of relative position bias
|
276 |
+
self.relative_position_bias_table = nn.Parameter(
|
277 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
278 |
+
|
279 |
+
# get pair-wise relative position index for each token inside the window
|
280 |
+
coords_h = torch.arange(self.window_size[0])
|
281 |
+
coords_w = torch.arange(self.window_size[1])
|
282 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
283 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
284 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
285 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
286 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
287 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
288 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
289 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
290 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
291 |
+
|
292 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
293 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
294 |
+
self.proj = nn.Linear(dim, dim)
|
295 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
296 |
+
|
297 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
298 |
+
self.softmax = nn.Softmax(dim=-1)
|
299 |
+
|
300 |
+
def forward(self, x, mask=None):
|
301 |
+
"""
|
302 |
+
Args:
|
303 |
+
x: input features with shape of (num_windows*B, N, C)
|
304 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
305 |
+
"""
|
306 |
+
B_, N, C = x.shape
|
307 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
308 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
309 |
+
|
310 |
+
q = q * self.scale
|
311 |
+
attn = (q @ k.transpose(-2, -1))
|
312 |
+
|
313 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
314 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
315 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
316 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
317 |
+
|
318 |
+
if mask is not None:
|
319 |
+
nW = mask.shape[0]
|
320 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
321 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
322 |
+
attn = self.softmax(attn)
|
323 |
+
else:
|
324 |
+
attn = self.softmax(attn)
|
325 |
+
|
326 |
+
attn = self.attn_drop(attn)
|
327 |
+
|
328 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
329 |
+
x = self.proj(x)
|
330 |
+
x = self.proj_drop(x)
|
331 |
+
return x, attn
|
332 |
+
|
333 |
+
def extra_repr(self):
|
334 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
335 |
+
|
336 |
+
|
337 |
+
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
|
338 |
+
class SwinTransformerBlock(nn.Module):
|
339 |
+
r""" Swin Transformer Block.
|
340 |
+
Args:
|
341 |
+
dim (int): Number of input channels.
|
342 |
+
input_resolution (tuple[int]): Input resulotion.
|
343 |
+
num_heads (int): Number of attention heads.
|
344 |
+
window_size (int): Window size.
|
345 |
+
shift_size (int): Shift size for SW-MSA.
|
346 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
347 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
348 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
349 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
350 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
351 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
352 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
353 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
357 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
358 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
|
359 |
+
super().__init__()
|
360 |
+
self.dim = dim
|
361 |
+
self.input_resolution = input_resolution
|
362 |
+
self.num_heads = num_heads
|
363 |
+
self.window_size = window_size
|
364 |
+
self.shift_size = shift_size
|
365 |
+
self.mlp_ratio = mlp_ratio
|
366 |
+
self.norm_before_mlp = norm_before_mlp
|
367 |
+
if min(self.input_resolution) <= self.window_size:
|
368 |
+
# if window size is larger than input resolution, we don't partition windows
|
369 |
+
self.shift_size = 0
|
370 |
+
self.window_size = min(self.input_resolution)
|
371 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
372 |
+
|
373 |
+
self.norm1 = norm_layer(dim)
|
374 |
+
self.attn = WindowAttention(
|
375 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
376 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
377 |
+
|
378 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
379 |
+
if self.norm_before_mlp == 'ln':
|
380 |
+
self.norm2 = nn.LayerNorm(dim)
|
381 |
+
elif self.norm_before_mlp == 'bn':
|
382 |
+
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
|
383 |
+
else:
|
384 |
+
raise NotImplementedError
|
385 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
386 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
387 |
+
|
388 |
+
if self.shift_size > 0:
|
389 |
+
# calculate attention mask for SW-MSA
|
390 |
+
H, W = self.input_resolution
|
391 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
392 |
+
h_slices = (slice(0, -self.window_size),
|
393 |
+
slice(-self.window_size, -self.shift_size),
|
394 |
+
slice(-self.shift_size, None))
|
395 |
+
w_slices = (slice(0, -self.window_size),
|
396 |
+
slice(-self.window_size, -self.shift_size),
|
397 |
+
slice(-self.shift_size, None))
|
398 |
+
cnt = 0
|
399 |
+
for h in h_slices:
|
400 |
+
for w in w_slices:
|
401 |
+
img_mask[:, h, w, :] = cnt
|
402 |
+
cnt += 1
|
403 |
+
|
404 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
405 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
406 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
407 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
408 |
+
else:
|
409 |
+
attn_mask = None
|
410 |
+
|
411 |
+
self.register_buffer("attn_mask", attn_mask)
|
412 |
+
|
413 |
+
def forward(self, x):
|
414 |
+
# pdb.set_trace()
|
415 |
+
H, W = self.input_resolution
|
416 |
+
# print("H: ", H)
|
417 |
+
# print("W: ", W)
|
418 |
+
# pdb.set_trace()
|
419 |
+
B, L, C = x.shape
|
420 |
+
# assert L == H * W, "input feature has wrong size"
|
421 |
+
|
422 |
+
shortcut = x
|
423 |
+
x = self.norm1(x)
|
424 |
+
x = x.view(B, H, W, C)
|
425 |
+
|
426 |
+
# cyclic shift
|
427 |
+
if self.shift_size > 0:
|
428 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
429 |
+
else:
|
430 |
+
shifted_x = x
|
431 |
+
|
432 |
+
# partition windows
|
433 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
434 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
435 |
+
|
436 |
+
# W-MSA/SW-MSA
|
437 |
+
attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
438 |
+
|
439 |
+
# merge windows
|
440 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
441 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
442 |
+
|
443 |
+
# reverse cyclic shift
|
444 |
+
if self.shift_size > 0:
|
445 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
446 |
+
else:
|
447 |
+
x = shifted_x
|
448 |
+
x = x.view(B, H * W, C)
|
449 |
+
|
450 |
+
# FFN
|
451 |
+
x = shortcut + self.drop_path(x)
|
452 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
453 |
+
|
454 |
+
return x, attn
|
455 |
+
|
456 |
+
def extra_repr(self):
|
457 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
458 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
459 |
+
|
460 |
+
|
461 |
+
|
462 |
+
class PatchMerging(nn.Module):
|
463 |
+
r""" Patch Merging Layer.
|
464 |
+
Args:
|
465 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
466 |
+
dim (int): Number of input channels.
|
467 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
468 |
+
"""
|
469 |
+
|
470 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
471 |
+
super().__init__()
|
472 |
+
self.input_resolution = input_resolution
|
473 |
+
self.dim = dim
|
474 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
475 |
+
self.norm = norm_layer(4 * dim)
|
476 |
+
|
477 |
+
def forward(self, x):
|
478 |
+
"""
|
479 |
+
x: B, H*W, C
|
480 |
+
"""
|
481 |
+
H, W = self.input_resolution
|
482 |
+
B, L, C = x.shape
|
483 |
+
assert L == H * W, "input feature has wrong size"
|
484 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
485 |
+
|
486 |
+
x = x.view(B, H, W, C)
|
487 |
+
|
488 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
489 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
490 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
491 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
492 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
493 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
494 |
+
|
495 |
+
x = self.norm(x)
|
496 |
+
x = self.reduction(x)
|
497 |
+
|
498 |
+
return x
|
499 |
+
|
500 |
+
def extra_repr(self):
|
501 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
502 |
+
|
503 |
+
|
504 |
+
class BasicLayer(nn.Module):
|
505 |
+
""" A basic Swin Transformer layer for one stage.
|
506 |
+
Args:
|
507 |
+
dim (int): Number of input channels.
|
508 |
+
input_resolution (tuple[int]): Input resolution.
|
509 |
+
depth (int): Number of blocks.
|
510 |
+
num_heads (int): Number of attention heads.
|
511 |
+
window_size (int): Local window size.
|
512 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
513 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
514 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
515 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
516 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
517 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
518 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
519 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
520 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
521 |
+
"""
|
522 |
+
|
523 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
524 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
525 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
526 |
+
norm_before_mlp='ln'):
|
527 |
+
|
528 |
+
super().__init__()
|
529 |
+
self.dim = dim
|
530 |
+
self.input_resolution = input_resolution
|
531 |
+
self.depth = depth
|
532 |
+
self.use_checkpoint = use_checkpoint
|
533 |
+
|
534 |
+
# build blocks
|
535 |
+
self.blocks = nn.ModuleList([
|
536 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
537 |
+
num_heads=num_heads, window_size=window_size,
|
538 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
539 |
+
mlp_ratio=mlp_ratio,
|
540 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
541 |
+
drop=drop, attn_drop=attn_drop,
|
542 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
543 |
+
norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
|
544 |
+
for i in range(depth)])
|
545 |
+
|
546 |
+
# patch merging layer
|
547 |
+
if downsample is not None:
|
548 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
549 |
+
else:
|
550 |
+
self.downsample = None
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
attns = []
|
554 |
+
for blk in self.blocks:
|
555 |
+
if self.use_checkpoint:
|
556 |
+
x = checkpoint.checkpoint(blk, x)
|
557 |
+
else:
|
558 |
+
x, attn = blk(x)
|
559 |
+
if not self.training:
|
560 |
+
attns.append(attn.unsqueeze(0))
|
561 |
+
if self.downsample is not None:
|
562 |
+
x = self.downsample(x)
|
563 |
+
if not self.training:
|
564 |
+
attn = torch.cat(attns, dim = 0)
|
565 |
+
attn = torch.mean(attn, dim = 0)
|
566 |
+
return x, attn
|
567 |
+
|
568 |
+
def extra_repr(self):
|
569 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
570 |
+
|
571 |
+
|
572 |
+
# The Core of HTSAT
|
573 |
+
class HTSAT_Swin_Transformer(nn.Module):
|
574 |
+
r"""HTSAT based on the Swin Transformer
|
575 |
+
Args:
|
576 |
+
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
|
577 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
578 |
+
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
|
579 |
+
in_chans (int): Number of input image channels. Default: 1 (mono)
|
580 |
+
num_classes (int): Number of classes for classification head. Default: 527
|
581 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
582 |
+
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
|
583 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
584 |
+
window_size (int): Window size. Default: 8
|
585 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
586 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
587 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
588 |
+
drop_rate (float): Dropout rate. Default: 0
|
589 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
590 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
591 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
592 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
593 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
594 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
595 |
+
config (module): The configuration Module from config.py
|
596 |
+
"""
|
597 |
+
|
598 |
+
def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
|
599 |
+
in_chans=1, num_classes=527,
|
600 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
|
601 |
+
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
602 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
603 |
+
norm_layer=nn.LayerNorm,
|
604 |
+
ape=False, patch_norm=True,
|
605 |
+
use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
|
606 |
+
super(HTSAT_Swin_Transformer, self).__init__()
|
607 |
+
|
608 |
+
self.config = config
|
609 |
+
self.spec_size = spec_size
|
610 |
+
self.patch_stride = patch_stride
|
611 |
+
self.patch_size = patch_size
|
612 |
+
self.window_size = window_size
|
613 |
+
self.embed_dim = embed_dim
|
614 |
+
self.depths = depths
|
615 |
+
self.ape = ape
|
616 |
+
self.in_chans = in_chans
|
617 |
+
self.num_classes = num_classes
|
618 |
+
self.num_heads = num_heads
|
619 |
+
self.num_layers = len(self.depths)
|
620 |
+
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
|
621 |
+
|
622 |
+
self.drop_rate = drop_rate
|
623 |
+
self.attn_drop_rate = attn_drop_rate
|
624 |
+
self.drop_path_rate = drop_path_rate
|
625 |
+
|
626 |
+
self.qkv_bias = qkv_bias
|
627 |
+
self.qk_scale = None
|
628 |
+
|
629 |
+
self.patch_norm = patch_norm
|
630 |
+
self.norm_layer = norm_layer if self.patch_norm else None
|
631 |
+
self.norm_before_mlp = norm_before_mlp
|
632 |
+
self.mlp_ratio = mlp_ratio
|
633 |
+
|
634 |
+
self.use_checkpoint = use_checkpoint
|
635 |
+
|
636 |
+
# process mel-spec ; used only once
|
637 |
+
self.freq_ratio = self.spec_size // self.config.mel_bins
|
638 |
+
window = 'hann'
|
639 |
+
center = True
|
640 |
+
pad_mode = 'reflect'
|
641 |
+
ref = 1.0
|
642 |
+
amin = 1e-10
|
643 |
+
top_db = None
|
644 |
+
self.interpolate_ratio = 32 # Downsampled ratio
|
645 |
+
# Spectrogram extractor
|
646 |
+
self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
|
647 |
+
win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
|
648 |
+
freeze_parameters=True)
|
649 |
+
# Logmel feature extractor
|
650 |
+
self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
|
651 |
+
n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
|
652 |
+
freeze_parameters=True)
|
653 |
+
# Spec augmenter
|
654 |
+
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
|
655 |
+
freq_drop_width=8, freq_stripes_num=2) # 2 2
|
656 |
+
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
|
657 |
+
|
658 |
+
|
659 |
+
# split spctrogram into non-overlapping patches
|
660 |
+
self.patch_embed = PatchEmbed(
|
661 |
+
img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
|
662 |
+
embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
|
663 |
+
|
664 |
+
num_patches = self.patch_embed.num_patches
|
665 |
+
patches_resolution = self.patch_embed.grid_size
|
666 |
+
self.patches_resolution = patches_resolution
|
667 |
+
|
668 |
+
# absolute position embedding
|
669 |
+
if self.ape:
|
670 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
|
671 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
672 |
+
|
673 |
+
self.pos_drop = nn.Dropout(p=self.drop_rate)
|
674 |
+
|
675 |
+
# stochastic depth
|
676 |
+
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
|
677 |
+
|
678 |
+
# build layers
|
679 |
+
self.layers = nn.ModuleList()
|
680 |
+
for i_layer in range(self.num_layers):
|
681 |
+
layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
|
682 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
683 |
+
patches_resolution[1] // (2 ** i_layer)),
|
684 |
+
depth=self.depths[i_layer],
|
685 |
+
num_heads=self.num_heads[i_layer],
|
686 |
+
window_size=self.window_size,
|
687 |
+
mlp_ratio=self.mlp_ratio,
|
688 |
+
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
|
689 |
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
|
690 |
+
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
|
691 |
+
norm_layer=self.norm_layer,
|
692 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
693 |
+
use_checkpoint=use_checkpoint,
|
694 |
+
norm_before_mlp=self.norm_before_mlp)
|
695 |
+
self.layers.append(layer)
|
696 |
+
|
697 |
+
self.norm = self.norm_layer(self.num_features)
|
698 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
699 |
+
self.maxpool = nn.AdaptiveMaxPool1d(1)
|
700 |
+
|
701 |
+
if self.config.enable_tscam:
|
702 |
+
SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
|
703 |
+
self.tscam_conv = nn.Conv2d(
|
704 |
+
in_channels = self.num_features,
|
705 |
+
out_channels = self.num_classes,
|
706 |
+
kernel_size = (SF,3),
|
707 |
+
padding = (0,1)
|
708 |
+
)
|
709 |
+
self.head = nn.Linear(num_classes, num_classes)
|
710 |
+
else:
|
711 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
712 |
+
|
713 |
+
self.apply(self._init_weights)
|
714 |
+
|
715 |
+
def _init_weights(self, m):
|
716 |
+
if isinstance(m, nn.Linear):
|
717 |
+
trunc_normal_(m.weight, std=.02)
|
718 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
719 |
+
nn.init.constant_(m.bias, 0)
|
720 |
+
elif isinstance(m, nn.LayerNorm):
|
721 |
+
nn.init.constant_(m.bias, 0)
|
722 |
+
nn.init.constant_(m.weight, 1.0)
|
723 |
+
|
724 |
+
@torch.jit.ignore
|
725 |
+
def no_weight_decay(self):
|
726 |
+
return {'absolute_pos_embed'}
|
727 |
+
|
728 |
+
@torch.jit.ignore
|
729 |
+
def no_weight_decay_keywords(self):
|
730 |
+
return {'relative_position_bias_table'}
|
731 |
+
|
732 |
+
def forward_features(self, x):
|
733 |
+
frames_num = x.shape[2]
|
734 |
+
x = self.patch_embed(x)
|
735 |
+
if self.ape:
|
736 |
+
x = x + self.absolute_pos_embed
|
737 |
+
x = self.pos_drop(x)
|
738 |
+
for i, layer in enumerate(self.layers):
|
739 |
+
x, attn = layer(x)
|
740 |
+
|
741 |
+
if self.config.enable_tscam:
|
742 |
+
# for x
|
743 |
+
x = self.norm(x)
|
744 |
+
B, N, C = x.shape
|
745 |
+
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
|
746 |
+
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
|
747 |
+
x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
|
748 |
+
B, C, F, T = x.shape
|
749 |
+
# group 2D CNN
|
750 |
+
c_freq_bin = F // self.freq_ratio
|
751 |
+
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
|
752 |
+
x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
|
753 |
+
|
754 |
+
# get latent_output
|
755 |
+
latent_output = self.avgpool(torch.flatten(x,2))
|
756 |
+
latent_output = torch.flatten(latent_output, 1)
|
757 |
+
|
758 |
+
# display the attention map, if needed
|
759 |
+
if self.config.htsat_attn_heatmap:
|
760 |
+
# for attn
|
761 |
+
attn = torch.mean(attn, dim = 1)
|
762 |
+
attn = torch.mean(attn, dim = 1)
|
763 |
+
attn = attn.reshape(B, SF, ST)
|
764 |
+
c_freq_bin = SF // self.freq_ratio
|
765 |
+
attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
|
766 |
+
attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
|
767 |
+
attn = attn.mean(dim = 1)
|
768 |
+
attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
|
769 |
+
attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
|
770 |
+
attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
|
771 |
+
attn = attn.unsqueeze(dim = 2)
|
772 |
+
|
773 |
+
x = self.tscam_conv(x)
|
774 |
+
x = torch.flatten(x, 2) # B, C, T
|
775 |
+
|
776 |
+
if self.config.htsat_attn_heatmap:
|
777 |
+
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
|
778 |
+
else:
|
779 |
+
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
|
780 |
+
|
781 |
+
x = self.avgpool(x)
|
782 |
+
x = torch.flatten(x, 1)
|
783 |
+
|
784 |
+
if self.config.loss_type == "clip_ce":
|
785 |
+
output_dict = {
|
786 |
+
'framewise_output': fpx, # already sigmoided
|
787 |
+
'clipwise_output': x,
|
788 |
+
'latent_output': latent_output
|
789 |
+
}
|
790 |
+
else:
|
791 |
+
output_dict = {
|
792 |
+
'framewise_output': fpx, # already sigmoided
|
793 |
+
'clipwise_output': torch.sigmoid(x),
|
794 |
+
'latent_output': latent_output
|
795 |
+
}
|
796 |
+
|
797 |
+
else:
|
798 |
+
x = self.norm(x) # B N C
|
799 |
+
B, N, C = x.shape
|
800 |
+
|
801 |
+
fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
|
802 |
+
B, C, F, T = fpx.shape
|
803 |
+
c_freq_bin = F // self.freq_ratio
|
804 |
+
fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
|
805 |
+
fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
|
806 |
+
fpx = torch.sum(fpx, dim = 2)
|
807 |
+
fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
|
808 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
809 |
+
x = torch.flatten(x, 1)
|
810 |
+
if self.num_classes > 0:
|
811 |
+
x = self.head(x)
|
812 |
+
fpx = self.head(fpx)
|
813 |
+
output_dict = {'framewise_output': torch.sigmoid(fpx),
|
814 |
+
'clipwise_output': torch.sigmoid(x)}
|
815 |
+
return output_dict
|
816 |
+
|
817 |
+
def crop_wav(self, x, crop_size, spe_pos = None):
|
818 |
+
time_steps = x.shape[2]
|
819 |
+
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
|
820 |
+
for i in range(len(x)):
|
821 |
+
if spe_pos is None:
|
822 |
+
crop_pos = random.randint(0, time_steps - crop_size - 1)
|
823 |
+
else:
|
824 |
+
crop_pos = spe_pos
|
825 |
+
tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
|
826 |
+
return tx
|
827 |
+
|
828 |
+
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
|
829 |
+
def reshape_wav2img(self, x):
|
830 |
+
B, C, T, F = x.shape
|
831 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
832 |
+
target_F = self.spec_size // self.freq_ratio
|
833 |
+
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
|
834 |
+
# to avoid bicubic zero error
|
835 |
+
if T < target_T:
|
836 |
+
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
|
837 |
+
if F < target_F:
|
838 |
+
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
|
839 |
+
x = x.permute(0,1,3,2).contiguous()
|
840 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
|
841 |
+
# print(x.shape)
|
842 |
+
x = x.permute(0,1,3,2,4).contiguous()
|
843 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
|
844 |
+
return x
|
845 |
+
|
846 |
+
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
|
847 |
+
def repeat_wat2img(self, x, cur_pos):
|
848 |
+
B, C, T, F = x.shape
|
849 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
850 |
+
target_F = self.spec_size // self.freq_ratio
|
851 |
+
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
|
852 |
+
# to avoid bicubic zero error
|
853 |
+
if T < target_T:
|
854 |
+
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
|
855 |
+
if F < target_F:
|
856 |
+
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
|
857 |
+
x = x.permute(0,1,3,2).contiguous() # B C F T
|
858 |
+
x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
|
859 |
+
x = x.repeat(repeats = (1,1,4,1))
|
860 |
+
return x
|
861 |
+
|
862 |
+
def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
|
863 |
+
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
|
864 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
865 |
+
|
866 |
+
|
867 |
+
x = x.transpose(1, 3)
|
868 |
+
x = self.bn0(x)
|
869 |
+
x = x.transpose(1, 3)
|
870 |
+
if self.training:
|
871 |
+
x = self.spec_augmenter(x)
|
872 |
+
if self.training and mixup_lambda is not None:
|
873 |
+
x = do_mixup(x, mixup_lambda)
|
874 |
+
|
875 |
+
if infer_mode:
|
876 |
+
# in infer mode. we need to handle different length audio input
|
877 |
+
frame_num = x.shape[2]
|
878 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
879 |
+
repeat_ratio = math.floor(target_T / frame_num)
|
880 |
+
x = x.repeat(repeats=(1,1,repeat_ratio,1))
|
881 |
+
x = self.reshape_wav2img(x)
|
882 |
+
output_dict = self.forward_features(x)
|
883 |
+
elif self.config.enable_repeat_mode:
|
884 |
+
if self.training:
|
885 |
+
cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
|
886 |
+
x = self.repeat_wat2img(x, cur_pos)
|
887 |
+
output_dict = self.forward_features(x)
|
888 |
+
else:
|
889 |
+
output_dicts = []
|
890 |
+
for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
|
891 |
+
tx = x.clone()
|
892 |
+
tx = self.repeat_wat2img(tx, cur_pos)
|
893 |
+
output_dicts.append(self.forward_features(tx))
|
894 |
+
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
|
895 |
+
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
|
896 |
+
for d in output_dicts:
|
897 |
+
clipwise_output += d["clipwise_output"]
|
898 |
+
framewise_output += d["framewise_output"]
|
899 |
+
clipwise_output = clipwise_output / len(output_dicts)
|
900 |
+
framewise_output = framewise_output / len(output_dicts)
|
901 |
+
|
902 |
+
output_dict = {
|
903 |
+
'framewise_output': framewise_output,
|
904 |
+
'clipwise_output': clipwise_output
|
905 |
+
}
|
906 |
+
else:
|
907 |
+
if x.shape[2] > self.freq_ratio * self.spec_size:
|
908 |
+
if self.training:
|
909 |
+
x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
|
910 |
+
x = self.reshape_wav2img(x)
|
911 |
+
output_dict = self.forward_features(x)
|
912 |
+
else:
|
913 |
+
# Change: Hard code here
|
914 |
+
overlap_size = 344 #(x.shape[2] - 1) // 4
|
915 |
+
output_dicts = []
|
916 |
+
crop_size = 689 #(x.shape[2] - 1) // 2
|
917 |
+
for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
|
918 |
+
tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
|
919 |
+
tx = self.reshape_wav2img(tx)
|
920 |
+
output_dicts.append(self.forward_features(tx))
|
921 |
+
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
|
922 |
+
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
|
923 |
+
latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
|
924 |
+
for d in output_dicts:
|
925 |
+
clipwise_output += d["clipwise_output"]
|
926 |
+
framewise_output += d["framewise_output"]
|
927 |
+
latent_output += d["latent_output"]
|
928 |
+
clipwise_output = clipwise_output / len(output_dicts)
|
929 |
+
framewise_output = framewise_output / len(output_dicts)
|
930 |
+
latent_output = latent_output / len(output_dicts)
|
931 |
+
output_dict = {
|
932 |
+
'framewise_output': framewise_output,
|
933 |
+
'clipwise_output': clipwise_output,
|
934 |
+
'latent_output': latent_output,
|
935 |
+
}
|
936 |
+
else: # this part is typically used, and most easy one
|
937 |
+
x = self.reshape_wav2img(x)
|
938 |
+
output_dict = self.forward_features(x)
|
939 |
+
# x = self.head(x)
|
940 |
+
return output_dict
|
941 |
+
|
942 |
+
class HTSATWrapper(nn.Module):
|
943 |
+
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
|
944 |
+
fmax, classes_num, out_emb):
|
945 |
+
super().__init__()
|
946 |
+
|
947 |
+
# print("parameters are being overidden when using HTSAT")
|
948 |
+
# print("HTSAT only support loading a pretrained model on AudioSet")
|
949 |
+
# @TODO later look at what parameters are same and can be merged
|
950 |
+
|
951 |
+
self.htsat = HTSAT_Swin_Transformer(config=config)
|
952 |
+
|
953 |
+
def forward(self, x):
|
954 |
+
out_dict = self.htsat(x)
|
955 |
+
out_dict['embedding'] = out_dict['latent_output']
|
956 |
+
return out_dict
|
ms_clap/src/models/mapper.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as nnf
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from enum import Enum
|
7 |
+
from transformers import GPT2LMHeadModel
|
8 |
+
from typing import Tuple, Optional, Union
|
9 |
+
|
10 |
+
def get_clapcap(name: str):
|
11 |
+
if name == "ClapCaption":
|
12 |
+
return ClapCaptionModel
|
13 |
+
else:
|
14 |
+
raise Exception('The ClapCap model {} is incorrect or not supported'.format(name))
|
15 |
+
|
16 |
+
class MappingType(Enum):
|
17 |
+
MLP = 'mlp'
|
18 |
+
Transformer = 'transformer'
|
19 |
+
|
20 |
+
class MLP(nn.Module):
|
21 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
22 |
+
super(MLP, self).__init__()
|
23 |
+
layers = []
|
24 |
+
for i in range(len(sizes) - 1):
|
25 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
26 |
+
if i < len(sizes) - 2:
|
27 |
+
layers.append(act())
|
28 |
+
self.model = nn.Sequential(*layers)
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
31 |
+
return self.model(x)
|
32 |
+
|
33 |
+
|
34 |
+
class MlpTransformer(nn.Module):
|
35 |
+
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
|
36 |
+
super().__init__()
|
37 |
+
out_d = out_d if out_d is not None else in_dim
|
38 |
+
self.fc1 = nn.Linear(in_dim, h_dim)
|
39 |
+
self.act = act
|
40 |
+
self.fc2 = nn.Linear(h_dim, out_d)
|
41 |
+
self.dropout = nn.Dropout(dropout)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.fc1(x)
|
45 |
+
x = self.act(x)
|
46 |
+
x = self.dropout(x)
|
47 |
+
x = self.fc2(x)
|
48 |
+
x = self.dropout(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class MultiHeadAttention(nn.Module):
|
52 |
+
|
53 |
+
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
54 |
+
super().__init__()
|
55 |
+
self.num_heads = num_heads
|
56 |
+
head_dim = dim_self // num_heads
|
57 |
+
self.scale = head_dim ** -0.5
|
58 |
+
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
59 |
+
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
60 |
+
self.project = nn.Linear(dim_self, dim_self)
|
61 |
+
self.dropout = nn.Dropout(dropout)
|
62 |
+
|
63 |
+
def forward(self, x, y=None, mask=None):
|
64 |
+
y = y if y is not None else x
|
65 |
+
b, n, c = x.shape
|
66 |
+
_, m, d = y.shape
|
67 |
+
# b n h dh
|
68 |
+
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
69 |
+
# b m 2 h dh
|
70 |
+
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
71 |
+
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
72 |
+
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
73 |
+
if mask is not None:
|
74 |
+
if mask.dim() == 2:
|
75 |
+
mask = mask.unsqueeze(1)
|
76 |
+
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
77 |
+
attention = attention.softmax(dim=2)
|
78 |
+
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
79 |
+
out = self.project(out)
|
80 |
+
return out, attention
|
81 |
+
|
82 |
+
|
83 |
+
class TransformerLayer(nn.Module):
|
84 |
+
|
85 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
86 |
+
x_, attention = self.attn(self.norm1(x), y, mask)
|
87 |
+
x = x + x_
|
88 |
+
x = x + self.mlp(self.norm2(x))
|
89 |
+
return x, attention
|
90 |
+
|
91 |
+
def forward(self, x, y=None, mask=None):
|
92 |
+
x = x + self.attn(self.norm1(x), y, mask)[0]
|
93 |
+
x = x + self.mlp(self.norm2(x))
|
94 |
+
return x
|
95 |
+
|
96 |
+
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
|
97 |
+
norm_layer: nn.Module = nn.LayerNorm):
|
98 |
+
super().__init__()
|
99 |
+
self.norm1 = norm_layer(dim_self)
|
100 |
+
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
101 |
+
self.norm2 = norm_layer(dim_self)
|
102 |
+
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
103 |
+
|
104 |
+
|
105 |
+
class Transformer(nn.Module):
|
106 |
+
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
107 |
+
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
108 |
+
super(Transformer, self).__init__()
|
109 |
+
dim_ref = dim_ref if dim_ref is not None else dim_self
|
110 |
+
self.enc_dec = enc_dec
|
111 |
+
if enc_dec:
|
112 |
+
num_layers = num_layers * 2
|
113 |
+
layers = []
|
114 |
+
for i in range(num_layers):
|
115 |
+
if i % 2 == 0 and enc_dec: # cross
|
116 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
117 |
+
elif enc_dec: # self
|
118 |
+
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
119 |
+
else: # self or cross
|
120 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
121 |
+
self.layers = nn.ModuleList(layers)
|
122 |
+
|
123 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
124 |
+
attentions = []
|
125 |
+
for layer in self.layers:
|
126 |
+
x, att = layer.forward_with_attention(x, y, mask)
|
127 |
+
attentions.append(att)
|
128 |
+
return x, attentions
|
129 |
+
|
130 |
+
def forward(self, x, y=None, mask=None):
|
131 |
+
for i, layer in enumerate(self.layers):
|
132 |
+
if i % 2 == 0 and self.enc_dec: # cross
|
133 |
+
x = layer(x, y)
|
134 |
+
elif self.enc_dec: # self
|
135 |
+
x = layer(x, x, mask)
|
136 |
+
else: # self or cross
|
137 |
+
x = layer(x, y, mask)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class TransformerMapper(nn.Module):
|
142 |
+
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
143 |
+
super(TransformerMapper, self).__init__()
|
144 |
+
self.clip_length = clip_length
|
145 |
+
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
146 |
+
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
147 |
+
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
151 |
+
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
152 |
+
prefix = torch.cat((x, prefix), dim=1)
|
153 |
+
out = self.transformer(prefix)[:, self.clip_length:]
|
154 |
+
return out
|
155 |
+
|
156 |
+
class ClapCaptionModel(nn.Module):
|
157 |
+
def __init__(self, clap, text_decoder: str, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
|
158 |
+
num_layers: int = 8, normalize_prefix: bool = True, mapping_type: str = None,\
|
159 |
+
freeze_audio_encoder_weights: bool = True, freeze_gpt_weights: bool = True):
|
160 |
+
super(ClapCaptionModel, self).__init__()
|
161 |
+
self.clap = clap.audio_encoder
|
162 |
+
self.prefix_length = prefix_length
|
163 |
+
self.normalize_prefix = normalize_prefix
|
164 |
+
self.gpt = GPT2LMHeadModel.from_pretrained(text_decoder)
|
165 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
166 |
+
if mapping_type == 'mlp':
|
167 |
+
self.clap_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
|
168 |
+
self.gpt_embedding_size * prefix_length))
|
169 |
+
else:
|
170 |
+
self.clap_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
|
171 |
+
clip_length, num_layers)
|
172 |
+
|
173 |
+
# Freeze all CLAP parameters
|
174 |
+
if freeze_audio_encoder_weights:
|
175 |
+
for p in self.clap.parameters():
|
176 |
+
p.requires_grad = False
|
177 |
+
|
178 |
+
if freeze_gpt_weights:
|
179 |
+
for p in self.gpt.parameters():
|
180 |
+
p.requires_grad = False
|
181 |
+
|
182 |
+
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
183 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
184 |
+
|
185 |
+
def forward(self, audios: torch.Tensor, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
186 |
+
labels: Optional[torch.Tensor] = None):
|
187 |
+
# get audio embeddings
|
188 |
+
prefix, _ = self.clap(audios)
|
189 |
+
# normalize prefix (audio embedding)
|
190 |
+
if self.normalize_prefix:
|
191 |
+
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
|
192 |
+
|
193 |
+
embedding_text = self.gpt.transformer.wte(tokens['input_ids'])
|
194 |
+
prefix_projections = self.clap_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
195 |
+
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
196 |
+
if labels is not None:
|
197 |
+
dummy_token = self.get_dummy_token(tokens['input_ids'].shape[0], tokens['input_ids'].device)
|
198 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
199 |
+
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
200 |
+
return out
|
ms_clap/src/models/pytorch_utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def move_data_to_device(x, device):
|
8 |
+
if 'float' in str(x.dtype):
|
9 |
+
x = torch.Tensor(x)
|
10 |
+
elif 'int' in str(x.dtype):
|
11 |
+
x = torch.LongTensor(x)
|
12 |
+
else:
|
13 |
+
return x
|
14 |
+
|
15 |
+
return x.to(device)
|
16 |
+
|
17 |
+
|
18 |
+
def do_mixup(x, mixup_lambda):
|
19 |
+
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
|
20 |
+
(1, 3, 5, ...).
|
21 |
+
Args:
|
22 |
+
x: (batch_size * 2, ...)
|
23 |
+
mixup_lambda: (batch_size * 2,)
|
24 |
+
Returns:
|
25 |
+
out: (batch_size, ...)
|
26 |
+
"""
|
27 |
+
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
|
28 |
+
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
def append_to_dict(dict, key, value):
|
33 |
+
if key in dict.keys():
|
34 |
+
dict[key].append(value)
|
35 |
+
else:
|
36 |
+
dict[key] = [value]
|
37 |
+
|
38 |
+
|
39 |
+
def interpolate(x, ratio):
|
40 |
+
"""Interpolate data in time domain. This is used to compensate the
|
41 |
+
resolution reduction in downsampling of a CNN.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x: (batch_size, time_steps, classes_num)
|
45 |
+
ratio: int, ratio to interpolate
|
46 |
+
Returns:
|
47 |
+
upsampled: (batch_size, time_steps * ratio, classes_num)
|
48 |
+
"""
|
49 |
+
(batch_size, time_steps, classes_num) = x.shape
|
50 |
+
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
|
51 |
+
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
|
52 |
+
return upsampled
|
53 |
+
|
54 |
+
|
55 |
+
def pad_framewise_output(framewise_output, frames_num):
|
56 |
+
"""Pad framewise_output to the same length as input frames. The pad value
|
57 |
+
is the same as the value of the last frame.
|
58 |
+
Args:
|
59 |
+
framewise_output: (batch_size, frames_num, classes_num)
|
60 |
+
frames_num: int, number of frames to pad
|
61 |
+
Outputs:
|
62 |
+
output: (batch_size, frames_num, classes_num)
|
63 |
+
"""
|
64 |
+
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
|
65 |
+
"""tensor for padding"""
|
66 |
+
|
67 |
+
output = torch.cat((framewise_output, pad), dim=1)
|
68 |
+
"""(batch_size, frames_num, classes_num)"""
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
|
73 |
+
def count_parameters(model):
|
74 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
75 |
+
|
76 |
+
|
77 |
+
def count_flops(model, audio_length):
|
78 |
+
"""Count flops. Code modified from others' implementation.
|
79 |
+
"""
|
80 |
+
multiply_adds = True
|
81 |
+
list_conv2d=[]
|
82 |
+
def conv2d_hook(self, input, output):
|
83 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
84 |
+
output_channels, output_height, output_width = output[0].size()
|
85 |
+
|
86 |
+
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
|
87 |
+
bias_ops = 1 if self.bias is not None else 0
|
88 |
+
|
89 |
+
params = output_channels * (kernel_ops + bias_ops)
|
90 |
+
flops = batch_size * params * output_height * output_width
|
91 |
+
|
92 |
+
list_conv2d.append(flops)
|
93 |
+
|
94 |
+
list_conv1d=[]
|
95 |
+
def conv1d_hook(self, input, output):
|
96 |
+
batch_size, input_channels, input_length = input[0].size()
|
97 |
+
output_channels, output_length = output[0].size()
|
98 |
+
|
99 |
+
kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
|
100 |
+
bias_ops = 1 if self.bias is not None else 0
|
101 |
+
|
102 |
+
params = output_channels * (kernel_ops + bias_ops)
|
103 |
+
flops = batch_size * params * output_length
|
104 |
+
|
105 |
+
list_conv1d.append(flops)
|
106 |
+
|
107 |
+
list_linear=[]
|
108 |
+
def linear_hook(self, input, output):
|
109 |
+
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
|
110 |
+
|
111 |
+
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
|
112 |
+
bias_ops = self.bias.nelement()
|
113 |
+
|
114 |
+
flops = batch_size * (weight_ops + bias_ops)
|
115 |
+
list_linear.append(flops)
|
116 |
+
|
117 |
+
list_bn=[]
|
118 |
+
def bn_hook(self, input, output):
|
119 |
+
list_bn.append(input[0].nelement() * 2)
|
120 |
+
|
121 |
+
list_relu=[]
|
122 |
+
def relu_hook(self, input, output):
|
123 |
+
list_relu.append(input[0].nelement() * 2)
|
124 |
+
|
125 |
+
list_pooling2d=[]
|
126 |
+
def pooling2d_hook(self, input, output):
|
127 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
128 |
+
output_channels, output_height, output_width = output[0].size()
|
129 |
+
|
130 |
+
kernel_ops = self.kernel_size * self.kernel_size
|
131 |
+
bias_ops = 0
|
132 |
+
params = output_channels * (kernel_ops + bias_ops)
|
133 |
+
flops = batch_size * params * output_height * output_width
|
134 |
+
|
135 |
+
list_pooling2d.append(flops)
|
136 |
+
|
137 |
+
list_pooling1d=[]
|
138 |
+
def pooling1d_hook(self, input, output):
|
139 |
+
batch_size, input_channels, input_length = input[0].size()
|
140 |
+
output_channels, output_length = output[0].size()
|
141 |
+
|
142 |
+
kernel_ops = self.kernel_size[0]
|
143 |
+
bias_ops = 0
|
144 |
+
|
145 |
+
params = output_channels * (kernel_ops + bias_ops)
|
146 |
+
flops = batch_size * params * output_length
|
147 |
+
|
148 |
+
list_pooling2d.append(flops)
|
149 |
+
|
150 |
+
def foo(net):
|
151 |
+
childrens = list(net.children())
|
152 |
+
if not childrens:
|
153 |
+
if isinstance(net, nn.Conv2d):
|
154 |
+
net.register_forward_hook(conv2d_hook)
|
155 |
+
elif isinstance(net, nn.Conv1d):
|
156 |
+
net.register_forward_hook(conv1d_hook)
|
157 |
+
elif isinstance(net, nn.Linear):
|
158 |
+
net.register_forward_hook(linear_hook)
|
159 |
+
elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
|
160 |
+
net.register_forward_hook(bn_hook)
|
161 |
+
elif isinstance(net, nn.ReLU):
|
162 |
+
net.register_forward_hook(relu_hook)
|
163 |
+
elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
|
164 |
+
net.register_forward_hook(pooling2d_hook)
|
165 |
+
elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
|
166 |
+
net.register_forward_hook(pooling1d_hook)
|
167 |
+
else:
|
168 |
+
print('Warning: flop of module {} is not counted!'.format(net))
|
169 |
+
return
|
170 |
+
for c in childrens:
|
171 |
+
foo(c)
|
172 |
+
|
173 |
+
# Register hook
|
174 |
+
foo(model)
|
175 |
+
|
176 |
+
device = device = next(model.parameters()).device
|
177 |
+
input = torch.rand(1, audio_length).to(device)
|
178 |
+
|
179 |
+
out = model(input)
|
180 |
+
|
181 |
+
total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \
|
182 |
+
sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d)
|
183 |
+
|
184 |
+
return total_flops
|
ms_clap/src/models/utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import yaml
|
3 |
+
import sys
|
4 |
+
|
5 |
+
def read_config_as_args(config_path,args=None,is_config_str=False):
|
6 |
+
return_dict = {}
|
7 |
+
|
8 |
+
if config_path is not None:
|
9 |
+
if is_config_str:
|
10 |
+
yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
|
11 |
+
else:
|
12 |
+
with open(config_path, "r") as f:
|
13 |
+
yml_config = yaml.load(f, Loader=yaml.FullLoader)
|
14 |
+
|
15 |
+
if args != None:
|
16 |
+
for k, v in yml_config.items():
|
17 |
+
if k in args.__dict__:
|
18 |
+
args.__dict__[k] = v
|
19 |
+
else:
|
20 |
+
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
|
21 |
+
else:
|
22 |
+
for k, v in yml_config.items():
|
23 |
+
return_dict[k] = v
|
24 |
+
|
25 |
+
args = args if args != None else return_dict
|
26 |
+
return argparse.Namespace(**args)
|
ms_clap/src/zero_shot_classification.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an example using CLAP to perform zeroshot
|
3 |
+
classification on ESC50 (https://github.com/karolpiczak/ESC-50).
|
4 |
+
"""
|
5 |
+
|
6 |
+
from CLAPWrapper import CLAPWrapper
|
7 |
+
from esc50_dataset import ESC50
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
from sklearn.metrics import accuracy_score
|
12 |
+
|
13 |
+
# Load dataset
|
14 |
+
root_path = "root_path" # Folder with ESC-50-master/
|
15 |
+
dataset = ESC50(root=root_path, download=True) #If download=False code assumes base_folder='ESC-50-master' in esc50_dataset.py
|
16 |
+
prompt = 'this is the sound of '
|
17 |
+
y = [prompt + x for x in dataset.classes]
|
18 |
+
|
19 |
+
# Load and initialize CLAP
|
20 |
+
weights_path = "weights_path"
|
21 |
+
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
|
22 |
+
|
23 |
+
# Computing text embeddings
|
24 |
+
text_embeddings = clap_model.get_text_embeddings(y)
|
25 |
+
|
26 |
+
# Computing audio embeddings
|
27 |
+
y_preds, y_labels = [], []
|
28 |
+
for i in tqdm(range(len(dataset))):
|
29 |
+
x, _, one_hot_target = dataset.__getitem__(i)
|
30 |
+
audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)
|
31 |
+
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
32 |
+
y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()
|
33 |
+
y_preds.append(y_pred)
|
34 |
+
y_labels.append(one_hot_target.detach().cpu().numpy())
|
35 |
+
|
36 |
+
|
37 |
+
y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)
|
38 |
+
acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))
|
39 |
+
print('ESC50 Accuracy {}'.format(acc))
|
40 |
+
|
41 |
+
"""
|
42 |
+
The output:
|
43 |
+
|
44 |
+
ESC50 Accuracy: 93.9%
|
45 |
+
|
46 |
+
"""
|
ms_clap/src/zero_shot_predictions.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an example using CLAP for zero-shot inference.
|
3 |
+
"""
|
4 |
+
from CLAPWrapper import CLAPWrapper
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# Define classes for zero-shot
|
8 |
+
# Should be in lower case and can be more than one word
|
9 |
+
classes = ['coughing','sneezing','drinking sipping', 'breathing', 'brushing teeth']
|
10 |
+
ground_truth = ['coughing']
|
11 |
+
# Add prompt
|
12 |
+
prompt = 'this is a sound of '
|
13 |
+
class_prompts = [prompt + x for x in classes]
|
14 |
+
#Load audio files
|
15 |
+
audio_files = ['audio_file']
|
16 |
+
|
17 |
+
# Load and initialize CLAP
|
18 |
+
weights_path = "weights_path"
|
19 |
+
# Setting use_cuda = True will load the model on a GPU using CUDA
|
20 |
+
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
|
21 |
+
|
22 |
+
# compute text embeddings from natural text
|
23 |
+
text_embeddings = clap_model.get_text_embeddings(class_prompts)
|
24 |
+
|
25 |
+
# compute the audio embeddings from an audio file
|
26 |
+
audio_embeddings = clap_model.get_audio_embeddings(audio_files, resample=True)
|
27 |
+
|
28 |
+
# compute the similarity between audio_embeddings and text_embeddings
|
29 |
+
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
30 |
+
|
31 |
+
similarity = F.softmax(similarity, dim=1)
|
32 |
+
values, indices = similarity[0].topk(5)
|
33 |
+
|
34 |
+
# Print the results
|
35 |
+
print("Ground Truth: {}".format(ground_truth))
|
36 |
+
print("Top predictions:\n")
|
37 |
+
for value, index in zip(values, indices):
|
38 |
+
print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")
|
39 |
+
|
40 |
+
"""
|
41 |
+
The output (the exact numbers may vary):
|
42 |
+
|
43 |
+
Ground Truth: coughing
|
44 |
+
Top predictions:
|
45 |
+
|
46 |
+
coughing: 98.55%
|
47 |
+
sneezing: 1.24%
|
48 |
+
drinking sipping: 0.15%
|
49 |
+
breathing: 0.02%
|
50 |
+
brushing teeth: 0.01%
|
51 |
+
"""
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
scipy
|
3 |
+
scikit-learn
|
4 |
+
librosa
|
5 |
+
soundfile
|
6 |
+
pydub
|
7 |
+
torch==2.0.1
|
8 |
+
torchaudio==2.0.2
|
9 |
+
torchlibrosa==0.1.0
|
10 |
+
torchvision==0.15.2
|
11 |
+
transformers==4.27.4
|
12 |
+
einops
|
13 |
+
huggingface-hub
|
14 |
+
laion-clap==1.1.3
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
src/factory.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
sys.path.append('../')
|
9 |
+
|
10 |
+
from typing import Optional
|
11 |
+
from copy import deepcopy
|
12 |
+
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
+
from ms_clap.src.CLAPWrapper import CLAPWrapper
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
try:
|
20 |
+
from .flamingo import Flamingo
|
21 |
+
from .flamingo_lm import FlamingoLMMixin
|
22 |
+
from .utils import extend_instance
|
23 |
+
except:
|
24 |
+
from flamingo import Flamingo
|
25 |
+
from flamingo_lm import FlamingoLMMixin
|
26 |
+
from utils import extend_instance
|
27 |
+
|
28 |
+
|
29 |
+
class CLAP(nn.Module):
|
30 |
+
def __init__(self, clap_config):
|
31 |
+
super(CLAP, self).__init__()
|
32 |
+
self.method = clap_config["method"]
|
33 |
+
device_id = f'cuda:{torch.cuda.current_device()}'
|
34 |
+
|
35 |
+
if self.method == 'laion-clap':
|
36 |
+
# https://github.com/LAION-AI/CLAP
|
37 |
+
if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']:
|
38 |
+
amodel = 'HTSAT-tiny'
|
39 |
+
elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']:
|
40 |
+
amodel = 'HTSAT-base'
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
enable_fusion = 'fusion' in clap_config["model_name"].lower()
|
45 |
+
self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id)
|
46 |
+
self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
|
47 |
+
|
48 |
+
for param in self.laion_clap.parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
self.laion_clap.eval()
|
51 |
+
|
52 |
+
print('loaded laion-clap model: {}'.format(clap_config["checkpoint"]))
|
53 |
+
|
54 |
+
elif self.method == 'microsoft-clap':
|
55 |
+
# https://github.com/microsoft/CLAP
|
56 |
+
self.ms_clap = CLAPWrapper(
|
57 |
+
clap_config["checkpoint"],
|
58 |
+
config_root=clap_config["config_root"],
|
59 |
+
version=clap_config['model_name'],
|
60 |
+
use_cuda=True
|
61 |
+
)
|
62 |
+
|
63 |
+
if clap_config['model_name'] in ['2022', '2023']:
|
64 |
+
for param in self.ms_clap.clap.parameters():
|
65 |
+
param.requires_grad = False
|
66 |
+
self.ms_clap.clap.eval()
|
67 |
+
else:
|
68 |
+
for param in self.ms_clap.clapcap.parameters():
|
69 |
+
param.requires_grad = False
|
70 |
+
self.ms_clap.clapcap.eval()
|
71 |
+
|
72 |
+
print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"]))
|
73 |
+
|
74 |
+
else:
|
75 |
+
raise NotImplementedError
|
76 |
+
|
77 |
+
def forward(self, audio_clips):
|
78 |
+
|
79 |
+
if len(audio_clips.shape) == 2:
|
80 |
+
audio_clips = audio_clips.unsqueeze(0)
|
81 |
+
assert len(audio_clips.shape) == 3
|
82 |
+
|
83 |
+
audio_embeds = []
|
84 |
+
for x in audio_clips:
|
85 |
+
if self.method == 'laion-clap':
|
86 |
+
audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True)
|
87 |
+
elif self.method == 'microsoft-clap':
|
88 |
+
audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x)
|
89 |
+
|
90 |
+
audio_embeds.append(audio_embed)
|
91 |
+
|
92 |
+
audio_embeds = torch.stack(audio_embeds, dim=0)
|
93 |
+
audio_embeds.requires_grad = False
|
94 |
+
|
95 |
+
return audio_embeds
|
96 |
+
|
97 |
+
|
98 |
+
def create_model_and_transforms(
|
99 |
+
clap_config: dict,
|
100 |
+
lang_encoder_path: str,
|
101 |
+
tokenizer_path: str,
|
102 |
+
audio_transformer_kwargs: dict,
|
103 |
+
cross_attn_every_n_layers: int = 1,
|
104 |
+
use_local_files: bool = False,
|
105 |
+
decoder_layers_attr_name: str = None,
|
106 |
+
freeze_lm_embeddings: bool = False,
|
107 |
+
unfreeze_full_lm: bool = False,
|
108 |
+
cache_dir: Optional[str] = None,
|
109 |
+
**flamingo_kwargs,
|
110 |
+
):
|
111 |
+
clap = CLAP(clap_config)
|
112 |
+
|
113 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
114 |
+
tokenizer_path,
|
115 |
+
local_files_only=use_local_files,
|
116 |
+
trust_remote_code=True,
|
117 |
+
cache_dir=cache_dir,
|
118 |
+
)
|
119 |
+
text_tokenizer.add_special_tokens(
|
120 |
+
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
|
121 |
+
)
|
122 |
+
if text_tokenizer.pad_token is None:
|
123 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
124 |
+
if text_tokenizer.sep_token is None:
|
125 |
+
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
|
126 |
+
|
127 |
+
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
128 |
+
lang_encoder_path,
|
129 |
+
local_files_only=use_local_files,
|
130 |
+
trust_remote_code=True,
|
131 |
+
cache_dir=cache_dir,
|
132 |
+
)
|
133 |
+
|
134 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
135 |
+
|
136 |
+
if decoder_layers_attr_name is None:
|
137 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
138 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
139 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
140 |
+
|
141 |
+
unfreeze_clap = False
|
142 |
+
|
143 |
+
model = Flamingo(
|
144 |
+
clap,
|
145 |
+
unfreeze_clap,
|
146 |
+
lang_encoder,
|
147 |
+
text_tokenizer.encode("<|endofchunk|>")[-1],
|
148 |
+
text_tokenizer.encode("<audio>")[-1],
|
149 |
+
text_tokenizer.sep_token_id,
|
150 |
+
audio_embed_dim=clap_config["audio_embed_dim"],
|
151 |
+
audio_transformer_kwargs=audio_transformer_kwargs,
|
152 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
153 |
+
**flamingo_kwargs,
|
154 |
+
)
|
155 |
+
|
156 |
+
model.requires_grad_(False)
|
157 |
+
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
158 |
+
|
159 |
+
model.audio_transformer.requires_grad_(True)
|
160 |
+
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
|
161 |
+
if not freeze_lm_embeddings:
|
162 |
+
model.lang_encoder.get_input_embeddings().requires_grad_(True)
|
163 |
+
|
164 |
+
if unfreeze_full_lm:
|
165 |
+
model.lang_encoder.requires_grad_(True)
|
166 |
+
|
167 |
+
if unfreeze_clap:
|
168 |
+
model.clap.requires_grad_(True)
|
169 |
+
|
170 |
+
print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format(
|
171 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
172 |
+
sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad),
|
173 |
+
sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad)
|
174 |
+
))
|
175 |
+
|
176 |
+
return model, text_tokenizer
|
177 |
+
|
178 |
+
|
179 |
+
def _infer_decoder_layers_attr_name(model):
|
180 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
181 |
+
if k.lower() in model.__class__.__name__.lower():
|
182 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
183 |
+
|
184 |
+
raise ValueError(
|
185 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
190 |
+
"opt": "model.decoder.layers",
|
191 |
+
"gptj": "transformer.h",
|
192 |
+
"gpt-j": "transformer.h",
|
193 |
+
"pythia": "gpt_neox.layers",
|
194 |
+
"llama": "model.layers",
|
195 |
+
"gptneoxforcausallm": "gpt_neox.layers",
|
196 |
+
"mpt": "transformer.blocks",
|
197 |
+
"mosaicgpt": "transformer.blocks",
|
198 |
+
}
|
src/flamingo.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from torch.distributed.fsdp.wrap import (
|
12 |
+
enable_wrap,
|
13 |
+
wrap,
|
14 |
+
)
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
16 |
+
from torch.distributed.fsdp import (
|
17 |
+
FullyShardedDataParallel as FSDP,
|
18 |
+
)
|
19 |
+
|
20 |
+
try:
|
21 |
+
from .helpers import TransformerEncoder
|
22 |
+
from .utils import apply_with_stopping_condition
|
23 |
+
except:
|
24 |
+
from helpers import TransformerEncoder
|
25 |
+
from utils import apply_with_stopping_condition
|
26 |
+
|
27 |
+
|
28 |
+
class Flamingo(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
clap: nn.Module,
|
32 |
+
unfreeze_clap: bool,
|
33 |
+
lang_encoder: nn.Module,
|
34 |
+
eoc_token_id: int,
|
35 |
+
media_token_id: int,
|
36 |
+
sep_token_id: int,
|
37 |
+
audio_embed_dim: int,
|
38 |
+
audio_transformer_kwargs: dict,
|
39 |
+
cross_attn_every_n_layers: int = 1,
|
40 |
+
gradient_checkpointing: bool = False,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.eoc_token_id = eoc_token_id
|
45 |
+
self.media_token_id = media_token_id
|
46 |
+
self.sep_token_id = sep_token_id
|
47 |
+
self.audio_embed_dim = audio_embed_dim
|
48 |
+
self.clap = clap # .to(torch.cuda.current_device())
|
49 |
+
self.unfreeze_clap = unfreeze_clap
|
50 |
+
self.clap.requires_grad_(unfreeze_clap)
|
51 |
+
|
52 |
+
if hasattr(lang_encoder.config, "d_model"):
|
53 |
+
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
|
54 |
+
else:
|
55 |
+
self.lang_dim = lang_encoder.config.hidden_size
|
56 |
+
|
57 |
+
n_head = audio_transformer_kwargs["n_head"]
|
58 |
+
n_layers = audio_transformer_kwargs["n_layers"]
|
59 |
+
d_inner = audio_transformer_kwargs["d_inner"]
|
60 |
+
max_num_media = audio_transformer_kwargs["max_num_media"]
|
61 |
+
max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"]
|
62 |
+
assert audio_embed_dim % n_head == 0
|
63 |
+
|
64 |
+
self.audio_transformer = TransformerEncoder(
|
65 |
+
d_word_vec=audio_embed_dim,
|
66 |
+
n_layers=n_layers,
|
67 |
+
n_head=n_head,
|
68 |
+
d_k=audio_embed_dim // n_head,
|
69 |
+
d_v=audio_embed_dim // n_head,
|
70 |
+
d_model=audio_embed_dim,
|
71 |
+
d_inner=d_inner,
|
72 |
+
dropout=0.0,
|
73 |
+
n_position=max_num_media,
|
74 |
+
scale_emb=True
|
75 |
+
)
|
76 |
+
|
77 |
+
self.lang_encoder = lang_encoder
|
78 |
+
self.lang_encoder.init_flamingo(
|
79 |
+
media_token_id=media_token_id,
|
80 |
+
lang_hidden_size=self.lang_dim,
|
81 |
+
audio_hidden_size=self.audio_embed_dim,
|
82 |
+
max_window_per_audio=max_window_per_audio,
|
83 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
84 |
+
gradient_checkpointing=gradient_checkpointing,
|
85 |
+
)
|
86 |
+
|
87 |
+
self._use_gradient_checkpointing = gradient_checkpointing
|
88 |
+
self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing
|
89 |
+
self.clap._use_gradient_checkpointing = gradient_checkpointing
|
90 |
+
|
91 |
+
def forward(
|
92 |
+
self,
|
93 |
+
audio_x: torch.Tensor,
|
94 |
+
audio_x_mask: torch.Tensor,
|
95 |
+
lang_x: torch.Tensor,
|
96 |
+
attention_mask: torch.Tensor = None,
|
97 |
+
labels: torch.Tensor = None,
|
98 |
+
clear_conditioned_layers: bool = True,
|
99 |
+
past_key_values=None,
|
100 |
+
use_cache: bool = False,
|
101 |
+
):
|
102 |
+
assert (
|
103 |
+
self.lang_encoder.initialized_flamingo
|
104 |
+
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
|
105 |
+
|
106 |
+
assert (
|
107 |
+
self.lang_encoder._use_cached_audio_x or audio_x is not None
|
108 |
+
), "Must provide either audio_x or have precached media using cache_media()."
|
109 |
+
|
110 |
+
if self.lang_encoder._use_cached_audio_x:
|
111 |
+
assert (
|
112 |
+
audio_x is None
|
113 |
+
), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first."
|
114 |
+
assert self.lang_encoder.is_conditioned()
|
115 |
+
|
116 |
+
else:
|
117 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
118 |
+
self._condition_media_locations(input_ids=lang_x)
|
119 |
+
|
120 |
+
output = self.lang_encoder(
|
121 |
+
input_ids=lang_x,
|
122 |
+
attention_mask=attention_mask,
|
123 |
+
labels=labels,
|
124 |
+
past_key_values=past_key_values,
|
125 |
+
use_cache=use_cache,
|
126 |
+
)
|
127 |
+
|
128 |
+
if clear_conditioned_layers:
|
129 |
+
self.lang_encoder.clear_conditioned_layers()
|
130 |
+
|
131 |
+
return output
|
132 |
+
|
133 |
+
def generate(
|
134 |
+
self,
|
135 |
+
audio_x: torch.Tensor,
|
136 |
+
audio_x_mask: torch.Tensor,
|
137 |
+
lang_x: torch.Tensor,
|
138 |
+
attention_mask: torch.Tensor = None,
|
139 |
+
**kwargs,
|
140 |
+
):
|
141 |
+
num_beams = kwargs.pop("num_beams", 1)
|
142 |
+
if num_beams > 1:
|
143 |
+
audio_x = audio_x.repeat_interleave(num_beams, dim=0)
|
144 |
+
|
145 |
+
self.lang_encoder._use_cached_audio_x = True
|
146 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
147 |
+
|
148 |
+
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
|
149 |
+
output = self.lang_encoder.generate(
|
150 |
+
input_ids=lang_x,
|
151 |
+
attention_mask=attention_mask,
|
152 |
+
eos_token_id=eos_token_id,
|
153 |
+
num_beams=num_beams,
|
154 |
+
**kwargs,
|
155 |
+
)
|
156 |
+
|
157 |
+
self.lang_encoder.clear_conditioned_layers()
|
158 |
+
self.lang_encoder._use_cached_audio_x = False
|
159 |
+
return output
|
160 |
+
|
161 |
+
def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
|
162 |
+
"""
|
163 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
164 |
+
"""
|
165 |
+
|
166 |
+
assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)"
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
audio_embeds = self.clap(audio_x)
|
170 |
+
B, L, D = audio_embeds.shape # L is number of windows, D is feature dim
|
171 |
+
assert D == self.audio_embed_dim
|
172 |
+
|
173 |
+
assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)"
|
174 |
+
|
175 |
+
if B > 1 and audio_x_mask.shape[0] == 1:
|
176 |
+
audio_x_mask = audio_x_mask.repeat(B, 1)
|
177 |
+
|
178 |
+
assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L)
|
179 |
+
|
180 |
+
audio_x_out = self.audio_transformer(audio_embeds) # B, L, D
|
181 |
+
audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D
|
182 |
+
audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1
|
183 |
+
|
184 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
185 |
+
layer.condition_audio_x(audio_x_out, audio_x_mask)
|
186 |
+
|
187 |
+
def wrap_fsdp(self, wrapper_kwargs, device_id):
|
188 |
+
# unfreeze the decoder layers
|
189 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
190 |
+
block.requires_grad_(True)
|
191 |
+
|
192 |
+
# wrap in FSDP
|
193 |
+
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
|
194 |
+
self.audio_transformer = wrap(wrap(self.audio_transformer))
|
195 |
+
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
|
196 |
+
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
|
197 |
+
)
|
198 |
+
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
|
199 |
+
wrap(wrap(layer)) if layer is not None else None
|
200 |
+
for layer in self.lang_encoder.gated_cross_attn_layers
|
201 |
+
)
|
202 |
+
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
|
203 |
+
self.lang_encoder.set_input_embeddings(
|
204 |
+
wrap(wrap(self.lang_encoder.get_input_embeddings()))
|
205 |
+
)
|
206 |
+
|
207 |
+
if hasattr(self.lang_encoder, 'set_output_embeddings'):
|
208 |
+
self.lang_encoder.set_output_embeddings(
|
209 |
+
wrap(wrap(self.lang_encoder.get_output_embeddings()))
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
print('skip wrapping output embeddings')
|
213 |
+
|
214 |
+
# manually move non-FSDP managed parameters to device_id
|
215 |
+
# these are all in lang_encoder
|
216 |
+
apply_with_stopping_condition(
|
217 |
+
module=self.lang_encoder,
|
218 |
+
apply_fn=lambda m: m.to(device_id),
|
219 |
+
apply_condition=lambda m: len(list(m.children())) == 0,
|
220 |
+
stopping_condition=lambda m: isinstance(m, FSDP),
|
221 |
+
)
|
222 |
+
|
223 |
+
# clap shouldn't be wrapped; should be on each gpu
|
224 |
+
if self.unfreeze_clap:
|
225 |
+
apply_with_stopping_condition(
|
226 |
+
module=self.clap,
|
227 |
+
apply_fn=lambda m: m.to(device_id),
|
228 |
+
apply_condition=lambda m: len(list(m.children())) == 0,
|
229 |
+
stopping_condition=lambda m: isinstance(m, FSDP),
|
230 |
+
)
|
231 |
+
|
232 |
+
# exclude the original decoder layers from the optimizer
|
233 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
234 |
+
for p in block.parameters():
|
235 |
+
p.exclude_from_optimizer = True
|
236 |
+
|
237 |
+
# set up clip_grad_norm_ function
|
238 |
+
def clip_grad_norm_(max_norm):
|
239 |
+
self.audio_transformer.clip_grad_norm_(max_norm)
|
240 |
+
for layer in self.lang_encoder.gated_cross_attn_layers:
|
241 |
+
if layer is not None:
|
242 |
+
layer.clip_grad_norm_(max_norm)
|
243 |
+
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
|
244 |
+
|
245 |
+
self.clip_grad_norm_ = clip_grad_norm_
|
246 |
+
|
247 |
+
def _condition_media_locations(self, input_ids: torch.Tensor):
|
248 |
+
media_locations = (input_ids == self.media_token_id)
|
249 |
+
|
250 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
251 |
+
layer.condition_media_locations(media_locations)
|
252 |
+
|
253 |
+
def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
|
254 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
255 |
+
self._condition_media_locations(input_ids=input_ids)
|
256 |
+
self.lang_encoder._use_cached_audio_x = True
|
257 |
+
|
258 |
+
def uncache_media(self):
|
259 |
+
self.lang_encoder.clear_conditioned_layers()
|
260 |
+
self.lang_encoder._use_cached_audio_x = False
|
src/flamingo_lm.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
try:
|
10 |
+
from .helpers import GatedCrossAttentionBlock
|
11 |
+
from .utils import getattr_recursive, setattr_recursive
|
12 |
+
except:
|
13 |
+
from helpers import GatedCrossAttentionBlock
|
14 |
+
from utils import getattr_recursive, setattr_recursive
|
15 |
+
|
16 |
+
|
17 |
+
class FlamingoLayer(nn.Module):
|
18 |
+
"""
|
19 |
+
FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.gated_cross_attn_layer = gated_cross_attn_layer
|
27 |
+
self.decoder_layer = decoder_layer
|
28 |
+
self.audio_x = None
|
29 |
+
self.audio_x_mask = None
|
30 |
+
self.few_shot_mask = None
|
31 |
+
self.media_locations = None
|
32 |
+
if self.gated_cross_attn_layer is not None:
|
33 |
+
self.gated_cross_attn_layer._use_gradient_checkpointing = (
|
34 |
+
gradient_checkpointing
|
35 |
+
)
|
36 |
+
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
|
37 |
+
|
38 |
+
def is_conditioned(self) -> bool:
|
39 |
+
"""Check whether the layer is conditioned."""
|
40 |
+
return (self.audio_x is not None) and (self.audio_x_mask is not None) and (self.media_locations is not None)
|
41 |
+
|
42 |
+
def condition_audio_x(self, audio_x, audio_x_mask):
|
43 |
+
self.audio_x = audio_x
|
44 |
+
self.audio_x_mask = audio_x_mask
|
45 |
+
|
46 |
+
def condition_media_locations(self, media_locations):
|
47 |
+
self.media_locations = media_locations
|
48 |
+
|
49 |
+
def condition_use_cached_media(self, use_cached_media):
|
50 |
+
self.use_cached_media = use_cached_media
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
lang_x,
|
55 |
+
attention_mask=None,
|
56 |
+
**decoder_layer_kwargs,
|
57 |
+
):
|
58 |
+
if self.gated_cross_attn_layer is not None:
|
59 |
+
if self.audio_x is None:
|
60 |
+
raise ValueError("audio_x must be conditioned before forward pass")
|
61 |
+
|
62 |
+
if self.media_locations is None:
|
63 |
+
raise ValueError(
|
64 |
+
"media_locations must be conditioned before forward pass"
|
65 |
+
)
|
66 |
+
|
67 |
+
lang_x = self.gated_cross_attn_layer(
|
68 |
+
lang_x,
|
69 |
+
self.audio_x,
|
70 |
+
self.audio_x_mask,
|
71 |
+
media_locations=self.media_locations,
|
72 |
+
use_cached_media=self.use_cached_media,
|
73 |
+
)
|
74 |
+
|
75 |
+
# Normal decoder layer
|
76 |
+
lang_x = self.decoder_layer(
|
77 |
+
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
|
78 |
+
)
|
79 |
+
return lang_x
|
80 |
+
|
81 |
+
|
82 |
+
class FlamingoLMMixin(nn.Module):
|
83 |
+
"""
|
84 |
+
Mixin to add cross-attention layers to a language model.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
88 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
89 |
+
|
90 |
+
def _get_decoder_layers(self):
|
91 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
92 |
+
|
93 |
+
def _set_decoder_layers(self, value):
|
94 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
95 |
+
|
96 |
+
def init_flamingo(
|
97 |
+
self,
|
98 |
+
media_token_id,
|
99 |
+
lang_hidden_size,
|
100 |
+
audio_hidden_size,
|
101 |
+
max_window_per_audio,
|
102 |
+
cross_attn_every_n_layers,
|
103 |
+
gradient_checkpointing,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
107 |
+
"""
|
108 |
+
self.old_decoder_blocks = self._get_decoder_layers()
|
109 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
110 |
+
[
|
111 |
+
GatedCrossAttentionBlock(
|
112 |
+
dim=lang_hidden_size,
|
113 |
+
dim_audio=audio_hidden_size,
|
114 |
+
max_window_per_audio=max_window_per_audio,
|
115 |
+
only_attend_immediate_media=False,
|
116 |
+
)
|
117 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
118 |
+
else None
|
119 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
120 |
+
]
|
121 |
+
)
|
122 |
+
self.init_flamingo_layers(gradient_checkpointing)
|
123 |
+
self.media_token_id = media_token_id
|
124 |
+
self.initialized_flamingo = True
|
125 |
+
self._use_cached_audio_x = False
|
126 |
+
|
127 |
+
def init_flamingo_layers(self, gradient_checkpointing):
|
128 |
+
"""
|
129 |
+
Re initializes the FlamingoLayers.
|
130 |
+
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
|
131 |
+
"""
|
132 |
+
self._set_decoder_layers(
|
133 |
+
nn.ModuleList(
|
134 |
+
[
|
135 |
+
FlamingoLayer(
|
136 |
+
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
|
137 |
+
)
|
138 |
+
for gated_cross_attn_layer, decoder_layer in zip(
|
139 |
+
self.gated_cross_attn_layers, self.old_decoder_blocks
|
140 |
+
)
|
141 |
+
]
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
def forward(self, input_ids, attention_mask, **kwargs):
|
146 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
147 |
+
if not self.initialized_flamingo:
|
148 |
+
raise ValueError(
|
149 |
+
"Flamingo layers are not initialized. Please call `init_flamingo` first."
|
150 |
+
)
|
151 |
+
|
152 |
+
media_locations = input_ids == self.media_token_id
|
153 |
+
|
154 |
+
use_cached_media_locations = (
|
155 |
+
self._use_cached_audio_x
|
156 |
+
and self.is_conditioned()
|
157 |
+
and not media_locations.any()
|
158 |
+
)
|
159 |
+
|
160 |
+
for layer in self._get_decoder_layers():
|
161 |
+
if not use_cached_media_locations:
|
162 |
+
layer.condition_media_locations(media_locations)
|
163 |
+
layer.condition_use_cached_media(use_cached_media_locations)
|
164 |
+
|
165 |
+
kwargs["input_ids"] = input_ids
|
166 |
+
kwargs["attention_mask"] = attention_mask
|
167 |
+
return super().forward(**kwargs)
|
168 |
+
|
169 |
+
def is_conditioned(self) -> bool:
|
170 |
+
"""Check whether all decoder layers are already conditioned."""
|
171 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
172 |
+
|
173 |
+
def clear_conditioned_layers(self):
|
174 |
+
for layer in self._get_decoder_layers():
|
175 |
+
layer.condition_audio_x(None, None)
|
176 |
+
layer.condition_media_locations(None)
|
177 |
+
layer.condition_use_cached_media(None)
|
src/helpers.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license.
|
8 |
+
# LICENSE is in incl_licenses directory.
|
9 |
+
|
10 |
+
# Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license.
|
11 |
+
# LICENSE is in incl_licenses directory.
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from einops_exts import rearrange_many
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import einsum, nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
def exists(val):
|
24 |
+
return val is not None
|
25 |
+
|
26 |
+
|
27 |
+
def FeedForward(dim, mult=4):
|
28 |
+
inner_dim = int(dim * mult)
|
29 |
+
return nn.Sequential(
|
30 |
+
nn.LayerNorm(dim),
|
31 |
+
nn.Linear(dim, inner_dim, bias=False),
|
32 |
+
nn.GELU(),
|
33 |
+
nn.Linear(inner_dim, dim, bias=False),
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class ScaledDotProductAttention(nn.Module):
|
38 |
+
''' Scaled Dot-Product Attention '''
|
39 |
+
|
40 |
+
def __init__(self, temperature, attn_dropout=0.1):
|
41 |
+
super().__init__()
|
42 |
+
self.temperature = temperature
|
43 |
+
self.dropout = nn.Dropout(attn_dropout)
|
44 |
+
|
45 |
+
def forward(self, q, k, v, mask=None):
|
46 |
+
|
47 |
+
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
|
48 |
+
|
49 |
+
if mask is not None:
|
50 |
+
attn = attn.masked_fill(mask == 0, -1e9)
|
51 |
+
|
52 |
+
attn = self.dropout(F.softmax(attn, dim=-1))
|
53 |
+
output = torch.matmul(attn, v)
|
54 |
+
|
55 |
+
return output, attn
|
56 |
+
|
57 |
+
|
58 |
+
class MultiHeadAttention(nn.Module):
|
59 |
+
''' Multi-Head Attention module '''
|
60 |
+
|
61 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
self.n_head = n_head
|
65 |
+
self.d_k = d_k
|
66 |
+
self.d_v = d_v
|
67 |
+
|
68 |
+
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
|
69 |
+
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
|
70 |
+
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
|
71 |
+
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
|
72 |
+
|
73 |
+
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
|
74 |
+
|
75 |
+
self.dropout = nn.Dropout(dropout)
|
76 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
77 |
+
|
78 |
+
|
79 |
+
def forward(self, q, k, v, mask=None):
|
80 |
+
|
81 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
82 |
+
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
|
83 |
+
|
84 |
+
residual = q
|
85 |
+
|
86 |
+
# Pass through the pre-attention projection: b x lq x (n*dv)
|
87 |
+
# Separate different heads: b x lq x n x dv
|
88 |
+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
89 |
+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
90 |
+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
91 |
+
|
92 |
+
# Transpose for attention dot product: b x n x lq x dv
|
93 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
94 |
+
|
95 |
+
if mask is not None:
|
96 |
+
mask = mask.unsqueeze(1) # For head axis broadcasting.
|
97 |
+
|
98 |
+
q, attn = self.attention(q, k, v, mask=mask)
|
99 |
+
|
100 |
+
# Transpose to move the head dimension back: b x lq x n x dv
|
101 |
+
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
|
102 |
+
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
|
103 |
+
q = self.dropout(self.fc(q))
|
104 |
+
q += residual
|
105 |
+
|
106 |
+
q = self.layer_norm(q)
|
107 |
+
|
108 |
+
return q, attn
|
109 |
+
|
110 |
+
|
111 |
+
class PositionwiseFeedForward(nn.Module):
|
112 |
+
''' A two-feed-forward-layer module '''
|
113 |
+
|
114 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
115 |
+
super().__init__()
|
116 |
+
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
|
117 |
+
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
|
118 |
+
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
119 |
+
self.dropout = nn.Dropout(dropout)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
|
123 |
+
residual = x
|
124 |
+
|
125 |
+
x = self.w_2(F.relu(self.w_1(x)))
|
126 |
+
x = self.dropout(x)
|
127 |
+
x += residual
|
128 |
+
|
129 |
+
x = self.layer_norm(x)
|
130 |
+
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class PositionalEncoding(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, d_hid, n_position=200):
|
137 |
+
super(PositionalEncoding, self).__init__()
|
138 |
+
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
139 |
+
|
140 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
141 |
+
|
142 |
+
def get_position_angle_vec(position):
|
143 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
144 |
+
|
145 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
146 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
147 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
148 |
+
|
149 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
153 |
+
|
154 |
+
|
155 |
+
class EncoderLayer(nn.Module):
|
156 |
+
''' Compose with two layers '''
|
157 |
+
|
158 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
|
159 |
+
super(EncoderLayer, self).__init__()
|
160 |
+
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
161 |
+
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
162 |
+
|
163 |
+
def forward(self, enc_input, slf_attn_mask=None):
|
164 |
+
enc_output, enc_slf_attn = self.slf_attn(
|
165 |
+
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
166 |
+
enc_output = self.pos_ffn(enc_output)
|
167 |
+
return enc_output, enc_slf_attn
|
168 |
+
|
169 |
+
|
170 |
+
class TransformerEncoder(nn.Module):
|
171 |
+
''' A encoder model with self attention mechanism. '''
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64,
|
175 |
+
d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True):
|
176 |
+
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
if n_position > 0:
|
180 |
+
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
|
181 |
+
else:
|
182 |
+
self.position_enc = lambda x: x
|
183 |
+
self.dropout = nn.Dropout(p=dropout)
|
184 |
+
self.layer_stack = nn.ModuleList([
|
185 |
+
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
186 |
+
for _ in range(n_layers)])
|
187 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
188 |
+
self.scale_emb = scale_emb
|
189 |
+
self.d_model = d_model
|
190 |
+
|
191 |
+
def forward(self, src_seq, return_attns=False):
|
192 |
+
if len(src_seq.shape) == 2:
|
193 |
+
src_seq = src_seq.unsqueeze(1)
|
194 |
+
B, L, D = src_seq.shape
|
195 |
+
|
196 |
+
enc_slf_attn_list = []
|
197 |
+
|
198 |
+
causal_mask = None
|
199 |
+
|
200 |
+
enc_output = src_seq
|
201 |
+
if self.scale_emb:
|
202 |
+
enc_output = enc_output * self.d_model ** 0.5
|
203 |
+
enc_output = self.dropout(self.position_enc(enc_output))
|
204 |
+
enc_output = self.layer_norm(enc_output)
|
205 |
+
|
206 |
+
for enc_layer in self.layer_stack:
|
207 |
+
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask)
|
208 |
+
enc_slf_attn_list += [enc_slf_attn] if return_attns else []
|
209 |
+
|
210 |
+
if return_attns:
|
211 |
+
return enc_output, enc_slf_attn_list
|
212 |
+
return enc_output
|
213 |
+
|
214 |
+
|
215 |
+
# gated cross attention
|
216 |
+
class MaskedCrossAttention(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
*,
|
220 |
+
dim,
|
221 |
+
dim_audio,
|
222 |
+
max_window_per_audio,
|
223 |
+
dim_head=64,
|
224 |
+
heads=8,
|
225 |
+
only_attend_immediate_media=True,
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.max_window_per_audio = max_window_per_audio
|
229 |
+
self.scale = dim_head**-0.5
|
230 |
+
self.heads = heads
|
231 |
+
inner_dim = dim_head * heads
|
232 |
+
|
233 |
+
self.norm = nn.LayerNorm(dim)
|
234 |
+
|
235 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
236 |
+
self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False)
|
237 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
238 |
+
|
239 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
240 |
+
|
241 |
+
def forward(
|
242 |
+
self,
|
243 |
+
x,
|
244 |
+
media, media_mask,
|
245 |
+
media_locations=None,
|
246 |
+
use_cached_media=False
|
247 |
+
):
|
248 |
+
|
249 |
+
if not use_cached_media:
|
250 |
+
assert (
|
251 |
+
media_locations.shape[1] == x.shape[1]
|
252 |
+
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
|
253 |
+
|
254 |
+
T_txt = x.shape[1]
|
255 |
+
B, L = media.shape[:2]
|
256 |
+
assert media.shape[2] == 1 # extra dim
|
257 |
+
assert L % self.max_window_per_audio == 0 # should be 4 or 8 times
|
258 |
+
h = self.heads
|
259 |
+
|
260 |
+
x = self.norm(x)
|
261 |
+
|
262 |
+
q = self.to_q(x)
|
263 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
264 |
+
|
265 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
266 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
267 |
+
|
268 |
+
q = q * self.scale
|
269 |
+
|
270 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
271 |
+
|
272 |
+
# mask padded audio embeddings
|
273 |
+
media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim
|
274 |
+
sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max)
|
275 |
+
|
276 |
+
assert self.only_attend_immediate_media is False
|
277 |
+
|
278 |
+
# mask media locations
|
279 |
+
if exists(media_locations):
|
280 |
+
few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device)
|
281 |
+
for batch_idx in range(B):
|
282 |
+
media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio>
|
283 |
+
if len(media_locations_b.shape) > 1:
|
284 |
+
media_locations_b = media_locations_b.squeeze(-1)
|
285 |
+
|
286 |
+
for i in range(-1, len(media_locations_b)):
|
287 |
+
if i == -1:
|
288 |
+
if len(media_locations_b) == 1:
|
289 |
+
text_start, text_end = 0, T_txt
|
290 |
+
else:
|
291 |
+
text_start, text_end = 0, media_locations_b[i+1]
|
292 |
+
|
293 |
+
elif i == len(media_locations_b) - 1:
|
294 |
+
text_start, text_end = media_locations_b[i], T_txt
|
295 |
+
|
296 |
+
else:
|
297 |
+
text_start, text_end = media_locations_b[i], media_locations_b[i+1]
|
298 |
+
|
299 |
+
if self.only_attend_immediate_media:
|
300 |
+
look_at_window_start = max(i,0) * self.max_window_per_audio
|
301 |
+
else:
|
302 |
+
look_at_window_start = 0
|
303 |
+
look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio
|
304 |
+
|
305 |
+
few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True
|
306 |
+
|
307 |
+
sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max)
|
308 |
+
|
309 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
310 |
+
attn = sim.softmax(dim=-1)
|
311 |
+
|
312 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
313 |
+
text_without_media_mask = text_time == 0
|
314 |
+
text_without_media_mask = rearrange(
|
315 |
+
text_without_media_mask, "b i -> b 1 i 1"
|
316 |
+
)
|
317 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
318 |
+
|
319 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
320 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
321 |
+
return self.to_out(out)
|
322 |
+
|
323 |
+
|
324 |
+
class GatedCrossAttentionBlock(nn.Module):
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
*,
|
328 |
+
dim,
|
329 |
+
dim_audio,
|
330 |
+
max_window_per_audio,
|
331 |
+
dim_head=64,
|
332 |
+
heads=8,
|
333 |
+
ff_mult=4,
|
334 |
+
only_attend_immediate_media=True,
|
335 |
+
):
|
336 |
+
super().__init__()
|
337 |
+
self.attn = MaskedCrossAttention(
|
338 |
+
dim=dim,
|
339 |
+
dim_audio=dim_audio,
|
340 |
+
max_window_per_audio=max_window_per_audio,
|
341 |
+
dim_head=dim_head,
|
342 |
+
heads=heads,
|
343 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
344 |
+
)
|
345 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
346 |
+
|
347 |
+
self.ff = FeedForward(dim, mult=ff_mult)
|
348 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
349 |
+
|
350 |
+
def forward(
|
351 |
+
self,
|
352 |
+
x,
|
353 |
+
media,
|
354 |
+
media_mask,
|
355 |
+
media_locations=None,
|
356 |
+
use_cached_media=False,
|
357 |
+
):
|
358 |
+
x = (
|
359 |
+
self.attn(
|
360 |
+
x,
|
361 |
+
media,
|
362 |
+
media_mask,
|
363 |
+
media_locations=media_locations,
|
364 |
+
use_cached_media=use_cached_media,
|
365 |
+
)
|
366 |
+
* self.attn_gate.tanh()
|
367 |
+
+ x
|
368 |
+
)
|
369 |
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
370 |
+
|
371 |
+
return x
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == '__main__':
|
375 |
+
enc = TransformerEncoder().cuda()
|
376 |
+
x = torch.randn(4, 512).cuda()
|
377 |
+
output = enc(x)
|
378 |
+
enc._use_gradient_checkpointing = True
|
379 |
+
print(output.shape)
|
src/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
5 |
+
# LICENSE is in incl_licenses directory.
|
6 |
+
|
7 |
+
def extend_instance(obj, mixin):
|
8 |
+
"""Apply mixins to a class instance after creation"""
|
9 |
+
base_cls = obj.__class__
|
10 |
+
base_cls_name = obj.__class__.__name__
|
11 |
+
obj.__class__ = type(
|
12 |
+
base_cls_name, (mixin, base_cls), {}
|
13 |
+
) # mixin needs to go first for our forward() logic to work
|
14 |
+
|
15 |
+
|
16 |
+
def getattr_recursive(obj, att):
|
17 |
+
"""
|
18 |
+
Return nested attribute of obj
|
19 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
20 |
+
"""
|
21 |
+
if att == "":
|
22 |
+
return obj
|
23 |
+
i = att.find(".")
|
24 |
+
if i < 0:
|
25 |
+
return getattr(obj, att)
|
26 |
+
else:
|
27 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
28 |
+
|
29 |
+
|
30 |
+
def setattr_recursive(obj, att, val):
|
31 |
+
"""
|
32 |
+
Set nested attribute of obj
|
33 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
34 |
+
"""
|
35 |
+
if "." in att:
|
36 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
37 |
+
setattr(obj, att.split(".")[-1], val)
|
38 |
+
|
39 |
+
|
40 |
+
def apply_with_stopping_condition(
|
41 |
+
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
|
42 |
+
):
|
43 |
+
if stopping_condition(module):
|
44 |
+
return
|
45 |
+
if apply_condition(module):
|
46 |
+
apply_fn(module, **other_args)
|
47 |
+
for child in module.children():
|
48 |
+
apply_with_stopping_condition(
|
49 |
+
child,
|
50 |
+
apply_fn,
|
51 |
+
apply_condition=apply_condition,
|
52 |
+
stopping_condition=stopping_condition,
|
53 |
+
**other_args
|
54 |
+
)
|