Update scripts/eval_mteb.py
Browse files- scripts/eval_mteb.py +21 -7
    	
        scripts/eval_mteb.py
    CHANGED
    
    | @@ -119,7 +119,6 @@ CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'Onl | |
| 119 | 
             
                               'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
         | 
| 120 | 
             
                               'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
         | 
| 121 |  | 
| 122 | 
            -
             | 
| 123 | 
             
            MTEB_PL = [
         | 
| 124 | 
             
                "CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
         | 
| 125 | 
             
                "SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
         | 
| @@ -406,9 +405,9 @@ class Wrapper: | |
| 406 | 
             
                    self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 407 | 
             
                    self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
         | 
| 408 | 
             
                    self.instruction = instruction
         | 
| 409 | 
            -
                    self.default_query = default_query
         | 
|  | |
| 410 | 
             
                    self.force_default = force_default
         | 
| 411 | 
            -
             
         | 
| 412 | 
             
                    if self.tokenizer.padding_side != 'right':
         | 
| 413 | 
             
                        logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
         | 
| 414 | 
             
                        self.tokenizer.padding_side = 'right'
         | 
| @@ -675,13 +674,15 @@ class Wrapper: | |
| 675 | 
             
            def main(args):
         | 
| 676 | 
             
                tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
         | 
| 677 | 
             
                encoder = Encoder(args.model, args.pooling)
         | 
|  | |
| 678 | 
             
                model = Wrapper(
         | 
| 679 | 
             
                    tokenizer, encoder,
         | 
| 680 | 
             
                    batch_size=args.batch_size,
         | 
| 681 | 
             
                    max_seq_len=args.max_seq_len,
         | 
| 682 | 
            -
                    normalize_embeddings=args.norm
         | 
|  | |
| 683 | 
             
                )
         | 
| 684 | 
            -
                
         | 
| 685 | 
             
                if args.task == 'mteb':
         | 
| 686 | 
             
                    task_names = MTEB_TASK_LIST
         | 
| 687 | 
             
                    lang = ['en']
         | 
| @@ -709,8 +710,21 @@ def main(args): | |
| 709 | 
             
                        eval_splits = task_cls.description['eval_splits']
         | 
| 710 | 
             
                    else:
         | 
| 711 | 
             
                        eval_splits = ["test"]
         | 
| 712 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 713 | 
             
                    evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
         | 
|  | |
|  | |
|  | |
|  | |
| 714 | 
             
                    print('\n')
         | 
| 715 |  | 
| 716 |  | 
| @@ -729,4 +743,4 @@ if __name__ == "__main__": | |
| 729 | 
             
                )
         | 
| 730 | 
             
                _PARSER.add_argument("--norm", action="store_true")
         | 
| 731 | 
             
                _ARGS = _PARSER.parse_args()
         | 
| 732 | 
            -
                main(_ARGS)
         | 
|  | |
| 119 | 
             
                               'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
         | 
| 120 | 
             
                               'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
         | 
| 121 |  | 
|  | |
| 122 | 
             
            MTEB_PL = [
         | 
| 123 | 
             
                "CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
         | 
| 124 | 
             
                "SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
         | 
|  | |
| 405 | 
             
                    self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 406 | 
             
                    self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
         | 
| 407 | 
             
                    self.instruction = instruction
         | 
| 408 | 
            +
                    self.default_query = default_query 
         | 
| 409 | 
            +
                    self.sep = sep
         | 
| 410 | 
             
                    self.force_default = force_default
         | 
|  | |
| 411 | 
             
                    if self.tokenizer.padding_side != 'right':
         | 
| 412 | 
             
                        logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
         | 
| 413 | 
             
                        self.tokenizer.padding_side = 'right'
         | 
|  | |
| 674 | 
             
            def main(args):
         | 
| 675 | 
             
                tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
         | 
| 676 | 
             
                encoder = Encoder(args.model, args.pooling)
         | 
| 677 | 
            +
                default_query = args.default_type == 'query'
         | 
| 678 | 
             
                model = Wrapper(
         | 
| 679 | 
             
                    tokenizer, encoder,
         | 
| 680 | 
             
                    batch_size=args.batch_size,
         | 
| 681 | 
             
                    max_seq_len=args.max_seq_len,
         | 
| 682 | 
            +
                    normalize_embeddings=args.norm,
         | 
| 683 | 
            +
                    default_query=default_query
         | 
| 684 | 
             
                )
         | 
| 685 | 
            +
                sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
         | 
| 686 | 
             
                if args.task == 'mteb':
         | 
| 687 | 
             
                    task_names = MTEB_TASK_LIST
         | 
| 688 | 
             
                    lang = ['en']
         | 
|  | |
| 710 | 
             
                        eval_splits = task_cls.description['eval_splits']
         | 
| 711 | 
             
                    else:
         | 
| 712 | 
             
                        eval_splits = ["test"]
         | 
| 713 | 
            +
                    sym = False
         | 
| 714 | 
            +
                    for name in sym_retrievals:
         | 
| 715 | 
            +
                        if task.startswith(name):
         | 
| 716 | 
            +
                            sym = True
         | 
| 717 | 
            +
                            break
         | 
| 718 | 
            +
                        else:
         | 
| 719 | 
            +
                            sym = False
         | 
| 720 | 
            +
                    if sym:
         | 
| 721 | 
            +
                        logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
         | 
| 722 | 
            +
                        model.force_default = True
         | 
| 723 | 
             
                    evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                    if sym:
         | 
| 726 | 
            +
                        logger.info(f"Switch back.")
         | 
| 727 | 
            +
                        model.force_default = force_default_ori
         | 
| 728 | 
             
                    print('\n')
         | 
| 729 |  | 
| 730 |  | 
|  | |
| 743 | 
             
                )
         | 
| 744 | 
             
                _PARSER.add_argument("--norm", action="store_true")
         | 
| 745 | 
             
                _ARGS = _PARSER.parse_args()
         | 
| 746 | 
            +
                main(_ARGS)
         | 
