Skip to content

Commit

Permalink
Re-structure dataset info format
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <jay@zijie.wang>
  • Loading branch information
xiaohk committed Feb 8, 2024
1 parent 476adcf commit f884453
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 46 deletions.
3 changes: 2 additions & 1 deletion examples/rag-playground/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ dist-ssr

pnpm-lock.yaml
.vercel
notebooks
notebooks
public/data
56 changes: 42 additions & 14 deletions examples/rag-playground/src/components/playground/playground.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';
import { ifDefined } from 'lit/directives/if-defined.js';
import { EmbeddingModel } from '../../workers/embedding';
import {
UserConfigManager,
Expand All @@ -14,6 +15,7 @@ import { textGenGpt } from '../../llms/gpt';
import { textGenMememo } from '../../llms/mememo-gen';
import { textGenGemini } from '../../llms/gemini';
import TextGenLocalWorkerInline from '../../llms/web-llm?worker&inline';
import { promptTemplates } from '../../config/promptTemplates';

import type { TextGenMessage } from '../../llms/gpt';
import type { EmbeddingWorkerMessage } from '../../workers/embedding';
Expand All @@ -29,18 +31,20 @@ import '../text-viewer/text-viewer';

import componentCSS from './playground.css?inline';
import EmbeddingWorkerInline from '../../workers/embedding?worker&inline';
import promptTemplatesJSON from '../../config/promptTemplates.json';
import logoIcon from '../../images/icon-logo.svg?raw';

interface DatasetInfo {
dataURL: string;
indexURL: string;
indexURL?: string;
datasetName: string;
datasetNameDisplay: string;
}

enum Dataset {
Arxiv = 'arxiv'
arXiv10k = 'arxiv-10k',
arXiv120k = 'arxiv-120k',
DiffusionDB1m = 'diffusiondb-1m',
accident3k = 'accident-3k'
}

enum Arrow {
Expand All @@ -50,14 +54,35 @@ enum Arrow {
Output = 'output'
}

const promptTemplate = promptTemplatesJSON as Record<Dataset, string>;
const promptTemplate = promptTemplates as Record<Dataset, string>;

const datasets: Record<Dataset, DatasetInfo> = {
[Dataset.Arxiv]: {
dataURL: '/data/ml-arxiv-papers-1000.ndjson.gzip',
indexURL: '/data/ml-arxiv-papers-1000-index.json',
datasetName: 'ml-arxiv-papers',
datasetNameDisplay: 'ML arXiv Abstracts (1k)'
[Dataset.arXiv10k]: {
indexURL: '/data/ml-arxiv-papers-10k-index.json.gzip',
dataURL: '/data/ml-arxiv-papers-10k.ndjson.gzip',
datasetName: 'ml-arxiv-papers-10k',
datasetNameDisplay: 'ML arXiv Abstracts (10k)'
},

[Dataset.arXiv120k]: {
indexURL: '/data/ml-arxiv-papers-120k-index.json.gzip',
dataURL: '/data/ml-arxiv-papers-120k.ndjson.gzip',
datasetName: 'ml-arxiv-papers-120k',
datasetNameDisplay: 'ML arXiv Abstracts (120k)'
},

[Dataset.DiffusionDB1m]: {
indexURL: '/data/diffusiondb-pormpts-1m-index.json.gzip',
dataURL: '/data/diffusiondb-pormpts-1m.ndjson.gzip',
datasetName: 'diffusiondb-pormpts-1m',
datasetNameDisplay: 'DiffusionDB Prompts (1M)'
},

[Dataset.accident3k]: {
indexURL: '/data/accident-3k-index.json.gzip',
dataURL: '/data/accident-3k.ndjson.gzip',
datasetName: 'accidents-3k',
datasetNameDisplay: 'AI Accidents (3k)'
}
};

Expand All @@ -82,6 +107,9 @@ export class MememoPlayground extends LitElement {
@state()
llmOutput = '';

@property()
curDataset: Dataset = Dataset.arXiv120k;

embeddingWorker: Worker;
embeddingWorkerRequestCount = 0;
get embeddingWorkerRequestID() {
Expand Down Expand Up @@ -536,18 +564,18 @@ export class MememoPlayground extends LitElement {
<div class="container container-text">
<mememo-text-viewer
dataURL=${datasets['arxiv'].dataURL}
indexURL=${datasets['arxiv'].indexURL}
datasetName=${datasets['arxiv'].datasetName}
datasetNameDisplay=${datasets['arxiv'].datasetNameDisplay}
dataURL=${datasets[this.curDataset].dataURL}
indexURL=${ifDefined(datasets[this.curDataset].indexURL)}
datasetName=${datasets[this.curDataset].datasetName}
datasetNameDisplay=${datasets[this.curDataset].datasetNameDisplay}
@semanticSearchFinished=${(e: CustomEvent<string[]>) =>
this.semanticSearchFinishedHandler(e)}
></mememo-text-viewer>
</div>
<div class="container container-prompt">
<mememo-prompt-box
template=${promptTemplate[Dataset.Arxiv]}
template=${promptTemplate[this.curDataset]}
userQuery=${this.userQuery}
.relevantDocuments=${this.relevantDocuments}
@runButtonClicked=${(e: CustomEvent<string>) => {
Expand Down
75 changes: 66 additions & 9 deletions examples/rag-playground/src/components/text-viewer/text-viewer.css
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,16 @@
overscroll-behavior: none;

.item {
line-height: 1.2;
line-height: 1.25;
width: 100%;
padding: 0 calc(var(--search-inner-padding-h) + var(--box-padding-h));
padding: 4px calc(var(--search-inner-padding-h) + var(--box-padding-h));
border-bottom: 1px solid var(--gray-200);
cursor: default;
padding-bottom: 3px;
padding-top: 3px;
box-sizing: border-box;
position: relative;

&:has(.distance-overlay:not([is-hidden])) {
padding-right: 56px;
padding-right: 60px;
}

&[clamp-line] {
Expand All @@ -242,15 +240,11 @@
position: absolute;
top: 0px;
right: 0px;
/* transform: translate(0, -50%); */
/* float: right; */

font-size: var(--font-d3);
font-variant-numeric: tabular-nums;
padding: 2px 5px;
border-bottom-left-radius: 4px;
/* border-top-left-radius: 5px; */
/* border-radius: 5px; */
text-align: right;

background-color: color-mix(in lab, var(--blue-100) 70%, transparent 30%);
Expand Down Expand Up @@ -348,3 +342,66 @@
animation: circle-loader-animation 1s infinite linear;
}
}

.svg-icon {
display: flex;
justify-content: center;
align-items: center;
width: 1em;
height: 1em;

color: currentColor;
transition: transform 80ms linear;
transform-origin: center;

& svg {
fill: currentColor;
width: 100%;
height: 100%;
}
}

.button-group {
display: flex;
flex-flow: row;
align-items: center;
gap: 8px;
}

button {
all: unset;

display: flex;
line-height: 1;
padding: 4px 6px;
border-radius: 5px;
white-space: nowrap;

cursor: pointer;
user-select: none;
-webkit-user-select: none;

background-color: color-mix(in lab, var(--gray-200), white 20%);
color: var(--gray-800);
display: flex;
flex-flow: row;
align-items: center;
font-size: var(--header-secondary-size);
height: 1em;

&:hover {
background-color: color-mix(in lab, var(--gray-300), white 30%);
}

&:active {
background-color: color-mix(in lab, var(--gray-300), white 20%);
}

.svg-icon {
position: relative;
top: 1px;
color: var(--gray-700);
width: 14px;
height: 14px;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ import componentCSS from './text-viewer.css?inline';
import MememoWorkerInline from '../../workers/mememo-worker?worker&inline';
import searchIcon from '../../images/icon-search.svg?raw';
import crossIcon from '../../images/icon-cross-thick.svg?raw';
import downloadIcon from '../../images/icon-download.svg?raw';
import crossSmallIcon from '../../images/icon-cross.svg?raw';

const MAX_DOCUMENTS_IN_MEMORY = 1000;
const DOCUMENT_INCREMENT = 100;
const LEXICAL_SEARCH_LIMIT = 2000;
const numberFormatter = d3.format(',');

const startLoadingTime = new Date().getUTCSeconds();
const loadingTimes: number[] = [];
const TRACK_LOADING_TIME = false;

/**
* Text viewer element.
*/
Expand Down Expand Up @@ -237,6 +242,12 @@ export class MememoTextViewer extends LitElement {
this.shownDocuments = [...this.shownDocuments, ...documents];
}

// Track the loading time
if (TRACK_LOADING_TIME) {
const now = new Date().getUTCSeconds();
loadingTimes.push(now - startLoadingTime);
}

if (e.data.payload.isLastBatch) {
// Mark the loading has completed
this.markMememoFinishedLoading();
Expand All @@ -248,6 +259,11 @@ export class MememoTextViewer extends LitElement {
// payload: { requestID: 0 }
// };
// this.mememoWorker.postMessage(message);

// Download the loading times
if (TRACK_LOADING_TIME) {
downloadJSON(loadingTimes, undefined, 'loading-times.json');
}
}
break;
}
Expand Down Expand Up @@ -304,7 +320,15 @@ export class MememoTextViewer extends LitElement {

case 'finishExportIndex': {
const indexJSON = e.data.payload.indexJSON;
downloadJSON(indexJSON, undefined, 'index.json');

// Download a compressed file
compressTextGzip(JSON.stringify(indexJSON)).then(
value => {
downloadBlob(value, 'mememo-index.json.gzip');
},
() => {}
);

break;
}

Expand All @@ -314,6 +338,20 @@ export class MememoTextViewer extends LitElement {
}
}

/**
* Export and download the created index
*/
downloadButtonClicked() {
// Export the index to a json file
if (this.isMememoFinishedLoading) {
const message: MememoWorkerMessage = {
command: 'startExportIndex',
payload: { requestID: 0 }
};
this.mememoWorker.postMessage(message);
}
}

//==========================================================================||
// Private Helpers ||
//==========================================================================||
Expand Down Expand Up @@ -429,6 +467,11 @@ export class MememoTextViewer extends LitElement {
<div class="header-bar">
<div class="header">MeMemo Database</div>
<div class="description">${this.datasetNameDisplay}</div>
<div class="button-group">
<button @click=${() => this.downloadButtonClicked()}>
<span class="svg-icon">${unsafeHTML(downloadIcon)}</span>
</button>
</div>
</div>

<div class="search-bar-container">
Expand Down Expand Up @@ -500,3 +543,31 @@ declare global {
'mememo-text-viewer': MememoTextViewer;
}
}

const downloadBlob = (blob: Blob, filename: string) => {
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
URL.revokeObjectURL(url);
};

const compressTextGzip = async (text: string): Promise<Blob> => {
// Create a stream from the text
const textBlob = new Blob([text], { type: 'text/plain' });
const textStream = textBlob.stream();

// Create a GZIP CompressionStream
const gzipStream = new CompressionStream('gzip');

// Pipe the text stream through the gzip compressor
const compressedStream = textStream.pipeThrough(gzipStream);

// Convert the compressed stream back to a Blob
const compressedBlob = await new Response(compressedStream).blob();

return compressedBlob;
};
3 changes: 0 additions & 3 deletions examples/rag-playground/src/config/promptTemplates.json

This file was deleted.

9 changes: 9 additions & 0 deletions examples/rag-playground/src/config/promptTemplates.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
const arXivTemplates =
"You are an expert in machine learning, and you are answering a user's questions about machine learning. The user's question is in <user></user>. You have access to documents in <context></context>. Your answer should be solely based on the provided documents. Provide cite the document source if possible. Answer your question in an <output></output> tag.\n\n<user>{{user}}</user>\n\n<context>{{context}}</context>";

export const promptTemplates = {
'arxiv-10k': arXivTemplates,
'arxiv-120k': arXivTemplates,
'diffusiondb-1m': arXivTemplates,
'accident-3k': arXivTemplates
};
8 changes: 8 additions & 0 deletions examples/rag-playground/src/images/icon-download.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion examples/rag-playground/src/types/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
* Type definitions
*/

export type DocumentRecordStreamData = [string, number[]];
/** [unique key, content, embedding] */
export type DocumentRecordStreamData = [number, string, number[]];

export interface DocumentDBEntry {
id: string;
Expand Down
Loading

0 comments on commit f884453

Please sign in to comment.