zeta: Extract usage information from response headers (#28999)
This PR updates the Zeta provider to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll need to figure out where this needs to get threaded through to in order to display it in the UI. Release Notes: - N/A
This commit is contained in:
parent
8660101b83
commit
be63d51eb7
3 changed files with 48 additions and 9 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
@ -18361,10 +18361,11 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9ee4d410dbc030c3e6e3af78fc76296f6bebe20dcb6d7d3fa24bca306fc8c1ce"
|
checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"strum 0.27.1",
|
"strum 0.27.1",
|
||||||
|
|
|
@ -604,7 +604,7 @@ wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
wit-component = "0.221"
|
wit-component = "0.221"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.5.1"
|
zed_llm_client = "0.6.0"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
metal = "0.29"
|
metal = "0.29"
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ mod rate_completion_modal;
|
||||||
|
|
||||||
pub(crate) use completion_diff_element::*;
|
pub(crate) use completion_diff_element::*;
|
||||||
use db::kvp::KEY_VALUE_STORE;
|
use db::kvp::KEY_VALUE_STORE;
|
||||||
|
use http_client::http::{HeaderMap, HeaderValue};
|
||||||
pub use init::*;
|
pub use init::*;
|
||||||
use inline_completion::DataCollectionState;
|
use inline_completion::DataCollectionState;
|
||||||
use license_detection::LICENSE_FILES_TO_CHECK;
|
use license_detection::LICENSE_FILES_TO_CHECK;
|
||||||
|
@ -54,8 +55,9 @@ use workspace::Workspace;
|
||||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId};
|
use workspace::notifications::{ErrorMessagePrompt, NotificationId};
|
||||||
use worktree::Worktree;
|
use worktree::Worktree;
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
|
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
|
EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
|
||||||
PredictEditsResponse,
|
PredictEditsResponse, UsageLimit,
|
||||||
};
|
};
|
||||||
|
|
||||||
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
|
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
|
||||||
|
@ -74,6 +76,32 @@ const MAX_EVENT_COUNT: usize = 16;
|
||||||
|
|
||||||
actions!(edit_prediction, [ClearHistory]);
|
actions!(edit_prediction, [ClearHistory]);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub limit: UsageLimit,
|
||||||
|
pub amount: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Usage {
|
||||||
|
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
|
||||||
|
let limit = headers
|
||||||
|
.get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header")
|
||||||
|
})?;
|
||||||
|
let limit = UsageLimit::from_str(limit.to_str()?)?;
|
||||||
|
|
||||||
|
let amount = headers
|
||||||
|
.get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header")
|
||||||
|
})?;
|
||||||
|
let amount = amount.to_str()?.parse::<i32>()?;
|
||||||
|
|
||||||
|
Ok(Self { limit, amount })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||||
pub struct InlineCompletionId(Uuid);
|
pub struct InlineCompletionId(Uuid);
|
||||||
|
|
||||||
|
@ -359,7 +387,7 @@ impl Zeta {
|
||||||
) -> Task<Result<Option<InlineCompletion>>>
|
) -> Task<Result<Option<InlineCompletion>>>
|
||||||
where
|
where
|
||||||
F: FnOnce(PerformPredictEditsParams) -> R + 'static,
|
F: FnOnce(PerformPredictEditsParams) -> R + 'static,
|
||||||
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
|
R: Future<Output = Result<(PredictEditsResponse, Option<Usage>)>> + Send + 'static,
|
||||||
{
|
{
|
||||||
let snapshot = self.report_changes_for_buffer(&buffer, cx);
|
let snapshot = self.report_changes_for_buffer(&buffer, cx);
|
||||||
let diagnostic_groups = snapshot.diagnostic_groups(None);
|
let diagnostic_groups = snapshot.diagnostic_groups(None);
|
||||||
|
@ -467,7 +495,7 @@ impl Zeta {
|
||||||
body,
|
body,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
let response = match response {
|
let (response, usage) = match response {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if err.is::<ZedUpdateRequiredError>() {
|
if err.is::<ZedUpdateRequiredError>() {
|
||||||
|
@ -503,6 +531,14 @@ impl Zeta {
|
||||||
|
|
||||||
log::debug!("completion response: {}", &response.output_excerpt);
|
log::debug!("completion response: {}", &response.output_excerpt);
|
||||||
|
|
||||||
|
if let Some(usage) = usage {
|
||||||
|
let limit = match usage.limit {
|
||||||
|
UsageLimit::Limited(limit) => limit.to_string(),
|
||||||
|
UsageLimit::Unlimited => "unlimited".to_string(),
|
||||||
|
};
|
||||||
|
log::info!("edit prediction usage: {} / {}", usage.amount, limit);
|
||||||
|
}
|
||||||
|
|
||||||
Self::process_completion_response(
|
Self::process_completion_response(
|
||||||
response,
|
response,
|
||||||
buffer,
|
buffer,
|
||||||
|
@ -685,7 +721,7 @@ and then another
|
||||||
use std::future::ready;
|
use std::future::ready;
|
||||||
|
|
||||||
self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
|
self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
|
||||||
ready(Ok(response))
|
ready(Ok((response, None)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -714,7 +750,7 @@ and then another
|
||||||
|
|
||||||
fn perform_predict_edits(
|
fn perform_predict_edits(
|
||||||
params: PerformPredictEditsParams,
|
params: PerformPredictEditsParams,
|
||||||
) -> impl Future<Output = Result<PredictEditsResponse>> {
|
) -> impl Future<Output = Result<(PredictEditsResponse, Option<Usage>)>> {
|
||||||
async move {
|
async move {
|
||||||
let PerformPredictEditsParams {
|
let PerformPredictEditsParams {
|
||||||
client,
|
client,
|
||||||
|
@ -760,9 +796,11 @@ and then another
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
|
let usage = Usage::from_headers(response.headers()).ok();
|
||||||
|
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
return Ok(serde_json::from_str(&body)?);
|
return Ok((serde_json::from_str(&body)?, usage));
|
||||||
} else if !did_retry
|
} else if !did_retry
|
||||||
&& response
|
&& response
|
||||||
.headers()
|
.headers()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue