Skip to content

Commit 67dd387

Browse files
authored
Merge pull request #6 from chenxwh/main
Add Replicate demo and API
2 parents 8c17939 + b0ba9d0 commit 67dd387

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ The official repository which contains the code and pre-trained models for our p
44

55

66
# 🔥 Updates
7-
7+
- [**2023-8-3**]: Integrated into Replicate, check out the [demo](https://replicate.com/cjwbw/lorahub)!
88
- [**2023-7-27**]: We released our [code](https://github.com/sail-sg/lorahub) and [demo](https://huggingface.co/spaces/sail/lorahub). Check it out!
99
- [**2023-7-26**]: We released our [paper](https://arxiv.org/abs/2307.13269).
1010

cog.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Configuration for Cog ⚙️
2+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3+
4+
build:
5+
gpu: true
6+
python_version: "3.11"
7+
python_packages:
8+
- "numpy==1.25.2"
9+
- "torch==2.0.1"
10+
- "peft==0.4.0"
11+
- "tqdm==4.65.0"
12+
- "nevergrad==0.11.0"
13+
- "transformers==4.31.0"
14+
- "datasets==2.14.3"
15+
16+
predict: "predict.py:Predictor"

predict.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Prediction interface for Cog ⚙️
2+
# https://github.com/replicate/cog/blob/main/docs/python.md
3+
4+
import os
5+
import shutil
6+
import random
7+
from typing import List
8+
import torch
9+
from cog import BasePredictor, Input, Path, BaseModel
10+
11+
from lorahub.algorithm import (
12+
lorahub_learning,
13+
default_get_loss,
14+
default_l1_regularization,
15+
)
16+
from lorahub.constant import LORA_MODULE_NAMES
17+
18+
19+
class Predictor(BasePredictor):
20+
def setup(self) -> None:
21+
"""Load the model into memory to make running multiple predictions efficient"""
22+
pass
23+
24+
def predict(
25+
self,
26+
example_inputs: str = Input(
27+
description="List of input examples, one Line one input.",
28+
default="Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:\nInfer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:\nInfer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:",
29+
),
30+
example_outputs: str = Input(
31+
description="List of output examples, one Line one output.",
32+
default="(C)\n(E)\n(F)",
33+
),
34+
lora_modules_specified: str = Input(
35+
description="Specify LoRA modules for the composition, options are from https://huggingface.co/models?search=lorahub, separated modules with comma, e.g. 'lorahub/flan_t5_large-quarel_logic_test, lorahub/flan_t5_large-coqa'",
36+
default=None,
37+
),
38+
num_random_lora_modules: int = Input(
39+
description="Set number of LoRA modules to use. Ignored if specified modules above.",
40+
default=20,
41+
ge=2,
42+
le=196,
43+
),
44+
max_inference_step: int = Input(
45+
description="Maximum iteration steps to maximise LoRA module composition. We suggest setting it to 40 steps if 20 modules were chosen, with more steps typically needed for more modules.",
46+
default=40,
47+
le=100,
48+
ge=10,
49+
),
50+
seed: int = Input(
51+
description="Random seed. Leave blank to randomize the seed.", default=None
52+
),
53+
) -> Path:
54+
"""Run a single prediction on the model"""
55+
56+
if seed is None:
57+
seed = int.from_bytes(os.urandom(2), "big")
58+
print(f"Using seed: {seed}")
59+
60+
if lora_modules_specified:
61+
lora_module_list = lora_modules_specified.split(",")
62+
for lora_module in lora_module_list:
63+
assert (
64+
lora_module in LORA_MODULE_NAMES
65+
), f"{lora_module} is not recognised."
66+
else:
67+
lora_module_list = random.sample(LORA_MODULE_NAMES, num_random_lora_modules)
68+
69+
example_inputs = example_inputs.splitlines()
70+
example_outputs = example_outputs.splitlines()
71+
assert len(example_inputs) == len(
72+
example_outputs
73+
), "Number of input and output do not match."
74+
75+
# perform LoRAHub learning
76+
module_weights, model, tokenizer = lorahub_learning(
77+
lora_module_list=lora_module_list,
78+
example_inputs=example_inputs,
79+
example_outputs=example_outputs,
80+
max_inference_step=max_inference_step,
81+
model_name_or_path=None, # if not given, we will use the model_name_or_path in lora config
82+
batch_size=None,
83+
get_loss=default_get_loss, # The function to get the objective for optimiztion, use loss as default (can be changed to something like acc. or similarity)
84+
get_regular=default_l1_regularization, # The function to get regularization term for the weight, use 0.05*|w_i| as default
85+
seed=seed,
86+
)
87+
88+
print("The recommended weight set for the LoRA modules is:")
89+
for module_weight, module in zip(module_weights, lora_module_list):
90+
print(f"{module_weight:.4f}: {module}")
91+
92+
out = "/tmp/out.bin"
93+
torch.save(model, out)
94+
95+
return Path(out)

0 commit comments

Comments
 (0)