diff --git a/pyproject.toml b/pyproject.toml index 696a444b9..2e910a5a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,8 @@ onnx = ["numpy>=1.22", "onnxruntime>=1.10.0", "sentencepiece>=0.1.91"] oskut = ["oskut>=1.3"] +qwen3 = ["torch>=1.9.0", "transformers>=4.22.1"] + sefr_cut = ["sefr_cut>=1.1"] spacy_thai = ["spacy_thai>=0.7.1"] diff --git a/pythainlp/chat/core.py b/pythainlp/chat/core.py index d7dce64ad..0b2a4dd4a 100644 --- a/pythainlp/chat/core.py +++ b/pythainlp/chat/core.py @@ -10,7 +10,6 @@ from pythainlp.generate.wangchanglm import WangChanGLM - class ChatBotModel: history: list[tuple[str, str]] model: "WangChanGLM" @@ -39,7 +38,7 @@ def load_model( :param bool return_dict: return_dict :param bool load_in_8bit: load model in 8bit :param str device: device (cpu, cuda or other) - :param torch_dtype torch_dtype: torch_dtype + :param Optional[torch.dtype] torch_dtype: torch_dtype :param str offload_folder: offload folder :param bool low_cpu_mem_usage: low cpu mem usage """ diff --git a/pythainlp/generate/wangchanglm.py b/pythainlp/generate/wangchanglm.py index a8932c314..ab5531c1b 100644 --- a/pythainlp/generate/wangchanglm.py +++ b/pythainlp/generate/wangchanglm.py @@ -54,7 +54,7 @@ def load_model( :param bool return_dict: return dict :param bool load_in_8bit: load model in 8bit :param str device: device (cpu, cuda or other) - :param torch_dtype torch_dtype: torch_dtype + :param Optional[torch.dtype] torch_dtype: torch_dtype :param str offload_folder: offload folder :param bool low_cpu_mem_usage: low cpu mem usage """ diff --git a/pythainlp/lm/__init__.py b/pythainlp/lm/__init__.py index 8f72a7d70..36ff60a7c 100644 --- a/pythainlp/lm/__init__.py +++ b/pythainlp/lm/__init__.py @@ -2,8 +2,9 @@ # SPDX-FileType: SOURCE # SPDX-License-Identifier: Apache-2.0 -__all__: list[str] = ["calculate_ngram_counts", "remove_repeated_ngrams"] +__all__: list[str] = ["calculate_ngram_counts", "remove_repeated_ngrams", "Qwen3"] +from pythainlp.lm.qwen3 import Qwen3 from pythainlp.lm.text_util import ( calculate_ngram_counts, remove_repeated_ngrams, diff --git a/pythainlp/lm/qwen3.py b/pythainlp/lm/qwen3.py new file mode 100644 index 000000000..b67d20ab5 --- /dev/null +++ b/pythainlp/lm/qwen3.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project +# SPDX-FileType: SOURCE +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + import torch + from transformers import PreTrainedModel, PreTrainedTokenizerBase + + +class Qwen3: + """Qwen3-0.6B language model for Thai text generation. + + A small but capable language model from Alibaba Cloud's Qwen family, + optimized for various NLP tasks including Thai language processing. + """ + + def __init__(self) -> None: + self.model: Optional["PreTrainedModel"] = None + self.tokenizer: Optional["PreTrainedTokenizerBase"] = None + self.device: Optional[str] = None + self.torch_dtype: Optional["torch.dtype"] = None + self.model_path: Optional[str] = None + + def load_model( + self, + model_path: str = "Qwen/Qwen3-0.6B", + device: str = "cuda", + torch_dtype: Optional["torch.dtype"] = None, + low_cpu_mem_usage: bool = True, + ) -> None: + """Load Qwen3 model. + + :param str model_path: model path or HuggingFace model ID + :param str device: device (cpu, cuda or other) + :param Optional[torch.dtype] torch_dtype: torch data type (e.g., torch.float16, torch.bfloat16) + :param bool low_cpu_mem_usage: low cpu mem usage + + :Example: + :: + + from pythainlp.lm import Qwen3 + import torch + + model = Qwen3() + model.load_model(device="cpu", torch_dtype=torch.bfloat16) + """ + try: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + except (ImportError, ModuleNotFoundError) as exc: + raise ImportError( + "Qwen3 language model requires optional dependencies. " + "Install them with: pip install 'pythainlp[qwen3]'" + ) from exc + + # Set default torch_dtype if not provided + if torch_dtype is None: + torch_dtype = torch.float16 + + # Check CUDA availability early before loading model + if device.startswith("cuda"): + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA device requested but CUDA is not available. " + "Check your PyTorch installation and GPU drivers, or use " + "device='cpu' instead." + ) + + self.device = device + self.torch_dtype = torch_dtype + self.model_path = model_path + + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + except OSError as exc: + raise RuntimeError( + f"Failed to load tokenizer from '{self.model_path}'. " + "Check the model path or your network connection." + ) from exc + + try: + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + device_map=device, + torch_dtype=torch_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + except OSError as exc: + # Clean up tokenizer on failure + self.tokenizer = None + raise RuntimeError( + f"Failed to load model from '{self.model_path}'. " + "This can happen due to an invalid model path, missing files, " + "or insufficient disk space." + ) from exc + except Exception as exc: + # Clean up tokenizer on failure + self.tokenizer = None + raise RuntimeError( + f"Failed to load model weights: {exc}. " + "This can be caused by insufficient memory, an incompatible " + "torch_dtype setting, or other configuration issues." + ) from exc + + def generate( + self, + text: str, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: int = 50, + do_sample: bool = True, + skip_special_tokens: bool = True, + ) -> str: + """Generate text from a prompt. + + :param str text: input text prompt + :param int max_new_tokens: maximum number of new tokens to generate + :param float temperature: temperature for sampling (higher = more random) + :param float top_p: top p for nucleus sampling + :param int top_k: top k for top-k sampling + :param bool do_sample: whether to use sampling or greedy decoding + :param bool skip_special_tokens: skip special tokens in output + :return: generated text + :rtype: str + + :Example: + :: + + from pythainlp.lm import Qwen3 + import torch + + model = Qwen3() + model.load_model(device="cpu", torch_dtype=torch.bfloat16) + + result = model.generate("สวัสดี") + print(result) + """ + if self.model is None or self.tokenizer is None or self.device is None: + raise RuntimeError( + "Model not loaded. Please call load_model() first." + ) + + if not text or not isinstance(text, str): + raise ValueError( + "text parameter must be a non-empty string." + ) + + try: + import torch + except (ImportError, ModuleNotFoundError) as exc: + raise ImportError( + "Qwen3 language model requires optional dependencies. " + "Install them with: pip install 'pythainlp[qwen3]'" + ) from exc + + inputs = self.tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"].to(self.device) + + # Note: When do_sample=False (greedy decoding), temperature, top_p, + # and top_k parameters are ignored by the transformers library + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + do_sample=do_sample, + ) + + # Decode only the newly generated tokens + # output_ids and input_ids are guaranteed to be 2D tensors with + # batch size 1 from the tokenizer call above + generated_text = self.tokenizer.decode( + output_ids[0][len(input_ids[0]) :], + skip_special_tokens=skip_special_tokens, + ) + + return generated_text + + def chat( + self, + messages: list[dict[str, Any]], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: int = 50, + do_sample: bool = True, + skip_special_tokens: bool = True, + ) -> str: + """Generate text using chat format. + + :param list[dict[str, Any]] messages: list of message dictionaries with 'role' and 'content' keys + :param int max_new_tokens: maximum number of new tokens to generate + :param float temperature: temperature for sampling + :param float top_p: top p for nucleus sampling + :param int top_k: top k for top-k sampling + :param bool do_sample: whether to use sampling + :param bool skip_special_tokens: skip special tokens in output + :return: generated response + :rtype: str + + :Example: + :: + + from pythainlp.lm import Qwen3 + import torch + + model = Qwen3() + model.load_model(device="cpu", torch_dtype=torch.bfloat16) + + messages = [{"role": "user", "content": "สวัสดีครับ"}] + response = model.chat(messages) + print(response) + """ + if self.model is None or self.tokenizer is None or self.device is None: + raise RuntimeError( + "Model not loaded. Please call load_model() first." + ) + + if not messages or not isinstance(messages, list): + raise ValueError( + "messages parameter must be a non-empty list of message dictionaries." + ) + + # Apply chat template if available, otherwise format manually + if hasattr(self.tokenizer, "apply_chat_template"): + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + # Simple fallback format - preserve content newlines + lines = [] + for msg in messages: + role = str(msg.get("role", "user")).replace("\n", " ") + content = str(msg.get("content", "")) + lines.append(f"{role}: {content}") + text = "\n".join(lines) + "\nassistant: " + + try: + import torch + except (ImportError, ModuleNotFoundError) as exc: + raise ImportError( + "Qwen3 language model requires optional dependencies. " + "Install them with: pip install 'pythainlp[qwen3]'" + ) from exc + + inputs = self.tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"].to(self.device) + + # Note: When do_sample=False (greedy decoding), temperature, top_p, + # and top_k parameters are ignored by the transformers library + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + do_sample=do_sample, + ) + + # Decode only the newly generated tokens + # output_ids and input_ids are guaranteed to be 2D tensors with + # batch size 1 from the tokenizer call above + generated_text = self.tokenizer.decode( + output_ids[0][len(input_ids[0]) :], + skip_special_tokens=skip_special_tokens, + ) + + return generated_text diff --git a/pythainlp/phayathaibert/core.py b/pythainlp/phayathaibert/core.py index b8fff4d4a..dba072907 100644 --- a/pythainlp/phayathaibert/core.py +++ b/pythainlp/phayathaibert/core.py @@ -10,8 +10,13 @@ from typing import TYPE_CHECKING, Union if TYPE_CHECKING: - from transformers import CamembertTokenizer - from transformers.pipelines.base import Pipeline + from transformers import ( + AutoModelForMaskedLM, + AutoModelForTokenClassification, + CamembertTokenizer, + Pipeline, + PreTrainedTokenizerBase, + ) from transformers import ( CamembertTokenizer, @@ -212,13 +217,13 @@ def __init__(self) -> None: pipeline, ) - self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained( _model_name ) - self.model_for_masked_lm: AutoModelForMaskedLM = ( + self.model_for_masked_lm: "AutoModelForMaskedLM" = ( AutoModelForMaskedLM.from_pretrained(_model_name) ) - self.model: "Pipeline" = pipeline( + self.model: "Pipeline" = pipeline( # transformers.Pipeline "fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm, @@ -311,8 +316,8 @@ def __init__(self, model: str = "lunarlist/pos_thai_phayathai") -> None: AutoTokenizer, ) - self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model) - self.model: AutoModelForTokenClassification = ( + self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained(model) + self.model: "AutoModelForTokenClassification" = ( AutoModelForTokenClassification.from_pretrained(model) ) @@ -356,8 +361,8 @@ def __init__(self, model: str = "Pavarissy/phayathaibert-thainer") -> None: AutoTokenizer, ) - self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model) - self.model: AutoModelForTokenClassification = ( + self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained(model) + self.model: "AutoModelForTokenClassification" = ( AutoModelForTokenClassification.from_pretrained(model) ) diff --git a/tests/noauto_torch/__init__.py b/tests/noauto_torch/__init__.py index 675181b6c..eba6b8584 100644 --- a/tests/noauto_torch/__init__.py +++ b/tests/noauto_torch/__init__.py @@ -23,6 +23,7 @@ # Names of module to be tested test_packages: list[str] = [ + "tests.noauto_torch.testn_lm_torch", "tests.noauto_torch.testn_spell_torch", "tests.noauto_torch.testn_tag_torch", "tests.noauto_torch.testn_tokenize_torch", diff --git a/tests/noauto_torch/testn_lm_torch.py b/tests/noauto_torch/testn_lm_torch.py new file mode 100644 index 000000000..7de647d76 --- /dev/null +++ b/tests/noauto_torch/testn_lm_torch.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project +# SPDX-FileType: SOURCE +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from pythainlp.lm import Qwen3 + + +class LMTestCaseN(unittest.TestCase): + def test_qwen3_initialization(self): + # Test that Qwen3 can be instantiated + try: + model = Qwen3() + self.assertIsNotNone(model) + self.assertIsNone(model.model) + self.assertIsNone(model.tokenizer) + except ImportError: + # Skip if dependencies not installed + self.skipTest("Qwen3 dependencies not installed") + + def test_qwen3_generate_without_load(self): + # Test that generate raises error when model is not loaded + try: + model = Qwen3() + with self.assertRaises(RuntimeError): + model.generate("test") + except ImportError: + # Skip if dependencies not installed + self.skipTest("Qwen3 dependencies not installed") + + def test_qwen3_chat_without_load(self): + # Test that chat raises error when model is not loaded + try: + model = Qwen3() + with self.assertRaises(RuntimeError): + model.chat([{"role": "user", "content": "test"}]) + except ImportError: + # Skip if dependencies not installed + self.skipTest("Qwen3 dependencies not installed") + + def test_qwen3_generate_empty_text(self): + # Test that generate validates text input + try: + model = Qwen3() + model.model = object() # Mock to bypass load check + model.tokenizer = object() + model.device = "cpu" + with self.assertRaises(ValueError): + model.generate("") + except ImportError: + # Skip if dependencies not installed + self.skipTest("Qwen3 dependencies not installed") + + def test_qwen3_chat_empty_messages(self): + # Test that chat validates messages input + try: + model = Qwen3() + model.model = object() # Mock to bypass load check + model.tokenizer = object() + model.device = "cpu" + with self.assertRaises(ValueError): + model.chat([]) + except ImportError: + # Skip if dependencies not installed + self.skipTest("Qwen3 dependencies not installed")