diff --git a/google/genai/tests/tunings/test_tune.py b/google/genai/tests/tunings/test_tune.py index 2212d5ee8..daf3f1356 100755 --- a/google/genai/tests/tunings/test_tune.py +++ b/google/genai/tests/tunings/test_tune.py @@ -20,6 +20,12 @@ from .. import pytest_helper import pytest + +VERTEX_HTTP_OPTIONS = { + 'api_version': 'v1beta1', + 'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/', +} + evaluation_config=genai_types.EvaluationConfig( metrics=[ genai_types.Metric(name="bleu", prompt_template="test prompt template") @@ -158,6 +164,26 @@ ), exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.", ), + pytest_helper.TestTableItem( + name="test_tune_distillation", + parameters=genai_types.CreateTuningJobParameters( + base_model="meta/llama3_1@llama-3.1-8b-instruct", + training_dataset=genai_types.TuningDataset( + gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl", + ), + config=genai_types.CreateTuningJobConfig( + method="DISTILLATION", + base_teacher_model="deepseek-ai/deepseek-v3.1-maas", + epoch_count=20, + validation_dataset=genai_types.TuningValidationDataset( + gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl", + ), + output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder", + http_options=VERTEX_HTTP_OPTIONS, + ), + ), + exception_if_mldev="parameter is not supported in Gemini API.", + ), ] pytestmark = pytest_helper.setup( diff --git a/google/genai/tunings.py b/google/genai/tunings.py index f84cc455d..0d442705a 100644 --- a/google/genai/tunings.py +++ b/google/genai/tunings.py @@ -213,6 +213,24 @@ def _CreateTuningJobConfig_to_mldev( if getv(from_object, ['beta']) is not None: raise ValueError('beta parameter is not supported in Gemini API.') + if getv(from_object, ['base_teacher_model']) is not None: + raise ValueError( + 'base_teacher_model parameter is not supported in Gemini API.' + ) + + if getv(from_object, ['tuned_teacher_model_source']) is not None: + raise ValueError( + 'tuned_teacher_model_source parameter is not supported in Gemini API.' + ) + + if getv(from_object, ['sft_loss_weight_multiplier']) is not None: + raise ValueError( + 'sft_loss_weight_multiplier parameter is not supported in Gemini API.' + ) + + if getv(from_object, ['output_uri']) is not None: + raise ValueError('output_uri parameter is not supported in Gemini API.') + return to_object @@ -246,6 +264,16 @@ def _CreateTuningJobConfig_to_vertex( ), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['validation_dataset']) is not None: + setv( + parent_object, + ['distillationSpec'], + _TuningValidationDataset_to_vertex( + getv(from_object, ['validation_dataset']), to_object, root_object + ), + ) + if getv(from_object, ['tuned_model_display_name']) is not None: setv( parent_object, @@ -275,6 +303,14 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['epoch_count']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) + discriminator = getv(root_object, ['config', 'method']) if discriminator is None: discriminator = 'SUPERVISED_FINE_TUNING' @@ -298,6 +334,14 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['learning_rate_multiplier']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'learningRateMultiplier'], + getv(from_object, ['learning_rate_multiplier']), + ) + discriminator = getv(root_object, ['config', 'method']) if discriminator is None: discriminator = 'SUPERVISED_FINE_TUNING' @@ -317,6 +361,14 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['export_last_checkpoint_only']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['export_last_checkpoint_only']) is not None: + setv( + parent_object, + ['distillationSpec', 'exportLastCheckpointOnly'], + getv(from_object, ['export_last_checkpoint_only']), + ) + discriminator = getv(root_object, ['config', 'method']) if discriminator is None: discriminator = 'SUPERVISED_FINE_TUNING' @@ -336,6 +388,14 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['adapter_size']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['adapter_size']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'adapterSize'], + getv(from_object, ['adapter_size']), + ) + if getv(from_object, ['batch_size']) is not None: raise ValueError('batch_size parameter is not supported in Vertex AI.') @@ -365,6 +425,16 @@ def _CreateTuningJobConfig_to_vertex( ), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['evaluation_config']) is not None: + setv( + parent_object, + ['distillationSpec', 'evaluationConfig'], + _EvaluationConfig_to_vertex( + getv(from_object, ['evaluation_config']), to_object, root_object + ), + ) + if getv(from_object, ['labels']) is not None: setv(parent_object, ['labels'], getv(from_object, ['labels'])) @@ -375,6 +445,30 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['beta']), ) + if getv(from_object, ['base_teacher_model']) is not None: + setv( + parent_object, + ['distillationSpec', 'baseTeacherModel'], + getv(from_object, ['base_teacher_model']), + ) + + if getv(from_object, ['tuned_teacher_model_source']) is not None: + setv( + parent_object, + ['distillationSpec', 'tunedTeacherModelSource'], + getv(from_object, ['tuned_teacher_model_source']), + ) + + if getv(from_object, ['sft_loss_weight_multiplier']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'sftLossWeightMultiplier'], + getv(from_object, ['sft_loss_weight_multiplier']), + ) + + if getv(from_object, ['output_uri']) is not None: + setv(parent_object, ['outputUri'], getv(from_object, ['output_uri'])) + return to_object @@ -920,6 +1014,14 @@ def _TuningDataset_to_vertex( getv(from_object, ['gcs_uri']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['gcs_uri']) is not None: + setv( + parent_object, + ['distillationSpec', 'promptDatasetUri'], + getv(from_object, ['gcs_uri']), + ) + discriminator = getv(root_object, ['config', 'method']) if discriminator is None: discriminator = 'SUPERVISED_FINE_TUNING' @@ -939,6 +1041,14 @@ def _TuningDataset_to_vertex( getv(from_object, ['vertex_dataset_resource']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['vertex_dataset_resource']) is not None: + setv( + parent_object, + ['distillationSpec', 'promptDatasetUri'], + getv(from_object, ['vertex_dataset_resource']), + ) + if getv(from_object, ['examples']) is not None: raise ValueError('examples parameter is not supported in Vertex AI.') @@ -1066,6 +1176,13 @@ def _TuningJob_from_vertex( getv(from_object, ['preferenceOptimizationSpec']), ) + if getv(from_object, ['distillationSpec']) is not None: + setv( + to_object, + ['distillation_spec'], + getv(from_object, ['distillationSpec']), + ) + if getv(from_object, ['tuningDataStats']) is not None: setv( to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats']) diff --git a/google/genai/types.py b/google/genai/types.py index e932e0373..0836d62fe 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -813,6 +813,8 @@ class TuningMethod(_common.CaseInSensitiveEnum): """Supervised fine tuning.""" PREFERENCE_TUNING = 'PREFERENCE_TUNING' """Preference optimization tuning.""" + DISTILLATION = 'DISTILLATION' + """Distillation tuning.""" class DocumentState(_common.CaseInSensitiveEnum): @@ -10546,6 +10548,107 @@ class PreferenceOptimizationSpecDict(TypedDict, total=False): ] +class DistillationHyperParameters(_common.BaseModel): + """Hyperparameters for Distillation. + + This data type is not supported in Gemini API. + """ + + adapter_size: Optional[AdapterSize] = Field( + default=None, description="""Optional. Adapter size for distillation.""" + ) + epoch_count: Optional[int] = Field( + default=None, + description="""Optional. Number of complete passes the model makes over the entire training dataset during training.""", + ) + learning_rate_multiplier: Optional[float] = Field( + default=None, + description="""Optional. Multiplier for adjusting the default learning rate.""", + ) + + +class DistillationHyperParametersDict(TypedDict, total=False): + """Hyperparameters for Distillation. + + This data type is not supported in Gemini API. + """ + + adapter_size: Optional[AdapterSize] + """Optional. Adapter size for distillation.""" + + epoch_count: Optional[int] + """Optional. Number of complete passes the model makes over the entire training dataset during training.""" + + learning_rate_multiplier: Optional[float] + """Optional. Multiplier for adjusting the default learning rate.""" + + +DistillationHyperParametersOrDict = Union[ + DistillationHyperParameters, DistillationHyperParametersDict +] + + +class DistillationSpec(_common.BaseModel): + """Distillation tuning spec for tuning.""" + + base_teacher_model: Optional[str] = Field( + default=None, + description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""", + ) + hyper_parameters: Optional[DistillationHyperParameters] = Field( + default=None, + description="""Optional. Hyperparameters for Distillation.""", + ) + pipeline_root_directory: Optional[str] = Field( + default=None, + description="""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""", + ) + student_model: Optional[str] = Field( + default=None, + description="""The student model that is being tuned, e.g., "google/gemma-2b-1.1-it". Deprecated. Use base_model instead.""", + ) + training_dataset_uri: Optional[str] = Field( + default=None, + description="""Deprecated. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + tuned_teacher_model_source: Optional[str] = Field( + default=None, + description="""The resource name of the Tuned teacher model. Format: `projects/{project}/locations/{location}/models/{model}`.""", + ) + validation_dataset_uri: Optional[str] = Field( + default=None, + description="""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""", + ) + + +class DistillationSpecDict(TypedDict, total=False): + """Distillation tuning spec for tuning.""" + + base_teacher_model: Optional[str] + """The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""" + + hyper_parameters: Optional[DistillationHyperParametersDict] + """Optional. Hyperparameters for Distillation.""" + + pipeline_root_directory: Optional[str] + """Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""" + + student_model: Optional[str] + """The student model that is being tuned, e.g., "google/gemma-2b-1.1-it". Deprecated. Use base_model instead.""" + + training_dataset_uri: Optional[str] + """Deprecated. Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""" + + tuned_teacher_model_source: Optional[str] + """The resource name of the Tuned teacher model. Format: `projects/{project}/locations/{location}/models/{model}`.""" + + validation_dataset_uri: Optional[str] + """Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""" + + +DistillationSpecOrDict = Union[DistillationSpec, DistillationSpecDict] + + class GcsDestination(_common.BaseModel): """The Google Cloud Storage location where the output is to be written to.""" @@ -11742,6 +11845,9 @@ class TuningJob(_common.BaseModel): preference_optimization_spec: Optional[PreferenceOptimizationSpec] = Field( default=None, description="""Tuning Spec for Preference Optimization.""" ) + distillation_spec: Optional[DistillationSpec] = Field( + default=None, description="""Tuning Spec for Distillation.""" + ) tuning_data_stats: Optional[TuningDataStats] = Field( default=None, description="""Output only. The tuning data statistics associated with this TuningJob.""", @@ -11845,6 +11951,9 @@ class TuningJobDict(TypedDict, total=False): preference_optimization_spec: Optional[PreferenceOptimizationSpecDict] """Tuning Spec for Preference Optimization.""" + distillation_spec: Optional[DistillationSpecDict] + """Tuning Spec for Distillation.""" + tuning_data_stats: Optional[TuningDataStatsDict] """Output only. The tuning data statistics associated with this TuningJob.""" @@ -12133,7 +12242,7 @@ class CreateTuningJobConfig(_common.BaseModel): ) method: Optional[TuningMethod] = Field( default=None, - description="""The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING). If not set, the default method (SFT) will be used.""", + description="""The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION). If not set, the default method (SFT) will be used.""", ) validation_dataset: Optional[TuningValidationDataset] = Field( default=None, @@ -12184,6 +12293,22 @@ class CreateTuningJobConfig(_common.BaseModel): default=None, description="""Weight for KL Divergence regularization, Preference Optimization tuning only.""", ) + base_teacher_model: Optional[str] = Field( + default=None, + description="""The base teacher model that is being distilled. Distillation only.""", + ) + tuned_teacher_model_source: Optional[str] = Field( + default=None, + description="""The resource name of the Tuned teacher model. Distillation only.""", + ) + sft_loss_weight_multiplier: Optional[float] = Field( + default=None, + description="""Multiplier for adjusting the weight of the SFT loss. Distillation only.""", + ) + output_uri: Optional[str] = Field( + default=None, + description="""The Google Cloud Storage location where the tuning job outputs are written.""", + ) class CreateTuningJobConfigDict(TypedDict, total=False): @@ -12193,7 +12318,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False): """Used to override HTTP request options.""" method: Optional[TuningMethod] - """The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING). If not set, the default method (SFT) will be used.""" + """The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION). If not set, the default method (SFT) will be used.""" validation_dataset: Optional[TuningValidationDatasetDict] """Validation dataset for tuning. The dataset must be formatted as a JSONL file.""" @@ -12234,6 +12359,18 @@ class CreateTuningJobConfigDict(TypedDict, total=False): beta: Optional[float] """Weight for KL Divergence regularization, Preference Optimization tuning only.""" + base_teacher_model: Optional[str] + """The base teacher model that is being distilled. Distillation only.""" + + tuned_teacher_model_source: Optional[str] + """The resource name of the Tuned teacher model. Distillation only.""" + + sft_loss_weight_multiplier: Optional[float] + """Multiplier for adjusting the weight of the SFT loss. Distillation only.""" + + output_uri: Optional[str] + """The Google Cloud Storage location where the tuning job outputs are written.""" + CreateTuningJobConfigOrDict = Union[ CreateTuningJobConfig, CreateTuningJobConfigDict