Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,13 @@
},
"nullable": true
},
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
Expand Down Expand Up @@ -2156,6 +2163,13 @@
"input_length"
],
"properties": {
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
Expand Down
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ csv = "1.3.0"
ureq = "=2.9"
pyo3 = { workspace = true }
chrono = "0.4.39"
nvml-wrapper = "0.11.0"


[build-dependencies]
Expand Down
1 change: 1 addition & 0 deletions router/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ mod tests {
generated_tokens: 10,
seed: None,
finish_reason: FinishReason::Length,
energy_mj: None,
}),
});
if let ChatEvent::Events(events) = events {
Expand Down
38 changes: 38 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ use tracing::warn;
use utoipa::ToSchema;
use uuid::Uuid;
use validation::Validation;
use nvml_wrapper::Nvml;
use std::sync::OnceLock;

static NVML: OnceLock<Option<Nvml>> = OnceLock::new();

#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
Expand Down Expand Up @@ -1468,6 +1472,9 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}

#[derive(Serialize, ToSchema, Clone)]
Expand Down Expand Up @@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
}
}

pub struct EnergyMonitor;

impl EnergyMonitor {
fn nvml() -> Option<&'static Nvml> {
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
}

pub fn energy_mj(gpu_index: u32) -> Option<u64> {
let nvml = Self::nvml()?;
let device = nvml.device_by_index(gpu_index).ok()?;
device.total_energy_consumption().ok()
}

pub fn total_energy_mj() -> Option<u64> {
let nvml = Self::nvml()?;
let count = nvml.device_count().ok()?;
let mut total = 0;
for i in 0..count {
if let Ok(device) = nvml.device_by_index(i) {
if let Ok(energy) = device.total_energy_consumption() {
total += energy;
}
}
}
Some(total)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
17 changes: 16 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
};
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
Expand Down Expand Up @@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span,
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);

// Do not long ultra long inputs, like image payloads.
Expand All @@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
}
_ => (infer.generate(req).await?, None),
};

let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};

// Token details
let input_length = response._input_length;
Expand Down Expand Up @@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
energy_mj,
})
}
false => None,
Expand Down Expand Up @@ -515,6 +523,7 @@ async fn generate_stream_internal(
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);

tracing::debug!("Input: {}", req.inputs);
Expand Down Expand Up @@ -590,13 +599,19 @@ async fn generate_stream_internal(
queued,
top_tokens,
} => {
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
input_length,
energy_mj,
}),
false => None,
};
Expand Down
Loading