Upload 22 files
Browse files- .gitattributes +1 -0
 - client.py +25 -0
 - data/labels2id.pkl +3 -0
 - models.py +36 -0
 - preprocessors.py +48 -0
 - pretrained_models/ELECT +3 -0
 - pretrained_models/chinese-roberta-wwm-ext/added_tokens.json +1 -0
 - pretrained_models/chinese-roberta-wwm-ext/config.json +28 -0
 - pretrained_models/chinese-roberta-wwm-ext/pytorch_model.bin +3 -0
 - pretrained_models/chinese-roberta-wwm-ext/special_tokens_map.json +1 -0
 - pretrained_models/chinese-roberta-wwm-ext/tokenizer.json +0 -0
 - pretrained_models/chinese-roberta-wwm-ext/tokenizer_config.json +1 -0
 - pretrained_models/chinese-roberta-wwm-ext/vocab.txt +0 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/README.md +55 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/config.json +43 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/pytorch_model.bin +3 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/special_tokens_map.json +7 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer.json +0 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer_config.json +13 -0
 - pretrained_models/roberta_wwm_ext_hunyin_2epoch/vocab.txt +0 -0
 - server.py +156 -0
 - utils/__init__.py +2 -0
 - utils/arg_parser.py +24 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            pretrained_models/ELECT filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        client.py
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import requests
         
     | 
| 3 | 
         
            +
            import time
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            def json_send(data, url):
         
     | 
| 7 | 
         
            +
                headers = {"Content-type": "application/json",
         
     | 
| 8 | 
         
            +
                           "Accept": "text/plain", "charset": "UTF-8"}
         
     | 
| 9 | 
         
            +
                response = requests.post(url=url, headers=headers, data=json.dumps(data))
         
     | 
| 10 | 
         
            +
                return json.loads(response.text)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 14 | 
         
            +
                url = 'http://127.0.0.1:9099/check_hunyin'
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                print("Start inference")
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                while True:
         
     | 
| 19 | 
         
            +
                    input_text = input("Enter text:").strip()
         
     | 
| 20 | 
         
            +
                    if len(input_text) == 0:
         
     | 
| 21 | 
         
            +
                        continue
         
     | 
| 22 | 
         
            +
                    data = {"input": input_text}
         
     | 
| 23 | 
         
            +
                    result = json_send(data, url)
         
     | 
| 24 | 
         
            +
                    print(result['output'])
         
     | 
| 25 | 
         
            +
             
     | 
    	
        data/labels2id.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:179f76b8b014524ca915315f6eab916a20b582d89016e15b36bbdc055f1790cd
         
     | 
| 3 | 
         
            +
            size 54968
         
     | 
    	
        models.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
            from transformers import AutoModel,AutoTokenizer
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class Elect(nn.Module):
         
     | 
| 7 | 
         
            +
                def __init__(self,args,device):
         
     | 
| 8 | 
         
            +
                    super(Elect, self).__init__()
         
     | 
| 9 | 
         
            +
                    self.device = device
         
     | 
| 10 | 
         
            +
                    self.plm = AutoModel.from_pretrained(args.ckpt_dir)
         
     | 
| 11 | 
         
            +
                    self.hidden_size = self.plm.config.hidden_size
         
     | 
| 12 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir)
         
     | 
| 13 | 
         
            +
                    self.clf = nn.Linear(self.hidden_size, len(args.labels))
         
     | 
| 14 | 
         
            +
                    self.dropout = nn.Dropout(0.3)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                    self.p2l = nn.Linear(self.hidden_size,256)
         
     | 
| 17 | 
         
            +
                    self.proj = nn.Linear(self.hidden_size*2,self.hidden_size)
         
     | 
| 18 | 
         
            +
                    self.l2a = nn.Linear(11,256)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size))
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def forward(self, batch):
         
     | 
| 23 | 
         
            +
                    ids = batch['ids'].to(self.device, dtype=torch.long)
         
     | 
| 24 | 
         
            +
                    mask = batch['mask'].to(self.device, dtype=torch.long)
         
     | 
| 25 | 
         
            +
                    token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long)
         
     | 
| 26 | 
         
            +
                    hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0]
         
     | 
| 27 | 
         
            +
                    pooler = hidden_state[:, 0]  # [batch_size, hidden_size]
         
     | 
| 28 | 
         
            +
                    pooler = self.dropout(pooler) # [batch_size, hidden_size]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1)  # [batch_size, hidden_size]
         
     | 
| 31 | 
         
            +
                    art = [email protected]  # [batch_size, hidden_size]
         
     | 
| 32 | 
         
            +
                    oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1)))  # [batch_size, hidden_size]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    output = self.clf(oa)  # [batch_size, len(labels)]
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    return output
         
     | 
    	
        preprocessors.py
    ADDED
    
    | 
         @@ -0,0 +1,48 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import pickle as pkl
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from sklearn.preprocessing import MultiLabelBinarizer
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class BasicPreprocessor(object):
         
     | 
| 8 | 
         
            +
                def __init__(self, data_generator, tokenizer, args):
         
     | 
| 9 | 
         
            +
                    self.data_generator = data_generator
         
     | 
| 10 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 11 | 
         
            +
                    self.args = args
         
     | 
| 12 | 
         
            +
                    file_path = os.path.join(args.data_dir, args.data_file)
         
     | 
| 13 | 
         
            +
                    if file_path.endswith("pkl"):
         
     | 
| 14 | 
         
            +
                        with open(file_path, "rb") as f:
         
     | 
| 15 | 
         
            +
                            self.raw_data = pkl.load(f)
         
     | 
| 16 | 
         
            +
                        print(self.raw_data[0])
         
     | 
| 17 | 
         
            +
                        exit()
         
     | 
| 18 | 
         
            +
                    elif file_path.endswith("json"):
         
     | 
| 19 | 
         
            +
                        self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
         
     | 
| 20 | 
         
            +
                    self.shuffle()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    self.mlb=MultiLabelBinarizer()
         
     | 
| 23 | 
         
            +
                    self.mlb.fit([args.labels])
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def shuffle(self):
         
     | 
| 26 | 
         
            +
                    idx=np.arange(len(self.raw_data))
         
     | 
| 27 | 
         
            +
                    np.random.shuffle(idx)
         
     | 
| 28 | 
         
            +
                    self.raw_data=np.array(self.raw_data)[idx]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def process(self):
         
     | 
| 31 | 
         
            +
                    args = self.args
         
     | 
| 32 | 
         
            +
                    data_generator = self.data_generator
         
     | 
| 33 | 
         
            +
                    raw_data = self.raw_data
         
     | 
| 34 | 
         
            +
                    tokenizer = self.tokenizer
         
     | 
| 35 | 
         
            +
                    mlb = self.mlb
         
     | 
| 36 | 
         
            +
                    
         
     | 
| 37 | 
         
            +
                    if args.test_only:
         
     | 
| 38 | 
         
            +
                        train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
         
     | 
| 39 | 
         
            +
                        test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
         
     | 
| 40 | 
         
            +
                        return train_data, test_data
         
     | 
| 41 | 
         
            +
                    #只使用90%作为训练集,10%作为测试集,不使用验证集
         
     | 
| 42 | 
         
            +
                    train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
         
     | 
| 43 | 
         
            +
                    test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    return train_data, test_data
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
    	
        pretrained_models/ELECT
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:acc44b4361b2a738336dce66dab399e54338f6100b900ddf1c654fd2d444b0ee
         
     | 
| 3 | 
         
            +
            size 415790649
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/added_tokens.json
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {}
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/config.json
    ADDED
    
    | 
         @@ -0,0 +1,28 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "architectures": [
         
     | 
| 3 | 
         
            +
                "BertForMaskedLM"
         
     | 
| 4 | 
         
            +
              ],
         
     | 
| 5 | 
         
            +
              "attention_probs_dropout_prob": 0.1,
         
     | 
| 6 | 
         
            +
              "bos_token_id": 0,
         
     | 
| 7 | 
         
            +
              "directionality": "bidi",
         
     | 
| 8 | 
         
            +
              "eos_token_id": 2,
         
     | 
| 9 | 
         
            +
              "hidden_act": "gelu",
         
     | 
| 10 | 
         
            +
              "hidden_dropout_prob": 0.1,
         
     | 
| 11 | 
         
            +
              "hidden_size": 768,
         
     | 
| 12 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 13 | 
         
            +
              "intermediate_size": 3072,
         
     | 
| 14 | 
         
            +
              "layer_norm_eps": 1e-12,
         
     | 
| 15 | 
         
            +
              "max_position_embeddings": 512,
         
     | 
| 16 | 
         
            +
              "model_type": "bert",
         
     | 
| 17 | 
         
            +
              "num_attention_heads": 12,
         
     | 
| 18 | 
         
            +
              "num_hidden_layers": 12,
         
     | 
| 19 | 
         
            +
              "output_past": true,
         
     | 
| 20 | 
         
            +
              "pad_token_id": 1,
         
     | 
| 21 | 
         
            +
              "pooler_fc_size": 768,
         
     | 
| 22 | 
         
            +
              "pooler_num_attention_heads": 12,
         
     | 
| 23 | 
         
            +
              "pooler_num_fc_layers": 3,
         
     | 
| 24 | 
         
            +
              "pooler_size_per_head": 128,
         
     | 
| 25 | 
         
            +
              "pooler_type": "first_token_transform",
         
     | 
| 26 | 
         
            +
              "type_vocab_size": 2,
         
     | 
| 27 | 
         
            +
              "vocab_size": 21128
         
     | 
| 28 | 
         
            +
            }
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/pytorch_model.bin
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:1ded5a5a1c7841dee6e47942f7b5bf2bcf6f73ff19197580f852f7f638f86b35
         
     | 
| 3 | 
         
            +
            size 411578458
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/tokenizer.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        pretrained_models/chinese-roberta-wwm-ext/tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {"init_inputs": []}
         
     | 
    	
        pretrained_models/chinese-roberta-wwm-ext/vocab.txt
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/README.md
    ADDED
    
    | 
         @@ -0,0 +1,55 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            tags:
         
     | 
| 3 | 
         
            +
            - generated_from_trainer
         
     | 
| 4 | 
         
            +
            metrics:
         
     | 
| 5 | 
         
            +
            - accuracy
         
     | 
| 6 | 
         
            +
            model-index:
         
     | 
| 7 | 
         
            +
            - name: roberta_wwm_ext_hunyin_2epoch
         
     | 
| 8 | 
         
            +
              results: []
         
     | 
| 9 | 
         
            +
            ---
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            <!-- This model card has been generated automatically according to the information the Trainer had access to. You
         
     | 
| 12 | 
         
            +
            should probably proofread and complete it, then remove this comment. -->
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # roberta_wwm_ext_hunyin_2epoch
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            This model is a fine-tuned version of [/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext](https://huggingface.co//home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext) on an unknown dataset.
         
     | 
| 17 | 
         
            +
            It achieves the following results on the evaluation set:
         
     | 
| 18 | 
         
            +
            - Loss: 0.0510
         
     | 
| 19 | 
         
            +
            - Accuracy: 0.9881
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            ## Model description
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            More information needed
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            ## Intended uses & limitations
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            More information needed
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            ## Training and evaluation data
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            More information needed
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            ## Training procedure
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ### Training hyperparameters
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            The following hyperparameters were used during training:
         
     | 
| 38 | 
         
            +
            - learning_rate: 2e-05
         
     | 
| 39 | 
         
            +
            - train_batch_size: 32
         
     | 
| 40 | 
         
            +
            - eval_batch_size: 8
         
     | 
| 41 | 
         
            +
            - seed: 42
         
     | 
| 42 | 
         
            +
            - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
         
     | 
| 43 | 
         
            +
            - lr_scheduler_type: linear
         
     | 
| 44 | 
         
            +
            - num_epochs: 2.0
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            ### Training results
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            ### Framework versions
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            - Transformers 4.28.0.dev0
         
     | 
| 53 | 
         
            +
            - Pytorch 1.13.1+cu117
         
     | 
| 54 | 
         
            +
            - Datasets 2.10.1
         
     | 
| 55 | 
         
            +
            - Tokenizers 0.13.2
         
     | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/config.json
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_name_or_path": "/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext",
         
     | 
| 3 | 
         
            +
              "architectures": [
         
     | 
| 4 | 
         
            +
                "BertForSequenceClassification"
         
     | 
| 5 | 
         
            +
              ],
         
     | 
| 6 | 
         
            +
              "attention_probs_dropout_prob": 0.1,
         
     | 
| 7 | 
         
            +
              "bos_token_id": 0,
         
     | 
| 8 | 
         
            +
              "classifier_dropout": null,
         
     | 
| 9 | 
         
            +
              "directionality": "bidi",
         
     | 
| 10 | 
         
            +
              "eos_token_id": 2,
         
     | 
| 11 | 
         
            +
              "hidden_act": "gelu",
         
     | 
| 12 | 
         
            +
              "hidden_dropout_prob": 0.1,
         
     | 
| 13 | 
         
            +
              "hidden_size": 768,
         
     | 
| 14 | 
         
            +
              "id2label": {
         
     | 
| 15 | 
         
            +
                "0": false,
         
     | 
| 16 | 
         
            +
                "1": true
         
     | 
| 17 | 
         
            +
              },
         
     | 
| 18 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 19 | 
         
            +
              "intermediate_size": 3072,
         
     | 
| 20 | 
         
            +
              "label2id": {
         
     | 
| 21 | 
         
            +
                "false": 0,
         
     | 
| 22 | 
         
            +
                "true": 1
         
     | 
| 23 | 
         
            +
              },
         
     | 
| 24 | 
         
            +
              "layer_norm_eps": 1e-12,
         
     | 
| 25 | 
         
            +
              "max_position_embeddings": 512,
         
     | 
| 26 | 
         
            +
              "model_type": "bert",
         
     | 
| 27 | 
         
            +
              "num_attention_heads": 12,
         
     | 
| 28 | 
         
            +
              "num_hidden_layers": 12,
         
     | 
| 29 | 
         
            +
              "output_past": true,
         
     | 
| 30 | 
         
            +
              "pad_token_id": 1,
         
     | 
| 31 | 
         
            +
              "pooler_fc_size": 768,
         
     | 
| 32 | 
         
            +
              "pooler_num_attention_heads": 12,
         
     | 
| 33 | 
         
            +
              "pooler_num_fc_layers": 3,
         
     | 
| 34 | 
         
            +
              "pooler_size_per_head": 128,
         
     | 
| 35 | 
         
            +
              "pooler_type": "first_token_transform",
         
     | 
| 36 | 
         
            +
              "position_embedding_type": "absolute",
         
     | 
| 37 | 
         
            +
              "problem_type": "single_label_classification",
         
     | 
| 38 | 
         
            +
              "torch_dtype": "float32",
         
     | 
| 39 | 
         
            +
              "transformers_version": "4.28.0.dev0",
         
     | 
| 40 | 
         
            +
              "type_vocab_size": 2,
         
     | 
| 41 | 
         
            +
              "use_cache": true,
         
     | 
| 42 | 
         
            +
              "vocab_size": 21128
         
     | 
| 43 | 
         
            +
            }
         
     | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/pytorch_model.bin
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:cd02e6af0b827ddf0cf89fe32850c1da32c1ce8f83e0157e2f2fb11a93b1a4f9
         
     | 
| 3 | 
         
            +
            size 409149557
         
     | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "cls_token": "[CLS]",
         
     | 
| 3 | 
         
            +
              "mask_token": "[MASK]",
         
     | 
| 4 | 
         
            +
              "pad_token": "[PAD]",
         
     | 
| 5 | 
         
            +
              "sep_token": "[SEP]",
         
     | 
| 6 | 
         
            +
              "unk_token": "[UNK]"
         
     | 
| 7 | 
         
            +
            }
         
     | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "cls_token": "[CLS]",
         
     | 
| 3 | 
         
            +
              "do_lower_case": true,
         
     | 
| 4 | 
         
            +
              "mask_token": "[MASK]",
         
     | 
| 5 | 
         
            +
              "model_max_length": 1000000000000000019884624838656,
         
     | 
| 6 | 
         
            +
              "pad_token": "[PAD]",
         
     | 
| 7 | 
         
            +
              "sep_token": "[SEP]",
         
     | 
| 8 | 
         
            +
              "special_tokens_map_file": "/home/zhangc/law_related/law_telecom/PLMs/chinese-roberta-wwm-ext/special_tokens_map.json",
         
     | 
| 9 | 
         
            +
              "strip_accents": null,
         
     | 
| 10 | 
         
            +
              "tokenize_chinese_chars": true,
         
     | 
| 11 | 
         
            +
              "tokenizer_class": "BertTokenizer",
         
     | 
| 12 | 
         
            +
              "unk_token": "[UNK]"
         
     | 
| 13 | 
         
            +
            }
         
     | 
    	
        pretrained_models/roberta_wwm_ext_hunyin_2epoch/vocab.txt
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        server.py
    ADDED
    
    | 
         @@ -0,0 +1,156 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import subprocess
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import codecs
         
     | 
| 5 | 
         
            +
            import logging
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import math
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import json
         
     | 
| 10 | 
         
            +
            import random
         
     | 
| 11 | 
         
            +
            from tqdm import tqdm
         
     | 
| 12 | 
         
            +
            from transformers import pipeline
         
     | 
| 13 | 
         
            +
            from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from flask import Flask, request, jsonify
         
     | 
| 17 | 
         
            +
            import json
         
     | 
| 18 | 
         
            +
            import random
         
     | 
| 19 | 
         
            +
            from tqdm import tqdm
         
     | 
| 20 | 
         
            +
            import os
         
     | 
| 21 | 
         
            +
            import pickle as pkl
         
     | 
| 22 | 
         
            +
            from argparse import Namespace
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from models import Elect
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            import torch
         
     | 
| 27 | 
         
            +
            from transformers import AutoModel,AutoTokenizer
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            from sklearn.preprocessing import MultiLabelBinarizer
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            app = Flask(__name__)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            hunyin_classifier = None
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            fatiao_args = Namespace()
         
     | 
| 39 | 
         
            +
            fatiao_tokenizer = None
         
     | 
| 40 | 
         
            +
            fatiao_model = None
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            @app.route('/check_hunyin', methods=['GET', 'POST'])
         
     | 
| 44 | 
         
            +
            def check_hunyin():
         
     | 
| 45 | 
         
            +
                input_text = request.json['input'].strip()
         
     | 
| 46 | 
         
            +
                force_return = request.json['force_return'] if 'force_return' in request.json else False
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                print("input_text:", input_text)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                if len(input_text) == 0:
         
     | 
| 51 | 
         
            +
                    json_result = {
         
     | 
| 52 | 
         
            +
                        "output": []
         
     | 
| 53 | 
         
            +
                    }
         
     | 
| 54 | 
         
            +
                    return jsonify(json_result)
         
     | 
| 55 | 
         
            +
                
         
     | 
| 56 | 
         
            +
                if not force_return:
         
     | 
| 57 | 
         
            +
                    classifier_result = hunyin_classifier(input_text[:500])
         
     | 
| 58 | 
         
            +
                    print(classifier_result)
         
     | 
| 59 | 
         
            +
                    classifier_result = classifier_result[0]['label']
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    # 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
         
     | 
| 62 | 
         
            +
                    if '婚' in input_text:
         
     | 
| 63 | 
         
            +
                        classifier_result = True
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    # 如果不是婚姻相关的,直接返回空
         
     | 
| 66 | 
         
            +
                    if classifier_result == False:
         
     | 
| 67 | 
         
            +
                        json_result = {
         
     | 
| 68 | 
         
            +
                            "output": []
         
     | 
| 69 | 
         
            +
                        }
         
     | 
| 70 | 
         
            +
                        return jsonify(json_result)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
         
     | 
| 73 | 
         
            +
                batch = {
         
     | 
| 74 | 
         
            +
                    'ids': inputs['input_ids'],
         
     | 
| 75 | 
         
            +
                    'mask': inputs['attention_mask'],
         
     | 
| 76 | 
         
            +
                    'token_type_ids':inputs["token_type_ids"]
         
     | 
| 77 | 
         
            +
                }
         
     | 
| 78 | 
         
            +
                model_output = fatiao_model(batch)
         
     | 
| 79 | 
         
            +
                pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
         
     | 
| 80 | 
         
            +
                pred_laws = []
         
     | 
| 81 | 
         
            +
                for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
         
     | 
| 82 | 
         
            +
                    pred_laws.append({
         
     | 
| 83 | 
         
            +
                        'id': law_id,
         
     | 
| 84 | 
         
            +
                        'score': float(score),
         
     | 
| 85 | 
         
            +
                        'text': fatiao_args.mlb.classes_[law_id]
         
     | 
| 86 | 
         
            +
                    })
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                json_result = {
         
     | 
| 89 | 
         
            +
                        "output": pred_laws[:3]
         
     | 
| 90 | 
         
            +
                    }
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                print("json_result:", json_result)
         
     | 
| 93 | 
         
            +
                return jsonify(json_result)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                        
         
     | 
| 96 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                # 加载咨询分类模型,用于判断是否与婚姻有关
         
     | 
| 100 | 
         
            +
                hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch"
         
     | 
| 101 | 
         
            +
                hunyin_config = AutoConfig.from_pretrained(
         
     | 
| 102 | 
         
            +
                    hunyin_classifier_path,
         
     | 
| 103 | 
         
            +
                    num_labels=2,
         
     | 
| 104 | 
         
            +
                )
         
     | 
| 105 | 
         
            +
                hunyin_tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 106 | 
         
            +
                    hunyin_classifier_path
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
                hunyin_model = AutoModelForSequenceClassification.from_pretrained(
         
     | 
| 109 | 
         
            +
                    hunyin_classifier_path,
         
     | 
| 110 | 
         
            +
                    config=hunyin_config,
         
     | 
| 111 | 
         
            +
                )
         
     | 
| 112 | 
         
            +
                hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                # 加载法条检索模型
         
     | 
| 115 | 
         
            +
                
         
     | 
| 116 | 
         
            +
                fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext"
         
     | 
| 117 | 
         
            +
                fatiao_args.device = "cuda:0"
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                with open(os.path.join("data/labels2id.pkl"), "rb") as f:
         
     | 
| 120 | 
         
            +
                    laws2id = pkl.load(f) 
         
     | 
| 121 | 
         
            +
                    fatiao_args.labels = list(laws2id.keys())
         
     | 
| 122 | 
         
            +
                # get id2laws
         
     | 
| 123 | 
         
            +
                id2laws = {}
         
     | 
| 124 | 
         
            +
                for k, v in laws2id.items():
         
     | 
| 125 | 
         
            +
                    id2laws[v] = k
         
     | 
| 126 | 
         
            +
                # fatiao_args.id2laws = id2laws
         
     | 
| 127 | 
         
            +
                print("法条个数:", len(id2laws))
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                fatiao_args.tokenizer = fatiao_tokenizer
         
     | 
| 132 | 
         
            +
                fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
         
     | 
| 133 | 
         
            +
                fatiao_model.eval()
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                mlb = MultiLabelBinarizer() # mlb.classes_: idx to law article
         
     | 
| 136 | 
         
            +
                mlb.fit([fatiao_args.labels])
         
     | 
| 137 | 
         
            +
                fatiao_args.mlb = mlb
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                with torch.no_grad():
         
     | 
| 140 | 
         
            +
                    for idx, l in enumerate(fatiao_args.labels):
         
     | 
| 141 | 
         
            +
                        # remove 《民法典》第xxxx条:
         
     | 
| 142 | 
         
            +
                        text = ':'.join(l.split(':')[1:]).lower()
         
     | 
| 143 | 
         
            +
                        la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256,
         
     | 
| 144 | 
         
            +
                               return_tensors="pt")
         
     | 
| 145 | 
         
            +
                        ids = la_in['input_ids'].to(fatiao_args.device)
         
     | 
| 146 | 
         
            +
                        mask = la_in['attention_mask'].to(fatiao_args.device)
         
     | 
| 147 | 
         
            +
                        fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0)   
         
     | 
| 148 | 
         
            +
                
         
     | 
| 149 | 
         
            +
                
         
     | 
| 150 | 
         
            +
                fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
         
     | 
| 151 | 
         
            +
                fatiao_model.to(fatiao_args.device)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                logger.info("model loaded")
         
     | 
| 156 | 
         
            +
                app.run(host="0.0.0.0", port=9098, debug=False)
         
     | 
    	
        utils/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .arg_parser import get_parser
         
     | 
| 2 | 
         
            +
            # from .eval_metric import EvalMetric
         
     | 
    	
        utils/arg_parser.py
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def get_parser():
         
     | 
| 5 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 6 | 
         
            +
                parser.add_argument("--data_dir", default="telecom_data/", type=str,
         
     | 
| 7 | 
         
            +
                                    help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", )
         
     | 
| 8 | 
         
            +
                parser.add_argument("--data_file", default="data_filter.pkl", type=str)
         
     | 
| 9 | 
         
            +
                parser.add_argument("--ckpt_dir", default="./PLMs/chinese-roberta-wwm-ext", type=str,
         
     | 
| 10 | 
         
            +
                                    help="The checkpoints dir. Should contain the pretrained model.", )
         
     | 
| 11 | 
         
            +
                parser.add_argument("--preprocessor", default="BasePreprocessor", type=str,
         
     | 
| 12 | 
         
            +
                                    help="Name of preprocessor.", )
         
     | 
| 13 | 
         
            +
                parser.add_argument("--device", default="cuda:0", type=str)
         
     | 
| 14 | 
         
            +
                parser.add_argument("--batch_size", default=128, type=int)
         
     | 
| 15 | 
         
            +
                parser.add_argument("--max_epoch", default=100, type=int)
         
     | 
| 16 | 
         
            +
                parser.add_argument("--top_k", default=5, type=int)
         
     | 
| 17 | 
         
            +
                parser.add_argument("--output_name", default='ELECT_test_output.json', type=str)
         
     | 
| 18 | 
         
            +
                return parser
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            '''
         
     | 
| 21 | 
         
            +
            python main_elect_inference.py \
         
     | 
| 22 | 
         
            +
            --data_file jicheng_questions.json \
         
     | 
| 23 | 
         
            +
            --output_name jicheng_questions_output.json 
         
     | 
| 24 | 
         
            +
            '''
         
     |