From 720dfee803b640695253a33b0b751caec9d633cb Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 24 Apr 2025 16:54:27 -0400 Subject: [PATCH] Treat invalid JSON in tool calls as failed tool calls (#29375) Release Notes: - N/A --------- Co-authored-by: Max Co-authored-by: Max Brunsfeld --- crates/agent/src/active_thread.rs | 35 ++++++--- crates/agent/src/buffer_codegen.rs | 8 +- crates/agent/src/thread.rs | 76 +++++++++++++++++-- crates/assistant/src/inline_assistant.rs | 5 +- crates/eval/src/example.rs | 3 + crates/language_model/src/fake_provider.rs | 13 +++- crates/language_model/src/language_model.rs | 24 +++++- .../language_models/src/provider/anthropic.rs | 76 +++++++++++-------- .../language_models/src/provider/bedrock.rs | 18 +++-- crates/language_models/src/provider/cloud.rs | 19 +++-- .../src/provider/copilot_chat.rs | 59 ++++++++------ .../language_models/src/provider/deepseek.rs | 43 ++++++----- crates/language_models/src/provider/google.rs | 25 ++++-- .../language_models/src/provider/lmstudio.rs | 17 ++++- .../language_models/src/provider/mistral.rs | 43 ++++++----- crates/language_models/src/provider/ollama.rs | 17 ++++- .../language_models/src/provider/open_ai.rs | 61 +++++++++------ 17 files changed, 374 insertions(+), 168 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 9abd3850fb..310a283287 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -43,6 +43,7 @@ use ui::{ Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*, }; use util::ResultExt as _; +use util::markdown::MarkdownString; use workspace::{OpenOptions, Workspace}; use zed_actions::assistant::OpenRulesLibrary; @@ -769,7 +770,7 @@ impl ActiveThread { this.render_tool_use_markdown( tool_use.id.clone(), tool_use.ui_text.clone(), - &tool_use.input, + &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), tool_use.status.text(), cx, ); @@ -870,7 +871,7 @@ impl ActiveThread { &mut self, tool_use_id: LanguageModelToolUseId, tool_label: impl Into, - tool_input: &serde_json::Value, + tool_input: &str, tool_output: SharedString, cx: &mut Context, ) { @@ -893,11 +894,10 @@ impl ActiveThread { this.replace(tool_label, cx); }); rendered.input.update(cx, |this, cx| { - let input = format!( - "```json\n{}\n```", - serde_json::to_string_pretty(tool_input).unwrap_or_default() + this.replace( + MarkdownString::code_block("json", tool_input).to_string(), + cx, ); - this.replace(input, cx); }); rendered.output.update(cx, |this, cx| { this.replace(tool_output, cx); @@ -988,7 +988,7 @@ impl ActiveThread { self.render_tool_use_markdown( tool_use.id.clone(), tool_use.ui_text.clone(), - &tool_use.input, + &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), "".into(), cx, ); @@ -1002,7 +1002,7 @@ impl ActiveThread { self.render_tool_use_markdown( tool_use_id.clone(), ui_text.clone(), - input, + &serde_json::to_string_pretty(&input).unwrap_or_default(), "".into(), cx, ); @@ -1014,7 +1014,7 @@ impl ActiveThread { self.render_tool_use_markdown( tool_use.id.clone(), tool_use.ui_text.clone(), - &tool_use.input, + &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), self.thread .read(cx) .output_for_tool(&tool_use.id) @@ -1026,6 +1026,23 @@ impl ActiveThread { } ThreadEvent::CheckpointChanged => cx.notify(), ThreadEvent::ReceivedTextChunk => {} + ThreadEvent::InvalidToolInput { + tool_use_id, + ui_text, + invalid_input_json, + } => { + self.render_tool_use_markdown( + tool_use_id.clone(), + ui_text, + invalid_input_json, + self.thread + .read(cx) + .output_for_tool(tool_use_id) + .map(|output| output.clone().into()) + .unwrap_or("".into()), + cx, + ); + } } } diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index f323c0ccab..ebdf9e3d9f 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -5,7 +5,9 @@ use anyhow::Result; use client::telemetry::Telemetry; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; -use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join}; +use futures::{ + SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join, +}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ @@ -508,7 +510,9 @@ impl CodegenAlternative { let mut response_latency = None; let request_start = Instant::now(); let diff = async { - let chunks = StripInvalidSpans::new(stream?.stream); + let chunks = StripInvalidSpans::new( + stream?.stream.map_err(|error| error.into()), + ); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); let mut line_diff = LineDiff::default(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index d47594d608..cce7d109c0 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -17,10 +17,10 @@ use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, + ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, TokenUsage, }; @@ -1275,9 +1275,30 @@ impl Thread { .push(event.as_ref().map_err(|error| error.to_string()).cloned()); } - let event = event?; - thread.update(cx, |thread, cx| { + let event = match event { + Ok(event) => event, + Err(LanguageModelCompletionError::BadInputJson { + id, + tool_name, + raw_input: invalid_input_json, + json_parse_error, + }) => { + thread.receive_invalid_tool_json( + id, + tool_name, + invalid_input_json, + json_parse_error, + window, + cx, + ); + return Ok(()); + } + Err(LanguageModelCompletionError::Other(error)) => { + return Err(error); + } + }; + match event { LanguageModelCompletionEvent::StartMessage { .. } => { request_assistant_message_id = Some(thread.insert_message( @@ -1390,7 +1411,8 @@ impl Thread { cx.notify(); thread.auto_capture_telemetry(cx); - })?; + Ok(()) + })??; smol::future::yield_now().await; } @@ -1681,6 +1703,41 @@ impl Thread { pending_tool_uses } + pub fn receive_invalid_tool_json( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + invalid_json: Arc, + error: String, + window: Option, + cx: &mut Context, + ) { + log::error!("The model returned invalid input JSON: {invalid_json}"); + + let pending_tool_use = self.tool_use.insert_tool_output( + tool_use_id.clone(), + tool_name, + Err(anyhow!("Error parsing input JSON: {error}")), + cx, + ); + let ui_text = if let Some(pending_tool_use) = &pending_tool_use { + pending_tool_use.ui_text.clone() + } else { + log::error!( + "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)." + ); + format!("Unknown tool {}", tool_use_id).into() + }; + + cx.emit(ThreadEvent::InvalidToolInput { + tool_use_id: tool_use_id.clone(), + ui_text, + invalid_input_json: invalid_json, + }); + + self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); + } + pub fn run_tool( &mut self, tool_use_id: LanguageModelToolUseId, @@ -2282,6 +2339,11 @@ pub enum ThreadEvent { ui_text: Arc, input: serde_json::Value, }, + InvalidToolInput { + tool_use_id: LanguageModelToolUseId, + ui_text: Arc, + invalid_input_json: Arc, + }, Stopped(Result>), MessageAdded(MessageId), MessageEdited(MessageId), diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 179d3bf060..ef7a10e3d0 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -22,7 +22,7 @@ use feature_flags::{ }; use fs::Fs; use futures::{ - SinkExt, Stream, StreamExt, + SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::{BoxFuture, LocalBoxFuture}, join, @@ -3056,7 +3056,8 @@ impl CodegenAlternative { let mut response_latency = None; let request_start = Instant::now(); let diff = async { - let chunks = StripInvalidSpans::new(stream?.stream); + let chunks = + StripInvalidSpans::new(stream?.stream.map_err(|e| e.into())); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); let mut line_diff = LineDiff::default(); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 39b2d7d57c..fccb9de7c8 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -253,6 +253,9 @@ impl ExampleContext { } }); } + ThreadEvent::InvalidToolInput { .. } => { + println!("{log_prefix} invalid tool input"); + } ThreadEvent::ToolConfirmationNeeded => { panic!( "{}Bug: Tool confirmation should not be required in eval", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 25f2a496e7..d1bcbea908 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -1,7 +1,7 @@ use crate::{ - AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, }; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; @@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel { &self, request: LanguageModelRequest, _: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); async move { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 5cf0895f14..8c9860ca75 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent { UsageUpdate(TokenUsage), } +#[derive(Error, Debug)] +pub enum LanguageModelCompletionError { + #[error("received bad input JSON")] + BadInputJson { + id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + /// Indicates the format used to define the input schema for a language model tool. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum LanguageModelToolSchemaFormat { @@ -193,7 +206,7 @@ pub struct LanguageModelToolUse { pub struct LanguageModelTextStream { pub message_id: Option, - pub stream: BoxStream<'static, Result>, + pub stream: BoxStream<'static, Result>, // Has complete token usage after the stream has finished pub last_token_usage: Arc>, } @@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>>; + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + >; fn stream_completion_with_usage( &self, @@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync { ) -> BoxFuture< 'static, Result<( - BoxStream<'static, Result>, + BoxStream<'static, Result>, Option, )>, > { diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c029cd1402..52f9806832 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -12,10 +12,10 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent, - RateLimiter, Role, + AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, + LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -27,7 +27,7 @@ use std::sync::Arc; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::{ResultExt, maybe}; +use util::ResultExt; const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID; const PROVIDER_NAME: &str = "Anthropic"; @@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let request = into_anthropic( request, self.model.request_id().into(), @@ -626,7 +631,7 @@ pub fn into_anthropic( pub fn map_to_language_model_completion_events( events: Pin>>>, -) -> impl Stream> { +) -> impl Stream> { struct RawToolUse { id: String, name: String, @@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events( Event::ContentBlockStop { index } => { if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { let input_json = tool_use.input_json.trim(); + let input_value = if input_json.is_empty() { + Ok(serde_json::Value::Object(serde_json::Map::default())) + } else { + serde_json::Value::from_str(input_json) + }; + let event_result = match input_value { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + }, + )), + Err(json_parse_err) => { + Err(LanguageModelCompletionError::BadInputJson { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; - return Some(( - vec![maybe!({ - Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input: if input_json.is_empty() { - serde_json::Value::Object( - serde_json::Map::default(), - ) - } else { - serde_json::Value::from_str( - input_json - ) - .map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))? - }, - raw_input: tool_use.input_json.clone(), - }, - )) - })], - state, - )); + return Some((vec![event_result], state)); } } Event::MessageStart { message } => { @@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events( } Event::Error { error } => { return Some(( - vec![Err(anyhow!(AnthropicError::ApiError(error)))], + vec![Err(LanguageModelCompletionError::Other(anyhow!( + AnthropicError::ApiError(error) + )))], state, )); } _ => {} }, Err(err) => { - return Some((vec![Err(anthropic_err_to_anyhow(err))], state)); + return Some(( + vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], + state, + )); } } } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index edcb604919..b9793ac16e 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -32,9 +32,10 @@ use gpui_tokio::Tokio; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage, + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, + RateLimiter, Role, TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let Ok(region) = cx.read_entity(&self.state, |state, _cx| { // Get region - from credentials or directly from settings let region = state @@ -780,7 +786,7 @@ pub fn get_bedrock_tokens( pub fn map_to_language_model_completion_events( events: Pin>>>, handle: Handle, -) -> impl Stream> { +) -> impl Stream> { struct RawToolUse { id: String, name: String, @@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events( _ => {} }, - Err(err) => return Some((Some(Err(anyhow!(err))), state)), + Err(err) => return Some((Some(Err(anyhow!(err).into())), state)), } } None diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 4c581194d3..5ff771a49b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -10,11 +10,11 @@ use futures::{ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ - AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, + LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, + ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { self.stream_completion_with_usage(request, cx) .map(|result| result.map(|(stream, _)| stream)) .boxed() @@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel { ) -> BoxFuture< 'static, Result<( - BoxStream<'static, Result>, + BoxStream<'static, Result>, Option, )>, > { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index f971cc803d..9d91be8fb6 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -17,16 +17,16 @@ use gpui::{ Transformation, percentage, svg, }; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role, + StopReason, }; use settings::SettingsStore; use std::time::Duration; use strum::IntoEnumIterator; use ui::prelude::*; -use util::maybe; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { if let Some(message) = request.messages.last() { if message.contents_empty() { const EMPTY_PROMPT_MSG: &str = @@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel { pub fn map_to_language_model_completion_events( events: Pin>>>, is_streaming: bool, -) -> impl Stream> { +) -> impl Stream> { #[derive(Default)] struct RawToolCall { id: String, @@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events( Ok(event) => { let Some(choice) = event.choices.first() else { return Some(( - vec![Err(anyhow!("Response contained no choices"))], + vec![Err(anyhow!("Response contained no choices").into())], state, )); }; @@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events( let Some(delta) = delta else { return Some(( - vec![Err(anyhow!("Response contained no delta"))], + vec![Err(anyhow!("Response contained no delta").into())], state, )); }; @@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events( } Some("tool_calls") => { events.extend(state.tool_calls_by_index.drain().map( - |(_, tool_call)| { - maybe!({ - Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - raw_input: tool_call.arguments.clone(), - input: serde_json::Value::from_str( - &tool_call.arguments, - )?, - }, - )) - }) + |(_, tool_call)| match serde_json::Value::from_str( + &tool_call.arguments, + ) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => { + Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }) + } }, )); @@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events( return Some((events, state)); } - Err(err) => return Some((vec![Err(err)], state)), + Err(err) => return Some((vec![Err(anyhow!(err).into())], state)), } } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 9989e4c6b1..f89154fcd9 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -9,9 +9,9 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let request = into_deepseek( request, self.model.id().to_string(), @@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel { let stream = stream.await?; Ok(stream .map(|result| { - result.and_then(|response| { - response - .choices - .first() - .ok_or_else(|| anyhow!("Empty response")) - .map(|choice| { - choice - .delta - .content - .clone() - .unwrap_or_default() - .map(LanguageModelCompletionEvent::Text) - }) - }) + result + .and_then(|response| { + response + .choices + .first() + .ok_or_else(|| anyhow!("Empty response")) + .map(|choice| { + choice + .delta + .content + .clone() + .unwrap_or_default() + .map(LanguageModelCompletionEvent::Text) + }) + }) + .map_err(LanguageModelCompletionError::Other) }) .boxed()) } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 04da5d6c03..c5035f0827 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -11,8 +11,9 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, + AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, + StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel { cx: &AsyncApp, ) -> BoxFuture< 'static, - Result>>, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + >, > { let request = into_google(request, self.model.id().to_string()); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; + let response = request + .await + .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?; Ok(map_to_language_model_completion_events(response)) }); async move { Ok(future.await?.boxed()) }.boxed() @@ -471,7 +479,7 @@ pub fn into_google( pub fn map_to_language_model_completion_events( events: Pin>>>, -) -> impl Stream> { +) -> impl Stream> { use std::sync::atomic::{AtomicU64, Ordering}; static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); @@ -492,7 +500,7 @@ pub fn map_to_language_model_completion_events( if let Some(event) = state.events.next().await { match event { Ok(event) => { - let mut events: Vec> = Vec::new(); + let mut events: Vec<_> = Vec::new(); let mut wants_to_use_tool = false; if let Some(usage_metadata) = event.usage_metadata { update_usage(&mut state.usage, &usage_metadata); @@ -559,7 +567,10 @@ pub fn map_to_language_model_completion_events( return Some((events, state)); } Err(err) => { - return Some((vec![Err(anyhow!(err))], state)); + return Some(( + vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], + state, + )); } } } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 2f5ae9ebb6..425caa2f45 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -2,7 +2,9 @@ use anyhow::{Result, anyhow}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; -use language_model::{AuthenticateError, LanguageModelCompletionEvent}; +use language_model::{ + AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, +}; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, @@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let request = self.to_lmstudio_request(request); let http_client = self.http_client.clone(); @@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel { async move { Ok(future .await? - .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .map(|result| { + result + .map(LanguageModelCompletionEvent::Text) + .map_err(LanguageModelCompletionError::Other) + }) .boxed()) } .boxed() diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index a5009c76a6..2774094e2f 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -8,9 +8,9 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use futures::stream::BoxStream; @@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let request = into_mistral( request, self.model.id().to_string(), @@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel { let stream = stream.await?; Ok(stream .map(|result| { - result.and_then(|response| { - response - .choices - .first() - .ok_or_else(|| anyhow!("Empty response")) - .map(|choice| { - choice - .delta - .content - .clone() - .unwrap_or_default() - .map(LanguageModelCompletionEvent::Text) - }) - }) + result + .and_then(|response| { + response + .choices + .first() + .ok_or_else(|| anyhow!("Empty response")) + .map(|choice| { + choice + .delta + .content + .clone() + .unwrap_or_default() + .map(LanguageModelCompletionEvent::Text) + }) + }) + .map_err(LanguageModelCompletionError::Other) }) .boxed()) } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 17c50c8eaf..28586b89b0 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -2,7 +2,9 @@ use anyhow::{Result, anyhow}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; -use language_model::{AuthenticateError, LanguageModelCompletionEvent}; +use language_model::{ + AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, +}; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, @@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + >, + > { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); @@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel { async move { Ok(future .await? - .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .map(|result| { + result + .map(LanguageModelCompletionEvent::Text) + .map_err(LanguageModelCompletionError::Other) + }) .boxed()) } .boxed() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 773c372d16..54f27b1727 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -9,10 +9,10 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, }; use open_ai::{Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; @@ -24,7 +24,7 @@ use std::sync::Arc; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::{ResultExt, maybe}; +use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; @@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel { cx: &AsyncApp, ) -> BoxFuture< 'static, - Result>>, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + >, > { let request = into_open_ai(request, &self.model, self.max_output_tokens()); let completions = self.stream_completion(request, cx); @@ -419,7 +424,7 @@ pub fn into_open_ai( pub fn map_to_language_model_completion_events( events: Pin>>>, -) -> impl Stream> { +) -> impl Stream> { #[derive(Default)] struct RawToolCall { id: String, @@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events( Ok(event) => { let Some(choice) = event.choices.first() else { return Some(( - vec![Err(anyhow!("Response contained no choices"))], + vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))], state, )); }; @@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events( } Some("tool_calls") => { events.extend(state.tool_calls_by_index.drain().map( - |(_, tool_call)| { - maybe!({ - Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - raw_input: tool_call.arguments.clone(), - input: serde_json::Value::from_str( - &tool_call.arguments, - )?, - }, - )) - }) + |(_, tool_call)| match serde_json::Value::from_str( + &tool_call.arguments, + ) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => { + Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }) + } }, )); @@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events( return Some((events, state)); } - Err(err) => return Some((vec![Err(err)], state)), + Err(err) => { + return Some((vec![Err(LanguageModelCompletionError::Other(err))], state)); + } } }