Treat invalid JSON in tool calls as failed tool calls (#29375)

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Richard Feldman 2025-04-24 16:54:27 -04:00 committed by GitHub
parent a98c648201
commit 720dfee803
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 374 additions and 168 deletions

View file

@ -43,6 +43,7 @@ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*, Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
}; };
use util::ResultExt as _; use util::ResultExt as _;
use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace}; use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary; use zed_actions::assistant::OpenRulesLibrary;
@ -769,7 +770,7 @@ impl ActiveThread {
this.render_tool_use_markdown( this.render_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
&tool_use.input, &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
tool_use.status.text(), tool_use.status.text(),
cx, cx,
); );
@ -870,7 +871,7 @@ impl ActiveThread {
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>, tool_label: impl Into<SharedString>,
tool_input: &serde_json::Value, tool_input: &str,
tool_output: SharedString, tool_output: SharedString,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -893,11 +894,10 @@ impl ActiveThread {
this.replace(tool_label, cx); this.replace(tool_label, cx);
}); });
rendered.input.update(cx, |this, cx| { rendered.input.update(cx, |this, cx| {
let input = format!( this.replace(
"```json\n{}\n```", MarkdownString::code_block("json", tool_input).to_string(),
serde_json::to_string_pretty(tool_input).unwrap_or_default() cx,
); );
this.replace(input, cx);
}); });
rendered.output.update(cx, |this, cx| { rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx); this.replace(tool_output, cx);
@ -988,7 +988,7 @@ impl ActiveThread {
self.render_tool_use_markdown( self.render_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
&tool_use.input, &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
"".into(), "".into(),
cx, cx,
); );
@ -1002,7 +1002,7 @@ impl ActiveThread {
self.render_tool_use_markdown( self.render_tool_use_markdown(
tool_use_id.clone(), tool_use_id.clone(),
ui_text.clone(), ui_text.clone(),
input, &serde_json::to_string_pretty(&input).unwrap_or_default(),
"".into(), "".into(),
cx, cx,
); );
@ -1014,7 +1014,7 @@ impl ActiveThread {
self.render_tool_use_markdown( self.render_tool_use_markdown(
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
&tool_use.input, &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread self.thread
.read(cx) .read(cx)
.output_for_tool(&tool_use.id) .output_for_tool(&tool_use.id)
@ -1026,6 +1026,23 @@ impl ActiveThread {
} }
ThreadEvent::CheckpointChanged => cx.notify(), ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {} ThreadEvent::ReceivedTextChunk => {}
ThreadEvent::InvalidToolInput {
tool_use_id,
ui_text,
invalid_input_json,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text,
invalid_input_json,
self.thread
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
}
} }
} }

View file

@ -5,7 +5,9 @@ use anyhow::Result;
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
use collections::HashSet; use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join}; use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{ use language_model::{
@ -508,7 +510,9 @@ impl CodegenAlternative {
let mut response_latency = None; let mut response_latency = None;
let request_start = Instant::now(); let request_start = Instant::now();
let diff = async { let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream); let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks); futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string()); let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default(); let mut line_diff = LineDiff::default();

View file

@ -17,10 +17,10 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
}; };
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage, TokenUsage,
}; };
@ -1275,9 +1275,30 @@ impl Thread {
.push(event.as_ref().map_err(|error| error.to_string()).cloned()); .push(event.as_ref().map_err(|error| error.to_string()).cloned());
} }
let event = event?;
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
Err(LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
}) => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
};
match event { match event {
LanguageModelCompletionEvent::StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = Some(thread.insert_message( request_assistant_message_id = Some(thread.insert_message(
@ -1390,7 +1411,8 @@ impl Thread {
cx.notify(); cx.notify();
thread.auto_capture_telemetry(cx); thread.auto_capture_telemetry(cx);
})?; Ok(())
})??;
smol::future::yield_now().await; smol::future::yield_now().await;
} }
@ -1681,6 +1703,41 @@ impl Thread {
pending_tool_uses pending_tool_uses
} }
pub fn receive_invalid_tool_json(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
invalid_json: Arc<str>,
error: String,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
log::error!("The model returned invalid input JSON: {invalid_json}");
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
} else {
log::error!(
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
);
format!("Unknown tool {}", tool_use_id).into()
};
cx.emit(ThreadEvent::InvalidToolInput {
tool_use_id: tool_use_id.clone(),
ui_text,
invalid_input_json: invalid_json,
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn run_tool( pub fn run_tool(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
ui_text: Arc<str>, ui_text: Arc<str>,
input: serde_json::Value, input: serde_json::Value,
}, },
InvalidToolInput {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
invalid_input_json: Arc<str>,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>), Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId), MessageAdded(MessageId),
MessageEdited(MessageId), MessageEdited(MessageId),

View file

@ -22,7 +22,7 @@ use feature_flags::{
}; };
use fs::Fs; use fs::Fs;
use futures::{ use futures::{
SinkExt, Stream, StreamExt, SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc, channel::mpsc,
future::{BoxFuture, LocalBoxFuture}, future::{BoxFuture, LocalBoxFuture},
join, join,
@ -3056,7 +3056,8 @@ impl CodegenAlternative {
let mut response_latency = None; let mut response_latency = None;
let request_start = Instant::now(); let request_start = Instant::now();
let diff = async { let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream); let chunks =
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
futures::pin_mut!(chunks); futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string()); let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default(); let mut line_diff = LineDiff::default();

View file

@ -253,6 +253,9 @@ impl ExampleContext {
} }
}); });
} }
ThreadEvent::InvalidToolInput { .. } => {
println!("{log_prefix} invalid tool input");
}
ThreadEvent::ToolConfirmationNeeded => { ThreadEvent::ToolConfirmationNeeded => {
panic!( panic!(
"{}Bug: Tool confirmation should not be required in eval", "{}Bug: Tool confirmation should not be required in eval",

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
}; };
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
_: &AsyncApp, _: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx)); self.current_completion_txs.lock().push((request, tx));
async move { async move {

View file

@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage), UsageUpdate(TokenUsage),
} }
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Indicates the format used to define the input schema for a language model tool. /// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat { pub enum LanguageModelToolSchemaFormat {
@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
pub struct LanguageModelTextStream { pub struct LanguageModelTextStream {
pub message_id: Option<String>, pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String>>, pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
// Has complete token usage after the stream has finished // Has complete token usage after the stream has finished
pub last_token_usage: Arc<Mutex<TokenUsage>>, pub last_token_usage: Arc<Mutex<TokenUsage>>,
} }
@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>; ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
>;
fn stream_completion_with_usage( fn stream_completion_with_usage(
&self, &self,
@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture< ) -> BoxFuture<
'static, 'static,
Result<( Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>, BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>, Option<RequestUsage>,
)>, )>,
> { > {

View file

@ -12,10 +12,10 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
}; };
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -27,7 +27,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe}; use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID; const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic"; const PROVIDER_NAME: &str = "Anthropic";
@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_anthropic( let request = into_anthropic(
request, request,
self.model.request_id().into(), self.model.request_id().into(),
@ -626,7 +631,7 @@ pub fn into_anthropic(
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse { struct RawToolUse {
id: String, id: String,
name: String, name: String,
@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
Event::ContentBlockStop { index } => { Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim(); let input_json = tool_use.input_json.trim();
let input_value = if input_json.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::default()))
} else {
serde_json::Value::from_str(input_json)
};
let event_result = match input_value {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
return Some(( return Some((vec![event_result], state));
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
input_json
)
.map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
},
raw_input: tool_use.input_json.clone(),
},
))
})],
state,
));
} }
} }
Event::MessageStart { message } => { Event::MessageStart { message } => {
@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
} }
Event::Error { error } => { Event::Error { error } => {
return Some(( return Some((
vec![Err(anyhow!(AnthropicError::ApiError(error)))], vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error)
)))],
state, state,
)); ));
} }
_ => {} _ => {}
}, },
Err(err) => { Err(err) => {
return Some((vec![Err(anthropic_err_to_anyhow(err))], state)); return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
} }
} }
} }

View file

@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, TokenUsage,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| { let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings // Get region - from credentials or directly from settings
let region = state let region = state
@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle, handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse { struct RawToolUse {
id: String, id: String,
name: String, name: String,
@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
_ => {} _ => {}
}, },
Err(err) => return Some((Some(Err(anyhow!(err))), state)), Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
} }
} }
None None

View file

@ -10,11 +10,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ZED_CLOUD_PROVIDER_ID, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
self.stream_completion_with_usage(request, cx) self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream)) .map(|result| result.map(|(stream, _)| stream))
.boxed() .boxed()
@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture< ) -> BoxFuture<
'static, 'static,
Result<( Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>, BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>, Option<RequestUsage>,
)>, )>,
> { > {

View file

@ -17,16 +17,16 @@ use gpui::{
Transformation, percentage, svg, Transformation, percentage, svg,
}; };
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens; use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens; use super::google::count_google_tokens;
@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
if let Some(message) = request.messages.last() { if let Some(message) = request.messages.last() {
if message.contents_empty() { if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str = const EMPTY_PROMPT_MSG: &str =
@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool, is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)] #[derive(Default)]
struct RawToolCall { struct RawToolCall {
id: String, id: String,
@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
Ok(event) => { Ok(event) => {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return Some(( return Some((
vec![Err(anyhow!("Response contained no choices"))], vec![Err(anyhow!("Response contained no choices").into())],
state, state,
)); ));
}; };
@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
let Some(delta) = delta else { let Some(delta) = delta else {
return Some(( return Some((
vec![Err(anyhow!("Response contained no delta"))], vec![Err(anyhow!("Response contained no delta").into())],
state, state,
)); ));
}; };
@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
} }
Some("tool_calls") => { Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map( events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| { |(_, tool_call)| match serde_json::Value::from_str(
maybe!({ &tool_call.arguments,
Ok(LanguageModelCompletionEvent::ToolUse( ) {
LanguageModelToolUse { Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
id: tool_call.id.into(), LanguageModelToolUse {
name: tool_call.name.as_str().into(), id: tool_call.id.clone().into(),
is_input_complete: true, name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.clone(), is_input_complete: true,
input: serde_json::Value::from_str( input,
&tool_call.arguments, raw_input: tool_call.arguments.clone(),
)?, },
}, )),
)) Err(error) => {
}) Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
}, },
)); ));
@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
return Some((events, state)); return Some((events, state));
} }
Err(err) => return Some((vec![Err(err)], state)), Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
} }
} }

View file

@ -9,9 +9,9 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_deepseek( let request = into_deepseek(
request, request,
self.model.id().to_string(), self.model.id().to_string(),
@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = stream.await?; let stream = stream.await?;
Ok(stream Ok(stream
.map(|result| { .map(|result| {
result.and_then(|response| { result
response .and_then(|response| {
.choices response
.first() .choices
.ok_or_else(|| anyhow!("Empty response")) .first()
.map(|choice| { .ok_or_else(|| anyhow!("Empty response"))
choice .map(|choice| {
.delta choice
.content .delta
.clone() .content
.unwrap_or_default() .clone()
.map(LanguageModelCompletionEvent::Text) .unwrap_or_default()
}) .map(LanguageModelCompletionEvent::Text)
}) })
})
.map_err(LanguageModelCompletionError::Other)
}) })
.boxed()) .boxed())
} }

View file

@ -11,8 +11,9 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
StopReason,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture< ) -> BoxFuture<
'static, 'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>, Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> { > {
let request = into_google(request, self.model.id().to_string()); let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?; let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response)) Ok(map_to_language_model_completion_events(response))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
@ -471,7 +479,7 @@ pub fn into_google(
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
@ -492,7 +500,7 @@ pub fn map_to_language_model_completion_events(
if let Some(event) = state.events.next().await { if let Some(event) = state.events.next().await {
match event { match event {
Ok(event) => { Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new(); let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false; let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata { if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata); update_usage(&mut state.usage, &usage_metadata);
@ -559,7 +567,10 @@ pub fn map_to_language_model_completion_events(
return Some((events, state)); return Some((events, state));
} }
Err(err) => { Err(err) => {
return Some((vec![Err(anyhow!(err))], state)); return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
} }
} }
} }

View file

@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent}; use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_lmstudio_request(request); let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
async move { async move {
Ok(future Ok(future
.await? .await?
.map(|result| result.map(LanguageModelCompletionEvent::Text)) .map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed()) .boxed())
} }
.boxed() .boxed()

View file

@ -8,9 +8,9 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use futures::stream::BoxStream; use futures::stream::BoxStream;
@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_mistral( let request = into_mistral(
request, request,
self.model.id().to_string(), self.model.id().to_string(),
@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
let stream = stream.await?; let stream = stream.await?;
Ok(stream Ok(stream
.map(|result| { .map(|result| {
result.and_then(|response| { result
response .and_then(|response| {
.choices response
.first() .choices
.ok_or_else(|| anyhow!("Empty response")) .first()
.map(|choice| { .ok_or_else(|| anyhow!("Empty response"))
choice .map(|choice| {
.delta choice
.content .delta
.clone() .content
.unwrap_or_default() .clone()
.map(LanguageModelCompletionEvent::Text) .unwrap_or_default()
}) .map(LanguageModelCompletionEvent::Text)
}) })
})
.map_err(LanguageModelCompletionError::Other)
}) })
.boxed()) .boxed())
} }

View file

@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent}; use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_ollama_request(request); let request = self.to_ollama_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
async move { async move {
Ok(future Ok(future
.await? .await?
.map(|result| result.map(LanguageModelCompletionEvent::Text)) .map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed()) .boxed())
} }
.boxed() .boxed()

View file

@ -9,10 +9,10 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
RateLimiter, Role, StopReason, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
}; };
use open_ai::{Model, ResponseStreamEvent, stream_completion}; use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -24,7 +24,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe}; use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture< ) -> BoxFuture<
'static, 'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>, Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> { > {
let request = into_open_ai(request, &self.model, self.max_output_tokens()); let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx); let completions = self.stream_completion(request, cx);
@ -419,7 +424,7 @@ pub fn into_open_ai(
pub fn map_to_language_model_completion_events( pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> { ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)] #[derive(Default)]
struct RawToolCall { struct RawToolCall {
id: String, id: String,
@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
Ok(event) => { Ok(event) => {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return Some(( return Some((
vec![Err(anyhow!("Response contained no choices"))], vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))],
state, state,
)); ));
}; };
@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
} }
Some("tool_calls") => { Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map( events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| { |(_, tool_call)| match serde_json::Value::from_str(
maybe!({ &tool_call.arguments,
Ok(LanguageModelCompletionEvent::ToolUse( ) {
LanguageModelToolUse { Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
id: tool_call.id.into(), LanguageModelToolUse {
name: tool_call.name.as_str().into(), id: tool_call.id.clone().into(),
is_input_complete: true, name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.clone(), is_input_complete: true,
input: serde_json::Value::from_str( input,
&tool_call.arguments, raw_input: tool_call.arguments.clone(),
)?, },
}, )),
)) Err(error) => {
}) Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
}, },
)); ));
@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
return Some((events, state)); return Some((events, state));
} }
Err(err) => return Some((vec![Err(err)], state)), Err(err) => {
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
}
} }
} }