|
| 1 | +# Prediction interface for Cog ⚙️ |
| 2 | +# https://github.com/replicate/cog/blob/main/docs/python.md |
| 3 | + |
| 4 | +import os |
| 5 | +import shutil |
| 6 | +import random |
| 7 | +from typing import List |
| 8 | +import torch |
| 9 | +from cog import BasePredictor, Input, Path, BaseModel |
| 10 | + |
| 11 | +from lorahub.algorithm import ( |
| 12 | + lorahub_learning, |
| 13 | + default_get_loss, |
| 14 | + default_l1_regularization, |
| 15 | +) |
| 16 | +from lorahub.constant import LORA_MODULE_NAMES |
| 17 | + |
| 18 | + |
| 19 | +class Predictor(BasePredictor): |
| 20 | + def setup(self) -> None: |
| 21 | + """Load the model into memory to make running multiple predictions efficient""" |
| 22 | + pass |
| 23 | + |
| 24 | + def predict( |
| 25 | + self, |
| 26 | + example_inputs: str = Input( |
| 27 | + description="List of input examples, one Line one input.", |
| 28 | + default="Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:\nInfer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:\nInfer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:", |
| 29 | + ), |
| 30 | + example_outputs: str = Input( |
| 31 | + description="List of output examples, one Line one output.", |
| 32 | + default="(C)\n(E)\n(F)", |
| 33 | + ), |
| 34 | + lora_modules_specified: str = Input( |
| 35 | + description="Specify LoRA modules for the composition, options are from https://huggingface.co/models?search=lorahub, separated modules with comma, e.g. 'lorahub/flan_t5_large-quarel_logic_test, lorahub/flan_t5_large-coqa'", |
| 36 | + default=None, |
| 37 | + ), |
| 38 | + num_random_lora_modules: int = Input( |
| 39 | + description="Set number of LoRA modules to use. Ignored if specified modules above.", |
| 40 | + default=20, |
| 41 | + ge=2, |
| 42 | + le=196, |
| 43 | + ), |
| 44 | + max_inference_step: int = Input( |
| 45 | + description="Maximum iteration steps to maximise LoRA module composition. We suggest setting it to 40 steps if 20 modules were chosen, with more steps typically needed for more modules.", |
| 46 | + default=40, |
| 47 | + le=100, |
| 48 | + ge=10, |
| 49 | + ), |
| 50 | + seed: int = Input( |
| 51 | + description="Random seed. Leave blank to randomize the seed.", default=None |
| 52 | + ), |
| 53 | + ) -> Path: |
| 54 | + """Run a single prediction on the model""" |
| 55 | + |
| 56 | + if seed is None: |
| 57 | + seed = int.from_bytes(os.urandom(2), "big") |
| 58 | + print(f"Using seed: {seed}") |
| 59 | + |
| 60 | + if lora_modules_specified: |
| 61 | + lora_module_list = lora_modules_specified.split(",") |
| 62 | + for lora_module in lora_module_list: |
| 63 | + assert ( |
| 64 | + lora_module in LORA_MODULE_NAMES |
| 65 | + ), f"{lora_module} is not recognised." |
| 66 | + else: |
| 67 | + lora_module_list = random.sample(LORA_MODULE_NAMES, num_random_lora_modules) |
| 68 | + |
| 69 | + example_inputs = example_inputs.splitlines() |
| 70 | + example_outputs = example_outputs.splitlines() |
| 71 | + assert len(example_inputs) == len( |
| 72 | + example_outputs |
| 73 | + ), "Number of input and output do not match." |
| 74 | + |
| 75 | + # perform LoRAHub learning |
| 76 | + module_weights, model, tokenizer = lorahub_learning( |
| 77 | + lora_module_list=lora_module_list, |
| 78 | + example_inputs=example_inputs, |
| 79 | + example_outputs=example_outputs, |
| 80 | + max_inference_step=max_inference_step, |
| 81 | + model_name_or_path=None, # if not given, we will use the model_name_or_path in lora config |
| 82 | + batch_size=None, |
| 83 | + get_loss=default_get_loss, # The function to get the objective for optimiztion, use loss as default (can be changed to something like acc. or similarity) |
| 84 | + get_regular=default_l1_regularization, # The function to get regularization term for the weight, use 0.05*|w_i| as default |
| 85 | + seed=seed, |
| 86 | + ) |
| 87 | + |
| 88 | + print("The recommended weight set for the LoRA modules is:") |
| 89 | + for module_weight, module in zip(module_weights, lora_module_list): |
| 90 | + print(f"{module_weight:.4f}: {module}") |
| 91 | + |
| 92 | + out = "/tmp/out.bin" |
| 93 | + torch.save(model, out) |
| 94 | + |
| 95 | + return Path(out) |
0 commit comments