Skip to content

Commit

Permalink
feat: use fix token for prompted (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
baptadn authored Dec 18, 2022
1 parent 35a8f7b commit d1fe49c
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 63 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ S3_UPLOAD_BUCKET=
S3_UPLOAD_REGION=
REPLICATE_API_TOKEN=
REPLICATE_USERNAME=
REPLICATE_MAX_TRAIN_STEPS=3000
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
SECRET=
EMAIL_FROM=
EMAIL_SERVER=smtp://localhost:1080
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ S3_UPLOAD_REGION=
// Replicate API token / username
REPLICATE_API_TOKEN=
REPLICATE_USERNAME=
REPLICATE_MAX_TRAIN_STEPS=3000
// Replicate instance token (should be rare)
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
// Random secret for NextAuth
SECRET=
Expand Down
1 change: 0 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
"superjson": "^1.12.0",
"typescript": "4.9.3",
"uniqid": "^5.4.0",
"url-slug": "^3.0.4",
"windups": "^1.2.1"
},
"devDependencies": {
Expand Down
23 changes: 8 additions & 15 deletions src/components/dashboard/Uploader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const Uploader = ({ handleOnAdd }: { handleOnAdd: () => void }) => {
const [uploadState, setUploadState] = useState<TUploadState>("not_uploaded");
const [errorMessages, setErrorMessages] = useState<string[]>([]);
const [urls, setUrls] = useState<string[]>([]);
const [instanceName, setInstanceName] = useState<string>("");
const [studioName, setStudioName] = useState<string>("");
const [instanceClass, setInstanceClass] = useState<string>("man");
const { uploadToS3 } = useS3Upload();
const toast = useToast();
Expand Down Expand Up @@ -102,7 +102,7 @@ const Uploader = ({ handleOnAdd }: { handleOnAdd: () => void }) => {
() =>
axios.post("/api/projects", {
urls,
instanceName,
studioName,
instanceClass,
}),
{
Expand All @@ -112,7 +112,7 @@ const Uploader = ({ handleOnAdd }: { handleOnAdd: () => void }) => {
// Reset
setFiles([]);
setUrls([]);
setInstanceName("");
setStudioName("");
setInstanceClass("");
setUploadState("not_uploaded");

Expand Down Expand Up @@ -262,17 +262,10 @@ const Uploader = ({ handleOnAdd }: { handleOnAdd: () => void }) => {
<Input
isRequired
backgroundColor="white"
placeholder="Subject name"
value={instanceName}
onChange={(e) => setInstanceName(e.currentTarget.value)}
placeholder="Studio name"
value={studioName}
onChange={(e) => setStudioName(e.currentTarget.value)}
/>
<FormHelperText color="blackAlpha.600">
This name will be use to name your person in your prompt:{" "}
<b>{`Painting of ${
instanceName || "Alice"
} ${instanceClass} by Andy Warhol`}</b>
.
</FormHelperText>
</FormControl>
<FormControl>
<Select
Expand All @@ -293,12 +286,12 @@ const Uploader = ({ handleOnAdd }: { handleOnAdd: () => void }) => {
</FormControl>
<Box>
<Button
disabled={!Boolean(instanceName)}
disabled={!Boolean(studioName)}
isLoading={isLoading}
variant="brand"
rightIcon={<MdCheckCircle />}
onClick={() => {
if (instanceName && instanceClass) {
if (studioName && instanceClass) {
handleCreateProject();
}
}}
Expand Down
3 changes: 2 additions & 1 deletion src/components/projects/ProjectCard.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { getRefinedStudioName } from "@/core/utils/projects";
import { ProjectWithShots } from "@/pages/studio/[id]";
import {
Avatar,
Expand Down Expand Up @@ -68,7 +69,7 @@ const ProjectCard = ({
<Flex width="100%">
<Box flex="1">
<Text fontSize="2xl" fontWeight="semibold">
Studio <b>{project.instanceName}</b>{" "}
Studio <b>{getRefinedStudioName(project)}</b>{" "}
{isReady && (
<Badge colorScheme="teal">{project.credits} shots left</Badge>
)}
Expand Down
15 changes: 3 additions & 12 deletions src/components/projects/PromptPanel.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import BuyShotButton from "@/components/projects/shot/BuyShotButton";
import { getRefinedInstanceClass } from "@/core/utils/predictions";
import useProjectContext from "@/hooks/use-project-context";
import {
Box,
Expand Down Expand Up @@ -85,11 +84,7 @@ const PromptPanel = () => {
focusBorderColor="gray.400"
_focus={{ shadow: "md" }}
mr={2}
placeholder={`a portrait of a ${
project.instanceName
} ${getRefinedInstanceClass(
project.instanceClass
)} as an astronaut, highly-detailed, trending on artstation`}
placeholder={`a portrait of a @me as an astronaut, highly-detailed, trending on artstation`}
/>

<Button
Expand Down Expand Up @@ -139,12 +134,8 @@ const PromptPanel = () => {
</HStack>
) : (
<Text fontSize="md">
<Icon as={BsLightbulb} /> Use the keyword{" "}
<b>
{project.instanceName}{" "}
{getRefinedInstanceClass(project.instanceClass)}
</b>{" "}
as the subject in your prompt. Need prompt inspiration? Check{" "}
<Icon as={BsLightbulb} /> Use the keyword <b>@me</b> as the subject in
your prompt. Need prompt inspiration? Check{" "}
<ChakraLink
textDecoration="underline"
isExternal
Expand Down
23 changes: 11 additions & 12 deletions src/core/utils/predictions.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import { Project } from "@prisma/client";

export const getRefinedInstanceClass = (instanceClass: string) => {
return instanceClass === "man" || instanceClass === "woman"
? "person"
: instanceClass;
};

export const getTrainCoefficient = (imagesCount: number) => {
if (imagesCount > 25) {
return 25;
}

if (imagesCount < 10) {
return 10;
}

return imagesCount;
};

export const extractSeedFromLogs = (logsContent: string) => {
try {
const logLines = logsContent.split("\n");
Expand All @@ -27,3 +17,12 @@ export const extractSeedFromLogs = (logsContent: string) => {
return undefined;
}
};

export const replacePromptToken = (prompt: string, project: Project) => {
const refinedPrompt = prompt.replaceAll(
"@me",
`${project.instanceName} ${getRefinedInstanceClass(project.instanceClass)}`
);

return refinedPrompt;
};
11 changes: 11 additions & 0 deletions src/core/utils/projects.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Project } from "@prisma/client";

export const getRefinedStudioName = (project: Project) => {
if (
project.instanceName === process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN
) {
return project.name;
}

return project.instanceName;
};
7 changes: 7 additions & 0 deletions src/core/utils/prompts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export const getPrompts = (type: "viking", name: string) => {
const prompts = {
viking: `close up portrait of ${name} person as a viking, full visage, volumetric lighting, beautiful, golden hour, sharp focus, ultra detailed, cgsociety by leesha hannigan, ross tran, thierry doizon, kai carpenter, ignacio fernandez rios, noir photorealism, film`,
};

return prompts[type];
};
5 changes: 3 additions & 2 deletions src/pages/api/projects/[id]/predictions/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import replicateClient from "@/core/clients/replicate";
import db from "@/core/db";
import { replacePromptToken } from "@/core/utils/predictions";
import { NextApiRequest, NextApiResponse } from "next";
import { getSession } from "next-auth/react";

Expand All @@ -26,7 +27,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
`https://api.replicate.com/v1/predictions`,
{
input: {
prompt,
prompt: replacePromptToken(prompt, project),
...(seed && { seed }),
},
version: project.modelVersionId,
Expand All @@ -35,7 +36,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {

const shot = await db.shot.create({
data: {
prompt: data.input.prompt,
prompt,
replicateId: data.id,
status: "starting",
projectId: project.id,
Expand Down
14 changes: 4 additions & 10 deletions src/pages/api/projects/[id]/train.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ import db from "@/core/db";
import { NextApiRequest, NextApiResponse } from "next";
import { getSession } from "next-auth/react";
import replicateClient from "@/core/clients/replicate";
import {
getRefinedInstanceClass,
getTrainCoefficient,
} from "@/core/utils/predictions";
import { getRefinedInstanceClass } from "@/core/utils/predictions";

const handler = async (req: NextApiRequest, res: NextApiResponse) => {
const projectId = req.query.id as string;
Expand All @@ -25,20 +22,17 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
});

const instanceClass = getRefinedInstanceClass(project.instanceClass);
const trainCoefficient = getTrainCoefficient(project.imageUrls.length);

const responseReplicate = await replicateClient.post(
"/v1/trainings",
{
input: {
instance_prompt: `a photo of a ${project.instanceName} ${instanceClass}`,
instance_prompt: `a photo of a ${process.env.INSTANCE_TOKEN} ${instanceClass}`,
class_prompt: `a photo of a ${instanceClass}`,
instance_data: `https://${process.env.S3_UPLOAD_BUCKET}.s3.amazonaws.com/${project.id}.zip`,
max_train_steps: trainCoefficient * 80,
num_class_images: trainCoefficient * 12,
lr_warmup_steps: Math.round((trainCoefficient * 80) / 10),
max_train_steps: Number(process.env.REPLICATE_MAX_TRAIN_STEPS),
num_class_images: 200,
learning_rate: 1e-6,
lr_scheduler: "polynomial",
},
model: `${process.env.REPLICATE_USERNAME}/${project.name}`,
webhook_completed: `${process.env.NEXTAUTH_URL}/api/webhooks/completed`,
Expand Down
8 changes: 3 additions & 5 deletions src/pages/api/projects/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import { NextApiRequest, NextApiResponse } from "next";
import { getSession } from "next-auth/react";
import db from "@/core/db";
import uniqid from "uniqid";
import { createZipFolder } from "@/core/utils/assets";
import s3Client from "@/core/clients/s3";
import { PutObjectCommand } from "@aws-sdk/client-s3";
import replicateClient from "@/core/clients/replicate";
import urlSlug from "url-slug";

const handler = async (req: NextApiRequest, res: NextApiResponse) => {
const session = await getSession({ req });
Expand All @@ -17,17 +15,17 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {

if (req.method === "POST") {
const urls = req.body.urls as string[];
const instanceName = req.body.instanceName as string;
const studioName = req.body.studioName as string;
const instanceClass = req.body.instanceClass as string;

const project = await db.project.create({
data: {
imageUrls: urls,
name: uniqid(),
name: studioName,
userId: session.userId,
modelStatus: "not_created",
instanceClass: instanceClass || "person",
instanceName: urlSlug(instanceName, { separator: "" }),
instanceName: process.env.NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN!,
credits: Number(process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT) || 50,
},
});
Expand Down
5 changes: 0 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4976,11 +4976,6 @@ uri-js@^4.2.2:
dependencies:
punycode "^2.1.0"

url-slug@^3.0.4:
version "3.0.4"
resolved "https://registry.yarnpkg.com/url-slug/-/url-slug-3.0.4.tgz#3278b556666389cd88d9210a12e7577dafac6d57"
integrity sha512-C880WJTo68O4J59i+w9Yp4P0iDUrMeCAVwNvHAYjSc59XsO4pkPfe8bsEHEyZrEfWWy8BDyxMQi6BPVHp/GEbg==

use-callback-ref@^1.3.0:
version "1.3.0"
resolved "https://registry.yarnpkg.com/use-callback-ref/-/use-callback-ref-1.3.0.tgz#772199899b9c9a50526fedc4993fc7fa1f7e32d5"
Expand Down

1 comment on commit d1fe49c

@vercel
Copy link

@vercel vercel bot commented on d1fe49c Dec 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.