cloud provider: Use CompletionEvent
type from zed_llm_client
(#35285)
Release Notes: - N/A
This commit is contained in:
parent
77dc65d826
commit
65250fe08d
1 changed files with 12 additions and 19 deletions
|
@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
|
||||||
use util::{ResultExt as _, maybe};
|
use util::{ResultExt as _, maybe};
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||||
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
|
||||||
ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
||||||
};
|
};
|
||||||
|
@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum CloudCompletionEvent<T> {
|
|
||||||
Status(CompletionRequestStatus),
|
|
||||||
Event(T),
|
|
||||||
}
|
|
||||||
|
|
||||||
fn map_cloud_completion_events<T, F>(
|
fn map_cloud_completion_events<T, F>(
|
||||||
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
|
stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
|
||||||
mut map_callback: F,
|
mut map_callback: F,
|
||||||
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||||
where
|
where
|
||||||
|
@ -1063,10 +1056,10 @@ where
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
vec![Err(LanguageModelCompletionError::from(error))]
|
vec![Err(LanguageModelCompletionError::from(error))]
|
||||||
}
|
}
|
||||||
Ok(CloudCompletionEvent::Status(event)) => {
|
Ok(CompletionEvent::Status(event)) => {
|
||||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||||
}
|
}
|
||||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
|
@ -1074,9 +1067,9 @@ where
|
||||||
|
|
||||||
fn usage_updated_event<T>(
|
fn usage_updated_event<T>(
|
||||||
usage: Option<ModelRequestUsage>,
|
usage: Option<ModelRequestUsage>,
|
||||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||||
futures::stream::iter(usage.map(|usage| {
|
futures::stream::iter(usage.map(|usage| {
|
||||||
Ok(CloudCompletionEvent::Status(
|
Ok(CompletionEvent::Status(
|
||||||
CompletionRequestStatus::UsageUpdated {
|
CompletionRequestStatus::UsageUpdated {
|
||||||
amount: usage.amount as usize,
|
amount: usage.amount as usize,
|
||||||
limit: usage.limit,
|
limit: usage.limit,
|
||||||
|
@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
|
||||||
|
|
||||||
fn tool_use_limit_reached_event<T>(
|
fn tool_use_limit_reached_event<T>(
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||||
futures::stream::iter(tool_use_limit_reached.then(|| {
|
futures::stream::iter(tool_use_limit_reached.then(|| {
|
||||||
Ok(CloudCompletionEvent::Status(
|
Ok(CompletionEvent::Status(
|
||||||
CompletionRequestStatus::ToolUseLimitReached,
|
CompletionRequestStatus::ToolUseLimitReached,
|
||||||
))
|
))
|
||||||
}))
|
}))
|
||||||
|
@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
|
||||||
fn response_lines<T: DeserializeOwned>(
|
fn response_lines<T: DeserializeOwned>(
|
||||||
response: Response<AsyncBody>,
|
response: Response<AsyncBody>,
|
||||||
includes_status_messages: bool,
|
includes_status_messages: bool,
|
||||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||||
futures::stream::try_unfold(
|
futures::stream::try_unfold(
|
||||||
(String::new(), BufReader::new(response.into_body())),
|
(String::new(), BufReader::new(response.into_body())),
|
||||||
move |(mut line, mut body)| async move {
|
move |(mut line, mut body)| async move {
|
||||||
|
@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let event = if includes_status_messages {
|
let event = if includes_status_messages {
|
||||||
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
serde_json::from_str::<CompletionEvent<T>>(&line)?
|
||||||
} else {
|
} else {
|
||||||
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||||
};
|
};
|
||||||
|
|
||||||
line.clear();
|
line.clear();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue