Reuse conversation cache when streaming edits (#30245)
Release Notes: - Improved latency when the agent applies edits.
This commit is contained in:
parent
032022e37b
commit
9f6809a28d
50 changed files with 847 additions and 21557 deletions
|
@ -1411,6 +1411,7 @@ impl ActiveThread {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
|
tool_choice: None,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: AssistantSettings::temperature_for_model(
|
temperature: AssistantSettings::temperature_for_model(
|
||||||
&configured_model.model,
|
&configured_model.model,
|
||||||
|
@ -3256,7 +3257,7 @@ impl ActiveThread {
|
||||||
c.tool_use_id.clone(),
|
c.tool_use_id.clone(),
|
||||||
c.ui_text.clone(),
|
c.ui_text.clone(),
|
||||||
c.input.clone(),
|
c.input.clone(),
|
||||||
&c.messages,
|
c.request.clone(),
|
||||||
c.tool.clone(),
|
c.tool.clone(),
|
||||||
configured.model,
|
configured.model,
|
||||||
Some(window.window_handle()),
|
Some(window.window_handle()),
|
||||||
|
|
|
@ -466,6 +466,7 @@ impl CodegenAlternative {
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
mode: None,
|
mode: None,
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature,
|
temperature,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
|
|
|
@ -4,7 +4,7 @@ use anyhow::{Result, anyhow, bail};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
||||||
use context_server::{ContextServerId, types};
|
use context_server::{ContextServerId, types};
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::{Project, context_server_store::ContextServerStore};
|
use project::{Project, context_server_store::ContextServerStore};
|
||||||
use ui::IconName;
|
use ui::IconName;
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ impl Tool for ContextServerTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -1245,6 +1245,7 @@ impl MessageEditor {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
|
tool_choice: None,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: AssistantSettings::temperature_for_model(&model.model, cx),
|
temperature: AssistantSettings::temperature_for_model(&model.model, cx),
|
||||||
};
|
};
|
||||||
|
|
|
@ -293,6 +293,7 @@ impl TerminalInlineAssistant {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature,
|
temperature,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1183,6 +1183,7 @@ impl Thread {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: AssistantSettings::temperature_for_model(&model, cx),
|
temperature: AssistantSettings::temperature_for_model(&model, cx),
|
||||||
};
|
};
|
||||||
|
@ -1227,6 +1228,7 @@ impl Thread {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut message_ix_to_cache = None;
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
let mut request_message = LanguageModelRequestMessage {
|
let mut request_message = LanguageModelRequestMessage {
|
||||||
role: message.role,
|
role: message.role,
|
||||||
|
@ -1263,19 +1265,57 @@ impl Thread {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tool_use
|
let mut cache_message = true;
|
||||||
.attach_tool_uses(message.id, &mut request_message);
|
let mut tool_results_message = LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: Vec::new(),
|
||||||
|
cache: false,
|
||||||
|
};
|
||||||
|
for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
|
||||||
|
if let Some(tool_result) = tool_result {
|
||||||
|
request_message
|
||||||
|
.content
|
||||||
|
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||||
|
tool_results_message
|
||||||
|
.content
|
||||||
|
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use.id.clone(),
|
||||||
|
tool_name: tool_result.tool_name.clone(),
|
||||||
|
is_error: tool_result.is_error,
|
||||||
|
content: if tool_result.content.is_empty() {
|
||||||
|
// Surprisingly, the API fails if we return an empty string here.
|
||||||
|
// It thinks we are sending a tool use without a tool result.
|
||||||
|
"<Tool returned an empty string>".into()
|
||||||
|
} else {
|
||||||
|
tool_result.content.clone()
|
||||||
|
},
|
||||||
|
output: None,
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
cache_message = false;
|
||||||
|
log::debug!(
|
||||||
|
"skipped tool use {:?} because it is still pending",
|
||||||
|
tool_use
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache_message {
|
||||||
|
message_ix_to_cache = Some(request.messages.len());
|
||||||
|
}
|
||||||
request.messages.push(request_message);
|
request.messages.push(request_message);
|
||||||
|
|
||||||
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
|
if !tool_results_message.content.is_empty() {
|
||||||
|
if cache_message {
|
||||||
|
message_ix_to_cache = Some(request.messages.len());
|
||||||
|
}
|
||||||
request.messages.push(tool_results_message);
|
request.messages.push(tool_results_message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||||
if let Some(last) = request.messages.last_mut() {
|
if let Some(message_ix_to_cache) = message_ix_to_cache {
|
||||||
last.cache = true;
|
request.messages[message_ix_to_cache].cache = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.attached_tracked_files_state(&mut request.messages, cx);
|
self.attached_tracked_files_state(&mut request.messages, cx);
|
||||||
|
@ -1302,6 +1342,7 @@ impl Thread {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: AssistantSettings::temperature_for_model(model, cx),
|
temperature: AssistantSettings::temperature_for_model(model, cx),
|
||||||
};
|
};
|
||||||
|
@ -1918,8 +1959,7 @@ impl Thread {
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
) -> Vec<PendingToolUse> {
|
) -> Vec<PendingToolUse> {
|
||||||
self.auto_capture_telemetry(cx);
|
self.auto_capture_telemetry(cx);
|
||||||
let request = self.to_completion_request(model.clone(), cx);
|
let request = Arc::new(self.to_completion_request(model.clone(), cx));
|
||||||
let messages = Arc::new(request.messages);
|
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
.tool_use
|
.tool_use
|
||||||
.pending_tool_uses()
|
.pending_tool_uses()
|
||||||
|
@ -1937,7 +1977,7 @@ impl Thread {
|
||||||
tool_use.id.clone(),
|
tool_use.id.clone(),
|
||||||
tool_use.ui_text.clone(),
|
tool_use.ui_text.clone(),
|
||||||
tool_use.input.clone(),
|
tool_use.input.clone(),
|
||||||
messages.clone(),
|
request.clone(),
|
||||||
tool,
|
tool,
|
||||||
);
|
);
|
||||||
cx.emit(ThreadEvent::ToolConfirmationNeeded);
|
cx.emit(ThreadEvent::ToolConfirmationNeeded);
|
||||||
|
@ -1946,7 +1986,7 @@ impl Thread {
|
||||||
tool_use.id.clone(),
|
tool_use.id.clone(),
|
||||||
tool_use.ui_text.clone(),
|
tool_use.ui_text.clone(),
|
||||||
tool_use.input.clone(),
|
tool_use.input.clone(),
|
||||||
&messages,
|
request.clone(),
|
||||||
tool,
|
tool,
|
||||||
model.clone(),
|
model.clone(),
|
||||||
window,
|
window,
|
||||||
|
@ -2041,21 +2081,14 @@ impl Thread {
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
ui_text: impl Into<SharedString>,
|
ui_text: impl Into<SharedString>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
messages: &[LanguageModelRequestMessage],
|
request: Arc<LanguageModelRequest>,
|
||||||
tool: Arc<dyn Tool>,
|
tool: Arc<dyn Tool>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
window: Option<AnyWindowHandle>,
|
window: Option<AnyWindowHandle>,
|
||||||
cx: &mut Context<Thread>,
|
cx: &mut Context<Thread>,
|
||||||
) {
|
) {
|
||||||
let task = self.spawn_tool_use(
|
let task =
|
||||||
tool_use_id.clone(),
|
self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
|
||||||
messages,
|
|
||||||
input,
|
|
||||||
tool,
|
|
||||||
model,
|
|
||||||
window,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
self.tool_use
|
self.tool_use
|
||||||
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
||||||
}
|
}
|
||||||
|
@ -2063,7 +2096,7 @@ impl Thread {
|
||||||
fn spawn_tool_use(
|
fn spawn_tool_use(
|
||||||
&mut self,
|
&mut self,
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
messages: &[LanguageModelRequestMessage],
|
request: Arc<LanguageModelRequest>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
tool: Arc<dyn Tool>,
|
tool: Arc<dyn Tool>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
@ -2077,7 +2110,7 @@ impl Thread {
|
||||||
} else {
|
} else {
|
||||||
tool.run(
|
tool.run(
|
||||||
input,
|
input,
|
||||||
messages,
|
request,
|
||||||
self.project.clone(),
|
self.project.clone(),
|
||||||
self.action_log.clone(),
|
self.action_log.clone(),
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -7,8 +7,8 @@ use futures::FutureExt as _;
|
||||||
use futures::future::Shared;
|
use futures::future::Shared;
|
||||||
use gpui::{App, Entity, SharedString, Task};
|
use gpui::{App, Entity, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
|
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use ui::{IconName, Window};
|
use ui::{IconName, Window};
|
||||||
|
@ -354,7 +354,7 @@ impl ToolUseState {
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
ui_text: impl Into<Arc<str>>,
|
ui_text: impl Into<Arc<str>>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
messages: Arc<Vec<LanguageModelRequestMessage>>,
|
request: Arc<LanguageModelRequest>,
|
||||||
tool: Arc<dyn Tool>,
|
tool: Arc<dyn Tool>,
|
||||||
) {
|
) {
|
||||||
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||||
|
@ -363,7 +363,7 @@ impl ToolUseState {
|
||||||
let confirmation = Confirmation {
|
let confirmation = Confirmation {
|
||||||
tool_use_id,
|
tool_use_id,
|
||||||
input,
|
input,
|
||||||
messages,
|
request,
|
||||||
tool,
|
tool,
|
||||||
ui_text,
|
ui_text,
|
||||||
};
|
};
|
||||||
|
@ -449,72 +449,20 @@ impl ToolUseState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn attach_tool_uses(
|
|
||||||
&self,
|
|
||||||
message_id: MessageId,
|
|
||||||
request_message: &mut LanguageModelRequestMessage,
|
|
||||||
) {
|
|
||||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
|
||||||
for tool_use in tool_uses {
|
|
||||||
if self.tool_results.contains_key(&tool_use.id) {
|
|
||||||
// Do not send tool uses until they are completed
|
|
||||||
request_message
|
|
||||||
.content
|
|
||||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
|
||||||
} else {
|
|
||||||
log::debug!(
|
|
||||||
"skipped tool use {:?} because it is still pending",
|
|
||||||
tool_use
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||||
self.tool_uses_by_assistant_message
|
self.tool_uses_by_assistant_message
|
||||||
.contains_key(&assistant_message_id)
|
.contains_key(&assistant_message_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_results_message(
|
pub fn tool_results(
|
||||||
&self,
|
&self,
|
||||||
assistant_message_id: MessageId,
|
assistant_message_id: MessageId,
|
||||||
) -> Option<LanguageModelRequestMessage> {
|
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
|
||||||
let tool_uses = self
|
self.tool_uses_by_assistant_message
|
||||||
.tool_uses_by_assistant_message
|
.get(&assistant_message_id)
|
||||||
.get(&assistant_message_id)?;
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
if tool_uses.is_empty() {
|
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut request_message = LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: vec![],
|
|
||||||
cache: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
for tool_use in tool_uses {
|
|
||||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
|
||||||
request_message
|
|
||||||
.content
|
|
||||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
|
||||||
tool_use_id: tool_use.id.clone(),
|
|
||||||
tool_name: tool_result.tool_name.clone(),
|
|
||||||
is_error: tool_result.is_error,
|
|
||||||
content: if tool_result.content.is_empty() {
|
|
||||||
// Surprisingly, the API fails if we return an empty string here.
|
|
||||||
// It thinks we are sending a tool use without a tool result.
|
|
||||||
"<Tool returned an empty string>".into()
|
|
||||||
} else {
|
|
||||||
tool_result.content.clone()
|
|
||||||
},
|
|
||||||
output: None,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(request_message)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,7 +483,7 @@ pub struct Confirmation {
|
||||||
pub tool_use_id: LanguageModelToolUseId,
|
pub tool_use_id: LanguageModelToolUseId,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
pub ui_text: Arc<str>,
|
pub ui_text: Arc<str>,
|
||||||
pub messages: Arc<Vec<LanguageModelRequestMessage>>,
|
pub request: Arc<LanguageModelRequest>,
|
||||||
pub tool: Arc<dyn Tool>,
|
pub tool: Arc<dyn Tool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -578,6 +578,7 @@ pub enum ToolChoice {
|
||||||
Auto,
|
Auto,
|
||||||
Any,
|
Any,
|
||||||
Tool { name: String },
|
Tool { name: String },
|
||||||
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
|
|
@ -2585,6 +2585,7 @@ impl AssistantContext {
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: model
|
temperature: model
|
||||||
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)),
|
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)),
|
||||||
|
|
|
@ -19,7 +19,7 @@ use gpui::Window;
|
||||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use language_model::LanguageModel;
|
use language_model::LanguageModel;
|
||||||
use language_model::LanguageModelRequestMessage;
|
use language_model::LanguageModelRequest;
|
||||||
use language_model::LanguageModelToolSchemaFormat;
|
use language_model::LanguageModelToolSchemaFormat;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
@ -206,7 +206,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
messages: &[LanguageModelRequestMessage],
|
request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -3,8 +3,8 @@ use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::AnyWindowHandle;
|
use gpui::AnyWindowHandle;
|
||||||
use gpui::{App, AppContext, Entity, Task};
|
use gpui::{App, AppContext, Entity, Task};
|
||||||
use language_model::LanguageModelToolSchemaFormat;
|
use language_model::LanguageModel;
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage};
|
use language_model::{LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -74,7 +74,7 @@ impl Tool for CopyPathTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::AnyWindowHandle;
|
use gpui::AnyWindowHandle;
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -61,7 +61,7 @@ impl Tool for CreateDirectoryTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::{Project, ProjectPath};
|
use project::{Project, ProjectPath};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -59,7 +59,7 @@ impl Tool for DeletePathTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -79,7 +79,7 @@ impl Tool for DiagnosticsTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -17,7 +17,7 @@ use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
||||||
use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
|
use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
MessageContent, Role,
|
LanguageModelToolChoice, MessageContent, Role,
|
||||||
};
|
};
|
||||||
use project::{AgentLocation, Project};
|
use project::{AgentLocation, Project};
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
@ -83,7 +83,7 @@ impl EditAgent {
|
||||||
&self,
|
&self,
|
||||||
buffer: Entity<Buffer>,
|
buffer: Entity<Buffer>,
|
||||||
edit_description: String,
|
edit_description: String,
|
||||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
conversation: &LanguageModelRequest,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> (
|
) -> (
|
||||||
Task<Result<EditAgentOutput>>,
|
Task<Result<EditAgentOutput>>,
|
||||||
|
@ -91,6 +91,7 @@ impl EditAgent {
|
||||||
) {
|
) {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let (events_tx, events_rx) = mpsc::unbounded();
|
let (events_tx, events_rx) = mpsc::unbounded();
|
||||||
|
let conversation = conversation.clone();
|
||||||
let output = cx.spawn(async move |cx| {
|
let output = cx.spawn(async move |cx| {
|
||||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||||
|
@ -99,7 +100,7 @@ impl EditAgent {
|
||||||
edit_description,
|
edit_description,
|
||||||
}
|
}
|
||||||
.render(&this.templates)?;
|
.render(&this.templates)?;
|
||||||
let new_chunks = this.request(previous_messages, prompt, cx).await?;
|
let new_chunks = this.request(conversation, prompt, cx).await?;
|
||||||
|
|
||||||
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
|
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
|
||||||
while let Some(event) = inner_events.next().await {
|
while let Some(event) = inner_events.next().await {
|
||||||
|
@ -194,7 +195,7 @@ impl EditAgent {
|
||||||
&self,
|
&self,
|
||||||
buffer: Entity<Buffer>,
|
buffer: Entity<Buffer>,
|
||||||
edit_description: String,
|
edit_description: String,
|
||||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
conversation: &LanguageModelRequest,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> (
|
) -> (
|
||||||
Task<Result<EditAgentOutput>>,
|
Task<Result<EditAgentOutput>>,
|
||||||
|
@ -214,6 +215,7 @@ impl EditAgent {
|
||||||
|
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let (events_tx, events_rx) = mpsc::unbounded();
|
let (events_tx, events_rx) = mpsc::unbounded();
|
||||||
|
let conversation = conversation.clone();
|
||||||
let output = cx.spawn(async move |cx| {
|
let output = cx.spawn(async move |cx| {
|
||||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||||
|
@ -222,7 +224,7 @@ impl EditAgent {
|
||||||
edit_description,
|
edit_description,
|
||||||
}
|
}
|
||||||
.render(&this.templates)?;
|
.render(&this.templates)?;
|
||||||
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
|
let edit_chunks = this.request(conversation, prompt, cx).await?;
|
||||||
|
|
||||||
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
||||||
while let Some(event) = inner_events.next().await {
|
while let Some(event) = inner_events.next().await {
|
||||||
|
@ -512,32 +514,67 @@ impl EditAgent {
|
||||||
|
|
||||||
async fn request(
|
async fn request(
|
||||||
&self,
|
&self,
|
||||||
mut messages: Vec<LanguageModelRequestMessage>,
|
mut conversation: LanguageModelRequest,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||||
let mut message_content = Vec::new();
|
let mut messages_iter = conversation.messages.iter_mut();
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages_iter.next_back() {
|
||||||
if last_message.role == Role::Assistant {
|
if last_message.role == Role::Assistant {
|
||||||
|
let old_content_len = last_message.content.len();
|
||||||
last_message
|
last_message
|
||||||
.content
|
.content
|
||||||
.retain(|content| !matches!(content, MessageContent::ToolUse(_)));
|
.retain(|content| !matches!(content, MessageContent::ToolUse(_)));
|
||||||
|
let new_content_len = last_message.content.len();
|
||||||
|
|
||||||
|
// We just removed pending tool uses from the content of the
|
||||||
|
// last message, so it doesn't make sense to cache it anymore
|
||||||
|
// (e.g., the message will look very different on the next
|
||||||
|
// request). Thus, we move the flag to the message prior to it,
|
||||||
|
// as it will still be a valid prefix of the conversation.
|
||||||
|
if old_content_len != new_content_len && last_message.cache {
|
||||||
|
if let Some(prev_message) = messages_iter.next_back() {
|
||||||
|
last_message.cache = false;
|
||||||
|
prev_message.cache = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if last_message.content.is_empty() {
|
if last_message.content.is_empty() {
|
||||||
messages.pop();
|
conversation.messages.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
message_content.push(MessageContent::Text(prompt));
|
|
||||||
messages.push(LanguageModelRequestMessage {
|
conversation.messages.push(LanguageModelRequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: message_content,
|
content: vec![MessageContent::Text(prompt)],
|
||||||
cache: false,
|
cache: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Include tools in the request so that we can take advantage of
|
||||||
|
// caching when ToolChoice::None is supported.
|
||||||
|
let mut tool_choice = None;
|
||||||
|
let mut tools = Vec::new();
|
||||||
|
if !conversation.tools.is_empty()
|
||||||
|
&& self
|
||||||
|
.model
|
||||||
|
.supports_tool_choice(LanguageModelToolChoice::None)
|
||||||
|
{
|
||||||
|
tool_choice = Some(LanguageModelToolChoice::None);
|
||||||
|
tools = conversation.tools.clone();
|
||||||
|
}
|
||||||
|
|
||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
messages,
|
thread_id: conversation.thread_id,
|
||||||
..Default::default()
|
prompt_id: conversation.prompt_id,
|
||||||
|
mode: conversation.mode,
|
||||||
|
messages: conversation.messages,
|
||||||
|
tool_choice,
|
||||||
|
tools,
|
||||||
|
stop: Vec::new(),
|
||||||
|
temperature: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(self.model.stream_completion_text(request, cx).await?.stream)
|
Ok(self.model.stream_completion_text(request, cx).await?.stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,16 @@ use super::*;
|
||||||
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
|
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
|
||||||
use Role::*;
|
use Role::*;
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
|
use assistant_tool::ToolRegistry;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use futures::{FutureExt, future::LocalBoxFuture};
|
use futures::{FutureExt, future::LocalBoxFuture};
|
||||||
use gpui::{AppContext, TestAppContext};
|
use gpui::{AppContext, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::{formatdoc, indoc};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
|
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
|
||||||
|
LanguageModelToolUseId,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
@ -37,7 +39,7 @@ fn eval_extract_handle_command_output() {
|
||||||
conversation: vec![
|
conversation: vec![
|
||||||
message(
|
message(
|
||||||
User,
|
User,
|
||||||
[text(indoc! {"
|
[text(formatdoc! {"
|
||||||
Read the `{input_file_path}` file and extract a method in
|
Read the `{input_file_path}` file and extract a method in
|
||||||
the final stanza of `run_git_blame` to deal with command failures,
|
the final stanza of `run_git_blame` to deal with command failures,
|
||||||
call it `handle_command_output` and take the std::process::Output as the only parameter.
|
call it `handle_command_output` and take the std::process::Output as the only parameter.
|
||||||
|
@ -96,7 +98,7 @@ fn eval_delete_run_git_blame() {
|
||||||
conversation: vec![
|
conversation: vec![
|
||||||
message(
|
message(
|
||||||
User,
|
User,
|
||||||
[text(indoc! {"
|
[text(formatdoc! {"
|
||||||
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
|
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
|
||||||
one function, not its usages.
|
one function, not its usages.
|
||||||
"})],
|
"})],
|
||||||
|
@ -138,6 +140,61 @@ fn eval_delete_run_git_blame() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||||
|
fn eval_translate_doc_comments() {
|
||||||
|
let input_file_path = "root/canvas.rs";
|
||||||
|
let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
|
||||||
|
let edit_description = "Translate all doc comments to Italian";
|
||||||
|
eval(
|
||||||
|
200,
|
||||||
|
1.,
|
||||||
|
EvalInput {
|
||||||
|
conversation: vec![
|
||||||
|
message(
|
||||||
|
User,
|
||||||
|
[text(formatdoc! {"
|
||||||
|
Read the {input_file_path} file and edit it (without overwriting it),
|
||||||
|
translating all the doc comments to italian.
|
||||||
|
"})],
|
||||||
|
),
|
||||||
|
message(
|
||||||
|
Assistant,
|
||||||
|
[tool_use(
|
||||||
|
"tool_1",
|
||||||
|
"read_file",
|
||||||
|
ReadFileToolInput {
|
||||||
|
path: input_file_path.into(),
|
||||||
|
start_line: None,
|
||||||
|
end_line: None,
|
||||||
|
},
|
||||||
|
)],
|
||||||
|
),
|
||||||
|
message(
|
||||||
|
User,
|
||||||
|
[tool_result("tool_1", "read_file", input_file_content)],
|
||||||
|
),
|
||||||
|
message(
|
||||||
|
Assistant,
|
||||||
|
[tool_use(
|
||||||
|
"tool_2",
|
||||||
|
"edit_file",
|
||||||
|
EditFileToolInput {
|
||||||
|
display_description: edit_description.into(),
|
||||||
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
|
},
|
||||||
|
)],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
input_path: input_file_path.into(),
|
||||||
|
input_content: Some(input_file_content.into()),
|
||||||
|
edit_description: edit_description.into(),
|
||||||
|
assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(not(feature = "eval"), ignore)]
|
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||||
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||||
|
@ -152,7 +209,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||||
conversation: vec![
|
conversation: vec![
|
||||||
message(
|
message(
|
||||||
User,
|
User,
|
||||||
[text(indoc! {"
|
[text(formatdoc! {"
|
||||||
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
|
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
|
||||||
Use `ureq` to download the SDK for the current platform and architecture.
|
Use `ureq` to download the SDK for the current platform and architecture.
|
||||||
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
|
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
|
||||||
|
@ -160,7 +217,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||||
that's inside of the archive.
|
that's inside of the archive.
|
||||||
Don't re-download the SDK if that executable already exists.
|
Don't re-download the SDK if that executable already exists.
|
||||||
|
|
||||||
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
|
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
|
||||||
|
|
||||||
Here are the available wasi-sdk assets:
|
Here are the available wasi-sdk assets:
|
||||||
- wasi-sdk-25.0-x86_64-macos.tar.gz
|
- wasi-sdk-25.0-x86_64-macos.tar.gz
|
||||||
|
@ -261,11 +318,10 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||||
fn eval_disable_cursor_blinking() {
|
fn eval_disable_cursor_blinking() {
|
||||||
let input_file_path = "root/editor.rs";
|
let input_file_path = "root/editor.rs";
|
||||||
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
|
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
|
||||||
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
|
|
||||||
let edit_description = "Comment out the call to `BlinkManager::enable`";
|
let edit_description = "Comment out the call to `BlinkManager::enable`";
|
||||||
eval(
|
eval(
|
||||||
200,
|
200,
|
||||||
0.6, // TODO: make this eval better
|
0.95,
|
||||||
EvalInput {
|
EvalInput {
|
||||||
conversation: vec![
|
conversation: vec![
|
||||||
message(User, [text("Let's research how to cursor blinking works.")]),
|
message(User, [text("Let's research how to cursor blinking works.")]),
|
||||||
|
@ -324,7 +380,11 @@ fn eval_disable_cursor_blinking() {
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: Some(input_file_content.into()),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::assert_eq(output_file_content),
|
assertion: EvalAssertion::judge_diff(indoc! {"
|
||||||
|
- Calls to BlinkManager in `observe_window_activation` were commented out
|
||||||
|
- The call to `blink_manager.enable` above the call to show_cursor_names was commented out
|
||||||
|
- All the edits have valid indentation
|
||||||
|
"}),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1031,7 +1091,8 @@ impl EvalAssertion {
|
||||||
|
|
||||||
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||||
let mut evaluated_count = 0;
|
let mut evaluated_count = 0;
|
||||||
report_progress(evaluated_count, iterations);
|
let mut failed_count = 0;
|
||||||
|
report_progress(evaluated_count, failed_count, iterations);
|
||||||
|
|
||||||
let (tx, rx) = mpsc::channel();
|
let (tx, rx) = mpsc::channel();
|
||||||
|
|
||||||
|
@ -1048,7 +1109,6 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||||
}
|
}
|
||||||
drop(tx);
|
drop(tx);
|
||||||
|
|
||||||
let mut failed_count = 0;
|
|
||||||
let mut failed_evals = HashMap::default();
|
let mut failed_evals = HashMap::default();
|
||||||
let mut errored_evals = HashMap::default();
|
let mut errored_evals = HashMap::default();
|
||||||
let mut eval_outputs = Vec::new();
|
let mut eval_outputs = Vec::new();
|
||||||
|
@ -1073,7 +1133,7 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluated_count += 1;
|
evaluated_count += 1;
|
||||||
report_progress(evaluated_count, iterations);
|
report_progress(evaluated_count, failed_count, iterations);
|
||||||
}
|
}
|
||||||
|
|
||||||
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
|
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
|
||||||
|
@ -1144,8 +1204,19 @@ impl Display for EvalOutput {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn report_progress(evaluated_count: usize, iterations: usize) {
|
fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
|
||||||
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
|
let passed_count = evaluated_count - failed_count;
|
||||||
|
let passed_ratio = if evaluated_count == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
passed_count as f64 / evaluated_count as f64
|
||||||
|
};
|
||||||
|
print!(
|
||||||
|
"\r\x1b[KEvaluated {}/{} ({:.2}%)",
|
||||||
|
evaluated_count,
|
||||||
|
iterations,
|
||||||
|
passed_ratio * 100.0
|
||||||
|
);
|
||||||
std::io::stdout().flush().unwrap();
|
std::io::stdout().flush().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1158,25 +1229,30 @@ struct EditAgentTest {
|
||||||
impl EditAgentTest {
|
impl EditAgentTest {
|
||||||
async fn new(cx: &mut TestAppContext) -> Self {
|
async fn new(cx: &mut TestAppContext) -> Self {
|
||||||
cx.executor().allow_parking();
|
cx.executor().allow_parking();
|
||||||
cx.update(settings::init);
|
|
||||||
cx.update(Project::init_settings);
|
|
||||||
cx.update(language::init);
|
|
||||||
cx.update(gpui_tokio::init);
|
|
||||||
cx.update(client::init_settings);
|
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor().clone());
|
let fs = FakeFs::new(cx.executor().clone());
|
||||||
|
cx.update(|cx| {
|
||||||
|
settings::init(cx);
|
||||||
|
gpui_tokio::init(cx);
|
||||||
|
let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
|
||||||
|
cx.set_http_client(http_client);
|
||||||
|
|
||||||
|
client::init_settings(cx);
|
||||||
|
let client = Client::production(cx);
|
||||||
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
|
|
||||||
|
settings::init(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
language::init(cx);
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||||
|
crate::init(client.http_client(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let (agent_model, judge_model) = cx
|
let (agent_model, judge_model) = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
|
||||||
cx.set_http_client(Arc::new(http_client));
|
|
||||||
|
|
||||||
let client = Client::production(cx);
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
||||||
language_model::init(client.clone(), cx);
|
|
||||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let agent_model =
|
let agent_model =
|
||||||
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
|
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
|
||||||
|
@ -1225,12 +1301,32 @@ impl EditAgentTest {
|
||||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
let conversation = LanguageModelRequest {
|
||||||
|
messages: eval.conversation,
|
||||||
|
tools: cx.update(|cx| {
|
||||||
|
ToolRegistry::default_global(cx)
|
||||||
|
.tools()
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|tool| {
|
||||||
|
let input_schema = tool
|
||||||
|
.input_schema(self.agent.model.tool_input_format())
|
||||||
|
.ok()?;
|
||||||
|
Some(LanguageModelRequestTool {
|
||||||
|
name: tool.name(),
|
||||||
|
description: tool.description(),
|
||||||
|
input_schema,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
|
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
|
||||||
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
||||||
let (edit_output, _) = self.agent.edit(
|
let (edit_output, _) = self.agent.edit(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
eval.edit_description,
|
eval.edit_description,
|
||||||
eval.conversation,
|
&conversation,
|
||||||
&mut cx.to_async(),
|
&mut cx.to_async(),
|
||||||
);
|
);
|
||||||
edit_output.await?
|
edit_output.await?
|
||||||
|
@ -1238,7 +1334,7 @@ impl EditAgentTest {
|
||||||
let (edit_output, _) = self.agent.overwrite(
|
let (edit_output, _) = self.agent.overwrite(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
eval.edit_description,
|
eval.edit_description,
|
||||||
eval.conversation,
|
&conversation,
|
||||||
&mut cx.to_async(),
|
&mut cx.to_async(),
|
||||||
);
|
);
|
||||||
edit_output.await?
|
edit_output.await?
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,339 @@
|
||||||
|
// font-kit/src/canvas.rs
|
||||||
|
//
|
||||||
|
// Copyright © 2018 The Pathfinder Project Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||||
|
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||||
|
// option. This file may not be copied, modified, or distributed
|
||||||
|
// except according to those terms.
|
||||||
|
|
||||||
|
//! An in-memory bitmap surface for glyph rasterization.
|
||||||
|
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use pathfinder_geometry::rect::RectI;
|
||||||
|
use pathfinder_geometry::vector::Vector2I;
|
||||||
|
use std::cmp;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::utils;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
|
||||||
|
let mut lut = [[0; 8]; 256];
|
||||||
|
for byte in 0..0x100 {
|
||||||
|
let mut value = [0; 8];
|
||||||
|
for bit in 0..8 {
|
||||||
|
if (byte & (0x80 >> bit)) != 0 {
|
||||||
|
value[bit] = 0xff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lut[byte] = value
|
||||||
|
}
|
||||||
|
lut
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An in-memory bitmap surface for glyph rasterization.
|
||||||
|
pub struct Canvas {
|
||||||
|
/// The raw pixel data.
|
||||||
|
pub pixels: Vec<u8>,
|
||||||
|
/// The size of the buffer, in pixels.
|
||||||
|
pub size: Vector2I,
|
||||||
|
/// The number of *bytes* between successive rows.
|
||||||
|
pub stride: usize,
|
||||||
|
/// The image format of the canvas.
|
||||||
|
pub format: Format,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Canvas {
|
||||||
|
/// Creates a new blank canvas with the given pixel size and format.
|
||||||
|
///
|
||||||
|
/// Stride is automatically calculated from width.
|
||||||
|
///
|
||||||
|
/// The canvas is initialized with transparent black (all values 0).
|
||||||
|
#[inline]
|
||||||
|
pub fn new(size: Vector2I, format: Format) -> Canvas {
|
||||||
|
Canvas::with_stride(
|
||||||
|
size,
|
||||||
|
size.x() as usize * format.bytes_per_pixel() as usize,
|
||||||
|
format,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
|
||||||
|
/// successive rows), and format.
|
||||||
|
///
|
||||||
|
/// The canvas is initialized with transparent black (all values 0).
|
||||||
|
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
|
||||||
|
Canvas {
|
||||||
|
pixels: vec![0; stride * size.y() as usize],
|
||||||
|
size,
|
||||||
|
stride,
|
||||||
|
format,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
|
||||||
|
self.blit_from(
|
||||||
|
Vector2I::default(),
|
||||||
|
&src.pixels,
|
||||||
|
src.size,
|
||||||
|
src.stride,
|
||||||
|
src.format,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
|
||||||
|
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
|
||||||
|
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
|
||||||
|
/// `src_stride` must be equal or larger than the actual data length.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn blit_from(
|
||||||
|
&mut self,
|
||||||
|
dst_point: Vector2I,
|
||||||
|
src_bytes: &[u8],
|
||||||
|
src_size: Vector2I,
|
||||||
|
src_stride: usize,
|
||||||
|
src_format: Format,
|
||||||
|
) {
|
||||||
|
assert_eq!(
|
||||||
|
src_stride * src_size.y() as usize,
|
||||||
|
src_bytes.len(),
|
||||||
|
"Number of pixels in src_bytes does not match stride and size."
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
|
||||||
|
"src_stride must be >= than src_size.x()"
|
||||||
|
);
|
||||||
|
|
||||||
|
let dst_rect = RectI::new(dst_point, src_size);
|
||||||
|
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||||
|
let dst_rect = match dst_rect {
|
||||||
|
Some(dst_rect) => dst_rect,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
match (self.format, src_format) {
|
||||||
|
(Format::A8, Format::A8)
|
||||||
|
| (Format::Rgb24, Format::Rgb24)
|
||||||
|
| (Format::Rgba32, Format::Rgba32) => {
|
||||||
|
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
|
||||||
|
}
|
||||||
|
(Format::A8, Format::Rgb24) => {
|
||||||
|
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
|
||||||
|
}
|
||||||
|
(Format::Rgb24, Format::A8) => {
|
||||||
|
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
|
||||||
|
}
|
||||||
|
(Format::Rgb24, Format::Rgba32) => self
|
||||||
|
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
|
||||||
|
(Format::Rgba32, Format::Rgb24) => self
|
||||||
|
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
|
||||||
|
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn blit_from_bitmap_1bpp(
|
||||||
|
&mut self,
|
||||||
|
dst_point: Vector2I,
|
||||||
|
src_bytes: &[u8],
|
||||||
|
src_size: Vector2I,
|
||||||
|
src_stride: usize,
|
||||||
|
) {
|
||||||
|
if self.format != Format::A8 {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst_rect = RectI::new(dst_point, src_size);
|
||||||
|
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||||
|
let dst_rect = match dst_rect {
|
||||||
|
Some(dst_rect) => dst_rect,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
let size = dst_rect.size();
|
||||||
|
|
||||||
|
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||||
|
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
|
||||||
|
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
|
||||||
|
|
||||||
|
for y in 0..size.y() {
|
||||||
|
let (dest_row_start, src_row_start) = (
|
||||||
|
(y + dst_rect.origin_y()) as usize * self.stride
|
||||||
|
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||||
|
y as usize * src_stride,
|
||||||
|
);
|
||||||
|
let dest_row_end = dest_row_start + dest_row_stride;
|
||||||
|
let src_row_end = src_row_start + src_row_stride;
|
||||||
|
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||||
|
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||||
|
for x in 0..src_row_stride {
|
||||||
|
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
|
||||||
|
let dest_start = x * 8;
|
||||||
|
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
|
||||||
|
let src = &pattern[0..(dest_end - dest_start)];
|
||||||
|
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
|
||||||
|
/// `src_stride` must be specified in bytes.
|
||||||
|
/// The dimensions of `rect` must be in pixels.
|
||||||
|
fn blit_from_with<B: Blit>(
|
||||||
|
&mut self,
|
||||||
|
rect: RectI,
|
||||||
|
src_bytes: &[u8],
|
||||||
|
src_stride: usize,
|
||||||
|
src_format: Format,
|
||||||
|
) {
|
||||||
|
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
|
||||||
|
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||||
|
|
||||||
|
for y in 0..rect.height() {
|
||||||
|
let (dest_row_start, src_row_start) = (
|
||||||
|
(y + rect.origin_y()) as usize * self.stride
|
||||||
|
+ rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||||
|
y as usize * src_stride,
|
||||||
|
);
|
||||||
|
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
|
||||||
|
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
|
||||||
|
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||||
|
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||||
|
B::blit(dest_row_pixels, src_row_pixels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Canvas {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
f.debug_struct("Canvas")
|
||||||
|
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
|
||||||
|
.field("size", &self.size)
|
||||||
|
.field("stride", &self.stride)
|
||||||
|
.field("format", &self.format)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The image format for the canvas.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
|
pub enum Format {
|
||||||
|
/// Premultiplied R8G8B8A8, little-endian.
|
||||||
|
Rgba32,
|
||||||
|
/// R8G8B8, little-endian.
|
||||||
|
Rgb24,
|
||||||
|
/// A8.
|
||||||
|
A8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Format {
|
||||||
|
/// Returns the number of bits per pixel that this image format corresponds to.
|
||||||
|
#[inline]
|
||||||
|
pub fn bits_per_pixel(self) -> u8 {
|
||||||
|
match self {
|
||||||
|
Format::Rgba32 => 32,
|
||||||
|
Format::Rgb24 => 24,
|
||||||
|
Format::A8 => 8,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of color channels per pixel that this image format corresponds to.
|
||||||
|
#[inline]
|
||||||
|
pub fn components_per_pixel(self) -> u8 {
|
||||||
|
match self {
|
||||||
|
Format::Rgba32 => 4,
|
||||||
|
Format::Rgb24 => 3,
|
||||||
|
Format::A8 => 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of bits per color channel that this image format contains.
|
||||||
|
#[inline]
|
||||||
|
pub fn bits_per_component(self) -> u8 {
|
||||||
|
self.bits_per_pixel() / self.components_per_pixel()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of bytes per pixel that this image format corresponds to.
|
||||||
|
#[inline]
|
||||||
|
pub fn bytes_per_pixel(self) -> u8 {
|
||||||
|
self.bits_per_pixel() / 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The antialiasing strategy that should be used when rasterizing glyphs.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
|
pub enum RasterizationOptions {
|
||||||
|
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
|
||||||
|
Bilevel,
|
||||||
|
/// Grayscale antialiasing. Only one channel is used.
|
||||||
|
GrayscaleAa,
|
||||||
|
/// Subpixel RGB antialiasing, for LCD screens.
|
||||||
|
SubpixelAa,
|
||||||
|
}
|
||||||
|
|
||||||
|
trait Blit {
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BlitMemcpy;
|
||||||
|
|
||||||
|
impl Blit for BlitMemcpy {
|
||||||
|
#[inline]
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||||
|
dest.clone_from_slice(src)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BlitRgb24ToA8;
|
||||||
|
|
||||||
|
impl Blit for BlitRgb24ToA8 {
|
||||||
|
#[inline]
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||||
|
// TODO(pcwalton): SIMD.
|
||||||
|
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
|
||||||
|
*dest = src[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BlitA8ToRgb24;
|
||||||
|
|
||||||
|
impl Blit for BlitA8ToRgb24 {
|
||||||
|
#[inline]
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||||
|
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
|
||||||
|
dest[0] = *src;
|
||||||
|
dest[1] = *src;
|
||||||
|
dest[2] = *src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BlitRgba32ToRgb24;
|
||||||
|
|
||||||
|
impl Blit for BlitRgba32ToRgb24 {
|
||||||
|
#[inline]
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||||
|
// TODO(pcwalton): SIMD.
|
||||||
|
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
|
||||||
|
dest.copy_from_slice(&src[0..3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BlitRgb24ToRgba32;
|
||||||
|
|
||||||
|
impl Blit for BlitRgb24ToRgba32 {
|
||||||
|
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||||
|
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
|
||||||
|
dest[0] = src[0];
|
||||||
|
dest[1] = src[1];
|
||||||
|
dest[2] = src[2];
|
||||||
|
dest[3] = 255;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,7 +19,7 @@ use language::{
|
||||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
||||||
language_settings::SoftWrap,
|
language_settings::SoftWrap,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -146,7 +146,7 @@ impl Tool for EditFileTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
messages: &[LanguageModelRequestMessage],
|
request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
@ -177,7 +177,6 @@ impl Tool for EditFileTool {
|
||||||
});
|
});
|
||||||
|
|
||||||
let card_clone = card.clone();
|
let card_clone = card.clone();
|
||||||
let messages = messages.to_vec();
|
|
||||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
|
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
|
||||||
|
|
||||||
|
@ -209,14 +208,14 @@ impl Tool for EditFileTool {
|
||||||
edit_agent.overwrite(
|
edit_agent.overwrite(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
input.display_description.clone(),
|
input.display_description.clone(),
|
||||||
messages,
|
&request,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
edit_agent.edit(
|
edit_agent.edit(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
input.display_description.clone(),
|
input.display_description.clone(),
|
||||||
messages,
|
&request,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
@ -847,7 +846,15 @@ mod tests {
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Arc::new(EditFileTool)
|
Arc::new(EditFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
|
@ -9,7 +9,7 @@ use futures::AsyncReadExt as _;
|
||||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -142,7 +142,7 @@ impl Tool for FetchTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -7,7 +7,7 @@ use gpui::{
|
||||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||||
};
|
};
|
||||||
use language;
|
use language;
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -73,7 +73,7 @@ impl Tool for FindPathTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language::{OffsetRangeExt, ParseStatus, Point};
|
use language::{OffsetRangeExt, ParseStatus, Point};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::{
|
use project::{
|
||||||
Project,
|
Project,
|
||||||
search::{SearchQuery, SearchResult},
|
search::{SearchQuery, SearchResult},
|
||||||
|
@ -96,7 +96,7 @@ impl Tool for GrepTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
@ -746,7 +746,8 @@ mod tests {
|
||||||
let tool = Arc::new(GrepTool);
|
let tool = Arc::new(GrepTool);
|
||||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, model, None, cx));
|
let task =
|
||||||
|
cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
|
||||||
|
|
||||||
match task.output.await {
|
match task.output.await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -73,7 +73,7 @@ impl Tool for ListDirectoryTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -86,7 +86,7 @@ impl Tool for MovePathTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -5,7 +5,7 @@ use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use chrono::{Local, Utc};
|
use chrono::{Local, Utc};
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -56,7 +56,7 @@ impl Tool for NowTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -49,7 +49,7 @@ impl Tool for OpenTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -7,7 +7,7 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use language::{Anchor, Point};
|
use language::{Anchor, Point};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::{AgentLocation, Project};
|
use project::{AgentLocation, Project};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -83,7 +83,7 @@ impl Tool for ReadFileTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
@ -231,7 +231,15 @@ mod test {
|
||||||
"path": "root/nonexistent_file.txt"
|
"path": "root/nonexistent_file.txt"
|
||||||
});
|
});
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
@ -262,7 +270,15 @@ mod test {
|
||||||
"path": "root/small_file.txt"
|
"path": "root/small_file.txt"
|
||||||
});
|
});
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
@ -295,7 +311,7 @@ mod test {
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(
|
.run(
|
||||||
input,
|
input,
|
||||||
&[],
|
Arc::default(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
@ -325,7 +341,15 @@ mod test {
|
||||||
"offset": 1
|
"offset": 1
|
||||||
});
|
});
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
@ -372,7 +396,15 @@ mod test {
|
||||||
"end_line": 4
|
"end_line": 4
|
||||||
});
|
});
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
@ -406,7 +438,7 @@ mod test {
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(
|
.run(
|
||||||
input,
|
input,
|
||||||
&[],
|
Arc::default(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
@ -429,7 +461,7 @@ mod test {
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(
|
.run(
|
||||||
input,
|
input,
|
||||||
&[],
|
Arc::default(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
@ -450,7 +482,15 @@ mod test {
|
||||||
"end_line": 2
|
"end_line": 2
|
||||||
});
|
});
|
||||||
Arc::new(ReadFileTool)
|
Arc::new(ReadFileTool)
|
||||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.output
|
.output
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
|
You MUST respond with a series of edits to a file, using the following format:
|
||||||
|
|
||||||
You MUST respond with a series of edits to that one file in the following format:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
<edits>
|
<edits>
|
||||||
|
@ -51,3 +49,5 @@ Rules for editing:
|
||||||
<edit_description>
|
<edit_description>
|
||||||
{{edit_description}}
|
{{edit_description}}
|
||||||
</edit_description>
|
</edit_description>
|
||||||
|
|
||||||
|
Tool calls have been disabled. You MUST start your response with <edits>.
|
||||||
|
|
|
@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||||
use futures::{FutureExt as _, future::Shared};
|
use futures::{FutureExt as _, future::Shared};
|
||||||
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
|
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
|
||||||
use language::LineEnding;
|
use language::LineEnding;
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
||||||
use project::{Project, terminals::TerminalKind};
|
use project::{Project, terminals::TerminalKind};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -107,7 +107,7 @@ impl Tool for TerminalTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
@ -656,7 +656,7 @@ mod tests {
|
||||||
TerminalTool::run(
|
TerminalTool::run(
|
||||||
Arc::new(TerminalTool::new(cx)),
|
Arc::new(TerminalTool::new(cx)),
|
||||||
serde_json::to_value(input).unwrap(),
|
serde_json::to_value(input).unwrap(),
|
||||||
&[],
|
Arc::default(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
model,
|
model,
|
||||||
|
@ -691,7 +691,7 @@ mod tests {
|
||||||
let headless_result = TerminalTool::run(
|
let headless_result = TerminalTool::run(
|
||||||
Arc::new(TerminalTool::new(cx)),
|
Arc::new(TerminalTool::new(cx)),
|
||||||
serde_json::to_value(input).unwrap(),
|
serde_json::to_value(input).unwrap(),
|
||||||
&[],
|
Arc::default(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -47,7 +47,7 @@ impl Tool for ThinkingTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -8,7 +8,7 @@ use futures::{Future, FutureExt, TryFutureExt};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -53,7 +53,7 @@ impl Tool for WebSearchTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
_messages: &[LanguageModelRequestMessage],
|
_request: Arc<LanguageModelRequest>,
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_action_log: Entity<ActionLog>,
|
_action_log: Entity<ActionLog>,
|
||||||
_model: Arc<dyn LanguageModel>,
|
_model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -7,9 +7,10 @@ use anyhow::{Error, Result, anyhow};
|
||||||
use aws_sdk_bedrockruntime as bedrock;
|
use aws_sdk_bedrockruntime as bedrock;
|
||||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||||
pub use aws_sdk_bedrockruntime::types::{
|
pub use aws_sdk_bedrockruntime::types::{
|
||||||
AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
|
AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
|
||||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
|
ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||||
ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
|
ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
|
||||||
|
ToolSpecification as BedrockToolSpec,
|
||||||
};
|
};
|
||||||
pub use aws_smithy_types::Blob as BedrockBlob;
|
pub use aws_smithy_types::Blob as BedrockBlob;
|
||||||
use aws_smithy_types::{Document, Number as AwsNumber};
|
use aws_smithy_types::{Document, Number as AwsNumber};
|
||||||
|
|
|
@ -182,11 +182,11 @@ pub enum Tool {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum ToolChoice {
|
pub enum ToolChoice {
|
||||||
Auto,
|
Auto,
|
||||||
Any,
|
Any,
|
||||||
Tool { name: String },
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
|
|
@ -578,6 +578,7 @@ impl ExampleInstance {
|
||||||
}],
|
}],
|
||||||
temperature: None,
|
temperature: None,
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1774,6 +1774,7 @@ impl GitPanel {
|
||||||
cache: false,
|
cache: false,
|
||||||
}],
|
}],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature,
|
temperature,
|
||||||
};
|
};
|
||||||
|
|
|
@ -2,6 +2,7 @@ use crate::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
|
LanguageModelToolChoice,
|
||||||
};
|
};
|
||||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||||
|
@ -152,6 +153,10 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
"fake".to_string()
|
"fake".to_string()
|
||||||
}
|
}
|
||||||
|
|
|
@ -246,6 +246,9 @@ pub trait LanguageModel: Send + Sync {
|
||||||
/// Whether this model supports tools.
|
/// Whether this model supports tools.
|
||||||
fn supports_tools(&self) -> bool;
|
fn supports_tools(&self) -> bool;
|
||||||
|
|
||||||
|
/// Whether this model supports choosing which tool to use.
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
|
||||||
|
|
||||||
/// Returns whether this model supports "max mode";
|
/// Returns whether this model supports "max mode";
|
||||||
fn supports_max_mode(&self) -> bool {
|
fn supports_max_mode(&self) -> bool {
|
||||||
if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
|
if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
|
||||||
|
|
|
@ -203,6 +203,13 @@ pub struct LanguageModelRequestTool {
|
||||||
pub input_schema: serde_json::Value,
|
pub input_schema: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum LanguageModelToolChoice {
|
||||||
|
Auto,
|
||||||
|
Any,
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct LanguageModelRequest {
|
pub struct LanguageModelRequest {
|
||||||
pub thread_id: Option<String>,
|
pub thread_id: Option<String>,
|
||||||
|
@ -210,6 +217,7 @@ pub struct LanguageModelRequest {
|
||||||
pub mode: Option<CompletionMode>,
|
pub mode: Option<CompletionMode>,
|
||||||
pub messages: Vec<LanguageModelRequestMessage>,
|
pub messages: Vec<LanguageModelRequestMessage>,
|
||||||
pub tools: Vec<LanguageModelRequestTool>,
|
pub tools: Vec<LanguageModelRequestTool>,
|
||||||
|
pub tool_choice: Option<LanguageModelToolChoice>,
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,8 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, MessageContent,
|
||||||
|
RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -420,6 +421,14 @@ impl LanguageModel for AnthropicModel {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto
|
||||||
|
| LanguageModelToolChoice::Any
|
||||||
|
| LanguageModelToolChoice::None => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("anthropic/{}", self.model.id())
|
format!("anthropic/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
@ -620,7 +629,11 @@ pub fn into_anthropic(
|
||||||
input_schema: tool.input_schema,
|
input_schema: tool.input_schema,
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
tool_choice: None,
|
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||||
|
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
|
||||||
|
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
|
||||||
|
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
|
||||||
|
}),
|
||||||
metadata: None,
|
metadata: None,
|
||||||
stop_sequences: Vec::new(),
|
stop_sequences: Vec::new(),
|
||||||
temperature: request.temperature.or(Some(default_temperature)),
|
temperature: request.temperature.or(Some(default_temperature)),
|
||||||
|
|
|
@ -15,11 +15,11 @@ use bedrock::bedrock_client::types::{
|
||||||
StopReason,
|
StopReason,
|
||||||
};
|
};
|
||||||
use bedrock::{
|
use bedrock::{
|
||||||
BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent, BedrockMessage,
|
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
|
||||||
BedrockModelMode, BedrockStreamingResponse, BedrockThinkingBlock, BedrockThinkingTextBlock,
|
BedrockMessage, BedrockModelMode, BedrockStreamingResponse, BedrockThinkingBlock,
|
||||||
BedrockTool, BedrockToolChoice, BedrockToolConfig, BedrockToolInputSchema,
|
BedrockThinkingTextBlock, BedrockTool, BedrockToolChoice, BedrockToolConfig,
|
||||||
BedrockToolResultBlock, BedrockToolResultContentBlock, BedrockToolResultStatus,
|
BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
|
||||||
BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
|
BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
|
||||||
};
|
};
|
||||||
use collections::{BTreeMap, HashMap};
|
use collections::{BTreeMap, HashMap};
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
|
@ -35,8 +35,8 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
|
||||||
RateLimiter, Role, TokenUsage,
|
LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
|
||||||
};
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -520,6 +520,15 @@ impl LanguageModel for BedrockModel {
|
||||||
self.model.supports_tool_use()
|
self.model.supports_tool_use()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
|
||||||
|
self.model.supports_tool_use()
|
||||||
|
}
|
||||||
|
LanguageModelToolChoice::None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("bedrock/{}", self.model.id())
|
format!("bedrock/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
@ -719,11 +728,20 @@ pub fn into_bedrock(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let tool_choice = match request.tool_choice {
|
||||||
|
Some(LanguageModelToolChoice::Auto) | None => {
|
||||||
|
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||||
|
}
|
||||||
|
Some(LanguageModelToolChoice::Any) => {
|
||||||
|
BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
|
||||||
|
}
|
||||||
|
Some(LanguageModelToolChoice::None) => {
|
||||||
|
return Err(anyhow!("LanguageModelToolChoice::None is not supported"));
|
||||||
|
}
|
||||||
|
};
|
||||||
let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
|
let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
|
||||||
.set_tools(Some(tool_spec))
|
.set_tools(Some(tool_spec))
|
||||||
.tool_choice(BedrockToolChoice::Auto(
|
.tool_choice(tool_choice)
|
||||||
BedrockAutoToolChoice::builder().build(),
|
|
||||||
))
|
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
Ok(bedrock::Request {
|
Ok(bedrock::Request {
|
||||||
|
|
|
@ -14,8 +14,9 @@ use language_model::{
|
||||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
||||||
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||||
|
ZED_CLOUD_PROVIDER_ID,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||||
|
@ -686,6 +687,14 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto
|
||||||
|
| LanguageModelToolChoice::Any
|
||||||
|
| LanguageModelToolChoice::None => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("zed.dev/{}", self.model.id())
|
format!("zed.dev/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,8 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
|
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolUse, MessageContent,
|
||||||
StopReason,
|
RateLimiter, Role, StopReason,
|
||||||
};
|
};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
@ -197,6 +197,14 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto
|
||||||
|
| LanguageModelToolChoice::Any
|
||||||
|
| LanguageModelToolChoice::None => self.supports_tools(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("copilot_chat/{}", self.model.id())
|
format!("copilot_chat/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
@ -541,7 +549,11 @@ impl CopilotChatLanguageModel {
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
tools,
|
tools,
|
||||||
tool_choice: None,
|
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||||
|
LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
|
||||||
|
LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
|
||||||
|
LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,8 @@ use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
|
LanguageModelToolChoice, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -282,6 +283,10 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("deepseek/{}", self.model.id())
|
format!("deepseek/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ use gpui::{
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
|
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
||||||
StopReason,
|
LanguageModelToolUseId, MessageContent, StopReason,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
@ -313,6 +313,14 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto
|
||||||
|
| LanguageModelToolChoice::Any
|
||||||
|
| LanguageModelToolChoice::None => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||||
}
|
}
|
||||||
|
@ -484,7 +492,16 @@ pub fn into_google(
|
||||||
.collect(),
|
.collect(),
|
||||||
}]
|
}]
|
||||||
}),
|
}),
|
||||||
tool_config: None,
|
tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
|
||||||
|
function_calling_config: google_ai::FunctionCallingConfig {
|
||||||
|
mode: match choice {
|
||||||
|
LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
|
||||||
|
LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
|
||||||
|
LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
|
||||||
|
},
|
||||||
|
allowed_function_names: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
|
LanguageModelToolChoice,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
@ -284,6 +285,10 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("lmstudio/{}", self.model.id())
|
format!("lmstudio/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,8 @@ use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
|
LanguageModelToolChoice, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
|
@ -302,6 +303,10 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("mistral/{}", self.model.id())
|
format!("mistral/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,8 @@ use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason,
|
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
|
||||||
|
LanguageModelToolUseId, StopReason,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
@ -324,6 +325,14 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
self.model.supports_tools.unwrap_or(false)
|
self.model.supports_tools.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto => false,
|
||||||
|
LanguageModelToolChoice::Any => false,
|
||||||
|
LanguageModelToolChoice::None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("ollama/{}", self.model.id())
|
format!("ollama/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
LanguageModelToolChoice, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||||
};
|
};
|
||||||
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -295,6 +295,14 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
|
match choice {
|
||||||
|
LanguageModelToolChoice::Auto => true,
|
||||||
|
LanguageModelToolChoice::Any => true,
|
||||||
|
LanguageModelToolChoice::None => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("openai/{}", self.model.id())
|
format!("openai/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
@ -417,7 +425,11 @@ pub fn into_open_ai(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
tool_choice: None,
|
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||||
|
LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
|
||||||
|
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
|
||||||
|
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -929,6 +929,7 @@ impl RulesLibrary {
|
||||||
cache: false,
|
cache: false,
|
||||||
}],
|
}],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
},
|
},
|
||||||
|
|
|
@ -566,6 +566,7 @@ impl SummaryIndex {
|
||||||
cache: use_cache,
|
cache: use_cache,
|
||||||
}],
|
}],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue