Skip to content

Commit

Permalink
add websockets instead of rest api
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayThorve committed Mar 6, 2024
1 parent 748ef90 commit 200cba5
Show file tree
Hide file tree
Showing 17 changed files with 496 additions and 359 deletions.
11 changes: 4 additions & 7 deletions jupyterlab_nvdashboard/apps/cpu.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import psutil
import time
import tornado
from jupyter_server.base.handlers import APIHandler
from jupyterlab_nvdashboard.apps.utils import CustomWebSocketHandler


class CPUResourceHandler(APIHandler):
@tornado.web.authenticated
def get(self):
class CPUResourceWebSocketHandler(CustomWebSocketHandler):
def send_data(self):
now = time.time()
stats = {
"time": now * 1000,
Expand All @@ -18,5 +16,4 @@ def get(self):
"network_read": psutil.net_io_counters().bytes_recv,
"network_write": psutil.net_io_counters().bytes_sent,
}
self.set_header("Content-Type", "application/json")
self.write(json.dumps(stats))
self.write_message(json.dumps(stats))
42 changes: 17 additions & 25 deletions jupyterlab_nvdashboard/apps/gpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from jupyterlab_nvdashboard.apps.utils import CustomWebSocketHandler
import pynvml
import time
import tornado
from jupyter_server.base.handlers import APIHandler

try:
pynvml.nvmlInit()
Expand Down Expand Up @@ -41,19 +40,17 @@
pci_gen = None


class GPUUtilizationHandler(APIHandler):
@tornado.web.authenticated
def get(self):
class GPUUtilizationWebSocketHandler(CustomWebSocketHandler):
def send_data(self):
gpu_utilization = [
pynvml.nvmlDeviceGetUtilizationRates(gpu_handles[i]).gpu
for i in range(ngpus)
]
self.finish(json.dumps({"gpu_utilization": gpu_utilization}))
self.write_message(json.dumps({"gpu_utilization": gpu_utilization}))


class GPUUsageHandler(APIHandler):
@tornado.web.authenticated
def get(self):
class GPUUsageWebSocketHandler(CustomWebSocketHandler):
def send_data(self):
memory_usage = [
pynvml.nvmlDeviceGetMemoryInfo(handle).used
for handle in gpu_handles
Expand All @@ -64,16 +61,15 @@ def get(self):
for handle in gpu_handles
]

self.finish(
self.write_message(
json.dumps(
{"memory_usage": memory_usage, "total_memory": total_memory}
)
)


class GPUResourceHandler(APIHandler):
@tornado.web.authenticated
def get(self):
class GPUResourceWebSocketHandler(CustomWebSocketHandler):
def send_data(self):
now = time.time()
stats = {
"time": now * 1000,
Expand Down Expand Up @@ -118,15 +114,14 @@ def get(self):
stats["gpu_memory_total"] = round(
(stats["gpu_memory_total"] / gpu_mem_sum) * 100, 2
)
self.set_header("Content-Type", "application/json")
self.write(json.dumps(stats))
print("writing message", stats)
self.write_message(json.dumps(stats))


class NVLinkThroughputHandler(APIHandler):
class NVLinkThroughputWebSocketHandler(CustomWebSocketHandler):
prev_throughput = None

@tornado.web.authenticated
def get(self):
def send_data(self):
throughput = [
pynvml.nvmlDeviceGetFieldValues(
handle,
Expand Down Expand Up @@ -162,9 +157,8 @@ def get(self):
# Store the current throughput for the next request
self.prev_throughput = throughput

self.set_header("Content-Type", "application/json")
# Send the change in throughput as part of the response
self.write(
self.write_message(
json.dumps(
{
"nvlink_rx": [
Expand All @@ -191,9 +185,8 @@ def get(self):
)


class PCIStatsHandler(APIHandler):
@tornado.web.authenticated
def get(self):
class PCIStatsWebSocketHandler(CustomWebSocketHandler):
def send_data(self):
# Use device-0 to get "upper bound"
pci_width = pynvml.nvmlDeviceGetMaxPcieLinkWidth(gpu_handles[0])
pci_bw = {
Expand Down Expand Up @@ -231,5 +224,4 @@ def get(self):
"max_rxtx_tp": max_rxtx_tp,
}

self.set_header("Content-Type", "application/json")
self.write(json.dumps(stats))
self.write_message(json.dumps(stats))
31 changes: 31 additions & 0 deletions jupyterlab_nvdashboard/apps/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from tornado.websocket import WebSocketHandler
import tornado
import json


class CustomWebSocketHandler(WebSocketHandler):
def open(self):
self.write_message(json.dumps({"status": "connected"}))
self.set_nodelay(True)
# Start a periodic callback to send data every 50ms
self.callback = tornado.ioloop.PeriodicCallback(self.send_data, 1000)
self.callback.start()

def on_message(self, message):
message_data = json.loads(message)
# Update the periodic callback frequency
new_frequency = message_data["updateFrequency"]
if hasattr(self, "callback"):
self.callback.stop()
self.callback = tornado.ioloop.PeriodicCallback(
self.send_data, new_frequency
)
if not message_data["isPaused"]:
self.callback.start()

def on_close(self):
if hasattr(self, "callback") and self.callback.is_running():
self.callback.stop()

def send_data(self):
pass
12 changes: 6 additions & 6 deletions jupyterlab_nvdashboard/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def setup_handlers(web_app):
base_url, URL_PATH, "nvlink_throughput"
)
handlers += [
(route_pattern_gpu_util, apps.gpu.GPUUtilizationHandler),
(route_pattern_gpu_usage, apps.gpu.GPUUsageHandler),
(route_pattern_gpu_resource, apps.gpu.GPUResourceHandler),
(route_pattern_pci_stats, apps.gpu.PCIStatsHandler),
(route_pattern_gpu_util, apps.gpu.GPUUtilizationWebSocketHandler),
(route_pattern_gpu_usage, apps.gpu.GPUUsageWebSocketHandler),
(route_pattern_gpu_resource, apps.gpu.GPUResourceWebSocketHandler),
(route_pattern_pci_stats, apps.gpu.PCIStatsWebSocketHandler),
(
route_pattern_nvlink_throughput,
apps.gpu.NVLinkThroughputHandler,
apps.gpu.NVLinkThroughputWebSocketHandler,
),
]

Expand All @@ -41,7 +41,7 @@ def setup_handlers(web_app):
)

handlers += [
(route_pattern_cpu_resource, apps.cpu.CPUResourceHandler),
(route_pattern_cpu_resource, apps.cpu.CPUResourceWebSocketHandler),
]

web_app.add_handlers(host_pattern, handlers)
9 changes: 8 additions & 1 deletion schema/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
"type": "object",
"properties": {
"updateFrequency": {
"type": "integer",
"title": "Frequency of Updates",
"description": "The frequency of updates for the GPU Dashboard widgets, in milliseconds.",
"type": "integer",
"default": 100,
"minimum": 1
},
"maxTimeSeriesDataRecords": {
"title": "Maximum Number of Data Records",
"description": "This setting determines the maximum number of data points that can be displayed in each time series chart within Nvdashboard. To apply changes to this setting, please close and reopen the chart window",
"type": "integer",
"default": 1000,
"minimum": 10
}
},
"additionalProperties": false
Expand Down
1 change: 1 addition & 0 deletions src/assets/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ export const WIDGET_TRACKER_NAME = 'gpu-dashboard-widgets';
export const COMMAND_OPEN_SETTINGS = 'settingeditor:open';
export const COMMAND_OPEN_WIDGET = 'gpu-dashboard-widget:open';
export const DEFAULT_UPDATE_FREQUENCY = 100; // ms
export const DEFAULT_MAX_RECORDS_TIMESERIES = 1000; // count
110 changes: 91 additions & 19 deletions src/assets/hooks.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,107 @@
import { ISettingRegistry } from '@jupyterlab/settingregistry';
import { SetStateAction, useEffect } from 'react';
import { DEFAULT_UPDATE_FREQUENCY, PLUGIN_ID_CONFIG } from './constants';
import { SetStateAction, useEffect, useRef } from 'react';
import {
DEFAULT_MAX_RECORDS_TIMESERIES,
DEFAULT_UPDATE_FREQUENCY,
PLUGIN_ID_CONFIG
} from './constants';
import { connectToWebSocket } from '../handler';

function loadSettingRegistry(
settingRegistry: ISettingRegistry,
setUpdateFrequency: {
(value: SetStateAction<number>): void;
(arg0: number): void;
/**
* Updates the settings for update frequency and maximum records for time series charts.
*/
const updateSettings = (
settings: ISettingRegistry.ISettings,
setUpdateFrequency: (value: SetStateAction<number>) => void,
setMaxRecords?: (value: SetStateAction<number>) => void
) => {
setUpdateFrequency(
(settings.get('updateFrequency').composite as number) ||
DEFAULT_UPDATE_FREQUENCY
);
if (setMaxRecords) {
setMaxRecords(
(settings.get('maxTimeSeriesDataRecords').composite as number) ||
DEFAULT_MAX_RECORDS_TIMESERIES
);
}
) {
};

/**
* Loads the setting registry and updates the settings accordingly.
*/
export const loadSettingRegistry = (
settingRegistry: ISettingRegistry,
setUpdateFrequency: (value: SetStateAction<number>) => void,
setIsSettingsLoaded: (value: SetStateAction<boolean>) => void,
setMaxRecords?: (value: SetStateAction<number>) => void
) => {
useEffect(() => {
const loadSettings = async () => {
try {
const settings = await settingRegistry.load(PLUGIN_ID_CONFIG);
const loadedUpdateFrequency =
(settings.get('updateFrequency').composite as number) ||
DEFAULT_UPDATE_FREQUENCY;
setUpdateFrequency(loadedUpdateFrequency);

updateSettings(settings, setUpdateFrequency, setMaxRecords);
settings.changed.connect(() => {
setUpdateFrequency(
(settings.get('updateFrequency').composite as number) ||
DEFAULT_UPDATE_FREQUENCY
);
updateSettings(settings, setUpdateFrequency, setMaxRecords);
});
setIsSettingsLoaded(true);
} catch (error) {
console.error(`An error occurred while loading settings: ${error}`);
}
};
loadSettings();
}, []);
}
};

/**
* Custom hook to establish a WebSocket connection and handle incoming messages.
*/
export const useWebSocket = <T>(
endpoint: string,
isPaused: boolean,
updateFrequency: number,
processData: (response: T, isPaused: boolean) => void,
isSettingsLoaded: boolean
) => {
const wsRef = useRef<WebSocket | null>(null);

useEffect(() => {
if (!isSettingsLoaded) {
return;
}

wsRef.current = connectToWebSocket(endpoint);
const ws = wsRef.current;

ws.onopen = () => {
console.log('WebSocket connected');
};

export default loadSettingRegistry;
ws.onmessage = event => {
const response = JSON.parse(event.data);
if (response.status !== 'connected') {
processData(response, isPaused);
} else {
ws.send(JSON.stringify({ updateFrequency, isPaused }));
}
};

ws.onerror = error => {
console.error('WebSocket error:', error);
};

ws.onclose = () => {
console.log('WebSocket disconnected');
};

return () => {
ws.close();
};
}, [isSettingsLoaded]);

useEffect(() => {
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ updateFrequency, isPaused }));
}
}, [isPaused, updateFrequency]);
};
52 changes: 52 additions & 0 deletions src/assets/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,55 @@ export interface IWidgetInfo {
title: string;
instance: MainAreaWidget;
}

export interface IGpuResourceProps {
time: number;
gpu_utilization_total: number;
gpu_memory_total: number;
rx_total: number;
tx_total: number;
gpu_utilization_individual: number[];
gpu_memory_individual: number[];
}

export interface IGpuUtilizationProps {
gpu_utilization: number[];
}

export interface IGpuUsageProps {
memory_usage: number[];
total_memory: number[];
}

export interface ICPUResourceProps {
time: number;
cpu_utilization: number;
memory_usage: number;
disk_read: number;
disk_write: number;
network_read: number;
network_write: number;
disk_read_current: number;
disk_write_current: number;
network_read_current: number;
network_write_current: number;
}

export interface INVLinkThroughputProps {
nvlink_tx: number[];
nvlink_rx: number[];
max_rxtx_bw: number;
}

export interface INVLinkTimeLineProps {
time: number;
nvlink_tx: number[];
nvlink_rx: number[];
max_rxtx_bw: number;
}

export interface IPCIThroughputProps {
pci_tx: number[];
pci_rx: number[];
max_rxtx_tp: number;
}
Loading

0 comments on commit 200cba5

Please sign in to comment.