ZhifengKong commited on
Commit
92740f3
·
1 Parent(s): 15f9587
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE_OPT_IML.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2 align="center"> OPT-IML 175B LICENSE AGREEMENT </h2>
2
+
3
+ This License Agreement (as may be amended in accordance with this License Agreement, **“License”**), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (**“Licensee”** or **“you”**) and Meta Platforms, Inc. (**“Meta”** or **“we”**) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Meta under this License (**“Software”**) and any specifications, manuals, documentation, and other written information provided by Meta related to the Software (**“Documentation”**).
4
+
5
+ **By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Meta that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.**
6
+ <br><br>
7
+ 1. **LICENSE GRANT**
8
+ <br><br>
9
+ a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Meta grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Meta’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Meta’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
10
+ <br><br>
11
+ b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
12
+ <br><br>
13
+ c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Meta and its licensors reserve all rights not expressly granted by this License.
14
+ <br><br>
15
+ 2. **RESTRICTIONS**
16
+ <br><br>
17
+ You will not, and will not permit, assist or cause any third party to:
18
+ <br><br>
19
+ a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law, including accessing the Software Products from an embargoed country as prohibited by the U.S. government, and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
20
+ <br><br>
21
+ b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
22
+ <br><br>
23
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Meta in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Meta; or
24
+ <br><br>
25
+ d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
26
+ <br><br>
27
+ 3. **ATTRIBUTION**
28
+ <br><br>
29
+ Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “OPT-IML 175B is licensed under the OPT-175B license, Copyright (c) Meta Platforms, Inc. All Rights Reserved.”
30
+ <br><br>
31
+ 4. **DISCLAIMERS**
32
+ <br><br>
33
+ THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. META EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. META MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
34
+ <br><br>
35
+ 5. **LIMITATION OF LIABILITY**
36
+ <br><br>
37
+ TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL META BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF META HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, **“SOFTWARE MATERIALS”**) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A **“HIGH-RISK USE”**). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
38
+ <br><br>
39
+ 6. **INDEMNIFICATION**
40
+ <br><br>
41
+ You will indemnify, defend and hold harmless Meta and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the **“Meta Parties”**) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Meta Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, **“Claims”**) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Meta Parties of any such Claims, and cooperate with Meta Parties in defending such Claims. You will also grant the Meta Parties sole control of the defense or settlement, at Meta’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Meta or the other Meta Parties.
42
+ <br><br>
43
+ 7. **TERMINATION; SURVIVAL**
44
+ <br><br>
45
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
46
+ <br><br>
47
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
48
+ <br><br>
49
+ c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
50
+ <br><br>
51
+ 8. **THIRD PARTY MATERIALS**
52
+ <br><br>
53
+ The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, **“Third Party Materials”**), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Meta does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
54
+ <br><br>
55
+ 9. **TRADEMARKS**
56
+ <br><br>
57
+ Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Meta without the prior written permission of Meta, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
58
+ <br><br>
59
+ 10. **APPLICABLE LAW; DISPUTE RESOLUTION**
60
+ <br><br>
61
+ This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
62
+ <br><br>
63
+ 11. **MISCELLANEOUS**
64
+ <br><br>
65
+ If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Meta to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Meta regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Meta regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import yaml
6
+
7
+ import gradio as gr
8
+
9
+ import librosa
10
+ from pydub import AudioSegment
11
+ import soundfile as sf
12
+
13
+ import numpy as np
14
+ import torch
15
+ import laion_clap
16
+
17
+ from inference_utils import prepare_tokenizer, prepare_model, inference
18
+ from data import AudioTextDataProcessor
19
+
20
+
21
+ def load_laionclap():
22
+ model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').cuda()
23
+ model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
24
+ model.eval()
25
+ return model
26
+
27
+
28
+ def int16_to_float32(x):
29
+ return (x / 32767.0).astype(np.float32)
30
+
31
+
32
+ def float32_to_int16(x):
33
+ x = np.clip(x, a_min=-1., a_max=1.)
34
+ return (x * 32767.).astype(np.int16)
35
+
36
+
37
+ def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0):
38
+ if file_path.endswith('.mp3'):
39
+ audio = AudioSegment.from_file(file_path)
40
+ if len(audio) > (start + duration) * 1000:
41
+ audio = audio[start * 1000:(start + duration) * 1000]
42
+
43
+ if audio.frame_rate != target_sr:
44
+ audio = audio.set_frame_rate(target_sr)
45
+
46
+ if audio.channels > 1:
47
+ audio = audio.set_channels(1)
48
+
49
+ data = np.array(audio.get_array_of_samples())
50
+ if audio.sample_width == 2:
51
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
52
+ elif audio.sample_width == 4:
53
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
54
+ else:
55
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
56
+
57
+ else:
58
+ with sf.SoundFile(file_path) as audio:
59
+ original_sr = audio.samplerate
60
+ channels = audio.channels
61
+
62
+ max_frames = int((start + duration) * original_sr)
63
+
64
+ audio.seek(int(start * original_sr))
65
+ frames_to_read = min(max_frames, len(audio))
66
+ data = audio.read(frames_to_read)
67
+
68
+ if data.max() > 1 or data.min() < -1:
69
+ data = data / max(abs(data.max()), abs(data.min()))
70
+
71
+ if original_sr != target_sr:
72
+ if channels == 1:
73
+ data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
74
+ else:
75
+ data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
76
+ else:
77
+ if channels != 1:
78
+ data = data.T[0]
79
+
80
+ if data.min() >= 0:
81
+ data = 2 * data / abs(data.max()) - 1.0
82
+ else:
83
+ data = data / max(abs(data.max()), abs(data.min()))
84
+ return data
85
+
86
+
87
+ @torch.no_grad()
88
+ def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs):
89
+ try:
90
+ data = load_audio(audio_file, target_sr=48000)
91
+
92
+ except Exception as e:
93
+ print(audio_file, 'unsuccessful due to', e)
94
+ return [0.0] * len(outputs)
95
+
96
+ audio_data = data.reshape(1, -1)
97
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda()
98
+ audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
99
+
100
+ text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
101
+
102
+ cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
103
+ cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed)
104
+ return cos_similarity.squeeze().cpu().numpy()
105
+
106
+
107
+ inference_kwargs = {
108
+ "do_sample": True,
109
+ "top_k": 50,
110
+ "top_p": 0.95,
111
+ "num_return_sequences": 10
112
+ }
113
+
114
+ config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader)
115
+ clap_config = config['clap_config']
116
+ model_config = config['model_config']
117
+
118
+ text_tokenizer = prepare_tokenizer(model_config)
119
+ DataProcessor = AudioTextDataProcessor(
120
+ data_root='./',
121
+ clap_config=clap_config,
122
+ tokenizer=text_tokenizer,
123
+ max_tokens=512,
124
+ )
125
+
126
+ laionclap_model = load_laionclap()
127
+
128
+ model = prepare_model(
129
+ model_config=model_config,
130
+ clap_config=clap_config,
131
+ checkpoint_path='chat.pt'
132
+ )
133
+
134
+
135
+ def inference_item(name, prompt):
136
+ item = {
137
+ 'name': str(name),
138
+ 'prefix': 'The task is dialog.',
139
+ 'prompt': str(prompt)
140
+ }
141
+ processed_item = DataProcessor.process(item)
142
+
143
+ outputs = inference(
144
+ model, text_tokenizer, item, processed_item,
145
+ inference_kwargs,
146
+ )
147
+
148
+ laionclap_scores = compute_laionclap_text_audio_sim(
149
+ item["name"],
150
+ laionclap_model,
151
+ outputs
152
+ )
153
+
154
+ outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)]
155
+ outputs_joint.sort(key=lambda x: -x[1])
156
+
157
+ return outputs_joint[0][0]
158
+
159
+
160
+ with gr.Blocks(title="Audio Flamingo - Demo") as ui:
161
+
162
+ gr.HTML(
163
+ """
164
+ <div style="text-align: center; max-width: 900px; margin: 0 auto;">
165
+ <div
166
+ style="
167
+ display: inline-flex;
168
+ align-items: center;
169
+ gap: 0.8rem;
170
+ font-size: 1.5rem;
171
+ "
172
+ >
173
+ <h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;">
174
+ Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities
175
+ </h1>
176
+ </div>
177
+ <p style="margin-bottom: 10px; font-size: 125%">
178
+ <a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo]</a>
179
+ </p>
180
+ </div>
181
+ """
182
+ )
183
+ gr.HTML(
184
+ """
185
+ <div>
186
+ <h3>Model Overview</h3>
187
+ Audio Flamingo is an audio language model that can understand sounds beyond speech.
188
+ It can also answer questions about the sound in natural language.
189
+ Examples of questions include:
190
+ "Can you briefly describe what you hear in this audio?",
191
+ "What is the emotion conveyed in this music?",
192
+ "Where is this audio usually heard?",
193
+ or "What place is this music usually played at?".
194
+ </div>
195
+ """
196
+ )
197
+
198
+ name = gr.Textbox(
199
+ label="Audio file path (choose one from: audio/wav{1--6}.wav)",
200
+ value="audio/wav5.wav"
201
+ )
202
+ prompt = gr.Textbox(
203
+ label="Instruction",
204
+ value='Can you briefly describe what you hear in this audio?'
205
+ )
206
+
207
+ with gr.Row():
208
+ play_audio_button = gr.Button("Play Audio")
209
+ audio_output = gr.Audio(label="Playback")
210
+ play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output)
211
+
212
+ inference_button = gr.Button("Inference")
213
+
214
+ output_text = gr.Textbox(label="Audio Flamingo output")
215
+
216
+ inference_button.click(
217
+ fn=inference_item,
218
+ inputs=[name, prompt],
219
+ outputs=output_text
220
+ )
221
+
222
+ ui.queue()
223
+ ui.launch()
audio/wav1.wav ADDED
Binary file (960 kB). View file
 
audio/wav2.wav ADDED
Binary file (960 kB). View file
 
audio/wav3.wav ADDED
Binary file (960 kB). View file
 
audio/wav4.wav ADDED
Binary file (441 kB). View file
 
audio/wav5.wav ADDED
Binary file (441 kB). View file
 
audio/wav6.wav ADDED
Binary file (441 kB). View file
 
chat.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clap_config:
2
+ method: microsoft-clap
3
+ audio_embed_dim: 1024
4
+ config_root: ./ms_clap/src/configs
5
+ model_name: 'clapcap'
6
+ checkpoint: ./clapcap_weights_2023.pth
7
+ window_length: 7.0
8
+ window_overlap: 5.25
9
+ max_num_window: 16
10
+ max_num_fewshot: 4
11
+
12
+ model_config:
13
+ cache_dir: None
14
+ lang_encoder_path: facebook/opt-iml-max-1.3b
15
+ tokenizer_path: facebook/opt-iml-max-1.3b
16
+ cross_attn_every_n_layers: 1
17
+ audio_transformer_kwargs: {
18
+ n_head: 8,
19
+ n_layers: 3,
20
+ d_inner: 2048,
21
+ max_num_media: 128,
22
+ max_window_per_audio: 16,
23
+ }
data.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import functools
5
+ import io
6
+ import json
7
+ import math
8
+ import os
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
10
+ import random
11
+ import re
12
+ import string
13
+ import subprocess
14
+ import sys
15
+ import yaml
16
+
17
+ import numpy as np
18
+
19
+ from collections import defaultdict
20
+ from copy import deepcopy
21
+ from dataclasses import dataclass
22
+ from functools import partial
23
+ from pydub import AudioSegment
24
+ from tqdm import tqdm
25
+
26
+ import torch
27
+ import torchvision
28
+ from torch.utils.data import DataLoader, Dataset, get_worker_info
29
+ from torch.utils.data.distributed import DistributedSampler
30
+
31
+
32
+ from transformers import AutoTokenizer
33
+
34
+ import librosa
35
+ import soundfile as sf
36
+
37
+
38
+ def int16_to_float32(x):
39
+ return (x / 32767.0).astype(np.float32)
40
+
41
+
42
+ def float32_to_int16(x):
43
+ x = np.clip(x, a_min=-1., a_max=1.)
44
+ return (x * 32767.).astype(np.int16)
45
+
46
+
47
+ class AudioTextDataProcessor:
48
+ def __init__(
49
+ self,
50
+ data_root: str,
51
+ clap_config: dict,
52
+ tokenizer,
53
+ max_tokens: int,
54
+ **kwargs
55
+ ):
56
+ self.data_root = data_root
57
+ self.clap_config = clap_config
58
+ self.tokenizer = tokenizer
59
+ self.tokenizer.padding_side = "right"
60
+ self.max_tokens = max_tokens
61
+
62
+ def get_num_windows(self, T, sr):
63
+ clap_config = self.clap_config
64
+ window_length = int(float(clap_config["window_length"]) * sr)
65
+ window_overlap = int(float(clap_config["window_overlap"]) * sr)
66
+ max_num_window = int(clap_config["max_num_window"])
67
+
68
+ num_windows = 1
69
+ if T <= window_length:
70
+ num_windows = 1
71
+ full_length = window_length
72
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
73
+ num_windows = max_num_window
74
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
75
+ else:
76
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
77
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
78
+
79
+ return num_windows, full_length
80
+
81
+ def load_audio(self, file_path, target_sr=44100, duration=30.0, start=0.0):
82
+ if file_path.endswith('.mp3'):
83
+ audio = AudioSegment.from_file(file_path)
84
+ if len(audio) > (start + duration) * 1000:
85
+ audio = audio[start * 1000:(start + duration) * 1000]
86
+
87
+ if audio.frame_rate != target_sr:
88
+ audio = audio.set_frame_rate(target_sr)
89
+
90
+ if audio.channels > 1:
91
+ audio = audio.set_channels(1)
92
+
93
+ data = np.array(audio.get_array_of_samples())
94
+ if audio.sample_width == 2:
95
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
96
+ elif audio.sample_width == 4:
97
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
98
+ else:
99
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
100
+
101
+ else:
102
+ with sf.SoundFile(file_path) as audio:
103
+ original_sr = audio.samplerate
104
+ channels = audio.channels
105
+
106
+ max_frames = int((start + duration) * original_sr)
107
+
108
+ audio.seek(int(start * original_sr))
109
+ frames_to_read = min(max_frames, len(audio))
110
+ data = audio.read(frames_to_read)
111
+
112
+ if data.max() > 1 or data.min() < -1:
113
+ data = data / max(abs(data.max()), abs(data.min()))
114
+
115
+ if original_sr != target_sr:
116
+ if channels == 1:
117
+ data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
118
+ else:
119
+ data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
120
+ else:
121
+ if channels != 1:
122
+ data = data.T[0]
123
+
124
+ if data.min() >= 0:
125
+ data = 2 * data / abs(data.max()) - 1.0
126
+ else:
127
+ data = data / max(abs(data.max()), abs(data.min()))
128
+
129
+ assert len(data.shape) == 1, data.shape
130
+ return data
131
+
132
+ def compute_sliding_window(self, audio_file, audio_start=0.0):
133
+ if type(audio_start) == str:
134
+ audio_start = float(audio_start)
135
+
136
+ clap_config = self.clap_config
137
+
138
+ if clap_config["method"] == 'laion-clap':
139
+ sr = 48000
140
+ elif clap_config["method"] == 'microsoft-clap':
141
+ sr = 44100
142
+ else:
143
+ raise NotImplementedError
144
+
145
+ window_length = int(float(clap_config["window_length"]) * sr)
146
+ window_overlap = int(float(clap_config["window_overlap"]) * sr)
147
+ max_num_window = int(clap_config["max_num_window"])
148
+ duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"]
149
+
150
+ audio_data = self.load_audio(audio_file, sr, duration, audio_start)
151
+ T = len(audio_data)
152
+ num_windows, full_length = self.get_num_windows(T, sr)
153
+
154
+ if full_length > T:
155
+ audio_data = np.append(audio_data, np.zeros(full_length - T))
156
+ audio_data = audio_data.reshape(1, -1)
157
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
158
+
159
+ audio_clips = []
160
+ audio_embed_mask = torch.zeros(max_num_window)
161
+ for i in range(num_windows):
162
+ start = i * (window_length - window_overlap)
163
+ audio_clips.append(audio_data_tensor[:, start:start+window_length])
164
+ audio_embed_mask[i] = 1
165
+
166
+ assert sum(audio_embed_mask) == num_windows
167
+
168
+ if num_windows < max_num_window:
169
+ for _ in range(max_num_window - num_windows):
170
+ audio_clips.append(torch.zeros_like(audio_clips[-1]))
171
+
172
+ audio_clips = torch.cat(audio_clips) # (max_num_window, window_length * sr) cuda tensor
173
+
174
+ return audio_clips, audio_embed_mask
175
+
176
+ def preprocess_string_for_eval(self, x):
177
+ x = x.rstrip().lstrip()
178
+ x = x.lower()
179
+ return x
180
+
181
+ def process(self, item):
182
+ if type(item['name']) is str:
183
+ audio_files = [os.path.join(self.data_root, item['name'])]
184
+ audio_starts = [0 if 'audio_start' not in item else float(item['audio_start'])]
185
+ else:
186
+ audio_files = [os.path.join(self.data_root, name) for name in item['name']]
187
+ audio_starts = [0] * len(audio_files) if 'audio_start' not in item else item['audio_start']
188
+
189
+ audio_clips, audio_embed_mask = [], []
190
+ for audio_file, audio_start in zip(audio_files, audio_starts):
191
+ this_audio_clips, this_audio_embed_mask = self.compute_sliding_window(audio_file, audio_start)
192
+ audio_clips.append(this_audio_clips)
193
+ audio_embed_mask.append(this_audio_embed_mask)
194
+
195
+ audio_clips = torch.cat(audio_clips)
196
+ audio_embed_mask = torch.cat(audio_embed_mask)
197
+
198
+ correct_num_windows = int(self.clap_config["max_num_window"]) * int(self.clap_config["max_num_fewshot"])
199
+ if len(audio_clips) < correct_num_windows:
200
+ audio_clips = torch.cat([
201
+ audio_clips,
202
+ torch.zeros(correct_num_windows - len(audio_clips), audio_clips.shape[1])
203
+ ])
204
+ audio_embed_mask = torch.cat([
205
+ audio_embed_mask,
206
+ torch.zeros(correct_num_windows - len(audio_embed_mask))
207
+ ])
208
+
209
+ audio_clips.requires_grad = False
210
+ audio_embed_mask.requires_grad = False
211
+
212
+ assert type(item['name']) is str
213
+
214
+ # simple data - 1 audio, 1 text
215
+ if 'prompt' in item:
216
+ text_prompt = item['prompt'].lower()
217
+ prefix = item['prefix'].lower() # the task is xxx.
218
+ sample = "{}{} <audio>{}\nanswer:{}".format(
219
+ self.tokenizer.bos_token,
220
+ self.preprocess_string_for_eval(prefix),
221
+ self.preprocess_string_for_eval(text_prompt),
222
+ self.tokenizer.sep_token
223
+ )
224
+
225
+ # dialog data - 1 audio, multiple text
226
+ elif 'dialogue' in item:
227
+ dialogue = item['dialogue']
228
+ prefix = item['prefix'].lower() # the task is dialog.
229
+ sample = f"{self.tokenizer.bos_token}{prefix}<audio>"
230
+ for each_round in dialogue:
231
+ sample = sample + f"user: {each_round['user']} \nassistant: {self.tokenizer.sep_token}"
232
+ if 'assistant' in each_round:
233
+ sample = sample + f"{each_round['assistant']}<|endofchunk|>{self.tokenizer.eos_token}\n"
234
+
235
+ text = self.tokenizer(
236
+ sample,
237
+ max_length=self.max_tokens*5,
238
+ padding="longest",
239
+ truncation="only_first",
240
+ return_tensors="pt"
241
+ )
242
+
243
+ return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"])
inference_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import string
6
+ import yaml
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, set_seed
11
+ set_seed(0)
12
+
13
+ from data import AudioTextDataProcessor
14
+ from src.factory import create_model_and_transforms
15
+
16
+
17
+ def prepare_tokenizer(model_config):
18
+ tokenizer_path = model_config['tokenizer_path']
19
+ cache_dir = model_config['cache_dir']
20
+ text_tokenizer = AutoTokenizer.from_pretrained(
21
+ tokenizer_path,
22
+ local_files_only=False,
23
+ trust_remote_code=True,
24
+ cache_dir=cache_dir,
25
+ )
26
+ text_tokenizer.add_special_tokens(
27
+ {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
28
+ )
29
+ if text_tokenizer.pad_token is None:
30
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
31
+ if text_tokenizer.sep_token is None:
32
+ text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
33
+ return text_tokenizer
34
+
35
+
36
+ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
38
+ model, tokenizer = create_model_and_transforms(
39
+ **model_config,
40
+ clap_config=clap_config,
41
+ use_local_files=False,
42
+ gradient_checkpointing=False,
43
+ freeze_lm_embeddings=False,
44
+ )
45
+ model.eval()
46
+ model = model.to(device_id)
47
+
48
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
49
+ model_state_dict = checkpoint["model_state_dict"]
50
+ model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
51
+ model.load_state_dict(model_state_dict, False)
52
+
53
+ return model
54
+
55
+
56
+ def inference(model, tokenizer, item, processed_item, inference_kwargs, device_id=0):
57
+ filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
58
+ audio_clips = audio_clips.to(device_id, dtype=None, non_blocking=True)
59
+ audio_embed_mask = audio_embed_mask.to(device_id, dtype=None, non_blocking=True)
60
+ input_ids = input_ids.to(device_id, dtype=None, non_blocking=True).squeeze()
61
+
62
+ media_token_id = tokenizer.encode("<audio>")[-1]
63
+ eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
64
+ sep_token_id = tokenizer.sep_token_id
65
+ eos_token_id = tokenizer.eos_token_id
66
+
67
+ outputs = model.generate(
68
+ audio_x=audio_clips.unsqueeze(0),
69
+ audio_x_mask=audio_embed_mask.unsqueeze(0),
70
+ lang_x=input_ids.unsqueeze(0),
71
+ eos_token_id=eos_token_id,
72
+ max_new_tokens=128,
73
+ **inference_kwargs,
74
+ )
75
+
76
+ outputs_decoded = [
77
+ tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
78
+ ]
79
+
80
+ return outputs_decoded
81
+
ms_clap/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ms_clap/.gitignore ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.rsuser
8
+ *.suo
9
+ *.user
10
+ *.userosscache
11
+ *.sln.docstates
12
+
13
+ # User-specific files (MonoDevelop/Xamarin Studio)
14
+ *.userprefs
15
+
16
+ # Mono auto generated files
17
+ mono_crash.*
18
+
19
+ # Build results
20
+ [Dd]ebug/
21
+ [Dd]ebugPublic/
22
+ [Rr]elease/
23
+ [Rr]eleases/
24
+ x64/
25
+ x86/
26
+ [Aa][Rr][Mm]/
27
+ [Aa][Rr][Mm]64/
28
+ bld/
29
+ [Bb]in/
30
+ [Oo]bj/
31
+ [Ll]og/
32
+ [Ll]ogs/
33
+
34
+ # Visual Studio 2015/2017 cache/options directory
35
+ .vs/
36
+ # Uncomment if you have tasks that create the project's static files in wwwroot
37
+ #wwwroot/
38
+
39
+ # Visual Studio 2017 auto generated files
40
+ Generated\ Files/
41
+
42
+ # MSTest test Results
43
+ [Tt]est[Rr]esult*/
44
+ [Bb]uild[Ll]og.*
45
+
46
+ # NUnit
47
+ *.VisualState.xml
48
+ TestResult.xml
49
+ nunit-*.xml
50
+
51
+ # Build Results of an ATL Project
52
+ [Dd]ebugPS/
53
+ [Rr]eleasePS/
54
+ dlldata.c
55
+
56
+ # Benchmark Results
57
+ BenchmarkDotNet.Artifacts/
58
+
59
+ # .NET Core
60
+ project.lock.json
61
+ project.fragment.lock.json
62
+ artifacts/
63
+
64
+ # StyleCop
65
+ StyleCopReport.xml
66
+
67
+ # Files built by Visual Studio
68
+ *_i.c
69
+ *_p.c
70
+ *_h.h
71
+ *.ilk
72
+ *.meta
73
+ *.obj
74
+ *.iobj
75
+ *.pch
76
+ *.pdb
77
+ *.ipdb
78
+ *.pgc
79
+ *.pgd
80
+ *.rsp
81
+ *.sbr
82
+ *.tlb
83
+ *.tli
84
+ *.tlh
85
+ *.tmp
86
+ *.tmp_proj
87
+ *_wpftmp.csproj
88
+ *.log
89
+ *.vspscc
90
+ *.vssscc
91
+ .builds
92
+ *.pidb
93
+ *.svclog
94
+ *.scc
95
+
96
+ # Chutzpah Test files
97
+ _Chutzpah*
98
+
99
+ # Visual C++ cache files
100
+ ipch/
101
+ *.aps
102
+ *.ncb
103
+ *.opendb
104
+ *.opensdf
105
+ *.sdf
106
+ *.cachefile
107
+ *.VC.db
108
+ *.VC.VC.opendb
109
+
110
+ # Visual Studio profiler
111
+ *.psess
112
+ *.vsp
113
+ *.vspx
114
+ *.sap
115
+
116
+ # Visual Studio Trace Files
117
+ *.e2e
118
+
119
+ # TFS 2012 Local Workspace
120
+ $tf/
121
+
122
+ # Guidance Automation Toolkit
123
+ *.gpState
124
+
125
+ # ReSharper is a .NET coding add-in
126
+ _ReSharper*/
127
+ *.[Rr]e[Ss]harper
128
+ *.DotSettings.user
129
+
130
+ # TeamCity is a build add-in
131
+ _TeamCity*
132
+
133
+ # DotCover is a Code Coverage Tool
134
+ *.dotCover
135
+
136
+ # AxoCover is a Code Coverage Tool
137
+ .axoCover/*
138
+ !.axoCover/settings.json
139
+
140
+ # Visual Studio code coverage results
141
+ *.coverage
142
+ *.coveragexml
143
+
144
+ # NCrunch
145
+ _NCrunch_*
146
+ .*crunch*.local.xml
147
+ nCrunchTemp_*
148
+
149
+ # MightyMoose
150
+ *.mm.*
151
+ AutoTest.Net/
152
+
153
+ # Web workbench (sass)
154
+ .sass-cache/
155
+
156
+ # Installshield output folder
157
+ [Ee]xpress/
158
+
159
+ # DocProject is a documentation generator add-in
160
+ DocProject/buildhelp/
161
+ DocProject/Help/*.HxT
162
+ DocProject/Help/*.HxC
163
+ DocProject/Help/*.hhc
164
+ DocProject/Help/*.hhk
165
+ DocProject/Help/*.hhp
166
+ DocProject/Help/Html2
167
+ DocProject/Help/html
168
+
169
+ # Click-Once directory
170
+ publish/
171
+
172
+ # Publish Web Output
173
+ *.[Pp]ublish.xml
174
+ *.azurePubxml
175
+ # Note: Comment the next line if you want to checkin your web deploy settings,
176
+ # but database connection strings (with potential passwords) will be unencrypted
177
+ *.pubxml
178
+ *.publishproj
179
+
180
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
181
+ # checkin your Azure Web App publish settings, but sensitive information contained
182
+ # in these scripts will be unencrypted
183
+ PublishScripts/
184
+
185
+ # NuGet Packages
186
+ *.nupkg
187
+ # NuGet Symbol Packages
188
+ *.snupkg
189
+ # The packages folder can be ignored because of Package Restore
190
+ **/[Pp]ackages/*
191
+ # except build/, which is used as an MSBuild target.
192
+ !**/[Pp]ackages/build/
193
+ # Uncomment if necessary however generally it will be regenerated when needed
194
+ #!**/[Pp]ackages/repositories.config
195
+ # NuGet v3's project.json files produces more ignorable files
196
+ *.nuget.props
197
+ *.nuget.targets
198
+
199
+ # Microsoft Azure Build Output
200
+ csx/
201
+ *.build.csdef
202
+
203
+ # Microsoft Azure Emulator
204
+ ecf/
205
+ rcf/
206
+
207
+ # Windows Store app package directories and files
208
+ AppPackages/
209
+ BundleArtifacts/
210
+ Package.StoreAssociation.xml
211
+ _pkginfo.txt
212
+ *.appx
213
+ *.appxbundle
214
+ *.appxupload
215
+
216
+ # Visual Studio cache files
217
+ # files ending in .cache can be ignored
218
+ *.[Cc]ache
219
+ # but keep track of directories ending in .cache
220
+ !?*.[Cc]ache/
221
+
222
+ # Others
223
+ ClientBin/
224
+ ~$*
225
+ *~
226
+ *.dbmdl
227
+ *.dbproj.schemaview
228
+ *.jfm
229
+ *.pfx
230
+ *.publishsettings
231
+ orleans.codegen.cs
232
+
233
+ # Including strong name files can present a security risk
234
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
235
+ #*.snk
236
+
237
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
238
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
239
+ #bower_components/
240
+
241
+ # RIA/Silverlight projects
242
+ Generated_Code/
243
+
244
+ # Backup & report files from converting an old project file
245
+ # to a newer Visual Studio version. Backup files are not needed,
246
+ # because we have git ;-)
247
+ _UpgradeReport_Files/
248
+ Backup*/
249
+ UpgradeLog*.XML
250
+ UpgradeLog*.htm
251
+ ServiceFabricBackup/
252
+ *.rptproj.bak
253
+
254
+ # SQL Server files
255
+ *.mdf
256
+ *.ldf
257
+ *.ndf
258
+
259
+ # Business Intelligence projects
260
+ *.rdl.data
261
+ *.bim.layout
262
+ *.bim_*.settings
263
+ *.rptproj.rsuser
264
+ *- [Bb]ackup.rdl
265
+ *- [Bb]ackup ([0-9]).rdl
266
+ *- [Bb]ackup ([0-9][0-9]).rdl
267
+
268
+ # Microsoft Fakes
269
+ FakesAssemblies/
270
+
271
+ # GhostDoc plugin setting file
272
+ *.GhostDoc.xml
273
+
274
+ # Node.js Tools for Visual Studio
275
+ .ntvs_analysis.dat
276
+ node_modules/
277
+
278
+ # Visual Studio 6 build log
279
+ *.plg
280
+
281
+ # Visual Studio 6 workspace options file
282
+ *.opt
283
+
284
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
285
+ *.vbw
286
+
287
+ # Visual Studio LightSwitch build output
288
+ **/*.HTMLClient/GeneratedArtifacts
289
+ **/*.DesktopClient/GeneratedArtifacts
290
+ **/*.DesktopClient/ModelManifest.xml
291
+ **/*.Server/GeneratedArtifacts
292
+ **/*.Server/ModelManifest.xml
293
+ _Pvt_Extensions
294
+
295
+ # Paket dependency manager
296
+ .paket/paket.exe
297
+ paket-files/
298
+
299
+ # FAKE - F# Make
300
+ .fake/
301
+
302
+ # CodeRush personal settings
303
+ .cr/personal
304
+
305
+ # Python Tools for Visual Studio (PTVS)
306
+ __pycache__/
307
+ *.pyc
308
+
309
+ # Cake - Uncomment if you are using it
310
+ # tools/**
311
+ # !tools/packages.config
312
+
313
+ # Tabs Studio
314
+ *.tss
315
+
316
+ # Telerik's JustMock configuration file
317
+ *.jmconfig
318
+
319
+ # BizTalk build output
320
+ *.btp.cs
321
+ *.btm.cs
322
+ *.odx.cs
323
+ *.xsd.cs
324
+
325
+ # OpenCover UI analysis results
326
+ OpenCover/
327
+
328
+ # Azure Stream Analytics local run output
329
+ ASALocalRun/
330
+
331
+ # MSBuild Binary and Structured Log
332
+ *.binlog
333
+
334
+ # NVidia Nsight GPU debugger configuration file
335
+ *.nvuser
336
+
337
+ # MFractors (Xamarin productivity tool) working folder
338
+ .mfractor/
339
+
340
+ # Local History for Visual Studio
341
+ .localhistory/
342
+
343
+ # BeatPulse healthcheck temp database
344
+ healthchecksdb
345
+
346
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
347
+ MigrationBackup/
348
+
349
+ # Ionide (cross platform F# VS Code tools) working folder
350
+ .ionide/
ms_clap/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [[email protected]](mailto:[email protected]) with questions or concerns
ms_clap/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
ms_clap/README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###### [Overview](#CLAP) | [Setup](#Setup) | [CLAP weights](#CLAP-weights) | [Usage](#Usage) | [Examples](#Examples) | [Citation](#Citation)
2
+
3
+ # CLAP
4
+
5
+ CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning.
6
+
7
+ <img width="832" alt="clap_diagrams" src="https://github.com/bmartin1/CLAP/assets/26778834/c5340a09-cc0c-4e41-ad5a-61546eaa824c">
8
+
9
+ ## Setup
10
+
11
+ Install the dependencies: `pip install -r requirements.txt` using Python 3 to get started.
12
+
13
+ If you have [conda](https://www.anaconda.com) installed, you can run the following:
14
+
15
+ ```shell
16
+ git clone https://github.com/microsoft/CLAP.git && \
17
+ cd CLAP && \
18
+ conda create -n clap python=3.10 && \
19
+ conda activate clap && \
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## NEW CLAP weights
24
+ Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Pretrained Model \[Zenodo\]](https://zenodo.org/record/8378278)
25
+
26
+ _clapcap_ is the audio captioning model that uses the 2023 encoders.
27
+
28
+ ## Usage
29
+
30
+ - Zero-Shot Classification and Retrieval
31
+ ```python
32
+ # Load model (Choose between versions '2022' or '2023')
33
+ from src import CLAP
34
+
35
+ clap_model = CLAP("<PATH TO WEIGHTS>", version = '2023', use_cuda=False)
36
+
37
+ # Extract text embeddings
38
+ text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
39
+
40
+ # Extract audio embeddings
41
+ audio_embeddings = clap_model.get_audio_embeddings(file_paths: List[str])
42
+
43
+ # Compute similarity between audio and text embeddings
44
+ similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
45
+ ```
46
+
47
+ - Audio Captioning
48
+ ```python
49
+ # Load model (Choose version 'clapcap')
50
+ from src import CLAP
51
+
52
+ clap_model = CLAP("<PATH TO WEIGHTS>", version = 'clapcap', use_cuda=False)
53
+
54
+ # Generate audio captions
55
+ captions = clap_model.generate_caption(file_paths: List[str])
56
+ ```
57
+
58
+ ## Examples
59
+ Take a look at `CLAP\src\` for usage examples.
60
+
61
+ To run Zero-Shot Classification on the ESC50 dataset try the following:
62
+
63
+ ```bash
64
+ > cd src && python zero_shot_classification.py
65
+ ```
66
+ Output (version 2023)
67
+ ```bash
68
+ ESC50 Accuracy: 93.9%
69
+ ```
70
+
71
+ ## Citation
72
+
73
+ Kindly cite our work if you find it useful.
74
+
75
+ [CLAP: Learning Audio Concepts from Natural Language Supervision](https://ieeexplore.ieee.org/abstract/document/10095889)
76
+ ```
77
+ @inproceedings{CLAP2022,
78
+ title={Clap learning audio concepts from natural language supervision},
79
+ author={Elizalde, Benjamin and Deshmukh, Soham and Al Ismail, Mahmoud and Wang, Huaming},
80
+ booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
81
+ pages={1--5},
82
+ year={2023},
83
+ organization={IEEE}
84
+ }
85
+ ```
86
+
87
+ [Natural Language Supervision for General-Purpose Audio Representations](https://arxiv.org/abs/2309.05767)
88
+ ```
89
+ @misc{CLAP2023,
90
+ title={Natural Language Supervision for General-Purpose Audio Representations},
91
+ author={Benjamin Elizalde and Soham Deshmukh and Huaming Wang},
92
+ year={2023},
93
+ eprint={2309.05767},
94
+ archivePrefix={arXiv},
95
+ primaryClass={cs.SD},
96
+ url={https://arxiv.org/abs/2309.05767}
97
+ }
98
+ ```
99
+
100
+ ## Contributing
101
+
102
+ This project welcomes contributions and suggestions. Most contributions require you to agree to a
103
+ Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
104
+ the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
105
+
106
+ When you submit a pull request, a CLA bot will automatically determine whether you need to provide
107
+ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
108
+ provided by the bot. You will only need to do this once across all repos using our CLA.
109
+
110
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
111
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
112
+ contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
113
+
114
+ ## Trademarks
115
+
116
+ This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
117
+ trademarks or logos is subject to and must follow
118
+ [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
119
+ Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
120
+ Any use of third-party trademarks or logos are subject to those third-party's policies.
ms_clap/SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
ms_clap/SUPPORT.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: The maintainer of this repo has not yet edited this file
2
+
3
+ **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
+
5
+ - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
+ - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7
+ - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8
+
9
+ *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
+
11
+ # Support
12
+
13
+ ## How to file issues and get help
14
+
15
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
+ issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
+ feature request as a new Issue.
18
+
19
+ For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
+ FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
+ CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
+
23
+ ## Microsoft Support Policy
24
+
25
+ Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
ms_clap/requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ appdirs==1.4.4
2
+ audioread==3.0.0
3
+ certifi==2022.12.7
4
+ cffi==1.15.1
5
+ charset-normalizer==3.0.1
6
+ colorama==0.4.6
7
+ decorator==5.1.1
8
+ filelock==3.9.0
9
+ flit_core==3.6.0
10
+ huggingface-hub==0.12.1
11
+ idna==3.4
12
+ importlib-metadata==6.0.0
13
+ importlib-resources==5.12.0
14
+ jaraco.classes==3.2.3
15
+ joblib==1.2.0
16
+ lazy_loader==0.1
17
+ librosa==0.10.0
18
+ llvmlite==0.39.1
19
+ mkl-service==2.4.0
20
+ more-itertools==9.0.0
21
+ msgpack==1.0.4
22
+ numba==0.56.4
23
+ numpy==1.23.5
24
+ packaging==23.0
25
+ pandas==1.4.2
26
+ pooch==1.6.0
27
+ pycparser==2.21
28
+ pywin32-ctypes==0.2.0
29
+ PyYAML==6.0
30
+ regex==2022.10.31
31
+ requests==2.28.2
32
+ scikit-learn==1.2.1
33
+ scipy==1.10.1
34
+ setuptools==65.6.3
35
+ six==1.16.0
36
+ soundfile==0.12.1
37
+ soxr==0.3.3
38
+ threadpoolctl==3.1.0
39
+ tokenizers==0.13.2
40
+ torch==1.13.1
41
+ torchaudio==0.13.1
42
+ torchlibrosa==0.1.0
43
+ torchvision==0.14.1
44
+ tqdm==4.64.1
45
+ transformers==4.26.1
46
+ typing_extensions==4.4.0
47
+ urllib3==1.26.14
48
+ wheel==0.38.4
49
+ wincertstore==0.2
50
+ zipp==3.14.0
ms_clap/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ms_clap/src/CLAPWrapper.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import random
4
+ import torchaudio
5
+ # from torch._six import string_classes
6
+ import collections
7
+ import re
8
+ import numpy as np
9
+ from transformers import AutoTokenizer, logging
10
+ try:
11
+ from models.clap import CLAP
12
+ from models.mapper import get_clapcap
13
+ except:
14
+ from .models.clap import CLAP
15
+ from .models.mapper import get_clapcap
16
+ import math
17
+ import torchaudio.transforms as T
18
+ import os
19
+ import torch
20
+ from importlib_resources import files
21
+ import argparse
22
+ import yaml
23
+ import sys
24
+ logging.set_verbosity_error()
25
+
26
+
27
+ class CLAPWrapper():
28
+ """
29
+ A class for interfacing CLAP model.
30
+ """
31
+
32
+ def __init__(self, model_fp, config_root, version, use_cuda=False):
33
+ self.supported_versions = ['2022', '2023', 'clapcap']
34
+ self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
35
+ self.file_path = os.path.realpath(__file__)
36
+ self.default_collate_err_msg_format = (
37
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
38
+ "dicts or lists; found {}")
39
+ self.config_root = config_root
40
+ self.config_as_str = self.get_config_path(version)
41
+ self.model_fp = model_fp
42
+ self.use_cuda = use_cuda
43
+ self.version = version
44
+ if 'clapcap' in self.version:
45
+ self.clapcap, self.tokenizer, self.args = self.load_clapcap()
46
+ else:
47
+ self.clap, self.tokenizer, self.args = self.load_clap()
48
+
49
+ def get_config_path(self, version):
50
+ if version in self.supported_versions:
51
+ # config_root = /home/zkong/audio_flamingo/audio_flamingo_v1/microsoft_clap/src/configs
52
+ return f"{self.config_root}/config_{version}.yml"
53
+ else:
54
+ raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
55
+
56
+ def read_config_as_args(self,config_path,args=None,is_config_str=False):
57
+ return_dict = {}
58
+
59
+ if config_path is not None:
60
+ if is_config_str:
61
+ yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
62
+ else:
63
+ with open(config_path, "r") as f:
64
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
65
+
66
+ if args != None:
67
+ for k, v in yml_config.items():
68
+ if k in args.__dict__:
69
+ args.__dict__[k] = v
70
+ else:
71
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
72
+ else:
73
+ for k, v in yml_config.items():
74
+ return_dict[k] = v
75
+
76
+ args = args if args != None else return_dict
77
+ return argparse.Namespace(**args)
78
+
79
+ def load_clap(self):
80
+ r"""Load CLAP model with args from config file"""
81
+
82
+ args = self.read_config_as_args(self.config_as_str, is_config_str=False)
83
+
84
+ if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
85
+ self.token_keys = ['input_ids', 'attention_mask']
86
+ elif 'bert' in args.text_model:
87
+ self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
88
+
89
+ clap = CLAP(
90
+ audioenc_name=args.audioenc_name,
91
+ sample_rate=args.sampling_rate,
92
+ window_size=args.window_size,
93
+ hop_size=args.hop_size,
94
+ mel_bins=args.mel_bins,
95
+ fmin=args.fmin,
96
+ fmax=args.fmax,
97
+ classes_num=args.num_classes,
98
+ out_emb=args.out_emb,
99
+ text_model=args.text_model,
100
+ transformer_embed_dim=args.transformer_embed_dim,
101
+ d_proj=args.d_proj
102
+ )
103
+
104
+ # Load pretrained weights for model
105
+ model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
106
+
107
+ # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
108
+ # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
109
+ clap.load_state_dict(model_state_dict)
110
+
111
+ clap.eval() # set clap in eval mode
112
+ tokenizer = AutoTokenizer.from_pretrained(args.text_model)
113
+ if 'gpt' in args.text_model:
114
+ tokenizer.add_special_tokens({'pad_token': '!'})
115
+
116
+ if self.use_cuda and torch.cuda.is_available():
117
+ clap = clap.cuda()
118
+
119
+ return clap, tokenizer, args
120
+
121
+ def load_clapcap(self):
122
+ r"""Load CLAP model with args from config file"""
123
+
124
+ args = self.read_config_as_args(self.config_as_str, is_config_str=False)
125
+ args.prefix_dim = args.d_proj
126
+ text_model = args.text_model
127
+ args.text_model = args.text_decoder
128
+ args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
129
+
130
+ if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
131
+ self.token_keys = ['input_ids', 'attention_mask']
132
+ elif 'bert' in args.text_model:
133
+ self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
134
+
135
+ clap = CLAP(
136
+ audioenc_name=args.audioenc_name,
137
+ sample_rate=args.sampling_rate,
138
+ window_size=args.window_size,
139
+ hop_size=args.hop_size,
140
+ mel_bins=args.mel_bins,
141
+ fmin=args.fmin,
142
+ fmax=args.fmax,
143
+ classes_num=args.num_classes,
144
+ out_emb=args.out_emb,
145
+ text_model=text_model,
146
+ transformer_embed_dim=args.transformer_embed_dim,
147
+ d_proj=args.d_proj
148
+ )
149
+
150
+ clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
151
+ args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
152
+
153
+ model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
154
+ clapcap.load_state_dict(model_state_dict)
155
+
156
+ clapcap.eval() # set clap in eval mode
157
+ tokenizer = AutoTokenizer.from_pretrained(args.text_model)
158
+ if 'gpt' in args.text_model:
159
+ tokenizer.add_special_tokens({'pad_token': '!'})
160
+
161
+ if self.use_cuda and torch.cuda.is_available():
162
+ clapcap = clapcap.cuda()
163
+
164
+ return clapcap, tokenizer, args
165
+
166
+ def default_collate(self, batch):
167
+ r"""Puts each data field into a tensor with outer dimension batch size"""
168
+ elem = batch[0]
169
+ elem_type = type(elem)
170
+ if isinstance(elem, torch.Tensor):
171
+ out = None
172
+ if torch.utils.data.get_worker_info() is not None:
173
+ # If we're in a background process, concatenate directly into a
174
+ # shared memory tensor to avoid an extra copy
175
+ numel = sum([x.numel() for x in batch])
176
+ storage = elem.storage()._new_shared(numel)
177
+ out = elem.new(storage)
178
+ return torch.stack(batch, 0, out=out)
179
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
180
+ and elem_type.__name__ != 'string_':
181
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
182
+ # array of string classes and object
183
+ if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
184
+ raise TypeError(
185
+ self.default_collate_err_msg_format.format(elem.dtype))
186
+
187
+ return self.default_collate([torch.as_tensor(b) for b in batch])
188
+ elif elem.shape == (): # scalars
189
+ return torch.as_tensor(batch)
190
+ elif isinstance(elem, float):
191
+ return torch.tensor(batch, dtype=torch.float64)
192
+ elif isinstance(elem, int):
193
+ return torch.tensor(batch)
194
+ # elif isinstance(elem, string_classes):
195
+ # return batch
196
+ elif isinstance(elem, collections.abc.Mapping):
197
+ return {key: self.default_collate([d[key] for d in batch]) for key in elem}
198
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
199
+ return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
200
+ elif isinstance(elem, collections.abc.Sequence):
201
+ # check to make sure that the elements in batch have consistent size
202
+ it = iter(batch)
203
+ elem_size = len(next(it))
204
+ if not all(len(elem) == elem_size for elem in it):
205
+ raise RuntimeError(
206
+ 'each element in list of batch should be of equal size')
207
+ transposed = zip(*batch)
208
+ return [self.default_collate(samples) for samples in transposed]
209
+
210
+ raise TypeError(self.default_collate_err_msg_format.format(elem_type))
211
+
212
+ def read_audio(self, audio_path, resample=False):
213
+ r"""Loads audio file or array and returns a torch tensor"""
214
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
215
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
216
+
217
+ resample_rate = self.args.sampling_rate
218
+ if resample:
219
+ resampler = T.Resample(sample_rate, resample_rate)
220
+ audio_time_series = resampler(audio_time_series)
221
+ return audio_time_series, sample_rate
222
+
223
+ def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
224
+ r"""Loads audio file and returns raw audio."""
225
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
226
+ audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
227
+ audio_time_series = audio_time_series.reshape(-1)
228
+
229
+ # audio_time_series is shorter than predefined audio duration,
230
+ # so audio_time_series is extended
231
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
232
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
233
+ audio_time_series.shape[0]))
234
+ # Repeat audio_time_series by repeat_factor to match audio_duration
235
+ audio_time_series = audio_time_series.repeat(repeat_factor)
236
+ # remove excess part of audio_time_series
237
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
238
+ else:
239
+ # audio_time_series is longer than predefined audio duration,
240
+ # so audio_time_series is trimmed
241
+ start_index = random.randrange(
242
+ audio_time_series.shape[0] - audio_duration*sample_rate)
243
+ audio_time_series = audio_time_series[start_index:start_index +
244
+ audio_duration*sample_rate]
245
+ return torch.FloatTensor(audio_time_series)
246
+
247
+ # modified by Kong
248
+ def load_audio_clip_into_tensor(self, audio_clip, audio_duration, resample=False):
249
+ r"""Loads audio clip and returns raw audio."""
250
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
251
+ sample_rate = 44100
252
+ audio_time_series = audio_clip.reshape(-1)
253
+
254
+ # audio_time_series is shorter than predefined audio duration,
255
+ # so audio_time_series is extended
256
+ assert audio_duration * sample_rate >= audio_time_series.shape[0], \
257
+ 'dur * sr = {} should be larger than len = {}'.format(audio_duration * sample_rate, audio_time_series.shape[0])
258
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
259
+ audio_time_series.shape[0]))
260
+ # Repeat audio_time_series by repeat_factor to match audio_duration
261
+ audio_time_series = audio_time_series.repeat(repeat_factor)
262
+ # remove excess part of audio_time_series
263
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
264
+
265
+ # return torch.FloatTensor(audio_time_series)
266
+ return audio_time_series # already on cuda device
267
+
268
+ def preprocess_audio(self, audio_files, resample):
269
+ r"""Load list of audio files and return raw audio"""
270
+ audio_tensors = []
271
+ for audio_file in audio_files:
272
+ audio_tensor = self.load_audio_into_tensor(
273
+ audio_file, self.args.duration, resample)
274
+ audio_tensor = audio_tensor.reshape(
275
+ 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
276
+ audio_tensors.append(audio_tensor)
277
+ return self.default_collate(audio_tensors)
278
+
279
+ # modified by Kong
280
+ def preprocess_audio_clips(self, audio_clips, resample=False):
281
+ r"""Load list of audio clips and return raw audio"""
282
+ audio_tensors = []
283
+ for audio_clip in audio_clips:
284
+ audio_tensor = self.load_audio_clip_into_tensor(
285
+ audio_clip, self.args.duration, resample=False)
286
+ audio_tensor = audio_tensor.reshape(
287
+ 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
288
+ audio_tensors.append(audio_tensor)
289
+ return self.default_collate(audio_tensors)
290
+
291
+ def preprocess_text(self, text_queries):
292
+ r"""Load list of class labels and return tokenized text"""
293
+ tokenized_texts = []
294
+ for ttext in text_queries:
295
+ if 'gpt' in self.args.text_model:
296
+ ttext = ttext + ' <|endoftext|>'
297
+ tok = self.tokenizer.encode_plus(
298
+ text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
299
+ for key in self.token_keys:
300
+ tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
301
+ tokenized_texts.append(tok)
302
+ return self.default_collate(tokenized_texts)
303
+
304
+ def get_text_embeddings(self, class_labels):
305
+ r"""Load list of class labels and return text embeddings"""
306
+ preprocessed_text = self.preprocess_text(class_labels)
307
+ return self._get_text_embeddings(preprocessed_text)
308
+
309
+ def get_audio_embeddings(self, audio_files, resample):
310
+ r"""Load list of audio files and return a audio embeddings"""
311
+ preprocessed_audio = self.preprocess_audio(audio_files, resample)
312
+ return self._get_audio_embeddings(preprocessed_audio)
313
+
314
+ # modified by Kong
315
+ def get_audio_embeddings_from_clips(self, audio_clips, resample=False):
316
+ r"""Load list of audio files and return a audio embeddings"""
317
+ preprocessed_audio = self.preprocess_audio_clips(audio_clips, resample)
318
+ return self._get_audio_embeddings(preprocessed_audio)
319
+
320
+ def _get_text_embeddings(self, preprocessed_text):
321
+ r"""Load preprocessed text and return text embeddings"""
322
+ with torch.no_grad():
323
+ return self.clap.caption_encoder(preprocessed_text)
324
+
325
+ # modified by Kong
326
+ def _get_audio_embeddings(self, preprocessed_audio):
327
+ r"""Load preprocessed audio and return a audio embeddings"""
328
+ with torch.no_grad():
329
+ preprocessed_audio = preprocessed_audio.reshape(
330
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
331
+ #Append [0] the audio emebdding, [1] has output class probabilities
332
+ if 'clapcap' in self.version:
333
+ return self.clapcap.clap(preprocessed_audio)[0]
334
+ else:
335
+ return self.clap.audio_encoder(preprocessed_audio)[0]
336
+
337
+ def _generic_batch_inference(self, func, *args):
338
+ r"""Process audio and/or text per batch"""
339
+ input_tmp = args[0]
340
+ batch_size = args[-1]
341
+ # args[0] has audio_files, args[1] has class_labels
342
+ inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
343
+ args0_len = len(args[0])
344
+ # compute text_embeddings once for all the audio_files batches
345
+ if len(inputs) == 2:
346
+ text_embeddings = self.get_text_embeddings(args[1])
347
+ inputs = [args[0], args[1], text_embeddings]
348
+ dataset_idx = 0
349
+ for _ in range(math.ceil(args0_len/batch_size)):
350
+ next_batch_idx = dataset_idx + batch_size
351
+ # batch size is bigger than available audio/text items
352
+ if next_batch_idx >= args0_len:
353
+ inputs[0] = input_tmp[dataset_idx:]
354
+ return func(*tuple(inputs))
355
+ else:
356
+ inputs[0] = input_tmp[dataset_idx:next_batch_idx]
357
+ yield func(*tuple(inputs))
358
+ dataset_idx = next_batch_idx
359
+
360
+ def get_audio_embeddings_per_batch(self, audio_files, batch_size):
361
+ r"""Load preprocessed audio and return a audio embeddings per batch"""
362
+ return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
363
+
364
+ def get_text_embeddings_per_batch(self, class_labels, batch_size):
365
+ r"""Load preprocessed text and return text embeddings per batch"""
366
+ return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
367
+
368
+ def compute_similarity(self, audio_embeddings, text_embeddings):
369
+ r"""Compute similarity between text and audio embeddings"""
370
+ audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
371
+ text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
372
+
373
+ logit_scale = self.clap.logit_scale.exp()
374
+ similarity = logit_scale*text_embeddings @ audio_embeddings.T
375
+ return similarity.T
376
+
377
+ def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
378
+ r"""Compute classification probabilities for each audio recording in a batch and each class label"""
379
+ return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
380
+
381
+ def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
382
+ r"""Generate audio captions for each audio recording in a batch"""
383
+ captions = []
384
+ audio_tensors = self.preprocess_audio(audio_files, resample)
385
+
386
+ with torch.no_grad():
387
+ prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
388
+ if self.args.normalize_prefix:
389
+ prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
390
+ prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
391
+
392
+ for i in range(len(audio_tensors)):
393
+ gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
394
+ beam_size=beam_size,\
395
+ entry_length=entry_length,\
396
+ temperature=temperature)[0]
397
+ captions.append(gen_caption.capitalize())
398
+ return captions
399
+
400
+ def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
401
+ entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
402
+ r"""Generate captions by beam search decoding"""
403
+ self.clapcap.eval()
404
+ stop_token_index = self.tokenizer.encode(stop_token)[0]
405
+ tokens = None
406
+ scores = None
407
+ device = next(self.clapcap.parameters()).device
408
+ seq_lengths = torch.ones(beam_size, device=device)
409
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
410
+ with torch.no_grad():
411
+ if embed is not None:
412
+ generated = embed
413
+ else:
414
+ if tokens is None:
415
+ tokens = torch.tensor(self.tokenizer.encode(prompt))
416
+ tokens = tokens.unsqueeze(0).to(device)
417
+ generated = self.clapcap.gpt.transformer.wte(tokens)
418
+ for i in range(entry_length):
419
+ outputs = self.clapcap.gpt(inputs_embeds=generated)
420
+ logits = outputs.logits
421
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
422
+ logits = logits.softmax(-1).log()
423
+ if scores is None:
424
+ scores, next_tokens = logits.topk(beam_size, -1)
425
+ generated = generated.expand(beam_size, *generated.shape[1:])
426
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
427
+ if tokens is None:
428
+ tokens = next_tokens
429
+ else:
430
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
431
+ tokens = torch.cat((tokens, next_tokens), dim=1)
432
+ else:
433
+ logits[is_stopped] = -float(np.inf)
434
+ logits[is_stopped, 0] = 0
435
+ scores_sum = scores[:, None] + logits
436
+ seq_lengths[~is_stopped] += 1
437
+ scores_sum_average = scores_sum / seq_lengths[:, None]
438
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
439
+ next_tokens_source = next_tokens // scores_sum.shape[1]
440
+ seq_lengths = seq_lengths[next_tokens_source]
441
+ next_tokens = next_tokens % scores_sum.shape[1]
442
+ next_tokens = next_tokens.unsqueeze(1)
443
+ tokens = tokens[next_tokens_source]
444
+ tokens = torch.cat((tokens, next_tokens), dim=1)
445
+ generated = generated[next_tokens_source]
446
+ scores = scores_sum_average * seq_lengths
447
+ is_stopped = is_stopped[next_tokens_source]
448
+ next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
449
+ generated = torch.cat((generated, next_token_embed), dim=1)
450
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
451
+ if is_stopped.all():
452
+ break
453
+ scores = scores / seq_lengths
454
+ output_list = tokens.cpu().numpy()
455
+ output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
456
+ order = scores.argsort(descending=True)
457
+ output_texts = [output_texts[i] for i in order]
458
+ return output_texts
ms_clap/src/__init__.py ADDED
File without changes
ms_clap/src/audio_captioning.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an example using CLAPCAP for audio captioning.
3
+ """
4
+ from CLAPWrapper import CLAPWrapper
5
+
6
+ # Load and initialize CLAP
7
+ weights_path = "weights_path"
8
+ clap_model = CLAPWrapper(weights_path, version = 'clapcap', use_cuda=False)
9
+
10
+ #Load audio files
11
+ audio_files = ['audio_file']
12
+
13
+ # Generate captions for the recording
14
+ captions = clap_model.generate_caption(audio_files, resample=True, beam_size=5, entry_length=67, temperature=0.01)
15
+
16
+ # Print the result
17
+ for i in range(len(audio_files)):
18
+ print(f"Audio file: {audio_files[i]} \n")
19
+ print(f"Generated caption: {captions[i]} \n")
20
+
21
+ """
22
+ The output (the exact caption may vary):
23
+
24
+ The birds are singing in the trees.
25
+ """
ms_clap/src/configs/config_2022.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEXT ENCODER CONFIG
2
+ text_model: 'bert-base-uncased'
3
+ text_len: 100
4
+ transformer_embed_dim: 768
5
+ freeze_text_encoder_weights: True
6
+
7
+ # AUDIO ENCODER CONFIG
8
+ audioenc_name: 'Cnn14'
9
+ out_emb: 2048
10
+ sampling_rate: 44100
11
+ duration: 5
12
+ fmin: 50
13
+ fmax: 14000
14
+ n_fft: 1028
15
+ hop_size: 320
16
+ mel_bins: 64
17
+ window_size: 1024
18
+
19
+ # PROJECTION SPACE CONFIG
20
+ d_proj: 1024
21
+ temperature: 0.003
22
+
23
+ # TRAINING AND EVALUATION CONFIG
24
+ num_classes: 527
25
+ batch_size: 1024
26
+ demo: False
ms_clap/src/configs/config_2023.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEXT ENCODER CONFIG
2
+ text_model: 'gpt2'
3
+ text_len: 77
4
+ transformer_embed_dim: 768
5
+ freeze_text_encoder_weights: True
6
+
7
+ # AUDIO ENCODER CONFIG
8
+ audioenc_name: 'HTSAT'
9
+ out_emb: 768
10
+ sampling_rate: 44100
11
+ duration: 7
12
+ fmin: 50
13
+ fmax: 8000 #14000
14
+ n_fft: 1024 # 1028
15
+ hop_size: 320
16
+ mel_bins: 64
17
+ window_size: 1024
18
+
19
+ # PROJECTION SPACE CONFIG
20
+ d_proj: 1024
21
+ temperature: 0.003
22
+
23
+ # TRAINING AND EVALUATION CONFIG
24
+ num_classes: 527
25
+ batch_size: 1024
26
+ demo: False
ms_clap/src/configs/config_clapcap.yml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEXT ENCODER CONFIG
2
+ text_model: 'gpt2'
3
+ transformer_embed_dim: 768
4
+ freeze_text_encoder_weights: True
5
+
6
+ # AUDIO ENCODER CONFIG
7
+ audioenc_name: 'HTSAT'
8
+ out_emb: 768
9
+ sampling_rate: 44100
10
+ duration: 7
11
+ fmin: 50
12
+ fmax: 8000
13
+ n_fft: 1024
14
+ hop_size: 320
15
+ mel_bins: 64
16
+ window_size: 1024
17
+
18
+ # PROJECTION SPACE CONFIG
19
+ d_proj: 1024
20
+ temperature: 0.003
21
+
22
+ # TRAINING AND EVALUATION CONFIG
23
+ batch_size: 128
24
+ num_classes: 527
25
+
26
+ # CLAPCAP CONFIG
27
+ clapcap_model: 'ClapCaption'
28
+ text_decoder: 'gpt2'
29
+ prefix_length: 40
30
+ prefix_length_clip: 40
31
+ mapping_type: 'transformer'
32
+ num_layers: 8
33
+ normalize_prefix: True
34
+ freeze_gpt_weights: True
ms_clap/src/esc50_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.datasets.utils import download_url
3
+ from tqdm import tqdm
4
+ import pandas as pd
5
+ import os
6
+ import torch.nn as nn
7
+ import torch
8
+
9
+ class AudioDataset(Dataset):
10
+ def __init__(self, root: str, download: bool = True):
11
+ self.root = os.path.expanduser(root)
12
+ if download:
13
+ self.download()
14
+
15
+ def __getitem__(self, index):
16
+ raise NotImplementedError
17
+
18
+ def download(self):
19
+ raise NotImplementedError
20
+
21
+ def __len__(self):
22
+ raise NotImplementedError
23
+
24
+
25
+ class ESC50(AudioDataset):
26
+ base_folder = 'ESC-50-master'
27
+ url = "https://github.com/karoldvl/ESC-50/archive/master.zip"
28
+ filename = "ESC-50-master.zip"
29
+ num_files_in_dir = 2000
30
+ audio_dir = 'audio'
31
+ label_col = 'category'
32
+ file_col = 'filename'
33
+ meta = {
34
+ 'filename': os.path.join('meta','esc50.csv'),
35
+ }
36
+
37
+ def __init__(self, root, reading_transformations: nn.Module = None, download: bool = True):
38
+ super().__init__(root)
39
+ self._load_meta()
40
+
41
+ self.targets, self.audio_paths = [], []
42
+ self.pre_transformations = reading_transformations
43
+ print("Loading audio files")
44
+ # self.df['filename'] = os.path.join(self.root, self.base_folder, self.audio_dir) + os.sep + self.df['filename']
45
+ self.df['category'] = self.df['category'].str.replace('_',' ')
46
+
47
+ for _, row in tqdm(self.df.iterrows()):
48
+ file_path = os.path.join(self.root, self.base_folder, self.audio_dir, row[self.file_col])
49
+ self.targets.append(row[self.label_col])
50
+ self.audio_paths.append(file_path)
51
+
52
+ def _load_meta(self):
53
+ path = os.path.join(self.root, self.base_folder, self.meta['filename'])
54
+
55
+ self.df = pd.read_csv(path)
56
+ self.class_to_idx = {}
57
+ self.classes = [x.replace('_',' ') for x in sorted(self.df[self.label_col].unique())]
58
+ for i, category in enumerate(self.classes):
59
+ self.class_to_idx[category] = i
60
+
61
+ def __getitem__(self, index):
62
+ """
63
+ Args:
64
+ index (int): Index
65
+ Returns:
66
+ tuple: (image, target) where target is index of the target class.
67
+ """
68
+ file_path, target = self.audio_paths[index], self.targets[index]
69
+ idx = torch.tensor(self.class_to_idx[target])
70
+ one_hot_target = torch.zeros(len(self.classes)).scatter_(0, idx, 1).reshape(1,-1)
71
+ return file_path, target, one_hot_target
72
+
73
+ def __len__(self):
74
+ return len(self.audio_paths)
75
+
76
+ def download(self):
77
+ download_url(self.url, self.root, self.filename)
78
+
79
+ # extract file
80
+ from zipfile import ZipFile
81
+ with ZipFile(os.path.join(self.root, self.filename), 'r') as zip:
82
+ zip.extractall(path=self.root)
ms_clap/src/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import clap
2
+ from . import audio
3
+ from . import htsat
4
+ from . import config
5
+ from . import pytorch_utils
6
+ from . import htsat
ms_clap/src/models/audio.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
5
+
6
+ try:
7
+ from models.htsat import HTSATWrapper
8
+ except:
9
+ from .htsat import HTSATWrapper
10
+
11
+ def get_audio_encoder(name: str):
12
+ if name == "Cnn14":
13
+ return Cnn14
14
+ elif name == "HTSAT":
15
+ return HTSATWrapper
16
+ else:
17
+ raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
18
+
19
+
20
+ class ConvBlock(nn.Module):
21
+ def __init__(self, in_channels, out_channels):
22
+
23
+ super(ConvBlock, self).__init__()
24
+
25
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
26
+ out_channels=out_channels,
27
+ kernel_size=(3, 3), stride=(1, 1),
28
+ padding=(1, 1), bias=False)
29
+
30
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
31
+ out_channels=out_channels,
32
+ kernel_size=(3, 3), stride=(1, 1),
33
+ padding=(1, 1), bias=False)
34
+
35
+ self.bn1 = nn.BatchNorm2d(out_channels)
36
+ self.bn2 = nn.BatchNorm2d(out_channels)
37
+
38
+
39
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
40
+
41
+ x = input
42
+ x = F.relu_(self.bn1(self.conv1(x)))
43
+ x = F.relu_(self.bn2(self.conv2(x)))
44
+ if pool_type == 'max':
45
+ x = F.max_pool2d(x, kernel_size=pool_size)
46
+ elif pool_type == 'avg':
47
+ x = F.avg_pool2d(x, kernel_size=pool_size)
48
+ elif pool_type == 'avg+max':
49
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
50
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
51
+ x = x1 + x2
52
+ else:
53
+ raise Exception('Incorrect argument!')
54
+
55
+ return x
56
+
57
+
58
+ class ConvBlock5x5(nn.Module):
59
+ def __init__(self, in_channels, out_channels):
60
+
61
+ super(ConvBlock5x5, self).__init__()
62
+
63
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
64
+ out_channels=out_channels,
65
+ kernel_size=(5, 5), stride=(1, 1),
66
+ padding=(2, 2), bias=False)
67
+
68
+ self.bn1 = nn.BatchNorm2d(out_channels)
69
+
70
+
71
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
72
+
73
+ x = input
74
+ x = F.relu_(self.bn1(self.conv1(x)))
75
+ if pool_type == 'max':
76
+ x = F.max_pool2d(x, kernel_size=pool_size)
77
+ elif pool_type == 'avg':
78
+ x = F.avg_pool2d(x, kernel_size=pool_size)
79
+ elif pool_type == 'avg+max':
80
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
81
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
82
+ x = x1 + x2
83
+ else:
84
+ raise Exception('Incorrect argument!')
85
+
86
+ return x
87
+
88
+
89
+ class AttBlock(nn.Module):
90
+ def __init__(self, n_in, n_out, activation='linear', temperature=1.):
91
+ super(AttBlock, self).__init__()
92
+
93
+ self.activation = activation
94
+ self.temperature = temperature
95
+ self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
96
+ self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
97
+
98
+ self.bn_att = nn.BatchNorm1d(n_out)
99
+
100
+ def forward(self, x):
101
+ # x: (n_samples, n_in, n_time)
102
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
103
+ cla = self.nonlinear_transform(self.cla(x))
104
+ x = torch.sum(norm_att * cla, dim=2)
105
+ return x, norm_att, cla
106
+
107
+ def nonlinear_transform(self, x):
108
+ if self.activation == 'linear':
109
+ return x
110
+ elif self.activation == 'sigmoid':
111
+ return torch.sigmoid(x)
112
+
113
+
114
+ class Cnn14(nn.Module):
115
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
116
+ fmax, classes_num, out_emb):
117
+
118
+ super(Cnn14, self).__init__()
119
+
120
+ window = 'hann'
121
+ center = True
122
+ pad_mode = 'reflect'
123
+ ref = 1.0
124
+ amin = 1e-10
125
+ top_db = None
126
+
127
+ # Spectrogram extractor
128
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
129
+ win_length=window_size, window=window, center=center, pad_mode=pad_mode,
130
+ freeze_parameters=True)
131
+
132
+ # Logmel feature extractor
133
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
134
+ n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
135
+ freeze_parameters=True)
136
+
137
+ self.bn0 = nn.BatchNorm2d(64)
138
+
139
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
140
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
141
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
142
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
143
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
144
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
145
+
146
+ # out_emb is 2048 for best Cnn14
147
+ self.fc1 = nn.Linear(2048, out_emb, bias=True)
148
+ self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True)
149
+
150
+ def forward(self, input, mixup_lambda=None):
151
+ """
152
+ Input: (batch_size, data_length)
153
+ """
154
+
155
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
156
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
157
+
158
+ x = x.transpose(1, 3)
159
+ x = self.bn0(x)
160
+ x = x.transpose(1, 3)
161
+
162
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
163
+ x = F.dropout(x, p=0.2, training=self.training)
164
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
165
+ x = F.dropout(x, p=0.2, training=self.training)
166
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
167
+ x = F.dropout(x, p=0.2, training=self.training)
168
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
169
+ x = F.dropout(x, p=0.2, training=self.training)
170
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
171
+ x = F.dropout(x, p=0.2, training=self.training)
172
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
173
+ x = F.dropout(x, p=0.2, training=self.training)
174
+ x = torch.mean(x, dim=3)
175
+
176
+ (x1, _) = torch.max(x, dim=2)
177
+ x2 = torch.mean(x, dim=2)
178
+ x = x1 + x2
179
+ x = F.dropout(x, p=0.5, training=self.training)
180
+ x = F.relu_(self.fc1(x))
181
+ embedding = F.dropout(x, p=0.5, training=self.training)
182
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
183
+
184
+ output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}
185
+
186
+ return output_dict
ms_clap/src/models/clap.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from transformers import AutoModel
6
+ from .audio import get_audio_encoder
7
+
8
+ class Projection(nn.Module):
9
+ def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
10
+ super().__init__()
11
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
12
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
13
+ self.layer_norm = nn.LayerNorm(d_out)
14
+ self.drop = nn.Dropout(p)
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ embed1 = self.linear1(x)
18
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
19
+ embeds = self.layer_norm(embed1 + embed2)
20
+ return embeds
21
+
22
+ class AudioEncoder(nn.Module):
23
+ def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
24
+ hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
25
+ super().__init__()
26
+
27
+ audio_encoder = get_audio_encoder(audioenc_name)
28
+
29
+ self.base = audio_encoder(
30
+ sample_rate, window_size,
31
+ hop_size, mel_bins, fmin, fmax,
32
+ classes_num, d_in)
33
+
34
+ self.projection = Projection(d_in, d_out)
35
+
36
+ def forward(self, x):
37
+ out_dict = self.base(x)
38
+ audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
39
+ projected_vec = self.projection(audio_features)
40
+ return projected_vec, audio_classification_output
41
+
42
+ class TextEncoder(nn.Module):
43
+ def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
44
+ super().__init__()
45
+ self.text_model = text_model
46
+ self.base = AutoModel.from_pretrained(text_model)
47
+
48
+ if 'clip' in text_model:
49
+ self.clip_text_projection = self.base.text_projection
50
+ self.base = self.base.text_model
51
+ if 'base' in text_model:
52
+ transformer_embed_dim = 512
53
+
54
+ self.projection = Projection(transformer_embed_dim, d_out)
55
+
56
+ def forward(self, x):
57
+ if 'clip' in self.text_model:
58
+ pooled_output = self.base(**x)[1] # get pooled output
59
+ out = self.clip_text_projection(pooled_output) # get CLS token output
60
+ elif 'gpt' in self.text_model:
61
+ batch_size = x['input_ids'].shape[0]
62
+ hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
63
+
64
+ sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
65
+ out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
66
+ else:
67
+ out = self.base(**x)[0]
68
+ out = out[:, 0, :] # get CLS token output
69
+
70
+ projected_vec = self.projection(out)
71
+
72
+ return projected_vec
73
+
74
+ class CLAP(nn.Module):
75
+ def __init__(self,
76
+ # audio
77
+ audioenc_name: str,
78
+ sample_rate: int,
79
+ window_size: int,
80
+ hop_size: int,
81
+ mel_bins: int,
82
+ fmin: int,
83
+ fmax: int,
84
+ classes_num: int,
85
+ out_emb: int,
86
+ # text
87
+ text_model: str,
88
+ transformer_embed_dim: int,
89
+ # common
90
+ d_proj: int,
91
+ ):
92
+ super().__init__()
93
+
94
+
95
+ self.audio_encoder = AudioEncoder(
96
+ audioenc_name, out_emb, d_proj,
97
+ sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
98
+
99
+ self.caption_encoder = TextEncoder(
100
+ d_proj, text_model, transformer_embed_dim
101
+ )
102
+
103
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
104
+
105
+ def forward(self, audio, text):
106
+ audio_embed, _ = self.audio_encoder(audio)
107
+ caption_embed = self.caption_encoder(text)
108
+
109
+ return caption_embed, audio_embed, self.logit_scale.exp()
ms_clap/src/models/config.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # The configuration for training the model
5
+
6
+ exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
7
+ workspace = "/home/kechen/Research/HTSAT" # the folder of your code
8
+ dataset_path = "/home/Research/audioset" # the dataset path
9
+ desed_folder = "/home/Research/DESED" # the desed file
10
+
11
+ dataset_type = "audioset" # "audioset" "esc-50" "scv2"
12
+ index_type = "full_train" # only works for audioset
13
+ balanced_data = True # only works for audioset
14
+
15
+ loss_type = "clip_bce" #
16
+ # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
17
+
18
+ # trained from a checkpoint, or evaluate a single model
19
+ resume_checkpoint = None
20
+ # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
21
+
22
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
23
+
24
+
25
+ debug = False
26
+
27
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
28
+ batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
29
+ learning_rate = 1e-3 # 1e-4 also workable
30
+ max_epoch = 100
31
+ num_workers = 3
32
+
33
+ lr_scheduler_epoch = [10,20,30]
34
+ lr_rate = [0.02, 0.05, 0.1]
35
+
36
+ # these data preparation optimizations do not bring many improvements, so deprecated
37
+ enable_token_label = False # token label
38
+ class_map_path = "class_hier_map.npy"
39
+ class_filter = None
40
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
41
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
42
+ token_label_range = [0.2,0.6]
43
+ enable_time_shift = False # shift time
44
+ enable_label_enhance = False # enhance hierarchical label
45
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
46
+
47
+
48
+
49
+ # for model's design
50
+ enable_tscam = True # enbale the token-semantic layer
51
+
52
+ # for signal processing
53
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
54
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
55
+ window_size = 1024
56
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
57
+ mel_bins = 64
58
+ fmin = 50
59
+ fmax = 14000
60
+ shift_max = int(clip_samples * 0.5)
61
+
62
+ # for data collection
63
+ classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
64
+ patch_size = (25, 4) # deprecated
65
+ crop_size = None # int(clip_samples * 0.5) deprecated
66
+
67
+ # for htsat hyperparamater
68
+ htsat_window_size = 8
69
+ htsat_spec_size = 256
70
+ htsat_patch_size = 4
71
+ htsat_stride = (4, 4)
72
+ htsat_num_head = [4,8,16,32]
73
+ htsat_dim = 96
74
+ htsat_depth = [2,2,6,2]
75
+
76
+ swin_pretrain_path = None
77
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
78
+
79
+ # Some Deprecated Optimization in the model design, check the model code for details
80
+ htsat_attn_heatmap = False
81
+ htsat_hier_output = False
82
+ htsat_use_max = False
83
+
84
+
85
+ # for ensemble test
86
+
87
+ ensemble_checkpoints = []
88
+ ensemble_strides = []
89
+
90
+
91
+ # weight average folder
92
+ wa_folder = "/home/version_0/checkpoints/"
93
+ # weight average output filename
94
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
95
+
96
+ esm_model_pathes = [
97
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
98
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
99
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
100
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
101
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
102
+ "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
103
+ ]
104
+
105
+ # for framewise localization
106
+ heatmap_dir = "/home/Research/heatmap_output"
107
+ test_file = "htsat-test-ensemble"
108
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
109
+ fl_dataset = "/home/Research/desed/desed_eval.npy"
110
+ fl_class_num = [
111
+ "Speech", "Frying", "Dishes", "Running_water",
112
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
113
+ "Cat", "Dog", "Vacuum_cleaner"
114
+ ]
115
+
116
+ # map 527 classes into 10 classes
117
+ fl_audioset_mapping = [
118
+ [0,1,2,3,4,5,6,7],
119
+ [366, 367, 368],
120
+ [364],
121
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
122
+ [369],
123
+ [382],
124
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
125
+ [81, 82, 83, 84, 85],
126
+ [74, 75, 76, 77, 78, 79],
127
+ [377]
128
+ ]
ms_clap/src/models/htsat.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Model Core
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+
9
+ import logging
10
+ import pdb
11
+ import math
12
+ import random
13
+ from numpy.core.fromnumeric import clip, reshape
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.utils.checkpoint as checkpoint
17
+
18
+ # import os
19
+ import sys
20
+ sys.path.append('/home/zkong/audio_flamingo/audio_flamingo_v1/v0.2/open_flamingo/my_ms_clap/models')
21
+
22
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
23
+ from torchlibrosa.augmentation import SpecAugmentation
24
+
25
+ from itertools import repeat
26
+ from typing import List
27
+ try:
28
+ from models.pytorch_utils import do_mixup, interpolate
29
+ import models.config as config
30
+ except:
31
+ from .pytorch_utils import do_mixup, interpolate
32
+ from . import config
33
+ # from CLAP_API.models.pytorch_utils import do_mixup, interpolate
34
+ # from CLAP_API.models import config
35
+
36
+ import torch.nn.functional as F
37
+ import collections.abc
38
+ import warnings
39
+
40
+ from torch.nn.init import _calculate_fan_in_and_fan_out
41
+
42
+ def _ntuple(n):
43
+ def parse(x):
44
+ if isinstance(x, collections.abc.Iterable):
45
+ return x
46
+ return tuple(repeat(x, n))
47
+ return parse
48
+
49
+ to_1tuple = _ntuple(1)
50
+ to_2tuple = _ntuple(2)
51
+ to_3tuple = _ntuple(3)
52
+ to_4tuple = _ntuple(4)
53
+ to_ntuple = _ntuple
54
+
55
+
56
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
57
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
58
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
59
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
60
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
61
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
62
+ 'survival rate' as the argument.
63
+ """
64
+ if drop_prob == 0. or not training:
65
+ return x
66
+ keep_prob = 1 - drop_prob
67
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
68
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
69
+ random_tensor.floor_() # binarize
70
+ output = x.div(keep_prob) * random_tensor
71
+ return output
72
+
73
+
74
+ class DropPath(nn.Module):
75
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
76
+ """
77
+ def __init__(self, drop_prob=None):
78
+ super(DropPath, self).__init__()
79
+ self.drop_prob = drop_prob
80
+
81
+ def forward(self, x):
82
+ return drop_path(x, self.drop_prob, self.training)
83
+
84
+ class PatchEmbed(nn.Module):
85
+ """ 2D Image to Patch Embedding
86
+ """
87
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
88
+ super().__init__()
89
+ img_size = to_2tuple(img_size)
90
+ patch_size = to_2tuple(patch_size)
91
+ patch_stride = to_2tuple(patch_stride)
92
+ self.img_size = img_size
93
+ self.patch_size = patch_size
94
+ self.patch_stride = patch_stride
95
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
96
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
97
+ self.flatten = flatten
98
+ self.in_chans = in_chans
99
+ self.embed_dim = embed_dim
100
+
101
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
102
+
103
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
104
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
105
+
106
+ def forward(self, x):
107
+ B, C, H, W = x.shape
108
+ assert H == self.img_size[0] and W == self.img_size[1], \
109
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
110
+ x = self.proj(x)
111
+ if self.flatten:
112
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
113
+ x = self.norm(x)
114
+ return x
115
+
116
+ class Mlp(nn.Module):
117
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
118
+ """
119
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
120
+ super().__init__()
121
+ out_features = out_features or in_features
122
+ hidden_features = hidden_features or in_features
123
+ self.fc1 = nn.Linear(in_features, hidden_features)
124
+ self.act = act_layer()
125
+ self.fc2 = nn.Linear(hidden_features, out_features)
126
+ self.drop = nn.Dropout(drop)
127
+
128
+ def forward(self, x):
129
+ x = self.fc1(x)
130
+ x = self.act(x)
131
+ x = self.drop(x)
132
+ x = self.fc2(x)
133
+ x = self.drop(x)
134
+ return x
135
+
136
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
137
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
138
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
139
+ def norm_cdf(x):
140
+ # Computes standard normal cumulative distribution function
141
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
142
+
143
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
144
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
145
+ "The distribution of values may be incorrect.",
146
+ stacklevel=2)
147
+
148
+ with torch.no_grad():
149
+ # Values are generated by using a truncated uniform distribution and
150
+ # then using the inverse CDF for the normal distribution.
151
+ # Get upper and lower cdf values
152
+ l = norm_cdf((a - mean) / std)
153
+ u = norm_cdf((b - mean) / std)
154
+
155
+ # Uniformly fill tensor with values from [l, u], then translate to
156
+ # [2l-1, 2u-1].
157
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
158
+
159
+ # Use inverse cdf transform for normal distribution to get truncated
160
+ # standard normal
161
+ tensor.erfinv_()
162
+
163
+ # Transform to proper mean, std
164
+ tensor.mul_(std * math.sqrt(2.))
165
+ tensor.add_(mean)
166
+
167
+ # Clamp to ensure it's in the proper range
168
+ tensor.clamp_(min=a, max=b)
169
+ return tensor
170
+
171
+
172
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
173
+ # type: (Tensor, float, float, float, float) -> Tensor
174
+ r"""Fills the input Tensor with values drawn from a truncated
175
+ normal distribution. The values are effectively drawn from the
176
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
177
+ with values outside :math:`[a, b]` redrawn until they are within
178
+ the bounds. The method used for generating the random values works
179
+ best when :math:`a \leq \text{mean} \leq b`.
180
+ Args:
181
+ tensor: an n-dimensional `torch.Tensor`
182
+ mean: the mean of the normal distribution
183
+ std: the standard deviation of the normal distribution
184
+ a: the minimum cutoff value
185
+ b: the maximum cutoff value
186
+ Examples:
187
+ >>> w = torch.empty(3, 5)
188
+ >>> nn.init.trunc_normal_(w)
189
+ """
190
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
191
+
192
+
193
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
194
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
195
+ if mode == 'fan_in':
196
+ denom = fan_in
197
+ elif mode == 'fan_out':
198
+ denom = fan_out
199
+ elif mode == 'fan_avg':
200
+ denom = (fan_in + fan_out) / 2
201
+
202
+ variance = scale / denom
203
+
204
+ if distribution == "truncated_normal":
205
+ # constant is stddev of standard normal truncated to (-2, 2)
206
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
207
+ elif distribution == "normal":
208
+ tensor.normal_(std=math.sqrt(variance))
209
+ elif distribution == "uniform":
210
+ bound = math.sqrt(3 * variance)
211
+ tensor.uniform_(-bound, bound)
212
+ else:
213
+ raise ValueError(f"invalid distribution {distribution}")
214
+
215
+
216
+ def lecun_normal_(tensor):
217
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
218
+
219
+
220
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
221
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
222
+
223
+ def window_partition(x, window_size):
224
+ """
225
+ Args:
226
+ x: (B, H, W, C)
227
+ window_size (int): window size
228
+ Returns:
229
+ windows: (num_windows*B, window_size, window_size, C)
230
+ """
231
+ B, H, W, C = x.shape
232
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
233
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
234
+ return windows
235
+
236
+
237
+ def window_reverse(windows, window_size, H, W):
238
+ """
239
+ Args:
240
+ windows: (num_windows*B, window_size, window_size, C)
241
+ window_size (int): Window size
242
+ H (int): Height of image
243
+ W (int): Width of image
244
+ Returns:
245
+ x: (B, H, W, C)
246
+ """
247
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
248
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
249
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
250
+ return x
251
+
252
+
253
+ class WindowAttention(nn.Module):
254
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
255
+ It supports both of shifted and non-shifted window.
256
+ Args:
257
+ dim (int): Number of input channels.
258
+ window_size (tuple[int]): The height and width of the window.
259
+ num_heads (int): Number of attention heads.
260
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
261
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
262
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
263
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
264
+ """
265
+
266
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
267
+
268
+ super().__init__()
269
+ self.dim = dim
270
+ self.window_size = window_size # Wh, Ww
271
+ self.num_heads = num_heads
272
+ head_dim = dim // num_heads
273
+ self.scale = qk_scale or head_dim ** -0.5
274
+
275
+ # define a parameter table of relative position bias
276
+ self.relative_position_bias_table = nn.Parameter(
277
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
278
+
279
+ # get pair-wise relative position index for each token inside the window
280
+ coords_h = torch.arange(self.window_size[0])
281
+ coords_w = torch.arange(self.window_size[1])
282
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
283
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
284
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
285
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
286
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
287
+ relative_coords[:, :, 1] += self.window_size[1] - 1
288
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
289
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
290
+ self.register_buffer("relative_position_index", relative_position_index)
291
+
292
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
293
+ self.attn_drop = nn.Dropout(attn_drop)
294
+ self.proj = nn.Linear(dim, dim)
295
+ self.proj_drop = nn.Dropout(proj_drop)
296
+
297
+ trunc_normal_(self.relative_position_bias_table, std=.02)
298
+ self.softmax = nn.Softmax(dim=-1)
299
+
300
+ def forward(self, x, mask=None):
301
+ """
302
+ Args:
303
+ x: input features with shape of (num_windows*B, N, C)
304
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
305
+ """
306
+ B_, N, C = x.shape
307
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
308
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
309
+
310
+ q = q * self.scale
311
+ attn = (q @ k.transpose(-2, -1))
312
+
313
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
314
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
315
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
316
+ attn = attn + relative_position_bias.unsqueeze(0)
317
+
318
+ if mask is not None:
319
+ nW = mask.shape[0]
320
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
321
+ attn = attn.view(-1, self.num_heads, N, N)
322
+ attn = self.softmax(attn)
323
+ else:
324
+ attn = self.softmax(attn)
325
+
326
+ attn = self.attn_drop(attn)
327
+
328
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
329
+ x = self.proj(x)
330
+ x = self.proj_drop(x)
331
+ return x, attn
332
+
333
+ def extra_repr(self):
334
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
335
+
336
+
337
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
338
+ class SwinTransformerBlock(nn.Module):
339
+ r""" Swin Transformer Block.
340
+ Args:
341
+ dim (int): Number of input channels.
342
+ input_resolution (tuple[int]): Input resulotion.
343
+ num_heads (int): Number of attention heads.
344
+ window_size (int): Window size.
345
+ shift_size (int): Shift size for SW-MSA.
346
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
347
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
348
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
349
+ drop (float, optional): Dropout rate. Default: 0.0
350
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
351
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
352
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ """
355
+
356
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
357
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
358
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
359
+ super().__init__()
360
+ self.dim = dim
361
+ self.input_resolution = input_resolution
362
+ self.num_heads = num_heads
363
+ self.window_size = window_size
364
+ self.shift_size = shift_size
365
+ self.mlp_ratio = mlp_ratio
366
+ self.norm_before_mlp = norm_before_mlp
367
+ if min(self.input_resolution) <= self.window_size:
368
+ # if window size is larger than input resolution, we don't partition windows
369
+ self.shift_size = 0
370
+ self.window_size = min(self.input_resolution)
371
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
372
+
373
+ self.norm1 = norm_layer(dim)
374
+ self.attn = WindowAttention(
375
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
376
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
377
+
378
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
379
+ if self.norm_before_mlp == 'ln':
380
+ self.norm2 = nn.LayerNorm(dim)
381
+ elif self.norm_before_mlp == 'bn':
382
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
383
+ else:
384
+ raise NotImplementedError
385
+ mlp_hidden_dim = int(dim * mlp_ratio)
386
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
387
+
388
+ if self.shift_size > 0:
389
+ # calculate attention mask for SW-MSA
390
+ H, W = self.input_resolution
391
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
392
+ h_slices = (slice(0, -self.window_size),
393
+ slice(-self.window_size, -self.shift_size),
394
+ slice(-self.shift_size, None))
395
+ w_slices = (slice(0, -self.window_size),
396
+ slice(-self.window_size, -self.shift_size),
397
+ slice(-self.shift_size, None))
398
+ cnt = 0
399
+ for h in h_slices:
400
+ for w in w_slices:
401
+ img_mask[:, h, w, :] = cnt
402
+ cnt += 1
403
+
404
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
405
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
406
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
407
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
408
+ else:
409
+ attn_mask = None
410
+
411
+ self.register_buffer("attn_mask", attn_mask)
412
+
413
+ def forward(self, x):
414
+ # pdb.set_trace()
415
+ H, W = self.input_resolution
416
+ # print("H: ", H)
417
+ # print("W: ", W)
418
+ # pdb.set_trace()
419
+ B, L, C = x.shape
420
+ # assert L == H * W, "input feature has wrong size"
421
+
422
+ shortcut = x
423
+ x = self.norm1(x)
424
+ x = x.view(B, H, W, C)
425
+
426
+ # cyclic shift
427
+ if self.shift_size > 0:
428
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
429
+ else:
430
+ shifted_x = x
431
+
432
+ # partition windows
433
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
434
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
435
+
436
+ # W-MSA/SW-MSA
437
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
438
+
439
+ # merge windows
440
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
441
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
442
+
443
+ # reverse cyclic shift
444
+ if self.shift_size > 0:
445
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
446
+ else:
447
+ x = shifted_x
448
+ x = x.view(B, H * W, C)
449
+
450
+ # FFN
451
+ x = shortcut + self.drop_path(x)
452
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
453
+
454
+ return x, attn
455
+
456
+ def extra_repr(self):
457
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
458
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
459
+
460
+
461
+
462
+ class PatchMerging(nn.Module):
463
+ r""" Patch Merging Layer.
464
+ Args:
465
+ input_resolution (tuple[int]): Resolution of input feature.
466
+ dim (int): Number of input channels.
467
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468
+ """
469
+
470
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
471
+ super().__init__()
472
+ self.input_resolution = input_resolution
473
+ self.dim = dim
474
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
475
+ self.norm = norm_layer(4 * dim)
476
+
477
+ def forward(self, x):
478
+ """
479
+ x: B, H*W, C
480
+ """
481
+ H, W = self.input_resolution
482
+ B, L, C = x.shape
483
+ assert L == H * W, "input feature has wrong size"
484
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
485
+
486
+ x = x.view(B, H, W, C)
487
+
488
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
489
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
490
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
491
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
492
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
493
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
494
+
495
+ x = self.norm(x)
496
+ x = self.reduction(x)
497
+
498
+ return x
499
+
500
+ def extra_repr(self):
501
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
502
+
503
+
504
+ class BasicLayer(nn.Module):
505
+ """ A basic Swin Transformer layer for one stage.
506
+ Args:
507
+ dim (int): Number of input channels.
508
+ input_resolution (tuple[int]): Input resolution.
509
+ depth (int): Number of blocks.
510
+ num_heads (int): Number of attention heads.
511
+ window_size (int): Local window size.
512
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
513
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
514
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
515
+ drop (float, optional): Dropout rate. Default: 0.0
516
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
517
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
518
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
519
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
520
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
521
+ """
522
+
523
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
524
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
525
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
526
+ norm_before_mlp='ln'):
527
+
528
+ super().__init__()
529
+ self.dim = dim
530
+ self.input_resolution = input_resolution
531
+ self.depth = depth
532
+ self.use_checkpoint = use_checkpoint
533
+
534
+ # build blocks
535
+ self.blocks = nn.ModuleList([
536
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
537
+ num_heads=num_heads, window_size=window_size,
538
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
539
+ mlp_ratio=mlp_ratio,
540
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
541
+ drop=drop, attn_drop=attn_drop,
542
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
543
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
544
+ for i in range(depth)])
545
+
546
+ # patch merging layer
547
+ if downsample is not None:
548
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
549
+ else:
550
+ self.downsample = None
551
+
552
+ def forward(self, x):
553
+ attns = []
554
+ for blk in self.blocks:
555
+ if self.use_checkpoint:
556
+ x = checkpoint.checkpoint(blk, x)
557
+ else:
558
+ x, attn = blk(x)
559
+ if not self.training:
560
+ attns.append(attn.unsqueeze(0))
561
+ if self.downsample is not None:
562
+ x = self.downsample(x)
563
+ if not self.training:
564
+ attn = torch.cat(attns, dim = 0)
565
+ attn = torch.mean(attn, dim = 0)
566
+ return x, attn
567
+
568
+ def extra_repr(self):
569
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
570
+
571
+
572
+ # The Core of HTSAT
573
+ class HTSAT_Swin_Transformer(nn.Module):
574
+ r"""HTSAT based on the Swin Transformer
575
+ Args:
576
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
577
+ patch_size (int | tuple(int)): Patch size. Default: 4
578
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
579
+ in_chans (int): Number of input image channels. Default: 1 (mono)
580
+ num_classes (int): Number of classes for classification head. Default: 527
581
+ embed_dim (int): Patch embedding dimension. Default: 96
582
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
583
+ num_heads (tuple(int)): Number of attention heads in different layers.
584
+ window_size (int): Window size. Default: 8
585
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
586
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
587
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
588
+ drop_rate (float): Dropout rate. Default: 0
589
+ attn_drop_rate (float): Attention dropout rate. Default: 0
590
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
591
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
592
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
593
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
594
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
595
+ config (module): The configuration Module from config.py
596
+ """
597
+
598
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
599
+ in_chans=1, num_classes=527,
600
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
601
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
602
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
603
+ norm_layer=nn.LayerNorm,
604
+ ape=False, patch_norm=True,
605
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
606
+ super(HTSAT_Swin_Transformer, self).__init__()
607
+
608
+ self.config = config
609
+ self.spec_size = spec_size
610
+ self.patch_stride = patch_stride
611
+ self.patch_size = patch_size
612
+ self.window_size = window_size
613
+ self.embed_dim = embed_dim
614
+ self.depths = depths
615
+ self.ape = ape
616
+ self.in_chans = in_chans
617
+ self.num_classes = num_classes
618
+ self.num_heads = num_heads
619
+ self.num_layers = len(self.depths)
620
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
621
+
622
+ self.drop_rate = drop_rate
623
+ self.attn_drop_rate = attn_drop_rate
624
+ self.drop_path_rate = drop_path_rate
625
+
626
+ self.qkv_bias = qkv_bias
627
+ self.qk_scale = None
628
+
629
+ self.patch_norm = patch_norm
630
+ self.norm_layer = norm_layer if self.patch_norm else None
631
+ self.norm_before_mlp = norm_before_mlp
632
+ self.mlp_ratio = mlp_ratio
633
+
634
+ self.use_checkpoint = use_checkpoint
635
+
636
+ # process mel-spec ; used only once
637
+ self.freq_ratio = self.spec_size // self.config.mel_bins
638
+ window = 'hann'
639
+ center = True
640
+ pad_mode = 'reflect'
641
+ ref = 1.0
642
+ amin = 1e-10
643
+ top_db = None
644
+ self.interpolate_ratio = 32 # Downsampled ratio
645
+ # Spectrogram extractor
646
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
647
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
648
+ freeze_parameters=True)
649
+ # Logmel feature extractor
650
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
651
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
652
+ freeze_parameters=True)
653
+ # Spec augmenter
654
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
655
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
656
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
657
+
658
+
659
+ # split spctrogram into non-overlapping patches
660
+ self.patch_embed = PatchEmbed(
661
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
662
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
663
+
664
+ num_patches = self.patch_embed.num_patches
665
+ patches_resolution = self.patch_embed.grid_size
666
+ self.patches_resolution = patches_resolution
667
+
668
+ # absolute position embedding
669
+ if self.ape:
670
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
671
+ trunc_normal_(self.absolute_pos_embed, std=.02)
672
+
673
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
674
+
675
+ # stochastic depth
676
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
677
+
678
+ # build layers
679
+ self.layers = nn.ModuleList()
680
+ for i_layer in range(self.num_layers):
681
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
682
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
683
+ patches_resolution[1] // (2 ** i_layer)),
684
+ depth=self.depths[i_layer],
685
+ num_heads=self.num_heads[i_layer],
686
+ window_size=self.window_size,
687
+ mlp_ratio=self.mlp_ratio,
688
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
689
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
690
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
691
+ norm_layer=self.norm_layer,
692
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
693
+ use_checkpoint=use_checkpoint,
694
+ norm_before_mlp=self.norm_before_mlp)
695
+ self.layers.append(layer)
696
+
697
+ self.norm = self.norm_layer(self.num_features)
698
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
699
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
700
+
701
+ if self.config.enable_tscam:
702
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
703
+ self.tscam_conv = nn.Conv2d(
704
+ in_channels = self.num_features,
705
+ out_channels = self.num_classes,
706
+ kernel_size = (SF,3),
707
+ padding = (0,1)
708
+ )
709
+ self.head = nn.Linear(num_classes, num_classes)
710
+ else:
711
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
712
+
713
+ self.apply(self._init_weights)
714
+
715
+ def _init_weights(self, m):
716
+ if isinstance(m, nn.Linear):
717
+ trunc_normal_(m.weight, std=.02)
718
+ if isinstance(m, nn.Linear) and m.bias is not None:
719
+ nn.init.constant_(m.bias, 0)
720
+ elif isinstance(m, nn.LayerNorm):
721
+ nn.init.constant_(m.bias, 0)
722
+ nn.init.constant_(m.weight, 1.0)
723
+
724
+ @torch.jit.ignore
725
+ def no_weight_decay(self):
726
+ return {'absolute_pos_embed'}
727
+
728
+ @torch.jit.ignore
729
+ def no_weight_decay_keywords(self):
730
+ return {'relative_position_bias_table'}
731
+
732
+ def forward_features(self, x):
733
+ frames_num = x.shape[2]
734
+ x = self.patch_embed(x)
735
+ if self.ape:
736
+ x = x + self.absolute_pos_embed
737
+ x = self.pos_drop(x)
738
+ for i, layer in enumerate(self.layers):
739
+ x, attn = layer(x)
740
+
741
+ if self.config.enable_tscam:
742
+ # for x
743
+ x = self.norm(x)
744
+ B, N, C = x.shape
745
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
746
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
747
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
748
+ B, C, F, T = x.shape
749
+ # group 2D CNN
750
+ c_freq_bin = F // self.freq_ratio
751
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
752
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
753
+
754
+ # get latent_output
755
+ latent_output = self.avgpool(torch.flatten(x,2))
756
+ latent_output = torch.flatten(latent_output, 1)
757
+
758
+ # display the attention map, if needed
759
+ if self.config.htsat_attn_heatmap:
760
+ # for attn
761
+ attn = torch.mean(attn, dim = 1)
762
+ attn = torch.mean(attn, dim = 1)
763
+ attn = attn.reshape(B, SF, ST)
764
+ c_freq_bin = SF // self.freq_ratio
765
+ attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
766
+ attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
767
+ attn = attn.mean(dim = 1)
768
+ attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
769
+ attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
770
+ attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
771
+ attn = attn.unsqueeze(dim = 2)
772
+
773
+ x = self.tscam_conv(x)
774
+ x = torch.flatten(x, 2) # B, C, T
775
+
776
+ if self.config.htsat_attn_heatmap:
777
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
778
+ else:
779
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
780
+
781
+ x = self.avgpool(x)
782
+ x = torch.flatten(x, 1)
783
+
784
+ if self.config.loss_type == "clip_ce":
785
+ output_dict = {
786
+ 'framewise_output': fpx, # already sigmoided
787
+ 'clipwise_output': x,
788
+ 'latent_output': latent_output
789
+ }
790
+ else:
791
+ output_dict = {
792
+ 'framewise_output': fpx, # already sigmoided
793
+ 'clipwise_output': torch.sigmoid(x),
794
+ 'latent_output': latent_output
795
+ }
796
+
797
+ else:
798
+ x = self.norm(x) # B N C
799
+ B, N, C = x.shape
800
+
801
+ fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
802
+ B, C, F, T = fpx.shape
803
+ c_freq_bin = F // self.freq_ratio
804
+ fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
805
+ fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
806
+ fpx = torch.sum(fpx, dim = 2)
807
+ fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
808
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
809
+ x = torch.flatten(x, 1)
810
+ if self.num_classes > 0:
811
+ x = self.head(x)
812
+ fpx = self.head(fpx)
813
+ output_dict = {'framewise_output': torch.sigmoid(fpx),
814
+ 'clipwise_output': torch.sigmoid(x)}
815
+ return output_dict
816
+
817
+ def crop_wav(self, x, crop_size, spe_pos = None):
818
+ time_steps = x.shape[2]
819
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
820
+ for i in range(len(x)):
821
+ if spe_pos is None:
822
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
823
+ else:
824
+ crop_pos = spe_pos
825
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
826
+ return tx
827
+
828
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
829
+ def reshape_wav2img(self, x):
830
+ B, C, T, F = x.shape
831
+ target_T = int(self.spec_size * self.freq_ratio)
832
+ target_F = self.spec_size // self.freq_ratio
833
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
834
+ # to avoid bicubic zero error
835
+ if T < target_T:
836
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
837
+ if F < target_F:
838
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
839
+ x = x.permute(0,1,3,2).contiguous()
840
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
841
+ # print(x.shape)
842
+ x = x.permute(0,1,3,2,4).contiguous()
843
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
844
+ return x
845
+
846
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
847
+ def repeat_wat2img(self, x, cur_pos):
848
+ B, C, T, F = x.shape
849
+ target_T = int(self.spec_size * self.freq_ratio)
850
+ target_F = self.spec_size // self.freq_ratio
851
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
852
+ # to avoid bicubic zero error
853
+ if T < target_T:
854
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
855
+ if F < target_F:
856
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
857
+ x = x.permute(0,1,3,2).contiguous() # B C F T
858
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
859
+ x = x.repeat(repeats = (1,1,4,1))
860
+ return x
861
+
862
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
863
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
864
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
865
+
866
+
867
+ x = x.transpose(1, 3)
868
+ x = self.bn0(x)
869
+ x = x.transpose(1, 3)
870
+ if self.training:
871
+ x = self.spec_augmenter(x)
872
+ if self.training and mixup_lambda is not None:
873
+ x = do_mixup(x, mixup_lambda)
874
+
875
+ if infer_mode:
876
+ # in infer mode. we need to handle different length audio input
877
+ frame_num = x.shape[2]
878
+ target_T = int(self.spec_size * self.freq_ratio)
879
+ repeat_ratio = math.floor(target_T / frame_num)
880
+ x = x.repeat(repeats=(1,1,repeat_ratio,1))
881
+ x = self.reshape_wav2img(x)
882
+ output_dict = self.forward_features(x)
883
+ elif self.config.enable_repeat_mode:
884
+ if self.training:
885
+ cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
886
+ x = self.repeat_wat2img(x, cur_pos)
887
+ output_dict = self.forward_features(x)
888
+ else:
889
+ output_dicts = []
890
+ for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
891
+ tx = x.clone()
892
+ tx = self.repeat_wat2img(tx, cur_pos)
893
+ output_dicts.append(self.forward_features(tx))
894
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
895
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
896
+ for d in output_dicts:
897
+ clipwise_output += d["clipwise_output"]
898
+ framewise_output += d["framewise_output"]
899
+ clipwise_output = clipwise_output / len(output_dicts)
900
+ framewise_output = framewise_output / len(output_dicts)
901
+
902
+ output_dict = {
903
+ 'framewise_output': framewise_output,
904
+ 'clipwise_output': clipwise_output
905
+ }
906
+ else:
907
+ if x.shape[2] > self.freq_ratio * self.spec_size:
908
+ if self.training:
909
+ x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
910
+ x = self.reshape_wav2img(x)
911
+ output_dict = self.forward_features(x)
912
+ else:
913
+ # Change: Hard code here
914
+ overlap_size = 344 #(x.shape[2] - 1) // 4
915
+ output_dicts = []
916
+ crop_size = 689 #(x.shape[2] - 1) // 2
917
+ for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
918
+ tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
919
+ tx = self.reshape_wav2img(tx)
920
+ output_dicts.append(self.forward_features(tx))
921
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
922
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
923
+ latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
924
+ for d in output_dicts:
925
+ clipwise_output += d["clipwise_output"]
926
+ framewise_output += d["framewise_output"]
927
+ latent_output += d["latent_output"]
928
+ clipwise_output = clipwise_output / len(output_dicts)
929
+ framewise_output = framewise_output / len(output_dicts)
930
+ latent_output = latent_output / len(output_dicts)
931
+ output_dict = {
932
+ 'framewise_output': framewise_output,
933
+ 'clipwise_output': clipwise_output,
934
+ 'latent_output': latent_output,
935
+ }
936
+ else: # this part is typically used, and most easy one
937
+ x = self.reshape_wav2img(x)
938
+ output_dict = self.forward_features(x)
939
+ # x = self.head(x)
940
+ return output_dict
941
+
942
+ class HTSATWrapper(nn.Module):
943
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
944
+ fmax, classes_num, out_emb):
945
+ super().__init__()
946
+
947
+ # print("parameters are being overidden when using HTSAT")
948
+ # print("HTSAT only support loading a pretrained model on AudioSet")
949
+ # @TODO later look at what parameters are same and can be merged
950
+
951
+ self.htsat = HTSAT_Swin_Transformer(config=config)
952
+
953
+ def forward(self, x):
954
+ out_dict = self.htsat(x)
955
+ out_dict['embedding'] = out_dict['latent_output']
956
+ return out_dict
ms_clap/src/models/mapper.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as nnf
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from enum import Enum
7
+ from transformers import GPT2LMHeadModel
8
+ from typing import Tuple, Optional, Union
9
+
10
+ def get_clapcap(name: str):
11
+ if name == "ClapCaption":
12
+ return ClapCaptionModel
13
+ else:
14
+ raise Exception('The ClapCap model {} is incorrect or not supported'.format(name))
15
+
16
+ class MappingType(Enum):
17
+ MLP = 'mlp'
18
+ Transformer = 'transformer'
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
22
+ super(MLP, self).__init__()
23
+ layers = []
24
+ for i in range(len(sizes) - 1):
25
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
26
+ if i < len(sizes) - 2:
27
+ layers.append(act())
28
+ self.model = nn.Sequential(*layers)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ return self.model(x)
32
+
33
+
34
+ class MlpTransformer(nn.Module):
35
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
36
+ super().__init__()
37
+ out_d = out_d if out_d is not None else in_dim
38
+ self.fc1 = nn.Linear(in_dim, h_dim)
39
+ self.act = act
40
+ self.fc2 = nn.Linear(h_dim, out_d)
41
+ self.dropout = nn.Dropout(dropout)
42
+
43
+ def forward(self, x):
44
+ x = self.fc1(x)
45
+ x = self.act(x)
46
+ x = self.dropout(x)
47
+ x = self.fc2(x)
48
+ x = self.dropout(x)
49
+ return x
50
+
51
+ class MultiHeadAttention(nn.Module):
52
+
53
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ head_dim = dim_self // num_heads
57
+ self.scale = head_dim ** -0.5
58
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
59
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
60
+ self.project = nn.Linear(dim_self, dim_self)
61
+ self.dropout = nn.Dropout(dropout)
62
+
63
+ def forward(self, x, y=None, mask=None):
64
+ y = y if y is not None else x
65
+ b, n, c = x.shape
66
+ _, m, d = y.shape
67
+ # b n h dh
68
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
69
+ # b m 2 h dh
70
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
71
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
72
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
73
+ if mask is not None:
74
+ if mask.dim() == 2:
75
+ mask = mask.unsqueeze(1)
76
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
77
+ attention = attention.softmax(dim=2)
78
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
79
+ out = self.project(out)
80
+ return out, attention
81
+
82
+
83
+ class TransformerLayer(nn.Module):
84
+
85
+ def forward_with_attention(self, x, y=None, mask=None):
86
+ x_, attention = self.attn(self.norm1(x), y, mask)
87
+ x = x + x_
88
+ x = x + self.mlp(self.norm2(x))
89
+ return x, attention
90
+
91
+ def forward(self, x, y=None, mask=None):
92
+ x = x + self.attn(self.norm1(x), y, mask)[0]
93
+ x = x + self.mlp(self.norm2(x))
94
+ return x
95
+
96
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
97
+ norm_layer: nn.Module = nn.LayerNorm):
98
+ super().__init__()
99
+ self.norm1 = norm_layer(dim_self)
100
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
101
+ self.norm2 = norm_layer(dim_self)
102
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
103
+
104
+
105
+ class Transformer(nn.Module):
106
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
107
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
108
+ super(Transformer, self).__init__()
109
+ dim_ref = dim_ref if dim_ref is not None else dim_self
110
+ self.enc_dec = enc_dec
111
+ if enc_dec:
112
+ num_layers = num_layers * 2
113
+ layers = []
114
+ for i in range(num_layers):
115
+ if i % 2 == 0 and enc_dec: # cross
116
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
117
+ elif enc_dec: # self
118
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
119
+ else: # self or cross
120
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
121
+ self.layers = nn.ModuleList(layers)
122
+
123
+ def forward_with_attention(self, x, y=None, mask=None):
124
+ attentions = []
125
+ for layer in self.layers:
126
+ x, att = layer.forward_with_attention(x, y, mask)
127
+ attentions.append(att)
128
+ return x, attentions
129
+
130
+ def forward(self, x, y=None, mask=None):
131
+ for i, layer in enumerate(self.layers):
132
+ if i % 2 == 0 and self.enc_dec: # cross
133
+ x = layer(x, y)
134
+ elif self.enc_dec: # self
135
+ x = layer(x, x, mask)
136
+ else: # self or cross
137
+ x = layer(x, y, mask)
138
+ return x
139
+
140
+
141
+ class TransformerMapper(nn.Module):
142
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
143
+ super(TransformerMapper, self).__init__()
144
+ self.clip_length = clip_length
145
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
146
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
147
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
148
+
149
+ def forward(self, x):
150
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
151
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
152
+ prefix = torch.cat((x, prefix), dim=1)
153
+ out = self.transformer(prefix)[:, self.clip_length:]
154
+ return out
155
+
156
+ class ClapCaptionModel(nn.Module):
157
+ def __init__(self, clap, text_decoder: str, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
158
+ num_layers: int = 8, normalize_prefix: bool = True, mapping_type: str = None,\
159
+ freeze_audio_encoder_weights: bool = True, freeze_gpt_weights: bool = True):
160
+ super(ClapCaptionModel, self).__init__()
161
+ self.clap = clap.audio_encoder
162
+ self.prefix_length = prefix_length
163
+ self.normalize_prefix = normalize_prefix
164
+ self.gpt = GPT2LMHeadModel.from_pretrained(text_decoder)
165
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
166
+ if mapping_type == 'mlp':
167
+ self.clap_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
168
+ self.gpt_embedding_size * prefix_length))
169
+ else:
170
+ self.clap_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
171
+ clip_length, num_layers)
172
+
173
+ # Freeze all CLAP parameters
174
+ if freeze_audio_encoder_weights:
175
+ for p in self.clap.parameters():
176
+ p.requires_grad = False
177
+
178
+ if freeze_gpt_weights:
179
+ for p in self.gpt.parameters():
180
+ p.requires_grad = False
181
+
182
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
183
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
184
+
185
+ def forward(self, audios: torch.Tensor, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None,
186
+ labels: Optional[torch.Tensor] = None):
187
+ # get audio embeddings
188
+ prefix, _ = self.clap(audios)
189
+ # normalize prefix (audio embedding)
190
+ if self.normalize_prefix:
191
+ prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
192
+
193
+ embedding_text = self.gpt.transformer.wte(tokens['input_ids'])
194
+ prefix_projections = self.clap_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
195
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
196
+ if labels is not None:
197
+ dummy_token = self.get_dummy_token(tokens['input_ids'].shape[0], tokens['input_ids'].device)
198
+ labels = torch.cat((dummy_token, tokens), dim=1)
199
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
200
+ return out
ms_clap/src/models/pytorch_utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def move_data_to_device(x, device):
8
+ if 'float' in str(x.dtype):
9
+ x = torch.Tensor(x)
10
+ elif 'int' in str(x.dtype):
11
+ x = torch.LongTensor(x)
12
+ else:
13
+ return x
14
+
15
+ return x.to(device)
16
+
17
+
18
+ def do_mixup(x, mixup_lambda):
19
+ """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
20
+ (1, 3, 5, ...).
21
+ Args:
22
+ x: (batch_size * 2, ...)
23
+ mixup_lambda: (batch_size * 2,)
24
+ Returns:
25
+ out: (batch_size, ...)
26
+ """
27
+ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
28
+ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
29
+ return out
30
+
31
+
32
+ def append_to_dict(dict, key, value):
33
+ if key in dict.keys():
34
+ dict[key].append(value)
35
+ else:
36
+ dict[key] = [value]
37
+
38
+
39
+ def interpolate(x, ratio):
40
+ """Interpolate data in time domain. This is used to compensate the
41
+ resolution reduction in downsampling of a CNN.
42
+
43
+ Args:
44
+ x: (batch_size, time_steps, classes_num)
45
+ ratio: int, ratio to interpolate
46
+ Returns:
47
+ upsampled: (batch_size, time_steps * ratio, classes_num)
48
+ """
49
+ (batch_size, time_steps, classes_num) = x.shape
50
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
51
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
52
+ return upsampled
53
+
54
+
55
+ def pad_framewise_output(framewise_output, frames_num):
56
+ """Pad framewise_output to the same length as input frames. The pad value
57
+ is the same as the value of the last frame.
58
+ Args:
59
+ framewise_output: (batch_size, frames_num, classes_num)
60
+ frames_num: int, number of frames to pad
61
+ Outputs:
62
+ output: (batch_size, frames_num, classes_num)
63
+ """
64
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
65
+ """tensor for padding"""
66
+
67
+ output = torch.cat((framewise_output, pad), dim=1)
68
+ """(batch_size, frames_num, classes_num)"""
69
+
70
+ return output
71
+
72
+
73
+ def count_parameters(model):
74
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
75
+
76
+
77
+ def count_flops(model, audio_length):
78
+ """Count flops. Code modified from others' implementation.
79
+ """
80
+ multiply_adds = True
81
+ list_conv2d=[]
82
+ def conv2d_hook(self, input, output):
83
+ batch_size, input_channels, input_height, input_width = input[0].size()
84
+ output_channels, output_height, output_width = output[0].size()
85
+
86
+ kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
87
+ bias_ops = 1 if self.bias is not None else 0
88
+
89
+ params = output_channels * (kernel_ops + bias_ops)
90
+ flops = batch_size * params * output_height * output_width
91
+
92
+ list_conv2d.append(flops)
93
+
94
+ list_conv1d=[]
95
+ def conv1d_hook(self, input, output):
96
+ batch_size, input_channels, input_length = input[0].size()
97
+ output_channels, output_length = output[0].size()
98
+
99
+ kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
100
+ bias_ops = 1 if self.bias is not None else 0
101
+
102
+ params = output_channels * (kernel_ops + bias_ops)
103
+ flops = batch_size * params * output_length
104
+
105
+ list_conv1d.append(flops)
106
+
107
+ list_linear=[]
108
+ def linear_hook(self, input, output):
109
+ batch_size = input[0].size(0) if input[0].dim() == 2 else 1
110
+
111
+ weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
112
+ bias_ops = self.bias.nelement()
113
+
114
+ flops = batch_size * (weight_ops + bias_ops)
115
+ list_linear.append(flops)
116
+
117
+ list_bn=[]
118
+ def bn_hook(self, input, output):
119
+ list_bn.append(input[0].nelement() * 2)
120
+
121
+ list_relu=[]
122
+ def relu_hook(self, input, output):
123
+ list_relu.append(input[0].nelement() * 2)
124
+
125
+ list_pooling2d=[]
126
+ def pooling2d_hook(self, input, output):
127
+ batch_size, input_channels, input_height, input_width = input[0].size()
128
+ output_channels, output_height, output_width = output[0].size()
129
+
130
+ kernel_ops = self.kernel_size * self.kernel_size
131
+ bias_ops = 0
132
+ params = output_channels * (kernel_ops + bias_ops)
133
+ flops = batch_size * params * output_height * output_width
134
+
135
+ list_pooling2d.append(flops)
136
+
137
+ list_pooling1d=[]
138
+ def pooling1d_hook(self, input, output):
139
+ batch_size, input_channels, input_length = input[0].size()
140
+ output_channels, output_length = output[0].size()
141
+
142
+ kernel_ops = self.kernel_size[0]
143
+ bias_ops = 0
144
+
145
+ params = output_channels * (kernel_ops + bias_ops)
146
+ flops = batch_size * params * output_length
147
+
148
+ list_pooling2d.append(flops)
149
+
150
+ def foo(net):
151
+ childrens = list(net.children())
152
+ if not childrens:
153
+ if isinstance(net, nn.Conv2d):
154
+ net.register_forward_hook(conv2d_hook)
155
+ elif isinstance(net, nn.Conv1d):
156
+ net.register_forward_hook(conv1d_hook)
157
+ elif isinstance(net, nn.Linear):
158
+ net.register_forward_hook(linear_hook)
159
+ elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
160
+ net.register_forward_hook(bn_hook)
161
+ elif isinstance(net, nn.ReLU):
162
+ net.register_forward_hook(relu_hook)
163
+ elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
164
+ net.register_forward_hook(pooling2d_hook)
165
+ elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
166
+ net.register_forward_hook(pooling1d_hook)
167
+ else:
168
+ print('Warning: flop of module {} is not counted!'.format(net))
169
+ return
170
+ for c in childrens:
171
+ foo(c)
172
+
173
+ # Register hook
174
+ foo(model)
175
+
176
+ device = device = next(model.parameters()).device
177
+ input = torch.rand(1, audio_length).to(device)
178
+
179
+ out = model(input)
180
+
181
+ total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \
182
+ sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d)
183
+
184
+ return total_flops
ms_clap/src/models/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import sys
4
+
5
+ def read_config_as_args(config_path,args=None,is_config_str=False):
6
+ return_dict = {}
7
+
8
+ if config_path is not None:
9
+ if is_config_str:
10
+ yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
11
+ else:
12
+ with open(config_path, "r") as f:
13
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
14
+
15
+ if args != None:
16
+ for k, v in yml_config.items():
17
+ if k in args.__dict__:
18
+ args.__dict__[k] = v
19
+ else:
20
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
21
+ else:
22
+ for k, v in yml_config.items():
23
+ return_dict[k] = v
24
+
25
+ args = args if args != None else return_dict
26
+ return argparse.Namespace(**args)
ms_clap/src/zero_shot_classification.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an example using CLAP to perform zeroshot
3
+ classification on ESC50 (https://github.com/karolpiczak/ESC-50).
4
+ """
5
+
6
+ from CLAPWrapper import CLAPWrapper
7
+ from esc50_dataset import ESC50
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from sklearn.metrics import accuracy_score
12
+
13
+ # Load dataset
14
+ root_path = "root_path" # Folder with ESC-50-master/
15
+ dataset = ESC50(root=root_path, download=True) #If download=False code assumes base_folder='ESC-50-master' in esc50_dataset.py
16
+ prompt = 'this is the sound of '
17
+ y = [prompt + x for x in dataset.classes]
18
+
19
+ # Load and initialize CLAP
20
+ weights_path = "weights_path"
21
+ clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
22
+
23
+ # Computing text embeddings
24
+ text_embeddings = clap_model.get_text_embeddings(y)
25
+
26
+ # Computing audio embeddings
27
+ y_preds, y_labels = [], []
28
+ for i in tqdm(range(len(dataset))):
29
+ x, _, one_hot_target = dataset.__getitem__(i)
30
+ audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)
31
+ similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
32
+ y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()
33
+ y_preds.append(y_pred)
34
+ y_labels.append(one_hot_target.detach().cpu().numpy())
35
+
36
+
37
+ y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)
38
+ acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))
39
+ print('ESC50 Accuracy {}'.format(acc))
40
+
41
+ """
42
+ The output:
43
+
44
+ ESC50 Accuracy: 93.9%
45
+
46
+ """
ms_clap/src/zero_shot_predictions.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an example using CLAP for zero-shot inference.
3
+ """
4
+ from CLAPWrapper import CLAPWrapper
5
+ import torch.nn.functional as F
6
+
7
+ # Define classes for zero-shot
8
+ # Should be in lower case and can be more than one word
9
+ classes = ['coughing','sneezing','drinking sipping', 'breathing', 'brushing teeth']
10
+ ground_truth = ['coughing']
11
+ # Add prompt
12
+ prompt = 'this is a sound of '
13
+ class_prompts = [prompt + x for x in classes]
14
+ #Load audio files
15
+ audio_files = ['audio_file']
16
+
17
+ # Load and initialize CLAP
18
+ weights_path = "weights_path"
19
+ # Setting use_cuda = True will load the model on a GPU using CUDA
20
+ clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
21
+
22
+ # compute text embeddings from natural text
23
+ text_embeddings = clap_model.get_text_embeddings(class_prompts)
24
+
25
+ # compute the audio embeddings from an audio file
26
+ audio_embeddings = clap_model.get_audio_embeddings(audio_files, resample=True)
27
+
28
+ # compute the similarity between audio_embeddings and text_embeddings
29
+ similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
30
+
31
+ similarity = F.softmax(similarity, dim=1)
32
+ values, indices = similarity[0].topk(5)
33
+
34
+ # Print the results
35
+ print("Ground Truth: {}".format(ground_truth))
36
+ print("Top predictions:\n")
37
+ for value, index in zip(values, indices):
38
+ print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")
39
+
40
+ """
41
+ The output (the exact numbers may vary):
42
+
43
+ Ground Truth: coughing
44
+ Top predictions:
45
+
46
+ coughing: 98.55%
47
+ sneezing: 1.24%
48
+ drinking sipping: 0.15%
49
+ breathing: 0.02%
50
+ brushing teeth: 0.01%
51
+ """
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ scikit-learn
4
+ librosa
5
+ soundfile
6
+ pydub
7
+ torch==2.0.1
8
+ torchaudio==2.0.2
9
+ torchlibrosa==0.1.0
10
+ torchvision==0.15.2
11
+ transformers==4.27.4
12
+ einops
13
+ huggingface-hub
14
+ laion-clap==1.1.3
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
src/factory.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import sys
8
+ sys.path.append('../')
9
+
10
+ from typing import Optional
11
+ from copy import deepcopy
12
+
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from ms_clap.src.CLAPWrapper import CLAPWrapper
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ try:
20
+ from .flamingo import Flamingo
21
+ from .flamingo_lm import FlamingoLMMixin
22
+ from .utils import extend_instance
23
+ except:
24
+ from flamingo import Flamingo
25
+ from flamingo_lm import FlamingoLMMixin
26
+ from utils import extend_instance
27
+
28
+
29
+ class CLAP(nn.Module):
30
+ def __init__(self, clap_config):
31
+ super(CLAP, self).__init__()
32
+ self.method = clap_config["method"]
33
+ device_id = f'cuda:{torch.cuda.current_device()}'
34
+
35
+ if self.method == 'laion-clap':
36
+ # https://github.com/LAION-AI/CLAP
37
+ if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']:
38
+ amodel = 'HTSAT-tiny'
39
+ elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']:
40
+ amodel = 'HTSAT-base'
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ enable_fusion = 'fusion' in clap_config["model_name"].lower()
45
+ self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id)
46
+ self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
47
+
48
+ for param in self.laion_clap.parameters():
49
+ param.requires_grad = False
50
+ self.laion_clap.eval()
51
+
52
+ print('loaded laion-clap model: {}'.format(clap_config["checkpoint"]))
53
+
54
+ elif self.method == 'microsoft-clap':
55
+ # https://github.com/microsoft/CLAP
56
+ self.ms_clap = CLAPWrapper(
57
+ clap_config["checkpoint"],
58
+ config_root=clap_config["config_root"],
59
+ version=clap_config['model_name'],
60
+ use_cuda=True
61
+ )
62
+
63
+ if clap_config['model_name'] in ['2022', '2023']:
64
+ for param in self.ms_clap.clap.parameters():
65
+ param.requires_grad = False
66
+ self.ms_clap.clap.eval()
67
+ else:
68
+ for param in self.ms_clap.clapcap.parameters():
69
+ param.requires_grad = False
70
+ self.ms_clap.clapcap.eval()
71
+
72
+ print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"]))
73
+
74
+ else:
75
+ raise NotImplementedError
76
+
77
+ def forward(self, audio_clips):
78
+
79
+ if len(audio_clips.shape) == 2:
80
+ audio_clips = audio_clips.unsqueeze(0)
81
+ assert len(audio_clips.shape) == 3
82
+
83
+ audio_embeds = []
84
+ for x in audio_clips:
85
+ if self.method == 'laion-clap':
86
+ audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True)
87
+ elif self.method == 'microsoft-clap':
88
+ audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x)
89
+
90
+ audio_embeds.append(audio_embed)
91
+
92
+ audio_embeds = torch.stack(audio_embeds, dim=0)
93
+ audio_embeds.requires_grad = False
94
+
95
+ return audio_embeds
96
+
97
+
98
+ def create_model_and_transforms(
99
+ clap_config: dict,
100
+ lang_encoder_path: str,
101
+ tokenizer_path: str,
102
+ audio_transformer_kwargs: dict,
103
+ cross_attn_every_n_layers: int = 1,
104
+ use_local_files: bool = False,
105
+ decoder_layers_attr_name: str = None,
106
+ freeze_lm_embeddings: bool = False,
107
+ unfreeze_full_lm: bool = False,
108
+ cache_dir: Optional[str] = None,
109
+ **flamingo_kwargs,
110
+ ):
111
+ clap = CLAP(clap_config)
112
+
113
+ text_tokenizer = AutoTokenizer.from_pretrained(
114
+ tokenizer_path,
115
+ local_files_only=use_local_files,
116
+ trust_remote_code=True,
117
+ cache_dir=cache_dir,
118
+ )
119
+ text_tokenizer.add_special_tokens(
120
+ {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
121
+ )
122
+ if text_tokenizer.pad_token is None:
123
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
124
+ if text_tokenizer.sep_token is None:
125
+ text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
126
+
127
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
128
+ lang_encoder_path,
129
+ local_files_only=use_local_files,
130
+ trust_remote_code=True,
131
+ cache_dir=cache_dir,
132
+ )
133
+
134
+ extend_instance(lang_encoder, FlamingoLMMixin)
135
+
136
+ if decoder_layers_attr_name is None:
137
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
138
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
139
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
140
+
141
+ unfreeze_clap = False
142
+
143
+ model = Flamingo(
144
+ clap,
145
+ unfreeze_clap,
146
+ lang_encoder,
147
+ text_tokenizer.encode("<|endofchunk|>")[-1],
148
+ text_tokenizer.encode("<audio>")[-1],
149
+ text_tokenizer.sep_token_id,
150
+ audio_embed_dim=clap_config["audio_embed_dim"],
151
+ audio_transformer_kwargs=audio_transformer_kwargs,
152
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
153
+ **flamingo_kwargs,
154
+ )
155
+
156
+ model.requires_grad_(False)
157
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
158
+
159
+ model.audio_transformer.requires_grad_(True)
160
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
161
+ if not freeze_lm_embeddings:
162
+ model.lang_encoder.get_input_embeddings().requires_grad_(True)
163
+
164
+ if unfreeze_full_lm:
165
+ model.lang_encoder.requires_grad_(True)
166
+
167
+ if unfreeze_clap:
168
+ model.clap.requires_grad_(True)
169
+
170
+ print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format(
171
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
172
+ sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad),
173
+ sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad)
174
+ ))
175
+
176
+ return model, text_tokenizer
177
+
178
+
179
+ def _infer_decoder_layers_attr_name(model):
180
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
181
+ if k.lower() in model.__class__.__name__.lower():
182
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
183
+
184
+ raise ValueError(
185
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
186
+ )
187
+
188
+
189
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
190
+ "opt": "model.decoder.layers",
191
+ "gptj": "transformer.h",
192
+ "gpt-j": "transformer.h",
193
+ "pythia": "gpt_neox.layers",
194
+ "llama": "model.layers",
195
+ "gptneoxforcausallm": "gpt_neox.layers",
196
+ "mpt": "transformer.blocks",
197
+ "mosaicgpt": "transformer.blocks",
198
+ }
src/flamingo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from torch.distributed.fsdp.wrap import (
12
+ enable_wrap,
13
+ wrap,
14
+ )
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from torch.distributed.fsdp import (
17
+ FullyShardedDataParallel as FSDP,
18
+ )
19
+
20
+ try:
21
+ from .helpers import TransformerEncoder
22
+ from .utils import apply_with_stopping_condition
23
+ except:
24
+ from helpers import TransformerEncoder
25
+ from utils import apply_with_stopping_condition
26
+
27
+
28
+ class Flamingo(nn.Module):
29
+ def __init__(
30
+ self,
31
+ clap: nn.Module,
32
+ unfreeze_clap: bool,
33
+ lang_encoder: nn.Module,
34
+ eoc_token_id: int,
35
+ media_token_id: int,
36
+ sep_token_id: int,
37
+ audio_embed_dim: int,
38
+ audio_transformer_kwargs: dict,
39
+ cross_attn_every_n_layers: int = 1,
40
+ gradient_checkpointing: bool = False,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.eoc_token_id = eoc_token_id
45
+ self.media_token_id = media_token_id
46
+ self.sep_token_id = sep_token_id
47
+ self.audio_embed_dim = audio_embed_dim
48
+ self.clap = clap # .to(torch.cuda.current_device())
49
+ self.unfreeze_clap = unfreeze_clap
50
+ self.clap.requires_grad_(unfreeze_clap)
51
+
52
+ if hasattr(lang_encoder.config, "d_model"):
53
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
54
+ else:
55
+ self.lang_dim = lang_encoder.config.hidden_size
56
+
57
+ n_head = audio_transformer_kwargs["n_head"]
58
+ n_layers = audio_transformer_kwargs["n_layers"]
59
+ d_inner = audio_transformer_kwargs["d_inner"]
60
+ max_num_media = audio_transformer_kwargs["max_num_media"]
61
+ max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"]
62
+ assert audio_embed_dim % n_head == 0
63
+
64
+ self.audio_transformer = TransformerEncoder(
65
+ d_word_vec=audio_embed_dim,
66
+ n_layers=n_layers,
67
+ n_head=n_head,
68
+ d_k=audio_embed_dim // n_head,
69
+ d_v=audio_embed_dim // n_head,
70
+ d_model=audio_embed_dim,
71
+ d_inner=d_inner,
72
+ dropout=0.0,
73
+ n_position=max_num_media,
74
+ scale_emb=True
75
+ )
76
+
77
+ self.lang_encoder = lang_encoder
78
+ self.lang_encoder.init_flamingo(
79
+ media_token_id=media_token_id,
80
+ lang_hidden_size=self.lang_dim,
81
+ audio_hidden_size=self.audio_embed_dim,
82
+ max_window_per_audio=max_window_per_audio,
83
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
84
+ gradient_checkpointing=gradient_checkpointing,
85
+ )
86
+
87
+ self._use_gradient_checkpointing = gradient_checkpointing
88
+ self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing
89
+ self.clap._use_gradient_checkpointing = gradient_checkpointing
90
+
91
+ def forward(
92
+ self,
93
+ audio_x: torch.Tensor,
94
+ audio_x_mask: torch.Tensor,
95
+ lang_x: torch.Tensor,
96
+ attention_mask: torch.Tensor = None,
97
+ labels: torch.Tensor = None,
98
+ clear_conditioned_layers: bool = True,
99
+ past_key_values=None,
100
+ use_cache: bool = False,
101
+ ):
102
+ assert (
103
+ self.lang_encoder.initialized_flamingo
104
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
105
+
106
+ assert (
107
+ self.lang_encoder._use_cached_audio_x or audio_x is not None
108
+ ), "Must provide either audio_x or have precached media using cache_media()."
109
+
110
+ if self.lang_encoder._use_cached_audio_x:
111
+ assert (
112
+ audio_x is None
113
+ ), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first."
114
+ assert self.lang_encoder.is_conditioned()
115
+
116
+ else:
117
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
118
+ self._condition_media_locations(input_ids=lang_x)
119
+
120
+ output = self.lang_encoder(
121
+ input_ids=lang_x,
122
+ attention_mask=attention_mask,
123
+ labels=labels,
124
+ past_key_values=past_key_values,
125
+ use_cache=use_cache,
126
+ )
127
+
128
+ if clear_conditioned_layers:
129
+ self.lang_encoder.clear_conditioned_layers()
130
+
131
+ return output
132
+
133
+ def generate(
134
+ self,
135
+ audio_x: torch.Tensor,
136
+ audio_x_mask: torch.Tensor,
137
+ lang_x: torch.Tensor,
138
+ attention_mask: torch.Tensor = None,
139
+ **kwargs,
140
+ ):
141
+ num_beams = kwargs.pop("num_beams", 1)
142
+ if num_beams > 1:
143
+ audio_x = audio_x.repeat_interleave(num_beams, dim=0)
144
+
145
+ self.lang_encoder._use_cached_audio_x = True
146
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
147
+
148
+ eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
149
+ output = self.lang_encoder.generate(
150
+ input_ids=lang_x,
151
+ attention_mask=attention_mask,
152
+ eos_token_id=eos_token_id,
153
+ num_beams=num_beams,
154
+ **kwargs,
155
+ )
156
+
157
+ self.lang_encoder.clear_conditioned_layers()
158
+ self.lang_encoder._use_cached_audio_x = False
159
+ return output
160
+
161
+ def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
162
+ """
163
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
164
+ """
165
+
166
+ assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)"
167
+
168
+ with torch.no_grad():
169
+ audio_embeds = self.clap(audio_x)
170
+ B, L, D = audio_embeds.shape # L is number of windows, D is feature dim
171
+ assert D == self.audio_embed_dim
172
+
173
+ assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)"
174
+
175
+ if B > 1 and audio_x_mask.shape[0] == 1:
176
+ audio_x_mask = audio_x_mask.repeat(B, 1)
177
+
178
+ assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L)
179
+
180
+ audio_x_out = self.audio_transformer(audio_embeds) # B, L, D
181
+ audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D
182
+ audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1
183
+
184
+ for layer in self.lang_encoder._get_decoder_layers():
185
+ layer.condition_audio_x(audio_x_out, audio_x_mask)
186
+
187
+ def wrap_fsdp(self, wrapper_kwargs, device_id):
188
+ # unfreeze the decoder layers
189
+ for block in self.lang_encoder.old_decoder_blocks:
190
+ block.requires_grad_(True)
191
+
192
+ # wrap in FSDP
193
+ with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
194
+ self.audio_transformer = wrap(wrap(self.audio_transformer))
195
+ self.lang_encoder.old_decoder_blocks = nn.ModuleList(
196
+ wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
197
+ )
198
+ self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
199
+ wrap(wrap(layer)) if layer is not None else None
200
+ for layer in self.lang_encoder.gated_cross_attn_layers
201
+ )
202
+ self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
203
+ self.lang_encoder.set_input_embeddings(
204
+ wrap(wrap(self.lang_encoder.get_input_embeddings()))
205
+ )
206
+
207
+ if hasattr(self.lang_encoder, 'set_output_embeddings'):
208
+ self.lang_encoder.set_output_embeddings(
209
+ wrap(wrap(self.lang_encoder.get_output_embeddings()))
210
+ )
211
+ else:
212
+ print('skip wrapping output embeddings')
213
+
214
+ # manually move non-FSDP managed parameters to device_id
215
+ # these are all in lang_encoder
216
+ apply_with_stopping_condition(
217
+ module=self.lang_encoder,
218
+ apply_fn=lambda m: m.to(device_id),
219
+ apply_condition=lambda m: len(list(m.children())) == 0,
220
+ stopping_condition=lambda m: isinstance(m, FSDP),
221
+ )
222
+
223
+ # clap shouldn't be wrapped; should be on each gpu
224
+ if self.unfreeze_clap:
225
+ apply_with_stopping_condition(
226
+ module=self.clap,
227
+ apply_fn=lambda m: m.to(device_id),
228
+ apply_condition=lambda m: len(list(m.children())) == 0,
229
+ stopping_condition=lambda m: isinstance(m, FSDP),
230
+ )
231
+
232
+ # exclude the original decoder layers from the optimizer
233
+ for block in self.lang_encoder.old_decoder_blocks:
234
+ for p in block.parameters():
235
+ p.exclude_from_optimizer = True
236
+
237
+ # set up clip_grad_norm_ function
238
+ def clip_grad_norm_(max_norm):
239
+ self.audio_transformer.clip_grad_norm_(max_norm)
240
+ for layer in self.lang_encoder.gated_cross_attn_layers:
241
+ if layer is not None:
242
+ layer.clip_grad_norm_(max_norm)
243
+ self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
244
+
245
+ self.clip_grad_norm_ = clip_grad_norm_
246
+
247
+ def _condition_media_locations(self, input_ids: torch.Tensor):
248
+ media_locations = (input_ids == self.media_token_id)
249
+
250
+ for layer in self.lang_encoder._get_decoder_layers():
251
+ layer.condition_media_locations(media_locations)
252
+
253
+ def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
254
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
255
+ self._condition_media_locations(input_ids=input_ids)
256
+ self.lang_encoder._use_cached_audio_x = True
257
+
258
+ def uncache_media(self):
259
+ self.lang_encoder.clear_conditioned_layers()
260
+ self.lang_encoder._use_cached_audio_x = False
src/flamingo_lm.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch.nn as nn
8
+
9
+ try:
10
+ from .helpers import GatedCrossAttentionBlock
11
+ from .utils import getattr_recursive, setattr_recursive
12
+ except:
13
+ from helpers import GatedCrossAttentionBlock
14
+ from utils import getattr_recursive, setattr_recursive
15
+
16
+
17
+ class FlamingoLayer(nn.Module):
18
+ """
19
+ FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
20
+ """
21
+
22
+ def __init__(
23
+ self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
24
+ ):
25
+ super().__init__()
26
+ self.gated_cross_attn_layer = gated_cross_attn_layer
27
+ self.decoder_layer = decoder_layer
28
+ self.audio_x = None
29
+ self.audio_x_mask = None
30
+ self.few_shot_mask = None
31
+ self.media_locations = None
32
+ if self.gated_cross_attn_layer is not None:
33
+ self.gated_cross_attn_layer._use_gradient_checkpointing = (
34
+ gradient_checkpointing
35
+ )
36
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
37
+
38
+ def is_conditioned(self) -> bool:
39
+ """Check whether the layer is conditioned."""
40
+ return (self.audio_x is not None) and (self.audio_x_mask is not None) and (self.media_locations is not None)
41
+
42
+ def condition_audio_x(self, audio_x, audio_x_mask):
43
+ self.audio_x = audio_x
44
+ self.audio_x_mask = audio_x_mask
45
+
46
+ def condition_media_locations(self, media_locations):
47
+ self.media_locations = media_locations
48
+
49
+ def condition_use_cached_media(self, use_cached_media):
50
+ self.use_cached_media = use_cached_media
51
+
52
+ def forward(
53
+ self,
54
+ lang_x,
55
+ attention_mask=None,
56
+ **decoder_layer_kwargs,
57
+ ):
58
+ if self.gated_cross_attn_layer is not None:
59
+ if self.audio_x is None:
60
+ raise ValueError("audio_x must be conditioned before forward pass")
61
+
62
+ if self.media_locations is None:
63
+ raise ValueError(
64
+ "media_locations must be conditioned before forward pass"
65
+ )
66
+
67
+ lang_x = self.gated_cross_attn_layer(
68
+ lang_x,
69
+ self.audio_x,
70
+ self.audio_x_mask,
71
+ media_locations=self.media_locations,
72
+ use_cached_media=self.use_cached_media,
73
+ )
74
+
75
+ # Normal decoder layer
76
+ lang_x = self.decoder_layer(
77
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
78
+ )
79
+ return lang_x
80
+
81
+
82
+ class FlamingoLMMixin(nn.Module):
83
+ """
84
+ Mixin to add cross-attention layers to a language model.
85
+ """
86
+
87
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
88
+ self.decoder_layers_attr_name = decoder_layers_attr_name
89
+
90
+ def _get_decoder_layers(self):
91
+ return getattr_recursive(self, self.decoder_layers_attr_name)
92
+
93
+ def _set_decoder_layers(self, value):
94
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
95
+
96
+ def init_flamingo(
97
+ self,
98
+ media_token_id,
99
+ lang_hidden_size,
100
+ audio_hidden_size,
101
+ max_window_per_audio,
102
+ cross_attn_every_n_layers,
103
+ gradient_checkpointing,
104
+ ):
105
+ """
106
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
107
+ """
108
+ self.old_decoder_blocks = self._get_decoder_layers()
109
+ self.gated_cross_attn_layers = nn.ModuleList(
110
+ [
111
+ GatedCrossAttentionBlock(
112
+ dim=lang_hidden_size,
113
+ dim_audio=audio_hidden_size,
114
+ max_window_per_audio=max_window_per_audio,
115
+ only_attend_immediate_media=False,
116
+ )
117
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
118
+ else None
119
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
120
+ ]
121
+ )
122
+ self.init_flamingo_layers(gradient_checkpointing)
123
+ self.media_token_id = media_token_id
124
+ self.initialized_flamingo = True
125
+ self._use_cached_audio_x = False
126
+
127
+ def init_flamingo_layers(self, gradient_checkpointing):
128
+ """
129
+ Re initializes the FlamingoLayers.
130
+ Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
131
+ """
132
+ self._set_decoder_layers(
133
+ nn.ModuleList(
134
+ [
135
+ FlamingoLayer(
136
+ gated_cross_attn_layer, decoder_layer, gradient_checkpointing
137
+ )
138
+ for gated_cross_attn_layer, decoder_layer in zip(
139
+ self.gated_cross_attn_layers, self.old_decoder_blocks
140
+ )
141
+ ]
142
+ )
143
+ )
144
+
145
+ def forward(self, input_ids, attention_mask, **kwargs):
146
+ """Condition the Flamingo layers on the media locations before forward()"""
147
+ if not self.initialized_flamingo:
148
+ raise ValueError(
149
+ "Flamingo layers are not initialized. Please call `init_flamingo` first."
150
+ )
151
+
152
+ media_locations = input_ids == self.media_token_id
153
+
154
+ use_cached_media_locations = (
155
+ self._use_cached_audio_x
156
+ and self.is_conditioned()
157
+ and not media_locations.any()
158
+ )
159
+
160
+ for layer in self._get_decoder_layers():
161
+ if not use_cached_media_locations:
162
+ layer.condition_media_locations(media_locations)
163
+ layer.condition_use_cached_media(use_cached_media_locations)
164
+
165
+ kwargs["input_ids"] = input_ids
166
+ kwargs["attention_mask"] = attention_mask
167
+ return super().forward(**kwargs)
168
+
169
+ def is_conditioned(self) -> bool:
170
+ """Check whether all decoder layers are already conditioned."""
171
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
172
+
173
+ def clear_conditioned_layers(self):
174
+ for layer in self._get_decoder_layers():
175
+ layer.condition_audio_x(None, None)
176
+ layer.condition_media_locations(None)
177
+ layer.condition_use_cached_media(None)
src/helpers.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license.
8
+ # LICENSE is in incl_licenses directory.
9
+
10
+ # Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license.
11
+ # LICENSE is in incl_licenses directory.
12
+
13
+ from einops import rearrange, repeat
14
+ from einops_exts import rearrange_many
15
+
16
+ import numpy as np
17
+
18
+ import torch
19
+ from torch import einsum, nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def FeedForward(dim, mult=4):
28
+ inner_dim = int(dim * mult)
29
+ return nn.Sequential(
30
+ nn.LayerNorm(dim),
31
+ nn.Linear(dim, inner_dim, bias=False),
32
+ nn.GELU(),
33
+ nn.Linear(inner_dim, dim, bias=False),
34
+ )
35
+
36
+
37
+ class ScaledDotProductAttention(nn.Module):
38
+ ''' Scaled Dot-Product Attention '''
39
+
40
+ def __init__(self, temperature, attn_dropout=0.1):
41
+ super().__init__()
42
+ self.temperature = temperature
43
+ self.dropout = nn.Dropout(attn_dropout)
44
+
45
+ def forward(self, q, k, v, mask=None):
46
+
47
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
48
+
49
+ if mask is not None:
50
+ attn = attn.masked_fill(mask == 0, -1e9)
51
+
52
+ attn = self.dropout(F.softmax(attn, dim=-1))
53
+ output = torch.matmul(attn, v)
54
+
55
+ return output, attn
56
+
57
+
58
+ class MultiHeadAttention(nn.Module):
59
+ ''' Multi-Head Attention module '''
60
+
61
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
62
+ super().__init__()
63
+
64
+ self.n_head = n_head
65
+ self.d_k = d_k
66
+ self.d_v = d_v
67
+
68
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
69
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
70
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
71
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
72
+
73
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
74
+
75
+ self.dropout = nn.Dropout(dropout)
76
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
77
+
78
+
79
+ def forward(self, q, k, v, mask=None):
80
+
81
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
82
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
83
+
84
+ residual = q
85
+
86
+ # Pass through the pre-attention projection: b x lq x (n*dv)
87
+ # Separate different heads: b x lq x n x dv
88
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
89
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
90
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
91
+
92
+ # Transpose for attention dot product: b x n x lq x dv
93
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
94
+
95
+ if mask is not None:
96
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
97
+
98
+ q, attn = self.attention(q, k, v, mask=mask)
99
+
100
+ # Transpose to move the head dimension back: b x lq x n x dv
101
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
102
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
103
+ q = self.dropout(self.fc(q))
104
+ q += residual
105
+
106
+ q = self.layer_norm(q)
107
+
108
+ return q, attn
109
+
110
+
111
+ class PositionwiseFeedForward(nn.Module):
112
+ ''' A two-feed-forward-layer module '''
113
+
114
+ def __init__(self, d_in, d_hid, dropout=0.1):
115
+ super().__init__()
116
+ self.w_1 = nn.Linear(d_in, d_hid) # position-wise
117
+ self.w_2 = nn.Linear(d_hid, d_in) # position-wise
118
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
119
+ self.dropout = nn.Dropout(dropout)
120
+
121
+ def forward(self, x):
122
+
123
+ residual = x
124
+
125
+ x = self.w_2(F.relu(self.w_1(x)))
126
+ x = self.dropout(x)
127
+ x += residual
128
+
129
+ x = self.layer_norm(x)
130
+
131
+ return x
132
+
133
+
134
+ class PositionalEncoding(nn.Module):
135
+
136
+ def __init__(self, d_hid, n_position=200):
137
+ super(PositionalEncoding, self).__init__()
138
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
139
+
140
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
141
+
142
+ def get_position_angle_vec(position):
143
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
144
+
145
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
146
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
147
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
148
+
149
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
150
+
151
+ def forward(self, x):
152
+ return x + self.pos_table[:, :x.size(1)].clone().detach()
153
+
154
+
155
+ class EncoderLayer(nn.Module):
156
+ ''' Compose with two layers '''
157
+
158
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
159
+ super(EncoderLayer, self).__init__()
160
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
161
+ self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
162
+
163
+ def forward(self, enc_input, slf_attn_mask=None):
164
+ enc_output, enc_slf_attn = self.slf_attn(
165
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
166
+ enc_output = self.pos_ffn(enc_output)
167
+ return enc_output, enc_slf_attn
168
+
169
+
170
+ class TransformerEncoder(nn.Module):
171
+ ''' A encoder model with self attention mechanism. '''
172
+
173
+ def __init__(
174
+ self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64,
175
+ d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True):
176
+
177
+ super().__init__()
178
+
179
+ if n_position > 0:
180
+ self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
181
+ else:
182
+ self.position_enc = lambda x: x
183
+ self.dropout = nn.Dropout(p=dropout)
184
+ self.layer_stack = nn.ModuleList([
185
+ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
186
+ for _ in range(n_layers)])
187
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
188
+ self.scale_emb = scale_emb
189
+ self.d_model = d_model
190
+
191
+ def forward(self, src_seq, return_attns=False):
192
+ if len(src_seq.shape) == 2:
193
+ src_seq = src_seq.unsqueeze(1)
194
+ B, L, D = src_seq.shape
195
+
196
+ enc_slf_attn_list = []
197
+
198
+ causal_mask = None
199
+
200
+ enc_output = src_seq
201
+ if self.scale_emb:
202
+ enc_output = enc_output * self.d_model ** 0.5
203
+ enc_output = self.dropout(self.position_enc(enc_output))
204
+ enc_output = self.layer_norm(enc_output)
205
+
206
+ for enc_layer in self.layer_stack:
207
+ enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask)
208
+ enc_slf_attn_list += [enc_slf_attn] if return_attns else []
209
+
210
+ if return_attns:
211
+ return enc_output, enc_slf_attn_list
212
+ return enc_output
213
+
214
+
215
+ # gated cross attention
216
+ class MaskedCrossAttention(nn.Module):
217
+ def __init__(
218
+ self,
219
+ *,
220
+ dim,
221
+ dim_audio,
222
+ max_window_per_audio,
223
+ dim_head=64,
224
+ heads=8,
225
+ only_attend_immediate_media=True,
226
+ ):
227
+ super().__init__()
228
+ self.max_window_per_audio = max_window_per_audio
229
+ self.scale = dim_head**-0.5
230
+ self.heads = heads
231
+ inner_dim = dim_head * heads
232
+
233
+ self.norm = nn.LayerNorm(dim)
234
+
235
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
236
+ self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False)
237
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
238
+
239
+ self.only_attend_immediate_media = only_attend_immediate_media
240
+
241
+ def forward(
242
+ self,
243
+ x,
244
+ media, media_mask,
245
+ media_locations=None,
246
+ use_cached_media=False
247
+ ):
248
+
249
+ if not use_cached_media:
250
+ assert (
251
+ media_locations.shape[1] == x.shape[1]
252
+ ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
253
+
254
+ T_txt = x.shape[1]
255
+ B, L = media.shape[:2]
256
+ assert media.shape[2] == 1 # extra dim
257
+ assert L % self.max_window_per_audio == 0 # should be 4 or 8 times
258
+ h = self.heads
259
+
260
+ x = self.norm(x)
261
+
262
+ q = self.to_q(x)
263
+ media = rearrange(media, "b t n d -> b (t n) d")
264
+
265
+ k, v = self.to_kv(media).chunk(2, dim=-1)
266
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
267
+
268
+ q = q * self.scale
269
+
270
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
271
+
272
+ # mask padded audio embeddings
273
+ media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim
274
+ sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max)
275
+
276
+ assert self.only_attend_immediate_media is False
277
+
278
+ # mask media locations
279
+ if exists(media_locations):
280
+ few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device)
281
+ for batch_idx in range(B):
282
+ media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio>
283
+ if len(media_locations_b.shape) > 1:
284
+ media_locations_b = media_locations_b.squeeze(-1)
285
+
286
+ for i in range(-1, len(media_locations_b)):
287
+ if i == -1:
288
+ if len(media_locations_b) == 1:
289
+ text_start, text_end = 0, T_txt
290
+ else:
291
+ text_start, text_end = 0, media_locations_b[i+1]
292
+
293
+ elif i == len(media_locations_b) - 1:
294
+ text_start, text_end = media_locations_b[i], T_txt
295
+
296
+ else:
297
+ text_start, text_end = media_locations_b[i], media_locations_b[i+1]
298
+
299
+ if self.only_attend_immediate_media:
300
+ look_at_window_start = max(i,0) * self.max_window_per_audio
301
+ else:
302
+ look_at_window_start = 0
303
+ look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio
304
+
305
+ few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True
306
+
307
+ sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max)
308
+
309
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
310
+ attn = sim.softmax(dim=-1)
311
+
312
+ if exists(media_locations) and self.only_attend_immediate_media:
313
+ text_without_media_mask = text_time == 0
314
+ text_without_media_mask = rearrange(
315
+ text_without_media_mask, "b i -> b 1 i 1"
316
+ )
317
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
318
+
319
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
320
+ out = rearrange(out, "b h n d -> b n (h d)")
321
+ return self.to_out(out)
322
+
323
+
324
+ class GatedCrossAttentionBlock(nn.Module):
325
+ def __init__(
326
+ self,
327
+ *,
328
+ dim,
329
+ dim_audio,
330
+ max_window_per_audio,
331
+ dim_head=64,
332
+ heads=8,
333
+ ff_mult=4,
334
+ only_attend_immediate_media=True,
335
+ ):
336
+ super().__init__()
337
+ self.attn = MaskedCrossAttention(
338
+ dim=dim,
339
+ dim_audio=dim_audio,
340
+ max_window_per_audio=max_window_per_audio,
341
+ dim_head=dim_head,
342
+ heads=heads,
343
+ only_attend_immediate_media=only_attend_immediate_media,
344
+ )
345
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
346
+
347
+ self.ff = FeedForward(dim, mult=ff_mult)
348
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
349
+
350
+ def forward(
351
+ self,
352
+ x,
353
+ media,
354
+ media_mask,
355
+ media_locations=None,
356
+ use_cached_media=False,
357
+ ):
358
+ x = (
359
+ self.attn(
360
+ x,
361
+ media,
362
+ media_mask,
363
+ media_locations=media_locations,
364
+ use_cached_media=use_cached_media,
365
+ )
366
+ * self.attn_gate.tanh()
367
+ + x
368
+ )
369
+ x = self.ff(x) * self.ff_gate.tanh() + x
370
+
371
+ return x
372
+
373
+
374
+ if __name__ == '__main__':
375
+ enc = TransformerEncoder().cuda()
376
+ x = torch.randn(4, 512).cuda()
377
+ output = enc(x)
378
+ enc._use_gradient_checkpointing = True
379
+ print(output.shape)
src/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ def extend_instance(obj, mixin):
8
+ """Apply mixins to a class instance after creation"""
9
+ base_cls = obj.__class__
10
+ base_cls_name = obj.__class__.__name__
11
+ obj.__class__ = type(
12
+ base_cls_name, (mixin, base_cls), {}
13
+ ) # mixin needs to go first for our forward() logic to work
14
+
15
+
16
+ def getattr_recursive(obj, att):
17
+ """
18
+ Return nested attribute of obj
19
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
20
+ """
21
+ if att == "":
22
+ return obj
23
+ i = att.find(".")
24
+ if i < 0:
25
+ return getattr(obj, att)
26
+ else:
27
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
28
+
29
+
30
+ def setattr_recursive(obj, att, val):
31
+ """
32
+ Set nested attribute of obj
33
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
34
+ """
35
+ if "." in att:
36
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
37
+ setattr(obj, att.split(".")[-1], val)
38
+
39
+
40
+ def apply_with_stopping_condition(
41
+ module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
42
+ ):
43
+ if stopping_condition(module):
44
+ return
45
+ if apply_condition(module):
46
+ apply_fn(module, **other_args)
47
+ for child in module.children():
48
+ apply_with_stopping_condition(
49
+ child,
50
+ apply_fn,
51
+ apply_condition=apply_condition,
52
+ stopping_condition=stopping_condition,
53
+ **other_args
54
+ )