File size: 3,531 Bytes
394f0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
---
license: mit
language:
- en
base_model:
- meta-llama/Meta-Llama-3-8B
pipeline_tag: text-generation
tags:
- transformers
---

## SPEED-synthesis-7b-senior

[Little Giants: Synthesizing High-Quality Embedding Data at Scale](https://arxiv.org/pdf/2410.18634.pdf). Haonan Chen, Liang Wang, Nan Yang, Yutao Zhu, Ziliang Zhao, Furu Wei, Zhicheng Dou, arXiv 2024

This is the senior data synthesis model of SPEED. 

## Usage

Below is an example to synthesize classification data using this senior generator. 

The prompts and misc scripts can be found in our [github page](https://github.com/haon-chen/SPEED)

### Transformers

```python
import torch
import os
import random
import numpy as np
import json
import re

from torch import Tensor
from transformers import AutoTokenizer, AutoModelForCausalLM

from prompts_synthesis import get_create_classify_data_prompt
from utils import fix_common_json_errors_and_loads


LLAMA3_PROMPT = """
{prompt} [/INST]
""".strip("\n")

# Each query must come with a one-sentence instruction that describes the task
tasks = [
    'Identify the intended age group for educational technology products.',
    'Classify businesses based on their operational hours.'
]
language = 'English'

prompts = [LLAMA3_PROMPT.format(prompt=get_create_classify_data_prompt(task=task, language=language)[1]['content']) for task in tasks]

tokenizer = AutoTokenizer.from_pretrained('Haon-Chen/speed-synthesis-7b-senior')
model = AutoModelForCausalLM.from_pretrained('Haon-Chen/speed-synthesis-7b-senior')
model.to("cuda:0")
model.eval()
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

with torch.inference_mode():
    # Tokenize the input texts
    encodes = tokenizer(prompts, padding="longest", add_special_tokens=True, return_tensors="pt")
    input_ids = encodes.input_ids.to(model.device)
    attention_mask = encodes.attention_mask.to(model.device)

    # Set the generation parameters
    GEN_CONFIG = {"do_sample":True, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 800}
    output = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pad_token_id = tokenizer.eos_token_id,
        **GEN_CONFIG
    )
output_texts = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)
batch_results = []
for i in range(len(output_texts)):
    batch_results.append(output_texts[i][len(prompts[i]):].strip(' '))

# Format outputs
bad_cnt=0
outputs = []
for i, result in enumerate(batch_results):
    try:
        output = fix_common_json_errors_and_loads(result)
        user_query = output.get("input_text", "")
        positive_document = output.get("label", "")
        hard_negative_document = output.get("misleading_label", "")
    except:
        bad_cnt+=1
        continue
    out_data = {
        "query": user_query,
        "positives": [positive_document],
        "negatives": [hard_negative_document],
        "language": "English",
        "task_definition": tasks[i],
    }
    outputs.append(out_data)
print(bad_cnt)
print(outputs)
```

## Citation

If you find our paper or models helpful, please consider cite as follows:

```bibtex
@article{chen2024little,
  title={Little Giants: Synthesizing High-Quality Embedding Data at Scale},
  author={Chen, Haonan and Wang, Liang and Yang, Nan and Zhu, Yutao and Zhao, Ziliang and Wei, Furu and Dou, Zhicheng},
  journal={arXiv preprint arXiv:2410.18634},
  year={2024}
}
```

## Limitations