AhmedSSoliman commited on
Commit
2cdb77d
·
1 Parent(s): 0a0064d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -19
README.md CHANGED
@@ -17,22 +17,32 @@ tags:
17
  SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1.
18
  Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned
19
 
20
- ## Training procedure
21
-
22
-
23
- The following `bitsandbytes` quantization config was used during training:
24
- - quant_method: bitsandbytes
25
- - load_in_8bit: False
26
- - load_in_4bit: True
27
- - llm_int8_threshold: 6.0
28
- - llm_int8_skip_modules: None
29
- - llm_int8_enable_fp32_cpu_offload: False
30
- - llm_int8_has_fp16_weight: False
31
- - bnb_4bit_quant_type: nf4
32
- - bnb_4bit_use_double_quant: False
33
- - bnb_4bit_compute_dtype: float16
34
-
35
- ### Framework versions
36
-
37
-
38
- - PEFT 0.6.0.dev0
 
 
 
 
 
 
 
 
 
 
 
17
  SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1.
18
  Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned
19
 
20
+ ### Code
21
+ ```py
22
+ import torch
23
+ from peft import PeftModel, PeftConfig
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
25
+ peft_model_id = "AhmedSSoliman/Mistral-Instruct-SQL-Generation"
26
+ config = PeftConfig.from_pretrained(peft_model_id)
27
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto')
28
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
29
+
30
+ # Load the Lora model
31
+ model = PeftModel.from_pretrained(model, peft_model_id)
32
+
33
+ def predict_SQL(table, question):
34
+ pipe = pipeline('text-generation', model = base_model, tokenizer = tokenizer)
35
+ prompt = f"[INST] Write SQL query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: {table} Question: {question} [/INST] Here is the SQL query to answer to the question: {question}: ``` "
36
+ #prompt = f"### Schema: {table} ### Question: {question} # "
37
+ ans = pipe(prompt, max_new_tokens=200)
38
+ generatedSql = ans[0]['generated_text'].split('```')[2]
39
+ return generatedSql
40
+
41
+
42
+ table = "CREATE TABLE Employee (name VARCHAR, salary INTEGER);"
43
+ question = 'Show names for all employees with salary more than the average.'
44
+
45
+ generatedSql=predict_SQL(table, question)
46
+ print(generatedSql)
47
+
48
+ ```