# Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md from cog import BasePredictor, Input import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import argparse class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" # self.model = torch.load("./weights.pth") model_name = "defog/sqlcoder-34b-alpha" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", use_cache=True, offload_folder="./.cache", ) def predict( self, prompt: str = Input(description="Prompt to generate from"), ) -> str: """Run a single prediction on the model""" # processed_input = preprocess(image) # output = self.model(processed_image, scale) # return postprocess(output) # make sure the model stops generating at triple ticks # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] eos_token_id = self.tokenizer.eos_token_id pipe = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, max_length=300, do_sample=False, num_beams=5, # do beam search with 5 beams for high quality results ) generated_query = ( pipe( prompt, num_return_sequences=1, eos_token_id=eos_token_id, pad_token_id=eos_token_id, )[0]["generated_text"] .split("```sql")[-1] .split("```")[0] .split(";")[0] .strip() + ";" ) return generated_query