Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,110 @@
|
|
1 |
-
---
|
2 |
-
license: llama3
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: llama3
|
3 |
+
---
|
4 |
+
|
5 |
+
**This is not an officially supported Google product.**
|
6 |
+
|
7 |
+
## Overview
|
8 |
+
|
9 |
+
[DiarizationLM](https://arxiv.org/abs/2401.03506) model finetuned
|
10 |
+
on the training subset of the Fisher corpus.
|
11 |
+
|
12 |
+
* Foundation model: [unsloth/llama-3-8b-bnb-4bit](https://huggingface.co/unsloth/llama-3-8b-bnb-4bit)
|
13 |
+
* Finetuning scripts: https://github.com/google/speaker-id/tree/master/DiarizationLM/unsloth
|
14 |
+
|
15 |
+
## Training config
|
16 |
+
|
17 |
+
This model is finetuned on the training subset of the Fisher corpus, using a LoRA adapter of rank 256. The total number of training parameters is 671,088,640. With a batch size of 16, this model has been trained for 25400 steps, which is ~8 epochs of the training data.
|
18 |
+
|
19 |
+
We use the `mixed` flavor during our training, meaning we combine data from `hyp2ora` and `deg2ref` flavors. After the prompt builder, we have a total of 51,063 prompt-completion pairs in our training set.
|
20 |
+
|
21 |
+
The finetuning took more than 4 days on a Google Cloud VM instance that has one NVIDIA A100 GPU with 80GB memory.
|
22 |
+
|
23 |
+
The maximal length of the prompt to this model is 6000 characters, including the " --> " suffix. The maximal sequence length is 4096 tokens.
|
24 |
+
|
25 |
+
## Metrics
|
26 |
+
|
27 |
+
Performance on the Fisher testing set:
|
28 |
+
|
29 |
+
| System | WER (%) | WDER (%) | cpWER (%) |
|
30 |
+
| ------- | ------- | -------- | --------- |
|
31 |
+
| USM + turn-to-diarize baseline | 15.48 | 5.32 | 21.19 |
|
32 |
+
| + This model | - | 4.40 | 19.76 |
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
First, you need to install two packages:
|
37 |
+
|
38 |
+
```
|
39 |
+
pip install transformers diarizationlm
|
40 |
+
```
|
41 |
+
|
42 |
+
On a machine with GPU and CUDA, you can use the model by running the following script:
|
43 |
+
|
44 |
+
```python
|
45 |
+
from transformers import LlamaForCausalLM, AutoTokenizer
|
46 |
+
from diarizationlm import utils
|
47 |
+
|
48 |
+
HYPOTHESIS = """<speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you."""
|
49 |
+
|
50 |
+
print("Loading model...")
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained("google/DiarizationLM-8b-Fisher-v1", device_map="cuda")
|
52 |
+
model = LlamaForCausalLM.from_pretrained("google/DiarizationLM-8b-Fisher-v1", device_map="cuda")
|
53 |
+
|
54 |
+
print("Tokenizing input...")
|
55 |
+
inputs = tokenizer([HYPOTHESIS + " --> "], return_tensors = "pt").to("cuda")
|
56 |
+
|
57 |
+
print("Generating completion...")
|
58 |
+
outputs = model.generate(**inputs,
|
59 |
+
max_new_tokens = inputs.input_ids.shape[1] * 1.2,
|
60 |
+
use_cache = False)
|
61 |
+
|
62 |
+
print("Decoding completion...")
|
63 |
+
completion = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:],
|
64 |
+
skip_special_tokens = True)[0]
|
65 |
+
|
66 |
+
print("Transferring completion to hypothesis text...")
|
67 |
+
transferred_completion = utils.transfer_llm_completion(completion, HYPOTHESIS)
|
68 |
+
|
69 |
+
print("========================================")
|
70 |
+
print("Hypothesis:", HYPOTHESIS)
|
71 |
+
print("========================================")
|
72 |
+
print("Completion:", completion)
|
73 |
+
print("========================================")
|
74 |
+
print("Transferred completion:", transferred_completion)
|
75 |
+
print("========================================")
|
76 |
+
```
|
77 |
+
|
78 |
+
The output will look like below:
|
79 |
+
|
80 |
+
```
|
81 |
+
Loading model...
|
82 |
+
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
|
83 |
+
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:13<00:00, 3.32s/it]
|
84 |
+
generation_config.json: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 172/172 [00:00<00:00, 992kB/s]
|
85 |
+
Tokenizing input...
|
86 |
+
Generating completion...
|
87 |
+
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
|
88 |
+
Decoding completion...
|
89 |
+
Transferring completion to hypothesis text...
|
90 |
+
========================================
|
91 |
+
Hypothesis: <speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you.
|
92 |
+
========================================
|
93 |
+
Completion: <speaker:1> Hello, how are you doing today? <speaker:2> i am doing well. What about you? <speaker:1> i'm doing well, too. Thank you. [eod] [eod] <speaker:2
|
94 |
+
========================================
|
95 |
+
Transferred completion: <speaker:1> Hello, how are you doing today? <speaker:2> I am doing well. What about you? <speaker:1> I'm doing well, too. Thank you.
|
96 |
+
========================================
|
97 |
+
```
|
98 |
+
|
99 |
+
## Citation
|
100 |
+
|
101 |
+
Our paper is cited as:
|
102 |
+
|
103 |
+
```
|
104 |
+
@article{wang2024diarizationlm,
|
105 |
+
title={{DiarizationLM: Speaker Diarization Post-Processing with Large Language Models}},
|
106 |
+
author={Quan Wang and Yiling Huang and Guanlong Zhao and Evan Clark and Wei Xia and Hank Liao},
|
107 |
+
journal={arXiv preprint arXiv:2401.03506},
|
108 |
+
year={2024}
|
109 |
+
}
|
110 |
+
```
|