Skip to content

Commit

Permalink
Merge pull request #52 from BatsResearch/update-huggingfacedocqa
Browse files Browse the repository at this point in the history
Updating Wrappers for Document Models from transformers
  • Loading branch information
dotpyu authored Nov 3, 2023
2 parents e24adb0 + 249517d commit 8fb4bce
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 69 deletions.
128 changes: 59 additions & 69 deletions alfred/fm/huggingfacedocument.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import logging
import re
from typing import Optional, List, Tuple

import torch
from PIL import Image
from transformers import AutoProcessor, CLIPModel

from alfred.fm.model import LocalAccessFoundationModel
from alfred.fm.utils import EmbeddingCache
from .response import RankedResponse
from .response import CompletionResponse

logger = logging.getLogger(__name__)


class HuggingFaceDocumentModel(LocalAccessFoundationModel):
"""
The HuggingFaceModel class is a wrapper for HuggingFace Document Models
For now, this class serves as an abstraction for DocumentQA-based prompted labelers.
Currently supports:
- Donut (MIT License)
- LayoutLM (MIT License)
Expand All @@ -25,8 +26,6 @@ def __init__(
self,
model_string: str,
local_path: Optional[str] = None,
image_cache_limit: int = 32,
text_cache_limit: int = 64,
):
"""
Constructor for HuggingFaceDocumentModel
Expand All @@ -35,27 +34,31 @@ def __init__(
:type model_string: str
:param local_path: (optional) local path to store the model
:type local_path: Optional[str]
:param cache_limit: vache limit for the text and image inputs respectively
:type cache_limit: int
"""
super().__init__(model_string, local_path)
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(model_string,
cache_dir=local_path).to(
self.device)
self.tokenizer = None
self.processor = AutoProcessor.from_pretrained(model_string,
cache_dir=local_path)

self.model.eval()

image_cache_limit = max(1, image_cache_limit)
text_cache_limit = max(1, text_cache_limit)
model_string_lowered = model_string.lower()

self.model_type = ""
if "layoutlm" in model_string_lowered:
from transformers import pipeline
self.pipe = pipeline(
"document-question-answering",
model=model_string,
)
self.model_type = "layoutlm"
elif "donut" in model_string_lowered:
from transformers import DonutProcessor, VisionEncoderDecoderModel
self.processor = DonutProcessor.from_pretrained(model_string, cache_dir=local_path)
self.model = VisionEncoderDecoderModel.from_pretrained(model_string, cache_dir=local_path).to(self.device)
self.model_type = "donut"
else:
raise NotImplementedError(f"Model {model_string} is not supported for document processing.")

self.image_cache = EmbeddingCache(image_cache_limit)
self.text_cache = EmbeddingCache(text_cache_limit)
self.model.eval()

def _score_batch(
def _generate_batch(
self,
batch_instance: Tuple[List[Image.Image], List[str]],
**kwargs,
Expand All @@ -67,53 +70,40 @@ def _score_batch(
:type batch_instance: Tuple[List[Image.Image], List[str]]
:param kwargs: (optional) additional arguments
:type kwargs: Dict
:return: list of RankedResponse
:rtype: List[RankedResponse]
:return: list of CompletionResponse
:rtype: List[CompletionResponse]
"""
def _get_image_features(image):
image = self.processor(images=image,
return_tensors="pt").to(self.device)
image_features = self.model.get_image_features(**image)
image_features = image_features / image_features.norm(
p=2, dim=-1, keepdim=True)
return image_features

def _get_text_features(text):
text = self.processor(text=text,
padding=True,
truncation=True,
return_tensors="pt").to(self.device)
text_features = self.model.get_text_features(**text)
text_features = text_features / text_features.norm(
p=2, dim=-1, keepdim=True)
return text_features

return_image_features = kwargs.get("return_image_features", False)
return_raw_logits = kwargs.get("raw_logits", False)

image, candidates = batch_instance

image_features = self.image_cache.get(image, _get_image_features)
text_features = self.text_cache.get(candidates, _get_text_features)

logits_per_image = image_features @ text_features.t()
logits_per_text = logits_per_image.t().detach().cpu()

logits = logits_per_text if return_raw_logits else logits_per_text.softmax(
dim=0)
prediction = [candidates[i] for i in logits.argmax(dim=0)]

return [
RankedResponse(prediction=prediction[i],
scores={
candidate: logits[cidx][i].item()
for cidx, candidate in enumerate(candidates)
},
logits={
candidate: logits_per_text[cidx][i].item()
for cidx, candidate in enumerate(candidates)
},
embeddings=image_features[i]
if return_image_features else None)
for i in range(len(prediction))
]
max_new_tokens = kwargs.get("max_new_tokens", 512)
if self.model_type == "donut":
responses = []
for image, prompt in zip(*batch_instance):
decoder_input_ids = self.processor.tokenizer(f"<s_docvqa><s_question>{prompt}</s_question><s_answer>", add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = self.processor(image, return_tensors="pt").pixel_values

outputs = self.model.generate(
pixel_values.to(self.device),
decoder_input_ids=decoder_input_ids.to(self.device),
max_length=max_new_tokens,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)

sequence = self.processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token,
"")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
response = self.processor.token2json(sequence)["answer"]
responses.append(CompletionResponse(prediction=response))
return responses
elif self.model_type == "layoutlm":
responses = []
for image, prompt in zip(*batch_instance):
response = self.pipe(image, prompt)["answer"]
responses.append(CompletionResponse(prediction=response))
return responses


1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ A full list of `Alfred` project modules.
- [Dummy](alfred/fm/dummy.md#dummy)
- [Flexgen](alfred/fm/flexgen.md#flexgen)
- [Huggingface](alfred/fm/huggingface.md#huggingface)
- [Huggingfacedocument](alfred/fm/huggingfacedocument.md#huggingfacedocument)
- [Huggingfacevlm](alfred/fm/huggingfacevlm.md#huggingfacevlm)
- [Model](alfred/fm/model.md#model)
- [Onnx](alfred/fm/onnx.md#onnx)
Expand Down
38 changes: 38 additions & 0 deletions docs/alfred/fm/huggingfacedocument.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Huggingfacedocument

[Alfred Index](../../README.md#alfred-index) /
[Alfred](../index.md#alfred) /
[Fm](./index.md#fm) /
Huggingfacedocument

> Auto-generated documentation for [alfred.fm.huggingfacedocument](../../../alfred/fm/huggingfacedocument.py) module.
- [Huggingfacedocument](#huggingfacedocument)
- [HuggingFaceDocumentModel](#huggingfacedocumentmodel)

## HuggingFaceDocumentModel

[Show source in huggingfacedocument.py:14](../../../alfred/fm/huggingfacedocument.py#L14)

The HuggingFaceModel class is a wrapper for HuggingFace Document Models
For now, this class serves as an abstraction for DocumentQA-based prompted labelers.

Currently supports:
- Donut (MIT License)
- LayoutLM (MIT License)
- LayoutLMv2 (CC BY-NC-SA 4.0)
- LayoutLMv3 (CC BY-NC-SA 4.0)

#### Signature

```python
class HuggingFaceDocumentModel(LocalAccessFoundationModel):
def __init__(self, model_string: str, local_path: Optional[str] = None):
...
```

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


1 change: 1 addition & 0 deletions docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Fm
- [Dummy](./dummy.md)
- [Flexgen](./flexgen.md)
- [Huggingface](./huggingface.md)
- [Huggingfacedocument](./huggingfacedocument.md)
- [Huggingfacevlm](./huggingfacevlm.md)
- [Model](./model.md)
- [Onnx](./onnx.md)
Expand Down

0 comments on commit 8fb4bce

Please sign in to comment.