Skip to content

Commit ca96001

Browse files
committed
Fix model_name requirements
1 parent b24a5c3 commit ca96001

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.0"
15+
version = "1.5.1"
1616
authors = [
1717
"Together AI <support@together.ai>"
1818
]

src/together/cli/api/finetune.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def fine_tuning(ctx: click.Context) -> None:
6060
@click.option(
6161
"--training-file", type=str, required=True, help="Training file ID from Files API"
6262
)
63-
@click.option("--model", type=str, required=True, help="Base model name")
63+
@click.option("--model", type=str, help="Base model name")
6464
@click.option("--n-epochs", type=int, default=1, help="Number of epochs to train for")
6565
@click.option(
6666
"--validation-file", type=str, default="", help="Validation file ID from Files API"
@@ -214,8 +214,15 @@ def create(
214214
from_checkpoint=from_checkpoint,
215215
)
216216

217+
if model is None and from_checkpoint is None:
218+
raise click.BadParameter("You must specify either a model or a checkpoint")
219+
220+
model_name = model
221+
if from_checkpoint is not None:
222+
model_name = from_checkpoint.split(":")[0]
223+
217224
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
218-
model=model
225+
model=model_name
219226
)
220227

221228
if lora:

src/together/resources/finetune.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def createFinetuneRequest(
5151
model_limits: FinetuneTrainingLimits,
5252
training_file: str,
53-
model: str,
53+
model: str | None = None,
5454
n_epochs: int = 1,
5555
validation_file: str | None = "",
5656
n_evals: int | None = 0,
@@ -237,7 +237,7 @@ def create(
237237
self,
238238
*,
239239
training_file: str,
240-
model: str,
240+
model: str | None = None,
241241
n_epochs: int = 1,
242242
validation_file: str | None = "",
243243
n_evals: int | None = 0,
@@ -270,7 +270,7 @@ def create(
270270
271271
Args:
272272
training_file (str): File-ID of a file uploaded to the Together API
273-
model (str): Name of the base model to run fine-tune job on
273+
model (str, optional): Name of the base model to run fine-tune job on
274274
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
275275
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
276276
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -320,12 +320,24 @@ def create(
320320
FinetuneResponse: Object containing information about fine-tuning job.
321321
"""
322322

323+
if model is None and from_checkpoint is None:
324+
raise ValueError("You must specify either a model or a checkpoint")
325+
323326
requestor = api_requestor.APIRequestor(
324327
client=self._client,
325328
)
326329

327330
if model_limits is None:
328-
model_limits = self.get_model_limits(model=model)
331+
# mypy doesn't understand that model or from_checkpoint is not None
332+
if model is not None:
333+
model_name = model
334+
elif from_checkpoint is not None:
335+
model_name = from_checkpoint.split(":")[0]
336+
else:
337+
# this branch is unreachable, but mypy doesn't know that
338+
pass
339+
model_limits = self.get_model_limits(model=model_name)
340+
329341
finetune_request = createFinetuneRequest(
330342
model_limits=model_limits,
331343
training_file=training_file,
@@ -610,7 +622,7 @@ async def create(
610622
self,
611623
*,
612624
training_file: str,
613-
model: str,
625+
model: str | None = None,
614626
n_epochs: int = 1,
615627
validation_file: str | None = "",
616628
n_evals: int | None = 0,
@@ -643,7 +655,7 @@ async def create(
643655
644656
Args:
645657
training_file (str): File-ID of a file uploaded to the Together API
646-
model (str): Name of the base model to run fine-tune job on
658+
model (str, optional): Name of the base model to run fine-tune job on
647659
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
648660
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
649661
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -693,12 +705,23 @@ async def create(
693705
FinetuneResponse: Object containing information about fine-tuning job.
694706
"""
695707

708+
if model is None and from_checkpoint is None:
709+
raise ValueError("You must specify either a model or a checkpoint")
710+
696711
requestor = api_requestor.APIRequestor(
697712
client=self._client,
698713
)
699714

700715
if model_limits is None:
701-
model_limits = await self.get_model_limits(model=model)
716+
# mypy doesn't understand that model or from_checkpoint is not None
717+
if model is not None:
718+
model_name = model
719+
elif from_checkpoint is not None:
720+
model_name = from_checkpoint.split(":")[0]
721+
else:
722+
# this branch is unreachable, but mypy doesn't know that
723+
pass
724+
model_limits = await self.get_model_limits(model=model_name)
702725

703726
finetune_request = createFinetuneRequest(
704727
model_limits=model_limits,

0 commit comments

Comments
 (0)