|
50 | 50 | def createFinetuneRequest( |
51 | 51 | model_limits: FinetuneTrainingLimits, |
52 | 52 | training_file: str, |
53 | | - model: str, |
| 53 | + model: str | None = None, |
54 | 54 | n_epochs: int = 1, |
55 | 55 | validation_file: str | None = "", |
56 | 56 | n_evals: int | None = 0, |
@@ -237,7 +237,7 @@ def create( |
237 | 237 | self, |
238 | 238 | *, |
239 | 239 | training_file: str, |
240 | | - model: str, |
| 240 | + model: str | None = None, |
241 | 241 | n_epochs: int = 1, |
242 | 242 | validation_file: str | None = "", |
243 | 243 | n_evals: int | None = 0, |
@@ -270,7 +270,7 @@ def create( |
270 | 270 |
|
271 | 271 | Args: |
272 | 272 | training_file (str): File-ID of a file uploaded to the Together API |
273 | | - model (str): Name of the base model to run fine-tune job on |
| 273 | + model (str, optional): Name of the base model to run fine-tune job on |
274 | 274 | n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. |
275 | 275 | validation file (str, optional): File ID of a file uploaded to the Together API for validation. |
276 | 276 | n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. |
@@ -320,12 +320,24 @@ def create( |
320 | 320 | FinetuneResponse: Object containing information about fine-tuning job. |
321 | 321 | """ |
322 | 322 |
|
| 323 | + if model is None and from_checkpoint is None: |
| 324 | + raise ValueError("You must specify either a model or a checkpoint") |
| 325 | + |
323 | 326 | requestor = api_requestor.APIRequestor( |
324 | 327 | client=self._client, |
325 | 328 | ) |
326 | 329 |
|
327 | 330 | if model_limits is None: |
328 | | - model_limits = self.get_model_limits(model=model) |
| 331 | + # mypy doesn't understand that model or from_checkpoint is not None |
| 332 | + if model is not None: |
| 333 | + model_name = model |
| 334 | + elif from_checkpoint is not None: |
| 335 | + model_name = from_checkpoint.split(":")[0] |
| 336 | + else: |
| 337 | + # this branch is unreachable, but mypy doesn't know that |
| 338 | + pass |
| 339 | + model_limits = self.get_model_limits(model=model_name) |
| 340 | + |
329 | 341 | finetune_request = createFinetuneRequest( |
330 | 342 | model_limits=model_limits, |
331 | 343 | training_file=training_file, |
@@ -610,7 +622,7 @@ async def create( |
610 | 622 | self, |
611 | 623 | *, |
612 | 624 | training_file: str, |
613 | | - model: str, |
| 625 | + model: str | None = None, |
614 | 626 | n_epochs: int = 1, |
615 | 627 | validation_file: str | None = "", |
616 | 628 | n_evals: int | None = 0, |
@@ -643,7 +655,7 @@ async def create( |
643 | 655 |
|
644 | 656 | Args: |
645 | 657 | training_file (str): File-ID of a file uploaded to the Together API |
646 | | - model (str): Name of the base model to run fine-tune job on |
| 658 | + model (str, optional): Name of the base model to run fine-tune job on |
647 | 659 | n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. |
648 | 660 | validation file (str, optional): File ID of a file uploaded to the Together API for validation. |
649 | 661 | n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. |
@@ -693,12 +705,23 @@ async def create( |
693 | 705 | FinetuneResponse: Object containing information about fine-tuning job. |
694 | 706 | """ |
695 | 707 |
|
| 708 | + if model is None and from_checkpoint is None: |
| 709 | + raise ValueError("You must specify either a model or a checkpoint") |
| 710 | + |
696 | 711 | requestor = api_requestor.APIRequestor( |
697 | 712 | client=self._client, |
698 | 713 | ) |
699 | 714 |
|
700 | 715 | if model_limits is None: |
701 | | - model_limits = await self.get_model_limits(model=model) |
| 716 | + # mypy doesn't understand that model or from_checkpoint is not None |
| 717 | + if model is not None: |
| 718 | + model_name = model |
| 719 | + elif from_checkpoint is not None: |
| 720 | + model_name = from_checkpoint.split(":")[0] |
| 721 | + else: |
| 722 | + # this branch is unreachable, but mypy doesn't know that |
| 723 | + pass |
| 724 | + model_limits = await self.get_model_limits(model=model_name) |
702 | 725 |
|
703 | 726 | finetune_request = createFinetuneRequest( |
704 | 727 | model_limits=model_limits, |
|
0 commit comments