Compare commits
15 commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a16825dab6 | ||
![]() |
5df1775990 | ||
![]() |
63e03abce0 | ||
![]() |
32357e338a | ||
![]() |
c701b68dd0 | ||
![]() |
304d7d9368 | ||
![]() |
d057388d0f | ||
![]() |
3eb57cfe2d | ||
![]() |
e7971d001b | ||
![]() |
064376a747 | ||
![]() |
c24798e538 | ||
![]() |
e1523924c4 | ||
![]() |
abdbc255f9 | ||
![]() |
d97a3fbbbe | ||
![]() |
98b8692fde |
45 changed files with 1330 additions and 602 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -3041,6 +3041,7 @@ dependencies = [
|
|||
"context_server",
|
||||
"ctor",
|
||||
"dap",
|
||||
"dap-types",
|
||||
"dap_adapters",
|
||||
"dashmap 6.1.0",
|
||||
"debugger_ui",
|
||||
|
@ -19942,7 +19943,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.194.0"
|
||||
version = "0.194.3"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
|
||||
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
|
||||
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName};
|
||||
use collections::IndexMap;
|
||||
use convert_case::{Case, Casing};
|
||||
use fs::Fs;
|
||||
|
@ -72,7 +72,7 @@ impl AgentProfile {
|
|||
&self.id
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
|
||||
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
@ -81,7 +81,7 @@ impl AgentProfile {
|
|||
.read(cx)
|
||||
.tools(cx)
|
||||
.into_iter()
|
||||
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
|
||||
.filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ mod tests {
|
|||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.map(|(_, tool)| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
|
@ -174,7 +174,7 @@ mod tests {
|
|||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.map(|(_, tool)| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
|
@ -207,7 +207,7 @@ mod tests {
|
|||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.map(|(_, tool)| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
|
@ -267,10 +267,10 @@ mod tests {
|
|||
}
|
||||
|
||||
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
|
||||
cx.new(|_| {
|
||||
cx.new(|cx| {
|
||||
let mut tool_set = ToolWorkingSet::default();
|
||||
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
|
||||
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
|
||||
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx);
|
||||
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx);
|
||||
tool_set
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ use anyhow::{Result, anyhow};
|
|||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use collections::{HashMap, HashSet};
|
||||
use collections::HashMap;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::{FutureExt, StreamExt as _, future::Shared};
|
||||
use git::repository::DiffType;
|
||||
|
@ -960,13 +960,14 @@ impl Thread {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
) -> Vec<LanguageModelRequestTool> {
|
||||
if model.supports_tools() {
|
||||
resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
|
||||
self.profile
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.filter_map(|(name, tool)| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name,
|
||||
name: name.into(),
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
|
@ -2386,7 +2387,7 @@ impl Thread {
|
|||
|
||||
let tool_list = available_tools
|
||||
.iter()
|
||||
.map(|tool| format!("- {}: {}", tool.name(), tool.description()))
|
||||
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
|
@ -2606,7 +2607,7 @@ impl Thread {
|
|||
.profile
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.map(|tool| tool.name())
|
||||
.map(|(name, _)| name.clone().into())
|
||||
.collect();
|
||||
|
||||
self.message_feedback.insert(message_id, feedback);
|
||||
|
@ -3144,85 +3145,6 @@ struct PendingCompletion {
|
|||
_task: Task<()>,
|
||||
}
|
||||
|
||||
/// Resolves tool name conflicts by ensuring all tool names are unique.
|
||||
///
|
||||
/// When multiple tools have the same name, this function applies the following rules:
|
||||
/// 1. Native tools always keep their original name
|
||||
/// 2. Context server tools get prefixed with their server ID and an underscore
|
||||
/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
|
||||
/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
|
||||
///
|
||||
/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
|
||||
fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
|
||||
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
|
||||
let mut tool_name = tool.name();
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
tool_name
|
||||
}
|
||||
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
let mut duplicated_tool_names = HashSet::default();
|
||||
let mut seen_tool_names = HashSet::default();
|
||||
for tool in tools {
|
||||
let tool_name = resolve_tool_name(tool);
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
debug_assert!(
|
||||
tool.source() != assistant_tool::ToolSource::Native,
|
||||
"There are two built-in tools with the same name: {}",
|
||||
tool_name
|
||||
);
|
||||
duplicated_tool_names.insert(tool_name);
|
||||
} else {
|
||||
seen_tool_names.insert(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
if duplicated_tool_names.is_empty() {
|
||||
return tools
|
||||
.into_iter()
|
||||
.map(|tool| (resolve_tool_name(tool), tool.clone()))
|
||||
.collect();
|
||||
}
|
||||
|
||||
tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let mut tool_name = resolve_tool_name(tool);
|
||||
if !duplicated_tool_names.contains(&tool_name) {
|
||||
return Some((tool_name, tool.clone()));
|
||||
}
|
||||
match tool.source() {
|
||||
assistant_tool::ToolSource::Native => {
|
||||
// Built-in tools always keep their original name
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
assistant_tool::ToolSource::ContextServer { id } => {
|
||||
// Context server tools are prefixed with the context server ID, and truncated if necessary
|
||||
tool_name.insert(0, '_');
|
||||
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
|
||||
let mut id = id.to_string();
|
||||
id.truncate(len);
|
||||
tool_name.insert_str(0, &id);
|
||||
} else {
|
||||
tool_name.insert_str(0, &id);
|
||||
}
|
||||
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
|
||||
None
|
||||
} else {
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -3238,7 +3160,6 @@ mod tests {
|
|||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use gpui::TestAppContext;
|
||||
use icons::IconName;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use language_model::{
|
||||
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
|
||||
|
@ -3883,148 +3804,6 @@ fn main() {{
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_tool_name_conflicts() {
|
||||
use assistant_tool::{Tool, ToolSource};
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
// Test that tool with very long name is always truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![TestTool::new(
|
||||
"tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
|
||||
ToolSource::Native,
|
||||
)],
|
||||
vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
|
||||
);
|
||||
|
||||
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
|
||||
TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::ContextServer {
|
||||
id: "mcp-with-very-very-very-long-name".into(),
|
||||
},
|
||||
),
|
||||
],
|
||||
vec![
|
||||
"tool-with-very-very-very-long-name",
|
||||
"mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
|
||||
],
|
||||
);
|
||||
|
||||
fn assert_resolve_tool_name_conflicts(
|
||||
tools: Vec<TestTool>,
|
||||
expected: Vec<impl Into<String>>,
|
||||
) {
|
||||
let tools: Vec<Arc<dyn Tool>> = tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t) as Arc<dyn Tool>)
|
||||
.collect();
|
||||
let tools = resolve_tool_name_conflicts(&tools);
|
||||
assert_eq!(tools.len(), expected.len());
|
||||
for (i, expected_name) in expected.into_iter().enumerate() {
|
||||
let expected_name = expected_name.into();
|
||||
let actual_name = &tools[i].0;
|
||||
assert_eq!(
|
||||
actual_name, &expected_name,
|
||||
"Expected '{}' got '{}' at index {}",
|
||||
expected_name, actual_name, i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTool {
|
||||
name: String,
|
||||
source: ToolSource,
|
||||
}
|
||||
|
||||
impl TestTool {
|
||||
fn new(name: impl Into<String>, source: ToolSource) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TestTool {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> assistant_tool::ToolResult {
|
||||
assistant_tool::ToolResult {
|
||||
output: Task::ready(Err(anyhow::anyhow!("No content"))),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a model that returns errors
|
||||
enum TestError {
|
||||
Overloaded,
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
};
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ToolId, ToolWorkingSet};
|
||||
use assistant_tool::{Tool, ToolId, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerId;
|
||||
|
@ -537,8 +537,8 @@ impl ThreadStore {
|
|||
}
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
tool_working_set.update(cx, |tool_working_set, cx| {
|
||||
tool_working_set.remove(&tool_ids, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -569,19 +569,17 @@ impl ThreadStore {
|
|||
.log_err()
|
||||
{
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
response
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!("registering context server tool: {:?}", tool.name);
|
||||
tool_working_set.insert(Arc::new(ContextServerTool::new(
|
||||
.update(cx, |tool_working_set, cx| {
|
||||
tool_working_set.extend(
|
||||
response.tools.into_iter().map(|tool| {
|
||||
Arc::new(ContextServerTool::new(
|
||||
context_server_store.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
)))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
)) as Arc<dyn Tool>
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
|
||||
|
|
|
@ -379,6 +379,14 @@ impl ConfigureContextServerModal {
|
|||
};
|
||||
|
||||
self.state = State::Waiting;
|
||||
|
||||
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
|
||||
if existing_server.is_some() {
|
||||
self.context_server_store.update(cx, |store, cx| {
|
||||
store.stop_server(&id, cx).log_err();
|
||||
});
|
||||
}
|
||||
|
||||
let wait_for_context_server_task =
|
||||
wait_for_context_server(&self.context_server_store, id.clone(), cx);
|
||||
cx.spawn({
|
||||
|
@ -399,13 +407,21 @@ impl ConfigureContextServerModal {
|
|||
})
|
||||
.detach();
|
||||
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
|
||||
project_settings.context_servers.insert(id.0, settings);
|
||||
let settings_changed =
|
||||
ProjectSettings::get_global(cx).context_servers.get(&id.0) != Some(&settings);
|
||||
|
||||
if settings_changed {
|
||||
// When we write the settings to the file, the context server will be restarted.
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
|
||||
project_settings.context_servers.insert(id.0, settings);
|
||||
});
|
||||
});
|
||||
});
|
||||
} else if let Some(existing_server) = existing_server {
|
||||
self.context_server_store
|
||||
.update(cx, |store, cx| store.start_server(existing_server, cx));
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) {
|
||||
|
|
|
@ -42,8 +42,8 @@ impl IncompatibleToolsState {
|
|||
.profile()
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
|
||||
.cloned()
|
||||
.filter(|(_, tool)| tool.input_schema(model.tool_input_format()).is_err())
|
||||
.map(|(_, tool)| tool.clone())
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ gpui.workspace = true
|
|||
icons.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
|
|
|
@ -1,18 +1,52 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashMap, IndexMap};
|
||||
use gpui::App;
|
||||
use std::{borrow::Borrow, sync::Arc};
|
||||
|
||||
use crate::{Tool, ToolRegistry, ToolSource};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use gpui::{App, SharedString};
|
||||
use util::debug_panic;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
|
||||
pub struct ToolId(usize);
|
||||
|
||||
/// A unique identifier for a tool within a working set.
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Default)]
|
||||
pub struct UniqueToolName(SharedString);
|
||||
|
||||
impl Borrow<str> for UniqueToolName {
|
||||
fn borrow(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for UniqueToolName {
|
||||
fn from(value: String) -> Self {
|
||||
UniqueToolName(SharedString::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for UniqueToolName {
|
||||
fn into(self) -> String {
|
||||
self.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UniqueToolName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UniqueToolName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
/// A working set of tools for use in one instance of the Assistant Panel.
|
||||
#[derive(Default)]
|
||||
pub struct ToolWorkingSet {
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>,
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
||||
|
@ -24,16 +58,20 @@ impl ToolWorkingSet {
|
|||
.or_else(|| ToolRegistry::global(cx).tool(name))
|
||||
}
|
||||
|
||||
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let mut tools = ToolRegistry::global(cx).tools();
|
||||
tools.extend(self.context_server_tools_by_id.values().cloned());
|
||||
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
|
||||
let mut tools = ToolRegistry::global(cx)
|
||||
.tools()
|
||||
.into_iter()
|
||||
.map(|tool| (UniqueToolName(tool.name().into()), tool))
|
||||
.collect::<Vec<_>>();
|
||||
tools.extend(self.context_server_tools_by_name.clone());
|
||||
tools
|
||||
}
|
||||
|
||||
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
|
||||
let mut tools_by_source = IndexMap::default();
|
||||
|
||||
for tool in self.tools(cx) {
|
||||
for (_, tool) in self.tools(cx) {
|
||||
tools_by_source
|
||||
.entry(tool.source())
|
||||
.or_insert_with(Vec::new)
|
||||
|
@ -49,27 +87,324 @@ impl ToolWorkingSet {
|
|||
tools_by_source
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId {
|
||||
let tool_id = self.register_tool(tool);
|
||||
self.tools_changed(cx);
|
||||
tool_id
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> {
|
||||
let ids = tools.map(|tool| self.register_tool(tool)).collect();
|
||||
self.tools_changed(cx);
|
||||
ids
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
self.tools_changed(cx);
|
||||
}
|
||||
|
||||
fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
let tool_id = self.next_tool_id;
|
||||
self.next_tool_id.0 += 1;
|
||||
self.context_server_tools_by_id
|
||||
.insert(tool_id, tool.clone());
|
||||
self.tools_changed();
|
||||
tool_id
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
self.tools_changed();
|
||||
}
|
||||
|
||||
fn tools_changed(&mut self) {
|
||||
self.context_server_tools_by_name.clear();
|
||||
self.context_server_tools_by_name.extend(
|
||||
self.context_server_tools_by_id
|
||||
fn tools_changed(&mut self, cx: &App) {
|
||||
self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
|
||||
&self
|
||||
.context_server_tools_by_id
|
||||
.values()
|
||||
.map(|tool| (tool.name(), tool.clone())),
|
||||
.cloned()
|
||||
.collect::<Vec<_>>(),
|
||||
&ToolRegistry::global(cx).tools(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_context_server_tool_name_conflicts(
|
||||
context_server_tools: &[Arc<dyn Tool>],
|
||||
native_tools: &[Arc<dyn Tool>],
|
||||
) -> HashMap<UniqueToolName, Arc<dyn Tool>> {
|
||||
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
|
||||
let mut tool_name = tool.name();
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
tool_name
|
||||
}
|
||||
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
let mut duplicated_tool_names = HashSet::default();
|
||||
let mut seen_tool_names = HashSet::default();
|
||||
seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
|
||||
for tool in context_server_tools {
|
||||
let tool_name = resolve_tool_name(tool);
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
debug_assert!(
|
||||
tool.source() != ToolSource::Native,
|
||||
"Expected MCP tool but got a native tool: {}",
|
||||
tool_name
|
||||
);
|
||||
duplicated_tool_names.insert(tool_name);
|
||||
} else {
|
||||
seen_tool_names.insert(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
if duplicated_tool_names.is_empty() {
|
||||
return context_server_tools
|
||||
.into_iter()
|
||||
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
|
||||
.collect();
|
||||
}
|
||||
|
||||
context_server_tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let mut tool_name = resolve_tool_name(tool);
|
||||
if !duplicated_tool_names.contains(&tool_name) {
|
||||
return Some((tool_name.into(), tool.clone()));
|
||||
}
|
||||
match tool.source() {
|
||||
ToolSource::Native => {
|
||||
debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
|
||||
// Built-in tools always keep their original name
|
||||
Some((tool_name.into(), tool.clone()))
|
||||
}
|
||||
ToolSource::ContextServer { id } => {
|
||||
// Context server tools are prefixed with the context server ID, and truncated if necessary
|
||||
tool_name.insert(0, '_');
|
||||
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
|
||||
let mut id = id.to_string();
|
||||
id.truncate(len);
|
||||
tool_name.insert_str(0, &id);
|
||||
} else {
|
||||
tool_name.insert_str(0, &id);
|
||||
}
|
||||
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
|
||||
None
|
||||
} else {
|
||||
Some((tool_name.into(), tool.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
|
||||
use language_model::{LanguageModel, LanguageModelRequest};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ActionLog, ToolResult};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_unique_tool_names(cx: &mut TestAppContext) {
|
||||
fn assert_tool(
|
||||
tool_working_set: &ToolWorkingSet,
|
||||
unique_name: &str,
|
||||
expected_name: &str,
|
||||
expected_source: ToolSource,
|
||||
cx: &App,
|
||||
) {
|
||||
let tool = tool_working_set.tool(unique_name, cx).unwrap();
|
||||
assert_eq!(tool.name(), expected_name);
|
||||
assert_eq!(tool.source(), expected_source);
|
||||
}
|
||||
|
||||
let tool_registry = cx.update(ToolRegistry::default_global);
|
||||
tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
|
||||
tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
|
||||
|
||||
let mut tool_working_set = ToolWorkingSet::default();
|
||||
cx.update(|cx| {
|
||||
tool_working_set.extend(
|
||||
vec![
|
||||
Arc::new(TestTool::new(
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
)) as Arc<dyn Tool>,
|
||||
Arc::new(TestTool::new(
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-2".into() },
|
||||
)) as Arc<dyn Tool>,
|
||||
]
|
||||
.into_iter(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
|
||||
assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
|
||||
assert_tool(
|
||||
&tool_working_set,
|
||||
"mcp-1_tool2",
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
cx,
|
||||
);
|
||||
assert_tool(
|
||||
&tool_working_set,
|
||||
"mcp-2_tool2",
|
||||
"tool2",
|
||||
ToolSource::ContextServer { id: "mcp-2".into() },
|
||||
cx,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_context_server_tool_name_conflicts() {
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
],
|
||||
vec![TestTool::new(
|
||||
"tool3",
|
||||
ToolSource::ContextServer { id: "mcp-1".into() },
|
||||
)],
|
||||
vec!["tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
],
|
||||
vec![
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::Native),
|
||||
],
|
||||
vec![
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
|
||||
assert_resolve_context_server_tool_name_conflicts(
|
||||
vec![TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::Native,
|
||||
)],
|
||||
vec![TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::ContextServer {
|
||||
id: "mcp-with-very-very-very-long-name".into(),
|
||||
},
|
||||
)],
|
||||
vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
|
||||
);
|
||||
|
||||
fn assert_resolve_context_server_tool_name_conflicts(
|
||||
builtin_tools: Vec<TestTool>,
|
||||
context_server_tools: Vec<TestTool>,
|
||||
expected: Vec<&'static str>,
|
||||
) {
|
||||
let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t) as Arc<dyn Tool>)
|
||||
.collect();
|
||||
let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t) as Arc<dyn Tool>)
|
||||
.collect();
|
||||
let tools =
|
||||
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
|
||||
assert_eq!(tools.len(), expected.len());
|
||||
for (i, (name, _)) in tools.into_iter().enumerate() {
|
||||
assert_eq!(
|
||||
name.0.as_ref(),
|
||||
expected[i],
|
||||
"Expected '{}' got '{}' at index {}",
|
||||
expected[i],
|
||||
name,
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTool {
|
||||
name: String,
|
||||
source: ToolSource,
|
||||
}
|
||||
|
||||
impl TestTool {
|
||||
fn new(name: impl Into<String>, source: ToolSource) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TestTool {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> icons::IconName {
|
||||
icons::IconName::Ai
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
ToolResult {
|
||||
output: Task::ready(Err(anyhow::anyhow!("No content"))),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,9 +25,7 @@ fn schema_to_json(
|
|||
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
let mut generator = match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
|
||||
// TODO: Gemini docs mention using a subset of OpenAPI 3, so this may benefit from using
|
||||
// `SchemaSettings::openapi3()`.
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::draft07()
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
|
||||
.with(|settings| {
|
||||
settings.meta_schema = None;
|
||||
settings.inline_subschemas = true;
|
||||
|
|
|
@ -218,7 +218,7 @@ impl Tool for TerminalTool {
|
|||
.update(cx, |project, cx| {
|
||||
project.create_terminal(
|
||||
TerminalKind::Task(task::SpawnInTerminal {
|
||||
command: program,
|
||||
command: Some(program),
|
||||
args,
|
||||
cwd,
|
||||
env,
|
||||
|
|
|
@ -93,6 +93,7 @@ context_server.workspace = true
|
|||
ctor.workspace = true
|
||||
dap = { workspace = true, features = ["test-support"] }
|
||||
dap_adapters = { workspace = true, features = ["test-support"] }
|
||||
dap-types.workspace = true
|
||||
debugger_ui = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
extension.workspace = true
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::tests::TestServer;
|
|||
use call::ActiveCall;
|
||||
use collections::{HashMap, HashSet};
|
||||
|
||||
use dap::{Capabilities, adapters::DebugTaskDefinition, transport::RequestHandling};
|
||||
use debugger_ui::debugger_panel::DebugPanel;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs as _, RemoveOptions};
|
||||
|
@ -22,6 +23,7 @@ use language::{
|
|||
use node_runtime::NodeRuntime;
|
||||
use project::{
|
||||
ProjectPath,
|
||||
debugger::session::ThreadId,
|
||||
lsp_store::{FormatTrigger, LspFormatTarget},
|
||||
};
|
||||
use remote::SshRemoteClient;
|
||||
|
@ -29,7 +31,11 @@ use remote_server::{HeadlessAppState, HeadlessProject};
|
|||
use rpc::proto;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{Arc, atomic::AtomicUsize},
|
||||
};
|
||||
use task::TcpArgumentsTemplate;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
|
@ -688,3 +694,162 @@ async fn test_remote_server_debugger(
|
|||
|
||||
shutdown_session.await.unwrap();
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_slow_adapter_startup_retries(
|
||||
cx_a: &mut TestAppContext,
|
||||
server_cx: &mut TestAppContext,
|
||||
executor: BackgroundExecutor,
|
||||
) {
|
||||
cx_a.update(|cx| {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
command_palette_hooks::init(cx);
|
||||
zlog::init_test();
|
||||
dap_adapters::init(cx);
|
||||
});
|
||||
server_cx.update(|cx| {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
dap_adapters::init(cx);
|
||||
});
|
||||
let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
|
||||
let remote_fs = FakeFs::new(server_cx.executor());
|
||||
remote_fs
|
||||
.insert_tree(
|
||||
path!("/code"),
|
||||
json!({
|
||||
"lib.rs": "fn one() -> usize { 1 }"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// User A connects to the remote project via SSH.
|
||||
server_cx.update(HeadlessProject::init);
|
||||
let remote_http_client = Arc::new(BlockedHttpClient);
|
||||
let node = NodeRuntime::unavailable();
|
||||
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
|
||||
let _headless_project = server_cx.new(|cx| {
|
||||
client::init_settings(cx);
|
||||
HeadlessProject::new(
|
||||
HeadlessAppState {
|
||||
session: server_ssh,
|
||||
fs: remote_fs.clone(),
|
||||
http_client: remote_http_client,
|
||||
node_runtime: node,
|
||||
languages,
|
||||
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await;
|
||||
let mut server = TestServer::start(server_cx.executor()).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
cx_a.update(|cx| {
|
||||
debugger_ui::init(cx);
|
||||
command_palette_hooks::init(cx);
|
||||
});
|
||||
let (project_a, _) = client_a
|
||||
.build_ssh_project(path!("/code"), client_ssh.clone(), cx_a)
|
||||
.await;
|
||||
|
||||
let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a);
|
||||
|
||||
let debugger_panel = workspace
|
||||
.update_in(cx_a, |_workspace, window, cx| {
|
||||
cx.spawn_in(window, DebugPanel::load)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
workspace.update_in(cx_a, |workspace, window, cx| {
|
||||
workspace.add_panel(debugger_panel, window, cx);
|
||||
});
|
||||
|
||||
cx_a.run_until_parked();
|
||||
let debug_panel = workspace
|
||||
.update(cx_a, |workspace, cx| workspace.panel::<DebugPanel>(cx))
|
||||
.unwrap();
|
||||
|
||||
let workspace_window = cx_a
|
||||
.window_handle()
|
||||
.downcast::<workspace::Workspace>()
|
||||
.unwrap();
|
||||
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
let session = debugger_ui::tests::start_debug_session_with(
|
||||
&workspace_window,
|
||||
cx_a,
|
||||
DebugTaskDefinition {
|
||||
adapter: "fake-adapter".into(),
|
||||
label: "test".into(),
|
||||
config: json!({
|
||||
"request": "launch"
|
||||
}),
|
||||
tcp_connection: Some(TcpArgumentsTemplate {
|
||||
port: None,
|
||||
host: None,
|
||||
timeout: None,
|
||||
}),
|
||||
},
|
||||
move |client| {
|
||||
let count = count.clone();
|
||||
client.on_request_ext::<dap::requests::Initialize, _>(move |_seq, _request| {
|
||||
if count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 5 {
|
||||
return RequestHandling::Exit;
|
||||
}
|
||||
RequestHandling::Respond(Ok(Capabilities::default()))
|
||||
});
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
cx_a.run_until_parked();
|
||||
|
||||
let client = session.update(cx_a, |session, _| session.adapter_client().unwrap());
|
||||
client
|
||||
.fake_event(dap::messages::Events::Stopped(dap::StoppedEvent {
|
||||
reason: dap::StoppedEventReason::Pause,
|
||||
description: None,
|
||||
thread_id: Some(1),
|
||||
preserve_focus_hint: None,
|
||||
text: None,
|
||||
all_threads_stopped: None,
|
||||
hit_breakpoint_ids: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
cx_a.run_until_parked();
|
||||
|
||||
let active_session = debug_panel
|
||||
.update(cx_a, |this, _| this.active_session())
|
||||
.unwrap();
|
||||
|
||||
let running_state = active_session.update(cx_a, |active_session, _| {
|
||||
active_session.running_state().clone()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
client.id(),
|
||||
running_state.read_with(cx_a, |running_state, _| running_state.session_id())
|
||||
);
|
||||
assert_eq!(
|
||||
ThreadId(1),
|
||||
running_state.read_with(cx_a, |running_state, _| running_state
|
||||
.selected_thread_id()
|
||||
.unwrap())
|
||||
);
|
||||
|
||||
let shutdown_session = workspace.update(cx_a, |workspace, cx| {
|
||||
workspace.project().update(cx, |project, cx| {
|
||||
project.dap_store().update(cx, |dap_store, cx| {
|
||||
dap_store.shutdown_session(session.read(cx).session_id(), cx)
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
client_ssh.update(cx_a, |a, _| {
|
||||
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor)
|
||||
});
|
||||
|
||||
shutdown_session.await.unwrap();
|
||||
}
|
||||
|
|
|
@ -442,10 +442,18 @@ impl DebugAdapter for FakeAdapter {
|
|||
_: Option<Vec<String>>,
|
||||
_: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary> {
|
||||
let connection = task_definition
|
||||
.tcp_connection
|
||||
.as_ref()
|
||||
.map(|connection| TcpArguments {
|
||||
host: connection.host(),
|
||||
port: connection.port.unwrap_or(17),
|
||||
timeout: connection.timeout,
|
||||
});
|
||||
Ok(DebugAdapterBinary {
|
||||
command: Some("command".into()),
|
||||
arguments: vec![],
|
||||
connection: None,
|
||||
connection,
|
||||
envs: HashMap::default(),
|
||||
cwd: None,
|
||||
request_args: StartDebuggingRequestArguments {
|
||||
|
|
|
@ -108,7 +108,9 @@ impl DebugAdapterClient {
|
|||
arguments: Some(serialized_arguments),
|
||||
};
|
||||
self.transport_delegate
|
||||
.add_pending_request(sequence_id, callback_tx);
|
||||
.pending_requests
|
||||
.lock()
|
||||
.insert(sequence_id, callback_tx)?;
|
||||
|
||||
log::debug!(
|
||||
"Client {} send `{}` request with sequence_id: {}",
|
||||
|
@ -166,6 +168,7 @@ impl DebugAdapterClient {
|
|||
pub fn kill(&self) {
|
||||
log::debug!("Killing DAP process");
|
||||
self.transport_delegate.transport.lock().kill();
|
||||
self.transport_delegate.pending_requests.lock().shutdown();
|
||||
}
|
||||
|
||||
pub fn has_adapter_logs(&self) -> bool {
|
||||
|
@ -180,11 +183,34 @@ impl DebugAdapterClient {
|
|||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
|
||||
{
|
||||
use crate::transport::RequestHandling;
|
||||
|
||||
self.transport_delegate
|
||||
.transport
|
||||
.lock()
|
||||
.as_fake()
|
||||
.on_request::<R, _>(move |seq, request| {
|
||||
RequestHandling::Respond(handler(seq, request))
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn on_request_ext<R: dap_types::requests::Request, F>(&self, handler: F)
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(
|
||||
u64,
|
||||
R::Arguments,
|
||||
) -> crate::transport::RequestHandling<
|
||||
Result<R::Response, dap_types::ErrorResponse>,
|
||||
>,
|
||||
{
|
||||
self.transport_delegate
|
||||
.transport
|
||||
|
|
|
@ -49,7 +49,12 @@ pub enum IoKind {
|
|||
StdErr,
|
||||
}
|
||||
|
||||
type Requests = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Response>>>>>;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub enum RequestHandling<T> {
|
||||
Respond(T),
|
||||
Exit,
|
||||
}
|
||||
|
||||
type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
|
||||
|
||||
pub trait Transport: Send + Sync {
|
||||
|
@ -77,7 +82,11 @@ async fn start(
|
|||
) -> Result<Box<dyn Transport>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
if cfg!(any(test, feature = "test-support")) {
|
||||
return Ok(Box::new(FakeTransport::start(cx).await?));
|
||||
if let Some(connection) = binary.connection.clone() {
|
||||
return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
|
||||
} else {
|
||||
return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
|
||||
}
|
||||
}
|
||||
|
||||
if binary.connection.is_some() {
|
||||
|
@ -91,20 +100,62 @@ async fn start(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PendingRequests {
|
||||
inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
|
||||
}
|
||||
|
||||
impl PendingRequests {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
inner: Some(HashMap::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self, e: anyhow::Error) {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
return;
|
||||
};
|
||||
for (_, sender) in inner.drain() {
|
||||
sender.send(Err(e.cloned())).ok();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert(
|
||||
&mut self,
|
||||
sequence_id: u64,
|
||||
callback_tx: oneshot::Sender<Result<Response>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
bail!("client is closed")
|
||||
};
|
||||
inner.insert(sequence_id, callback_tx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn remove(
|
||||
&mut self,
|
||||
sequence_id: u64,
|
||||
) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
|
||||
let Some(inner) = self.inner.as_mut() else {
|
||||
bail!("client is closed");
|
||||
};
|
||||
Ok(inner.remove(&sequence_id))
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&mut self) {
|
||||
self.flush(anyhow!("transport shutdown"));
|
||||
self.inner = None;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct TransportDelegate {
|
||||
log_handlers: LogHandlers,
|
||||
pub(crate) pending_requests: Requests,
|
||||
pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
|
||||
pub(crate) transport: Mutex<Box<dyn Transport>>,
|
||||
pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
|
||||
tasks: Mutex<Vec<Task<()>>>,
|
||||
}
|
||||
|
||||
impl Drop for TransportDelegate {
|
||||
fn drop(&mut self) {
|
||||
self.transport.lock().kill()
|
||||
}
|
||||
}
|
||||
|
||||
impl TransportDelegate {
|
||||
pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
|
||||
let log_handlers: LogHandlers = Default::default();
|
||||
|
@ -113,7 +164,7 @@ impl TransportDelegate {
|
|||
transport: Mutex::new(transport),
|
||||
log_handlers,
|
||||
server_tx: Default::default(),
|
||||
pending_requests: Default::default(),
|
||||
pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
|
||||
tasks: Default::default(),
|
||||
})
|
||||
}
|
||||
|
@ -154,16 +205,12 @@ impl TransportDelegate {
|
|||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
pending_requests.lock().drain().for_each(|(_, request)| {
|
||||
request
|
||||
.send(Err(anyhow!("debugger shutdown unexpectedly")))
|
||||
.ok();
|
||||
});
|
||||
pending_requests
|
||||
.lock()
|
||||
.flush(anyhow!("debugger shutdown unexpectedly"));
|
||||
}
|
||||
Err(e) => {
|
||||
pending_requests.lock().drain().for_each(|(_, request)| {
|
||||
request.send(Err(e.cloned())).ok();
|
||||
});
|
||||
pending_requests.lock().flush(e);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
@ -188,15 +235,6 @@ impl TransportDelegate {
|
|||
self.transport.lock().tcp_arguments()
|
||||
}
|
||||
|
||||
pub(crate) fn add_pending_request(
|
||||
&self,
|
||||
sequence_id: u64,
|
||||
request: oneshot::Sender<Result<Response>>,
|
||||
) {
|
||||
let mut pending_requests = self.pending_requests.lock();
|
||||
pending_requests.insert(sequence_id, request);
|
||||
}
|
||||
|
||||
pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
|
||||
if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
|
||||
server_tx.send(message).await.context("sending message")
|
||||
|
@ -290,7 +328,7 @@ impl TransportDelegate {
|
|||
async fn recv_from_server<Stdout>(
|
||||
server_stdout: Stdout,
|
||||
mut message_handler: DapMessageHandler,
|
||||
pending_requests: Requests,
|
||||
pending_requests: Arc<Mutex<PendingRequests>>,
|
||||
log_handlers: Option<LogHandlers>,
|
||||
) -> Result<()>
|
||||
where
|
||||
|
@ -300,16 +338,17 @@ impl TransportDelegate {
|
|||
let mut reader = BufReader::new(server_stdout);
|
||||
|
||||
let result = loop {
|
||||
match Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
|
||||
.await
|
||||
{
|
||||
let result =
|
||||
Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
|
||||
.await;
|
||||
match result {
|
||||
ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
return Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Ok(Message::Response(res))) => {
|
||||
let tx = pending_requests.lock().remove(&res.request_seq);
|
||||
let tx = pending_requests.lock().remove(res.request_seq)?;
|
||||
if let Some(tx) = tx {
|
||||
if let Err(e) = tx.send(Self::process_response(res)) {
|
||||
log::trace!("Did not send response `{:?}` for a cancelled", e);
|
||||
|
@ -703,8 +742,7 @@ impl Drop for StdioTransport {
|
|||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
type RequestHandler =
|
||||
Box<dyn Send + FnMut(u64, serde_json::Value) -> dap_types::messages::Response>;
|
||||
type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
type ResponseHandler = Box<dyn Send + Fn(Response)>;
|
||||
|
@ -715,23 +753,38 @@ pub struct FakeTransport {
|
|||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
// for reverse request responses
|
||||
response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
|
||||
|
||||
stdin_writer: Option<PipeWriter>,
|
||||
stdout_reader: Option<PipeReader>,
|
||||
message_handler: Option<Task<Result<()>>>,
|
||||
kind: FakeTransportKind,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub enum FakeTransportKind {
|
||||
Stdio {
|
||||
stdin_writer: Option<PipeWriter>,
|
||||
stdout_reader: Option<PipeReader>,
|
||||
},
|
||||
Tcp {
|
||||
connection: TcpArguments,
|
||||
executor: BackgroundExecutor,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl FakeTransport {
|
||||
pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(u64, R::Arguments) -> Result<R::Response, ErrorResponse>,
|
||||
F: 'static
|
||||
+ Send
|
||||
+ FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
|
||||
{
|
||||
self.request_handlers.lock().insert(
|
||||
R::COMMAND,
|
||||
Box::new(move |seq, args| {
|
||||
let result = handler(seq, serde_json::from_value(args).unwrap());
|
||||
let response = match result {
|
||||
let RequestHandling::Respond(response) = result else {
|
||||
return RequestHandling::Exit;
|
||||
};
|
||||
let response = match response {
|
||||
Ok(response) => Response {
|
||||
seq: seq + 1,
|
||||
request_seq: seq,
|
||||
|
@ -749,7 +802,7 @@ impl FakeTransport {
|
|||
message: None,
|
||||
},
|
||||
};
|
||||
response
|
||||
RequestHandling::Respond(response)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
@ -763,86 +816,75 @@ impl FakeTransport {
|
|||
.insert(R::COMMAND, Box::new(handler));
|
||||
}
|
||||
|
||||
async fn start(cx: &mut AsyncApp) -> Result<Self> {
|
||||
async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
|
||||
Ok(Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
message_handler: None,
|
||||
kind: FakeTransportKind::Tcp {
|
||||
connection,
|
||||
executor: cx.background_executor().clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_messages(
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
|
||||
stdin_reader: PipeReader,
|
||||
stdout_writer: PipeWriter,
|
||||
) -> Result<()> {
|
||||
use dap_types::requests::{Request, RunInTerminal, StartDebugging};
|
||||
use serde_json::json;
|
||||
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
|
||||
let mut this = Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
stdin_writer: Some(stdin_writer),
|
||||
stdout_reader: Some(stdout_reader),
|
||||
message_handler: None,
|
||||
};
|
||||
|
||||
let request_handlers = this.request_handlers.clone();
|
||||
let response_handlers = this.response_handlers.clone();
|
||||
let mut reader = BufReader::new(stdin_reader);
|
||||
let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
|
||||
let mut buffer = String::new();
|
||||
|
||||
this.message_handler = Some(cx.background_spawn(async move {
|
||||
let mut reader = BufReader::new(stdin_reader);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
|
||||
.await
|
||||
{
|
||||
ConnectionResult::Timeout => {
|
||||
anyhow::bail!("Timed out when connecting to debugger");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
break Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||
ConnectionResult::Result(Ok(message)) => {
|
||||
match message {
|
||||
Message::Request(request) => {
|
||||
// redirect reverse requests to stdout writer/reader
|
||||
if request.command == RunInTerminal::COMMAND
|
||||
|| request.command == StartDebugging::COMMAND
|
||||
{
|
||||
let message =
|
||||
serde_json::to_string(&Message::Request(request)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
} else {
|
||||
let response = if let Some(handle) =
|
||||
request_handlers.lock().get_mut(request.command.as_str())
|
||||
{
|
||||
handle(request.seq, request.arguments.unwrap_or(json!({})))
|
||||
} else {
|
||||
panic!("No request handler for {}", request.command);
|
||||
};
|
||||
let message =
|
||||
serde_json::to_string(&Message::Response(response))
|
||||
.unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
}
|
||||
Message::Event(event) => {
|
||||
loop {
|
||||
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
|
||||
ConnectionResult::Timeout => {
|
||||
anyhow::bail!("Timed out when connecting to debugger");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::info!("Debugger closed the connection");
|
||||
break Ok(());
|
||||
}
|
||||
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||
ConnectionResult::Result(Ok(message)) => {
|
||||
match message {
|
||||
Message::Request(request) => {
|
||||
// redirect reverse requests to stdout writer/reader
|
||||
if request.command == RunInTerminal::COMMAND
|
||||
|| request.command == StartDebugging::COMMAND
|
||||
{
|
||||
let message =
|
||||
serde_json::to_string(&Message::Event(event)).unwrap();
|
||||
serde_json::to_string(&Message::Request(request)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message).as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
} else {
|
||||
let response = if let Some(handle) =
|
||||
request_handlers.lock().get_mut(request.command.as_str())
|
||||
{
|
||||
handle(request.seq, request.arguments.unwrap_or(json!({})))
|
||||
} else {
|
||||
panic!("No request handler for {}", request.command);
|
||||
};
|
||||
let response = match response {
|
||||
RequestHandling::Respond(response) => response,
|
||||
RequestHandling::Exit => {
|
||||
break Err(anyhow!("exit in response to request"));
|
||||
}
|
||||
};
|
||||
let message =
|
||||
serde_json::to_string(&Message::Response(response)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
|
@ -853,20 +895,56 @@ impl FakeTransport {
|
|||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
Message::Response(response) => {
|
||||
if let Some(handle) =
|
||||
response_handlers.lock().get(response.command.as_str())
|
||||
{
|
||||
handle(response);
|
||||
} else {
|
||||
log::error!("No response handler for {}", response.command);
|
||||
}
|
||||
}
|
||||
Message::Event(event) => {
|
||||
let message = serde_json::to_string(&Message::Event(event)).unwrap();
|
||||
|
||||
let mut writer = stdout_writer.lock().await;
|
||||
writer
|
||||
.write_all(TransportDelegate::build_rpc_message(message).as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
Message::Response(response) => {
|
||||
if let Some(handle) =
|
||||
response_handlers.lock().get(response.command.as_str())
|
||||
{
|
||||
handle(response);
|
||||
} else {
|
||||
log::error!("No response handler for {}", response.command);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
let kind = FakeTransportKind::Stdio {
|
||||
stdin_writer: Some(stdin_writer),
|
||||
stdout_reader: Some(stdout_reader),
|
||||
};
|
||||
|
||||
let mut this = Self {
|
||||
request_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
response_handlers: Arc::new(Mutex::new(HashMap::default())),
|
||||
message_handler: None,
|
||||
kind,
|
||||
};
|
||||
|
||||
let request_handlers = this.request_handlers.clone();
|
||||
let response_handlers = this.response_handlers.clone();
|
||||
|
||||
this.message_handler = Some(cx.background_spawn(Self::handle_messages(
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
stdin_reader,
|
||||
stdout_writer,
|
||||
)));
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
@ -875,7 +953,10 @@ impl FakeTransport {
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl Transport for FakeTransport {
|
||||
fn tcp_arguments(&self) -> Option<TcpArguments> {
|
||||
None
|
||||
match &self.kind {
|
||||
FakeTransportKind::Stdio { .. } => None,
|
||||
FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
|
@ -886,12 +967,33 @@ impl Transport for FakeTransport {
|
|||
Box<dyn AsyncRead + Unpin + Send + 'static>,
|
||||
)>,
|
||||
> {
|
||||
let result = util::maybe!({
|
||||
Ok((
|
||||
Box::new(self.stdin_writer.take().context("Cannot reconnect")?) as _,
|
||||
Box::new(self.stdout_reader.take().context("Cannot reconnect")?) as _,
|
||||
))
|
||||
});
|
||||
let result = match &mut self.kind {
|
||||
FakeTransportKind::Stdio {
|
||||
stdin_writer,
|
||||
stdout_reader,
|
||||
} => util::maybe!({
|
||||
Ok((
|
||||
Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
|
||||
Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
|
||||
))
|
||||
}),
|
||||
FakeTransportKind::Tcp { executor, .. } => {
|
||||
let (stdin_writer, stdin_reader) = async_pipe::pipe();
|
||||
let (stdout_writer, stdout_reader) = async_pipe::pipe();
|
||||
|
||||
let request_handlers = self.request_handlers.clone();
|
||||
let response_handlers = self.response_handlers.clone();
|
||||
|
||||
self.message_handler = Some(executor.spawn(Self::handle_messages(
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
stdin_reader,
|
||||
stdout_writer,
|
||||
)));
|
||||
|
||||
Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
|
||||
}
|
||||
};
|
||||
Task::ready(result)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use adapters::latest_github_release;
|
||||
use anyhow::Context as _;
|
||||
use collections::HashMap;
|
||||
use dap::{StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
|
||||
use gpui::AsyncApp;
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
||||
use std::{path::PathBuf, sync::OnceLock};
|
||||
use task::DebugRequest;
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
|
@ -70,6 +71,8 @@ impl JsDebugAdapter {
|
|||
let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
|
||||
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
|
||||
|
||||
let mut envs = HashMap::default();
|
||||
|
||||
let mut configuration = task_definition.config.clone();
|
||||
if let Some(configuration) = configuration.as_object_mut() {
|
||||
maybe!({
|
||||
|
@ -110,6 +113,12 @@ impl JsDebugAdapter {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(env) = configuration.get("env").cloned() {
|
||||
if let Ok(env) = serde_json::from_value(env) {
|
||||
envs = env;
|
||||
}
|
||||
}
|
||||
|
||||
configuration
|
||||
.entry("cwd")
|
||||
.or_insert(delegate.worktree_root_path().to_string_lossy().into());
|
||||
|
@ -158,7 +167,7 @@ impl JsDebugAdapter {
|
|||
),
|
||||
arguments,
|
||||
cwd: Some(delegate.worktree_root_path().to_path_buf()),
|
||||
envs: HashMap::default(),
|
||||
envs,
|
||||
connection: Some(adapters::TcpArguments {
|
||||
host,
|
||||
port,
|
||||
|
|
|
@ -33,7 +33,7 @@ use std::sync::{Arc, LazyLock};
|
|||
use task::{DebugScenario, TaskContext};
|
||||
use tree_sitter::{Query, StreamingIterator as _};
|
||||
use ui::{ContextMenu, Divider, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use util::maybe;
|
||||
use util::{ResultExt, maybe};
|
||||
use workspace::SplitDirection;
|
||||
use workspace::{
|
||||
Pane, Workspace,
|
||||
|
@ -363,11 +363,17 @@ impl DebugPanel {
|
|||
let label = curr_session.read(cx).label().clone();
|
||||
let adapter = curr_session.read(cx).adapter().clone();
|
||||
let binary = curr_session.read(cx).binary().cloned().unwrap();
|
||||
let task = curr_session.update(cx, |session, cx| session.shutdown(cx));
|
||||
let task_context = curr_session.read(cx).task_context().clone();
|
||||
|
||||
let curr_session_id = curr_session.read(cx).session_id();
|
||||
self.sessions
|
||||
.retain(|session| session.read(cx).session_id(cx) != curr_session_id);
|
||||
let task = dap_store_handle.update(cx, |dap_store, cx| {
|
||||
dap_store.shutdown_session(curr_session_id, cx)
|
||||
});
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
task.await;
|
||||
task.await.log_err();
|
||||
|
||||
let (session, task) = dap_store_handle.update(cx, |dap_store, cx| {
|
||||
let session = dap_store.new_session(label, adapter, task_context, None, cx);
|
||||
|
@ -1298,9 +1304,7 @@ impl Panel for DebugPanel {
|
|||
|
||||
impl Render for DebugPanel {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let has_sessions = self.sessions.len() > 0;
|
||||
let this = cx.weak_entity();
|
||||
debug_assert_eq!(has_sessions, self.active_session.is_some());
|
||||
|
||||
if self
|
||||
.active_session
|
||||
|
@ -1487,8 +1491,8 @@ impl Render for DebugPanel {
|
|||
}))
|
||||
})
|
||||
.map(|this| {
|
||||
if has_sessions {
|
||||
this.children(self.active_session.clone())
|
||||
if let Some(active_session) = self.active_session.clone() {
|
||||
this.child(active_session)
|
||||
} else {
|
||||
let docked_to_bottom = self.position(window, cx) == DockPosition::Bottom;
|
||||
let welcome_experience = v_flex()
|
||||
|
|
|
@ -121,7 +121,7 @@ impl DebugSession {
|
|||
.to_owned()
|
||||
}
|
||||
|
||||
pub(crate) fn running_state(&self) -> &Entity<RunningState> {
|
||||
pub fn running_state(&self) -> &Entity<RunningState> {
|
||||
&self.running_state
|
||||
}
|
||||
|
||||
|
|
|
@ -973,7 +973,7 @@ impl RunningState {
|
|||
|
||||
let task_with_shell = SpawnInTerminal {
|
||||
command_label,
|
||||
command,
|
||||
command: Some(command),
|
||||
args,
|
||||
..task.resolved.clone()
|
||||
};
|
||||
|
@ -1085,19 +1085,6 @@ impl RunningState {
|
|||
.map(PathBuf::from)
|
||||
.or_else(|| session.binary().unwrap().cwd.clone());
|
||||
|
||||
let mut args = request.args.clone();
|
||||
|
||||
// Handle special case for NodeJS debug adapter
|
||||
// If only the Node binary path is provided, we set the command to None
|
||||
// This prevents the NodeJS REPL from appearing, which is not the desired behavior
|
||||
// The expected usage is for users to provide their own Node command, e.g., `node test.js`
|
||||
// This allows the NodeJS debug client to attach correctly
|
||||
let command = if args.len() > 1 {
|
||||
Some(args.remove(0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut envs: HashMap<String, String> =
|
||||
self.session.read(cx).task_context().project_env.clone();
|
||||
if let Some(Value::Object(env)) = &request.env {
|
||||
|
@ -1111,32 +1098,58 @@ impl RunningState {
|
|||
}
|
||||
}
|
||||
|
||||
let shell = project.read(cx).terminal_settings(&cwd, cx).shell.clone();
|
||||
let kind = if let Some(command) = command {
|
||||
let title = request.title.clone().unwrap_or(command.clone());
|
||||
TerminalKind::Task(task::SpawnInTerminal {
|
||||
id: task::TaskId("debug".to_string()),
|
||||
full_label: title.clone(),
|
||||
label: title.clone(),
|
||||
command: command.clone(),
|
||||
args,
|
||||
command_label: title.clone(),
|
||||
cwd,
|
||||
env: envs,
|
||||
use_new_terminal: true,
|
||||
allow_concurrent_runs: true,
|
||||
reveal: task::RevealStrategy::NoFocus,
|
||||
reveal_target: task::RevealTarget::Dock,
|
||||
hide: task::HideStrategy::Never,
|
||||
shell,
|
||||
show_summary: false,
|
||||
show_command: false,
|
||||
show_rerun: false,
|
||||
})
|
||||
let mut args = request.args.clone();
|
||||
let command = if envs.contains_key("VSCODE_INSPECTOR_OPTIONS") {
|
||||
// Handle special case for NodeJS debug adapter
|
||||
// If the Node binary path is provided (possibly with arguments like --experimental-network-inspection),
|
||||
// we set the command to None
|
||||
// This prevents the NodeJS REPL from appearing, which is not the desired behavior
|
||||
// The expected usage is for users to provide their own Node command, e.g., `node test.js`
|
||||
// This allows the NodeJS debug client to attach correctly
|
||||
if args
|
||||
.iter()
|
||||
.filter(|arg| !arg.starts_with("--"))
|
||||
.collect::<Vec<_>>()
|
||||
.len()
|
||||
> 1
|
||||
{
|
||||
Some(args.remove(0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else if args.len() > 0 {
|
||||
Some(args.remove(0))
|
||||
} else {
|
||||
TerminalKind::Shell(cwd.map(|c| c.to_path_buf()))
|
||||
None
|
||||
};
|
||||
|
||||
let shell = project.read(cx).terminal_settings(&cwd, cx).shell.clone();
|
||||
let title = request
|
||||
.title
|
||||
.clone()
|
||||
.filter(|title| !title.is_empty())
|
||||
.or_else(|| command.clone())
|
||||
.unwrap_or_else(|| "Debug terminal".to_string());
|
||||
let kind = TerminalKind::Task(task::SpawnInTerminal {
|
||||
id: task::TaskId("debug".to_string()),
|
||||
full_label: title.clone(),
|
||||
label: title.clone(),
|
||||
command: command.clone(),
|
||||
args,
|
||||
command_label: title.clone(),
|
||||
cwd,
|
||||
env: envs,
|
||||
use_new_terminal: true,
|
||||
allow_concurrent_runs: true,
|
||||
reveal: task::RevealStrategy::NoFocus,
|
||||
reveal_target: task::RevealTarget::Dock,
|
||||
hide: task::HideStrategy::Never,
|
||||
shell,
|
||||
show_summary: false,
|
||||
show_command: false,
|
||||
show_rerun: false,
|
||||
});
|
||||
|
||||
let workspace = self.workspace.clone();
|
||||
let weak_project = project.downgrade();
|
||||
|
||||
|
@ -1446,7 +1459,7 @@ impl RunningState {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn selected_thread_id(&self) -> Option<ThreadId> {
|
||||
pub fn selected_thread_id(&self) -> Option<ThreadId> {
|
||||
self.thread_id
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut Te
|
|||
let workspace = init_test_workspace(&project, cx).await;
|
||||
let cx = &mut VisualTestContext::from_window(*workspace, cx);
|
||||
|
||||
let session = start_debug_session_with(
|
||||
let _session = start_debug_session_with(
|
||||
&workspace,
|
||||
cx,
|
||||
DebugTaskDefinition {
|
||||
|
@ -59,14 +59,6 @@ async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut Te
|
|||
assert!(workspace.active_modal::<AttachModal>(cx).is_none());
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let shutdown_session = project.update(cx, |project, cx| {
|
||||
project.dap_store().update(cx, |dap_store, cx| {
|
||||
dap_store.shutdown_session(session.read(cx).session_id(), cx)
|
||||
})
|
||||
});
|
||||
|
||||
shutdown_session.await.unwrap();
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::{
|
||||
ExtensionLibraryKind, ExtensionManifest, GrammarManifestEntry, parse_wasm_extension_version,
|
||||
ExtensionLibraryKind, ExtensionManifest, GrammarManifestEntry, build_debug_adapter_schema_path,
|
||||
parse_wasm_extension_version,
|
||||
};
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
|
@ -99,12 +100,8 @@ impl ExtensionBuilder {
|
|||
}
|
||||
|
||||
for (debug_adapter_name, meta) in &mut extension_manifest.debug_adapters {
|
||||
let debug_adapter_relative_schema_path =
|
||||
meta.schema_path.clone().unwrap_or_else(|| {
|
||||
Path::new("debug_adapter_schemas")
|
||||
.join(Path::new(debug_adapter_name.as_ref()).with_extension("json"))
|
||||
});
|
||||
let debug_adapter_schema_path = extension_dir.join(debug_adapter_relative_schema_path);
|
||||
let debug_adapter_schema_path =
|
||||
extension_dir.join(build_debug_adapter_schema_path(debug_adapter_name, meta));
|
||||
|
||||
let debug_adapter_schema = fs::read_to_string(&debug_adapter_schema_path)
|
||||
.with_context(|| {
|
||||
|
|
|
@ -132,6 +132,16 @@ impl ExtensionManifest {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn build_debug_adapter_schema_path(
|
||||
adapter_name: &Arc<str>,
|
||||
meta: &DebugAdapterManifestEntry,
|
||||
) -> PathBuf {
|
||||
meta.schema_path.clone().unwrap_or_else(|| {
|
||||
Path::new("debug_adapter_schemas")
|
||||
.join(Path::new(adapter_name.as_ref()).with_extension("json"))
|
||||
})
|
||||
}
|
||||
|
||||
/// A capability for an extension.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind")]
|
||||
|
@ -320,6 +330,29 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_adapter_schema_path_with_schema_path() {
|
||||
let adapter_name = Arc::from("my_adapter");
|
||||
let entry = DebugAdapterManifestEntry {
|
||||
schema_path: Some(PathBuf::from("foo/bar")),
|
||||
};
|
||||
|
||||
let path = build_debug_adapter_schema_path(&adapter_name, &entry);
|
||||
assert_eq!(path, PathBuf::from("foo/bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_adapter_schema_path_without_schema_path() {
|
||||
let adapter_name = Arc::from("my_adapter");
|
||||
let entry = DebugAdapterManifestEntry { schema_path: None };
|
||||
|
||||
let path = build_debug_adapter_schema_path(&adapter_name, &entry);
|
||||
assert_eq!(
|
||||
path,
|
||||
PathBuf::from("debug_adapter_schemas").join("my_adapter.json")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_exact_match() {
|
||||
let manifest = ExtensionManifest {
|
||||
|
|
|
@ -1633,6 +1633,23 @@ impl ExtensionStore {
|
|||
}
|
||||
}
|
||||
|
||||
for (adapter_name, meta) in loaded_extension.manifest.debug_adapters.iter() {
|
||||
let schema_path = &extension::build_debug_adapter_schema_path(adapter_name, meta);
|
||||
|
||||
if fs.is_file(&src_dir.join(schema_path)).await {
|
||||
match schema_path.parent() {
|
||||
Some(parent) => fs.create_dir(&tmp_dir.join(parent)).await?,
|
||||
None => {}
|
||||
}
|
||||
fs.copy_file(
|
||||
&src_dir.join(schema_path),
|
||||
&tmp_dir.join(schema_path),
|
||||
fs::CopyOptions::default(),
|
||||
)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ use anyhow::{Context as _, Result};
|
|||
use client::{TypedEnvelope, proto};
|
||||
use collections::{HashMap, HashSet};
|
||||
use extension::{
|
||||
Extension, ExtensionHostProxy, ExtensionLanguageProxy, ExtensionLanguageServerProxy,
|
||||
ExtensionManifest,
|
||||
Extension, ExtensionDebugAdapterProviderProxy, ExtensionHostProxy, ExtensionLanguageProxy,
|
||||
ExtensionLanguageServerProxy, ExtensionManifest,
|
||||
};
|
||||
use fs::{Fs, RemoveOptions, RenameOptions};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
|
||||
|
@ -169,8 +169,9 @@ impl HeadlessExtensionStore {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let wasm_extension: Arc<dyn Extension> =
|
||||
Arc::new(WasmExtension::load(extension_dir, &manifest, wasm_host.clone(), &cx).await?);
|
||||
let wasm_extension: Arc<dyn Extension> = Arc::new(
|
||||
WasmExtension::load(extension_dir.clone(), &manifest, wasm_host.clone(), &cx).await?,
|
||||
);
|
||||
|
||||
for (language_server_id, language_server_config) in &manifest.language_servers {
|
||||
for language in language_server_config.languages() {
|
||||
|
@ -186,6 +187,24 @@ impl HeadlessExtensionStore {
|
|||
);
|
||||
})?;
|
||||
}
|
||||
for (debug_adapter, meta) in &manifest.debug_adapters {
|
||||
let schema_path = extension::build_debug_adapter_schema_path(debug_adapter, meta);
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.proxy.register_debug_adapter(
|
||||
wasm_extension.clone(),
|
||||
debug_adapter.clone(),
|
||||
&extension_dir.join(schema_path),
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
for debug_adapter in manifest.debug_locators.keys() {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.proxy
|
||||
.register_debug_locator(wasm_extension.clone(), debug_adapter.clone());
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -999,7 +999,7 @@ impl Extension {
|
|||
) -> Result<Result<DebugRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = resolved_build_task.into();
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
.call_run_dap_locator(store, &locator_name, &build_config_template)
|
||||
.await?
|
||||
|
|
|
@ -299,15 +299,17 @@ impl From<extension::DebugScenario> for DebugScenario {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<SpawnInTerminal> for ResolvedTask {
|
||||
fn from(value: SpawnInTerminal) -> Self {
|
||||
Self {
|
||||
impl TryFrom<SpawnInTerminal> for ResolvedTask {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: SpawnInTerminal) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
label: value.label,
|
||||
command: value.command,
|
||||
command: value.command.context("missing command")?,
|
||||
args: value.args,
|
||||
env: value.env.into_iter().collect(),
|
||||
cwd: value.cwd.map(|s| s.to_string_lossy().into_owned()),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -321,7 +321,7 @@ inventory::submit! {
|
|||
let language_settings_content_ref = generator
|
||||
.subschema_for::<LanguageSettingsContent>()
|
||||
.to_value();
|
||||
let schema = json_schema!({
|
||||
replace_subschema::<LanguageToSettingsMap>(generator, || json_schema!({
|
||||
"type": "object",
|
||||
"properties": params
|
||||
.language_names
|
||||
|
@ -333,8 +333,7 @@ inventory::submit! {
|
|||
)
|
||||
})
|
||||
.collect::<serde_json::Map<_, _>>()
|
||||
});
|
||||
replace_subschema::<LanguageToSettingsMap>(generator, schema)
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -812,9 +812,9 @@ mod tests {
|
|||
.await;
|
||||
|
||||
let executor = cx.executor();
|
||||
let registry = cx.new(|_| {
|
||||
let registry = cx.new(|cx| {
|
||||
let mut registry = ContextServerDescriptorRegistry::new();
|
||||
registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
|
||||
registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
|
||||
registry
|
||||
});
|
||||
let store = cx.new(|cx| {
|
||||
|
|
|
@ -103,19 +103,20 @@ struct ContextServerDescriptorRegistryProxy {
|
|||
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
|
||||
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
.update(cx, |registry, cx| {
|
||||
registry.register_context_server_descriptor(
|
||||
id.clone(),
|
||||
Arc::new(ContextServerDescriptor { id, extension })
|
||||
as Arc<dyn registry::ContextServerDescriptor>,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
registry.unregister_context_server_descriptor_by_id(&server_id)
|
||||
.update(cx, |registry, cx| {
|
||||
registry.unregister_context_server_descriptor_by_id(&server_id, cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
|||
use collections::HashMap;
|
||||
use context_server::ContextServerCommand;
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, Task};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Global, Task};
|
||||
|
||||
use crate::worktree_store::WorktreeStore;
|
||||
|
||||
|
@ -66,12 +66,19 @@ impl ContextServerDescriptorRegistry {
|
|||
&mut self,
|
||||
id: Arc<str>,
|
||||
descriptor: Arc<dyn ContextServerDescriptor>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.context_servers.insert(id, descriptor);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
|
||||
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
|
||||
pub fn unregister_context_server_descriptor_by_id(
|
||||
&mut self,
|
||||
server_id: &str,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.context_servers.remove(server_id);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ impl DapLocator for CargoLocator {
|
|||
.context("Couldn't get cwd from debug config which is needed for locators")?;
|
||||
let builder = ShellBuilder::new(true, &build_config.shell).non_interactive();
|
||||
let (program, args) = builder.build(
|
||||
"cargo".into(),
|
||||
Some("cargo".into()),
|
||||
&build_config
|
||||
.args
|
||||
.iter()
|
||||
|
|
|
@ -660,6 +660,7 @@ pub struct Session {
|
|||
ignore_breakpoints: bool,
|
||||
exception_breakpoints: BTreeMap<String, (ExceptionBreakpointsFilter, IsEnabled)>,
|
||||
background_tasks: Vec<Task<()>>,
|
||||
restart_task: Option<Task<()>>,
|
||||
task_context: TaskContext,
|
||||
}
|
||||
|
||||
|
@ -821,6 +822,7 @@ impl Session {
|
|||
loaded_sources: Vec::default(),
|
||||
threads: IndexMap::default(),
|
||||
background_tasks: Vec::default(),
|
||||
restart_task: None,
|
||||
locations: Default::default(),
|
||||
is_session_terminated: false,
|
||||
ignore_breakpoints: false,
|
||||
|
@ -1865,18 +1867,30 @@ impl Session {
|
|||
}
|
||||
|
||||
pub fn restart(&mut self, args: Option<Value>, cx: &mut Context<Self>) {
|
||||
if self.capabilities.supports_restart_request.unwrap_or(false) && !self.is_terminated() {
|
||||
self.request(
|
||||
RestartCommand {
|
||||
raw: args.unwrap_or(Value::Null),
|
||||
},
|
||||
Self::fallback_to_manual_restart,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
} else {
|
||||
cx.emit(SessionStateEvent::Restart);
|
||||
if self.restart_task.is_some() || self.as_running().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let supports_dap_restart =
|
||||
self.capabilities.supports_restart_request.unwrap_or(false) && !self.is_terminated();
|
||||
|
||||
self.restart_task = Some(cx.spawn(async move |this, cx| {
|
||||
let _ = this.update(cx, |session, cx| {
|
||||
if supports_dap_restart {
|
||||
session
|
||||
.request(
|
||||
RestartCommand {
|
||||
raw: args.unwrap_or(Value::Null),
|
||||
},
|
||||
Self::fallback_to_manual_restart,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
} else {
|
||||
cx.emit(SessionStateEvent::Restart);
|
||||
}
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn shutdown(&mut self, cx: &mut Context<Self>) -> Task<()> {
|
||||
|
@ -1914,8 +1928,13 @@ impl Session {
|
|||
|
||||
cx.emit(SessionStateEvent::Shutdown);
|
||||
|
||||
cx.spawn(async move |_, _| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
task.await;
|
||||
let _ = this.update(cx, |this, _| {
|
||||
if let Some(adapter_client) = this.adapter_client() {
|
||||
adapter_client.kill();
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -568,7 +568,7 @@ async fn test_fallback_to_single_worktree_tasks(cx: &mut gpui::TestAppContext) {
|
|||
.into_iter()
|
||||
.map(|(source_kind, task)| {
|
||||
let resolved = task.resolved;
|
||||
(source_kind, resolved.command)
|
||||
(source_kind, resolved.command.unwrap())
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
vec![(
|
||||
|
|
|
@ -149,7 +149,7 @@ impl Project {
|
|||
let settings = self.terminal_settings(&path, cx).clone();
|
||||
|
||||
let builder = ShellBuilder::new(ssh_details.is_none(), &settings.shell).non_interactive();
|
||||
let (command, args) = builder.build(command, &Vec::new());
|
||||
let (command, args) = builder.build(Some(command), &Vec::new());
|
||||
|
||||
let mut env = self
|
||||
.environment
|
||||
|
@ -297,7 +297,10 @@ impl Project {
|
|||
.or_insert_with(|| "xterm-256color".to_string());
|
||||
let (program, args) = wrap_for_ssh(
|
||||
&ssh_command,
|
||||
Some((&spawn_task.command, &spawn_task.args)),
|
||||
spawn_task
|
||||
.command
|
||||
.as_ref()
|
||||
.map(|command| (command, &spawn_task.args)),
|
||||
path.as_deref(),
|
||||
env,
|
||||
python_venv_directory.as_deref(),
|
||||
|
@ -317,14 +320,16 @@ impl Project {
|
|||
add_environment_path(&mut env, &venv_path.join("bin")).log_err();
|
||||
}
|
||||
|
||||
(
|
||||
task_state,
|
||||
let shell = if let Some(program) = spawn_task.command {
|
||||
Shell::WithArguments {
|
||||
program: spawn_task.command,
|
||||
program,
|
||||
args: spawn_task.args,
|
||||
title_override: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
} else {
|
||||
Shell::System
|
||||
};
|
||||
(task_state, shell)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -535,7 +535,7 @@ message DebugScenario {
|
|||
|
||||
message SpawnInTerminal {
|
||||
string label = 1;
|
||||
string command = 2;
|
||||
optional string command = 2;
|
||||
repeated string args = 3;
|
||||
map<string, string> env = 4;
|
||||
optional string cwd = 5;
|
||||
|
|
|
@ -23,35 +23,26 @@ inventory::collect!(ParameterizedJsonSchema);
|
|||
|
||||
const DEFS_PATH: &str = "#/$defs/";
|
||||
|
||||
/// Replaces the JSON schema definition for some type, and returns a reference to it.
|
||||
/// Replaces the JSON schema definition for some type if it is in use (in the definitions list), and
|
||||
/// returns a reference to it.
|
||||
///
|
||||
/// This asserts that JsonSchema::schema_name() + "2" does not exist because this indicates that
|
||||
/// there are multiple types that use this name, and unfortunately schemars APIs do not support
|
||||
/// resolving this ambiguity - see https://github.com/GREsau/schemars/issues/449
|
||||
///
|
||||
/// This takes a closure for `schema` because some settings types are not available on the remote
|
||||
/// server, and so will crash when attempting to access e.g. GlobalThemeRegistry.
|
||||
pub fn replace_subschema<T: JsonSchema>(
|
||||
generator: &mut schemars::SchemaGenerator,
|
||||
schema: schemars::Schema,
|
||||
schema: impl Fn() -> schemars::Schema,
|
||||
) -> schemars::Schema {
|
||||
// The key in definitions may not match T::schema_name() if multiple types have the same name.
|
||||
// This is a workaround for there being no straightforward way to get the key used for a type -
|
||||
// see https://github.com/GREsau/schemars/issues/449
|
||||
let ref_schema = generator.subschema_for::<T>();
|
||||
if let Some(serde_json::Value::String(definition_pointer)) = ref_schema.get("$ref") {
|
||||
if let Some(definition_name) = definition_pointer.strip_prefix(DEFS_PATH) {
|
||||
generator
|
||||
.definitions_mut()
|
||||
.insert(definition_name.to_string(), schema.to_value());
|
||||
return ref_schema;
|
||||
} else {
|
||||
log::error!(
|
||||
"bug: expected `$ref` field to start with {DEFS_PATH}, \
|
||||
got {definition_pointer}"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
log::error!("bug: expected `$ref` field in result of `subschema_for`");
|
||||
}
|
||||
// fallback on just using the schema name, which could collide.
|
||||
let schema_name = T::schema_name();
|
||||
generator
|
||||
.definitions_mut()
|
||||
.insert(schema_name.to_string(), schema.to_value());
|
||||
let definitions = generator.definitions_mut();
|
||||
assert!(!definitions.contains_key(&format!("{schema_name}2")));
|
||||
if definitions.contains_key(schema_name.as_ref()) {
|
||||
definitions.insert(schema_name.to_string(), schema().to_value());
|
||||
}
|
||||
Schema::new_ref(format!("{DEFS_PATH}{schema_name}"))
|
||||
}
|
||||
|
||||
|
|
|
@ -301,7 +301,12 @@ impl DebugTaskFile {
|
|||
.get_mut("properties")
|
||||
.and_then(|value| value.as_object_mut())
|
||||
{
|
||||
properties.remove("label");
|
||||
if properties.remove("label").is_none() {
|
||||
debug_panic!(
|
||||
"Generated TaskTemplate json schema did not have expected 'label' field. \
|
||||
Schema of 2nd alternative is: {template_object:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(arr) = template_object
|
||||
|
@ -311,13 +316,13 @@ impl DebugTaskFile {
|
|||
arr.retain(|v| v.as_str() != Some("label"));
|
||||
}
|
||||
} else {
|
||||
debug_panic!("Task Template schema in debug scenario's needs to be updated");
|
||||
debug_panic!(
|
||||
"Generated TaskTemplate json schema did not match expectations. \
|
||||
Schema is: {build_task_value:?}"
|
||||
);
|
||||
}
|
||||
|
||||
let task_definitions = build_task_value
|
||||
.get("definitions")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let task_definitions = build_task_value.get("$defs").cloned().unwrap_or_default();
|
||||
|
||||
let adapter_conditions = schemas
|
||||
.0
|
||||
|
@ -375,7 +380,7 @@ impl DebugTaskFile {
|
|||
},
|
||||
"allOf": adapter_conditions
|
||||
},
|
||||
"definitions": task_definitions
|
||||
"$defs": task_definitions
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ pub struct SpawnInTerminal {
|
|||
/// Human readable name of the terminal tab.
|
||||
pub label: String,
|
||||
/// Executable command to spawn.
|
||||
pub command: String,
|
||||
pub command: Option<String>,
|
||||
/// Arguments to the command, potentially unsubstituted,
|
||||
/// to let the shell that spawns the command to do the substitution, if needed.
|
||||
pub args: Vec<String>,
|
||||
|
@ -387,20 +387,26 @@ impl ShellBuilder {
|
|||
}
|
||||
|
||||
/// Returns the program and arguments to run this task in a shell.
|
||||
pub fn build(mut self, task_command: String, task_args: &Vec<String>) -> (String, Vec<String>) {
|
||||
let combined_command = task_args
|
||||
.into_iter()
|
||||
.fold(task_command, |mut command, arg| {
|
||||
command.push(' ');
|
||||
command.push_str(&arg);
|
||||
command
|
||||
});
|
||||
self.args.extend(
|
||||
self.interactive
|
||||
.then(|| "-i".to_owned())
|
||||
pub fn build(
|
||||
mut self,
|
||||
task_command: Option<String>,
|
||||
task_args: &Vec<String>,
|
||||
) -> (String, Vec<String>) {
|
||||
if let Some(task_command) = task_command {
|
||||
let combined_command = task_args
|
||||
.into_iter()
|
||||
.chain(["-c".to_owned(), combined_command]),
|
||||
);
|
||||
.fold(task_command, |mut command, arg| {
|
||||
command.push(' ');
|
||||
command.push_str(&arg);
|
||||
command
|
||||
});
|
||||
self.args.extend(
|
||||
self.interactive
|
||||
.then(|| "-i".to_owned())
|
||||
.into_iter()
|
||||
.chain(["-c".to_owned(), combined_command]),
|
||||
);
|
||||
}
|
||||
|
||||
(self.program, self.args)
|
||||
}
|
||||
|
@ -428,21 +434,29 @@ impl ShellBuilder {
|
|||
}
|
||||
|
||||
/// Returns the program and arguments to run this task in a shell.
|
||||
pub fn build(mut self, task_command: String, task_args: &Vec<String>) -> (String, Vec<String>) {
|
||||
let combined_command = task_args
|
||||
.into_iter()
|
||||
.fold(task_command, |mut command, arg| {
|
||||
command.push(' ');
|
||||
command.push_str(&self.to_windows_shell_variable(arg.to_string()));
|
||||
command
|
||||
});
|
||||
pub fn build(
|
||||
mut self,
|
||||
task_command: Option<String>,
|
||||
task_args: &Vec<String>,
|
||||
) -> (String, Vec<String>) {
|
||||
if let Some(task_command) = task_command {
|
||||
let combined_command = task_args
|
||||
.into_iter()
|
||||
.fold(task_command, |mut command, arg| {
|
||||
command.push(' ');
|
||||
command.push_str(&self.to_windows_shell_variable(arg.to_string()));
|
||||
command
|
||||
});
|
||||
|
||||
match self.windows_shell_type() {
|
||||
WindowsShellType::Powershell => self.args.extend(["-C".to_owned(), combined_command]),
|
||||
WindowsShellType::Cmd => self.args.extend(["/C".to_owned(), combined_command]),
|
||||
WindowsShellType::Other => {
|
||||
self.args
|
||||
.extend(["-i".to_owned(), "-c".to_owned(), combined_command])
|
||||
match self.windows_shell_type() {
|
||||
WindowsShellType::Powershell => {
|
||||
self.args.extend(["-C".to_owned(), combined_command])
|
||||
}
|
||||
WindowsShellType::Cmd => self.args.extend(["/C".to_owned(), combined_command]),
|
||||
WindowsShellType::Other => {
|
||||
self.args
|
||||
.extend(["-i".to_owned(), "-c".to_owned(), combined_command])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
172
crates/task/src/shell_builder.rs
Normal file
172
crates/task/src/shell_builder.rs
Normal file
|
@ -0,0 +1,172 @@
|
|||
use crate::Shell;
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
|
||||
enum ShellKind {
|
||||
#[default]
|
||||
Posix,
|
||||
Powershell,
|
||||
Cmd,
|
||||
}
|
||||
|
||||
impl ShellKind {
|
||||
fn new(program: &str) -> Self {
|
||||
if program == "powershell"
|
||||
|| program.ends_with("powershell.exe")
|
||||
|| program == "pwsh"
|
||||
|| program.ends_with("pwsh.exe")
|
||||
{
|
||||
ShellKind::Powershell
|
||||
} else if program == "cmd" || program.ends_with("cmd.exe") {
|
||||
ShellKind::Cmd
|
||||
} else {
|
||||
// Someother shell detected, the user might install and use a
|
||||
// unix-like shell.
|
||||
ShellKind::Posix
|
||||
}
|
||||
}
|
||||
|
||||
fn to_shell_variable(&self, input: &str) -> String {
|
||||
match self {
|
||||
Self::Powershell => Self::to_powershell_variable(input),
|
||||
Self::Cmd => Self::to_cmd_variable(input),
|
||||
Self::Posix => input.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_cmd_variable(input: &str) -> String {
|
||||
if let Some(var_str) = input.strip_prefix("${") {
|
||||
if var_str.find(':').is_none() {
|
||||
// If the input starts with "${", remove the trailing "}"
|
||||
format!("%{}%", &var_str[..var_str.len() - 1])
|
||||
} else {
|
||||
// `${SOME_VAR:-SOME_DEFAULT}`, we currently do not handle this situation,
|
||||
// which will result in the task failing to run in such cases.
|
||||
input.into()
|
||||
}
|
||||
} else if let Some(var_str) = input.strip_prefix('$') {
|
||||
// If the input starts with "$", directly append to "$env:"
|
||||
format!("%{}%", var_str)
|
||||
} else {
|
||||
// If no prefix is found, return the input as is
|
||||
input.into()
|
||||
}
|
||||
}
|
||||
fn to_powershell_variable(input: &str) -> String {
|
||||
if let Some(var_str) = input.strip_prefix("${") {
|
||||
if var_str.find(':').is_none() {
|
||||
// If the input starts with "${", remove the trailing "}"
|
||||
format!("$env:{}", &var_str[..var_str.len() - 1])
|
||||
} else {
|
||||
// `${SOME_VAR:-SOME_DEFAULT}`, we currently do not handle this situation,
|
||||
// which will result in the task failing to run in such cases.
|
||||
input.into()
|
||||
}
|
||||
} else if let Some(var_str) = input.strip_prefix('$') {
|
||||
// If the input starts with "$", directly append to "$env:"
|
||||
format!("$env:{}", var_str)
|
||||
} else {
|
||||
// If no prefix is found, return the input as is
|
||||
input.into()
|
||||
}
|
||||
}
|
||||
|
||||
fn args_for_shell(&self, interactive: bool, combined_command: String) -> Vec<String> {
|
||||
match self {
|
||||
ShellKind::Powershell => vec!["-C".to_owned(), combined_command],
|
||||
ShellKind::Cmd => vec!["/C".to_owned(), combined_command],
|
||||
ShellKind::Posix => interactive
|
||||
.then(|| "-i".to_owned())
|
||||
.into_iter()
|
||||
.chain(["-c".to_owned(), combined_command])
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn system_shell() -> String {
|
||||
if cfg!(target_os = "windows") {
|
||||
// `alacritty_terminal` uses this as default on Windows. See:
|
||||
// https://github.com/alacritty/alacritty/blob/0d4ab7bca43213d96ddfe40048fc0f922543c6f8/alacritty_terminal/src/tty/windows/mod.rs#L130
|
||||
// We could use `util::get_windows_system_shell()` here, but we are running tasks here, so leave it to `powershell.exe`
|
||||
// should be okay.
|
||||
"powershell.exe".to_string()
|
||||
} else {
|
||||
std::env::var("SHELL").unwrap_or("/bin/sh".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// ShellBuilder is used to turn a user-requested task into a
|
||||
/// program that can be executed by the shell.
|
||||
pub struct ShellBuilder {
|
||||
/// The shell to run
|
||||
program: String,
|
||||
args: Vec<String>,
|
||||
interactive: bool,
|
||||
kind: ShellKind,
|
||||
}
|
||||
|
||||
pub static DEFAULT_REMOTE_SHELL: &str = "\"${SHELL:-sh}\"";
|
||||
|
||||
impl ShellBuilder {
|
||||
/// Create a new ShellBuilder as configured.
|
||||
pub fn new(is_local: bool, shell: &Shell) -> Self {
|
||||
let (program, args) = match shell {
|
||||
Shell::System => {
|
||||
if is_local {
|
||||
(system_shell(), Vec::new())
|
||||
} else {
|
||||
(DEFAULT_REMOTE_SHELL.to_string(), Vec::new())
|
||||
}
|
||||
}
|
||||
Shell::Program(shell) => (shell.clone(), Vec::new()),
|
||||
Shell::WithArguments { program, args, .. } => (program.clone(), args.clone()),
|
||||
};
|
||||
let kind = ShellKind::new(&program);
|
||||
Self {
|
||||
program,
|
||||
args,
|
||||
interactive: true,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
pub fn non_interactive(mut self) -> Self {
|
||||
self.interactive = false;
|
||||
self
|
||||
}
|
||||
/// Returns the label to show in the terminal tab
|
||||
pub fn command_label(&self, command_label: &str) -> String {
|
||||
match self.kind {
|
||||
ShellKind::Powershell => {
|
||||
format!("{} -C '{}'", self.program, command_label)
|
||||
}
|
||||
ShellKind::Cmd => {
|
||||
format!("{} /C '{}'", self.program, command_label)
|
||||
}
|
||||
ShellKind::Posix => {
|
||||
let interactivity = self.interactive.then_some("-i ").unwrap_or_default();
|
||||
format!("{} {interactivity}-c '{}'", self.program, command_label)
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns the program and arguments to run this task in a shell.
|
||||
pub fn build(
|
||||
mut self,
|
||||
task_command: Option<String>,
|
||||
task_args: &Vec<String>,
|
||||
) -> (String, Vec<String>) {
|
||||
if let Some(task_command) = task_command {
|
||||
let combined_command = task_args
|
||||
.into_iter()
|
||||
.fold(task_command, |mut command, arg| {
|
||||
command.push(' ');
|
||||
command.push_str(&self.kind.to_shell_variable(arg));
|
||||
command
|
||||
});
|
||||
|
||||
self.args
|
||||
.extend(self.kind.args_for_shell(self.interactive, combined_command));
|
||||
}
|
||||
|
||||
(self.program, self.args)
|
||||
}
|
||||
}
|
|
@ -253,7 +253,7 @@ impl TaskTemplate {
|
|||
command_label
|
||||
},
|
||||
),
|
||||
command,
|
||||
command: Some(command),
|
||||
args: self.args.clone(),
|
||||
env,
|
||||
use_new_terminal: self.use_new_terminal,
|
||||
|
@ -633,7 +633,7 @@ mod tests {
|
|||
"Human-readable label should have long substitutions trimmed"
|
||||
);
|
||||
assert_eq!(
|
||||
spawn_in_terminal.command,
|
||||
spawn_in_terminal.command.clone().unwrap(),
|
||||
format!("echo test_file {long_value}"),
|
||||
"Command should be substituted with variables and those should not be shortened"
|
||||
);
|
||||
|
@ -650,7 +650,7 @@ mod tests {
|
|||
spawn_in_terminal.command_label,
|
||||
format!(
|
||||
"{} arg1 test_selected_text arg2 5678 arg3 {long_value}",
|
||||
spawn_in_terminal.command
|
||||
spawn_in_terminal.command.clone().unwrap()
|
||||
),
|
||||
"Command label args should be substituted with variables and those should not be shortened"
|
||||
);
|
||||
|
@ -709,7 +709,7 @@ mod tests {
|
|||
assert_substituted_variables(&resolved_task, Vec::new());
|
||||
let resolved = resolved_task.resolved;
|
||||
assert_eq!(resolved.label, task.label);
|
||||
assert_eq!(resolved.command, task.command);
|
||||
assert_eq!(resolved.command, Some(task.command));
|
||||
assert_eq!(resolved.args, task.args);
|
||||
}
|
||||
|
||||
|
|
|
@ -499,7 +499,7 @@ impl TerminalPanel {
|
|||
|
||||
let task = SpawnInTerminal {
|
||||
command_label,
|
||||
command,
|
||||
command: Some(command),
|
||||
args,
|
||||
..task.clone()
|
||||
};
|
||||
|
|
|
@ -978,11 +978,10 @@ pub struct ThemeName(pub Arc<str>);
|
|||
inventory::submit! {
|
||||
ParameterizedJsonSchema {
|
||||
add_and_get_ref: |generator, _params, cx| {
|
||||
let schema = json_schema!({
|
||||
replace_subschema::<ThemeName>(generator, || json_schema!({
|
||||
"type": "string",
|
||||
"enum": ThemeRegistry::global(cx).list_names(),
|
||||
});
|
||||
replace_subschema::<ThemeName>(generator, schema)
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -996,15 +995,14 @@ pub struct IconThemeName(pub Arc<str>);
|
|||
inventory::submit! {
|
||||
ParameterizedJsonSchema {
|
||||
add_and_get_ref: |generator, _params, cx| {
|
||||
let schema = json_schema!({
|
||||
replace_subschema::<IconThemeName>(generator, || json_schema!({
|
||||
"type": "string",
|
||||
"enum": ThemeRegistry::global(cx)
|
||||
.list_icon_themes()
|
||||
.into_iter()
|
||||
.map(|icon_theme| icon_theme.name)
|
||||
.collect::<Vec<SharedString>>(),
|
||||
});
|
||||
replace_subschema::<IconThemeName>(generator, schema)
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1018,11 +1016,12 @@ pub struct FontFamilyName(pub Arc<str>);
|
|||
inventory::submit! {
|
||||
ParameterizedJsonSchema {
|
||||
add_and_get_ref: |generator, params, _cx| {
|
||||
let schema = json_schema!({
|
||||
"type": "string",
|
||||
"enum": params.font_names,
|
||||
});
|
||||
replace_subschema::<FontFamilyName>(generator, schema)
|
||||
replace_subschema::<FontFamilyName>(generator, || {
|
||||
json_schema!({
|
||||
"type": "string",
|
||||
"enum": params.font_names,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1669,7 +1669,7 @@ impl ShellExec {
|
|||
id: TaskId("vim".to_string()),
|
||||
full_label: command.clone(),
|
||||
label: command.clone(),
|
||||
command: command.clone(),
|
||||
command: Some(command.clone()),
|
||||
args: Vec::new(),
|
||||
command_label: command.clone(),
|
||||
cwd,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
description = "The fast, collaborative code editor."
|
||||
edition.workspace = true
|
||||
name = "zed"
|
||||
version = "0.194.0"
|
||||
version = "0.194.3"
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
authors = ["Zed Team <hi@zed.dev>"]
|
||||
|
|
|
@ -1 +1 @@
|
|||
dev
|
||||
stable
|
Loading…
Add table
Add a link
Reference in a new issue