cloud provider: Use CompletionEvent type from zed_llm_client (#35285)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-07-29 11:28:18 -06:00 committed by GitHub
parent 77dc65d826
commit 65250fe08d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe}; use util::{ResultExt as _, maybe};
use zed_llm_client::{ use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
}; };
@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
Status(CompletionRequestStatus),
Event(T),
}
fn map_cloud_completion_events<T, F>( fn map_cloud_completion_events<T, F>(
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>, stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
mut map_callback: F, mut map_callback: F,
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
where where
@ -1063,10 +1056,10 @@ where
Err(error) => { Err(error) => {
vec![Err(LanguageModelCompletionError::from(error))] vec![Err(LanguageModelCompletionError::from(error))]
} }
Ok(CloudCompletionEvent::Status(event)) => { Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
} }
Ok(CloudCompletionEvent::Event(event)) => map_callback(event), Ok(CompletionEvent::Event(event)) => map_callback(event),
}) })
}) })
.boxed() .boxed()
@ -1074,9 +1067,9 @@ where
fn usage_updated_event<T>( fn usage_updated_event<T>(
usage: Option<ModelRequestUsage>, usage: Option<ModelRequestUsage>,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| { futures::stream::iter(usage.map(|usage| {
Ok(CloudCompletionEvent::Status( Ok(CompletionEvent::Status(
CompletionRequestStatus::UsageUpdated { CompletionRequestStatus::UsageUpdated {
amount: usage.amount as usize, amount: usage.amount as usize,
limit: usage.limit, limit: usage.limit,
@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
fn tool_use_limit_reached_event<T>( fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| { futures::stream::iter(tool_use_limit_reached.then(|| {
Ok(CloudCompletionEvent::Status( Ok(CompletionEvent::Status(
CompletionRequestStatus::ToolUseLimitReached, CompletionRequestStatus::ToolUseLimitReached,
)) ))
})) }))
@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
fn response_lines<T: DeserializeOwned>( fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>, response: Response<AsyncBody>,
includes_status_messages: bool, includes_status_messages: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::try_unfold( futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())), (String::new(), BufReader::new(response.into_body())),
move |(mut line, mut body)| async move { move |(mut line, mut body)| async move {
@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
Ok(0) => Ok(None), Ok(0) => Ok(None),
Ok(_) => { Ok(_) => {
let event = if includes_status_messages { let event = if includes_status_messages {
serde_json::from_str::<CloudCompletionEvent<T>>(&line)? serde_json::from_str::<CompletionEvent<T>>(&line)?
} else { } else {
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
}; };
line.clear(); line.clear();