cloud provider: Use CompletionEvent type from zed_llm_client (#35285)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-07-29 11:28:18 -06:00 committed by GitHub
parent 77dc65d826
commit 65250fe08d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
}
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
Status(CompletionRequestStatus),
Event(T),
}
fn map_cloud_completion_events<T, F>(
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
mut map_callback: F,
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
where
@ -1063,10 +1056,10 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CloudCompletionEvent::Status(event)) => {
Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
Ok(CompletionEvent::Event(event)) => map_callback(event),
})
})
.boxed()
@ -1074,9 +1067,9 @@ where
fn usage_updated_event<T>(
usage: Option<ModelRequestUsage>,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| {
Ok(CloudCompletionEvent::Status(
Ok(CompletionEvent::Status(
CompletionRequestStatus::UsageUpdated {
amount: usage.amount as usize,
limit: usage.limit,
@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| {
Ok(CloudCompletionEvent::Status(
Ok(CompletionEvent::Status(
CompletionRequestStatus::ToolUseLimitReached,
))
}))
@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
includes_status_messages: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())),
move |(mut line, mut body)| async move {
@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
Ok(0) => Ok(None),
Ok(_) => {
let event = if includes_status_messages {
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
serde_json::from_str::<CompletionEvent<T>>(&line)?
} else {
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
};
line.clear();