import spacy import zstandard as zstd import json import typing import os from tqdm import tqdm import multiprocessing import random from langdetect import detect import argparse parser = argparse.ArgumentParser() parser.add_argument('--input_dir', type=str, help='Path to the input file') args = parser.parse_args() input_dir = args.input_dir def is_english(text): try: lang = detect(text) return lang == 'en' except: return False def process_text(texts, model, out_f, lock): for text in texts: doc = model(text) freq_cnt = {} for e in doc.ents: if e not in freq_cnt: freq_cnt[e] = 0 freq_cnt[e] += 1 if len(freq_cnt) == 0: continue sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1]) most_freq = sorted_freq[-1][0] data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_} json_data = json.dumps(data) with lock: out_f.write(json_data + '\n') out_f.flush() def run_ner_linking(texts: typing.List[str], ner_model_path: str): nlp = spacy.load(ner_model_path) out_f = open('result/temp_store_data.json', 'w', encoding='utf-8') lock = multiprocessing.Lock() processes = [] for i in tqdm(range(0, len(texts), 1000)): p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock)) processes.append(p) p.start() for p in processes: p.join() out_f.close() return wikipedia_out_path='result/wikipedia.json' subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()] wikipedia_data = [] for sub_dir in subdirectories: chunk_dir = sub_dir+'/' zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')] for file in tqdm(zst_files): with open(chunk_dir+file, 'rb') as compressed_file: decompressor = zstd.ZstdDecompressor() with decompressor.stream_reader(compressed_file) as reader: decompressed_data = reader.read() for line in decompressed_data.splitlines(): data = json.loads(line) # print(data) if data['meta']['redpajama_set_name']=='RedPajamaWikipedia': if is_english(data['text']): wikipedia_data.append(data) with open(wikipedia_out_path, 'w', encoding='utf-8') as f: for data in wikipedia_data: json_data = json.dumps(data) f.write(json_data+'\n') wikipedia_data = [] ner_model_path = 'kc-ner-model' with open(wikipedia_out_path, 'r', encoding='utf-8') as f: for line in tqdm(f): data = json.loads(line) wikipedia_data.append(data['text']) run_ner_linking(wikipedia_data, ner_model_path) entity_info_path = 'result/entity_info.json' with open(entity_info_path, 'r', encoding='utf-8') as f: entity_info = json.load(f) all_original_data = [] category = {} all_data = [] with open('result/temp_store_data.json', 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) all_data.append(data) if data['label'] not in category: category[data['label']] = [] category[data['label']].append(data['main_entity']) with open('result/processed_data.json', 'w', encoding='utf-8') as f: for data in tqdm(all_data): text = data['text'] main_entity = [data['main_entity']] if data['id'] in entity_info: main_entity.extend(entity_info[data['id']]['aliases']) if len(category[data['label']]) == 1: continue replaced_eneity = random.sample(category[data['label']], 1) while replaced_eneity[0] in main_entity: replaced_eneity = random.sample(category[data['label']], 1) for entity in main_entity: text = text.replace(entity, replaced_eneity[0]) data = { 'text':text, 'original_main_entity':main_entity, 'replaced_entity':replaced_eneity[0] } json_data = json.dumps(data) f.write(json_data+'\n')