Skip to content

Commit

Permalink
Merge pull request #267 from smart-on-fhir/mikix/llama2-13b
Browse files Browse the repository at this point in the history
feat: bump WIP llama2 image to 13B model (from 7B)
  • Loading branch information
mikix authored Aug 22, 2023
2 parents 92b86be + db0fcd1 commit 9723ef2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
16 changes: 9 additions & 7 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ services:
- AWS_SESSION_TOKEN
- AWS_PROFILE
- AWS_DEFAULT_PROFILE
- CUMULUS_HUGGING_FACE_URL
- CUMULUS_HUGGING_FACE_URL=http://llama2:8086/
- URL_CTAKES_REST=http://ctakes-covid:8080/ctakes-web-rest/service/analyze
- URL_CNLP_NEGATION=http://cnlp-transformers:8000/negation/process
volumes:
Expand Down Expand Up @@ -71,25 +71,27 @@ services:
devices:
- capabilities: [gpu]

# This is a WIP llama2 setup, currently suitable for running in a g5.xlarge AWS instance.
llama2:
image: ghcr.io/huggingface/text-generation-inference:1.0.0
image: ghcr.io/huggingface/text-generation-inference:1.0.1
environment:
# If you update anything here that could affect NLP results, consider updating the
# task_version of any tasks that use this docker.
- HUGGING_FACE_HUB_TOKEN
- MAX_BATCH_PREFILL_TOKENS=2048 # default of 4096 overwhelms a 16GB machine
- MODEL_ID=meta-llama/Llama-2-7b-chat-hf
- REVISION=08751db2aca9bf2f7f80d2e516117a53d7450235
- MODEL_ID=meta-llama/Llama-2-13b-chat-hf
- QUANTIZE=bitsandbytes-nf4 # 4bit
- PORT=8086
- REVISION=0ba94ac9b9e1d5a0037780667e8b219adde1908c
healthcheck:
# There's no curl or wget inside this container, but there is python3!
test: ["CMD", "python3", "-c", "import socket; socket.create_connection(('localhost', 80))"]
test: ["CMD", "python3", "-c", "import socket; socket.create_connection(('localhost', 8086))"]
start_period: 20m # give plenty of time for startup, since we may be downloading a model
volumes:
- hf-data:/data
networks:
- cumulus-etl
ports:
- 8086:80
- 8086:8086
deploy:
resources:
reservations:
Expand Down
16 changes: 8 additions & 8 deletions cumulus_etl/etl/studies/hftest/hf_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class HuggingFaceTestTask(tasks.EtlTask):
# ** 0 **
# This is fluid until we actually promote this to a real task - feel free to update without bumping the version.
# container: ghcr.io/huggingface/text-generation-inference
# container reversion: 3ef5ffbc6400370ff2e1546550a6bad3ac61b079
# container reversion: 09eca6422788b1710c54ee0d05dd6746f16bb681
# container properties:
# MAX_BATCH_PREFILL_TOKENS=2048
# model: meta-llama/Llama-2-7b-chat-hf
# model revision: 08751db2aca9bf2f7f80d2e516117a53d7450235
# QUANTIZE=bitsandbytes-nf4
# model: meta-llama/Llama-2-13b-chat-hf
# model revision: 0ba94ac9b9e1d5a0037780667e8b219adde1908c
# system prompt:
# "You will be given a clinical note, and you should reply with a short summary of that note."
# user prompt: a clinical note
Expand All @@ -51,9 +51,9 @@ async def prepare_task(self) -> bool:

# Sanity check a few of the properties, to make sure we don't accidentally get pointed at an unexpected model.
expected_info_present = (
raw_info.get("model_id") == "meta-llama/Llama-2-7b-chat-hf"
and raw_info.get("model_sha") == "08751db2aca9bf2f7f80d2e516117a53d7450235"
and raw_info.get("sha") == "3ef5ffbc6400370ff2e1546550a6bad3ac61b079"
raw_info.get("model_id") == "meta-llama/Llama-2-13b-chat-hf"
and raw_info.get("model_sha") == "0ba94ac9b9e1d5a0037780667e8b219adde1908c"
and raw_info.get("sha") == "09eca6422788b1710c54ee0d05dd6746f16bb681"
)
if not expected_info_present:
logging.warning(" Skipping task: NLP server is using an unexpected model setup.")
Expand All @@ -64,7 +64,7 @@ async def prepare_task(self) -> bool:

async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
"""Passes clinical notes through HF and returns any symptoms found"""
http_client = httpx.AsyncClient()
http_client = httpx.AsyncClient(timeout=300)

for docref in self.read_ndjson(progress=progress):
can_process = nlp.is_docref_valid(docref) and self.scrubber.scrub_resource(docref, scrub_attachments=False)
Expand Down
5 changes: 5 additions & 0 deletions cumulus_etl/nlp/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ async def hf_prompt(prompt: str | dict, *, client: httpx.AsyncClient = None) ->
"options": {
"wait_for_model": True,
},
"parameters": {
# Maybe max_new_tokens should be configurable, but let's hope a universal value is fine for now.
# When bumping this, consider whether you should bump the task version of any tasks that call this.
"max_new_tokens": 1000,
},
},
)
response.raise_for_status()
Expand Down

0 comments on commit 9723ef2

Please sign in to comment.