diff --git a/easyocr/detection.py b/easyocr/detection.py index e2964b71b45..0c424660b1d 100644 --- a/easyocr/detection.py +++ b/easyocr/detection.py @@ -40,6 +40,7 @@ def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, for n_img in img_resized_list] x = torch.from_numpy(np.array(x)) x = x.to(device) + net = net.to(device) # forward pass with torch.no_grad(): @@ -74,8 +75,9 @@ def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): net = CRAFT() - if device == 'cpu': + if device == 'cpu'or device == 'xpu': net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device, weights_only=False))) + net.to(device) if quantize: try: torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index c08fe0388dd..efd449e5fe2 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -10,6 +10,12 @@ import numpy as np import cv2 import torch + +try: + import intel_extension_for_pytorch as ipex +except Exception: + pass + import os import sys from PIL import Image @@ -74,10 +80,12 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, self.device = 'cuda' elif torch.backends.mps.is_available(): self.device = 'mps' + elif torch.xpu.is_available(): + self.device = 'xpu' else: self.device = 'cpu' if verbose: - LOGGER.warning('Neither CUDA nor MPS are available - defaulting to CPU. Note: This module is much faster with a GPU.') + LOGGER.warning('Neither CUDA/XPU/MPS are available - defaulting to CPU. Note: This module is much faster with a GPU.') else: self.device = gpu diff --git a/easyocr/recognition.py b/easyocr/recognition.py index 530ef9517e2..65a7396c992 100644 --- a/easyocr/recognition.py +++ b/easyocr/recognition.py @@ -165,13 +165,14 @@ def get_recognizer(recog_network, network_params, character,\ model_pkg = importlib.import_module(recog_network) model = model_pkg.Model(num_class=num_class, **network_params) - if device == 'cpu': + if device == 'cpu'or device == 'xpu': state_dict = torch.load(model_path, map_location=device, weights_only=False) new_state_dict = OrderedDict() for key, value in state_dict.items(): new_key = key[7:] new_state_dict[new_key] = value model.load_state_dict(new_state_dict) + model.to(device) if quantize: try: torch.quantization.quantize_dynamic(model, dtype=torch.qint8, inplace=True) diff --git a/requirements.txt b/requirements.txt index 9d0fb7b7107..70c831d730e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -torch -torchvision>=0.5 + opencv-python-headless scipy numpy