Skip to content

Commit

Permalink
Inpainting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nico Hiort af Ornäs committed Dec 5, 2024
1 parent b0388c3 commit 9a71779
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/GetImgAiClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public function textToImage(TextToImageRequest $request): ImageResponse

public function inpaint(InpaintingRequest $request): ImageResponse
{
$response = $this->request('POST', '/' . $request->getModel() . '/inpaint', $request->toArray());
$response = $this->request('POST', '/' . $request->getFamily() . '/inpaint', $request->toArray());

return ImageResponse::fromArray($response);
}
Expand Down
41 changes: 28 additions & 13 deletions src/Request/InpaintingRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class InpaintingRequest
private string $prompt;
private string $image;
private string $maskImage;
private string $family;
private string $model;
private ?string $negativePrompt;
private ?string $prompt2;
Expand All @@ -28,6 +29,7 @@ public function __construct(
string $prompt,
string $image,
string $maskImage,
string $family = 'stable-diffusion-xl',
string $model = 'stable-diffusion-xl-v1-0',
?string $negativePrompt = null,
?string $prompt2 = null,
Expand All @@ -44,6 +46,7 @@ public function __construct(
$this->setPrompt($prompt);
$this->setImage($image);
$this->setMaskImage($maskImage);
$this->setFamily($family);
$this->setModel($model);
$this->setNegativePrompt($negativePrompt);
$this->setPrompt2($prompt2);
Expand All @@ -68,23 +71,23 @@ public function setPrompt(string $prompt): void

public function setNegativePrompt(?string $negativePrompt): void
{
if (strlen($negativePrompt) > 2048) {
if ($negativePrompt && strlen($negativePrompt) > 2048) {
throw new InvalidArgumentException('Negative prompt length must be ≤ 2048 characters.');
}
$this->negativePrompt = $negativePrompt;
}

public function setPrompt2(?string $prompt2): void
{
if (strlen($prompt2) > 2048) {
if ($prompt2 && strlen($prompt2) > 2048) {
throw new InvalidArgumentException('Prompt2 length must be ≤ 2048 characters.');
}
$this->prompt2 = $prompt2;
}

public function setNegativePrompt2(?string $negativePrompt2): void
{
if (strlen($negativePrompt2) > 2048) {
if ($negativePrompt2 && strlen($negativePrompt2) > 2048) {
throw new InvalidArgumentException('Negative prompt2 length must be ≤ 2048 characters.');
}
$this->negativePrompt2 = $negativePrompt2;
Expand All @@ -110,6 +113,14 @@ public function setModel(string $model): void
$this->model = $model;
}

public function setFamily(string $family): void
{
if (!in_array($family, ['stable-diffusion-xl'], true)) {
throw new InvalidArgumentException('Unsupported/untested family.');
}
$this->family = $family;
}

public function setStrength(float $strength): void
{
if ($strength < 0 || $strength > 1) {
Expand All @@ -120,23 +131,23 @@ public function setStrength(float $strength): void

public function setWidth(int $width): void
{
if ($width < 256 || $width > 1280) {
throw new InvalidArgumentException('Width must be between 256 and 1280.');
if ($width < 256 || $width > 1536) {
throw new InvalidArgumentException('Width must be between 256 and 1536.');
}
$this->width = $width;
}

public function setHeight(int $height): void
{
if ($height < 256 || $height > 1280) {
throw new InvalidArgumentException('Height must be between 256 and 1280.');
if ($height < 256 || $height > 1536) {
throw new InvalidArgumentException('Height must be between 256 and 1536.');
}
$this->height = $height;
}

public function setSteps(int $steps): void
{
if ($steps < 1 || $steps > 6) {
if ($steps < 1 || $steps > 100) {
throw new InvalidArgumentException('Steps must be between 1 and 6.');
}
$this->steps = $steps;
Expand Down Expand Up @@ -177,12 +188,13 @@ public function setResponseFormat(string $responseFormat): void
public function toArray(): array
{
return array_filter([
'image' => $this->image,
'mask_image' => $this->maskImage,
'model' => $this->model,
'prompt' => $this->prompt,
'negative_prompt' => $this->negativePrompt,
'prompt2' => $this->prompt2,
'negative_prompt2' => $this->negativePrompt2,
'image' => $this->image,
'mask_image' => $this->maskImage,
'strength' => $this->strength,
'width' => $this->width,
'height' => $this->height,
Expand All @@ -191,11 +203,14 @@ public function toArray(): array
'seed' => $this->seed,
'output_format' => $this->outputFormat,
'response_format' => $this->responseFormat,
]);
], function ($value) {
// Keep all values except null
return $value !== null;
});
}

public function getModel(): string
public function getFamily(): string
{
return $this->model;
return $this->family;
}
}

0 comments on commit 9a71779

Please sign in to comment.