Fix usage recording in llm service (#16044)

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-09 11:48:18 -07:00 committed by GitHub
parent eb3c4b0e46
commit b1c69c2178
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -15,10 +15,14 @@ use axum::{
}; };
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use db::{ActiveUserCount, LlmDatabase}; use db::{ActiveUserCount, LlmDatabase};
use futures::StreamExt as _; use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient; use http_client::IsahcHttpClient;
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc; use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use util::ResultExt; use util::ResultExt;
@ -155,7 +159,7 @@ async fn perform_completion(
check_usage_limit(&state, params.provider, &model, &claims).await?; check_usage_limit(&state, params.provider, &model, &claims).await?;
match params.provider { let stream = match params.provider {
LanguageModelProvider::Anthropic => { LanguageModelProvider::Anthropic => {
let api_key = state let api_key = state
.config .config
@ -185,37 +189,27 @@ async fn perform_completion(
) )
.await?; .await?;
let mut recorder = UsageRecorder { chunks
db: state.db.clone(), .map(move |event| {
executor: state.executor.clone(), let chunk = event?;
user_id, let (input_tokens, output_tokens) = match &chunk {
provider: params.provider,
model,
token_count: 0,
};
let stream = chunks.map(move |event| {
let mut buffer = Vec::new();
event.map(|chunk| {
match &chunk {
anthropic::Event::MessageStart { anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. }, message: anthropic::Response { usage, .. },
} }
| anthropic::Event::MessageDelta { usage, .. } => { | anthropic::Event::MessageDelta { usage, .. } => (
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize; usage.input_tokens.unwrap_or(0) as usize,
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize; usage.output_tokens.unwrap_or(0) as usize,
} ),
_ => {} _ => (0, 0),
} };
buffer.clear(); anyhow::Ok((
serde_json::to_writer(&mut buffer, &chunk).unwrap(); serde_json::to_vec(&chunk).unwrap(),
buffer.push(b'\n'); input_tokens,
buffer output_tokens,
))
}) })
}); .boxed()
Ok(Response::new(Body::wrap_stream(stream)))
} }
LanguageModelProvider::OpenAi => { LanguageModelProvider::OpenAi => {
let api_key = state let api_key = state
@ -232,17 +226,21 @@ async fn perform_completion(
) )
.await?; .await?;
let stream = chunks.map(|event| { chunks
let mut buffer = Vec::new(); .map(|event| {
event.map(|chunk| { event.map(|chunk| {
buffer.clear(); let input_tokens =
serde_json::to_writer(&mut buffer, &chunk).unwrap(); chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
buffer.push(b'\n'); let output_tokens =
buffer chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
(
serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
}) })
}); })
.boxed()
Ok(Response::new(Body::wrap_stream(stream)))
} }
LanguageModelProvider::Google => { LanguageModelProvider::Google => {
let api_key = state let api_key = state
@ -258,17 +256,20 @@ async fn perform_completion(
) )
.await?; .await?;
let stream = chunks.map(|event| { chunks
let mut buffer = Vec::new(); .map(|event| {
event.map(|chunk| { event.map(|chunk| {
buffer.clear(); // TODO - implement token counting for Google AI
serde_json::to_writer(&mut buffer, &chunk).unwrap(); let input_tokens = 0;
buffer.push(b'\n'); let output_tokens = 0;
buffer (
serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
}) })
}); })
.boxed()
Ok(Response::new(Body::wrap_stream(stream)))
} }
LanguageModelProvider::Zed => { LanguageModelProvider::Zed => {
let api_key = state let api_key = state
@ -290,19 +291,34 @@ async fn perform_completion(
) )
.await?; .await?;
let stream = chunks.map(|event| { chunks
let mut buffer = Vec::new(); .map(|event| {
event.map(|chunk| { event.map(|chunk| {
buffer.clear(); let input_tokens =
serde_json::to_writer(&mut buffer, &chunk).unwrap(); chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
buffer.push(b'\n'); let output_tokens =
buffer chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
(
serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
}) })
}); })
.boxed()
}
};
Ok(Response::new(Body::wrap_stream(stream))) Ok(Response::new(Body::wrap_stream(TokenCountingStream {
} db: state.db.clone(),
} executor: state.executor.clone(),
user_id,
provider: params.provider,
model,
input_tokens: 0,
output_tokens: 0,
inner_stream: stream,
})))
} }
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
@ -377,22 +393,46 @@ async fn check_usage_limit(
Ok(()) Ok(())
} }
struct UsageRecorder {
struct TokenCountingStream<S> {
db: Arc<LlmDatabase>, db: Arc<LlmDatabase>,
executor: Executor, executor: Executor,
user_id: i32, user_id: i32,
provider: LanguageModelProvider, provider: LanguageModelProvider,
model: String, model: String,
token_count: usize, input_tokens: usize,
output_tokens: usize,
inner_stream: S,
} }
impl Drop for UsageRecorder { impl<S> Stream for TokenCountingStream<S>
where
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
{
type Item = Result<Vec<u8>, anyhow::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
bytes.push(b'\n');
self.input_tokens += input_tokens;
self.output_tokens += output_tokens;
Poll::Ready(Some(Ok(bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) { fn drop(&mut self) {
let db = self.db.clone(); let db = self.db.clone();
let user_id = self.user_id; let user_id = self.user_id;
let provider = self.provider; let provider = self.provider;
let model = std::mem::take(&mut self.model); let model = std::mem::take(&mut self.model);
let token_count = self.token_count; let token_count = self.input_tokens + self.output_tokens;
self.executor.spawn_detached(async move { self.executor.spawn_detached(async move {
db.record_usage(user_id, provider, &model, token_count, Utc::now()) db.record_usage(user_id, provider, &model, token_count, Utc::now())
.await .await