Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Wrappers for Document Models from transformers #52

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 59 additions & 69 deletions alfred/fm/huggingfacedocument.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import logging
import re
from typing import Optional, List, Tuple

import torch
from PIL import Image
from transformers import AutoProcessor, CLIPModel

from alfred.fm.model import LocalAccessFoundationModel
from alfred.fm.utils import EmbeddingCache
from .response import RankedResponse
from .response import CompletionResponse

logger = logging.getLogger(__name__)


class HuggingFaceDocumentModel(LocalAccessFoundationModel):
"""
The HuggingFaceModel class is a wrapper for HuggingFace Document Models
For now, this class serves as an abstraction for DocumentQA-based prompted labelers.

Currently supports:
- Donut (MIT License)
- LayoutLM (MIT License)
Expand All @@ -25,8 +26,6 @@ def __init__(
self,
model_string: str,
local_path: Optional[str] = None,
image_cache_limit: int = 32,
text_cache_limit: int = 64,
):
"""
Constructor for HuggingFaceDocumentModel
Expand All @@ -35,27 +34,31 @@ def __init__(
:type model_string: str
:param local_path: (optional) local path to store the model
:type local_path: Optional[str]
:param cache_limit: vache limit for the text and image inputs respectively
:type cache_limit: int
"""
super().__init__(model_string, local_path)
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(model_string,
cache_dir=local_path).to(
self.device)
self.tokenizer = None
self.processor = AutoProcessor.from_pretrained(model_string,
cache_dir=local_path)

self.model.eval()

image_cache_limit = max(1, image_cache_limit)
text_cache_limit = max(1, text_cache_limit)
model_string_lowered = model_string.lower()

self.model_type = ""
if "layoutlm" in model_string_lowered:
from transformers import pipeline
self.pipe = pipeline(
"document-question-answering",
model=model_string,
)
self.model_type = "layoutlm"
elif "donut" in model_string_lowered:
from transformers import DonutProcessor, VisionEncoderDecoderModel
self.processor = DonutProcessor.from_pretrained(model_string, cache_dir=local_path)
self.model = VisionEncoderDecoderModel.from_pretrained(model_string, cache_dir=local_path).to(self.device)
self.model_type = "donut"
else:
raise NotImplementedError(f"Model {model_string} is not supported for document processing.")

self.image_cache = EmbeddingCache(image_cache_limit)
self.text_cache = EmbeddingCache(text_cache_limit)
self.model.eval()

def _score_batch(
def _generate_batch(
self,
batch_instance: Tuple[List[Image.Image], List[str]],
**kwargs,
Expand All @@ -67,53 +70,40 @@ def _score_batch(
:type batch_instance: Tuple[List[Image.Image], List[str]]
:param kwargs: (optional) additional arguments
:type kwargs: Dict
:return: list of RankedResponse
:rtype: List[RankedResponse]
:return: list of CompletionResponse
:rtype: List[CompletionResponse]
"""
def _get_image_features(image):
image = self.processor(images=image,
return_tensors="pt").to(self.device)
image_features = self.model.get_image_features(**image)
image_features = image_features / image_features.norm(
p=2, dim=-1, keepdim=True)
return image_features

def _get_text_features(text):
text = self.processor(text=text,
padding=True,
truncation=True,
return_tensors="pt").to(self.device)
text_features = self.model.get_text_features(**text)
text_features = text_features / text_features.norm(
p=2, dim=-1, keepdim=True)
return text_features

return_image_features = kwargs.get("return_image_features", False)
return_raw_logits = kwargs.get("raw_logits", False)

image, candidates = batch_instance

image_features = self.image_cache.get(image, _get_image_features)
text_features = self.text_cache.get(candidates, _get_text_features)

logits_per_image = image_features @ text_features.t()
logits_per_text = logits_per_image.t().detach().cpu()

logits = logits_per_text if return_raw_logits else logits_per_text.softmax(
dim=0)
prediction = [candidates[i] for i in logits.argmax(dim=0)]

return [
RankedResponse(prediction=prediction[i],
scores={
candidate: logits[cidx][i].item()
for cidx, candidate in enumerate(candidates)
},
logits={
candidate: logits_per_text[cidx][i].item()
for cidx, candidate in enumerate(candidates)
},
embeddings=image_features[i]
if return_image_features else None)
for i in range(len(prediction))
]
max_new_tokens = kwargs.get("max_new_tokens", 512)
if self.model_type == "donut":
responses = []
for image, prompt in zip(*batch_instance):
decoder_input_ids = self.processor.tokenizer(f"<s_docvqa><s_question>{prompt}</s_question><s_answer>", add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = self.processor(image, return_tensors="pt").pixel_values

outputs = self.model.generate(
pixel_values.to(self.device),
decoder_input_ids=decoder_input_ids.to(self.device),
max_length=max_new_tokens,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)

sequence = self.processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token,
"")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
response = self.processor.token2json(sequence)["answer"]
responses.append(CompletionResponse(prediction=response))
return responses
elif self.model_type == "layoutlm":
responses = []
for image, prompt in zip(*batch_instance):
response = self.pipe(image, prompt)["answer"]
responses.append(CompletionResponse(prediction=response))
return responses


1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ A full list of `Alfred` project modules.
- [Dummy](alfred/fm/dummy.md#dummy)
- [Flexgen](alfred/fm/flexgen.md#flexgen)
- [Huggingface](alfred/fm/huggingface.md#huggingface)
- [Huggingfacedocument](alfred/fm/huggingfacedocument.md#huggingfacedocument)
- [Huggingfacevlm](alfred/fm/huggingfacevlm.md#huggingfacevlm)
- [Model](alfred/fm/model.md#model)
- [Onnx](alfred/fm/onnx.md#onnx)
Expand Down
38 changes: 38 additions & 0 deletions docs/alfred/fm/huggingfacedocument.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Huggingfacedocument

[Alfred Index](../../README.md#alfred-index) /
[Alfred](../index.md#alfred) /
[Fm](./index.md#fm) /
Huggingfacedocument

> Auto-generated documentation for [alfred.fm.huggingfacedocument](../../../alfred/fm/huggingfacedocument.py) module.
- [Huggingfacedocument](#huggingfacedocument)
- [HuggingFaceDocumentModel](#huggingfacedocumentmodel)

## HuggingFaceDocumentModel

[Show source in huggingfacedocument.py:14](../../../alfred/fm/huggingfacedocument.py#L14)

The HuggingFaceModel class is a wrapper for HuggingFace Document Models
For now, this class serves as an abstraction for DocumentQA-based prompted labelers.

Currently supports:
- Donut (MIT License)
- LayoutLM (MIT License)
- LayoutLMv2 (CC BY-NC-SA 4.0)
- LayoutLMv3 (CC BY-NC-SA 4.0)

#### Signature

```python
class HuggingFaceDocumentModel(LocalAccessFoundationModel):
def __init__(self, model_string: str, local_path: Optional[str] = None):
...
```

#### See also

- [LocalAccessFoundationModel](./model.md#localaccessfoundationmodel)


1 change: 1 addition & 0 deletions docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Fm
- [Dummy](./dummy.md)
- [Flexgen](./flexgen.md)
- [Huggingface](./huggingface.md)
- [Huggingfacedocument](./huggingfacedocument.md)
- [Huggingfacevlm](./huggingfacevlm.md)
- [Model](./model.md)
- [Onnx](./onnx.md)
Expand Down