Treat invalid JSON in tool calls as failed tool calls (#29375)
Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
a98c648201
commit
720dfee803
17 changed files with 374 additions and 168 deletions
|
@ -43,6 +43,7 @@ use ui::{
|
|||
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use util::markdown::MarkdownString;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
|
@ -769,7 +770,7 @@ impl ActiveThread {
|
|||
this.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
tool_use.status.text(),
|
||||
cx,
|
||||
);
|
||||
|
@ -870,7 +871,7 @@ impl ActiveThread {
|
|||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_label: impl Into<SharedString>,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_input: &str,
|
||||
tool_output: SharedString,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -893,11 +894,10 @@ impl ActiveThread {
|
|||
this.replace(tool_label, cx);
|
||||
});
|
||||
rendered.input.update(cx, |this, cx| {
|
||||
let input = format!(
|
||||
"```json\n{}\n```",
|
||||
serde_json::to_string_pretty(tool_input).unwrap_or_default()
|
||||
this.replace(
|
||||
MarkdownString::code_block("json", tool_input).to_string(),
|
||||
cx,
|
||||
);
|
||||
this.replace(input, cx);
|
||||
});
|
||||
rendered.output.update(cx, |this, cx| {
|
||||
this.replace(tool_output, cx);
|
||||
|
@ -988,7 +988,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
|
@ -1002,7 +1002,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text.clone(),
|
||||
input,
|
||||
&serde_json::to_string_pretty(&input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
|
@ -1014,7 +1014,7 @@ impl ActiveThread {
|
|||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(&tool_use.id)
|
||||
|
@ -1026,6 +1026,23 @@ impl ActiveThread {
|
|||
}
|
||||
ThreadEvent::CheckpointChanged => cx.notify(),
|
||||
ThreadEvent::ReceivedTextChunk => {}
|
||||
ThreadEvent::InvalidToolInput {
|
||||
tool_use_id,
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
} => {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(tool_use_id)
|
||||
.map(|output| output.clone().into())
|
||||
.unwrap_or("".into()),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ use anyhow::Result;
|
|||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
|
@ -508,7 +510,9 @@ impl CodegenAlternative {
|
|||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
|
|
@ -17,10 +17,10 @@ use gpui::{
|
|||
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
TokenUsage,
|
||||
};
|
||||
|
@ -1275,9 +1275,30 @@ impl Thread {
|
|||
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
|
||||
}
|
||||
|
||||
let event = event?;
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
}) => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(LanguageModelCompletionError::Other(error)) => {
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
request_assistant_message_id = Some(thread.insert_message(
|
||||
|
@ -1390,7 +1411,8 @@ impl Thread {
|
|||
cx.notify();
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
})?;
|
||||
Ok(())
|
||||
})??;
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
|
@ -1681,6 +1703,41 @@ impl Thread {
|
|||
pending_tool_uses
|
||||
}
|
||||
|
||||
pub fn receive_invalid_tool_json(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
invalid_json: Arc<str>,
|
||||
error: String,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
log::error!("The model returned invalid input JSON: {invalid_json}");
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
} else {
|
||||
log::error!(
|
||||
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
|
||||
);
|
||||
format!("Unknown tool {}", tool_use_id).into()
|
||||
};
|
||||
|
||||
cx.emit(ThreadEvent::InvalidToolInput {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json: invalid_json,
|
||||
});
|
||||
|
||||
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
}
|
||||
|
||||
pub fn run_tool(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
|
|||
ui_text: Arc<str>,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
InvalidToolInput {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: Arc<str>,
|
||||
invalid_input_json: Arc<str>,
|
||||
},
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
|
|
|
@ -22,7 +22,7 @@ use feature_flags::{
|
|||
};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt,
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
join,
|
||||
|
@ -3056,7 +3056,8 @@ impl CodegenAlternative {
|
|||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks =
|
||||
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
|
|
@ -253,6 +253,9 @@ impl ExampleContext {
|
|||
}
|
||||
});
|
||||
}
|
||||
ThreadEvent::InvalidToolInput { .. } => {
|
||||
println!("{log_prefix} invalid tool input");
|
||||
}
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
panic!(
|
||||
"{}Bug: Tool confirmation should not be required in eval",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
|
@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move {
|
||||
|
|
|
@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
|
|||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Indicates the format used to define the input schema for a language model tool.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub enum LanguageModelToolSchemaFormat {
|
||||
|
@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
|
|||
|
||||
pub struct LanguageModelTextStream {
|
||||
pub message_id: Option<String>,
|
||||
pub stream: BoxStream<'static, Result<String>>,
|
||||
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
|
||||
// Has complete token usage after the stream has finished
|
||||
pub last_token_usage: Arc<Mutex<TokenUsage>>,
|
||||
}
|
||||
|
@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
>;
|
||||
|
||||
fn stream_completion_with_usage(
|
||||
&self,
|
||||
|
@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<(
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
Option<RequestUsage>,
|
||||
)>,
|
||||
> {
|
||||
|
|
|
@ -12,10 +12,10 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
|
||||
RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -27,7 +27,7 @@ use std::sync::Arc;
|
|||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::ResultExt;
|
||||
|
||||
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
|
@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
self.model.request_id().into(),
|
||||
|
@ -626,7 +631,7 @@ pub fn into_anthropic(
|
|||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
|
@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
|
|||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
let input_json = tool_use.input_json.trim();
|
||||
let input_value = if input_json.is_empty() {
|
||||
Ok(serde_json::Value::Object(serde_json::Map::default()))
|
||||
} else {
|
||||
serde_json::Value::from_str(input_json)
|
||||
};
|
||||
let event_result = match input_value {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
return Some((
|
||||
vec![maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input: if input_json.is_empty() {
|
||||
serde_json::Value::Object(
|
||||
serde_json::Map::default(),
|
||||
)
|
||||
} else {
|
||||
serde_json::Value::from_str(
|
||||
input_json
|
||||
)
|
||||
.map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
|
||||
},
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
},
|
||||
))
|
||||
})],
|
||||
state,
|
||||
));
|
||||
return Some((vec![event_result], state));
|
||||
}
|
||||
}
|
||||
Event::MessageStart { message } => {
|
||||
|
@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
Event::Error { error } => {
|
||||
return Some((
|
||||
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
AnthropicError::ApiError(error)
|
||||
)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
|
|||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, TokenUsage,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
|
||||
// Get region - from credentials or directly from settings
|
||||
let region = state
|
||||
|
@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
|
|||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
handle: Handle,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
|
@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
|
|||
_ => {}
|
||||
},
|
||||
|
||||
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
|
||||
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
|
||||
}
|
||||
}
|
||||
None
|
||||
|
|
|
@ -10,11 +10,11 @@ use futures::{
|
|||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
|
@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
self.stream_completion_with_usage(request, cx)
|
||||
.map(|result| result.map(|(stream, _)| stream))
|
||||
.boxed()
|
||||
|
@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<(
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
Option<RequestUsage>,
|
||||
)>,
|
||||
> {
|
||||
|
|
|
@ -17,16 +17,16 @@ use gpui::{
|
|||
Transformation, percentage, svg,
|
||||
};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
|
||||
StopReason,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
use util::maybe;
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
|
@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
if let Some(message) = request.messages.last() {
|
||||
if message.contents_empty() {
|
||||
const EMPTY_PROMPT_MSG: &str =
|
||||
|
@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||
is_streaming: bool,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
|
@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
|
|||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
vec![Err(anyhow!("Response contained no choices").into())],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
|
|||
|
||||
let Some(delta) = delta else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no delta"))],
|
||||
vec![Err(anyhow!("Response contained no delta").into())],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
|(_, tool_call)| match serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
|
@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
|
|||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_deepseek(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
|
@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
|
|
|
@ -11,8 +11,9 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
|
||||
StopReason,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
|
@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let response = request
|
||||
.await
|
||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
|
@ -471,7 +479,7 @@ pub fn into_google(
|
|||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
@ -492,7 +500,7 @@ pub fn map_to_language_model_completion_events(
|
|||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut state.usage, &usage_metadata);
|
||||
|
@ -559,7 +567,10 @@ pub fn map_to_language_model_completion_events(
|
|||
return Some((events, state));
|
||||
}
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anyhow!(err))], state));
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
|
|||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
|
@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = self.to_lmstudio_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
|
@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
|
|
|
@ -8,9 +8,9 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
use futures::stream::BoxStream;
|
||||
|
@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_mistral(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
|
@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
|
|||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
|
|||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
|
@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
|
@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
|
|
|
@ -9,10 +9,10 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||
};
|
||||
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -24,7 +24,7 @@ use std::sync::Arc;
|
|||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
|
@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
|
@ -419,7 +424,7 @@ pub fn into_open_ai(
|
|||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
|
@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
|
|||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
|(_, tool_call)| match serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
|
@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
|
|||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
Err(err) => {
|
||||
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue