Treat invalid JSON in tool calls as failed tool calls (#29375)

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Richard Feldman 2025-04-24 16:54:27 -04:00 committed by GitHub
parent a98c648201
commit 720dfee803
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 374 additions and 168 deletions

View file

@ -43,6 +43,7 @@ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
@ -769,7 +770,7 @@ impl ActiveThread {
this.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
tool_use.status.text(),
cx,
);
@ -870,7 +871,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
tool_input: &serde_json::Value,
tool_input: &str,
tool_output: SharedString,
cx: &mut Context<Self>,
) {
@ -893,11 +894,10 @@ impl ActiveThread {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
let input = format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
this.replace(
MarkdownString::code_block("json", tool_input).to_string(),
cx,
);
this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
@ -988,7 +988,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
"".into(),
cx,
);
@ -1002,7 +1002,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
input,
&serde_json::to_string_pretty(&input).unwrap_or_default(),
"".into(),
cx,
);
@ -1014,7 +1014,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
@ -1026,6 +1026,23 @@ impl ActiveThread {
}
ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {}
ThreadEvent::InvalidToolInput {
tool_use_id,
ui_text,
invalid_input_json,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text,
invalid_input_json,
self.thread
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
}
}
}

View file

@ -5,7 +5,9 @@ use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
@ -508,7 +510,9 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();

View file

@ -17,10 +17,10 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
};
@ -1275,9 +1275,30 @@ impl Thread {
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
let event = event?;
thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
Err(LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
}) => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
};
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = Some(thread.insert_message(
@ -1390,7 +1411,8 @@ impl Thread {
cx.notify();
thread.auto_capture_telemetry(cx);
})?;
Ok(())
})??;
smol::future::yield_now().await;
}
@ -1681,6 +1703,41 @@ impl Thread {
pending_tool_uses
}
pub fn receive_invalid_tool_json(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
invalid_json: Arc<str>,
error: String,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
log::error!("The model returned invalid input JSON: {invalid_json}");
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
} else {
log::error!(
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
);
format!("Unknown tool {}", tool_use_id).into()
};
cx.emit(ThreadEvent::InvalidToolInput {
tool_use_id: tool_use_id.clone(),
ui_text,
invalid_input_json: invalid_json,
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
ui_text: Arc<str>,
input: serde_json::Value,
},
InvalidToolInput {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
invalid_input_json: Arc<str>,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),

View file

@ -22,7 +22,7 @@ use feature_flags::{
};
use fs::Fs;
use futures::{
SinkExt, Stream, StreamExt,
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{BoxFuture, LocalBoxFuture},
join,
@ -3056,7 +3056,8 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks =
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();

View file

@ -253,6 +253,9 @@ impl ExampleContext {
}
});
}
ThreadEvent::InvalidToolInput { .. } => {
println!("{log_prefix} invalid tool input");
}
ThreadEvent::ToolConfirmationNeeded => {
panic!(
"{}Bug: Tool confirmation should not be required in eval",

View file

@ -1,7 +1,7 @@
use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {

View file

@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {
@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String>>,
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
// Has complete token usage after the stream has finished
pub last_token_usage: Arc<Mutex<TokenUsage>>,
}
@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
>;
fn stream_completion_with_usage(
&self,
@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {

View file

@ -12,10 +12,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@ -27,7 +27,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe};
use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_anthropic(
request,
self.model.request_id().into(),
@ -626,7 +631,7 @@ pub fn into_anthropic(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim();
let input_value = if input_json.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::default()))
} else {
serde_json::Value::from_str(input_json)
};
let event_result = match input_value {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
return Some((
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
input_json
)
.map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
},
raw_input: tool_use.input_json.clone(),
},
))
})],
state,
));
return Some((vec![event_result], state));
}
}
Event::MessageStart { message } => {
@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
}
Event::Error { error } => {
return Some((
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error)
)))],
state,
));
}
_ => {}
},
Err(err) => {
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View file

@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings
let region = state
@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
_ => {}
},
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
}
}
None

View file

@ -10,11 +10,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {

View file

@ -17,16 +17,16 @@ use gpui::{
Transformation, percentage, svg,
};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
vec![Err(anyhow!("Response contained no choices").into())],
state,
));
};
@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
let Some(delta) = delta else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
vec![Err(anyhow!("Response contained no delta").into())],
state,
));
};
@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
raw_input: tool_call.arguments.clone(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
|(_, tool_call)| match serde_json::Value::from_str(
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
));
@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
}
}

View file

@ -9,9 +9,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_deepseek(
request,
self.model.id().to_string(),
@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View file

@ -11,8 +11,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@ -471,7 +479,7 @@ pub fn into_google(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
@ -492,7 +500,7 @@ pub fn map_to_language_model_completion_events(
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
@ -559,7 +567,10 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => {
return Some((vec![Err(anyhow!(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View file

@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

View file

@ -8,9 +8,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use futures::stream::BoxStream;
@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_mistral(
request,
self.model.id().to_string(),
@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View file

@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

View file

@ -9,10 +9,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
@ -24,7 +24,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe};
use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
@ -419,7 +424,7 @@ pub fn into_open_ai(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))],
state,
));
};
@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
raw_input: tool_call.arguments.clone(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
|(_, tool_call)| match serde_json::Value::from_str(
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
));
@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
Err(err) => {
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
}
}
}