Skip to content

Commit

Permalink
Fix OpenVINO image classification examples (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Mar 11, 2024
1 parent 72b0630 commit e6ee88c
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions examples/openvino/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
token: str = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
Expand Down Expand Up @@ -239,8 +239,7 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
Expand All @@ -252,7 +251,6 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

# If we don't have a validation split, split off a percentage of train as validation.
Expand Down Expand Up @@ -287,15 +285,15 @@ def compute_metrics(p):
finetuning_task="image-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
model = AutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

Expand All @@ -311,7 +309,7 @@ def compute_metrics(p):
model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Define torchvision transforms to be applied to each image.
Expand Down

0 comments on commit e6ee88c

Please sign in to comment.