| |
| |
|
|
| from cog import BasePredictor, Input |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| import argparse |
|
|
|
|
| class Predictor(BasePredictor): |
| def setup(self) -> None: |
| """Load the model into memory to make running multiple predictions efficient""" |
| |
| model_name = "defog/sqlcoder-34b-alpha" |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| use_cache=True, |
| offload_folder="./.cache", |
| ) |
|
|
| def predict( |
| self, |
| prompt: str = Input(description="Prompt to generate from"), |
| ) -> str: |
| """Run a single prediction on the model""" |
| |
| |
| |
|
|
| |
| |
| eos_token_id = self.tokenizer.eos_token_id |
| pipe = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| max_length=300, |
| do_sample=False, |
| num_beams=5, |
| ) |
| generated_query = ( |
| pipe( |
| prompt, |
| num_return_sequences=1, |
| eos_token_id=eos_token_id, |
| pad_token_id=eos_token_id, |
| )[0]["generated_text"] |
| .split("```sql")[-1] |
| .split("```")[0] |
| .split(";")[0] |
| .strip() |
| + ";" |
| ) |
| return generated_query |
|
|