Skip to content

Commit

Permalink
Merge pull request #1 from geniusrise/feat/notebook
Browse files Browse the repository at this point in the history
Create notebooks for each Auto class
  • Loading branch information
ixaxaar authored Jan 19, 2024
2 parents 188fda5 + ff12828 commit 5042d89
Show file tree
Hide file tree
Showing 10 changed files with 860 additions and 1 deletion.
1 change: 1 addition & 0 deletions geniusrise_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .qa import QAAPI, QABulk, QAFineTuner
from .summarization import SummarizationAPI, SummarizationBulk, SummarizationFineTuner
from .translation import TranslationAPI, TranslationBulk, TranslationFineTuner
from .notebook import TextJupyterNotebook
3 changes: 2 additions & 1 deletion geniusrise_text/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def listen(
else:
model_revision = None
tokenizer_revision = None
tokenizer_name = model_name
tokenizer_name = model_name

self.model_name = model_name
self.model_revision = model_revision
self.tokenizer_name = tokenizer_name
Expand Down
16 changes: 16 additions & 0 deletions geniusrise_text/notebook/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# 🧠 Geniusrise
# Copyright (C) 2023 geniusrise.ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .notebook import TextJupyterNotebook
225 changes: 225 additions & 0 deletions geniusrise_text/notebook/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# 🧠 Geniusrise
# Copyright (C) 2023 geniusrise.ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import sys
from jinja2 import Environment, FileSystemLoader, Template
from nbformat import v4 as nbf
import nbformat
from geniusrise import BatchInput, BatchOutput, Bolt, State
from geniusrise.logging import setup_logger
from typing import Any, Dict, List, Optional


class TextJupyterNotebook(Bolt):
def __init__(
self,
input: BatchInput,
output: BatchOutput,
state: State,
**kwargs,
):
super().__init__(input=input, output=output, state=state)
self.log = setup_logger(self)
script_dir = os.path.dirname(os.path.realpath(__file__))
templates_dir = os.path.join(script_dir, "templates")

# Initialize Jinja2 Environment with the correct templates directory
self.env = Environment(loader=FileSystemLoader(templates_dir))

def create(
self,
model_name: str,
model_class: str = "AutoModelForCausalLM",
tokenizer_class: str = "AutoTokenizer",
use_cuda: bool = False,
precision: str = "float16",
quantization: int = 0,
device_map: str | Dict | None = "auto",
torchscript: bool = False,
compile: bool = True,
awq_enabled: bool = False,
flash_attention: bool = False,
port: int = 8888,
password: Optional[str] = None,
**model_args: Any,
):
self.model_class = model_class
self.tokenizer_class = tokenizer_class
self.use_cuda = use_cuda
self.precision = precision
self.quantization = quantization
self.device_map = device_map
self.torchscript = torchscript
self.compile = compile
self.awq_enabled = awq_enabled
self.flash_attention = flash_attention
self.model_args = model_args

if ":" in model_name:
model_revision = model_name.split(":")[1]
tokenizer_revision = model_name.split(":")[1]
model_name = model_name.split(":")[0]
tokenizer_name = model_name
else:
model_revision = None
tokenizer_revision = None
tokenizer_name = model_name

self.model_name = model_name
self.model_revision = model_revision
self.tokenizer_name = tokenizer_name
self.tokenizer_revision = tokenizer_revision

self.env = Environment(loader=FileSystemLoader("./templates"))

# Context for Jinja template
context = {
"model_name": model_name,
"tokenizer_name": tokenizer_name,
"model_revision": model_revision,
"tokenizer_revision": tokenizer_revision,
"model_class": model_class,
"tokenizer_class": tokenizer_class,
"use_cuda": use_cuda,
"precision": precision,
"quantization": quantization,
"device_map": device_map,
"torchscript": torchscript,
"compile": compile,
"awq_enabled": awq_enabled,
"flash_attention": flash_attention,
"model_args": model_args,
}

import os

dir_path = os.path.dirname(os.path.realpath(__file__))

output_path = self.output.output_folder

script_dir = os.path.dirname(os.path.abspath(__file__))
templates_dir = os.path.join(script_dir, "templates")
# fmt: off
class_to_template_map = {
"AutoModelForCausalLM": os.path.join(templates_dir, "AutoModelForCausalLM.jinja"),
"AutoModelForTokenClassification": os.path.join(templates_dir, "AutoModelForTokenClassification.jinja"),
"AutoModelForSequenceClassification": os.path.join(templates_dir, "AutoModelForSequenceClassification.jinja"),
"AutoModelForTableQuestionAnswering": os.path.join(templates_dir, "AutoModelForTableQuestionAnswering.jinja"),
"AutoModelForQuestionAnswering": os.path.join(templates_dir, "AutoModelForQuestionAnswering.jinja"),
"AutoModelForSeq2SeqLM": os.path.join(templates_dir, "AutoModelForSeq2SeqLM.jinja"),
}
# fmt: on

template_name = class_to_template_map[model_class]

self.create_notebook(name=template_name, context=context, output_path=f"{output_path}/notebook.ipynb")

self.install_packages(
[
"jupyterthemes",
"jupyter==1.0.0",
"jupyterlab_legos_ui",
"jupyterlab_darkside_ui",
"theme-darcula",
# "notebook==6.4.12",
# "jupyter_contrib_nbextensions",
]
)

# subprocess.run("jupyter contrib nbextension install --user".split(" "), check=True)

# self.install_jupyter_extensions(
# [
# # "jupyter_nbextensions_configurator",
# "@yeebc/jupyterlab_neon_theme",
# "@yudai-nkt/jupyterlab_city-lights-theme",
# "rise",
# "nbdime",
# ]
# )

self.start_jupyter_server(notebook_dir=output_path, port=port, password=password)

def create_notebook(self, name: str, context: dict, output_path: str):
"""
Create a Jupyter Notebook from a Jinja template.
Args:
context (dict): Context variables to render the template.
output_path (str): Path to save the generated notebook.
"""
# template = self.env.get_template(name)
with open(name, "r") as file:
template_content = file.read()

template = Template(template_content)

notebook_json = template.render(context)
notebook = nbf.reads(notebook_json)

with open(output_path, "w") as f:
nbformat.write(notebook, f)
self.log.info(f"Notebook created at {output_path}")

def start_jupyter_server(self, notebook_dir: str, port: int = 8888, password: Optional[str] = None):
"""
Start a Jupyter Notebook server in the specified directory with an optional port and password.
Args:
notebook_dir (str): Directory where the notebook server should start.
port (int): Port number for the notebook server. Default is 8888.
password (Optional[str]): Password for accessing the notebook server. If None, no password is set.
"""

command = [
"jupyter",
"lab",
# f"--ServerApp.password=''",
"--ip=0.0.0.0",
f"--ServerApp.token={password}",
"--no-browser",
"--port",
str(port),
"--ServerApp.root_dir",
notebook_dir,
]
self.log.info(f"Running command {' '.join(command)}")

subprocess.run(command, check=True) # type: ignore

def install_packages(self, packages: List[str]):
"""
Install Python packages using pip.
Args:
packages (List[str]): List of package names to install.
"""
for package in packages:
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
self.log.info("Required packages installed.")

def install_jupyter_extensions(self, extensions: List[str]):
"""
Install Jupyter Notebook extensions.
Args:
extensions (List[str]): List of Jupyter extension names to install.
"""
for extension in extensions:
subprocess.run(["jupyter", "nbextension", "install", extension, "--user"], check=True)
subprocess.run(["jupyter", "nbextension", "enable", extension, "--user"], check=True)
self.log.info("Jupyter extensions installed and enabled.")
100 changes: 100 additions & 0 deletions geniusrise_text/notebook/templates/AutoModelForCausalLM.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# {{ model_class }} Demonstration\n",
"\n",
"This notebook demonstrates how to load and use the `{{ model_class }}` from Hugging Face's Transformers library."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Importing necessary libraries\n",
"from transformers import {{ model_class }}, {{ tokenizer_class }}\n",
"import torch\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the Model and Tokenizer\n",
"\n",
"Here we load the model and tokenizer. We are using the model `{{ model_name }}` and its corresponding tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Loading the model and tokenizer\n",
"model = {{ model_class }}.from_pretrained('{{ model_name }}', revision='{{ model_revision }}')\n",
"tokenizer = {{ tokenizer_class }}.from_pretrained('{{ tokenizer_name }}', revision='{{ tokenizer_revision }}')\n",
"\n",
"# Additional configurations\n",
"model.to('cuda' if torch.cuda.is_available() and {{ use_cuda }} else 'cpu')\n",
"if '{{ precision }}' == 'float16':\n",
" model = model.half()\n",
"\n",
"# Describe each configuration and its impact here..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Inference\n",
"\n",
"Now, let's use the model to generate some text. We will provide a prompt, and the model will generate a continuation of it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generating text\n",
"prompt = 'Today is a beautiful day'\n",
"inputs = tokenizer(prompt, return_tensors='pt')\n",
"inputs.to(model.device)\n",
"\n",
"# Generate a response\n",
"with torch.no_grad():\n",
" outputs = model.generate(**inputs)\n",
" generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"print(generated_text)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 5042d89

Please sign in to comment.