diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 3d13b6f812..f35d77533a 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -15,10 +15,14 @@ use axum::{ }; use chrono::{DateTime, Duration, Utc}; use db::{ActiveUserCount, LlmDatabase}; -use futures::StreamExt as _; +use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; -use std::sync::Arc; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use tokio::sync::RwLock; use util::ResultExt; @@ -155,7 +159,7 @@ async fn perform_completion( check_usage_limit(&state, params.provider, &model, &claims).await?; - match params.provider { + let stream = match params.provider { LanguageModelProvider::Anthropic => { let api_key = state .config @@ -185,37 +189,27 @@ async fn perform_completion( ) .await?; - let mut recorder = UsageRecorder { - db: state.db.clone(), - executor: state.executor.clone(), - user_id, - provider: params.provider, - model, - token_count: 0, - }; - - let stream = chunks.map(move |event| { - let mut buffer = Vec::new(); - event.map(|chunk| { - match &chunk { + chunks + .map(move |event| { + let chunk = event?; + let (input_tokens, output_tokens) = match &chunk { anthropic::Event::MessageStart { message: anthropic::Response { usage, .. }, } - | anthropic::Event::MessageDelta { usage, .. } => { - recorder.token_count += usage.input_tokens.unwrap_or(0) as usize; - recorder.token_count += usage.output_tokens.unwrap_or(0) as usize; - } - _ => {} - } + | anthropic::Event::MessageDelta { usage, .. } => ( + usage.input_tokens.unwrap_or(0) as usize, + usage.output_tokens.unwrap_or(0) as usize, + ), + _ => (0, 0), + }; - buffer.clear(); - serde_json::to_writer(&mut buffer, &chunk).unwrap(); - buffer.push(b'\n'); - buffer + anyhow::Ok(( + serde_json::to_vec(&chunk).unwrap(), + input_tokens, + output_tokens, + )) }) - }); - - Ok(Response::new(Body::wrap_stream(stream))) + .boxed() } LanguageModelProvider::OpenAi => { let api_key = state @@ -232,17 +226,21 @@ async fn perform_completion( ) .await?; - let stream = chunks.map(|event| { - let mut buffer = Vec::new(); - event.map(|chunk| { - buffer.clear(); - serde_json::to_writer(&mut buffer, &chunk).unwrap(); - buffer.push(b'\n'); - buffer + chunks + .map(|event| { + event.map(|chunk| { + let input_tokens = + chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize; + let output_tokens = + chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize; + ( + serde_json::to_vec(&chunk).unwrap(), + input_tokens, + output_tokens, + ) + }) }) - }); - - Ok(Response::new(Body::wrap_stream(stream))) + .boxed() } LanguageModelProvider::Google => { let api_key = state @@ -258,17 +256,20 @@ async fn perform_completion( ) .await?; - let stream = chunks.map(|event| { - let mut buffer = Vec::new(); - event.map(|chunk| { - buffer.clear(); - serde_json::to_writer(&mut buffer, &chunk).unwrap(); - buffer.push(b'\n'); - buffer + chunks + .map(|event| { + event.map(|chunk| { + // TODO - implement token counting for Google AI + let input_tokens = 0; + let output_tokens = 0; + ( + serde_json::to_vec(&chunk).unwrap(), + input_tokens, + output_tokens, + ) + }) }) - }); - - Ok(Response::new(Body::wrap_stream(stream))) + .boxed() } LanguageModelProvider::Zed => { let api_key = state @@ -290,19 +291,34 @@ async fn perform_completion( ) .await?; - let stream = chunks.map(|event| { - let mut buffer = Vec::new(); - event.map(|chunk| { - buffer.clear(); - serde_json::to_writer(&mut buffer, &chunk).unwrap(); - buffer.push(b'\n'); - buffer + chunks + .map(|event| { + event.map(|chunk| { + let input_tokens = + chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize; + let output_tokens = + chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize; + ( + serde_json::to_vec(&chunk).unwrap(), + input_tokens, + output_tokens, + ) + }) }) - }); - - Ok(Response::new(Body::wrap_stream(stream))) + .boxed() } - } + }; + + Ok(Response::new(Body::wrap_stream(TokenCountingStream { + db: state.db.clone(), + executor: state.executor.clone(), + user_id, + provider: params.provider, + model, + input_tokens: 0, + output_tokens: 0, + inner_stream: stream, + }))) } fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { @@ -377,22 +393,46 @@ async fn check_usage_limit( Ok(()) } -struct UsageRecorder { + +struct TokenCountingStream { db: Arc, executor: Executor, user_id: i32, provider: LanguageModelProvider, model: String, - token_count: usize, + input_tokens: usize, + output_tokens: usize, + inner_stream: S, } -impl Drop for UsageRecorder { +impl Stream for TokenCountingStream +where + S: Stream, usize, usize), anyhow::Error>> + Unpin, +{ + type Item = Result, anyhow::Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner_stream).poll_next(cx) { + Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => { + bytes.push(b'\n'); + self.input_tokens += input_tokens; + self.output_tokens += output_tokens; + Poll::Ready(Some(Ok(bytes))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for TokenCountingStream { fn drop(&mut self) { let db = self.db.clone(); let user_id = self.user_id; let provider = self.provider; let model = std::mem::take(&mut self.model); - let token_count = self.token_count; + let token_count = self.input_tokens + self.output_tokens; self.executor.spawn_detached(async move { db.record_usage(user_id, provider, &model, token_count, Utc::now()) .await