Steven10429 commited on
Commit
f52e9f0
·
verified ·
1 Parent(s): 574d76d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -17,6 +17,7 @@ import time
17
  log_queue = queue.Queue()
18
  current_logs = []
19
 
 
20
  def log(msg):
21
  """统一的日志处理函数"""
22
  print(msg)
@@ -74,14 +75,11 @@ def check_system_resources(model_name):
74
  else:
75
  raise MemoryError(f"❌ 系统内存不足 (需要 {required_memory_gb:.1f}GB, 可用 {available_memory_gb:.1f}GB)")
76
 
77
- def setup_environment(model_name, hf_token):
78
- """设置环境并返回设备信息""" # try to get from env
79
- if not hf_token:
80
- raise ValueError("请在设置HF_TOKEN")
81
- login(hf_token)
82
 
83
- # 检查系统资源并决定使用什么设备
84
- device, available_memory = check_system_resources(model_name)
 
85
  return device
86
 
87
  def create_hf_repo(repo_name, hf_token, private=True):
@@ -133,6 +131,7 @@ def download_and_merge_model(base_model_name, lora_model_name, output_dir, devic
133
  model.save_pretrained(output_dir)
134
  tokenizer.save_pretrained(output_dir)
135
 
 
136
  return output_dir
137
 
138
  except Exception as e:
@@ -198,14 +197,15 @@ def quantize_and_push_model(model_path, repo_id, bits=8):
198
  def process_model(base_model, lora_model, repo_name, hf_token, progress=gr.Progress()):
199
  """处理模型的主函数,用于Gradio界面"""
200
  try:
 
201
  # 清空之前的日志
202
  current_logs.clear()
203
 
204
  # 设置环境和检查资源
205
- device = setup_environment(base_model, hf_token)
206
 
207
  # 创建HuggingFace仓库
208
- repo_url = create_hf_repo(repo_name, hf_token)
209
 
210
  # 设置输出目录
211
  output_dir = os.path.join(".", "output", repo_name)
@@ -214,16 +214,25 @@ def process_model(base_model, lora_model, repo_name, hf_token, progress=gr.Progr
214
  # 下载并合并模型
215
  model_path = download_and_merge_model(base_model, lora_model, output_dir, device)
216
 
217
- progress(0.4, desc="开始8位量化...")
218
- # 量化并上传模型
219
- quantize_and_push_model(model_path, repo_name, bits=8)
 
 
 
 
 
 
 
 
 
220
 
221
- progress(0.7, desc="开始4位量化...")
222
- quantize_and_push_model(model_path, repo_name, bits=4)
223
 
224
- final_message = f"全部完成!模型已上传至: https://huggingface.co/{repo_name}"
225
- log(final_message)
226
- progress(1.0, desc="处理完成")
227
 
228
  return "\n".join(current_logs)
229
  except Exception as e:
 
17
  log_queue = queue.Queue()
18
  current_logs = []
19
 
20
+
21
  def log(msg):
22
  """统一的日志处理函数"""
23
  print(msg)
 
75
  else:
76
  raise MemoryError(f"❌ 系统内存不足 (需要 {required_memory_gb:.1f}GB, 可用 {available_memory_gb:.1f}GB)")
77
 
78
+ def setup_environment(api, model_name):
 
 
 
 
79
 
80
+ # # 检查系统资源并决定使用什么设备
81
+ # device, available_memory = check_system_resources(model_name)
82
+ device = "cpu"
83
  return device
84
 
85
  def create_hf_repo(repo_name, hf_token, private=True):
 
131
  model.save_pretrained(output_dir)
132
  tokenizer.save_pretrained(output_dir)
133
 
134
+
135
  return output_dir
136
 
137
  except Exception as e:
 
197
  def process_model(base_model, lora_model, repo_name, hf_token, progress=gr.Progress()):
198
  """处理模型的主函数,用于Gradio界面"""
199
  try:
200
+ api = HfApi(token=hf_token)
201
  # 清空之前的日志
202
  current_logs.clear()
203
 
204
  # 设置环境和检查资源
205
+ device = setup_environment(api, base_model)
206
 
207
  # 创建HuggingFace仓库
208
+ repo_url = create_hf_repo(api, repo_name)
209
 
210
  # 设置输出目录
211
  output_dir = os.path.join(".", "output", repo_name)
 
214
  # 下载并合并模型
215
  model_path = download_and_merge_model(base_model, lora_model, output_dir, device)
216
 
217
+ # 推送到HuggingFace
218
+ log(f"正在将模型推送到HuggingFace...")
219
+ api = HfApi()
220
+ api.upload_folder(
221
+ folder_path=model_path,
222
+ repo_id=repo_name,
223
+ repo_type="model"
224
+ )
225
+
226
+ # progress(0.4, desc="开始8位量化...")
227
+ # # 量化并上传模型
228
+ # quantize_and_push_model(model_path, repo_name, bits=8)
229
 
230
+ # progress(0.7, desc="开始4位量化...")
231
+ # quantize_and_push_model(model_path, repo_name, bits=4)
232
 
233
+ # final_message = f"全部完成!模型已上传至: https://huggingface.co/{repo_name}"
234
+ # log(final_message)
235
+ # progress(1.0, desc="处理完成")
236
 
237
  return "\n".join(current_logs)
238
  except Exception as e: