Skip to content

Commit

Permalink
Fix canvas context bug and update model loading
Browse files Browse the repository at this point in the history
logic
  • Loading branch information
salim laimeche committed Jul 27, 2024
1 parent e8b76a7 commit 88ed0b6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 50 deletions.
1 change: 1 addition & 0 deletions app/lib/cocossd/detect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const cocossdVideoInference = async (
if (context) {
context.drawImage(video, 0, 0, video.videoWidth, video.videoHeight)
const predictions = await net.detect(video)
console.log(predictions)
predictions.forEach(prediction => {
const [x, y, width, height] = prediction.bbox
context.strokeStyle = "#00FFFF"
Expand Down
4 changes: 2 additions & 2 deletions app/lib/yolov8n/detect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ export const detectVideo = (vidSource, model, canvasRef) => {
*/
const detectFrame = async () => {
if (vidSource.videoWidth === 0 && vidSource.srcObject === null) {
const ctx = canvasRef.getContext("2d")
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height) // clean canvas
const ctx = canvasRef?.getContext("2d") // get canvas context
ctx?.clearRect(0, 0, ctx?.canvas.width, ctx?.canvas.height) // clean canvas
return // handle if source is closed
}

Expand Down
8 changes: 4 additions & 4 deletions app/lib/yolov8n/renderBox.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ export const renderBoxes = (
const score = (scores_data[i] * 100).toFixed(1)

let [y1, x1, y2, x2] = boxes_data.slice(i * 4, (i + 1) * 4)
x1 *= ratios[0]
x2 *= ratios[0]
y1 *= ratios[1]
y2 *= ratios[1]
// x1 *= ratios[0]
// x2 *= ratios[0]
// y1 *= ratios[1]
// y2 *= ratios[1]
const width = x2 - x1
const height = y2 - y1

Expand Down
6 changes: 1 addition & 5 deletions app/video-inference/page.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { redirect } from "next/navigation"
import { UserView } from "../lib/identity/definition"
import { verifySession } from "../lib/identity/session-local"
import VideoInference from "@/components/VideoInference"

export default async function VideoInferencePage() {
const session = await verifySession()

const user: UserView = {
id: session?.userId as string,
name: session?.name as string,
Expand All @@ -13,9 +13,5 @@ export default async function VideoInferencePage() {
container: session?.container as string,
}

if (!user.chatid) {
redirect("/parameter/telegram")
}

return <VideoInference user={user} />
}
77 changes: 38 additions & 39 deletions components/VideoInference.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { useEffect, useState, useRef } from "react"
import { ModelComputerVision } from "@/models/model-list"
import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"
import { Badge } from "./ui/badge"
import { set } from "date-fns"
import { cocossdVideoInference } from "@/app/lib/cocossd/detect"
import { detectVideo } from "@/app/lib/yolov8n/detect"

Expand All @@ -27,12 +26,11 @@ interface IProps {
}

export default function VideoInference({ user }: IProps) {
const [model, setModel] = useState<ObjectDetection | null>(null)
const [coco, setCoco] = useState<ObjectDetection | null>(null)
const [yolo, setYolo] = useState({
net: null,
inputShape: [1, 0, 0, 3],
}) // init model & input shape
const [loading, setLoading] = useState({ loading: true }) // loading state
const [modelName, setModelName] = useState<string>("")
const [loadModel, setLoadModel] = useState<boolean>(false)
const [videoSrc, setVideoSrc] = useState<string | null>(null)
Expand All @@ -49,45 +47,46 @@ export default function VideoInference({ user }: IProps) {
}, [])

useEffect(() => {
console.log("model", modelName)
if (modelName === ModelComputerVision.COCO_SSD) {
yolo?.net?.dispose()
setLoadModel(true)
load()
.then(loadedModel => setModel(loadedModel))
.catch(err => console.error(err))
.finally(() => setLoadModel(false))
}
tf.getBackend() !== "webgl" && tf.setBackend("webgl")
tf.ready().then(() => {
console.log("model", modelName)
if (modelName === ModelComputerVision.COCO_SSD) {
yolo?.net?.dispose()
setLoadModel(true)
load()
.then(loadedModel => setCoco(loadedModel))
.catch(err => console.error(err))
.finally(() => setLoadModel(false))
}

if (modelName === ModelComputerVision.YOLOV8N) {
model?.dispose()
setLoadModel(true)
setModel(null)
tf.ready().then(async () => {
const yolov8 = await tf.loadGraphModel(
`https://huggingface.co/salim4n/yolov8n_web_model/resolve/main/model.json`,
{
onProgress: fractions => {
setLoading({ loading: true }) // set loading fractions
console.log(`Loading YOLOv8n: ${fractions * 100}%`)
},
}
) // load model
if (modelName === ModelComputerVision.YOLOV8N) {
coco?.dispose()
setLoadModel(true)
setCoco(null)
tf.ready().then(async () => {
const yolov8 = await tf.loadGraphModel(
`https://huggingface.co/salim4n/yolov8n_web_model/resolve/main/model.json`,
{
onProgress: fractions => {
console.log(`Loading YOLOv8n: ${fractions * 100}%`)
},
}
) // load model

// warming up model
const dummyInput = tf.ones(yolov8.inputs[0].shape)
const warmupResults = yolov8.execute(dummyInput)
// warming up model
const dummyInput = tf.ones(yolov8.inputs[0].shape)
const warmupResults = yolov8.execute(dummyInput)

setLoading({ loading: false })
setYolo({
net: yolov8,
inputShape: yolov8.inputs[0].shape,
}) // set model & input shape
setYolo({
net: yolov8,
inputShape: yolov8.inputs[0].shape,
}) // set model & input shape

tf.dispose([warmupResults, dummyInput]) // cleanup memory
setLoadModel(false)
})
}
tf.dispose([warmupResults, dummyInput]) // cleanup memory
setLoadModel(false)
})
}
})
}, [modelName])

const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
Expand All @@ -100,7 +99,7 @@ export default function VideoInference({ user }: IProps) {

const handleCreateVideoWithBoundingBox = () => {
if (modelName === ModelComputerVision.COCO_SSD) {
cocossdVideoInference(model, videoRef, canvasRef)
cocossdVideoInference(coco, videoRef, canvasRef)
} else if (modelName === ModelComputerVision.YOLOV8N) {
detectVideo(videoRef.current, yolo, canvasRef)
}
Expand Down

0 comments on commit 88ed0b6

Please sign in to comment.