yxccai commited on
Commit
764612d
·
verified ·
1 Parent(s): 5af2a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -5,32 +5,35 @@ import torch
5
  import re
6
 
7
 
8
- from transformers import AutoModelForSequenceClassification, AutoTokenizer,LlamaConfig,AutoConfig
9
  import torch
10
 
11
- # ==== 配置修正 ====
12
- # 1. 定义并注册配置类
13
  class CustomLlamaConfig(LlamaConfig):
14
- model_type = "llama"
 
15
  def _rope_scaling_validation(self):
16
- pass
17
 
18
- AutoConfig.register("llama", CustomLlamaConfig)
 
19
 
20
- # 2. 加载修正后的配置
 
21
  config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B")
22
 
23
- # ==== 加载模型 ====
24
  model = AutoModelForSequenceClassification.from_pretrained(
25
  "unsloth/DeepSeek-R1-Distill-Llama-8B",
26
  config=config,
27
- trust_remote_code=True
 
28
  )
29
 
30
- # ==== 加载适配器 ===-
31
  model.load_adapter("yxccai/ds-ai-app")
32
 
33
- # ==== 加载分词器 ===-
34
  tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
35
 
36
 
 
5
  import re
6
 
7
 
8
+ from transformers import AutoConfig, LlamaConfig, AutoModelForSequenceClassification, AutoTokenizer
9
  import torch
10
 
11
+ # ==== 自定义配置类 ====
 
12
  class CustomLlamaConfig(LlamaConfig):
13
+ model_type = "custom_llama" # 新名称
14
+
15
  def _rope_scaling_validation(self):
16
+ pass # 禁用RoPE验证
17
 
18
+ # ==== 注册配置 ====
19
+ AutoConfig.register("custom_llama", CustomLlamaConfig) # 使用新名称
20
 
21
+ # ==== 加载模型 ====
22
+ # 1. 加载配置
23
  config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B")
24
 
25
+ # 2. 加载模型(关键修改)
26
  model = AutoModelForSequenceClassification.from_pretrained(
27
  "unsloth/DeepSeek-R1-Distill-Llama-8B",
28
  config=config,
29
+ trust_remote_code=True,
30
+ _config_class=CustomLlamaConfig # 明确指定配置类
31
  )
32
 
33
+ # 3. 加载适配器
34
  model.load_adapter("yxccai/ds-ai-app")
35
 
36
+ # 4. 加载分词器
37
  tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
38
 
39