diff --git a/src/blocks/file-blocks/sentence-encoder.tsx b/src/blocks/file-blocks/sentence-encoder.tsx index bc555f9..c074634 100644 --- a/src/blocks/file-blocks/sentence-encoder.tsx +++ b/src/blocks/file-blocks/sentence-encoder.tsx @@ -1,22 +1,27 @@ -import '@tensorflow/tfjs'; -import { FileBlockProps } from "@githubnext/utils"; +import "@tensorflow/tfjs"; +import { FileBlockProps, useTailwindCdn } from "@githubnext/utils"; import { useEffect, useState } from "react"; -import * as use from '@tensorflow-models/universal-sentence-encoder'; -// import { Tensor2D } from "@tensorflow/tfjs"; -// import * as tf from '@tensorflow/tfjs-core'; -// import { UniversalSentenceEncoder } from "@tensorflow-models/universal-sentence-encoder"; +import * as use from "@tensorflow-models/universal-sentence-encoder"; +import { UniversalSentenceEncoderQnA } from "@tensorflow-models/universal-sentence-encoder/dist/use_qna"; // zipWith :: (a -> b -> c) -> [a] -> [b] -> [c] -const zipWith = (f: (a: number, b: number) => number, xs: number[], ys: number[]) => { - const ny = ys.length; - return (xs.length <= ny ? xs : xs.slice(0, ny)).map((x, i) => f(x, ys[i])); -} +const zipWith = ( + f: (a: number, b: number) => number, + xs: number[], + ys: number[] +) => { + const ny = ys.length; + return (xs.length <= ny ? xs : xs.slice(0, ny)).map((x, i) => f(x, ys[i])); +}; // dotProduct :: [Int] -> [Int] -> Int const dotProduct = (xs: number[], ys: number[]) => { - const sum = (xs: number[]) => xs ? xs.reduce((a, b) => a + b, 0) : undefined; - return xs.length === ys.length ? (sum(zipWith((a, b) => a * b, xs, ys))) : undefined; -} + const sum = (xs: number[]) => + xs ? xs.reduce((a, b) => a + b, 0) : undefined; + return xs.length === ys.length + ? sum(zipWith((a, b) => a * b, xs, ys)) + : undefined; +}; interface Response { score: number; @@ -25,68 +30,161 @@ interface Response { interface QueryResult { query: string; - responses: Response[] + responses: Response[]; } export default function (props: FileBlockProps) { - const { content } = props; + const status = useTailwindCdn(); + + const { content } = props; + const input = JSON.parse(content); + + const [editView, setEditView] = useState(false); + const [model, setModel] = useState(); + const [results, setResults] = useState([]); - const [results, setResults] = useState([]); - const input = JSON.parse(content); - - useEffect(() => { - const init = async () => { - console.log("initializing...") - - const model = await use.loadQnA(); - const result = model.embed(input); - const query = result['queryEmbedding'].arraySync() as number[][]; // [numQueries, 100] - const answers = result['responseEmbedding'].arraySync() as number[][]; // [numAnswers, 100] - const queriesLength = input.queries.length; - const answersLength = input.responses.length; - - const tempResults = []; - // go through each query - for (let i = 0; i < queriesLength; i++) { - const temp = []; - // calculate the dot product of the query and each answer - for (let j = 0; j < answersLength; j++) { - temp.push({ - response: input.responses[j], - score: dotProduct(query[i], answers[j]) || 0 - }) - } + // custom edit section + const [customQuestion, setCustomQuestion] = useState(); + const [customAnswer, setCustomAnswer] = useState(); + const [computedScore, setComputedScore] = useState(); - tempResults.push({ - query: input.queries[i], - responses: temp - }) + const computeScore = async () => { + if (model && customQuestion && customAnswer) { + const result = model.embed({ + queries: [customQuestion], + responses: [customAnswer], + }); + const query = result["queryEmbedding"].arraySync() as number[][]; // [1, 100] + const answers = result["responseEmbedding"].arraySync() as number[][]; // [1, 100] + setComputedScore(dotProduct(query[0], answers[0]) || 0); + } + }; + + useEffect(() => { + const init = async () => { + console.log("initializing..."); + + const model = await use.loadQnA(); + setModel(model); + const result = model.embed(input); + const query = result["queryEmbedding"].arraySync() as number[][]; // [numQueries, 100] + const answers = result["responseEmbedding"].arraySync() as number[][]; // [numAnswers, 100] + const queriesLength = input.queries.length; + const answersLength = input.responses.length; + + const tempResults = []; + // go through each query + for (let i = 0; i < queriesLength; i++) { + const temp = []; + // calculate the dot product of the query and each answer + for (let j = 0; j < answersLength; j++) { + temp.push({ + response: input.responses[j], + score: dotProduct(query[i], answers[j]) || 0, + }); } - setResults(tempResults); + + tempResults.push({ + query: input.queries[i], + responses: temp, + }); } - init(); - }, []); - - return ( -
-

Sentence Encoder Results

- {results ? results.map((query, i) => ( -
+ setResults(tempResults); + }; + init(); + }, []); + + return ( + <> + {status === "ready" && model ? ( +
+
+

+ Sentence Encoder +

+ +
+ + {editView ? (
-

{query.query}

- {query.responses.map((response, j) => ( -
- - {response.response} — - - {response.score.toFixed(2)} -
- ))} +
+ setCustomQuestion(e.target.value)} + type="text" + placeholder="Question" + className="px-3 py-3 placeholder-blueGray-300 text-blueGray-600 relative bg-white bg-white rounded text-sm border border-blueGray-300 outline-none focus:outline-none focus:ring w-full" + /> +
+
+ setCustomAnswer(e.target.value)} + type="text" + placeholder="Answer" + className="px-3 py-3 placeholder-blueGray-300 text-blueGray-600 relative bg-white bg-white rounded text-sm border border-blueGray-300 outline-none focus:outline-none focus:ring w-full" + /> +
+ {computedScore ? ( +
Score: {computedScore.toFixed(2)}
+ ) : null} +
-
-
- )) :
Loading...
} -
- ) - } - \ No newline at end of file + ) : results ? ( + results.map((query, i) => ( +
+
+ + + + + + + + + + {query.responses.map((response, j) => ( + + + + + + ))} + +
QuestionsAnswerScore
+ {query.query} + + {response.response} + + {response.score.toFixed(2)} +
+
+
+
+ )) + ) : ( +
Loading...
+ )} +
+ ) : ( +
Loading...
+ )} + + ); +}