Change cloud language model provider JSON protocol to surface errors and usage information (#29830)
Release Notes: - N/A --------- Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
3984531a45
commit
c3d9cdecab
8 changed files with 128 additions and 197 deletions
|
@ -9,12 +9,11 @@ use futures::{
|
|||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
|
||||
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
|
@ -36,9 +35,10 @@ use strum::IntoEnumIterator;
|
|||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use zed_llm_client::{
|
||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
|
||||
};
|
||||
|
||||
|
@ -517,7 +517,7 @@ struct PerformLlmCompletionResponse {
|
|||
response: Response<AsyncBody>,
|
||||
usage: Option<RequestUsage>,
|
||||
tool_use_limit_reached: bool,
|
||||
includes_queue_events: bool,
|
||||
includes_status_messages: bool,
|
||||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
|
@ -545,25 +545,31 @@ impl CloudLanguageModel {
|
|||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header("x-zed-client-supports-queueing", "true")
|
||||
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let includes_queue_events = response
|
||||
let includes_status_messages = response
|
||||
.headers()
|
||||
.get("x-zed-server-supports-queueing")
|
||||
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
|
||||
.is_some();
|
||||
|
||||
let tool_use_limit_reached = response
|
||||
.headers()
|
||||
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
|
||||
.is_some();
|
||||
let usage = RequestUsage::from_headers(response.headers()).ok();
|
||||
|
||||
let usage = if includes_status_messages {
|
||||
None
|
||||
} else {
|
||||
RequestUsage::from_headers(response.headers()).ok()
|
||||
};
|
||||
|
||||
return Ok(PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
});
|
||||
} else if response
|
||||
|
@ -767,28 +773,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
_cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
self.stream_completion_with_usage(request, cx)
|
||||
.map(|result| result.map(|(stream, _)| stream))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion_with_usage(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<(
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
Option<RequestUsage>,
|
||||
)>,
|
||||
> {
|
||||
let thread_id = request.thread_id.clone();
|
||||
let prompt_id = request.prompt_id.clone();
|
||||
|
@ -804,11 +794,11 @@ impl LanguageModel for CloudLanguageModel {
|
|||
);
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
|
@ -840,32 +830,26 @@ impl LanguageModel for CloudLanguageModel {
|
|||
})?;
|
||||
|
||||
let mut mapper = AnthropicEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
Ok(map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_status_messages)
|
||||
.chain(usage_updated_event(usage))
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
usage,
|
||||
move |event| mapper.map_event(event),
|
||||
))
|
||||
});
|
||||
async move {
|
||||
let (stream, usage) = future.await?;
|
||||
Ok((stream.boxed(), usage))
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
|
@ -882,32 +866,26 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
|
||||
let mut mapper = OpenAiEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
Ok(map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_status_messages)
|
||||
.chain(usage_updated_event(usage))
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
usage,
|
||||
move |event| mapper.map_event(event),
|
||||
))
|
||||
});
|
||||
async move {
|
||||
let (stream, usage) = future.await?;
|
||||
Ok((stream.boxed(), usage))
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
|
@ -924,22 +902,16 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
Ok(map_cloud_completion_events(
|
||||
Box::pin(
|
||||
response_lines(response, includes_status_messages)
|
||||
.chain(usage_updated_event(usage))
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
usage,
|
||||
move |event| mapper.map_event(event),
|
||||
))
|
||||
});
|
||||
async move {
|
||||
let (stream, usage) = future.await?;
|
||||
Ok((stream.boxed(), usage))
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -948,7 +920,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CloudCompletionEvent<T> {
|
||||
System(CompletionRequestStatus),
|
||||
Status(CompletionRequestStatus),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
|
@ -968,8 +940,8 @@ where
|
|||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::System(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
|
||||
Ok(CloudCompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
||||
})
|
||||
|
@ -977,11 +949,24 @@ where
|
|||
.boxed()
|
||||
}
|
||||
|
||||
fn usage_updated_event<T>(
|
||||
usage: Option<RequestUsage>,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::iter(usage.map(|usage| {
|
||||
Ok(CloudCompletionEvent::Status(
|
||||
CompletionRequestStatus::UsageUpdated {
|
||||
amount: usage.amount as usize,
|
||||
limit: usage.limit,
|
||||
},
|
||||
))
|
||||
}))
|
||||
}
|
||||
|
||||
fn tool_use_limit_reached_event<T>(
|
||||
tool_use_limit_reached: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::iter(tool_use_limit_reached.then(|| {
|
||||
Ok(CloudCompletionEvent::System(
|
||||
Ok(CloudCompletionEvent::Status(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
))
|
||||
}))
|
||||
|
@ -989,7 +974,7 @@ fn tool_use_limit_reached_event<T>(
|
|||
|
||||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
includes_queue_events: bool,
|
||||
includes_status_messages: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::try_unfold(
|
||||
(String::new(), BufReader::new(response.into_body())),
|
||||
|
@ -997,7 +982,7 @@ fn response_lines<T: DeserializeOwned>(
|
|||
match body.read_line(&mut line).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event = if includes_queue_events {
|
||||
let event = if includes_status_messages {
|
||||
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
||||
} else {
|
||||
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue