Reuse conversation cache when streaming edits (#30245)

Release Notes:

- Improved latency when the agent applies edits.
This commit is contained in:
Antonio Scandurra 2025-05-08 14:36:34 +02:00 committed by GitHub
parent 032022e37b
commit 9f6809a28d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 847 additions and 21557 deletions

View file

@ -1411,6 +1411,7 @@ impl ActiveThread {
mode: None, mode: None,
messages: vec![request_message], messages: vec![request_message],
tools: vec![], tools: vec![],
tool_choice: None,
stop: vec![], stop: vec![],
temperature: AssistantSettings::temperature_for_model( temperature: AssistantSettings::temperature_for_model(
&configured_model.model, &configured_model.model,
@ -3256,7 +3257,7 @@ impl ActiveThread {
c.tool_use_id.clone(), c.tool_use_id.clone(),
c.ui_text.clone(), c.ui_text.clone(),
c.input.clone(), c.input.clone(),
&c.messages, c.request.clone(),
c.tool.clone(), c.tool.clone(),
configured.model, configured.model,
Some(window.window_handle()), Some(window.window_handle()),

View file

@ -466,6 +466,7 @@ impl CodegenAlternative {
prompt_id: None, prompt_id: None,
mode: None, mode: None,
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature, temperature,
messages: vec![request_message], messages: vec![request_message],

View file

@ -4,7 +4,7 @@ use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource}; use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use context_server::{ContextServerId, types}; use context_server::{ContextServerId, types};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{Project, context_server_store::ContextServerStore}; use project::{Project, context_server_store::ContextServerStore};
use ui::IconName; use ui::IconName;
@ -72,7 +72,7 @@ impl Tool for ContextServerTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -1245,6 +1245,7 @@ impl MessageEditor {
mode: None, mode: None,
messages: vec![request_message], messages: vec![request_message],
tools: vec![], tools: vec![],
tool_choice: None,
stop: vec![], stop: vec![],
temperature: AssistantSettings::temperature_for_model(&model.model, cx), temperature: AssistantSettings::temperature_for_model(&model.model, cx),
}; };

View file

@ -293,6 +293,7 @@ impl TerminalInlineAssistant {
mode: None, mode: None,
messages: vec![request_message], messages: vec![request_message],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature, temperature,
} }

View file

@ -1183,6 +1183,7 @@ impl Thread {
mode: None, mode: None,
messages: vec![], messages: vec![],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature: AssistantSettings::temperature_for_model(&model, cx), temperature: AssistantSettings::temperature_for_model(&model, cx),
}; };
@ -1227,6 +1228,7 @@ impl Thread {
})); }));
} }
let mut message_ix_to_cache = None;
for message in &self.messages { for message in &self.messages {
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
role: message.role, role: message.role,
@ -1263,19 +1265,57 @@ impl Thread {
}; };
} }
self.tool_use let mut cache_message = true;
.attach_tool_uses(message.id, &mut request_message); let mut tool_results_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
if let Some(tool_result) = tool_result {
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
tool_results_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
output: None,
}));
} else {
cache_message = false;
log::debug!(
"skipped tool use {:?} because it is still pending",
tool_use
);
}
}
if cache_message {
message_ix_to_cache = Some(request.messages.len());
}
request.messages.push(request_message); request.messages.push(request_message);
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) { if !tool_results_message.content.is_empty() {
if cache_message {
message_ix_to_cache = Some(request.messages.len());
}
request.messages.push(tool_results_message); request.messages.push(tool_results_message);
} }
} }
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
if let Some(last) = request.messages.last_mut() { if let Some(message_ix_to_cache) = message_ix_to_cache {
last.cache = true; request.messages[message_ix_to_cache].cache = true;
} }
self.attached_tracked_files_state(&mut request.messages, cx); self.attached_tracked_files_state(&mut request.messages, cx);
@ -1302,6 +1342,7 @@ impl Thread {
mode: None, mode: None,
messages: vec![], messages: vec![],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature: AssistantSettings::temperature_for_model(model, cx), temperature: AssistantSettings::temperature_for_model(model, cx),
}; };
@ -1918,8 +1959,7 @@ impl Thread {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
) -> Vec<PendingToolUse> { ) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx); self.auto_capture_telemetry(cx);
let request = self.to_completion_request(model.clone(), cx); let request = Arc::new(self.to_completion_request(model.clone(), cx));
let messages = Arc::new(request.messages);
let pending_tool_uses = self let pending_tool_uses = self
.tool_use .tool_use
.pending_tool_uses() .pending_tool_uses()
@ -1937,7 +1977,7 @@ impl Thread {
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
tool_use.input.clone(), tool_use.input.clone(),
messages.clone(), request.clone(),
tool, tool,
); );
cx.emit(ThreadEvent::ToolConfirmationNeeded); cx.emit(ThreadEvent::ToolConfirmationNeeded);
@ -1946,7 +1986,7 @@ impl Thread {
tool_use.id.clone(), tool_use.id.clone(),
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
tool_use.input.clone(), tool_use.input.clone(),
&messages, request.clone(),
tool, tool,
model.clone(), model.clone(),
window, window,
@ -2041,21 +2081,14 @@ impl Thread {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<SharedString>, ui_text: impl Into<SharedString>,
input: serde_json::Value, input: serde_json::Value,
messages: &[LanguageModelRequestMessage], request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>, tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>, cx: &mut Context<Thread>,
) { ) {
let task = self.spawn_tool_use( let task =
tool_use_id.clone(), self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
messages,
input,
tool,
model,
window,
cx,
);
self.tool_use self.tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task); .run_pending_tool(tool_use_id, ui_text.into(), task);
} }
@ -2063,7 +2096,7 @@ impl Thread {
fn spawn_tool_use( fn spawn_tool_use(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
messages: &[LanguageModelRequestMessage], request: Arc<LanguageModelRequest>,
input: serde_json::Value, input: serde_json::Value,
tool: Arc<dyn Tool>, tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
@ -2077,7 +2110,7 @@ impl Thread {
} else { } else {
tool.run( tool.run(
input, input,
messages, request,
self.project.clone(), self.project.clone(),
self.action_log.clone(), self.action_log.clone(),
model, model,

View file

@ -7,8 +7,8 @@ use futures::FutureExt as _;
use futures::future::Shared; use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language_model::{ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult, ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, LanguageModelToolUse, LanguageModelToolUseId, Role,
}; };
use project::Project; use project::Project;
use ui::{IconName, Window}; use ui::{IconName, Window};
@ -354,7 +354,7 @@ impl ToolUseState {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>, ui_text: impl Into<Arc<str>>,
input: serde_json::Value, input: serde_json::Value,
messages: Arc<Vec<LanguageModelRequestMessage>>, request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>, tool: Arc<dyn Tool>,
) { ) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
@ -363,7 +363,7 @@ impl ToolUseState {
let confirmation = Confirmation { let confirmation = Confirmation {
tool_use_id, tool_use_id,
input, input,
messages, request,
tool, tool,
ui_text, ui_text,
}; };
@ -449,72 +449,20 @@ impl ToolUseState {
} }
} }
pub fn attach_tool_uses(
&self,
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
// Do not send tool uses until they are completed
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
} else {
log::debug!(
"skipped tool use {:?} because it is still pending",
tool_use
);
}
}
}
}
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool { pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id) .contains_key(&assistant_message_id)
} }
pub fn tool_results_message( pub fn tool_results(
&self, &self,
assistant_message_id: MessageId, assistant_message_id: MessageId,
) -> Option<LanguageModelRequestMessage> { ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
let tool_uses = self self.tool_uses_by_assistant_message
.tool_uses_by_assistant_message .get(&assistant_message_id)
.get(&assistant_message_id)?; .into_iter()
.flatten()
if tool_uses.is_empty() { .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
for tool_use in tool_uses {
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
request_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
output: None,
}));
}
}
Some(request_message)
} }
} }
@ -535,7 +483,7 @@ pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value, pub input: serde_json::Value,
pub ui_text: Arc<str>, pub ui_text: Arc<str>,
pub messages: Arc<Vec<LanguageModelRequestMessage>>, pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>, pub tool: Arc<dyn Tool>,
} }

View file

@ -578,6 +578,7 @@ pub enum ToolChoice {
Auto, Auto,
Any, Any,
Tool { name: String }, Tool { name: String },
None,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -2585,6 +2585,7 @@ impl AssistantContext {
mode: None, mode: None,
messages: Vec::new(), messages: Vec::new(),
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature: model temperature: model
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)), .and_then(|model| AssistantSettings::temperature_for_model(model, cx)),

View file

@ -19,7 +19,7 @@ use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity}; use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName; use icons::IconName;
use language_model::LanguageModel; use language_model::LanguageModel;
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat; use language_model::LanguageModelToolSchemaFormat;
use project::Project; use project::Project;
use workspace::Workspace; use workspace::Workspace;
@ -206,7 +206,7 @@ pub trait Tool: 'static + Send + Sync {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
messages: &[LanguageModelRequestMessage], request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,

View file

@ -3,8 +3,8 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle; use gpui::AnyWindowHandle;
use gpui::{App, AppContext, Entity, Task}; use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelToolSchemaFormat; use language_model::LanguageModel;
use language_model::{LanguageModel, LanguageModelRequestMessage}; use language_model::{LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -74,7 +74,7 @@ impl Tool for CopyPathTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle; use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -61,7 +61,7 @@ impl Tool for CreateDirectoryTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::{SinkExt, StreamExt, channel::mpsc}; use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{Project, ProjectPath}; use project::{Project, ProjectPath};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -59,7 +59,7 @@ impl Tool for DeletePathTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt}; use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -79,7 +79,7 @@ impl Tool for DiagnosticsTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -17,7 +17,7 @@ use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point}; use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage, LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, LanguageModelToolChoice, MessageContent, Role,
}; };
use project::{AgentLocation, Project}; use project::{AgentLocation, Project};
use serde::Serialize; use serde::Serialize;
@ -83,7 +83,7 @@ impl EditAgent {
&self, &self,
buffer: Entity<Buffer>, buffer: Entity<Buffer>,
edit_description: String, edit_description: String,
previous_messages: Vec<LanguageModelRequestMessage>, conversation: &LanguageModelRequest,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> ( ) -> (
Task<Result<EditAgentOutput>>, Task<Result<EditAgentOutput>>,
@ -91,6 +91,7 @@ impl EditAgent {
) { ) {
let this = self.clone(); let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded(); let (events_tx, events_rx) = mpsc::unbounded();
let conversation = conversation.clone();
let output = cx.spawn(async move |cx| { let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?; let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
@ -99,7 +100,7 @@ impl EditAgent {
edit_description, edit_description,
} }
.render(&this.templates)?; .render(&this.templates)?;
let new_chunks = this.request(previous_messages, prompt, cx).await?; let new_chunks = this.request(conversation, prompt, cx).await?;
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx); let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
while let Some(event) = inner_events.next().await { while let Some(event) = inner_events.next().await {
@ -194,7 +195,7 @@ impl EditAgent {
&self, &self,
buffer: Entity<Buffer>, buffer: Entity<Buffer>,
edit_description: String, edit_description: String,
previous_messages: Vec<LanguageModelRequestMessage>, conversation: &LanguageModelRequest,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> ( ) -> (
Task<Result<EditAgentOutput>>, Task<Result<EditAgentOutput>>,
@ -214,6 +215,7 @@ impl EditAgent {
let this = self.clone(); let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded(); let (events_tx, events_rx) = mpsc::unbounded();
let conversation = conversation.clone();
let output = cx.spawn(async move |cx| { let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?; let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
@ -222,7 +224,7 @@ impl EditAgent {
edit_description, edit_description,
} }
.render(&this.templates)?; .render(&this.templates)?;
let edit_chunks = this.request(previous_messages, prompt, cx).await?; let edit_chunks = this.request(conversation, prompt, cx).await?;
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx); let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
while let Some(event) = inner_events.next().await { while let Some(event) = inner_events.next().await {
@ -512,32 +514,67 @@ impl EditAgent {
async fn request( async fn request(
&self, &self,
mut messages: Vec<LanguageModelRequestMessage>, mut conversation: LanguageModelRequest,
prompt: String, prompt: String,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> { ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
let mut message_content = Vec::new(); let mut messages_iter = conversation.messages.iter_mut();
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages_iter.next_back() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant {
let old_content_len = last_message.content.len();
last_message last_message
.content .content
.retain(|content| !matches!(content, MessageContent::ToolUse(_))); .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
let new_content_len = last_message.content.len();
// We just removed pending tool uses from the content of the
// last message, so it doesn't make sense to cache it anymore
// (e.g., the message will look very different on the next
// request). Thus, we move the flag to the message prior to it,
// as it will still be a valid prefix of the conversation.
if old_content_len != new_content_len && last_message.cache {
if let Some(prev_message) = messages_iter.next_back() {
last_message.cache = false;
prev_message.cache = true;
}
}
if last_message.content.is_empty() { if last_message.content.is_empty() {
messages.pop(); conversation.messages.pop();
} }
} }
} }
message_content.push(MessageContent::Text(prompt));
messages.push(LanguageModelRequestMessage { conversation.messages.push(LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: message_content, content: vec![MessageContent::Text(prompt)],
cache: false, cache: false,
}); });
// Include tools in the request so that we can take advantage of
// caching when ToolChoice::None is supported.
let mut tool_choice = None;
let mut tools = Vec::new();
if !conversation.tools.is_empty()
&& self
.model
.supports_tool_choice(LanguageModelToolChoice::None)
{
tool_choice = Some(LanguageModelToolChoice::None);
tools = conversation.tools.clone();
}
let request = LanguageModelRequest { let request = LanguageModelRequest {
messages, thread_id: conversation.thread_id,
..Default::default() prompt_id: conversation.prompt_id,
mode: conversation.mode,
messages: conversation.messages,
tool_choice,
tools,
stop: Vec::new(),
temperature: None,
}; };
Ok(self.model.stream_completion_text(request, cx).await?.stream) Ok(self.model.stream_completion_text(request, cx).await?.stream)
} }

View file

@ -2,14 +2,16 @@ use super::*;
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput}; use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
use Role::*; use Role::*;
use anyhow::anyhow; use anyhow::anyhow;
use assistant_tool::ToolRegistry;
use client::{Client, UserStore}; use client::{Client, UserStore};
use collections::HashMap; use collections::HashMap;
use fs::FakeFs; use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture}; use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext}; use gpui::{AppContext, TestAppContext};
use indoc::indoc; use indoc::{formatdoc, indoc};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId,
}; };
use project::Project; use project::Project;
use rand::prelude::*; use rand::prelude::*;
@ -37,7 +39,7 @@ fn eval_extract_handle_command_output() {
conversation: vec![ conversation: vec![
message( message(
User, User,
[text(indoc! {" [text(formatdoc! {"
Read the `{input_file_path}` file and extract a method in Read the `{input_file_path}` file and extract a method in
the final stanza of `run_git_blame` to deal with command failures, the final stanza of `run_git_blame` to deal with command failures,
call it `handle_command_output` and take the std::process::Output as the only parameter. call it `handle_command_output` and take the std::process::Output as the only parameter.
@ -96,7 +98,7 @@ fn eval_delete_run_git_blame() {
conversation: vec![ conversation: vec![
message( message(
User, User,
[text(indoc! {" [text(formatdoc! {"
Read the `{input_file_path}` file and delete `run_git_blame`. Just that Read the `{input_file_path}` file and delete `run_git_blame`. Just that
one function, not its usages. one function, not its usages.
"})], "})],
@ -138,6 +140,61 @@ fn eval_delete_run_git_blame() {
); );
} }
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_translate_doc_comments() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
let edit_description = "Translate all doc comments to Italian";
eval(
200,
1.,
EvalInput {
conversation: vec![
message(
User,
[text(formatdoc! {"
Read the {input_file_path} file and edit it (without overwriting it),
translating all the doc comments to italian.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
},
);
}
#[test] #[test]
#[cfg_attr(not(feature = "eval"), ignore)] #[cfg_attr(not(feature = "eval"), ignore)]
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
@ -152,7 +209,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
conversation: vec![ conversation: vec![
message( message(
User, User,
[text(indoc! {" [text(formatdoc! {"
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten. Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
Use `ureq` to download the SDK for the current platform and architecture. Use `ureq` to download the SDK for the current platform and architecture.
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir. Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
@ -160,7 +217,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
that's inside of the archive. that's inside of the archive.
Don't re-download the SDK if that executable already exists. Don't re-download the SDK if that executable already exists.
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name} Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
Here are the available wasi-sdk assets: Here are the available wasi-sdk assets:
- wasi-sdk-25.0-x86_64-macos.tar.gz - wasi-sdk-25.0-x86_64-macos.tar.gz
@ -261,11 +318,10 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
fn eval_disable_cursor_blinking() { fn eval_disable_cursor_blinking() {
let input_file_path = "root/editor.rs"; let input_file_path = "root/editor.rs";
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs"); let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
let edit_description = "Comment out the call to `BlinkManager::enable`"; let edit_description = "Comment out the call to `BlinkManager::enable`";
eval( eval(
200, 200,
0.6, // TODO: make this eval better 0.95,
EvalInput { EvalInput {
conversation: vec![ conversation: vec![
message(User, [text("Let's research how to cursor blinking works.")]), message(User, [text("Let's research how to cursor blinking works.")]),
@ -324,7 +380,11 @@ fn eval_disable_cursor_blinking() {
input_path: input_file_path.into(), input_path: input_file_path.into(),
input_content: Some(input_file_content.into()), input_content: Some(input_file_content.into()),
edit_description: edit_description.into(), edit_description: edit_description.into(),
assertion: EvalAssertion::assert_eq(output_file_content), assertion: EvalAssertion::judge_diff(indoc! {"
- Calls to BlinkManager in `observe_window_activation` were commented out
- The call to `blink_manager.enable` above the call to show_cursor_names was commented out
- All the edits have valid indentation
"}),
}, },
); );
} }
@ -1031,7 +1091,8 @@ impl EvalAssertion {
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) { fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0; let mut evaluated_count = 0;
report_progress(evaluated_count, iterations); let mut failed_count = 0;
report_progress(evaluated_count, failed_count, iterations);
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
@ -1048,7 +1109,6 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
} }
drop(tx); drop(tx);
let mut failed_count = 0;
let mut failed_evals = HashMap::default(); let mut failed_evals = HashMap::default();
let mut errored_evals = HashMap::default(); let mut errored_evals = HashMap::default();
let mut eval_outputs = Vec::new(); let mut eval_outputs = Vec::new();
@ -1073,7 +1133,7 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
} }
evaluated_count += 1; evaluated_count += 1;
report_progress(evaluated_count, iterations); report_progress(evaluated_count, failed_count, iterations);
} }
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32; let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
@ -1144,8 +1204,19 @@ impl Display for EvalOutput {
} }
} }
fn report_progress(evaluated_count: usize, iterations: usize) { fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations); let passed_count = evaluated_count - failed_count;
let passed_ratio = if evaluated_count == 0 {
0.0
} else {
passed_count as f64 / evaluated_count as f64
};
print!(
"\r\x1b[KEvaluated {}/{} ({:.2}%)",
evaluated_count,
iterations,
passed_ratio * 100.0
);
std::io::stdout().flush().unwrap(); std::io::stdout().flush().unwrap();
} }
@ -1158,25 +1229,30 @@ struct EditAgentTest {
impl EditAgentTest { impl EditAgentTest {
async fn new(cx: &mut TestAppContext) -> Self { async fn new(cx: &mut TestAppContext) -> Self {
cx.executor().allow_parking(); cx.executor().allow_parking();
cx.update(settings::init);
cx.update(Project::init_settings);
cx.update(language::init);
cx.update(gpui_tokio::init);
cx.update(client::init_settings);
let fs = FakeFs::new(cx.executor().clone()); let fs = FakeFs::new(cx.executor().clone());
cx.update(|cx| {
settings::init(cx);
gpui_tokio::init(cx);
let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
cx.set_http_client(http_client);
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
Project::init_settings(cx);
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
crate::init(client.http_client(), cx);
});
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let (agent_model, judge_model) = cx let (agent_model, judge_model) = cx
.update(|cx| { .update(|cx| {
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let agent_model = let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await; Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
@ -1225,12 +1301,32 @@ impl EditAgentTest {
.update(cx, |project, cx| project.open_buffer(path, cx)) .update(cx, |project, cx| project.open_buffer(path, cx))
.await .await
.unwrap(); .unwrap();
let conversation = LanguageModelRequest {
messages: eval.conversation,
tools: cx.update(|cx| {
ToolRegistry::default_global(cx)
.tools()
.into_iter()
.filter_map(|tool| {
let input_schema = tool
.input_schema(self.agent.model.tool_input_format())
.ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
})
.collect()
}),
..Default::default()
};
let edit_output = if let Some(input_content) = eval.input_content.as_deref() { let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx)); buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
let (edit_output, _) = self.agent.edit( let (edit_output, _) = self.agent.edit(
buffer.clone(), buffer.clone(),
eval.edit_description, eval.edit_description,
eval.conversation, &conversation,
&mut cx.to_async(), &mut cx.to_async(),
); );
edit_output.await? edit_output.await?
@ -1238,7 +1334,7 @@ impl EditAgentTest {
let (edit_output, _) = self.agent.overwrite( let (edit_output, _) = self.agent.overwrite(
buffer.clone(), buffer.clone(),
eval.edit_description, eval.edit_description,
eval.conversation, &conversation,
&mut cx.to_async(), &mut cx.to_async(),
); );
edit_output.await? edit_output.await?

View file

@ -0,0 +1,339 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

View file

@ -19,7 +19,7 @@ use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer, Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
language_settings::SoftWrap, language_settings::SoftWrap,
}; };
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -146,7 +146,7 @@ impl Tool for EditFileTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
messages: &[LanguageModelRequestMessage], request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
@ -177,7 +177,6 @@ impl Tool for EditFileTool {
}); });
let card_clone = card.clone(); let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| { let task = cx.spawn(async move |cx: &mut AsyncApp| {
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new()); let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
@ -209,14 +208,14 @@ impl Tool for EditFileTool {
edit_agent.overwrite( edit_agent.overwrite(
buffer.clone(), buffer.clone(),
input.display_description.clone(), input.display_description.clone(),
messages, &request,
cx, cx,
) )
} else { } else {
edit_agent.edit( edit_agent.edit(
buffer.clone(), buffer.clone(),
input.display_description.clone(), input.display_description.clone(),
messages, &request,
cx, cx,
) )
}; };
@ -847,7 +846,15 @@ mod tests {
}) })
.unwrap(); .unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;

View file

@ -9,7 +9,7 @@ use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl}; use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -142,7 +142,7 @@ impl Tool for FetchTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -7,7 +7,7 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
}; };
use language; use language;
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -73,7 +73,7 @@ impl Tool for FindPathTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt; use futures::StreamExt;
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{OffsetRangeExt, ParseStatus, Point}; use language::{OffsetRangeExt, ParseStatus, Point};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{ use project::{
Project, Project,
search::{SearchQuery, SearchResult}, search::{SearchQuery, SearchResult},
@ -96,7 +96,7 @@ impl Tool for GrepTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,
@ -746,7 +746,8 @@ mod tests {
let tool = Arc::new(GrepTool); let tool = Arc::new(GrepTool);
let action_log = cx.new(|_cx| ActionLog::new(project.clone())); let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let task = cx.update(|cx| tool.run(input, &[], project, action_log, model, None, cx)); let task =
cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
match task.output.await { match task.output.await {
Ok(result) => { Ok(result) => {

View file

@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -73,7 +73,7 @@ impl Tool for ListDirectoryTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -86,7 +86,7 @@ impl Tool for MovePathTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -5,7 +5,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc}; use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -56,7 +56,7 @@ impl Tool for NowTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -49,7 +49,7 @@ impl Tool for OpenTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -7,7 +7,7 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
use indoc::formatdoc; use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use language::{Anchor, Point}; use language::{Anchor, Point};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{AgentLocation, Project}; use project::{AgentLocation, Project};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -83,7 +83,7 @@ impl Tool for ReadFileTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,
@ -231,7 +231,15 @@ mod test {
"path": "root/nonexistent_file.txt" "path": "root/nonexistent_file.txt"
}); });
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;
@ -262,7 +270,15 @@ mod test {
"path": "root/small_file.txt" "path": "root/small_file.txt"
}); });
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;
@ -295,7 +311,7 @@ mod test {
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
&[], Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
model.clone(), model.clone(),
@ -325,7 +341,15 @@ mod test {
"offset": 1 "offset": 1
}); });
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;
@ -372,7 +396,15 @@ mod test {
"end_line": 4 "end_line": 4
}); });
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;
@ -406,7 +438,7 @@ mod test {
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
&[], Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
model.clone(), model.clone(),
@ -429,7 +461,7 @@ mod test {
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
&[], Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
model.clone(), model.clone(),
@ -450,7 +482,15 @@ mod test {
"end_line": 2 "end_line": 2
}); });
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, model, None, cx) .run(
input,
Arc::default(),
project.clone(),
action_log,
model,
None,
cx,
)
.output .output
}) })
.await; .await;

View file

@ -1,6 +1,4 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make. You MUST respond with a series of edits to a file, using the following format:
You MUST respond with a series of edits to that one file in the following format:
``` ```
<edits> <edits>
@ -51,3 +49,5 @@ Rules for editing:
<edit_description> <edit_description>
{{edit_description}} {{edit_description}}
</edit_description> </edit_description>
Tool calls have been disabled. You MUST start your response with <edits>.

View file

@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared}; use futures::{FutureExt as _, future::Shared};
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window}; use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
use language::LineEnding; use language::LineEnding;
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use portable_pty::{CommandBuilder, PtySize, native_pty_system}; use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use project::{Project, terminals::TerminalKind}; use project::{Project, terminals::TerminalKind};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -107,7 +107,7 @@ impl Tool for TerminalTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,
@ -656,7 +656,7 @@ mod tests {
TerminalTool::run( TerminalTool::run(
Arc::new(TerminalTool::new(cx)), Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(), serde_json::to_value(input).unwrap(),
&[], Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
model, model,
@ -691,7 +691,7 @@ mod tests {
let headless_result = TerminalTool::run( let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)), Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(), serde_json::to_value(input).unwrap(),
&[], Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
model.clone(), model.clone(),

View file

@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -47,7 +47,7 @@ impl Tool for ThinkingTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -8,7 +8,7 @@ use futures::{Future, FutureExt, TryFutureExt};
use gpui::{ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
}; };
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -53,7 +53,7 @@ impl Tool for WebSearchTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>, _model: Arc<dyn LanguageModel>,

View file

@ -7,9 +7,10 @@ use anyhow::{Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock; use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client; pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{ pub use aws_sdk_bedrockruntime::types::{
AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent, AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig, ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec, ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
ToolSpecification as BedrockToolSpec,
}; };
pub use aws_smithy_types::Blob as BedrockBlob; pub use aws_smithy_types::Blob as BedrockBlob;
use aws_smithy_types::{Document, Number as AwsNumber}; use aws_smithy_types::{Document, Number as AwsNumber};

View file

@ -182,11 +182,11 @@ pub enum Tool {
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ToolChoice { pub enum ToolChoice {
Auto, Auto,
Any, Any,
Tool { name: String }, None,
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View file

@ -578,6 +578,7 @@ impl ExampleInstance {
}], }],
temperature: None, temperature: None,
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
}; };

View file

@ -1774,6 +1774,7 @@ impl GitPanel {
cache: false, cache: false,
}], }],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature, temperature,
}; };

View file

@ -2,6 +2,7 @@ use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice,
}; };
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@ -152,6 +153,10 @@ impl LanguageModel for FakeLanguageModel {
false false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
"fake".to_string() "fake".to_string()
} }

View file

@ -246,6 +246,9 @@ pub trait LanguageModel: Send + Sync {
/// Whether this model supports tools. /// Whether this model supports tools.
fn supports_tools(&self) -> bool; fn supports_tools(&self) -> bool;
/// Whether this model supports choosing which tool to use.
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
/// Returns whether this model supports "max mode"; /// Returns whether this model supports "max mode";
fn supports_max_mode(&self) -> bool { fn supports_max_mode(&self) -> bool {
if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID { if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {

View file

@ -203,6 +203,13 @@ pub struct LanguageModelRequestTool {
pub input_schema: serde_json::Value, pub input_schema: serde_json::Value,
} }
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
pub enum LanguageModelToolChoice {
Auto,
Any,
None,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct LanguageModelRequest { pub struct LanguageModelRequest {
pub thread_id: Option<String>, pub thread_id: Option<String>,
@ -210,6 +217,7 @@ pub struct LanguageModelRequest {
pub mode: Option<CompletionMode>, pub mode: Option<CompletionMode>,
pub messages: Vec<LanguageModelRequestMessage>, pub messages: Vec<LanguageModelRequestMessage>,
pub tools: Vec<LanguageModelRequestTool>, pub tools: Vec<LanguageModelRequestTool>,
pub tool_choice: Option<LanguageModelToolChoice>,
pub stop: Vec<String>, pub stop: Vec<String>,
pub temperature: Option<f32>, pub temperature: Option<f32>,
} }

View file

@ -15,7 +15,8 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, MessageContent,
RateLimiter, Role,
}; };
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -420,6 +421,14 @@ impl LanguageModel for AnthropicModel {
true true
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("anthropic/{}", self.model.id()) format!("anthropic/{}", self.model.id())
} }
@ -620,7 +629,11 @@ pub fn into_anthropic(
input_schema: tool.input_schema, input_schema: tool.input_schema,
}) })
.collect(), .collect(),
tool_choice: None, tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
}),
metadata: None, metadata: None,
stop_sequences: Vec::new(), stop_sequences: Vec::new(),
temperature: request.temperature.or(Some(default_temperature)), temperature: request.temperature.or(Some(default_temperature)),

View file

@ -15,11 +15,11 @@ use bedrock::bedrock_client::types::{
StopReason, StopReason,
}; };
use bedrock::{ use bedrock::{
BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent, BedrockMessage, BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
BedrockModelMode, BedrockStreamingResponse, BedrockThinkingBlock, BedrockThinkingTextBlock, BedrockMessage, BedrockModelMode, BedrockStreamingResponse, BedrockThinkingBlock,
BedrockTool, BedrockToolChoice, BedrockToolConfig, BedrockToolInputSchema, BedrockThinkingTextBlock, BedrockTool, BedrockToolChoice, BedrockToolConfig,
BedrockToolResultBlock, BedrockToolResultContentBlock, BedrockToolResultStatus, BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document, BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
}; };
use collections::{BTreeMap, HashMap}; use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
@ -35,8 +35,8 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
RateLimiter, Role, TokenUsage, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -520,6 +520,15 @@ impl LanguageModel for BedrockModel {
self.model.supports_tool_use() self.model.supports_tool_use()
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
self.model.supports_tool_use()
}
LanguageModelToolChoice::None => false,
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("bedrock/{}", self.model.id()) format!("bedrock/{}", self.model.id())
} }
@ -719,11 +728,20 @@ pub fn into_bedrock(
}) })
.collect(); .collect();
let tool_choice = match request.tool_choice {
Some(LanguageModelToolChoice::Auto) | None => {
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
}
Some(LanguageModelToolChoice::Any) => {
BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
}
Some(LanguageModelToolChoice::None) => {
return Err(anyhow!("LanguageModelToolChoice::None is not supported"));
}
};
let tool_config: BedrockToolConfig = BedrockToolConfig::builder() let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
.set_tools(Some(tool_spec)) .set_tools(Some(tool_spec))
.tool_choice(BedrockToolChoice::Auto( .tool_choice(tool_choice)
BedrockAutoToolChoice::builder().build(),
))
.build()?; .build()?;
Ok(bedrock::Request { Ok(bedrock::Request {

View file

@ -14,8 +14,9 @@ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -686,6 +687,14 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("zed.dev/{}", self.model.id()) format!("zed.dev/{}", self.model.id())
} }

View file

@ -20,8 +20,8 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolUse, MessageContent,
StopReason, RateLimiter, Role, StopReason,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
@ -197,6 +197,14 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => self.supports_tools(),
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("copilot_chat/{}", self.model.id()) format!("copilot_chat/{}", self.model.id())
} }
@ -541,7 +549,11 @@ impl CopilotChatLanguageModel {
model, model,
messages, messages,
tools, tools,
tool_choice: None, tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
}),
}) })
} }
} }

View file

@ -11,7 +11,8 @@ use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -282,6 +283,10 @@ impl LanguageModel for DeepSeekLanguageModel {
false false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("deepseek/{}", self.model.id()) format!("deepseek/{}", self.model.id())
} }

View file

@ -12,8 +12,8 @@ use gpui::{
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
StopReason, LanguageModelToolUseId, MessageContent, StopReason,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -313,6 +313,14 @@ impl LanguageModel for GoogleLanguageModel {
true true
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchemaSubset LanguageModelToolSchemaFormat::JsonSchemaSubset
} }
@ -484,7 +492,16 @@ pub fn into_google(
.collect(), .collect(),
}] }]
}), }),
tool_config: None, tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
function_calling_config: google_ai::FunctionCallingConfig {
mode: match choice {
LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
},
allowed_function_names: None,
},
}),
} }
} }

View file

@ -4,6 +4,7 @@ use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -284,6 +285,10 @@ impl LanguageModel for LmStudioLanguageModel {
false false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("lmstudio/{}", self.model.id()) format!("lmstudio/{}", self.model.id())
} }

View file

@ -10,7 +10,8 @@ use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
}; };
use futures::stream::BoxStream; use futures::stream::BoxStream;
@ -302,6 +303,10 @@ impl LanguageModel for MistralLanguageModel {
false false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("mistral/{}", self.model.id()) format!("mistral/{}", self.model.id())
} }

View file

@ -5,7 +5,8 @@ use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, StopReason,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -324,6 +325,14 @@ impl LanguageModel for OllamaLanguageModel {
self.model.supports_tools.unwrap_or(false) self.model.supports_tools.unwrap_or(false)
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto => false,
LanguageModelToolChoice::Any => false,
LanguageModelToolChoice::None => false,
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id()) format!("ollama/{}", self.model.id())
} }

View file

@ -12,7 +12,7 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, LanguageModelToolChoice, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
}; };
use open_ai::{Model, ResponseStreamEvent, stream_completion}; use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -295,6 +295,14 @@ impl LanguageModel for OpenAiLanguageModel {
true true
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto => true,
LanguageModelToolChoice::Any => true,
LanguageModelToolChoice::None => true,
}
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("openai/{}", self.model.id()) format!("openai/{}", self.model.id())
} }
@ -417,7 +425,11 @@ pub fn into_open_ai(
}, },
}) })
.collect(), .collect(),
tool_choice: None, tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
}),
} }
} }

View file

@ -929,6 +929,7 @@ impl RulesLibrary {
cache: false, cache: false,
}], }],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: None,
}, },

View file

@ -566,6 +566,7 @@ impl SummaryIndex {
cache: use_cache, cache: use_cache,
}], }],
tools: Vec::new(), tools: Vec::new(),
tool_choice: None,
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: None,
}; };