Add support for queuing status updates in cloud language model provider (#29818)
This sets us up to display queue position information to the user, once our language model backend is updated to support request queuing. The JSON returned by the LLM backend will need to look like this: ```json {"queue": {"status": "queued", "position": 1}} {"queue": {"status": "started"}} {"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} ``` Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
4d1df7bcd7
commit
04772bf17d
9 changed files with 492 additions and 430 deletions
|
@ -1,11 +1,10 @@
|
|||
use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
|
||||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
|
||||
stream::BoxStream,
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
|
@ -14,7 +13,7 @@ use language_model::{
|
|||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
|
@ -26,6 +25,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
|||
use settings::{Settings, SettingsStore};
|
||||
use smol::Timer;
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::{
|
||||
sync::{Arc, LazyLock},
|
||||
|
@ -41,9 +41,9 @@ use zed_llm_client::{
|
|||
};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
|
||||
use crate::provider::google::into_google;
|
||||
use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
|
||||
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
|
||||
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
||||
|
||||
pub const PROVIDER_NAME: &str = "Zed";
|
||||
|
||||
|
@ -518,7 +518,7 @@ impl CloudLanguageModel {
|
|||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: CompletionBody,
|
||||
) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
|
||||
) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
|
@ -536,13 +536,18 @@ impl CloudLanguageModel {
|
|||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header("x-zed-client-supports-queueing", "true")
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let includes_queue_events = response
|
||||
.headers()
|
||||
.get("x-zed-server-supports-queueing")
|
||||
.is_some();
|
||||
let usage = RequestUsage::from_headers(response.headers()).ok();
|
||||
|
||||
return Ok((response, usage));
|
||||
return Ok((response, usage, includes_queue_events));
|
||||
} else if response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
|
@ -782,7 +787,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage) = Self::perform_llm_completion(
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -811,9 +816,11 @@ impl LanguageModel for CloudLanguageModel {
|
|||
Err(err) => anyhow!(err),
|
||||
})?;
|
||||
|
||||
let mut mapper = AnthropicEventMapper::new();
|
||||
Ok((
|
||||
crate::provider::anthropic::map_to_language_model_completion_events(
|
||||
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
))
|
||||
|
@ -829,7 +836,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage) = Self::perform_llm_completion(
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -842,9 +849,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut mapper = OpenAiEventMapper::new();
|
||||
Ok((
|
||||
crate::provider::open_ai::map_to_language_model_completion_events(
|
||||
Box::pin(response_lines(response)),
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
))
|
||||
|
@ -860,7 +870,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = into_google(request, model.id().into());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage) = Self::perform_llm_completion(
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -873,10 +883,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
Ok((
|
||||
crate::provider::google::map_to_language_model_completion_events(Box::pin(
|
||||
response_lines(response),
|
||||
)),
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
))
|
||||
});
|
||||
|
@ -890,16 +902,54 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CloudCompletionEvent<T> {
|
||||
Queue(QueueState),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
fn map_cloud_completion_events<T, F>(
|
||||
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
|
||||
mut map_callback: F,
|
||||
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
where
|
||||
T: DeserializeOwned + 'static,
|
||||
F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
stream
|
||||
.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Queue(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
) -> impl Stream<Item = Result<T>> {
|
||||
includes_queue_events: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::try_unfold(
|
||||
(String::new(), BufReader::new(response.into_body())),
|
||||
move |(mut line, mut body)| async {
|
||||
move |(mut line, mut body)| async move {
|
||||
match body.read_line(&mut line).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: T = serde_json::from_str(&line)?;
|
||||
let event = if includes_queue_events {
|
||||
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
||||
} else {
|
||||
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||
};
|
||||
|
||||
line.clear();
|
||||
Ok(Some((event, (line, body))))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue