collab: Track cache writes/reads in LLM usage (#18834)

This PR extends the LLM usage tracking to support tracking usage for
cache writes and reads for Anthropic models.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-07 17:32:49 -04:00 committed by GitHub
parent c5d252b837
commit d55f025906
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 241 additions and 39 deletions

View file

@ -318,22 +318,31 @@ async fn perform_completion(
chunks
.map(move |event| {
let chunk = event?;
let (input_tokens, output_tokens) = match &chunk {
let (
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
) = match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => (
usage.input_tokens.unwrap_or(0) as usize,
usage.output_tokens.unwrap_or(0) as usize,
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
usage.cache_read_input_tokens.unwrap_or(0) as usize,
),
_ => (0, 0),
_ => (0, 0, 0, 0),
};
anyhow::Ok((
serde_json::to_vec(&chunk).unwrap(),
anyhow::Ok(CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
))
cache_creation_input_tokens,
cache_read_input_tokens,
})
})
.boxed()
}
@ -359,11 +368,13 @@ async fn perform_completion(
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
let output_tokens =
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
(
serde_json::to_vec(&chunk).unwrap(),
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
@ -387,13 +398,13 @@ async fn perform_completion(
.map(|event| {
event.map(|chunk| {
// TODO - implement token counting for Google AI
let input_tokens = 0;
let output_tokens = 0;
(
serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
)
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
@ -407,6 +418,8 @@ async fn perform_completion(
model,
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
inner_stream: stream,
})))
}
@ -551,6 +564,14 @@ async fn check_usage_limit(
Ok(())
}
struct CompletionChunk {
bytes: Vec<u8>,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
}
struct TokenCountingStream<S> {
state: Arc<LlmState>,
claims: LlmTokenClaims,
@ -558,22 +579,26 @@ struct TokenCountingStream<S> {
model: String,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
inner_stream: S,
}
impl<S> Stream for TokenCountingStream<S>
where
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
{
type Item = Result<Vec<u8>, anyhow::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
bytes.push(b'\n');
self.input_tokens += input_tokens;
self.output_tokens += output_tokens;
Poll::Ready(Some(Ok(bytes)))
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.input_tokens += chunk.input_tokens;
self.output_tokens += chunk.output_tokens;
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
@ -590,6 +615,8 @@ impl<S> Drop for TokenCountingStream<S> {
let model = std::mem::take(&mut self.model);
let input_token_count = self.input_tokens;
let output_token_count = self.output_tokens;
let cache_creation_input_token_count = self.cache_creation_input_tokens;
let cache_read_input_token_count = self.cache_read_input_tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
@ -599,6 +626,8 @@ impl<S> Drop for TokenCountingStream<S> {
provider,
&model,
input_token_count,
cache_creation_input_token_count,
cache_read_input_token_count,
output_token_count,
Utc::now(),
)
@ -630,11 +659,20 @@ impl<S> Drop for TokenCountingStream<S> {
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
cache_creation_input_token_count: cache_creation_input_token_count
as u64,
cache_read_input_token_count: cache_read_input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
cache_creation_input_tokens_this_month: usage
.cache_creation_input_tokens_this_month
as u64,
cache_read_input_tokens_this_month: usage
.cache_read_input_tokens_this_month
as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
spending_this_month: usage.spending_this_month as u64,
lifetime_spending: usage.lifetime_spending as u64,