No support for pipline or TextStreamer
#34
by
mohamedlotfy50
- opened
I tried to run the model with pipeline
and found that it does not support image-to-text tasks, I also tried to use the TextStreamer
but it raised an exception.
mohamedlotfy50
changed discussion title from
No support for pipline or stream
to No support for pipline or TextStreamer
I'm using TextStreamer
just fine.
I will share my code for you
from transformers import TextStreamer
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
from io import BytesIO
import argparse
DEFAULT_IMAGE_TOKEN = "<|image_1|>"
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def main(args):
model_id = args.model_base
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=args.device, trust_remote_code=True, torch_dtype=torch.float16)
if args.model_path:
peft_model_id = args.model_path
model.load_adapter(peft_model_id)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
messages = [
{"role": "system", "content": "You are a helpful AI assistant. Provide a useful information about the given image."},
]
image = load_image(args.image_file)
generation_args = {
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
"do_sample": True if args.temperature > 0 else False,
"repetition_penalty": args.repetition_penalty,
}
while True:
try:
inp = input(f"User: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"Assistant: ", end="")
if image is not None and len(messages) < 2:
# only putting the image token in the first turn of user.
# You could just uncomment the system messages or use it.
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
messages.append({"role": "user", "content": inp})
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to(args.device)
streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
generate_ids = model.generate(
**inputs,
eos_token_id=processor.tokenizer.eos_token_id,
streamer=streamer,
**generation_args,
use_cache=True
)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
messages.append({"role":"assistant", "content": outputs})
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--model-base", type=str, default="microsoft/Phi-3-vision-128k-instruct")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--max-new-tokens", type=int, default=500)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
Thanks a lot, I think my problem was I forgot to set the skip_special_tokens=True