Skip to content

Commit

Permalink
feat: add 4k support (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
baptadn authored Jan 14, 2023
1 parent 69f7776 commit a3d64cc
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 23 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ REPLICATE_API_TOKEN=
REPLICATE_USERNAME=
REPLICATE_MAX_TRAIN_STEPS=3000
REPLICATE_NEGATIVE_PROMPT="cropped face, cover face, cover visage, mutated hands"
REPLICATE_HD_VERSION_MODEL_ID=
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
SECRET=
EMAIL_FROM=
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ REPLICATE_API_TOKEN=
REPLICATE_USERNAME=
REPLICATE_MAX_TRAIN_STEPS=3000
REPLICATE_NEGATIVE_PROMPT=
REPLICATE_HD_VERSION_MODEL_ID=
// Replicate instance token (should be rare)
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
Expand Down
5 changes: 5 additions & 0 deletions prisma/migrations/20230114145652_hd_status/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- CreateEnum
CREATE TYPE "HdStatus" AS ENUM ('NO', 'PENDING', 'PROCESSED');

-- AlterTable
ALTER TABLE "Shot" ADD COLUMN "hdStatus" "HdStatus" NOT NULL DEFAULT 'NO';
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Shot" ADD COLUMN "hdPredictionId" TEXT;
2 changes: 2 additions & 0 deletions prisma/migrations/20230114160322_url_hd/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Shot" ADD COLUMN "hdOutputUrl" TEXT;
11 changes: 10 additions & 1 deletion prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ generator client {
provider = "prisma-client-js"
}

enum HdStatus {
NO
PENDING
PROCESSED
}

model Account {
id String @id @default(cuid())
userId String @map("user_id")
Expand Down Expand Up @@ -76,7 +82,7 @@ model Project {
userId String?
shots Shot[]
credits Int @default(100)
promptWizardCredits Int @default(30)
promptWizardCredits Int @default(20)
Payment Payment[]
}

Expand All @@ -93,6 +99,9 @@ model Shot {
bookmarked Boolean? @default(false)
blurhash String?
seed Int?
hdStatus HdStatus @default(NO)
hdPredictionId String?
hdOutputUrl String?
}

model Payment {
Expand Down
2 changes: 1 addition & 1 deletion src/components/home/Pricing.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ const Pricing = () => {
<b>1</b> Studio with a <b>custom trained model</b>
</CheckedListItem>
<CheckedListItem>
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars 4K
generation
</CheckedListItem>
<CheckedListItem>
Expand Down
4 changes: 2 additions & 2 deletions src/components/projects/FormPayment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ const FormPayment = ({
<b>1</b> Studio with a <b>custom trained model</b>
</CheckedListItem>
<CheckedListItem>
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars
generation (512x512 resolution)
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars 4K
generation
</CheckedListItem>
<CheckedListItem>
<b>30</b> AI prompt assists
Expand Down
112 changes: 95 additions & 17 deletions src/components/projects/shot/ShotCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,27 @@ import { BsHeart, BsHeartFill } from "react-icons/bs";
import { HiDownload } from "react-icons/hi";
import { IoMdCheckmarkCircleOutline } from "react-icons/io";
import { MdOutlineModelTraining } from "react-icons/md";
import { Ri4KFill } from "react-icons/ri";
import { useMutation, useQuery } from "react-query";
import ShotImage from "./ShotImage";
import { TbFaceIdError } from "react-icons/tb";

const getHdLabel = (shot: Shot, isHd: boolean) => {
if (shot.hdStatus === "NO") {
return "Generate in 4K";
}

if (shot.hdStatus === "PENDING") {
return "4K in progress";
}

if (shot.hdStatus === "PROCESSED" && isHd) {
return "Show standard resolution";
}

return "Show 4K";
};

const ShotCard = ({
shot: initialShot,
handleSeed,
Expand All @@ -37,6 +54,7 @@ const ShotCard = ({
const { onCopy, hasCopied } = useClipboard(initialShot.prompt);
const { query } = useRouter();
const [shot, setShot] = useState(initialShot);
const [isHd, setIsHd] = useState(Boolean(shot.hdOutputUrl));

const { mutate: bookmark, isLoading } = useMutation(
`update-shot-${initialShot.id}`,
Expand All @@ -54,6 +72,19 @@ const ShotCard = ({
}
);

const { mutate: createdHd, isLoading: isCreatingHd } = useMutation(
`create-hd-${initialShot.id}`,
() =>
axios.post<{ shot: Shot }>(
`/api/projects/${query.id}/predictions/${initialShot.id}/hd`
),
{
onSuccess: (response) => {
setShot(response.data.shot);
},
}
);

useQuery(
`shot-${initialShot.id}`,
() =>
Expand All @@ -65,10 +96,33 @@ const ShotCard = ({
{
refetchInterval: (data) => (data?.shot.outputUrl ? false : 5000),
refetchOnWindowFocus: false,
enabled: !initialShot.outputUrl,
enabled: !initialShot.outputUrl && initialShot.status !== "failed",
initialData: { shot: initialShot },
onSuccess: (response) => {
setShot(response.shot);
},
}
);

useQuery(
`shot-hd-${initialShot.id}`,
() =>
axios
.get<{ shot: Shot }>(
`/api/projects/${query.id}/predictions/${initialShot.id}/hd`
)
.then((res) => res.data),
{
refetchInterval: (data) =>
data?.shot.hdStatus !== "PENDING" ? false : 5000,
refetchOnWindowFocus: false,
enabled: shot.hdStatus === "PENDING",
initialData: { shot: initialShot },
onSuccess: (response) => {
setShot(response.shot);
if (response.shot.hdOutputUrl) {
setIsHd(true);
}
},
}
);
Expand All @@ -82,7 +136,7 @@ const ShotCard = ({
position="relative"
>
{shot.outputUrl ? (
<ShotImage shot={shot} />
<ShotImage isHd={isHd} shot={shot} />
) : (
<Box>
<AspectRatio ratio={1}>
Expand All @@ -104,10 +158,7 @@ const ShotCard = ({
</Box>
)}
<Flex position="relative" p={3} flexDirection="column">
<Flex alignItems="center" justifyContent="space-between">
<Text color="blackAlpha.700" fontSize="xs">
{formatRelative(new Date(shot.createdAt), new Date())}
</Text>
<Flex alignItems="center" justifyContent="flex-end">
<Box>
{shot.seed && shot.outputUrl && (
<Tooltip hasArrow label="Re-use style">
Expand All @@ -129,17 +180,41 @@ const ShotCard = ({
</Tooltip>
)}
{shot.outputUrl && (
<IconButton
size="sm"
as={Link}
href={shot.outputUrl}
target="_blank"
variant="ghost"
aria-label="Download"
fontSize="md"
icon={<HiDownload />}
/>
<>
<IconButton
size="sm"
as={Link}
href={isHd ? shot.hdOutputUrl : shot.outputUrl}
target="_blank"
variant="ghost"
aria-label="Download"
fontSize="md"
icon={<HiDownload />}
/>
<Tooltip hasArrow label={getHdLabel(shot, isHd)}>
<IconButton
icon={<Ri4KFill />}
color={isHd ? "red.400" : "gray.600"}
isLoading={shot.hdStatus === "PENDING" || isCreatingHd}
onClick={() => {
if (shot.hdStatus === "NO") {
createdHd();
} else if (
shot.hdStatus === "PROCESSED" &&
shot.hdOutputUrl
) {
setIsHd(!isHd);
}
}}
size="sm"
variant="ghost"
aria-label="Make 4K"
fontSize="lg"
/>
</Tooltip>
</>
)}

<Tooltip
hasArrow
label={`${shot.bookmarked ? "Remove" : "Add"} to your gallery`}
Expand Down Expand Up @@ -168,7 +243,10 @@ const ShotCard = ({
{shot.prompt}
</Text>

<HStack mt={4}>
<HStack justifyContent="space-between" mt={4}>
<Text color="beige.400" fontSize="xs">
{formatRelative(new Date(shot.createdAt), new Date())}
</Text>
<Button
rightIcon={hasCopied ? <IoMdCheckmarkCircleOutline /> : undefined}
colorScheme="beige"
Expand Down
4 changes: 2 additions & 2 deletions src/components/projects/shot/ShotImage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { useRouter } from "next/router";
import React from "react";
import { Controlled as ControlledZoom } from "react-medium-image-zoom";

const ShotImage = ({ shot }: { shot: Shot }) => {
const ShotImage = ({ shot, isHd = false }: { shot: Shot; isHd?: boolean }) => {
const { push, query } = useRouter();
const { onOpen, onClose, isOpen: isZoomed } = useDisclosure();

Expand Down Expand Up @@ -34,7 +34,7 @@ const ShotImage = ({ shot }: { shot: Shot }) => {
placeholder="blur"
blurDataURL={shot.blurhash || "placeholder"}
alt={shot.prompt}
src={shot.outputUrl!}
src={isHd ? shot.hdOutputUrl! : shot.outputUrl!}
width={512}
height={512}
unoptimized
Expand Down
75 changes: 75 additions & 0 deletions src/pages/api/projects/[id]/predictions/[predictionId]/hd.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import replicateClient from "@/core/clients/replicate";
import db from "@/core/db";
import { NextApiRequest, NextApiResponse } from "next";
import { getSession } from "next-auth/react";

const handler = async (req: NextApiRequest, res: NextApiResponse) => {
const projectId = req.query.id as string;
const predictionId = req.query.predictionId as string;

const session = await getSession({ req });

if (!session?.user) {
return res.status(401).json({ message: "Not authenticated" });
}

const project = await db.project.findFirstOrThrow({
where: { id: projectId, userId: session.userId },
});

let shot = await db.shot.findFirstOrThrow({
where: { projectId: project.id, id: predictionId },
});

if (req.method === "POST") {
if (shot.hdStatus !== "NO") {
return res.status(400).json({ message: "4K already applied" });
}

const { data } = await replicateClient.post(
`https://api.replicate.com/v1/predictions`,
{
input: {
image: shot.outputUrl,
upscale: 8,
face_upsample: true,
codeformer_fidelity: 1,
},
version: process.env.REPLICATE_HD_VERSION_MODEL_ID,
}
);

shot = await db.shot.update({
where: { id: shot.id },
data: { hdStatus: "PENDING", hdPredictionId: data.id },
});

return res.json({ shot });
}

if (req.method === "GET") {
if (shot.hdStatus !== "PENDING") {
return res.status(400).json({ message: "4K already applied" });
}

const { data: prediction } = await replicateClient.get(
`https://api.replicate.com/v1/predictions/${shot.hdPredictionId}`
);

if (prediction.output) {
shot = await db.shot.update({
where: { id: shot.id },
data: {
hdStatus: "PROCESSED",
hdOutputUrl: prediction.output,
},
});
}

return res.json({ shot });
}

return res.status(405).json({ message: "Method not allowed" });
};

export default handler;

1 comment on commit a3d64cc

@vercel
Copy link

@vercel vercel bot commented on a3d64cc Jan 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.