Alimubariz124 commited on
Commit
2e1a242
·
verified ·
1 Parent(s): 905af92

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +7 -4
model_loader.py CHANGED
@@ -1,12 +1,15 @@
1
  from sentence_transformers import SentenceTransformer
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
  def load_embedding_model():
5
  return SentenceTransformer("all-MiniLM-L6-v2")
6
 
 
 
7
  def load_llm():
8
- model_name = "google/flan-t5-base" # You can change this
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForCausalLM.from_pretrained(model_name)
11
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
12
  return pipe
 
 
1
  from sentence_transformers import SentenceTransformer
2
+ #from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
  def load_embedding_model():
5
  return SentenceTransformer("all-MiniLM-L6-v2")
6
 
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
+
9
  def load_llm():
10
+ model_name = "google/flan-t5-base"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
14
  return pipe
15
+