Change cloud language model provider JSON protocol to surface errors and usage information (#29830)
Release Notes: - N/A --------- Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
3984531a45
commit
c3d9cdecab
8 changed files with 128 additions and 197 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -18825,9 +18825,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.7.2"
|
version = "0.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7"
|
checksum = "2adf9bc80def4ec93c190f06eb78111865edc2576019a9753eaef6fd7bc3b72c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -611,7 +611,7 @@ wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
wit-component = "0.221"
|
wit-component = "0.221"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.7.2"
|
zed_llm_client = "0.7.4"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
|
|
||||||
[workspace.dependencies.async-stripe]
|
[workspace.dependencies.async-stripe]
|
||||||
|
|
|
@ -37,7 +37,7 @@ use settings::Settings;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use zed_llm_client::CompletionMode;
|
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||||
|
|
||||||
use crate::ThreadStore;
|
use crate::ThreadStore;
|
||||||
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
|
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
|
||||||
|
@ -1356,20 +1356,17 @@ impl Thread {
|
||||||
self.last_received_chunk_at = Some(Instant::now());
|
self.last_received_chunk_at = Some(Instant::now());
|
||||||
|
|
||||||
let task = cx.spawn(async move |thread, cx| {
|
let task = cx.spawn(async move |thread, cx| {
|
||||||
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
|
let stream_completion_future = model.stream_completion(request, &cx);
|
||||||
let initial_token_usage =
|
let initial_token_usage =
|
||||||
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let (mut events, usage) = stream_completion_future.await?;
|
let mut events = stream_completion_future.await?;
|
||||||
|
|
||||||
let mut stop_reason = StopReason::EndTurn;
|
let mut stop_reason = StopReason::EndTurn;
|
||||||
let mut current_token_usage = TokenUsage::default();
|
let mut current_token_usage = TokenUsage::default();
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |_thread, cx| {
|
.update(cx, |_thread, cx| {
|
||||||
if let Some(usage) = usage {
|
|
||||||
cx.emit(ThreadEvent::UsageUpdated(usage));
|
|
||||||
}
|
|
||||||
cx.emit(ThreadEvent::NewRequest);
|
cx.emit(ThreadEvent::NewRequest);
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
|
@ -1515,27 +1512,34 @@ impl Thread {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::QueueUpdate(status) => {
|
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||||
if let Some(completion) = thread
|
if let Some(completion) = thread
|
||||||
.pending_completions
|
.pending_completions
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.find(|completion| completion.id == pending_completion_id)
|
.find(|completion| completion.id == pending_completion_id)
|
||||||
{
|
{
|
||||||
let queue_state = match status {
|
match status_update {
|
||||||
language_model::CompletionRequestStatus::Queued {
|
CompletionRequestStatus::Queued {
|
||||||
position,
|
position,
|
||||||
} => Some(QueueState::Queued { position }),
|
} => {
|
||||||
language_model::CompletionRequestStatus::Started => {
|
completion.queue_state = QueueState::Queued { position };
|
||||||
Some(QueueState::Started)
|
|
||||||
}
|
}
|
||||||
language_model::CompletionRequestStatus::ToolUseLimitReached => {
|
CompletionRequestStatus::Started => {
|
||||||
|
completion.queue_state = QueueState::Started;
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::Failed {
|
||||||
|
code, message
|
||||||
|
} => {
|
||||||
|
return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::UsageUpdated {
|
||||||
|
amount, limit
|
||||||
|
} => {
|
||||||
|
cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::ToolUseLimitReached => {
|
||||||
thread.tool_use_limit_reached = true;
|
thread.tool_use_limit_reached = true;
|
||||||
None
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(queue_state) = queue_state {
|
|
||||||
completion.queue_state = queue_state;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1690,19 +1694,27 @@ impl Thread {
|
||||||
|
|
||||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||||
async move {
|
async move {
|
||||||
let stream = model.model.stream_completion_text_with_usage(request, &cx);
|
let mut messages = model.model.stream_completion(request, &cx).await?;
|
||||||
let (mut messages, usage) = stream.await?;
|
|
||||||
|
|
||||||
if let Some(usage) = usage {
|
|
||||||
this.update(cx, |_thread, cx| {
|
|
||||||
cx.emit(ThreadEvent::UsageUpdated(usage));
|
|
||||||
})
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut new_summary = String::new();
|
let mut new_summary = String::new();
|
||||||
while let Some(message) = messages.stream.next().await {
|
while let Some(event) = messages.next().await {
|
||||||
let text = message?;
|
let event = event?;
|
||||||
|
let text = match event {
|
||||||
|
LanguageModelCompletionEvent::Text(text) => text,
|
||||||
|
LanguageModelCompletionEvent::StatusUpdate(
|
||||||
|
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||||
|
) => {
|
||||||
|
this.update(cx, |_, cx| {
|
||||||
|
cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
|
||||||
|
limit,
|
||||||
|
amount: amount as i32,
|
||||||
|
}));
|
||||||
|
})?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
let mut lines = text.lines();
|
let mut lines = text.lines();
|
||||||
new_summary.extend(lines.next());
|
new_summary.extend(lines.next());
|
||||||
|
|
||||||
|
|
|
@ -2371,7 +2371,7 @@ impl AssistantContext {
|
||||||
});
|
});
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
LanguageModelCompletionEvent::QueueUpdate { .. } => {}
|
LanguageModelCompletionEvent::StatusUpdate { .. } => {}
|
||||||
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
||||||
LanguageModelCompletionEvent::Stop(reason) => {
|
LanguageModelCompletionEvent::Stop(reason) => {
|
||||||
stop_reason = reason;
|
stop_reason = reason;
|
||||||
|
@ -2429,8 +2429,8 @@ impl AssistantContext {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::ToolUse(_) => {}
|
LanguageModelCompletionEvent::ToolUse(_) |
|
||||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1018,7 +1018,7 @@ pub fn response_events_to_markdown(
|
||||||
Ok(
|
Ok(
|
||||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
LanguageModelCompletionEvent::UsageUpdate(_)
|
||||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||||
| LanguageModelCompletionEvent::QueueUpdate { .. },
|
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
||||||
) => {}
|
) => {}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||||
|
@ -1093,7 +1093,7 @@ impl ThreadDialog {
|
||||||
|
|
||||||
// Skip these
|
// Skip these
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
|
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
|
||||||
| Ok(LanguageModelCompletionEvent::QueueUpdate { .. })
|
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
|
||||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,8 @@ use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::serde::is_default;
|
use util::serde::is_default;
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
|
||||||
|
MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub use crate::model::*;
|
pub use crate::model::*;
|
||||||
|
@ -64,18 +65,10 @@ pub struct LanguageModelCacheConfiguration {
|
||||||
pub min_total_token: usize,
|
pub min_total_token: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "status", rename_all = "snake_case")]
|
|
||||||
pub enum CompletionRequestStatus {
|
|
||||||
Queued { position: usize },
|
|
||||||
Started,
|
|
||||||
ToolUseLimitReached,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A completion event from a language model.
|
/// A completion event from a language model.
|
||||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||||
pub enum LanguageModelCompletionEvent {
|
pub enum LanguageModelCompletionEvent {
|
||||||
QueueUpdate(CompletionRequestStatus),
|
StatusUpdate(CompletionRequestStatus),
|
||||||
Stop(StopReason),
|
Stop(StopReason),
|
||||||
Text(String),
|
Text(String),
|
||||||
Thinking {
|
Thinking {
|
||||||
|
@ -299,41 +292,15 @@ pub trait LanguageModel: Send + Sync {
|
||||||
>,
|
>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
fn stream_completion_with_usage(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AsyncApp,
|
|
||||||
) -> BoxFuture<
|
|
||||||
'static,
|
|
||||||
Result<(
|
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
|
||||||
Option<RequestUsage>,
|
|
||||||
)>,
|
|
||||||
> {
|
|
||||||
self.stream_completion(request, cx)
|
|
||||||
.map(|result| result.map(|stream| (stream, None)))
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion_text(
|
fn stream_completion_text(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
|
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
|
||||||
self.stream_completion_text_with_usage(request, cx)
|
let future = self.stream_completion(request, cx);
|
||||||
.map(|result| result.map(|(stream, _usage)| stream))
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion_text_with_usage(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AsyncApp,
|
|
||||||
) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
|
|
||||||
let future = self.stream_completion_with_usage(request, cx);
|
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let (events, usage) = future.await?;
|
let events = future.await?;
|
||||||
let mut events = events.fuse();
|
let mut events = events.fuse();
|
||||||
let mut message_id = None;
|
let mut message_id = None;
|
||||||
let mut first_item_text = None;
|
let mut first_item_text = None;
|
||||||
|
@ -358,7 +325,7 @@ pub trait LanguageModel: Send + Sync {
|
||||||
let last_token_usage = last_token_usage.clone();
|
let last_token_usage = last_token_usage.clone();
|
||||||
async move {
|
async move {
|
||||||
match result {
|
match result {
|
||||||
Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
|
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
|
||||||
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
||||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||||
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
||||||
|
@ -375,14 +342,11 @@ pub trait LanguageModel: Send + Sync {
|
||||||
}))
|
}))
|
||||||
.boxed();
|
.boxed();
|
||||||
|
|
||||||
Ok((
|
Ok(LanguageModelTextStream {
|
||||||
LanguageModelTextStream {
|
message_id,
|
||||||
message_id,
|
stream,
|
||||||
stream,
|
last_token_usage,
|
||||||
last_token_usage,
|
})
|
||||||
},
|
|
||||||
usage,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,6 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::RequestUsage;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RateLimiter {
|
pub struct RateLimiter {
|
||||||
semaphore: Arc<Semaphore>,
|
semaphore: Arc<Semaphore>,
|
||||||
|
@ -69,32 +67,4 @@ impl RateLimiter {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stream_with_usage<'a, Fut, T>(
|
|
||||||
&self,
|
|
||||||
future: Fut,
|
|
||||||
) -> impl 'a
|
|
||||||
+ Future<
|
|
||||||
Output = Result<(
|
|
||||||
impl Stream<Item = T::Item> + use<Fut, T>,
|
|
||||||
Option<RequestUsage>,
|
|
||||||
)>,
|
|
||||||
>
|
|
||||||
where
|
|
||||||
Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
|
|
||||||
T: Stream,
|
|
||||||
{
|
|
||||||
let guard = self.semaphore.acquire_arc();
|
|
||||||
async move {
|
|
||||||
let guard = guard.await;
|
|
||||||
let (inner, usage) = future.await?;
|
|
||||||
Ok((
|
|
||||||
RateLimitGuard {
|
|
||||||
inner,
|
|
||||||
_guard: guard,
|
|
||||||
},
|
|
||||||
usage,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,12 +9,11 @@ use futures::{
|
||||||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
|
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
|
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||||
ZED_CLOUD_PROVIDER_ID,
|
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||||
|
@ -36,9 +35,10 @@ use strum::IntoEnumIterator;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ui::{TintColor, prelude::*};
|
use ui::{TintColor, prelude::*};
|
||||||
use zed_llm_client::{
|
use zed_llm_client::{
|
||||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
|
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||||
|
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
|
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -517,7 +517,7 @@ struct PerformLlmCompletionResponse {
|
||||||
response: Response<AsyncBody>,
|
response: Response<AsyncBody>,
|
||||||
usage: Option<RequestUsage>,
|
usage: Option<RequestUsage>,
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
includes_queue_events: bool,
|
includes_status_messages: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CloudLanguageModel {
|
impl CloudLanguageModel {
|
||||||
|
@ -545,25 +545,31 @@ impl CloudLanguageModel {
|
||||||
let request = request_builder
|
let request = request_builder
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {token}"))
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
.header("x-zed-client-supports-queueing", "true")
|
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
|
||||||
.body(serde_json::to_string(&body)?.into())?;
|
.body(serde_json::to_string(&body)?.into())?;
|
||||||
let mut response = http_client.send(request).await?;
|
let mut response = http_client.send(request).await?;
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
let includes_queue_events = response
|
let includes_status_messages = response
|
||||||
.headers()
|
.headers()
|
||||||
.get("x-zed-server-supports-queueing")
|
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
|
||||||
.is_some();
|
.is_some();
|
||||||
|
|
||||||
let tool_use_limit_reached = response
|
let tool_use_limit_reached = response
|
||||||
.headers()
|
.headers()
|
||||||
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
|
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
|
||||||
.is_some();
|
.is_some();
|
||||||
let usage = RequestUsage::from_headers(response.headers()).ok();
|
|
||||||
|
let usage = if includes_status_messages {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
RequestUsage::from_headers(response.headers()).ok()
|
||||||
|
};
|
||||||
|
|
||||||
return Ok(PerformLlmCompletionResponse {
|
return Ok(PerformLlmCompletionResponse {
|
||||||
response,
|
response,
|
||||||
usage,
|
usage,
|
||||||
includes_queue_events,
|
includes_status_messages,
|
||||||
tool_use_limit_reached,
|
tool_use_limit_reached,
|
||||||
});
|
});
|
||||||
} else if response
|
} else if response
|
||||||
|
@ -767,28 +773,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
_cx: &AsyncApp,
|
||||||
) -> BoxFuture<
|
) -> BoxFuture<
|
||||||
'static,
|
'static,
|
||||||
Result<
|
Result<
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
>,
|
>,
|
||||||
> {
|
|
||||||
self.stream_completion_with_usage(request, cx)
|
|
||||||
.map(|result| result.map(|(stream, _)| stream))
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion_with_usage(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
_cx: &AsyncApp,
|
|
||||||
) -> BoxFuture<
|
|
||||||
'static,
|
|
||||||
Result<(
|
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
|
||||||
Option<RequestUsage>,
|
|
||||||
)>,
|
|
||||||
> {
|
> {
|
||||||
let thread_id = request.thread_id.clone();
|
let thread_id = request.thread_id.clone();
|
||||||
let prompt_id = request.prompt_id.clone();
|
let prompt_id = request.prompt_id.clone();
|
||||||
|
@ -804,11 +794,11 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
);
|
);
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream_with_usage(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let PerformLlmCompletionResponse {
|
let PerformLlmCompletionResponse {
|
||||||
response,
|
response,
|
||||||
usage,
|
usage,
|
||||||
includes_queue_events,
|
includes_status_messages,
|
||||||
tool_use_limit_reached,
|
tool_use_limit_reached,
|
||||||
} = Self::perform_llm_completion(
|
} = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
|
@ -840,32 +830,26 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let mut mapper = AnthropicEventMapper::new();
|
let mut mapper = AnthropicEventMapper::new();
|
||||||
Ok((
|
Ok(map_cloud_completion_events(
|
||||||
map_cloud_completion_events(
|
Box::pin(
|
||||||
Box::pin(
|
response_lines(response, includes_status_messages)
|
||||||
response_lines(response, includes_queue_events)
|
.chain(usage_updated_event(usage))
|
||||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||||
),
|
|
||||||
move |event| mapper.map_event(event),
|
|
||||||
),
|
),
|
||||||
usage,
|
move |event| mapper.map_event(event),
|
||||||
))
|
))
|
||||||
});
|
});
|
||||||
async move {
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let (stream, usage) = future.await?;
|
|
||||||
Ok((stream.boxed(), usage))
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = into_open_ai(request, model, model.max_output_tokens());
|
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream_with_usage(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let PerformLlmCompletionResponse {
|
let PerformLlmCompletionResponse {
|
||||||
response,
|
response,
|
||||||
usage,
|
usage,
|
||||||
includes_queue_events,
|
includes_status_messages,
|
||||||
tool_use_limit_reached,
|
tool_use_limit_reached,
|
||||||
} = Self::perform_llm_completion(
|
} = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
|
@ -882,32 +866,26 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut mapper = OpenAiEventMapper::new();
|
let mut mapper = OpenAiEventMapper::new();
|
||||||
Ok((
|
Ok(map_cloud_completion_events(
|
||||||
map_cloud_completion_events(
|
Box::pin(
|
||||||
Box::pin(
|
response_lines(response, includes_status_messages)
|
||||||
response_lines(response, includes_queue_events)
|
.chain(usage_updated_event(usage))
|
||||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||||
),
|
|
||||||
move |event| mapper.map_event(event),
|
|
||||||
),
|
),
|
||||||
usage,
|
move |event| mapper.map_event(event),
|
||||||
))
|
))
|
||||||
});
|
});
|
||||||
async move {
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let (stream, usage) = future.await?;
|
|
||||||
Ok((stream.boxed(), usage))
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = into_google(request, model.id().into());
|
let request = into_google(request, model.id().into());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream_with_usage(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let PerformLlmCompletionResponse {
|
let PerformLlmCompletionResponse {
|
||||||
response,
|
response,
|
||||||
usage,
|
usage,
|
||||||
includes_queue_events,
|
includes_status_messages,
|
||||||
tool_use_limit_reached,
|
tool_use_limit_reached,
|
||||||
} = Self::perform_llm_completion(
|
} = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
|
@ -924,22 +902,16 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut mapper = GoogleEventMapper::new();
|
let mut mapper = GoogleEventMapper::new();
|
||||||
Ok((
|
Ok(map_cloud_completion_events(
|
||||||
map_cloud_completion_events(
|
Box::pin(
|
||||||
Box::pin(
|
response_lines(response, includes_status_messages)
|
||||||
response_lines(response, includes_queue_events)
|
.chain(usage_updated_event(usage))
|
||||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||||
),
|
|
||||||
move |event| mapper.map_event(event),
|
|
||||||
),
|
),
|
||||||
usage,
|
move |event| mapper.map_event(event),
|
||||||
))
|
))
|
||||||
});
|
});
|
||||||
async move {
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
let (stream, usage) = future.await?;
|
|
||||||
Ok((stream.boxed(), usage))
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -948,7 +920,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum CloudCompletionEvent<T> {
|
pub enum CloudCompletionEvent<T> {
|
||||||
System(CompletionRequestStatus),
|
Status(CompletionRequestStatus),
|
||||||
Event(T),
|
Event(T),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -968,8 +940,8 @@ where
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||||
}
|
}
|
||||||
Ok(CloudCompletionEvent::System(event)) => {
|
Ok(CloudCompletionEvent::Status(event)) => {
|
||||||
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
|
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||||
}
|
}
|
||||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
||||||
})
|
})
|
||||||
|
@ -977,11 +949,24 @@ where
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn usage_updated_event<T>(
|
||||||
|
usage: Option<RequestUsage>,
|
||||||
|
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||||
|
futures::stream::iter(usage.map(|usage| {
|
||||||
|
Ok(CloudCompletionEvent::Status(
|
||||||
|
CompletionRequestStatus::UsageUpdated {
|
||||||
|
amount: usage.amount as usize,
|
||||||
|
limit: usage.limit,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
fn tool_use_limit_reached_event<T>(
|
fn tool_use_limit_reached_event<T>(
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||||
futures::stream::iter(tool_use_limit_reached.then(|| {
|
futures::stream::iter(tool_use_limit_reached.then(|| {
|
||||||
Ok(CloudCompletionEvent::System(
|
Ok(CloudCompletionEvent::Status(
|
||||||
CompletionRequestStatus::ToolUseLimitReached,
|
CompletionRequestStatus::ToolUseLimitReached,
|
||||||
))
|
))
|
||||||
}))
|
}))
|
||||||
|
@ -989,7 +974,7 @@ fn tool_use_limit_reached_event<T>(
|
||||||
|
|
||||||
fn response_lines<T: DeserializeOwned>(
|
fn response_lines<T: DeserializeOwned>(
|
||||||
response: Response<AsyncBody>,
|
response: Response<AsyncBody>,
|
||||||
includes_queue_events: bool,
|
includes_status_messages: bool,
|
||||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||||
futures::stream::try_unfold(
|
futures::stream::try_unfold(
|
||||||
(String::new(), BufReader::new(response.into_body())),
|
(String::new(), BufReader::new(response.into_body())),
|
||||||
|
@ -997,7 +982,7 @@ fn response_lines<T: DeserializeOwned>(
|
||||||
match body.read_line(&mut line).await {
|
match body.read_line(&mut line).await {
|
||||||
Ok(0) => Ok(None),
|
Ok(0) => Ok(None),
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let event = if includes_queue_events {
|
let event = if includes_status_messages {
|
||||||
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
|
||||||
} else {
|
} else {
|
||||||
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue