Fix usage recording in llm service (#16044)
Release Notes: - N/A Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
eb3c4b0e46
commit
b1c69c2178
1 changed files with 104 additions and 64 deletions
|
@ -15,10 +15,14 @@ use axum::{
|
||||||
};
|
};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use db::{ActiveUserCount, LlmDatabase};
|
use db::{ActiveUserCount, LlmDatabase};
|
||||||
use futures::StreamExt as _;
|
use futures::{Stream, StreamExt as _};
|
||||||
use http_client::IsahcHttpClient;
|
use http_client::IsahcHttpClient;
|
||||||
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use std::sync::Arc;
|
use std::{
|
||||||
|
pin::Pin,
|
||||||
|
sync::Arc,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
|
@ -155,7 +159,7 @@ async fn perform_completion(
|
||||||
|
|
||||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||||
|
|
||||||
match params.provider {
|
let stream = match params.provider {
|
||||||
LanguageModelProvider::Anthropic => {
|
LanguageModelProvider::Anthropic => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
.config
|
.config
|
||||||
|
@ -185,37 +189,27 @@ async fn perform_completion(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut recorder = UsageRecorder {
|
chunks
|
||||||
db: state.db.clone(),
|
.map(move |event| {
|
||||||
executor: state.executor.clone(),
|
let chunk = event?;
|
||||||
user_id,
|
let (input_tokens, output_tokens) = match &chunk {
|
||||||
provider: params.provider,
|
|
||||||
model,
|
|
||||||
token_count: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let stream = chunks.map(move |event| {
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
event.map(|chunk| {
|
|
||||||
match &chunk {
|
|
||||||
anthropic::Event::MessageStart {
|
anthropic::Event::MessageStart {
|
||||||
message: anthropic::Response { usage, .. },
|
message: anthropic::Response { usage, .. },
|
||||||
}
|
}
|
||||||
| anthropic::Event::MessageDelta { usage, .. } => {
|
| anthropic::Event::MessageDelta { usage, .. } => (
|
||||||
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
|
usage.input_tokens.unwrap_or(0) as usize,
|
||||||
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
|
usage.output_tokens.unwrap_or(0) as usize,
|
||||||
}
|
),
|
||||||
_ => {}
|
_ => (0, 0),
|
||||||
}
|
};
|
||||||
|
|
||||||
buffer.clear();
|
anyhow::Ok((
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
serde_json::to_vec(&chunk).unwrap(),
|
||||||
buffer.push(b'\n');
|
input_tokens,
|
||||||
buffer
|
output_tokens,
|
||||||
|
))
|
||||||
})
|
})
|
||||||
});
|
.boxed()
|
||||||
|
|
||||||
Ok(Response::new(Body::wrap_stream(stream)))
|
|
||||||
}
|
}
|
||||||
LanguageModelProvider::OpenAi => {
|
LanguageModelProvider::OpenAi => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
|
@ -232,17 +226,21 @@ async fn perform_completion(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let stream = chunks.map(|event| {
|
chunks
|
||||||
let mut buffer = Vec::new();
|
.map(|event| {
|
||||||
event.map(|chunk| {
|
event.map(|chunk| {
|
||||||
buffer.clear();
|
let input_tokens =
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||||
buffer.push(b'\n');
|
let output_tokens =
|
||||||
buffer
|
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||||
|
(
|
||||||
|
serde_json::to_vec(&chunk).unwrap(),
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
});
|
.boxed()
|
||||||
|
|
||||||
Ok(Response::new(Body::wrap_stream(stream)))
|
|
||||||
}
|
}
|
||||||
LanguageModelProvider::Google => {
|
LanguageModelProvider::Google => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
|
@ -258,17 +256,20 @@ async fn perform_completion(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let stream = chunks.map(|event| {
|
chunks
|
||||||
let mut buffer = Vec::new();
|
.map(|event| {
|
||||||
event.map(|chunk| {
|
event.map(|chunk| {
|
||||||
buffer.clear();
|
// TODO - implement token counting for Google AI
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
let input_tokens = 0;
|
||||||
buffer.push(b'\n');
|
let output_tokens = 0;
|
||||||
buffer
|
(
|
||||||
|
serde_json::to_vec(&chunk).unwrap(),
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
});
|
.boxed()
|
||||||
|
|
||||||
Ok(Response::new(Body::wrap_stream(stream)))
|
|
||||||
}
|
}
|
||||||
LanguageModelProvider::Zed => {
|
LanguageModelProvider::Zed => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
|
@ -290,19 +291,34 @@ async fn perform_completion(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let stream = chunks.map(|event| {
|
chunks
|
||||||
let mut buffer = Vec::new();
|
.map(|event| {
|
||||||
event.map(|chunk| {
|
event.map(|chunk| {
|
||||||
buffer.clear();
|
let input_tokens =
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||||
buffer.push(b'\n');
|
let output_tokens =
|
||||||
buffer
|
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||||
|
(
|
||||||
|
serde_json::to_vec(&chunk).unwrap(),
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
});
|
.boxed()
|
||||||
|
|
||||||
Ok(Response::new(Body::wrap_stream(stream)))
|
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
|
||||||
|
db: state.db.clone(),
|
||||||
|
executor: state.executor.clone(),
|
||||||
|
user_id,
|
||||||
|
provider: params.provider,
|
||||||
|
model,
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: 0,
|
||||||
|
inner_stream: stream,
|
||||||
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||||
|
@ -377,22 +393,46 @@ async fn check_usage_limit(
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
struct UsageRecorder {
|
|
||||||
|
struct TokenCountingStream<S> {
|
||||||
db: Arc<LlmDatabase>,
|
db: Arc<LlmDatabase>,
|
||||||
executor: Executor,
|
executor: Executor,
|
||||||
user_id: i32,
|
user_id: i32,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model: String,
|
model: String,
|
||||||
token_count: usize,
|
input_tokens: usize,
|
||||||
|
output_tokens: usize,
|
||||||
|
inner_stream: S,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for UsageRecorder {
|
impl<S> Stream for TokenCountingStream<S>
|
||||||
|
where
|
||||||
|
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
|
||||||
|
{
|
||||||
|
type Item = Result<Vec<u8>, anyhow::Error>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||||
|
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
|
||||||
|
bytes.push(b'\n');
|
||||||
|
self.input_tokens += input_tokens;
|
||||||
|
self.output_tokens += output_tokens;
|
||||||
|
Poll::Ready(Some(Ok(bytes)))
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Drop for TokenCountingStream<S> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let db = self.db.clone();
|
let db = self.db.clone();
|
||||||
let user_id = self.user_id;
|
let user_id = self.user_id;
|
||||||
let provider = self.provider;
|
let provider = self.provider;
|
||||||
let model = std::mem::take(&mut self.model);
|
let model = std::mem::take(&mut self.model);
|
||||||
let token_count = self.token_count;
|
let token_count = self.input_tokens + self.output_tokens;
|
||||||
self.executor.spawn_detached(async move {
|
self.executor.spawn_detached(async move {
|
||||||
db.record_usage(user_id, provider, &model, token_count, Utc::now())
|
db.record_usage(user_id, provider, &model, token_count, Utc::now())
|
||||||
.await
|
.await
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue