cloud provider: Use CompletionEvent
type from zed_llm_client
(#35285)
Release Notes: - N/A
This commit is contained in:
parent
77dc65d826
commit
65250fe08d
1 changed files with 12 additions and 19 deletions
|
@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
|
|||
use util::{ResultExt as _, maybe};
|
||||
use zed_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
|
@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CloudCompletionEvent<T> {
|
||||
Status(CompletionRequestStatus),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
fn map_cloud_completion_events<T, F>(
|
||||
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
|
||||
stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
|
||||
mut map_callback: F,
|
||||
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
where
|
||||
|
@ -1063,10 +1056,10 @@ where
|
|||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Status(event)) => {
|
||||
Ok(CompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
||||
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
|
@ -1074,9 +1067,9 @@ where
|
|||
|
||||
fn usage_updated_event<T>(
|
||||
usage: Option<ModelRequestUsage>,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||
futures::stream::iter(usage.map(|usage| {
|
||||
Ok(CloudCompletionEvent::Status(
|
||||
Ok(CompletionEvent::Status(
|
||||
CompletionRequestStatus::UsageUpdated {
|
||||
amount: usage.amount as usize,
|
||||
limit: usage.limit,
|
||||
|
@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
|
|||
|
||||
fn tool_use_limit_reached_event<T>(
|
||||
tool_use_limit_reached: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||
futures::stream::iter(tool_use_limit_reached.then(|| {
|
||||
Ok(CloudCompletionEvent::Status(
|
||||
Ok(CompletionEvent::Status(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
))
|
||||
}))
|
||||
|
@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
|
|||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
includes_status_messages: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
) -> impl Stream<Item = Result<CompletionEvent<T>>> {
|
||||
futures::stream::try_unfold(
|
||||
(String::new(), BufReader::new(response.into_body())),
|
||||
move |(mut line, mut body)| async move {
|
||||
|
@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
|
|||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event = if includes_status_messages {
|
||||
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
||||
serde_json::from_str::<CompletionEvent<T>>(&line)?
|
||||
} else {
|
||||
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||
CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||
};
|
||||
|
||||
line.clear();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue