Merge branch 'main' into push-trzsxukkukpr
This commit is contained in:
commit
6522ef5456
142 changed files with 4750 additions and 2275 deletions
49
Cargo.lock
generated
49
Cargo.lock
generated
|
@ -17,6 +17,7 @@ dependencies = [
|
|||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
|
@ -137,9 +138,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.18"
|
||||
version = "0.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8e4c1dccb35e69d32566f0d11948d902f9942fc3f038821816c1150cf5925f4"
|
||||
checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
|
@ -150,6 +151,44 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"anyhow",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"indoc",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_servers"
|
||||
version = "0.1.0"
|
||||
|
@ -214,6 +253,7 @@ dependencies = [
|
|||
"acp_thread",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent2",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
"ai_onboarding",
|
||||
|
@ -9209,6 +9249,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-fs",
|
||||
"async-tar",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
|
@ -9240,9 +9281,11 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"sha2",
|
||||
"smol",
|
||||
"snippet_provider",
|
||||
"task",
|
||||
"tempfile",
|
||||
"text",
|
||||
"theme",
|
||||
"toml 0.8.20",
|
||||
|
@ -20376,7 +20419,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.199.0"
|
||||
version = "0.200.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
|
|
|
@ -4,6 +4,7 @@ members = [
|
|||
"crates/acp_thread",
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/agent_servers",
|
||||
"crates/agent_settings",
|
||||
"crates/agent_ui",
|
||||
|
@ -229,6 +230,7 @@ edition = "2024"
|
|||
|
||||
acp_thread = { path = "crates/acp_thread" }
|
||||
agent = { path = "crates/agent" }
|
||||
agent2 = { path = "crates/agent2" }
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent_ui = { path = "crates/agent_ui" }
|
||||
agent_settings = { path = "crates/agent_settings" }
|
||||
|
@ -423,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
|||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.18"
|
||||
agent-client-protocol = "0.0.20"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
|
1
assets/images/certified_user_stamp.svg
Normal file
1
assets/images/certified_user_stamp.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 6.4 KiB |
1
assets/images/pro_trial_stamp.svg
Normal file
1
assets/images/pro_trial_stamp.svg
Normal file
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="51" height="51" fill="none"><path fill="#000" fill-opacity=".15" d="M45 3a3 3 0 0 1 3 3v39a3 3 0 0 1-3 3H6a3 3 0 0 1-3-3V6a3 3 0 0 1 3-3h39ZM10 7a3 3 0 0 0-3 3v31a3 3 0 0 0 3 3h31a3 3 0 0 0 3-3V10a3 3 0 0 0-3-3H10Z"/><rect width="36" height="36" x="7.5" y="7.5" stroke="#000" stroke-dasharray="2 2" rx=".5"/><rect width="44" height="44" x="3.5" y="3.5" stroke="#000" stroke-dasharray="2 2" rx="2.5"/><path fill="#000" d="M28.636 13.124c.732 0 1.27.617 1.27 1.464s-.538 1.465-1.27 1.465-1.27-.618-1.27-1.465c0-.847.538-1.464 1.27-1.464Zm-6.066-.784c.69 0 1.171.465 1.171 1.124s-.48 1.124-1.17 1.124h-.66V16h-.653v-3.66h1.312Zm3.323.8c.46 0 .784.319.837.82h-.649c-.042-.166-.188-.27-.376-.27-.319 0-.544.297-.544.73V16h-.627v-2.823h.6v.382h.022a.844.844 0 0 1 .737-.418Zm2.743.56c-.382 0-.643.36-.643.888s.261.89.643.89c.381 0 .643-.362.643-.89s-.262-.889-.643-.889Zm-6.725.314h.686c.308 0 .517-.22.517-.55 0-.33-.21-.549-.517-.549h-.686v1.099ZM28.63 36.124c.649 0 1.083.366 1.083.91v.957c0 .382.036.721.104 1.009h-.606a2.793 2.793 0 0 1-.078-.366h-.027c-.151.266-.413.419-.757.419-.539 0-.915-.356-.915-.868 0-.523.397-.811 1.228-.884l.173-.016c.177-.016.25-.09.25-.246 0-.24-.172-.392-.465-.392-.267 0-.46.126-.48.314h-.613c.006-.486.466-.837 1.104-.837Zm3.413 1.867c0 .355.167.527.512.528.146 0 .303-.031.455-.1v.523c-.157.074-.361.11-.575.11-.67 0-1.02-.366-1.02-1.06v-2.286h-.81v-.575h1.438v2.86Zm-11.452-2.076h-1.03V39h-.654v-3.085h-1.035v-.575h2.719v.575Zm2.23.226c.46 0 .784.318.836.82h-.648c-.042-.167-.188-.271-.376-.271-.32 0-.544.297-.544.73V39h-.627v-2.823h.6v.382h.022a.844.844 0 0 1 .737-.418Zm2.993 2.284h.785V39h-2.196v-.575h.785v-1.673h-.758v-.575h1.384v2.248Zm3.272-.784a.583.583 0 0 1-.204.062l-.168.022c-.46.057-.653.193-.653.449 0 .23.162.382.418.382.34 0 .607-.272.607-.622v-.293Zm-3.638-2.782c.267 0 .477.21.477.476a.477.477 0 0 1-.952 0c0-.267.214-.475.475-.476Z"/><path fill="#000" fill-rule="evenodd" d="M20.219 19.813a.406.406 0 0 0-.407.406v8.937H19V20.22c0-.673.546-1.219 1.219-1.219h10.884a.61.61 0 0 1 .431 1.04l-6.704 6.704h1.889v-.838h.812v1.041a.61.61 0 0 1-.61.61h-2.903l-1.397 1.396h6.332v-5.078h.813v5.078a.812.812 0 0 1-.813.813H21.81l-1.422 1.422h10.394a.406.406 0 0 0 .407-.407v-8.937H32v8.937c0 .673-.546 1.219-1.219 1.219H19.897a.61.61 0 0 1-.431-1.04l6.678-6.679h-1.863v.813h-.812v-1.016a.61.61 0 0 1 .61-.61h2.878l1.422-1.421h-6.332v5.078h-.813v-5.078c0-.449.364-.813.813-.813h7.144l1.422-1.422H20.219Z" clip-rule="evenodd"/></svg>
|
After Width: | Height: | Size: 2.5 KiB |
|
@ -25,6 +25,7 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
|
|
|
@ -18,6 +18,7 @@ use project::{AgentLocation, Project};
|
|||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt::Formatter;
|
||||
use std::process::ExitStatus;
|
||||
use std::rc::Rc;
|
||||
use std::{
|
||||
fmt::Display,
|
||||
|
@ -581,6 +582,7 @@ pub enum AcpThreadEvent {
|
|||
ToolAuthorizationRequired,
|
||||
Stopped,
|
||||
Error,
|
||||
ServerExited(ExitStatus),
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
|
@ -654,6 +656,10 @@ impl AcpThread {
|
|||
&self.entries
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &acp::SessionId {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ThreadStatus {
|
||||
if self.send_task.is_some() {
|
||||
if self.waiting_for_tool_confirmation() {
|
||||
|
@ -1229,6 +1235,10 @@ impl AcpThread {
|
|||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
self.entries.iter().map(|e| e.to_markdown(cx)).collect()
|
||||
}
|
||||
|
||||
pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
|
||||
cx.emit(AcpThreadEvent::ServerExited(status));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1371,6 +1381,9 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
})?;
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
|
@ -1443,7 +1456,9 @@ mod tests {
|
|||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
Ok(())
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
},
|
||||
|
@ -1516,7 +1531,9 @@ mod tests {
|
|||
})
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
Ok(())
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
@ -1626,7 +1643,9 @@ mod tests {
|
|||
})
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
Ok(())
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
@ -1680,7 +1699,7 @@ mod tests {
|
|||
acp::PromptRequest,
|
||||
WeakEntity<AcpThread>,
|
||||
AsyncApp,
|
||||
) -> LocalBoxFuture<'static, Result<()>>
|
||||
) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
|
||||
+ 'static,
|
||||
>,
|
||||
>,
|
||||
|
@ -1707,7 +1726,7 @@ mod tests {
|
|||
acp::PromptRequest,
|
||||
WeakEntity<AcpThread>,
|
||||
AsyncApp,
|
||||
) -> LocalBoxFuture<'static, Result<()>>
|
||||
) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
|
||||
+ 'static,
|
||||
) -> Self {
|
||||
self.on_user_message.replace(Rc::new(handler));
|
||||
|
@ -1749,7 +1768,11 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
if let Some(handler) = &self.on_user_message {
|
||||
|
@ -1757,7 +1780,9 @@ mod tests {
|
|||
let thread = thread.clone();
|
||||
cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
|
||||
} else {
|
||||
Task::ready(Ok(()))
|
||||
Task::ready(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,61 @@
|
|||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModel;
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
|
||||
use crate::AcpThread;
|
||||
|
||||
/// Trait for agents that support listing, selecting, and querying language models.
|
||||
///
|
||||
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||
pub trait ModelSelector: 'static {
|
||||
/// Lists all available language models for this agent.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `cx`: The GPUI app context for async operations and global access.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||
|
||||
/// Selects a model for a specific session (thread).
|
||||
///
|
||||
/// This sets the default model for future interactions in the session.
|
||||
/// If the session doesn't exist or the model is invalid, it returns an error.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to apply the model to.
|
||||
/// - `model`: The model to select (should be one from [list_models]).
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to `Ok(())` on success or an error.
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>>;
|
||||
|
||||
/// Retrieves the currently selected model for a specific session (thread).
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to query.
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
|
@ -20,9 +68,18 @@ pub trait AgentConnection {
|
|||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
||||
-> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
None // Default impl for agents that don't support it
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::{
|
|||
},
|
||||
tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
|
||||
};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
@ -2112,12 +2112,10 @@ impl Thread {
|
|||
return;
|
||||
}
|
||||
|
||||
let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
|
||||
|
||||
let request = self.to_summarize_request(
|
||||
&model.model,
|
||||
CompletionIntent::ThreadSummarization,
|
||||
added_user_message.into(),
|
||||
SUMMARIZE_THREAD_PROMPT.into(),
|
||||
cx,
|
||||
);
|
||||
|
||||
|
@ -4047,8 +4045,8 @@ fn main() {{
|
|||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Brief");
|
||||
fake_model.stream_last_completion_response(" Introduction");
|
||||
fake_model.send_last_completion_stream_text_chunk("Brief");
|
||||
fake_model.send_last_completion_stream_text_chunk(" Introduction");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -4141,7 +4139,7 @@ fn main() {{
|
|||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("A successful summary");
|
||||
fake_model.send_last_completion_stream_text_chunk("A successful summary");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -4774,7 +4772,7 @@ fn main() {{
|
|||
!pending.is_empty(),
|
||||
"Should have a pending completion after retry"
|
||||
);
|
||||
fake_model.stream_completion_response(&pending[0], "Success!");
|
||||
fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
|
||||
fake_model.end_completion_stream(&pending[0]);
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -4942,7 +4940,7 @@ fn main() {{
|
|||
|
||||
// Check for pending completions and complete them
|
||||
if let Some(pending) = inner_fake.pending_completions().first() {
|
||||
inner_fake.stream_completion_response(pending, "Success!");
|
||||
inner_fake.send_completion_stream_text_chunk(pending, "Success!");
|
||||
inner_fake.end_completion_stream(pending);
|
||||
}
|
||||
cx.run_until_parked();
|
||||
|
@ -5427,7 +5425,7 @@ fn main() {{
|
|||
|
||||
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Assistant response");
|
||||
fake_model.send_last_completion_stream_text_chunk("Assistant response");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
|
54
crates/agent2/Cargo.toml
Normal file
54
crates/agent2/Cargo.toml
Normal file
|
@ -0,0 +1,54 @@
|
|||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
indoc.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
worktree.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
1
crates/agent2/LICENSE-GPL
Symbolic link
1
crates/agent2/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
345
crates/agent2/src/agent.rs
Normal file
345
crates/agent2/src/agent.rs
Normal file
|
@ -0,0 +1,345 @@
|
|||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use project::Project;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||
|
||||
/// Holds both the internal Thread and the AcpThread for a session
|
||||
struct Session {
|
||||
/// The internal thread that processes messages
|
||||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl NativeAgent {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
log::info!("Creating new NativeAgent");
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
templates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let thread = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
})?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let agent = self.0.clone();
|
||||
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Starting thread creation in async context");
|
||||
// Create Thread
|
||||
let (session_id, thread) = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
log::debug!("Total available models: {}", available_count);
|
||||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
Ok((session_id, thread))
|
||||
},
|
||||
)??;
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
})
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[] // No auth for in-process
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Get session
|
||||
let (thread, acp_thread) = agent
|
||||
.update(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
log::error!("Session not found: {}", session_id);
|
||||
anyhow::anyhow!("Session not found")
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message = convert_prompt_to_message(params.prompt);
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {}", message);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentThoughtChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCall(tool_call),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCallUpdate(tool_call_update),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
// TODO: Consider sending an error message to the UI
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
log::info!("Cancelling on session: {}", session_id);
|
||||
self.0.update(cx, |agent, cx| {
|
||||
if let Some(agent) = agent.sessions.get(session_id) {
|
||||
agent.thread.update(cx, |thread, _cx| thread.cancel());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ACP content blocks to a message string
|
||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||
let mut message = String::new();
|
||||
|
||||
for block in blocks {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
log::trace!("Processing text block: {} chars", text.text.len());
|
||||
message.push_str(&text.text);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => {
|
||||
log::trace!("Processing resource link: {}", link.uri);
|
||||
message.push_str(&format!(" @{} ", link.uri));
|
||||
}
|
||||
acp::ContentBlock::Image(_) => {
|
||||
log::trace!("Processing image block");
|
||||
message.push_str(" [image] ");
|
||||
}
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
log::trace!("Processing audio block");
|
||||
message.push_str(" [audio] ");
|
||||
}
|
||||
acp::ContentBlock::Resource(resource) => {
|
||||
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message
|
||||
}
|
13
crates/agent2/src/agent2.rs
Normal file
13
crates/agent2/src/agent2.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
mod agent;
|
||||
mod native_agent_server;
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use thread::*;
|
58
crates/agent2/src/native_agent_server.rs
Normal file
58
crates/agent2/src/native_agent_server.rs
Normal file
|
@ -0,0 +1,58 @@
|
|||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"How can I help you today?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
// Using the ZedAssistant icon as it's the native built-in agent
|
||||
ui::IconName::ZedAssistant
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
||||
log::info!(
|
||||
"NativeAgentServer::connect called for path: {:?}",
|
||||
_root_dir
|
||||
);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
// Create templates (you might want to load these from files or resources)
|
||||
let templates = Templates::new();
|
||||
|
||||
// Create the native agent
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
log::info!("NativeAgentServer connection established successfully");
|
||||
|
||||
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
|
||||
})
|
||||
}
|
||||
}
|
35
crates/agent2/src/prompts.rs
Normal file
35
crates/agent2/src/prompts.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
pub struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl BasePrompt {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
57
crates/agent2/src/templates.rs
Normal file
57
crates/agent2/src/templates.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
}
|
56
crates/agent2/src/templates/base.hbs
Normal file
56
crates/agent2/src/templates/base.hbs
Normal file
|
@ -0,0 +1,56 @@
|
|||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
8
crates/agent2/src/templates/glob.hbs
Normal file
8
crates/agent2/src/templates/glob.hbs
Normal file
|
@ -0,0 +1,8 @@
|
|||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
534
crates/agent2/src/tests/mod.rs
Normal file
534
crates/agent2/src/tests/mod.rs
Normal file
|
@ -0,0 +1,534 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use client::{Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
|
||||
StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_thinking(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
model.clone(),
|
||||
indoc! {"
|
||||
Testing:
|
||||
|
||||
Generate a thinking step where you just think the word 'Think',
|
||||
and have your final answer be 'Hello'
|
||||
"},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## assistant
|
||||
<think>Think</think>
|
||||
Hello
|
||||
"}
|
||||
)
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
|
||||
// Test a tool calls that's likely to complete *after* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert!(thread
|
||||
.messages()
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send(model.clone(), "Test the word_list tool.", cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_call.status == acp::ToolCallStatus::Pending {
|
||||
if !last_tool_use.is_input_complete
|
||||
&& last_tool_use.input.get("g").is_none()
|
||||
{
|
||||
saw_partial_tool_use = true;
|
||||
}
|
||||
} else {
|
||||
last_tool_use
|
||||
.input
|
||||
.get("a")
|
||||
.expect("'a' has streamed because input is now complete");
|
||||
last_tool_use
|
||||
.input
|
||||
.get("g")
|
||||
.expect("'g' has streamed because input is now complete");
|
||||
}
|
||||
} else {
|
||||
panic!("last content should be a tool use");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_partial_tool_use,
|
||||
"should see at least one partially streamed tool use in the history"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(InfiniteTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Call the echo tool and then call the infinite tool, then explain their output",
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Wait until both tools are called.
|
||||
let mut expected_tool_calls = vec!["echo", "infinite"];
|
||||
let mut echo_id = None;
|
||||
let mut echo_completed = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event.unwrap() {
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
assert_eq!(tool_call.title, expected_tool_calls.remove(0));
|
||||
if tool_call.title == "echo" {
|
||||
echo_id = Some(tool_call.id);
|
||||
}
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
|
||||
id,
|
||||
fields:
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
..
|
||||
},
|
||||
}) if Some(&id) == echo_id.as_ref() => {
|
||||
echo_completed = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if expected_tool_calls.is_empty() && echo_completed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the current send and ensure that the event stream is closed, even
|
||||
// if one of the tools is still running.
|
||||
thread.update(cx, |thread, _cx| thread.cancel());
|
||||
events.collect::<Vec<_>>().await;
|
||||
|
||||
// Ensure we can still send a new message after cancellation.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_refusal(cx: &mut TestAppContext) {
|
||||
let fake_model = Arc::new(FakeLanguageModel::default());
|
||||
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(fake_model.clone(), "Hello", cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
Hello
|
||||
## assistant
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
// If the model refuses to continue, the thread should remove all the messages after the last user message.
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
|
||||
let events = events.collect::<Vec<_>>().await;
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.to_markdown(), "");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
// Initialize language model system with test provider
|
||||
cx.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
client::init_settings(cx);
|
||||
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||
let client = Client::new(clock, http_client, cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
let selector_opt = connection.model_selector();
|
||||
assert!(
|
||||
selector_opt.is_some(),
|
||||
"agent2 should always support ModelSelector"
|
||||
);
|
||||
let selector = selector_opt.unwrap();
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
// Get the session_id from the AcpThread
|
||||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
// Test selected_model returns the default
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
let model = model.as_fake();
|
||||
assert_eq!(model.id().0, "fake", "should return default model");
|
||||
|
||||
let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
|
||||
cx.run_until_parked();
|
||||
model.send_last_completion_stream_text_chunk("def");
|
||||
cx.run_until_parked();
|
||||
acp_thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
abc
|
||||
|
||||
## Assistant
|
||||
|
||||
def
|
||||
|
||||
"}
|
||||
)
|
||||
});
|
||||
|
||||
// Test cancel
|
||||
cx.update(|cx| connection.cancel(&session_id, cx));
|
||||
request.await.expect("prompt should fail gracefully");
|
||||
|
||||
// Ensure that dropping the ACP thread causes the native thread to be
|
||||
// dropped as well.
|
||||
cx.update(|_| drop(acp_thread));
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
connection.prompt(
|
||||
acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec!["ghi".into()],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
result.as_ref().unwrap_err().to_string(),
|
||||
"Session not found",
|
||||
"unexpected result: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
Sonnet4,
|
||||
Sonnet4Thinking,
|
||||
Fake(Arc<FakeLanguageModel>),
|
||||
}
|
||||
|
||||
impl TestModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
match self {
|
||||
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
|
||||
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
|
||||
TestModel::Fake(fake_model) => fake_model.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(|cx| {
|
||||
settings::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
let templates = Templates::new();
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake(model) = model {
|
||||
Task::ready(model as Arc<_>)
|
||||
} else {
|
||||
let model_id = model.id();
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id() == model_id)
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
|
||||
|
||||
ThreadTest { model, thread }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
106
crates/agent2/src/tests/test_tools.rs
Normal file
106
crates/agent2/src/tests/test_tools.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use super::*;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, SharedString, Task};
|
||||
use std::future;
|
||||
|
||||
/// A tool that echoes its input
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that waits for a specified delay
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
/// The delay in milliseconds.
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
cx.foreground_executor().spawn(async move {
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct InfiniteToolInput {}
|
||||
|
||||
pub struct InfiniteTool;
|
||||
|
||||
impl AgentTool for InfiniteTool {
|
||||
type Input = InfiniteToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"infinite".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
future::pending::<()>().await;
|
||||
unreachable!()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that takes an object with map from letters to random words starting with that letter.
|
||||
/// All fiealds are required! Pass a word for every letter!
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct WordListInput {
|
||||
/// Provide a random word that starts with A.
|
||||
a: Option<String>,
|
||||
/// Provide a random word that starts with B.
|
||||
b: Option<String>,
|
||||
/// Provide a random word that starts with C.
|
||||
c: Option<String>,
|
||||
/// Provide a random word that starts with D.
|
||||
d: Option<String>,
|
||||
/// Provide a random word that starts with E.
|
||||
e: Option<String>,
|
||||
/// Provide a random word that starts with F.
|
||||
f: Option<String>,
|
||||
/// Provide a random word that starts with G.
|
||||
g: Option<String>,
|
||||
}
|
||||
|
||||
pub struct WordListTool;
|
||||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
754
crates/agent2/src/thread.rs
Normal file
754
crates/agent2/src/thread.rs
Normal file
|
@ -0,0 +1,754 @@
|
|||
use crate::{prompts::BasePrompt, templates::Templates};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, stream::FuturesUnordered};
|
||||
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use project::Project;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, fmt::Write, sync::Arc};
|
||||
use util::{markdown::MarkdownCodeBlock, ResultExt};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
impl AgentMessage {
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = format!("## {}\n", self.role);
|
||||
|
||||
for content in &self.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
markdown.push_str(text);
|
||||
markdown.push('\n');
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
markdown.push_str("<think>");
|
||||
markdown.push_str(text);
|
||||
markdown.push_str("</think>\n");
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
|
||||
MessageContent::Image(_) => {
|
||||
markdown.push_str("<image />\n");
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
markdown.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
));
|
||||
markdown.push_str(&format!(
|
||||
"{}\n",
|
||||
MarkdownCodeBlock {
|
||||
tag: "json",
|
||||
text: &format!("{:#}", tool_use.input)
|
||||
}
|
||||
));
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
markdown.push_str(&format!(
|
||||
"**Tool Result**: {} (ID: {})\n\n",
|
||||
tool_result.tool_name, tool_result.tool_use_id
|
||||
));
|
||||
if tool_result.is_error {
|
||||
markdown.push_str("**ERROR:**\n");
|
||||
}
|
||||
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
writeln!(markdown, "{text}\n").ok();
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
writeln!(markdown, "<image />\n").ok();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(output) = tool_result.output.as_ref() {
|
||||
writeln!(
|
||||
markdown,
|
||||
"**Debug Output**:\n\n```json\n{}\n```\n",
|
||||
serde_json::to_string_pretty(output).unwrap()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
markdown
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AgentResponseEvent {
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: vec![Arc::new(BasePrompt::new(project))],
|
||||
running_turn: None,
|
||||
pending_tool_uses: HashMap::default(),
|
||||
tools: BTreeMap::default(),
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
self.running_turn.take();
|
||||
|
||||
let tool_results = self
|
||||
.pending_tool_uses
|
||||
.drain()
|
||||
.map(|(tool_use_id, tool_use)| {
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id,
|
||||
tool_name: tool_use.name.clone(),
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
|
||||
output: None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
self.last_user_message().content.extend(tool_results);
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
let content = content.into();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let user_message_ix = self.messages.len();
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content],
|
||||
});
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
'outer: loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
if let Some(reason) = to_acp_stop_reason(reason) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(reason)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if reason == StopReason::Refusal {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.messages.truncate(user_message_ix);
|
||||
})?;
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event, &events_tx, cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Error in completion stream: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
break;
|
||||
}
|
||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||
|
||||
// As tool results trickle in, insert them in the last user
|
||||
// message so that they can be sent on the next tick of the
|
||||
// agentic loop.
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
to_acp_tool_call_update(&tool_result),
|
||||
)))
|
||||
.ok();
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
||||
thread
|
||||
.last_user_message()
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self, cx: &App) -> Option<AgentMessage> {
|
||||
log::debug!("Building system message");
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
}
|
||||
|
||||
let result = (!system_message.content.is_empty()).then_some(system_message);
|
||||
log::debug!("System message built: {}", result.is_some());
|
||||
result
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
Text(new_text) => self.handle_text_event(new_text, events_tx, cx),
|
||||
Thinking { text, signature } => {
|
||||
self.handle_thinking_event(text, signature, events_tx, cx)
|
||||
}
|
||||
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, events_tx, cx);
|
||||
}
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
)));
|
||||
}
|
||||
UsageUpdate(_) | StatusUpdate(_) => {}
|
||||
Stop(_) => unreachable!(),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_text_event(
|
||||
&mut self,
|
||||
new_text: String,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone())))
|
||||
.ok();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_thinking_event(
|
||||
&mut self,
|
||||
new_text: String,
|
||||
new_signature: Option<String>,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone())))
|
||||
.ok();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
|
||||
{
|
||||
text.push_str(&new_text);
|
||||
*signature = new_signature.or(signature.take());
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Thinking {
|
||||
text: new_text,
|
||||
signature: new_signature,
|
||||
});
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::RedactedThinking(data));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
self.pending_tool_uses
|
||||
.insert(tool_use.id.clone(), tool_use.clone());
|
||||
let last_message = self.last_assistant_message();
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if push_new_tool_use {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
title: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
})))
|
||||
.ok();
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
} else {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
let content = format!("No tool named {} exists", tool_use.name);
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
) -> LanguageModelToolResult {
|
||||
let tool_output = format!("Error parsing input JSON: {json_parse_error}");
|
||||
LanguageModelToolResult {
|
||||
tool_use_id,
|
||||
tool_name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(tool_output.into()),
|
||||
output: Some(serde_json::Value::String(raw_input.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the user and returns a mutable reference.
|
||||
fn last_user_message(&mut self) -> &mut AgentMessage {
|
||||
if self.messages.last().map_or(true, |m| m.role != Role::User) {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
) -> LanguageModelRequest {
|
||||
log::debug!("Building completion request");
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||
|
||||
let messages = self.build_request_messages(cx);
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
log::debug!("Completion request built successfully");
|
||||
request
|
||||
}
|
||||
|
||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
|
||||
let messages = self
|
||||
.build_system_message(cx)
|
||||
.iter()
|
||||
.chain(self.messages.iter())
|
||||
.map(|message| {
|
||||
log::trace!(
|
||||
" - {} message with {} content items",
|
||||
match message.role {
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
},
|
||||
message.content.len()
|
||||
);
|
||||
LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
messages
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = String::new();
|
||||
for message in &self.messages {
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
markdown
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
where
|
||||
T: AgentTool,
|
||||
{
|
||||
fn name(&self) -> SharedString {
|
||||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_acp_stop_reason(reason: StopReason) -> Option<acp::StopReason> {
|
||||
match reason {
|
||||
StopReason::EndTurn => Some(acp::StopReason::EndTurn),
|
||||
StopReason::MaxTokens => Some(acp::StopReason::MaxTokens),
|
||||
StopReason::Refusal => Some(acp::StopReason::Refusal),
|
||||
StopReason::ToolUse => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate {
|
||||
let status = if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
};
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
|
||||
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
|
||||
acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::Image(acp::ImageContent {
|
||||
annotations: None,
|
||||
data: source.to_string(),
|
||||
mime_type: ImageFormat::Png.mime_type().to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
};
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(status),
|
||||
content: Some(vec![content]),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
1
crates/agent2/src/tools.rs
Normal file
1
crates/agent2/src/tools.rs
Normal file
|
@ -0,0 +1 @@
|
|||
mod glob;
|
76
crates/agent2/src/tools/glob.rs
Normal file
76
crates/agent2/src/tools/glob.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -380,6 +380,7 @@ impl AcpConnection {
|
|||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
log::trace!("Spawned (pid: {})", child.id());
|
||||
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
|
@ -463,7 +464,11 @@ impl AgentConnection for AcpConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let chunks = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
|
@ -483,7 +488,9 @@ impl AgentConnection for AcpConnection {
|
|||
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
|
||||
cx.foreground_executor().spawn(async move {
|
||||
task.await?;
|
||||
anyhow::Ok(())
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ pub struct AcpConnection {
|
|||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
auth_methods: Vec<acp::AuthMethod>,
|
||||
_io_task: Task<Result<()>>,
|
||||
_child: smol::process::Child,
|
||||
}
|
||||
|
||||
pub struct AcpSession {
|
||||
|
@ -47,6 +46,7 @@ impl AcpConnection {
|
|||
|
||||
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||
log::trace!("Spawned (pid: {})", child.id());
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
|
@ -63,6 +63,23 @@ impl AcpConnection {
|
|||
|
||||
let io_task = cx.background_spawn(io_task);
|
||||
|
||||
cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
let status = child.status().await?;
|
||||
|
||||
for session in sessions.borrow().values() {
|
||||
session
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.emit_server_exited(status, cx))
|
||||
.ok();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let response = connection
|
||||
.initialize(acp::InitializeRequest {
|
||||
protocol_version: acp::VERSION,
|
||||
|
@ -84,7 +101,6 @@ impl AcpConnection {
|
|||
connection: connection.into(),
|
||||
server_name,
|
||||
sessions,
|
||||
_child: child,
|
||||
_io_task: io_task,
|
||||
})
|
||||
}
|
||||
|
@ -153,10 +169,16 @@ impl AgentConnection for AcpConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { Ok(conn.prompt(params).await?) })
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = conn.prompt(params).await?;
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
|
|
|
@ -89,6 +89,7 @@ impl AgentServerCommand {
|
|||
pub(crate) async fn resolve(
|
||||
path_bin_name: &'static str,
|
||||
extra_args: &[&'static str],
|
||||
fallback_path: Option<&Path>,
|
||||
settings: Option<AgentServerSettings>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
|
@ -105,13 +106,24 @@ impl AgentServerCommand {
|
|||
env: agent_settings.command.env,
|
||||
});
|
||||
} else {
|
||||
find_bin_in_path(path_bin_name, project, cx)
|
||||
.await
|
||||
.map(|path| Self {
|
||||
match find_bin_in_path(path_bin_name, project, cx).await {
|
||||
Some(path) => Some(Self {
|
||||
path,
|
||||
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
|
||||
env: None,
|
||||
})
|
||||
}),
|
||||
None => fallback_path.and_then(|path| {
|
||||
if path.exists() {
|
||||
Some(Self {
|
||||
path: path.to_path_buf(),
|
||||
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
|
||||
env: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,8 +101,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
|
||||
let Some(command) = AgentServerCommand::resolve(
|
||||
"claude",
|
||||
&[],
|
||||
Some(&util::paths::home_dir().join(".claude/local/claude")),
|
||||
settings,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
@ -114,43 +121,42 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
|
||||
log::trace!("Starting session with id: {}", session_id);
|
||||
|
||||
cx.background_spawn({
|
||||
let session_id = session_id.clone();
|
||||
async move {
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
let mut child = spawn_claude(
|
||||
&command,
|
||||
ClaudeSessionMode::Start,
|
||||
session_id.clone(),
|
||||
&mcp_config_path,
|
||||
&cwd,
|
||||
)?;
|
||||
|
||||
let mut child = spawn_claude(
|
||||
&command,
|
||||
ClaudeSessionMode::Start,
|
||||
session_id.clone(),
|
||||
&mcp_config_path,
|
||||
&cwd,
|
||||
)
|
||||
.await?;
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let pid = child.id();
|
||||
log::trace!("Spawned (pid: {})", pid);
|
||||
let pid = child.id();
|
||||
log::trace!("Spawned (pid: {})", pid);
|
||||
|
||||
ClaudeAgentSession::handle_io(
|
||||
outgoing_rx.take().unwrap(),
|
||||
incoming_message_tx.clone(),
|
||||
child.stdin.take().unwrap(),
|
||||
child.stdout.take().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
cx.background_spawn(async move {
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
|
||||
log::trace!("Stopped (pid: {})", pid);
|
||||
ClaudeAgentSession::handle_io(
|
||||
outgoing_rx.take().unwrap(),
|
||||
incoming_message_tx.clone(),
|
||||
stdin,
|
||||
stdout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(mcp_config_path);
|
||||
anyhow::Ok(())
|
||||
}
|
||||
log::trace!("Stopped (pid: {})", pid);
|
||||
|
||||
drop(mcp_config_path);
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let handler_task = cx.spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let thread_rx = thread_rx.clone();
|
||||
let mut thread_rx = thread_rx.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentSession::handle_message(
|
||||
|
@ -161,6 +167,16 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
)
|
||||
.await
|
||||
}
|
||||
|
||||
if let Some(status) = child.status().await.log_err() {
|
||||
if let Some(thread) = thread_rx.recv().await.ok() {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.emit_server_exited(status, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -191,7 +207,11 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
|
@ -235,10 +255,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
return Task::ready(Err(anyhow!(err)));
|
||||
}
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
rx.await??;
|
||||
Ok(())
|
||||
})
|
||||
cx.foreground_executor().spawn(async move { rx.await? })
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||
|
@ -252,6 +269,14 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
.outgoing_tx
|
||||
.unbounded_send(SdkMessage::new_interrupt_message())
|
||||
.log_err();
|
||||
|
||||
if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::Cancelled,
|
||||
}))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -262,7 +287,7 @@ enum ClaudeSessionMode {
|
|||
Resume,
|
||||
}
|
||||
|
||||
async fn spawn_claude(
|
||||
fn spawn_claude(
|
||||
command: &AgentServerCommand,
|
||||
mode: ClaudeSessionMode,
|
||||
session_id: acp::SessionId,
|
||||
|
@ -313,7 +338,7 @@ async fn spawn_claude(
|
|||
|
||||
struct ClaudeAgentSession {
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
@ -322,7 +347,7 @@ impl ClaudeAgentSession {
|
|||
async fn handle_message(
|
||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
message: SdkMessage,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
match message {
|
||||
|
@ -355,6 +380,24 @@ impl ClaudeAgentSession {
|
|||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::Thinking { thinking } => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(thinking.into(), true, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::RedactedThinking => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
"[REDACTED]".into(),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolUse { id, name, input } => {
|
||||
let claude_tool = ClaudeTool::infer(&name, input);
|
||||
|
||||
|
@ -404,8 +447,6 @@ impl ClaudeAgentSession {
|
|||
}
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
|
@ -427,7 +468,7 @@ impl ClaudeAgentSession {
|
|||
..
|
||||
} => {
|
||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
||||
if is_error {
|
||||
if is_error || subtype == ResultErrorType::ErrorDuringExecution {
|
||||
end_turn_tx
|
||||
.send(Err(anyhow!(
|
||||
"Error: {}",
|
||||
|
@ -435,7 +476,14 @@ impl ClaudeAgentSession {
|
|||
)))
|
||||
.ok();
|
||||
} else {
|
||||
end_turn_tx.send(Ok(())).ok();
|
||||
let stop_reason = match subtype {
|
||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||
ResultErrorType::ErrorDuringExecution => unreachable!(),
|
||||
};
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -548,11 +596,13 @@ enum ContentChunk {
|
|||
content: Content,
|
||||
tool_use_id: String,
|
||||
},
|
||||
Thinking {
|
||||
thinking: String,
|
||||
},
|
||||
RedactedThinking,
|
||||
// TODO
|
||||
Image,
|
||||
Document,
|
||||
Thinking,
|
||||
RedactedThinking,
|
||||
WebSearchToolResult,
|
||||
#[serde(untagged)]
|
||||
UntaggedText(String),
|
||||
|
@ -562,12 +612,12 @@ impl Display for ContentChunk {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentChunk::Text { text } => write!(f, "{}", text),
|
||||
ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
|
||||
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
|
||||
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
|
||||
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::ToolUse { .. }
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
write!(f, "\n{:?}\n", &self)
|
||||
|
@ -660,7 +710,7 @@ struct ControlResponse {
|
|||
subtype: ResultErrorType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ResultErrorType {
|
||||
Success,
|
||||
|
@ -714,6 +764,8 @@ enum PermissionMode {
|
|||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::e2e_tests;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
|
||||
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
|
||||
|
@ -726,6 +778,68 @@ pub(crate) mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn test_todo_plan(cx: &mut TestAppContext) {
|
||||
let fs = e2e_tests::init_test(cx).await;
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread =
|
||||
e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut entries_len = 0;
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
entries_len = thread.plan().entries.len();
|
||||
assert!(thread.plan().entries.len() > 0, "Empty plan");
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Mark the first entry status as in progress without acting on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
thread.plan().entries[0].status,
|
||||
acp::PlanEntryStatus::InProgress
|
||||
));
|
||||
assert_eq!(thread.plan().entries.len(), entries_len);
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Now mark the first entry as completed without acting on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
thread.plan().entries[0].status,
|
||||
acp::PlanEntryStatus::Completed
|
||||
));
|
||||
assert_eq!(thread.plan().entries.len(), entries_len);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_content_untagged_text() {
|
||||
let json = json!("Hello, world!");
|
||||
|
|
|
@ -143,25 +143,6 @@ impl ClaudeTool {
|
|||
Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
|
||||
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
|
||||
Self::WebSearch(Some(params)) => vec![params.to_string().into()],
|
||||
Self::TodoWrite(Some(params)) => vec![
|
||||
params
|
||||
.todos
|
||||
.iter()
|
||||
.map(|todo| {
|
||||
format!(
|
||||
"- {} {}: {}",
|
||||
match todo.status {
|
||||
TodoStatus::Completed => "✅",
|
||||
TodoStatus::InProgress => "🚧",
|
||||
TodoStatus::Pending => "⬜",
|
||||
},
|
||||
todo.priority,
|
||||
todo.content
|
||||
)
|
||||
})
|
||||
.join("\n")
|
||||
.into(),
|
||||
],
|
||||
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
|
||||
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
|
@ -193,6 +174,10 @@ impl ClaudeTool {
|
|||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Self::TodoWrite(Some(_)) => {
|
||||
// These are mapped to plan updates later
|
||||
vec![]
|
||||
}
|
||||
Self::Task(None)
|
||||
| Self::NotebookRead(None)
|
||||
| Self::NotebookEdit(None)
|
||||
|
@ -488,10 +473,11 @@ impl std::fmt::Display for GrepToolParams {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
|
||||
#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TodoPriority {
|
||||
High,
|
||||
#[default]
|
||||
Medium,
|
||||
Low,
|
||||
}
|
||||
|
@ -526,14 +512,13 @@ impl Into<acp::PlanEntryStatus> for TodoStatus {
|
|||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
pub struct Todo {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
/// Task description
|
||||
pub content: String,
|
||||
/// Priority level of the todo
|
||||
pub priority: TodoPriority,
|
||||
/// Current status of the todo
|
||||
pub status: TodoStatus,
|
||||
/// Priority level of the todo
|
||||
#[serde(default)]
|
||||
pub priority: TodoPriority,
|
||||
}
|
||||
|
||||
impl Into<acp::PlanEntry> for Todo {
|
||||
|
|
|
@ -311,6 +311,27 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
|||
});
|
||||
}
|
||||
|
||||
pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
|
||||
});
|
||||
|
||||
let weak_thread = thread.downgrade();
|
||||
drop(thread);
|
||||
|
||||
cx.executor().run_until_parked();
|
||||
assert!(!weak_thread.is_upgradable());
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! common_e2e_tests {
|
||||
($server:expr, allow_option_id = $allow_option_id:expr) => {
|
||||
|
@ -351,6 +372,12 @@ macro_rules! common_e2e_tests {
|
|||
async fn cancel(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_cancel($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_thread_drop($server, cx).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::path::Path;
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand};
|
||||
use acp_thread::AgentConnection;
|
||||
use acp_thread::{AgentConnection, LoadError};
|
||||
use anyhow::Result;
|
||||
use gpui::{Entity, Task};
|
||||
use project::Project;
|
||||
|
@ -48,12 +48,42 @@ impl AgentServer for Gemini {
|
|||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find gemini binary");
|
||||
};
|
||||
|
||||
crate::acp::connect(server_name, command, &root_dir, cx).await
|
||||
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
|
||||
if result.is_err() {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
||||
|
||||
if !supported {
|
||||
return Err(LoadError::Unsupported {
|
||||
error_message: format!(
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
).into(),
|
||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
||||
}.into())
|
||||
}
|
||||
}
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,9 @@ use std::borrow::Cow;
|
|||
|
||||
pub use crate::agent_profile::*;
|
||||
|
||||
pub const SUMMARIZE_THREAD_PROMPT: &str =
|
||||
include_str!("../../agent/src/prompts/summarize_thread_prompt.txt");
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
AgentSettings::register(cx);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"]
|
|||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
agent2.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
ai_onboarding.workspace = true
|
||||
|
|
|
@ -5,6 +5,7 @@ use audio::{Audio, Sound};
|
|||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
use std::process::ExitStatus;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
@ -20,10 +21,10 @@ use editor::{
|
|||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
|
||||
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
|
||||
list, percentage, point, prelude::*, pulsating_between,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
|
||||
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
|
||||
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
|
||||
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use language::{Buffer, Language};
|
||||
|
@ -33,7 +34,9 @@ use project::Project;
|
|||
use settings::Settings as _;
|
||||
use text::{Anchor, BufferSnapshot};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||
|
@ -68,6 +71,7 @@ pub struct AcpThreadView {
|
|||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
last_error: Option<Entity<Markdown>>,
|
||||
list_state: ListState,
|
||||
scrollbar_state: ScrollbarState,
|
||||
auth_task: Option<Task<()>>,
|
||||
expanded_tool_calls: HashSet<acp::ToolCallId>,
|
||||
expanded_thinking_blocks: HashSet<(usize, usize)>,
|
||||
|
@ -90,6 +94,9 @@ enum ThreadState {
|
|||
Unauthenticated {
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
},
|
||||
ServerExited {
|
||||
status: ExitStatus,
|
||||
},
|
||||
}
|
||||
|
||||
impl AcpThreadView {
|
||||
|
@ -169,22 +176,7 @@ impl AcpThreadView {
|
|||
|
||||
let mention_set = mention_set.clone();
|
||||
|
||||
let list_state = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Bottom,
|
||||
px(2048.0),
|
||||
cx.processor({
|
||||
move |this: &mut Self, index: usize, window, cx| {
|
||||
let Some((entry, len)) = this.thread().and_then(|thread| {
|
||||
let entries = &thread.read(cx).entries();
|
||||
Some((entries.get(index)?, entries.len()))
|
||||
}) else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
this.render_entry(index, len, entry, window, cx)
|
||||
}
|
||||
}),
|
||||
);
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
|
||||
|
||||
Self {
|
||||
agent: agent.clone(),
|
||||
|
@ -198,7 +190,8 @@ impl AcpThreadView {
|
|||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
diff_editors: Default::default(),
|
||||
list_state: list_state,
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
last_error: None,
|
||||
auth_task: None,
|
||||
expanded_tool_calls: HashSet::default(),
|
||||
|
@ -228,7 +221,7 @@ impl AcpThreadView {
|
|||
let connect_task = agent.connect(&root_dir, &project, cx);
|
||||
let load_task = cx.spawn_in(window, async move |this, cx| {
|
||||
let connection = match connect_task.await {
|
||||
Ok(thread) => thread,
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.handle_load_error(err, cx);
|
||||
|
@ -239,6 +232,20 @@ impl AcpThreadView {
|
|||
}
|
||||
};
|
||||
|
||||
// this.update_in(cx, |_this, _window, cx| {
|
||||
// let status = connection.exit_status(cx);
|
||||
// cx.spawn(async move |this, cx| {
|
||||
// let status = status.await.ok();
|
||||
// this.update(cx, |this, cx| {
|
||||
// this.thread_state = ThreadState::ServerExited { status };
|
||||
// cx.notify();
|
||||
// })
|
||||
// .ok();
|
||||
// })
|
||||
// .detach();
|
||||
// })
|
||||
// .ok();
|
||||
|
||||
let result = match connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
|
@ -307,7 +314,8 @@ impl AcpThreadView {
|
|||
ThreadState::Ready { thread, .. } => Some(thread),
|
||||
ThreadState::Unauthenticated { .. }
|
||||
| ThreadState::Loading { .. }
|
||||
| ThreadState::LoadError(..) => None,
|
||||
| ThreadState::LoadError(..)
|
||||
| ThreadState::ServerExited { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -317,6 +325,7 @@ impl AcpThreadView {
|
|||
ThreadState::Loading { .. } => "Loading…".into(),
|
||||
ThreadState::LoadError(_) => "Failed to load".into(),
|
||||
ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
|
||||
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -646,6 +655,9 @@ impl AcpThreadView {
|
|||
cx,
|
||||
);
|
||||
}
|
||||
AcpThreadEvent::ServerExited(status) => {
|
||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
@ -785,7 +797,7 @@ impl AcpThreadView {
|
|||
window: &mut Window,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
match &entry {
|
||||
let primary = match &entry {
|
||||
AgentThreadEntry::UserMessage(message) => div()
|
||||
.py_4()
|
||||
.px_2()
|
||||
|
@ -850,6 +862,20 @@ impl AcpThreadView {
|
|||
.px_5()
|
||||
.child(self.render_tool_call(index, tool_call, window, cx))
|
||||
.into_any(),
|
||||
};
|
||||
|
||||
let Some(thread) = self.thread() else {
|
||||
return primary;
|
||||
};
|
||||
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
|
||||
if index == total_entries - 1 && !is_generating {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.child(primary)
|
||||
.child(self.render_thread_controls(cx))
|
||||
.into_any_element()
|
||||
} else {
|
||||
primary
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1369,7 +1395,29 @@ impl AcpThreadView {
|
|||
.into_any()
|
||||
}
|
||||
|
||||
fn render_error_state(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
|
||||
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
|
||||
v_flex()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_agent_logo())
|
||||
.child(
|
||||
v_flex()
|
||||
.mt_4()
|
||||
.mb_2()
|
||||
.gap_0p5()
|
||||
.text_center()
|
||||
.items_center()
|
||||
.child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium))
|
||||
.child(
|
||||
Label::new(format!("Exit status: {}", status.code().unwrap_or(-127)))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
|
||||
let mut container = v_flex()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
|
@ -2404,7 +2452,7 @@ impl AcpThreadView {
|
|||
}
|
||||
}
|
||||
|
||||
fn render_thread_controls(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render_thread_controls(&self, cx: &Context<Self>) -> impl IntoElement {
|
||||
let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
|
@ -2425,9 +2473,8 @@ impl AcpThreadView {
|
|||
}));
|
||||
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.mr_1()
|
||||
.py_2()
|
||||
.pb_2()
|
||||
.px(RESPONSE_PADDING_X)
|
||||
.opacity(0.4)
|
||||
.hover(|style| style.opacity(1.))
|
||||
|
@ -2436,6 +2483,39 @@ impl AcpThreadView {
|
|||
.child(open_as_markdown)
|
||||
.child(scroll_to_top)
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.id("acp-thread-scrollbar")
|
||||
.occlude()
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AcpThreadView {
|
||||
|
@ -2481,24 +2561,36 @@ impl Render for AcpThreadView {
|
|||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_state(e, cx)),
|
||||
.child(self.render_load_error(e, cx)),
|
||||
ThreadState::ServerExited { status } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_server_exited(*status, cx)),
|
||||
ThreadState::Ready { thread, .. } => {
|
||||
let thread_clone = thread.clone();
|
||||
|
||||
v_flex().flex_1().map(|this| {
|
||||
if self.list_state.item_count() > 0 {
|
||||
let is_generating =
|
||||
matches!(thread_clone.read(cx).status(), ThreadStatus::Generating);
|
||||
|
||||
this.child(
|
||||
list(self.list_state.clone())
|
||||
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
|
||||
.flex_grow()
|
||||
.into_any(),
|
||||
list(
|
||||
self.list_state.clone(),
|
||||
cx.processor(|this, index: usize, window, cx| {
|
||||
let Some((entry, len)) = this.thread().and_then(|thread| {
|
||||
let entries = &thread.read(cx).entries();
|
||||
Some((entries.get(index)?, entries.len()))
|
||||
}) else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
this.render_entry(index, len, entry, window, cx)
|
||||
}),
|
||||
)
|
||||
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
|
||||
.flex_grow()
|
||||
.into_any(),
|
||||
)
|
||||
.when(!is_generating, |this| {
|
||||
this.child(self.render_thread_controls(cx))
|
||||
})
|
||||
.child(self.render_vertical_scrollbar(cx))
|
||||
.children(match thread_clone.read(cx).status() {
|
||||
ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => {
|
||||
None
|
||||
|
@ -2713,6 +2805,16 @@ mod tests {
|
|||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_drop(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await;
|
||||
let weak_view = thread_view.downgrade();
|
||||
drop(thread_view);
|
||||
assert!(!weak_view.is_upgradable());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
@ -2943,7 +3045,11 @@ mod tests {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
|
@ -2977,7 +3083,9 @@ mod tests {
|
|||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(())
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3021,7 +3129,11 @@ mod tests {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn prompt(
|
||||
&self,
|
||||
_params: acp::PromptRequest,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
||||
}
|
||||
|
||||
|
|
|
@ -69,8 +69,6 @@ pub struct ActiveThread {
|
|||
messages: Vec<MessageId>,
|
||||
list_state: ListState,
|
||||
scrollbar_state: ScrollbarState,
|
||||
show_scrollbar: bool,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
|
||||
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
|
||||
editing_message: Option<(MessageId, EditingMessageState)>,
|
||||
|
@ -780,13 +778,7 @@ impl ActiveThread {
|
|||
cx.observe_global::<SettingsStore>(|_, cx| cx.notify()),
|
||||
];
|
||||
|
||||
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
|
||||
let this = cx.entity().downgrade();
|
||||
move |ix, window: &mut Window, cx: &mut App| {
|
||||
this.update(cx, |this, cx| this.render_message(ix, window, cx))
|
||||
.unwrap()
|
||||
}
|
||||
});
|
||||
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.));
|
||||
|
||||
let workspace_subscription = if let Some(workspace) = workspace.upgrade() {
|
||||
Some(cx.observe_release(&workspace, |this, _, cx| {
|
||||
|
@ -811,9 +803,7 @@ impl ActiveThread {
|
|||
expanded_thinking_segments: HashMap::default(),
|
||||
expanded_code_blocks: HashMap::default(),
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state),
|
||||
show_scrollbar: false,
|
||||
hide_scrollbar_task: None,
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
editing_message: None,
|
||||
last_error: None,
|
||||
copied_code_block_ids: HashSet::default(),
|
||||
|
@ -1846,7 +1836,12 @@ impl ActiveThread {
|
|||
)))
|
||||
}
|
||||
|
||||
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
|
||||
fn render_message(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let message_id = self.messages[ix];
|
||||
let workspace = self.workspace.clone();
|
||||
let thread = self.thread.read(cx);
|
||||
|
@ -3503,60 +3498,37 @@ impl ActiveThread {
|
|||
}
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !self.show_scrollbar && !self.scrollbar_state.is_dragging() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("active-thread-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("active-thread-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
fn hide_scrollbar_later(&mut self, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
if !thread.scrollbar_state.is_dragging() {
|
||||
thread.show_scrollbar = false;
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
|
||||
|
@ -3597,26 +3569,8 @@ impl Render for ActiveThread {
|
|||
.size_full()
|
||||
.relative()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.on_mouse_move(cx.listener(|this, _, _, cx| {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_later(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.on_scroll_wheel(cx.listener(|this, _, _, cx| {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_later(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|this, _, _, cx| {
|
||||
this.hide_scrollbar_later(cx);
|
||||
}),
|
||||
)
|
||||
.child(list(self.list_state.clone()).flex_grow())
|
||||
.when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| {
|
||||
this.child(scrollbar)
|
||||
})
|
||||
.child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow())
|
||||
.child(self.render_vertical_scrollbar(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1523,7 +1523,8 @@ impl AgentDiff {
|
|||
}
|
||||
AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error => {}
|
||||
| AcpThreadEvent::Error
|
||||
| AcpThreadEvent::ServerExited(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -970,13 +970,7 @@ impl AgentPanel {
|
|||
)
|
||||
});
|
||||
|
||||
this.set_active_view(
|
||||
ActiveView::ExternalAgentThread {
|
||||
thread_view: thread_view.clone(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
@ -1987,6 +1981,22 @@ impl AgentPanel {
|
|||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Native Agent Thread")
|
||||
.icon(IconName::ZedAssistant)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::NativeAgent,
|
||||
),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
|
@ -2649,6 +2659,31 @@ impl AgentPanel {
|
|||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-native-agent-thread-btn",
|
||||
"New Native Agent Thread",
|
||||
IconName::ZedAssistant,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::NativeAgent,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
|
|
|
@ -151,6 +151,7 @@ enum ExternalAgent {
|
|||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
NativeAgent,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
|
@ -158,6 +159,7 @@ impl ExternalAgent {
|
|||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -504,7 +504,7 @@ impl Render for ContextStrip {
|
|||
)
|
||||
.on_click({
|
||||
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count > 1 {
|
||||
if event.click_count() > 1 {
|
||||
this.open_context(&context, window, cx);
|
||||
} else {
|
||||
this.focused_index = Some(i);
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use ai_onboarding::{AgentPanelOnboardingCard, BulletItem};
|
||||
use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions};
|
||||
use client::zed_urls;
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Tooltip, prelude::*};
|
||||
use ui::{Divider, Tooltip, prelude::*};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct EndTrialUpsell {
|
||||
|
@ -18,6 +18,8 @@ impl EndTrialUpsell {
|
|||
|
||||
impl RenderOnce for EndTrialUpsell {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
let pro_section = v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
|
@ -31,13 +33,7 @@ impl RenderOnce for EndTrialUpsell {
|
|||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
)
|
||||
.child(plan_definitions.pro_plan(false))
|
||||
.child(
|
||||
Button::new("cta-button", "Upgrade to Zed Pro")
|
||||
.full_width()
|
||||
|
@ -68,11 +64,7 @@ impl RenderOnce for EndTrialUpsell {
|
|||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with the Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
);
|
||||
.child(plan_definitions.free_plan());
|
||||
|
||||
AgentPanelOnboardingCard::new()
|
||||
.child(Headline::new("Your Zed Pro Trial has expired"))
|
||||
|
@ -102,18 +94,20 @@ impl RenderOnce for EndTrialUpsell {
|
|||
|
||||
impl Component for EndTrialUpsell {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
ComponentScope::Onboarding
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"End of Trial Upsell Banner"
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AgentEndTrialUpsell"
|
||||
"End of Trial Upsell Banner"
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
.gap_4()
|
||||
.child(EndTrialUpsell {
|
||||
dismiss_upsell: Arc::new(|_, _| {}),
|
||||
})
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::{Divider, List, prelude::*};
|
||||
|
||||
use crate::BulletItem;
|
||||
use ui::{Divider, List, ListBulletItem, prelude::*};
|
||||
|
||||
pub struct ApiKeysWithProviders {
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
|
@ -128,7 +126,7 @@ impl RenderOnce for ApiKeysWithoutProviders {
|
|||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(List::new().child(BulletItem::new(
|
||||
.child(List::new().child(ListBulletItem::new(
|
||||
"Add your own keys to use AI without signing in.",
|
||||
)))
|
||||
.child(
|
||||
|
|
|
@ -3,6 +3,7 @@ mod agent_panel_onboarding_card;
|
|||
mod agent_panel_onboarding_content;
|
||||
mod ai_upsell_card;
|
||||
mod edit_prediction_onboarding_content;
|
||||
mod plan_definitions;
|
||||
mod young_account_banner;
|
||||
|
||||
pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders};
|
||||
|
@ -11,51 +12,14 @@ pub use agent_panel_onboarding_content::AgentPanelOnboarding;
|
|||
pub use ai_upsell_card::AiUpsellCard;
|
||||
use cloud_llm_client::Plan;
|
||||
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
|
||||
pub use plan_definitions::PlanDefinitions;
|
||||
pub use young_account_banner::YoungAccountBanner;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString};
|
||||
use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct BulletItem {
|
||||
label: SharedString,
|
||||
}
|
||||
|
||||
impl BulletItem {
|
||||
pub fn new(label: impl Into<SharedString>) -> Self {
|
||||
Self {
|
||||
label: label.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for BulletItem {
|
||||
fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
let line_height = 0.85 * window.line_height();
|
||||
|
||||
ListItem::new("list-item")
|
||||
.selectable(false)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.min_w_0()
|
||||
.gap_1()
|
||||
.items_start()
|
||||
.child(
|
||||
h_flex().h(line_height).justify_center().child(
|
||||
Icon::new(IconName::Dash)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hidden),
|
||||
),
|
||||
)
|
||||
.child(div().w_full().min_w_0().child(Label::new(self.label))),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
use gpui::{AnyElement, Entity, IntoElement, ParentElement};
|
||||
use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*};
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum SignInStatus {
|
||||
|
@ -130,107 +94,6 @@ impl ZedAiOnboarding {
|
|||
self
|
||||
}
|
||||
|
||||
fn free_plan_definition(&self, cx: &mut App) -> impl IntoElement {
|
||||
v_flex()
|
||||
.mt_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Free")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(
|
||||
Label::new("(Current Plan)")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6)))
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts per month with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"2,000 accepted edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
fn pro_trial_definition(&self) -> impl IntoElement {
|
||||
List::new()
|
||||
.child(BulletItem::new("150 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited accepted edit predictions with Zeta, our open-source model",
|
||||
))
|
||||
}
|
||||
|
||||
fn pro_plan_definition(&self, cx: &mut App) -> impl IntoElement {
|
||||
v_flex().mt_2().gap_1().map(|this| {
|
||||
if self.account_too_young {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts per month with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited accepted edit predictions with Zeta, our open-source model",
|
||||
))
|
||||
.child(BulletItem::new("$20 USD per month")),
|
||||
)
|
||||
.child(
|
||||
Button::new("pro", "Get Started")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Upgrade To Pro Clicked", state = "young-account");
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro Trial")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(self.pro_trial_definition())
|
||||
.child(BulletItem::new(
|
||||
"Try it out for 14 days for free, no credit card required",
|
||||
)),
|
||||
)
|
||||
.child(
|
||||
Button::new("pro", "Start Free Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn render_accept_terms_of_service(&self) -> AnyElement {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
|
@ -269,6 +132,7 @@ impl ZedAiOnboarding {
|
|||
|
||||
fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement {
|
||||
let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn);
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
v_flex()
|
||||
.gap_1()
|
||||
|
@ -278,7 +142,7 @@ impl ZedAiOnboarding {
|
|||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(self.pro_trial_definition())
|
||||
.child(plan_definitions.pro_plan(false))
|
||||
.child(
|
||||
Button::new("sign_in", "Try Zed Pro for Free")
|
||||
.disabled(signing_in)
|
||||
|
@ -297,43 +161,132 @@ impl ZedAiOnboarding {
|
|||
|
||||
fn render_free_plan_state(&self, cx: &mut App) -> AnyElement {
|
||||
let young_account_banner = YoungAccountBanner;
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
v_flex()
|
||||
.relative()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed AI"))
|
||||
.map(|this| {
|
||||
if self.account_too_young {
|
||||
this.child(young_account_banner)
|
||||
} else {
|
||||
this.child(self.free_plan_definition(cx)).when_some(
|
||||
self.dismiss_onboarding.as_ref(),
|
||||
|this, dismiss_callback| {
|
||||
let callback = dismiss_callback.clone();
|
||||
if self.account_too_young {
|
||||
v_flex()
|
||||
.relative()
|
||||
.max_w_full()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed AI"))
|
||||
.child(young_account_banner)
|
||||
.child(
|
||||
v_flex()
|
||||
.mt_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(plan_definitions.pro_plan(true))
|
||||
.child(
|
||||
Button::new("pro", "Get Started")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Upgrade To Pro Clicked",
|
||||
state = "young-account"
|
||||
);
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
v_flex()
|
||||
.relative()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed AI"))
|
||||
.child(
|
||||
v_flex()
|
||||
.mt_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Free")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(
|
||||
Label::new("(Current Plan)")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Custom(
|
||||
cx.theme().colors().text_muted.opacity(0.6),
|
||||
))
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(plan_definitions.free_plan()),
|
||||
)
|
||||
.when_some(
|
||||
self.dismiss_onboarding.as_ref(),
|
||||
|this, dismiss_callback| {
|
||||
let callback = dismiss_callback.clone();
|
||||
|
||||
this.child(
|
||||
h_flex().absolute().top_0().right_0().child(
|
||||
IconButton::new("dismiss_onboarding", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click(move |_, window, cx| {
|
||||
telemetry::event!(
|
||||
"Banner Dismissed",
|
||||
source = "AI Onboarding",
|
||||
);
|
||||
callback(window, cx)
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
})
|
||||
.child(self.pro_plan_definition(cx))
|
||||
.into_any_element()
|
||||
this.child(
|
||||
h_flex().absolute().top_0().right_0().child(
|
||||
IconButton::new("dismiss_onboarding", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click(move |_, window, cx| {
|
||||
telemetry::event!(
|
||||
"Banner Dismissed",
|
||||
source = "AI Onboarding",
|
||||
);
|
||||
callback(window, cx)
|
||||
}),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.mt_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro Trial")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(plan_definitions.pro_trial(true))
|
||||
.child(
|
||||
Button::new("pro", "Start Free Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Start Trial Clicked",
|
||||
state = "post-sign-in"
|
||||
);
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
fn render_trial_state(&self, _cx: &mut App) -> AnyElement {
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
v_flex()
|
||||
.relative()
|
||||
.gap_1()
|
||||
|
@ -343,13 +296,7 @@ impl ZedAiOnboarding {
|
|||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("150 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
)
|
||||
.child(plan_definitions.pro_trial(false))
|
||||
.when_some(
|
||||
self.dismiss_onboarding.as_ref(),
|
||||
|this, dismiss_callback| {
|
||||
|
@ -374,6 +321,8 @@ impl ZedAiOnboarding {
|
|||
}
|
||||
|
||||
fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement {
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed Pro"))
|
||||
|
@ -382,13 +331,7 @@ impl ZedAiOnboarding {
|
|||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
)
|
||||
.child(plan_definitions.pro_plan(false))
|
||||
.child(
|
||||
Button::new("pro", "Continue with Zed Pro")
|
||||
.full_width()
|
||||
|
@ -425,7 +368,15 @@ impl RenderOnce for ZedAiOnboarding {
|
|||
|
||||
impl Component for ZedAiOnboarding {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
ComponentScope::Onboarding
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"Agent Panel Banners"
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"Agent Panel Banners"
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
|
@ -450,8 +401,9 @@ impl Component for ZedAiOnboarding {
|
|||
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
.gap_4()
|
||||
.items_center()
|
||||
.max_w_4_5()
|
||||
.children(vec![
|
||||
single_example(
|
||||
"Not Signed-in",
|
||||
|
@ -462,8 +414,8 @@ impl Component for ZedAiOnboarding {
|
|||
onboarding(SignInStatus::SignedIn, false, None, false),
|
||||
),
|
||||
single_example(
|
||||
"Account too young",
|
||||
onboarding(SignInStatus::SignedIn, false, None, true),
|
||||
"Young Account",
|
||||
onboarding(SignInStatus::SignedIn, true, None, true),
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use client::{Client, zed_urls};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Vector, VectorName, prelude::*};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, AnyElement, App, IntoElement, RenderOnce, Transformation, Window,
|
||||
percentage,
|
||||
};
|
||||
use ui::{Divider, Vector, VectorName, prelude::*};
|
||||
|
||||
use crate::{BulletItem, SignInStatus};
|
||||
use crate::{SignInStatus, plan_definitions::PlanDefinitions};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct AiUpsellCard {
|
||||
|
@ -36,6 +39,8 @@ impl AiUpsellCard {
|
|||
|
||||
impl RenderOnce for AiUpsellCard {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let plan_definitions = PlanDefinitions;
|
||||
|
||||
let pro_section = v_flex()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
|
@ -51,13 +56,7 @@ impl RenderOnce for AiUpsellCard {
|
|||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
);
|
||||
.child(plan_definitions.pro_plan(false));
|
||||
|
||||
let free_section = v_flex()
|
||||
.flex_grow()
|
||||
|
@ -74,11 +73,7 @@ impl RenderOnce for AiUpsellCard {
|
|||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
);
|
||||
.child(plan_definitions.free_plan());
|
||||
|
||||
let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child(
|
||||
Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.))
|
||||
|
@ -101,44 +96,11 @@ impl RenderOnce for AiUpsellCard {
|
|||
),
|
||||
));
|
||||
|
||||
const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
|
||||
let description = PlanDefinitions::AI_DESCRIPTION;
|
||||
|
||||
let footer_buttons = match self.sign_in_status {
|
||||
SignInStatus::SignedIn => v_flex()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("sign_in", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
})
|
||||
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)),
|
||||
)
|
||||
.child(
|
||||
Label::new("No credit card required")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element(),
|
||||
_ => Button::new("sign_in", "Sign In")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index))
|
||||
.on_click({
|
||||
let callback = self.sign_in.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
|
||||
callback(window, cx)
|
||||
}
|
||||
})
|
||||
.into_any_element(),
|
||||
};
|
||||
|
||||
v_flex()
|
||||
let card = v_flex()
|
||||
.relative()
|
||||
.flex_grow()
|
||||
.p_4()
|
||||
.pt_3()
|
||||
.border_1()
|
||||
|
@ -146,31 +108,135 @@ impl RenderOnce for AiUpsellCard {
|
|||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(grid_bg)
|
||||
.child(gradient_bg)
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(gradient_bg);
|
||||
|
||||
let plans_section = h_flex()
|
||||
.w_full()
|
||||
.mt_1p5()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.gap_6()
|
||||
.child(free_section)
|
||||
.child(pro_section);
|
||||
|
||||
let footer_container = v_flex().items_center().gap_1();
|
||||
|
||||
let certified_user_stamp = div()
|
||||
.absolute()
|
||||
.top_2()
|
||||
.right_2()
|
||||
.size(rems_from_px(72.))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted)),
|
||||
)
|
||||
Vector::new(
|
||||
VectorName::CertifiedUserStamp,
|
||||
rems_from_px(72.),
|
||||
rems_from_px(72.),
|
||||
)
|
||||
.color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3)))
|
||||
.with_animation(
|
||||
"loading_stamp",
|
||||
Animation::new(Duration::from_secs(10)).repeat(),
|
||||
|this, delta| this.transform(Transformation::rotate(percentage(delta))),
|
||||
),
|
||||
);
|
||||
|
||||
let pro_trial_stamp = div()
|
||||
.absolute()
|
||||
.top_2()
|
||||
.right_2()
|
||||
.size(rems_from_px(72.))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.mt_1p5()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.gap_6()
|
||||
.child(free_section)
|
||||
.child(pro_section),
|
||||
)
|
||||
.child(footer_buttons)
|
||||
Vector::new(
|
||||
VectorName::ProTrialStamp,
|
||||
rems_from_px(72.),
|
||||
rems_from_px(72.),
|
||||
)
|
||||
.color(Color::Custom(cx.theme().colors().text.alpha(0.2))),
|
||||
);
|
||||
|
||||
match self.sign_in_status {
|
||||
SignInStatus::SignedIn => match self.user_plan {
|
||||
None | Some(Plan::ZedFree) => card
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(description).color(Color::Muted)),
|
||||
)
|
||||
.child(plans_section)
|
||||
.child(
|
||||
footer_container
|
||||
.child(
|
||||
Button::new("start_trial", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.when_some(self.tab_index, |this, tab_index| {
|
||||
this.tab_index(tab_index)
|
||||
})
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Start Trial Clicked",
|
||||
state = "post-sign-in"
|
||||
);
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Label::new("No credit card required")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
),
|
||||
Some(Plan::ZedProTrial) => card
|
||||
.child(pro_trial_stamp)
|
||||
.child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large))
|
||||
.child(
|
||||
Label::new("Here's what you get for the next 14 days:")
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(plan_definitions.pro_trial(false)),
|
||||
Some(Plan::ZedPro) => card
|
||||
.child(certified_user_stamp)
|
||||
.child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large))
|
||||
.child(
|
||||
Label::new("Here's what you get:")
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(plan_definitions.pro_plan(false)),
|
||||
},
|
||||
// Signed Out State
|
||||
_ => card
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(description).color(Color::Muted)),
|
||||
)
|
||||
.child(plans_section)
|
||||
.child(
|
||||
Button::new("sign_in", "Sign In")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index))
|
||||
.on_click({
|
||||
let callback = self.sign_in.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
|
||||
callback(window, cx)
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for AiUpsellCard {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
ComponentScope::Onboarding
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
|
@ -188,7 +254,6 @@ impl Component for AiUpsellCard {
|
|||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
.gap_4()
|
||||
.children(vec![example_group(vec![
|
||||
single_example(
|
||||
|
@ -202,11 +267,31 @@ impl Component for AiUpsellCard {
|
|||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Signed In State",
|
||||
"Free Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
user_plan: Some(Plan::ZedFree),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Trial",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: Some(Plan::ZedProTrial),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: Some(Plan::ZedPro),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
|
|
39
crates/ai_onboarding/src/plan_definitions.rs
Normal file
39
crates/ai_onboarding/src/plan_definitions.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
use gpui::{IntoElement, ParentElement};
|
||||
use ui::{List, ListBulletItem, prelude::*};
|
||||
|
||||
/// Centralized definitions for Zed AI plans
|
||||
pub struct PlanDefinitions;
|
||||
|
||||
impl PlanDefinitions {
|
||||
pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
|
||||
|
||||
pub fn free_plan(&self) -> impl IntoElement {
|
||||
List::new()
|
||||
.child(ListBulletItem::new("50 prompts with Claude models"))
|
||||
.child(ListBulletItem::new("2,000 accepted edit predictions"))
|
||||
}
|
||||
|
||||
pub fn pro_trial(&self, period: bool) -> impl IntoElement {
|
||||
List::new()
|
||||
.child(ListBulletItem::new("150 prompts with Claude models"))
|
||||
.child(ListBulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
))
|
||||
.when(period, |this| {
|
||||
this.child(ListBulletItem::new(
|
||||
"Try it out for 14 days for free, no credit card required",
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pro_plan(&self, price: bool) -> impl IntoElement {
|
||||
List::new()
|
||||
.child(ListBulletItem::new("500 prompts with Claude models"))
|
||||
.child(ListBulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
))
|
||||
.when(price, |this| {
|
||||
this.child(ListBulletItem::new("$20 USD per month"))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ impl RenderOnce for YoungAccountBanner {
|
|||
.child(YOUNG_ACCOUNT_DISCLAIMER);
|
||||
|
||||
div()
|
||||
.max_w_full()
|
||||
.my_1()
|
||||
.child(Banner::new().severity(ui::Severity::Warning).child(label))
|
||||
}
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
mod assistant_context_tests;
|
||||
mod context_store;
|
||||
|
||||
use agent_settings::AgentSettings;
|
||||
use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT};
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use assistant_slash_command::{
|
||||
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
|
||||
SlashCommandResult, SlashCommandWorkingSet,
|
||||
};
|
||||
use assistant_slash_commands::FileCommandMetadata;
|
||||
use client::{self, Client, proto, telemetry::Telemetry};
|
||||
use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::{Fs, RenameOptions};
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
|
@ -2080,7 +2080,18 @@ impl AssistantContext {
|
|||
});
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StatusUpdate { .. } => {}
|
||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||
match status_update {
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||
this.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
stop_reason = reason;
|
||||
|
@ -2677,10 +2688,7 @@ impl AssistantContext {
|
|||
let mut request = self.to_completion_request(Some(&model.model), cx);
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
"Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
|
||||
.into(),
|
||||
],
|
||||
content: vec![SUMMARIZE_THREAD_PROMPT.into()],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
|
@ -2956,6 +2964,21 @@ impl AssistantContext {
|
|||
summary.text = custom_summary;
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
}
|
||||
|
||||
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) {
|
||||
let Some(project) = &self.project else {
|
||||
return;
|
||||
};
|
||||
project.read(cx).user_store().update(cx, |user_store, cx| {
|
||||
user_store.update_model_request_usage(
|
||||
ModelRequestUsage(RequestUsage {
|
||||
amount: amount as i32,
|
||||
limit,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
|
|
@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) {
|
|||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Brief");
|
||||
fake_model.stream_last_completion_response(" Introduction");
|
||||
fake_model.send_last_completion_stream_text_chunk("Brief");
|
||||
fake_model.send_last_completion_stream_text_chunk(" Introduction");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
|
|||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("A successful summary");
|
||||
fake_model.send_last_completion_stream_text_chunk("A successful summary");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model(
|
|||
|
||||
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Assistant response");
|
||||
fake_model.send_last_completion_stream_text_chunk("Assistant response");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
|
|
@ -962,7 +962,7 @@ mod tests {
|
|||
);
|
||||
cx.run_until_parked();
|
||||
|
||||
model.stream_last_completion_response("<old_text>a");
|
||||
model.send_last_completion_stream_text_chunk("<old_text>a");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), vec![]);
|
||||
assert_eq!(
|
||||
|
@ -974,7 +974,7 @@ mod tests {
|
|||
None
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("bc</old_text>");
|
||||
model.send_last_completion_stream_text_chunk("bc</old_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
|
@ -996,7 +996,7 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("<new_text>abX");
|
||||
model.send_last_completion_stream_text_chunk("<new_text>abX");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_eq!(
|
||||
|
@ -1011,7 +1011,7 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("cY");
|
||||
model.send_last_completion_stream_text_chunk("cY");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_eq!(
|
||||
|
@ -1026,8 +1026,8 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("</new_text>");
|
||||
model.stream_last_completion_response("<old_text>hall");
|
||||
model.send_last_completion_stream_text_chunk("</new_text>");
|
||||
model.send_last_completion_stream_text_chunk("<old_text>hall");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), vec![]);
|
||||
assert_eq!(
|
||||
|
@ -1042,8 +1042,8 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("ucinated old</old_text>");
|
||||
model.stream_last_completion_response("<new_text>");
|
||||
model.send_last_completion_stream_text_chunk("ucinated old</old_text>");
|
||||
model.send_last_completion_stream_text_chunk("<new_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
|
@ -1061,8 +1061,8 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("hallucinated new</new_");
|
||||
model.stream_last_completion_response("text>");
|
||||
model.send_last_completion_stream_text_chunk("hallucinated new</new_");
|
||||
model.send_last_completion_stream_text_chunk("text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), vec![]);
|
||||
assert_eq!(
|
||||
|
@ -1077,7 +1077,7 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("<old_text>\nghi\nj");
|
||||
model.send_last_completion_stream_text_chunk("<old_text>\nghi\nj");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
|
@ -1099,8 +1099,8 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("kl</old_text>");
|
||||
model.stream_last_completion_response("<new_text>");
|
||||
model.send_last_completion_stream_text_chunk("kl</old_text>");
|
||||
model.send_last_completion_stream_text_chunk("<new_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
|
@ -1122,7 +1122,7 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
model.stream_last_completion_response("GHI</new_text>");
|
||||
model.send_last_completion_stream_text_chunk("GHI</new_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
|
@ -1367,7 +1367,9 @@ mod tests {
|
|||
cx.background_spawn(async move {
|
||||
for chunk in chunks {
|
||||
executor.simulate_random_delay().await;
|
||||
model.as_fake().stream_last_completion_response(chunk);
|
||||
model
|
||||
.as_fake()
|
||||
.send_last_completion_stream_text_chunk(chunk);
|
||||
}
|
||||
model.as_fake().end_last_completion_stream();
|
||||
})
|
||||
|
|
|
@ -1577,7 +1577,7 @@ mod tests {
|
|||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
|
@ -1641,7 +1641,7 @@ mod tests {
|
|||
|
||||
// Stream the unformatted content
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||
model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
|
@ -1720,7 +1720,9 @@ mod tests {
|
|||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.send_last_completion_stream_text_chunk(
|
||||
CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
|
||||
);
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
|
@ -1777,7 +1779,9 @@ mod tests {
|
|||
|
||||
// Stream the content with trailing whitespace
|
||||
cx.executor().run_until_parked();
|
||||
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||
model.send_last_completion_stream_text_chunk(
|
||||
CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
|
||||
);
|
||||
model.end_last_completion_stream();
|
||||
|
||||
edit_task.await
|
||||
|
|
|
@ -1,31 +1,23 @@
|
|||
use anyhow::{Context as _, bail};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::{HashMap, HashSet};
|
||||
use sea_orm::ActiveValue;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
|
||||
use crate::db::{
|
||||
CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer,
|
||||
};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::stripe_client::{
|
||||
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
|
||||
StripeSubscriptionId,
|
||||
};
|
||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||
use crate::{
|
||||
db::{
|
||||
CreateBillingCustomerParams, CreateBillingSubscriptionParams,
|
||||
CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingSubscriptionParams, billing_customer,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
|
||||
/// The amount of time we wait in between each poll of Stripe events.
|
||||
///
|
||||
|
@ -542,194 +534,3 @@ pub async fn find_or_create_billing_customer(
|
|||
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.context("failed to sync LLM request usage to Stripe")
|
||||
.trace_err();
|
||||
executor
|
||||
.sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn sync_model_request_usage_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_model_request_usage_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
|
||||
if sync_model_request_usage_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::info!("Stripe usage sync: Starting");
|
||||
let started_at = Utc::now();
|
||||
|
||||
let staff_users = app.db.get_staff_users().await?;
|
||||
let staff_user_ids = staff_users
|
||||
.iter()
|
||||
.map(|user| user.id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
|
||||
let usage_meters = llm_db
|
||||
.get_current_subscription_usage_meters(Utc::now())
|
||||
.await?;
|
||||
let mut usage_meters_by_user_id =
|
||||
HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
|
||||
for (usage_meter, usage) in usage_meters {
|
||||
let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
|
||||
meters.push(usage_meter);
|
||||
}
|
||||
|
||||
log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
|
||||
let get_zed_pro_subscriptions_started_at = Utc::now();
|
||||
let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
|
||||
log::info!(
|
||||
"Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
|
||||
billing_subscriptions.len(),
|
||||
Utc::now() - get_zed_pro_subscriptions_started_at
|
||||
);
|
||||
|
||||
let claude_sonnet_4 = stripe_billing
|
||||
.find_price_by_lookup_key("claude-sonnet-4-requests")
|
||||
.await?;
|
||||
let claude_sonnet_4_max = stripe_billing
|
||||
.find_price_by_lookup_key("claude-sonnet-4-requests-max")
|
||||
.await?;
|
||||
let claude_opus_4 = stripe_billing
|
||||
.find_price_by_lookup_key("claude-opus-4-requests")
|
||||
.await?;
|
||||
let claude_opus_4_max = stripe_billing
|
||||
.find_price_by_lookup_key("claude-opus-4-requests-max")
|
||||
.await?;
|
||||
let claude_3_5_sonnet = stripe_billing
|
||||
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
|
||||
.await?;
|
||||
let claude_3_7_sonnet = stripe_billing
|
||||
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
|
||||
.await?;
|
||||
let claude_3_7_sonnet_max = stripe_billing
|
||||
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
|
||||
.await?;
|
||||
|
||||
let model_mode_combinations = [
|
||||
("claude-opus-4", CompletionMode::Max),
|
||||
("claude-opus-4", CompletionMode::Normal),
|
||||
("claude-sonnet-4", CompletionMode::Max),
|
||||
("claude-sonnet-4", CompletionMode::Normal),
|
||||
("claude-3-7-sonnet", CompletionMode::Max),
|
||||
("claude-3-7-sonnet", CompletionMode::Normal),
|
||||
("claude-3-5-sonnet", CompletionMode::Normal),
|
||||
];
|
||||
|
||||
let billing_subscription_count = billing_subscriptions.len();
|
||||
|
||||
log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
|
||||
|
||||
for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
|
||||
maybe!(async {
|
||||
if staff_user_ids.contains(&user_id) {
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
|
||||
let stripe_customer_id =
|
||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
let stripe_subscription_id =
|
||||
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
|
||||
|
||||
let usage_meters = usage_meters_by_user_id.get(&user_id);
|
||||
|
||||
for (model, mode) in &model_mode_combinations {
|
||||
let Ok(model) =
|
||||
llm_db.model(LanguageModelProvider::Anthropic, model)
|
||||
else {
|
||||
log::warn!("Failed to load model for user {user_id}: {model}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let (price, meter_event_name) = match model.name.as_str() {
|
||||
"claude-opus-4" => match mode {
|
||||
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
|
||||
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
|
||||
},
|
||||
"claude-sonnet-4" => match mode {
|
||||
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
|
||||
CompletionMode::Max => {
|
||||
(&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
|
||||
}
|
||||
},
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
|
||||
"claude-3-7-sonnet" => match mode {
|
||||
CompletionMode::Normal => {
|
||||
(&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
|
||||
}
|
||||
CompletionMode::Max => {
|
||||
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
|
||||
}
|
||||
},
|
||||
model_name => {
|
||||
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
|
||||
}
|
||||
};
|
||||
|
||||
let model_requests = usage_meters
|
||||
.and_then(|usage_meters| {
|
||||
usage_meters
|
||||
.iter()
|
||||
.find(|meter| meter.model_id == model.id && meter.mode == *mode)
|
||||
})
|
||||
.map(|usage_meter| usage_meter.requests)
|
||||
.unwrap_or(0);
|
||||
|
||||
if model_requests > 0 {
|
||||
stripe_billing
|
||||
.subscribe_to_price(&stripe_subscription_id, price)
|
||||
.await?;
|
||||
}
|
||||
|
||||
stripe_billing
|
||||
.bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
|
||||
Utc::now() - started_at
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -85,19 +85,6 @@ impl Database {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing subscription with the specified ID.
|
||||
pub async fn get_billing_subscription_by_id(
|
||||
&self,
|
||||
id: BillingSubscriptionId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing subscription with the specified Stripe subscription ID.
|
||||
pub async fn get_billing_subscription_by_stripe_subscription_id(
|
||||
&self,
|
||||
|
@ -143,119 +130,6 @@ impl Database {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Returns all of the billing subscriptions for the user with the specified ID.
|
||||
///
|
||||
/// Note that this returns the subscriptions regardless of their status.
|
||||
/// If you're wanting to check if a use has an active billing subscription,
|
||||
/// use `get_active_billing_subscriptions` instead.
|
||||
pub async fn get_billing_subscriptions(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Vec<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
let subscriptions = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(subscriptions)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_billing_subscriptions(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.is_null())
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_zed_pro_billing_subscriptions(
|
||||
&self,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_zed_pro_billing_subscriptions_for_users(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
mod billing_subscription_tests;
|
||||
mod buffer_tests;
|
||||
mod channel_tests;
|
||||
mod contributor_tests;
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::tests::new_test_user;
|
||||
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
|
||||
use crate::test_both_dbs;
|
||||
|
||||
use super::Database;
|
||||
|
||||
test_both_dbs!(
|
||||
test_get_active_billing_subscriptions,
|
||||
test_get_active_billing_subscriptions_postgres,
|
||||
test_get_active_billing_subscriptions_sqlite
|
||||
);
|
||||
|
||||
async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
|
||||
// A user with no subscription has no active billing subscriptions.
|
||||
{
|
||||
let user_id = new_test_user(db, "no-subscription-user@example.com").await;
|
||||
let subscription_count = db
|
||||
.count_active_billing_subscriptions(user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(subscription_count, 0);
|
||||
}
|
||||
|
||||
// A user with an active subscription has one active billing subscription.
|
||||
{
|
||||
let user_id = new_test_user(db, "active-user@example.com").await;
|
||||
let customer = db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id,
|
||||
stripe_customer_id: "cus_active_user".into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
|
||||
|
||||
db.create_billing_subscription(&CreateBillingSubscriptionParams {
|
||||
billing_customer_id: customer.id,
|
||||
kind: None,
|
||||
stripe_subscription_id: "sub_active_user".into(),
|
||||
stripe_subscription_status: StripeSubscriptionStatus::Active,
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: None,
|
||||
stripe_current_period_end: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
|
||||
assert_eq!(subscriptions.len(), 1);
|
||||
|
||||
let subscription = &subscriptions[0];
|
||||
assert_eq!(
|
||||
subscription.stripe_subscription_id,
|
||||
"sub_active_user".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
subscription.stripe_subscription_status,
|
||||
StripeSubscriptionStatus::Active
|
||||
);
|
||||
}
|
||||
|
||||
// A user with a past-due subscription has no active billing subscriptions.
|
||||
{
|
||||
let user_id = new_test_user(db, "past-due-user@example.com").await;
|
||||
let customer = db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id,
|
||||
stripe_customer_id: "cus_past_due_user".into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
|
||||
|
||||
db.create_billing_subscription(&CreateBillingSubscriptionParams {
|
||||
billing_customer_id: customer.id,
|
||||
kind: None,
|
||||
stripe_subscription_id: "sub_past_due_user".into(),
|
||||
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: None,
|
||||
stripe_current_period_end: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let subscription_count = db
|
||||
.count_active_billing_subscriptions(user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(subscription_count, 0);
|
||||
}
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
use super::*;
|
||||
|
||||
pub mod providers;
|
||||
pub mod subscription_usage_meters;
|
||||
pub mod subscription_usages;
|
||||
pub mod usages;
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
use crate::db::UserId;
|
||||
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Returns all current subscription usage meters as of the given timestamp.
|
||||
pub async fn get_current_subscription_usage_meters(
|
||||
&self,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
|
||||
let now = convert_chrono_to_time(now)?;
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
let result = subscription_usage_meter::Entity::find()
|
||||
.inner_join(subscription_usage::Entity)
|
||||
.filter(
|
||||
subscription_usage::Column::PeriodStartAt
|
||||
.lte(now)
|
||||
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
|
||||
)
|
||||
.select_also(subscription_usage::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let result = result
|
||||
.into_iter()
|
||||
.filter_map(|(meter, usage)| {
|
||||
let usage = usage?;
|
||||
Some((meter, usage))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns all current subscription usage meters for the given user as of the given timestamp.
|
||||
pub async fn get_current_subscription_usage_meters_for_user(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
|
||||
let now = convert_chrono_to_time(now)?;
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
let result = subscription_usage_meter::Entity::find()
|
||||
.inner_join(subscription_usage::Entity)
|
||||
.filter(subscription_usage::Column::UserId.eq(user_id))
|
||||
.filter(
|
||||
subscription_usage::Column::PeriodStartAt
|
||||
.lte(now)
|
||||
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
|
||||
)
|
||||
.select_also(subscription_usage::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let result = result
|
||||
.into_iter()
|
||||
.filter_map(|(meter, usage)| {
|
||||
let usage = usage?;
|
||||
Some((meter, usage))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,28 +1,7 @@
|
|||
use time::PrimitiveDateTime;
|
||||
|
||||
use crate::db::UserId;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
|
||||
use chrono::{Datelike as _, Timelike as _};
|
||||
|
||||
let date = time::Date::from_calendar_date(
|
||||
datetime.year(),
|
||||
time::Month::try_from(datetime.month() as u8).unwrap(),
|
||||
datetime.day() as u8,
|
||||
)?;
|
||||
|
||||
let time = time::Time::from_hms_nano(
|
||||
datetime.hour() as u8,
|
||||
datetime.minute() as u8,
|
||||
datetime.second() as u8,
|
||||
datetime.nanosecond(),
|
||||
)?;
|
||||
|
||||
Ok(PrimitiveDateTime::new(date, time))
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_subscription_usage_for_period(
|
||||
&self,
|
||||
|
|
|
@ -8,7 +8,6 @@ use axum::{
|
|||
};
|
||||
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
|
@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer;
|
|||
use tracing_subscriber::{
|
||||
Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt,
|
||||
};
|
||||
use util::{ResultExt as _, maybe};
|
||||
use util::ResultExt as _;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
|
||||
|
@ -133,29 +132,6 @@ async fn main() -> Result<()> {
|
|||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
let llm_db = maybe!(async {
|
||||
let database_url = state
|
||||
.config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.context("missing LLM_DATABASE_URL")?;
|
||||
let max_connections = state
|
||||
.config
|
||||
.llm_database_max_connections
|
||||
.context("missing LLM_DATABASE_MAX_CONNECTIONS")?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
LlmDatabase::new(db_options, state.executor.clone()).await
|
||||
})
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_request_usage_with_stripe_periodically(state.clone());
|
||||
}
|
||||
|
||||
app = app
|
||||
.merge(collab::api::events::router())
|
||||
.merge(collab::api::extensions::router())
|
||||
|
|
|
@ -340,9 +340,6 @@ impl Server {
|
|||
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
|
||||
.add_request_handler(
|
||||
forward_read_only_project_request::<proto::LanguageServerIdForName>,
|
||||
)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
|
||||
.add_request_handler(
|
||||
forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
|
||||
|
|
|
@ -24,10 +24,7 @@ use language::{
|
|||
};
|
||||
use project::{
|
||||
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
|
||||
lsp_store::{
|
||||
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
|
||||
rust_analyzer_ext::RUST_ANALYZER_NAME,
|
||||
},
|
||||
lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
|
||||
project_settings::{InlineBlameSettings, ProjectSettings},
|
||||
};
|
||||
use recent_projects::disconnected_overlay::DisconnectedOverlay;
|
||||
|
@ -3786,11 +3783,18 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
|
|||
cx_b.update(editor::init);
|
||||
|
||||
client_a.language_registry().add(rust_lang());
|
||||
client_b.language_registry().add(rust_lang());
|
||||
let mut fake_language_servers = client_a.language_registry().register_fake_lsp(
|
||||
"Rust",
|
||||
FakeLspAdapter {
|
||||
name: RUST_ANALYZER_NAME,
|
||||
name: "rust-analyzer",
|
||||
..FakeLspAdapter::default()
|
||||
},
|
||||
);
|
||||
client_b.language_registry().add(rust_lang());
|
||||
client_b.language_registry().register_fake_lsp_adapter(
|
||||
"Rust",
|
||||
FakeLspAdapter {
|
||||
name: "rust-analyzer",
|
||||
..FakeLspAdapter::default()
|
||||
},
|
||||
);
|
||||
|
|
|
@ -103,28 +103,16 @@ impl ChatPanel {
|
|||
});
|
||||
|
||||
cx.new(|cx| {
|
||||
let entity = cx.entity().downgrade();
|
||||
let message_list = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Bottom,
|
||||
px(1000.),
|
||||
move |ix, window, cx| {
|
||||
if let Some(entity) = entity.upgrade() {
|
||||
entity.update(cx, |this: &mut Self, cx| {
|
||||
this.render_message(ix, window, cx).into_any_element()
|
||||
})
|
||||
} else {
|
||||
div().into_any()
|
||||
}
|
||||
},
|
||||
);
|
||||
let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.));
|
||||
|
||||
message_list.set_scroll_handler(cx.listener(|this, event: &ListScrollEvent, _, cx| {
|
||||
if event.visible_range.start < MESSAGE_LOADING_THRESHOLD {
|
||||
this.load_more_messages(cx);
|
||||
}
|
||||
this.is_scrolled_to_bottom = !event.is_scrolled;
|
||||
}));
|
||||
message_list.set_scroll_handler(cx.listener(
|
||||
|this: &mut Self, event: &ListScrollEvent, _, cx| {
|
||||
if event.visible_range.start < MESSAGE_LOADING_THRESHOLD {
|
||||
this.load_more_messages(cx);
|
||||
}
|
||||
this.is_scrolled_to_bottom = !event.is_scrolled;
|
||||
},
|
||||
));
|
||||
|
||||
let local_offset = chrono::Local::now().offset().local_minus_utc();
|
||||
let mut this = Self {
|
||||
|
@ -399,7 +387,7 @@ impl ChatPanel {
|
|||
ix: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
) -> AnyElement {
|
||||
let active_chat = &self.active_chat.as_ref().unwrap().0;
|
||||
let (message, is_continuation_from_previous, is_admin) =
|
||||
active_chat.update(cx, |active_chat, cx| {
|
||||
|
@ -582,6 +570,7 @@ impl ChatPanel {
|
|||
self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx)
|
||||
.mt_neg_2p5(),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn has_open_menu(&self, message_id: Option<u64>) -> bool {
|
||||
|
@ -979,7 +968,13 @@ impl Render for ChatPanel {
|
|||
)
|
||||
.child(div().flex_grow().px_2().map(|this| {
|
||||
if self.active_chat.is_some() {
|
||||
this.child(list(self.message_list.clone()).size_full())
|
||||
this.child(
|
||||
list(
|
||||
self.message_list.clone(),
|
||||
cx.processor(Self::render_message),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
div()
|
||||
|
|
|
@ -324,20 +324,6 @@ impl CollabPanel {
|
|||
)
|
||||
.detach();
|
||||
|
||||
let entity = cx.entity().downgrade();
|
||||
let list_state = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Top,
|
||||
px(1000.),
|
||||
move |ix, window, cx| {
|
||||
if let Some(entity) = entity.upgrade() {
|
||||
entity.update(cx, |this, cx| this.render_list_entry(ix, window, cx))
|
||||
} else {
|
||||
div().into_any()
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut this = Self {
|
||||
width: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
|
@ -345,7 +331,7 @@ impl CollabPanel {
|
|||
fs: workspace.app_state().fs.clone(),
|
||||
pending_serialization: Task::ready(None),
|
||||
context_menu: None,
|
||||
list_state,
|
||||
list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)),
|
||||
channel_name_editor,
|
||||
filter_editor,
|
||||
entries: Vec::default(),
|
||||
|
@ -2431,7 +2417,13 @@ impl CollabPanel {
|
|||
});
|
||||
v_flex()
|
||||
.size_full()
|
||||
.child(list(self.list_state.clone()).size_full())
|
||||
.child(
|
||||
list(
|
||||
self.list_state.clone(),
|
||||
cx.processor(Self::render_list_entry),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.child(div().mx_2().border_primary(cx).border_t_1())
|
||||
|
@ -2605,7 +2597,7 @@ impl CollabPanel {
|
|||
let contact = contact.clone();
|
||||
move |this, event: &ClickEvent, window, cx| {
|
||||
this.deploy_contact_context_menu(
|
||||
event.down.position,
|
||||
event.position(),
|
||||
contact.clone(),
|
||||
window,
|
||||
cx,
|
||||
|
|
|
@ -118,16 +118,7 @@ impl NotificationPanel {
|
|||
})
|
||||
.detach();
|
||||
|
||||
let entity = cx.entity().downgrade();
|
||||
let notification_list =
|
||||
ListState::new(0, ListAlignment::Top, px(1000.), move |ix, window, cx| {
|
||||
entity
|
||||
.upgrade()
|
||||
.and_then(|entity| {
|
||||
entity.update(cx, |this, cx| this.render_notification(ix, window, cx))
|
||||
})
|
||||
.unwrap_or_else(|| div().into_any())
|
||||
});
|
||||
let notification_list = ListState::new(0, ListAlignment::Top, px(1000.));
|
||||
notification_list.set_scroll_handler(cx.listener(
|
||||
|this, event: &ListScrollEvent, _, cx| {
|
||||
if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD {
|
||||
|
@ -687,7 +678,16 @@ impl Render for NotificationPanel {
|
|||
),
|
||||
)
|
||||
} else {
|
||||
this.child(list(self.notification_list.clone()).size_full())
|
||||
this.child(
|
||||
list(
|
||||
self.notification_list.clone(),
|
||||
cx.processor(|this, ix, window, cx| {
|
||||
this.render_notification(ix, window, cx)
|
||||
.unwrap_or_else(|| div().into_any())
|
||||
}),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -318,8 +318,10 @@ pub enum ComponentScope {
|
|||
Notification,
|
||||
#[strum(serialize = "Overlays & Layering")]
|
||||
Overlays,
|
||||
Onboarding,
|
||||
Status,
|
||||
Typography,
|
||||
Utilities,
|
||||
#[strum(serialize = "Version Control")]
|
||||
VersionControl,
|
||||
}
|
||||
|
|
|
@ -1015,15 +1015,13 @@ impl DebugDelegate {
|
|||
let language_names = languages.language_names();
|
||||
let language = dap_registry
|
||||
.adapter_language(&scenario.adapter)
|
||||
.map(|language| TaskSourceKind::Language {
|
||||
name: language.into(),
|
||||
});
|
||||
.map(|language| TaskSourceKind::Language { name: language.0 });
|
||||
|
||||
let language = language.or_else(|| {
|
||||
scenario.label.split_whitespace().find_map(|word| {
|
||||
language_names
|
||||
.iter()
|
||||
.find(|name| name.eq_ignore_ascii_case(word))
|
||||
.find(|name| name.as_ref().eq_ignore_ascii_case(word))
|
||||
.map(|name| TaskSourceKind::Language {
|
||||
name: name.to_owned().into(),
|
||||
})
|
||||
|
|
|
@ -29,7 +29,6 @@ use ui::{
|
|||
Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable,
|
||||
Tooltip, Window, div, h_flex, px, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint};
|
||||
|
||||
|
@ -56,8 +55,6 @@ pub(crate) struct BreakpointList {
|
|||
scrollbar_state: ScrollbarState,
|
||||
breakpoints: Vec<BreakpointEntry>,
|
||||
session: Option<Entity<Session>>,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
show_scrollbar: bool,
|
||||
focus_handle: FocusHandle,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
selected_ix: Option<usize>,
|
||||
|
@ -103,8 +100,6 @@ impl BreakpointList {
|
|||
worktree_store,
|
||||
scrollbar_state,
|
||||
breakpoints: Default::default(),
|
||||
hide_scrollbar_task: None,
|
||||
show_scrollbar: false,
|
||||
workspace,
|
||||
session,
|
||||
focus_handle,
|
||||
|
@ -565,21 +560,6 @@ impl BreakpointList {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.show_scrollbar = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_list(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let selected_ix = self.selected_ix;
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
@ -614,43 +594,39 @@ impl BreakpointList {
|
|||
.flex_grow()
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("breakpoint-list-vertical-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("breakpoint-list-vertical-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
pub(crate) fn render_control_strip(&self) -> AnyElement {
|
||||
let selection_kind = self.selection_kind();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
@ -819,15 +795,6 @@ impl Render for BreakpointList {
|
|||
.id("breakpoint-list")
|
||||
.key_context("BreakpointList")
|
||||
.track_focus(&self.focus_handle)
|
||||
.on_hover(cx.listener(|this, hovered, window, cx| {
|
||||
if *hovered {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_task.take();
|
||||
cx.notify();
|
||||
} else if !this.focus_handle.contains_focused(window, cx) {
|
||||
this.hide_scrollbar(window, cx);
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(Self::select_next))
|
||||
.on_action(cx.listener(Self::select_previous))
|
||||
.on_action(cx.listener(Self::select_first))
|
||||
|
@ -844,7 +811,7 @@ impl Render for BreakpointList {
|
|||
v_flex()
|
||||
.size_full()
|
||||
.child(self.render_list(cx))
|
||||
.children(self.render_vertical_scrollbar(cx)),
|
||||
.child(self.render_vertical_scrollbar(cx)),
|
||||
)
|
||||
.when_some(self.strip_mode, |this, _| {
|
||||
this.child(Divider::horizontal()).child(
|
||||
|
|
|
@ -13,22 +13,8 @@ pub(crate) struct LoadedSourceList {
|
|||
|
||||
impl LoadedSourceList {
|
||||
pub fn new(session: Entity<Session>, cx: &mut Context<Self>) -> Self {
|
||||
let weak_entity = cx.weak_entity();
|
||||
let focus_handle = cx.focus_handle();
|
||||
|
||||
let list = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Top,
|
||||
px(1000.),
|
||||
move |ix, _window, cx| {
|
||||
weak_entity
|
||||
.upgrade()
|
||||
.map(|loaded_sources| {
|
||||
loaded_sources.update(cx, |this, cx| this.render_entry(ix, cx))
|
||||
})
|
||||
.unwrap_or(div().into_any())
|
||||
},
|
||||
);
|
||||
let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
|
||||
|
||||
let _subscription = cx.subscribe(&session, |this, _, event, cx| match event {
|
||||
SessionEvent::Stopped(_) | SessionEvent::LoadedSources => {
|
||||
|
@ -98,6 +84,12 @@ impl Render for LoadedSourceList {
|
|||
.track_focus(&self.focus_handle)
|
||||
.size_full()
|
||||
.p_1()
|
||||
.child(list(self.list.clone()).size_full())
|
||||
.child(
|
||||
list(
|
||||
self.list.clone(),
|
||||
cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ use ui::{
|
|||
ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString,
|
||||
StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList};
|
||||
|
@ -34,9 +33,7 @@ pub(crate) struct MemoryView {
|
|||
workspace: WeakEntity<Workspace>,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
scroll_state: ScrollbarState,
|
||||
show_scrollbar: bool,
|
||||
stack_frame_list: WeakEntity<StackFrameList>,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
focus_handle: FocusHandle,
|
||||
view_state: ViewState,
|
||||
query_editor: Entity<Editor>,
|
||||
|
@ -150,8 +147,6 @@ impl MemoryView {
|
|||
scroll_state,
|
||||
scroll_handle,
|
||||
stack_frame_list,
|
||||
show_scrollbar: false,
|
||||
hide_scrollbar_task: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
view_state,
|
||||
query_editor,
|
||||
|
@ -168,61 +163,42 @@ impl MemoryView {
|
|||
.detach();
|
||||
this
|
||||
}
|
||||
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.show_scrollbar = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !(self.show_scrollbar || self.scroll_state.is_dragging()) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("memory-view-vertical-scrollbar")
|
||||
.on_drag_move(cx.listener(|this, evt, _, cx| {
|
||||
let did_handle = this.handle_scroll_drag(evt);
|
||||
cx.notify();
|
||||
if did_handle {
|
||||
cx.stop_propagation()
|
||||
}
|
||||
}))
|
||||
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("memory-view-vertical-scrollbar")
|
||||
.on_drag_move(cx.listener(|this, evt, _, cx| {
|
||||
let did_handle = this.handle_scroll_drag(evt);
|
||||
cx.notify();
|
||||
if did_handle {
|
||||
cx.stop_propagation()
|
||||
}
|
||||
}))
|
||||
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scroll_state.clone())),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
fn render_memory(&self, cx: &mut Context<Self>) -> UniformList {
|
||||
|
@ -920,15 +896,6 @@ impl Render for MemoryView {
|
|||
.on_action(cx.listener(Self::page_up))
|
||||
.size_full()
|
||||
.track_focus(&self.focus_handle)
|
||||
.on_hover(cx.listener(|this, hovered, window, cx| {
|
||||
if *hovered {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_task.take();
|
||||
cx.notify();
|
||||
} else if !this.focus_handle.contains_focused(window, cx) {
|
||||
this.hide_scrollbar(window, cx);
|
||||
}
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
|
@ -978,7 +945,7 @@ impl Render for MemoryView {
|
|||
)
|
||||
.with_priority(1)
|
||||
}))
|
||||
.children(self.render_vertical_scrollbar(cx)),
|
||||
.child(self.render_vertical_scrollbar(cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,13 +70,7 @@ impl StackFrameList {
|
|||
_ => {}
|
||||
});
|
||||
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.), {
|
||||
let this = cx.weak_entity();
|
||||
move |ix, _window, cx| {
|
||||
this.update(cx, |this, cx| this.render_entry(ix, cx))
|
||||
.unwrap_or(div().into_any())
|
||||
}
|
||||
});
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
|
||||
let scrollbar_state = ScrollbarState::new(list_state.clone());
|
||||
|
||||
let mut this = Self {
|
||||
|
@ -708,11 +702,14 @@ impl StackFrameList {
|
|||
self.activate_selected_entry(window, cx);
|
||||
}
|
||||
|
||||
fn render_list(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.p_1()
|
||||
.size_full()
|
||||
.child(list(self.list_state.clone()).size_full())
|
||||
fn render_list(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div().p_1().size_full().child(
|
||||
list(
|
||||
self.list_state.clone(),
|
||||
cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1107,7 +1107,7 @@ impl VariableList {
|
|||
let variable_value = value.clone();
|
||||
this.on_click(cx.listener(
|
||||
move |this, click: &ClickEvent, window, cx| {
|
||||
if click.down.click_count < 2 {
|
||||
if click.click_count() < 2 {
|
||||
return;
|
||||
}
|
||||
let editor = Self::create_variable_editor(
|
||||
|
|
|
@ -29,16 +29,14 @@ pub fn switch_source_header(
|
|||
return;
|
||||
};
|
||||
|
||||
let server_lookup =
|
||||
find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME);
|
||||
let Some((_, _, server_to_query, buffer)) =
|
||||
find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let project = project.clone();
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |_editor, cx| {
|
||||
let Some((_, _, server_to_query, buffer)) =
|
||||
server_lookup.await
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let source_file = buffer.read_with(cx, |buffer, _| {
|
||||
buffer.file().map(|file| file.path()).map(|path| path.to_string_lossy().to_string()).unwrap_or_else(|| "Unknown".to_string())
|
||||
})?;
|
||||
|
|
|
@ -8183,7 +8183,7 @@ impl Editor {
|
|||
editor.set_breakpoint_context_menu(
|
||||
row,
|
||||
Some(position),
|
||||
event.down.position,
|
||||
event.position(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
@ -8350,7 +8350,11 @@ impl Editor {
|
|||
.icon_color(color)
|
||||
.toggle_state(is_active)
|
||||
.on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| {
|
||||
let quick_launch = e.down.button == MouseButton::Left;
|
||||
let quick_launch = match e {
|
||||
ClickEvent::Keyboard(_) => true,
|
||||
ClickEvent::Mouse(e) => e.down.button == MouseButton::Left,
|
||||
};
|
||||
|
||||
window.focus(&editor.focus_handle(cx));
|
||||
editor.toggle_code_actions(
|
||||
&ToggleCodeActions {
|
||||
|
@ -8362,7 +8366,7 @@ impl Editor {
|
|||
);
|
||||
}))
|
||||
.on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| {
|
||||
editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx);
|
||||
editor.set_breakpoint_context_menu(row, position, event.position(), window, cx);
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -22188,7 +22192,6 @@ impl SemanticsProvider for Entity<Project> {
|
|||
}
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Entity<Buffer>, cx: &mut App) -> bool {
|
||||
// TODO: make this work for remote projects
|
||||
self.update(cx, |project, cx| {
|
||||
if project
|
||||
.active_debug_session(cx)
|
||||
|
|
|
@ -43,11 +43,11 @@ use gpui::{
|
|||
Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges,
|
||||
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox,
|
||||
HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length,
|
||||
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad,
|
||||
ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString,
|
||||
Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity,
|
||||
Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px,
|
||||
quad, relative, size, solid_background, transparent_black,
|
||||
ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent, MouseMoveEvent,
|
||||
MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent,
|
||||
ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled, TextRun,
|
||||
TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop,
|
||||
linear_gradient, outline, point, px, quad, relative, size, solid_background, transparent_black,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use language::language_settings::{
|
||||
|
@ -949,8 +949,12 @@ impl EditorElement {
|
|||
|
||||
let hovered_link_modifier = Editor::multi_cursor_modifier(false, &event.modifiers(), cx);
|
||||
|
||||
if !pending_nonempty_selections && hovered_link_modifier && text_hitbox.is_hovered(window) {
|
||||
let point = position_map.point_for_position(event.up.position);
|
||||
if let Some(mouse_position) = event.mouse_position()
|
||||
&& !pending_nonempty_selections
|
||||
&& hovered_link_modifier
|
||||
&& text_hitbox.is_hovered(window)
|
||||
{
|
||||
let point = position_map.point_for_position(mouse_position);
|
||||
editor.handle_click_hovered_link(point, event.modifiers(), window, cx);
|
||||
editor.selection_drag_state = SelectionDragState::None;
|
||||
|
||||
|
@ -3735,7 +3739,7 @@ impl EditorElement {
|
|||
move |editor, e: &ClickEvent, window, cx| {
|
||||
editor.open_excerpts_common(
|
||||
Some(jump_data.clone()),
|
||||
e.down.modifiers.secondary(),
|
||||
e.modifiers().secondary(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
@ -6882,10 +6886,10 @@ impl EditorElement {
|
|||
// Fire click handlers during the bubble phase.
|
||||
DispatchPhase::Bubble => editor.update(cx, |editor, cx| {
|
||||
if let Some(mouse_down) = captured_mouse_down.take() {
|
||||
let event = ClickEvent {
|
||||
let event = ClickEvent::Mouse(MouseClickEvent {
|
||||
down: mouse_down,
|
||||
up: event.clone(),
|
||||
};
|
||||
});
|
||||
Self::click(editor, &event, &position_map, window, cx);
|
||||
}
|
||||
}),
|
||||
|
|
|
@ -3,9 +3,8 @@ use std::time::Duration;
|
|||
|
||||
use crate::Editor;
|
||||
use collections::HashMap;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language::Buffer;
|
||||
use language::Language;
|
||||
|
@ -18,7 +17,6 @@ use project::Project;
|
|||
use project::TaskSourceKind;
|
||||
use project::lsp_store::lsp_ext_command::GetLspRunnables;
|
||||
use smol::future::FutureExt as _;
|
||||
use smol::stream::StreamExt;
|
||||
use task::ResolvedTask;
|
||||
use task::TaskContext;
|
||||
use text::BufferId;
|
||||
|
@ -29,52 +27,32 @@ pub(crate) fn find_specific_language_server_in_selection<F>(
|
|||
editor: &Editor,
|
||||
cx: &mut App,
|
||||
filter_language: F,
|
||||
language_server_name: &str,
|
||||
) -> Task<Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)>>
|
||||
language_server_name: LanguageServerName,
|
||||
) -> Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)>
|
||||
where
|
||||
F: Fn(&Language) -> bool,
|
||||
{
|
||||
let Some(project) = &editor.project else {
|
||||
return Task::ready(None);
|
||||
};
|
||||
|
||||
let applicable_buffers = editor
|
||||
let project = editor.project.clone()?;
|
||||
editor
|
||||
.selections
|
||||
.disjoint_anchors()
|
||||
.iter()
|
||||
.filter_map(|selection| Some((selection.head(), selection.head().buffer_id?)))
|
||||
.unique_by(|(_, buffer_id)| *buffer_id)
|
||||
.filter_map(|(trigger_anchor, buffer_id)| {
|
||||
.find_map(|(trigger_anchor, buffer_id)| {
|
||||
let buffer = editor.buffer().read(cx).buffer(buffer_id)?;
|
||||
let language = buffer.read(cx).language_at(trigger_anchor.text_anchor)?;
|
||||
if filter_language(&language) {
|
||||
Some((trigger_anchor, buffer, language))
|
||||
let server_id = buffer.update(cx, |buffer, cx| {
|
||||
project
|
||||
.read(cx)
|
||||
.language_server_id_for_name(buffer, &language_server_name, cx)
|
||||
})?;
|
||||
Some((trigger_anchor, language, server_id, buffer))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let applicable_buffer_tasks = applicable_buffers
|
||||
.into_iter()
|
||||
.map(|(trigger_anchor, buffer, language)| {
|
||||
let task = buffer.update(cx, |buffer, cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
project.language_server_id_for_name(buffer, language_server_name, cx)
|
||||
})
|
||||
});
|
||||
(trigger_anchor, buffer, language, task)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.background_spawn(async move {
|
||||
for (trigger_anchor, buffer, language, task) in applicable_buffer_tasks {
|
||||
if let Some(server_id) = task.await {
|
||||
return Some((trigger_anchor, language, server_id, buffer));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
async fn lsp_task_context(
|
||||
|
@ -116,9 +94,9 @@ pub fn lsp_tasks(
|
|||
for_position: Option<text::Anchor>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<(TaskSourceKind, Vec<(Option<LocationLink>, ResolvedTask)>)>> {
|
||||
let mut lsp_task_sources = task_sources
|
||||
let lsp_task_sources = task_sources
|
||||
.iter()
|
||||
.map(|(name, buffer_ids)| {
|
||||
.filter_map(|(name, buffer_ids)| {
|
||||
let buffers = buffer_ids
|
||||
.iter()
|
||||
.filter(|&&buffer_id| match for_position {
|
||||
|
@ -127,61 +105,63 @@ pub fn lsp_tasks(
|
|||
})
|
||||
.filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx))
|
||||
.collect::<Vec<_>>();
|
||||
language_server_for_buffers(project.clone(), name.clone(), buffers, cx)
|
||||
|
||||
let server_id = buffers.iter().find_map(|buffer| {
|
||||
project.read_with(cx, |project, cx| {
|
||||
project.language_server_id_for_name(buffer.read(cx), name, cx)
|
||||
})
|
||||
});
|
||||
server_id.zip(Some(buffers))
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
cx.spawn(async move |cx| {
|
||||
let mut lsp_tasks = HashMap::default();
|
||||
while let Some(server_to_query) = lsp_task_sources.next().await {
|
||||
if let Some((server_id, buffers)) = server_to_query {
|
||||
let mut new_lsp_tasks = Vec::new();
|
||||
for buffer in buffers {
|
||||
let source_kind = match buffer.update(cx, |buffer, _| {
|
||||
buffer.language().map(|language| language.name())
|
||||
}) {
|
||||
Ok(Some(language_name)) => TaskSourceKind::Lsp {
|
||||
server: server_id,
|
||||
language_name: SharedString::from(language_name),
|
||||
},
|
||||
Ok(None) => continue,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let id_base = source_kind.to_id_base();
|
||||
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
for (server_id, buffers) in lsp_task_sources {
|
||||
let mut new_lsp_tasks = Vec::new();
|
||||
for buffer in buffers {
|
||||
let source_kind = match buffer.update(cx, |buffer, _| {
|
||||
buffer.language().map(|language| language.name())
|
||||
}) {
|
||||
Ok(Some(language_name)) => TaskSourceKind::Lsp {
|
||||
server: server_id,
|
||||
language_name: SharedString::from(language_name),
|
||||
},
|
||||
Ok(None) => continue,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
let id_base = source_kind.to_id_base();
|
||||
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Ok(runnables_task) = project.update(cx, |project, cx| {
|
||||
let buffer_id = buffer.read(cx).remote_id();
|
||||
project.request_lsp(
|
||||
buffer,
|
||||
LanguageServerToQuery::Other(server_id),
|
||||
GetLspRunnables {
|
||||
buffer_id,
|
||||
position: for_position,
|
||||
if let Ok(runnables_task) = project.update(cx, |project, cx| {
|
||||
let buffer_id = buffer.read(cx).remote_id();
|
||||
project.request_lsp(
|
||||
buffer,
|
||||
LanguageServerToQuery::Other(server_id),
|
||||
GetLspRunnables {
|
||||
buffer_id,
|
||||
position: for_position,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
}) {
|
||||
if let Some(new_runnables) = runnables_task.await.log_err() {
|
||||
new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map(
|
||||
|(location, runnable)| {
|
||||
let resolved_task =
|
||||
runnable.resolve_task(&id_base, &lsp_buffer_context)?;
|
||||
Some((location, resolved_task))
|
||||
},
|
||||
cx,
|
||||
)
|
||||
}) {
|
||||
if let Some(new_runnables) = runnables_task.await.log_err() {
|
||||
new_lsp_tasks.extend(
|
||||
new_runnables.runnables.into_iter().filter_map(
|
||||
|(location, runnable)| {
|
||||
let resolved_task = runnable
|
||||
.resolve_task(&id_base, &lsp_buffer_context)?;
|
||||
Some((location, resolved_task))
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
));
|
||||
}
|
||||
lsp_tasks
|
||||
.entry(source_kind)
|
||||
.or_insert_with(Vec::new)
|
||||
.append(&mut new_lsp_tasks);
|
||||
}
|
||||
lsp_tasks
|
||||
.entry(source_kind)
|
||||
.or_insert_with(Vec::new)
|
||||
.append(&mut new_lsp_tasks);
|
||||
}
|
||||
}
|
||||
lsp_tasks.into_iter().collect()
|
||||
|
@ -198,27 +178,3 @@ pub fn lsp_tasks(
|
|||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn language_server_for_buffers(
|
||||
project: Entity<Project>,
|
||||
name: LanguageServerName,
|
||||
candidates: Vec<Entity<Buffer>>,
|
||||
cx: &mut App,
|
||||
) -> Task<Option<(LanguageServerId, Vec<Entity<Buffer>>)>> {
|
||||
cx.spawn(async move |cx| {
|
||||
for buffer in &candidates {
|
||||
let server_id = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
project.language_server_id_for_name(buffer, &name.0, cx)
|
||||
})
|
||||
})
|
||||
.ok()?
|
||||
.await;
|
||||
if let Some(server_id) = server_id {
|
||||
return Some((server_id, candidates));
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
|
|
|
@ -57,21 +57,21 @@ pub fn go_to_parent_module(
|
|||
return;
|
||||
};
|
||||
|
||||
let server_lookup = find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
);
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) =
|
||||
find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let project = project.clone();
|
||||
let lsp_store = project.read(cx).lsp_store();
|
||||
let upstream_client = lsp_store.read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |editor, cx| {
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
let location_links = if let Some((client, project_id)) = upstream_client {
|
||||
let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?;
|
||||
|
||||
|
@ -121,7 +121,7 @@ pub fn go_to_parent_module(
|
|||
)
|
||||
})?
|
||||
.await?;
|
||||
Ok(())
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
@ -139,21 +139,19 @@ pub fn expand_macro_recursively(
|
|||
return;
|
||||
};
|
||||
|
||||
let server_lookup = find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
);
|
||||
|
||||
let Some((trigger_anchor, rust_language, server_to_query, buffer)) =
|
||||
find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let project = project.clone();
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |_editor, cx| {
|
||||
let Some((trigger_anchor, rust_language, server_to_query, buffer)) = server_lookup.await
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let macro_expansion = if let Some((client, project_id)) = upstream_client {
|
||||
let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?;
|
||||
let request = proto::LspExtExpandMacro {
|
||||
|
@ -231,20 +229,20 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu
|
|||
return;
|
||||
};
|
||||
|
||||
let server_lookup = find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
);
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) =
|
||||
find_specific_language_server_in_selection(
|
||||
editor,
|
||||
cx,
|
||||
is_rust_language,
|
||||
RUST_ANALYZER_NAME,
|
||||
)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let project = project.clone();
|
||||
let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client();
|
||||
cx.spawn_in(window, async move |_editor, cx| {
|
||||
let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let docs_urls = if let Some((client, project_id)) = upstream_client {
|
||||
let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?;
|
||||
let request = proto::LspExtOpenDocs {
|
||||
|
|
|
@ -10,7 +10,7 @@ use fs::{FakeFs, Fs, RealFs};
|
|||
use futures::{AsyncReadExt, StreamExt, io::BufReader};
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
|
||||
use http_client::{FakeHttpClient, Response};
|
||||
use language::{BinaryStatus, LanguageMatcher, LanguageRegistry};
|
||||
use language::{BinaryStatus, LanguageMatcher, LanguageName, LanguageRegistry};
|
||||
use language_extension::LspAccess;
|
||||
use lsp::LanguageServerName;
|
||||
use node_runtime::NodeRuntime;
|
||||
|
@ -306,7 +306,11 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
|||
|
||||
assert_eq!(
|
||||
language_registry.language_names(),
|
||||
["ERB", "Plain Text", "Ruby"]
|
||||
[
|
||||
LanguageName::new("ERB"),
|
||||
LanguageName::new("Plain Text"),
|
||||
LanguageName::new("Ruby"),
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
theme_registry.list_names(),
|
||||
|
@ -458,7 +462,11 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
|||
|
||||
assert_eq!(
|
||||
language_registry.language_names(),
|
||||
["ERB", "Plain Text", "Ruby"]
|
||||
[
|
||||
LanguageName::new("ERB"),
|
||||
LanguageName::new("Plain Text"),
|
||||
LanguageName::new("Ruby"),
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
language_registry.grammar_names(),
|
||||
|
@ -513,7 +521,10 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
|||
assert_eq!(actual_language.hidden, expected_language.hidden);
|
||||
}
|
||||
|
||||
assert_eq!(language_registry.language_names(), ["Plain Text"]);
|
||||
assert_eq!(
|
||||
language_registry.language_names(),
|
||||
[LanguageName::new("Plain Text")]
|
||||
);
|
||||
assert_eq!(language_registry.grammar_names(), []);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -208,8 +208,15 @@ impl<'a> Matcher<'a> {
|
|||
return 1.0;
|
||||
}
|
||||
|
||||
let path_len = prefix.len() + path.len();
|
||||
let limit = self.last_positions[query_idx];
|
||||
let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
|
||||
let safe_limit = limit.min(max_valid_index);
|
||||
|
||||
if path_idx > safe_limit {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let path_len = prefix.len() + path.len();
|
||||
if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
|
||||
return memoized;
|
||||
}
|
||||
|
@ -218,16 +225,13 @@ impl<'a> Matcher<'a> {
|
|||
let mut best_position = 0;
|
||||
|
||||
let query_char = self.lowercase_query[query_idx];
|
||||
let limit = self.last_positions[query_idx];
|
||||
|
||||
let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
|
||||
let safe_limit = limit.min(max_valid_index);
|
||||
|
||||
let mut last_slash = 0;
|
||||
|
||||
for j in path_idx..=safe_limit {
|
||||
let extra_lowercase_chars_count = extra_lowercase_chars
|
||||
.iter()
|
||||
.take_while(|(i, _)| i < &&j)
|
||||
.take_while(|&(&i, _)| i < j)
|
||||
.map(|(_, increment)| increment)
|
||||
.sum::<usize>();
|
||||
let j_regular = j - extra_lowercase_chars_count;
|
||||
|
@ -236,10 +240,9 @@ impl<'a> Matcher<'a> {
|
|||
lowercase_prefix[j]
|
||||
} else {
|
||||
let path_index = j - prefix.len();
|
||||
if path_index < path_lowercased.len() {
|
||||
path_lowercased[path_index]
|
||||
} else {
|
||||
continue;
|
||||
match path_lowercased.get(path_index) {
|
||||
Some(&char) => char,
|
||||
None => continue,
|
||||
}
|
||||
};
|
||||
let is_path_sep = path_char == MAIN_SEPARATOR;
|
||||
|
@ -255,18 +258,16 @@ impl<'a> Matcher<'a> {
|
|||
#[cfg(target_os = "windows")]
|
||||
let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
|
||||
if need_to_score {
|
||||
let curr = if j_regular < prefix.len() {
|
||||
prefix[j_regular]
|
||||
} else {
|
||||
path[j_regular - prefix.len()]
|
||||
let curr = match prefix.get(j_regular) {
|
||||
Some(&curr) => curr,
|
||||
None => path[j_regular - prefix.len()],
|
||||
};
|
||||
|
||||
let mut char_score = 1.0;
|
||||
if j > path_idx {
|
||||
let last = if j_regular - 1 < prefix.len() {
|
||||
prefix[j_regular - 1]
|
||||
} else {
|
||||
path[j_regular - 1 - prefix.len()]
|
||||
let last = match prefix.get(j_regular - 1) {
|
||||
Some(&last) => last,
|
||||
None => path[j_regular - 1 - prefix.len()],
|
||||
};
|
||||
|
||||
if last == MAIN_SEPARATOR {
|
||||
|
|
|
@ -111,8 +111,24 @@ impl Render for Example {
|
|||
.flex_row()
|
||||
.gap_3()
|
||||
.items_center()
|
||||
.child(button("el1").tab_index(4).child("Button 1"))
|
||||
.child(button("el2").tab_index(5).child("Button 2")),
|
||||
.child(
|
||||
button("el1")
|
||||
.tab_index(4)
|
||||
.child("Button 1")
|
||||
.on_click(cx.listener(|this, _, _, cx| {
|
||||
this.message = "You have clicked Button 1.".into();
|
||||
cx.notify();
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
button("el2")
|
||||
.tab_index(5)
|
||||
.child("Button 2")
|
||||
.on_click(cx.listener(|this, _, _, cx| {
|
||||
this.message = "You have clicked Button 2.".into();
|
||||
cx.notify();
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -165,8 +165,8 @@ impl Render for WindowShadow {
|
|||
},
|
||||
)
|
||||
.on_click(|e, window, _| {
|
||||
if e.down.button == MouseButton::Right {
|
||||
window.show_window_menu(e.up.position);
|
||||
if e.is_right_click() {
|
||||
window.show_window_menu(e.position());
|
||||
}
|
||||
})
|
||||
.text_color(black())
|
||||
|
|
|
@ -19,10 +19,10 @@ use crate::{
|
|||
Action, AnyDrag, AnyElement, AnyTooltip, AnyView, App, Bounds, ClickEvent, DispatchPhase,
|
||||
Element, ElementId, Entity, FocusHandle, Global, GlobalElementId, Hitbox, HitboxBehavior,
|
||||
HitboxId, InspectorElementId, IntoElement, IsZero, KeyContext, KeyDownEvent, KeyUpEvent,
|
||||
LayoutId, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent,
|
||||
Overflow, ParentElement, Pixels, Point, Render, ScrollWheelEvent, SharedString, Size, Style,
|
||||
StyleRefinement, Styled, Task, TooltipId, Visibility, Window, WindowControlArea, point, px,
|
||||
size,
|
||||
KeyboardButton, KeyboardClickEvent, LayoutId, ModifiersChangedEvent, MouseButton,
|
||||
MouseClickEvent, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Overflow, ParentElement, Pixels,
|
||||
Point, Render, ScrollWheelEvent, SharedString, Size, Style, StyleRefinement, Styled, Task,
|
||||
TooltipId, Visibility, Window, WindowControlArea, point, px, size,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use refineable::Refineable;
|
||||
|
@ -484,10 +484,9 @@ impl Interactivity {
|
|||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.click_listeners
|
||||
.push(Box::new(move |event, window, cx| {
|
||||
listener(event, window, cx)
|
||||
}));
|
||||
self.click_listeners.push(Rc::new(move |event, window, cx| {
|
||||
listener(event, window, cx)
|
||||
}));
|
||||
}
|
||||
|
||||
/// On drag initiation, this callback will be used to create a new view to render the dragged value for a
|
||||
|
@ -1156,7 +1155,7 @@ pub(crate) type MouseMoveListener =
|
|||
pub(crate) type ScrollWheelListener =
|
||||
Box<dyn Fn(&ScrollWheelEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>;
|
||||
|
||||
pub(crate) type ClickListener = Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>;
|
||||
pub(crate) type ClickListener = Rc<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>;
|
||||
|
||||
pub(crate) type DragListener =
|
||||
Box<dyn Fn(&dyn Any, Point<Pixels>, &mut Window, &mut App) -> AnyView + 'static>;
|
||||
|
@ -1950,6 +1949,12 @@ impl Interactivity {
|
|||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let is_focused = self
|
||||
.tracked_focus_handle
|
||||
.as_ref()
|
||||
.map(|handle| handle.is_focused(window))
|
||||
.unwrap_or(false);
|
||||
|
||||
// If this element can be focused, register a mouse down listener
|
||||
// that will automatically transfer focus when hitting the element.
|
||||
// This behavior can be suppressed by using `cx.prevent_default()`.
|
||||
|
@ -2113,6 +2118,39 @@ impl Interactivity {
|
|||
}
|
||||
});
|
||||
|
||||
if is_focused {
|
||||
// Press enter, space to trigger click, when the element is focused.
|
||||
window.on_key_event({
|
||||
let click_listeners = click_listeners.clone();
|
||||
let hitbox = hitbox.clone();
|
||||
move |event: &KeyUpEvent, phase, window, cx| {
|
||||
if phase.bubble() && !window.default_prevented() {
|
||||
let stroke = &event.keystroke;
|
||||
let keyboard_button = if stroke.key.eq("enter") {
|
||||
Some(KeyboardButton::Enter)
|
||||
} else if stroke.key.eq("space") {
|
||||
Some(KeyboardButton::Space)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(button) = keyboard_button
|
||||
&& !stroke.modifiers.modified()
|
||||
{
|
||||
let click_event = ClickEvent::Keyboard(KeyboardClickEvent {
|
||||
button,
|
||||
bounds: hitbox.bounds,
|
||||
});
|
||||
|
||||
for listener in &click_listeners {
|
||||
listener(&click_event, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
window.on_mouse_event({
|
||||
let mut captured_mouse_down = None;
|
||||
let hitbox = hitbox.clone();
|
||||
|
@ -2138,10 +2176,10 @@ impl Interactivity {
|
|||
// Fire click handlers during the bubble phase.
|
||||
DispatchPhase::Bubble => {
|
||||
if let Some(mouse_down) = captured_mouse_down.take() {
|
||||
let mouse_click = ClickEvent {
|
||||
let mouse_click = ClickEvent::Mouse(MouseClickEvent {
|
||||
down: mouse_down,
|
||||
up: event.clone(),
|
||||
};
|
||||
});
|
||||
for listener in &click_listeners {
|
||||
listener(&mouse_click, window, cx);
|
||||
}
|
||||
|
|
|
@ -18,10 +18,16 @@ use refineable::Refineable as _;
|
|||
use std::{cell::RefCell, ops::Range, rc::Rc};
|
||||
use sum_tree::{Bias, Dimensions, SumTree};
|
||||
|
||||
type RenderItemFn = dyn FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static;
|
||||
|
||||
/// Construct a new list element
|
||||
pub fn list(state: ListState) -> List {
|
||||
pub fn list(
|
||||
state: ListState,
|
||||
render_item: impl FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static,
|
||||
) -> List {
|
||||
List {
|
||||
state,
|
||||
render_item: Box::new(render_item),
|
||||
style: StyleRefinement::default(),
|
||||
sizing_behavior: ListSizingBehavior::default(),
|
||||
}
|
||||
|
@ -30,6 +36,7 @@ pub fn list(state: ListState) -> List {
|
|||
/// A list element
|
||||
pub struct List {
|
||||
state: ListState,
|
||||
render_item: Box<RenderItemFn>,
|
||||
style: StyleRefinement,
|
||||
sizing_behavior: ListSizingBehavior,
|
||||
}
|
||||
|
@ -55,7 +62,6 @@ impl std::fmt::Debug for ListState {
|
|||
struct StateInner {
|
||||
last_layout_bounds: Option<Bounds<Pixels>>,
|
||||
last_padding: Option<Edges<Pixels>>,
|
||||
render_item: Box<dyn FnMut(usize, &mut Window, &mut App) -> AnyElement>,
|
||||
items: SumTree<ListItem>,
|
||||
logical_scroll_top: Option<ListOffset>,
|
||||
alignment: ListAlignment,
|
||||
|
@ -186,19 +192,10 @@ impl ListState {
|
|||
/// above and below the visible area. Elements within this area will
|
||||
/// be measured even though they are not visible. This can help ensure
|
||||
/// that the list doesn't flicker or pop in when scrolling.
|
||||
pub fn new<R>(
|
||||
item_count: usize,
|
||||
alignment: ListAlignment,
|
||||
overdraw: Pixels,
|
||||
render_item: R,
|
||||
) -> Self
|
||||
where
|
||||
R: 'static + FnMut(usize, &mut Window, &mut App) -> AnyElement,
|
||||
{
|
||||
pub fn new(item_count: usize, alignment: ListAlignment, overdraw: Pixels) -> Self {
|
||||
let this = Self(Rc::new(RefCell::new(StateInner {
|
||||
last_layout_bounds: None,
|
||||
last_padding: None,
|
||||
render_item: Box::new(render_item),
|
||||
items: SumTree::default(),
|
||||
logical_scroll_top: None,
|
||||
alignment,
|
||||
|
@ -532,6 +529,7 @@ impl StateInner {
|
|||
available_width: Option<Pixels>,
|
||||
available_height: Pixels,
|
||||
padding: &Edges<Pixels>,
|
||||
render_item: &mut RenderItemFn,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> LayoutItemsResponse {
|
||||
|
@ -566,7 +564,7 @@ impl StateInner {
|
|||
// If we're within the visible area or the height wasn't cached, render and measure the item's element
|
||||
if visible_height < available_height || size.is_none() {
|
||||
let item_index = scroll_top.item_ix + ix;
|
||||
let mut element = (self.render_item)(item_index, window, cx);
|
||||
let mut element = render_item(item_index, window, cx);
|
||||
let element_size = element.layout_as_root(available_item_space, window, cx);
|
||||
size = Some(element_size);
|
||||
if visible_height < available_height {
|
||||
|
@ -601,7 +599,7 @@ impl StateInner {
|
|||
cursor.prev();
|
||||
if let Some(item) = cursor.item() {
|
||||
let item_index = cursor.start().0;
|
||||
let mut element = (self.render_item)(item_index, window, cx);
|
||||
let mut element = render_item(item_index, window, cx);
|
||||
let element_size = element.layout_as_root(available_item_space, window, cx);
|
||||
let focus_handle = item.focus_handle();
|
||||
rendered_height += element_size.height;
|
||||
|
@ -650,7 +648,7 @@ impl StateInner {
|
|||
let size = if let ListItem::Measured { size, .. } = item {
|
||||
*size
|
||||
} else {
|
||||
let mut element = (self.render_item)(cursor.start().0, window, cx);
|
||||
let mut element = render_item(cursor.start().0, window, cx);
|
||||
element.layout_as_root(available_item_space, window, cx)
|
||||
};
|
||||
|
||||
|
@ -683,7 +681,7 @@ impl StateInner {
|
|||
while let Some(item) = cursor.item() {
|
||||
if item.contains_focused(window, cx) {
|
||||
let item_index = cursor.start().0;
|
||||
let mut element = (self.render_item)(cursor.start().0, window, cx);
|
||||
let mut element = render_item(cursor.start().0, window, cx);
|
||||
let size = element.layout_as_root(available_item_space, window, cx);
|
||||
item_layouts.push_back(ItemLayout {
|
||||
index: item_index,
|
||||
|
@ -708,6 +706,7 @@ impl StateInner {
|
|||
bounds: Bounds<Pixels>,
|
||||
padding: Edges<Pixels>,
|
||||
autoscroll: bool,
|
||||
render_item: &mut RenderItemFn,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Result<LayoutItemsResponse, ListOffset> {
|
||||
|
@ -716,6 +715,7 @@ impl StateInner {
|
|||
Some(bounds.size.width),
|
||||
bounds.size.height,
|
||||
&padding,
|
||||
render_item,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
@ -753,8 +753,7 @@ impl StateInner {
|
|||
let Some(item) = cursor.item() else { break };
|
||||
|
||||
let size = item.size().unwrap_or_else(|| {
|
||||
let mut item =
|
||||
(self.render_item)(cursor.start().0, window, cx);
|
||||
let mut item = render_item(cursor.start().0, window, cx);
|
||||
let item_available_size = size(
|
||||
bounds.size.width.into(),
|
||||
AvailableSpace::MinContent,
|
||||
|
@ -876,8 +875,14 @@ impl Element for List {
|
|||
window.rem_size(),
|
||||
);
|
||||
|
||||
let layout_response =
|
||||
state.layout_items(None, available_height, &padding, window, cx);
|
||||
let layout_response = state.layout_items(
|
||||
None,
|
||||
available_height,
|
||||
&padding,
|
||||
&mut self.render_item,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_element_width = layout_response.max_item_width;
|
||||
|
||||
let summary = state.items.summary();
|
||||
|
@ -951,15 +956,16 @@ impl Element for List {
|
|||
let padding = style
|
||||
.padding
|
||||
.to_pixels(bounds.size.into(), window.rem_size());
|
||||
let layout = match state.prepaint_items(bounds, padding, true, window, cx) {
|
||||
Ok(layout) => layout,
|
||||
Err(autoscroll_request) => {
|
||||
state.logical_scroll_top = Some(autoscroll_request);
|
||||
state
|
||||
.prepaint_items(bounds, padding, false, window, cx)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
let layout =
|
||||
match state.prepaint_items(bounds, padding, true, &mut self.render_item, window, cx) {
|
||||
Ok(layout) => layout,
|
||||
Err(autoscroll_request) => {
|
||||
state.logical_scroll_top = Some(autoscroll_request);
|
||||
state
|
||||
.prepaint_items(bounds, padding, false, &mut self.render_item, window, cx)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
state.last_layout_bounds = Some(bounds);
|
||||
state.last_padding = Some(padding);
|
||||
|
@ -1108,9 +1114,7 @@ mod test {
|
|||
|
||||
let cx = cx.add_empty_window();
|
||||
|
||||
let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| {
|
||||
div().h(px(10.)).w_full().into_any()
|
||||
});
|
||||
let state = ListState::new(5, crate::ListAlignment::Top, px(10.));
|
||||
|
||||
// Ensure that the list is scrolled to the top
|
||||
state.scroll_to(gpui::ListOffset {
|
||||
|
@ -1121,7 +1125,11 @@ mod test {
|
|||
struct TestView(ListState);
|
||||
impl Render for TestView {
|
||||
fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
|
||||
list(self.0.clone()).w_full().h_full()
|
||||
list(self.0.clone(), |_, _, _| {
|
||||
div().h(px(10.)).w_full().into_any()
|
||||
})
|
||||
.w_full()
|
||||
.h_full()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1154,14 +1162,16 @@ mod test {
|
|||
|
||||
let cx = cx.add_empty_window();
|
||||
|
||||
let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| {
|
||||
div().h(px(20.)).w_full().into_any()
|
||||
});
|
||||
let state = ListState::new(5, crate::ListAlignment::Top, px(10.));
|
||||
|
||||
struct TestView(ListState);
|
||||
impl Render for TestView {
|
||||
fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
|
||||
list(self.0.clone()).w_full().h_full()
|
||||
list(self.0.clone(), |_, _, _| {
|
||||
div().h(px(20.)).w_full().into_any()
|
||||
})
|
||||
.w_full()
|
||||
.h_full()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, Window,
|
||||
point, seal::Sealed,
|
||||
Bounds, Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render,
|
||||
Window, point, seal::Sealed,
|
||||
};
|
||||
use smallvec::SmallVec;
|
||||
use std::{any::Any, fmt::Debug, ops::Deref, path::PathBuf};
|
||||
|
@ -141,7 +141,7 @@ impl MouseEvent for MouseUpEvent {}
|
|||
|
||||
/// A click event, generated when a mouse button is pressed and released.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ClickEvent {
|
||||
pub struct MouseClickEvent {
|
||||
/// The mouse event when the button was pressed.
|
||||
pub down: MouseDownEvent,
|
||||
|
||||
|
@ -149,18 +149,119 @@ pub struct ClickEvent {
|
|||
pub up: MouseUpEvent,
|
||||
}
|
||||
|
||||
/// A click event that was generated by a keyboard button being pressed and released.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeyboardClickEvent {
|
||||
/// The keyboard button that was pressed to trigger the click.
|
||||
pub button: KeyboardButton,
|
||||
|
||||
/// The bounds of the element that was clicked.
|
||||
pub bounds: Bounds<Pixels>,
|
||||
}
|
||||
|
||||
/// A click event, generated when a mouse button or keyboard button is pressed and released.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ClickEvent {
|
||||
/// A click event trigger by a mouse button being pressed and released.
|
||||
Mouse(MouseClickEvent),
|
||||
/// A click event trigger by a keyboard button being pressed and released.
|
||||
Keyboard(KeyboardClickEvent),
|
||||
}
|
||||
|
||||
impl ClickEvent {
|
||||
/// Returns the modifiers that were held down during both the
|
||||
/// mouse down and mouse up events
|
||||
/// Returns the modifiers that were held during the click event
|
||||
///
|
||||
/// `Keyboard`: The keyboard click events never have modifiers.
|
||||
/// `Mouse`: Modifiers that were held during the mouse key up event.
|
||||
pub fn modifiers(&self) -> Modifiers {
|
||||
Modifiers {
|
||||
control: self.up.modifiers.control && self.down.modifiers.control,
|
||||
alt: self.up.modifiers.alt && self.down.modifiers.alt,
|
||||
shift: self.up.modifiers.shift && self.down.modifiers.shift,
|
||||
platform: self.up.modifiers.platform && self.down.modifiers.platform,
|
||||
function: self.up.modifiers.function && self.down.modifiers.function,
|
||||
match self {
|
||||
// Click events are only generated from keyboard events _without any modifiers_, so we know the modifiers are always Default
|
||||
ClickEvent::Keyboard(_) => Modifiers::default(),
|
||||
// Click events on the web only reflect the modifiers for the keyup event,
|
||||
// tested via observing the behavior of the `ClickEvent.shiftKey` field in Chrome 138
|
||||
// under various combinations of modifiers and keyUp / keyDown events.
|
||||
ClickEvent::Mouse(event) => event.up.modifiers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the position of the click event
|
||||
///
|
||||
/// `Keyboard`: The bottom left corner of the clicked hitbox
|
||||
/// `Mouse`: The position of the mouse when the button was released.
|
||||
pub fn position(&self) -> Point<Pixels> {
|
||||
match self {
|
||||
ClickEvent::Keyboard(event) => event.bounds.bottom_left(),
|
||||
ClickEvent::Mouse(event) => event.up.position,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the mouse position of the click event
|
||||
///
|
||||
/// `Keyboard`: None
|
||||
/// `Mouse`: The position of the mouse when the button was released.
|
||||
pub fn mouse_position(&self) -> Option<Point<Pixels>> {
|
||||
match self {
|
||||
ClickEvent::Keyboard(_) => None,
|
||||
ClickEvent::Mouse(event) => Some(event.up.position),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns if this was a right click
|
||||
///
|
||||
/// `Keyboard`: false
|
||||
/// `Mouse`: Whether the right button was pressed and released
|
||||
pub fn is_right_click(&self) -> bool {
|
||||
match self {
|
||||
ClickEvent::Keyboard(_) => false,
|
||||
ClickEvent::Mouse(event) => {
|
||||
event.down.button == MouseButton::Right && event.up.button == MouseButton::Right
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the click was a standard click
|
||||
///
|
||||
/// `Keyboard`: Always true
|
||||
/// `Mouse`: Left button pressed and released
|
||||
pub fn standard_click(&self) -> bool {
|
||||
match self {
|
||||
ClickEvent::Keyboard(_) => true,
|
||||
ClickEvent::Mouse(event) => {
|
||||
event.down.button == MouseButton::Left && event.up.button == MouseButton::Left
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the click focused the element
|
||||
///
|
||||
/// `Keyboard`: false, keyboard clicks only work if an element is already focused
|
||||
/// `Mouse`: Whether this was the first focusing click
|
||||
pub fn first_focus(&self) -> bool {
|
||||
match self {
|
||||
ClickEvent::Keyboard(_) => false,
|
||||
ClickEvent::Mouse(event) => event.down.first_mouse,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the click count of the click event
|
||||
///
|
||||
/// `Keyboard`: Always 1
|
||||
/// `Mouse`: Count of clicks from MouseUpEvent
|
||||
pub fn click_count(&self) -> usize {
|
||||
match self {
|
||||
ClickEvent::Keyboard(_) => 1,
|
||||
ClickEvent::Mouse(event) => event.up.click_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the keyboard button that was pressed for a click event.
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
|
||||
pub enum KeyboardButton {
|
||||
/// Enter key was clicked
|
||||
Enter,
|
||||
/// Space key was clicked
|
||||
Space,
|
||||
}
|
||||
|
||||
/// An enum representing the mouse button that was pressed.
|
||||
|
|
|
@ -606,7 +606,7 @@ impl BladeRenderer {
|
|||
xy_position: v.xy_position,
|
||||
st_position: v.st_position,
|
||||
color: path.color,
|
||||
bounds: path.bounds.intersect(&path.content_mask.bounds),
|
||||
bounds: path.clipped_bounds(),
|
||||
}));
|
||||
}
|
||||
let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) };
|
||||
|
@ -735,13 +735,13 @@ impl BladeRenderer {
|
|||
paths
|
||||
.iter()
|
||||
.map(|path| PathSprite {
|
||||
bounds: path.bounds,
|
||||
bounds: path.clipped_bounds(),
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
let mut bounds = first_path.bounds;
|
||||
let mut bounds = first_path.clipped_bounds();
|
||||
for path in paths.iter().skip(1) {
|
||||
bounds = bounds.union(&path.bounds);
|
||||
bounds = bounds.union(&path.clipped_bounds());
|
||||
}
|
||||
vec![PathSprite { bounds }]
|
||||
};
|
||||
|
|
|
@ -791,13 +791,13 @@ impl MetalRenderer {
|
|||
sprites = paths
|
||||
.iter()
|
||||
.map(|path| PathSprite {
|
||||
bounds: path.bounds,
|
||||
bounds: path.clipped_bounds(),
|
||||
})
|
||||
.collect();
|
||||
} else {
|
||||
let mut bounds = first_path.bounds;
|
||||
let mut bounds = first_path.clipped_bounds();
|
||||
for path in paths.iter().skip(1) {
|
||||
bounds = bounds.union(&path.bounds);
|
||||
bounds = bounds.union(&path.clipped_bounds());
|
||||
}
|
||||
sprites = vec![PathSprite { bounds }];
|
||||
}
|
||||
|
|
|
@ -435,7 +435,7 @@ impl DirectXRenderer {
|
|||
xy_position: v.xy_position,
|
||||
st_position: v.st_position,
|
||||
color: path.color,
|
||||
bounds: path.bounds.intersect(&path.content_mask.bounds),
|
||||
bounds: path.clipped_bounds(),
|
||||
}));
|
||||
}
|
||||
|
||||
|
@ -487,13 +487,13 @@ impl DirectXRenderer {
|
|||
paths
|
||||
.iter()
|
||||
.map(|path| PathSprite {
|
||||
bounds: path.bounds,
|
||||
bounds: path.clipped_bounds(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let mut bounds = first_path.bounds;
|
||||
let mut bounds = first_path.clipped_bounds();
|
||||
for path in paths.iter().skip(1) {
|
||||
bounds = bounds.union(&path.bounds);
|
||||
bounds = bounds.union(&path.clipped_bounds());
|
||||
}
|
||||
vec![PathSprite { bounds }]
|
||||
};
|
||||
|
|
|
@ -174,20 +174,37 @@ impl WindowsWindowInner {
|
|||
let width = lparam.loword().max(1) as i32;
|
||||
let height = lparam.hiword().max(1) as i32;
|
||||
let new_size = size(DevicePixels(width), DevicePixels(height));
|
||||
|
||||
let scale_factor = lock.scale_factor;
|
||||
let mut should_resize_renderer = false;
|
||||
if lock.restore_from_minimized.is_some() {
|
||||
lock.callbacks.request_frame = lock.restore_from_minimized.take();
|
||||
} else {
|
||||
lock.renderer.resize(new_size).log_err();
|
||||
should_resize_renderer = true;
|
||||
}
|
||||
drop(lock);
|
||||
|
||||
self.handle_size_change(new_size, scale_factor, should_resize_renderer);
|
||||
Some(0)
|
||||
}
|
||||
|
||||
fn handle_size_change(
|
||||
&self,
|
||||
device_size: Size<DevicePixels>,
|
||||
scale_factor: f32,
|
||||
should_resize_renderer: bool,
|
||||
) {
|
||||
let new_logical_size = device_size.to_pixels(scale_factor);
|
||||
let mut lock = self.state.borrow_mut();
|
||||
lock.logical_size = new_logical_size;
|
||||
if should_resize_renderer {
|
||||
lock.renderer.resize(device_size).log_err();
|
||||
}
|
||||
let new_size = new_size.to_pixels(scale_factor);
|
||||
lock.logical_size = new_size;
|
||||
if let Some(mut callback) = lock.callbacks.resize.take() {
|
||||
drop(lock);
|
||||
callback(new_size, scale_factor);
|
||||
callback(new_logical_size, scale_factor);
|
||||
self.state.borrow_mut().callbacks.resize = Some(callback);
|
||||
}
|
||||
Some(0)
|
||||
}
|
||||
|
||||
fn handle_size_move_loop(&self, handle: HWND) -> Option<isize> {
|
||||
|
@ -747,7 +764,9 @@ impl WindowsWindowInner {
|
|||
) -> Option<isize> {
|
||||
let new_dpi = wparam.loword() as f32;
|
||||
let mut lock = self.state.borrow_mut();
|
||||
lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32;
|
||||
let is_maximized = lock.is_maximized();
|
||||
let new_scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32;
|
||||
lock.scale_factor = new_scale_factor;
|
||||
lock.border_offset.update(handle).log_err();
|
||||
drop(lock);
|
||||
|
||||
|
@ -771,6 +790,13 @@ impl WindowsWindowInner {
|
|||
.log_err();
|
||||
}
|
||||
|
||||
// When maximized, SetWindowPos doesn't send WM_SIZE, so we need to manually
|
||||
// update the size and call the resize callback
|
||||
if is_maximized {
|
||||
let device_size = size(DevicePixels(width), DevicePixels(height));
|
||||
self.handle_size_change(device_size, new_scale_factor, true);
|
||||
}
|
||||
|
||||
Some(0)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,12 @@ use crate::{
|
|||
AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels,
|
||||
Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point,
|
||||
};
|
||||
use std::{fmt::Debug, iter::Peekable, ops::Range, slice};
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
iter::Peekable,
|
||||
ops::{Add, Range, Sub},
|
||||
slice,
|
||||
};
|
||||
|
||||
#[allow(non_camel_case_types, unused)]
|
||||
pub(crate) type PathVertex_ScaledPixels = PathVertex<ScaledPixels>;
|
||||
|
@ -793,6 +798,16 @@ impl Path<Pixels> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Path<T>
|
||||
where
|
||||
T: Clone + Debug + Default + PartialEq + PartialOrd + Add<T, Output = T> + Sub<Output = T>,
|
||||
{
|
||||
#[allow(unused)]
|
||||
pub(crate) fn clipped_bounds(&self) -> Bounds<T> {
|
||||
self.bounds.intersect(&self.content_mask.bounds)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Path<ScaledPixels>> for Primitive {
|
||||
fn from(path: Path<ScaledPixels>) -> Self {
|
||||
Primitive::Path(path)
|
||||
|
|
|
@ -79,11 +79,13 @@ pub enum DispatchPhase {
|
|||
|
||||
impl DispatchPhase {
|
||||
/// Returns true if this represents the "bubble" phase.
|
||||
#[inline]
|
||||
pub fn bubble(self) -> bool {
|
||||
self == DispatchPhase::Bubble
|
||||
}
|
||||
|
||||
/// Returns true if this represents the "capture" phase.
|
||||
#[inline]
|
||||
pub fn capture(self) -> bool {
|
||||
self == DispatchPhase::Capture
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use url::Url;
|
|||
pub struct GitHubLspBinaryVersion {
|
||||
pub name: String,
|
||||
pub url: String,
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
|
@ -24,6 +25,7 @@ pub struct GithubRelease {
|
|||
pub struct GithubReleaseAsset {
|
||||
pub name: String,
|
||||
pub browser_download_url: String,
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn latest_github_release(
|
||||
|
|
|
@ -2353,9 +2353,9 @@ mod tests {
|
|||
assert_eq!(
|
||||
languages.language_names(),
|
||||
&[
|
||||
"JSON".to_string(),
|
||||
"Plain Text".to_string(),
|
||||
"Rust".to_string(),
|
||||
LanguageName::new("JSON"),
|
||||
LanguageName::new("Plain Text"),
|
||||
LanguageName::new("Rust"),
|
||||
]
|
||||
);
|
||||
|
||||
|
@ -2366,9 +2366,9 @@ mod tests {
|
|||
assert_eq!(
|
||||
languages.language_names(),
|
||||
&[
|
||||
"JSON".to_string(),
|
||||
"Plain Text".to_string(),
|
||||
"Rust".to_string(),
|
||||
LanguageName::new("JSON"),
|
||||
LanguageName::new("Plain Text"),
|
||||
LanguageName::new("Rust"),
|
||||
]
|
||||
);
|
||||
|
||||
|
@ -2379,9 +2379,9 @@ mod tests {
|
|||
assert_eq!(
|
||||
languages.language_names(),
|
||||
&[
|
||||
"JSON".to_string(),
|
||||
"Plain Text".to_string(),
|
||||
"Rust".to_string(),
|
||||
LanguageName::new("JSON"),
|
||||
LanguageName::new("Plain Text"),
|
||||
LanguageName::new("Rust"),
|
||||
]
|
||||
);
|
||||
|
||||
|
|
|
@ -547,15 +547,15 @@ impl LanguageRegistry {
|
|||
self.state.read().language_settings.clone()
|
||||
}
|
||||
|
||||
pub fn language_names(&self) -> Vec<String> {
|
||||
pub fn language_names(&self) -> Vec<LanguageName> {
|
||||
let state = self.state.read();
|
||||
let mut result = state
|
||||
.available_languages
|
||||
.iter()
|
||||
.filter_map(|l| l.loaded.not().then_some(l.name.to_string()))
|
||||
.chain(state.languages.iter().map(|l| l.config.name.to_string()))
|
||||
.filter_map(|l| l.loaded.not().then_some(l.name.clone()))
|
||||
.chain(state.languages.iter().map(|l| l.config.name.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
result.sort_unstable_by_key(|language_name| language_name.to_lowercase());
|
||||
result.sort_unstable_by_key(|language_name| language_name.as_ref().to_lowercase());
|
||||
result
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,12 @@ pub struct ToolUseRequest {
|
|||
pub struct FakeLanguageModel {
|
||||
provider_id: LanguageModelProviderId,
|
||||
provider_name: LanguageModelProviderName,
|
||||
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
||||
current_completion_txs: Mutex<
|
||||
Vec<(
|
||||
LanguageModelRequest,
|
||||
mpsc::UnboundedSender<LanguageModelCompletionEvent>,
|
||||
)>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Default for FakeLanguageModel {
|
||||
|
@ -118,10 +123,21 @@ impl FakeLanguageModel {
|
|||
self.current_completion_txs.lock().len()
|
||||
}
|
||||
|
||||
pub fn stream_completion_response(
|
||||
pub fn send_completion_stream_text_chunk(
|
||||
&self,
|
||||
request: &LanguageModelRequest,
|
||||
chunk: impl Into<String>,
|
||||
) {
|
||||
self.send_completion_stream_event(
|
||||
request,
|
||||
LanguageModelCompletionEvent::Text(chunk.into()),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn send_completion_stream_event(
|
||||
&self,
|
||||
request: &LanguageModelRequest,
|
||||
event: impl Into<LanguageModelCompletionEvent>,
|
||||
) {
|
||||
let current_completion_txs = self.current_completion_txs.lock();
|
||||
let tx = current_completion_txs
|
||||
|
@ -129,7 +145,7 @@ impl FakeLanguageModel {
|
|||
.find(|(req, _)| req == request)
|
||||
.map(|(_, tx)| tx)
|
||||
.unwrap();
|
||||
tx.unbounded_send(chunk.into()).unwrap();
|
||||
tx.unbounded_send(event.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
|
||||
|
@ -138,8 +154,15 @@ impl FakeLanguageModel {
|
|||
.retain(|(req, _)| req != request);
|
||||
}
|
||||
|
||||
pub fn stream_last_completion_response(&self, chunk: impl Into<String>) {
|
||||
self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
|
||||
pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) {
|
||||
self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn send_last_completion_stream_event(
|
||||
&self,
|
||||
event: impl Into<LanguageModelCompletionEvent>,
|
||||
) {
|
||||
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
|
||||
}
|
||||
|
||||
pub fn end_last_completion_stream(&self) {
|
||||
|
@ -201,12 +224,7 @@ impl LanguageModel for FakeLanguageModel {
|
|||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move {
|
||||
Ok(rx
|
||||
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &Self {
|
||||
|
|
|
@ -136,6 +136,7 @@ impl State {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
let mut current_user = user_store.read(cx).watch_current_user();
|
||||
Self {
|
||||
client: client.clone(),
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
|
@ -151,22 +152,14 @@ impl State {
|
|||
let (client, llm_api_token) = this
|
||||
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
|
||||
|
||||
loop {
|
||||
let is_authenticated = user_store
|
||||
.read_with(cx, |user_store, _cx| user_store.current_user().is_some())?;
|
||||
if is_authenticated {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(100))
|
||||
.await;
|
||||
while current_user.borrow().is_none() {
|
||||
current_user.next().await;
|
||||
}
|
||||
|
||||
let response = Self::fetch_models(client, llm_api_token).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
let response =
|
||||
Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
|
||||
this.update(cx, |this, cx| this.update_models(response, cx))?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch Zed models")
|
||||
|
@ -1267,8 +1260,16 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
|
||||
impl Component for ZedAiConfiguration {
|
||||
fn name() -> &'static str {
|
||||
"AI Configuration Content"
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AI Configuration Content"
|
||||
}
|
||||
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
ComponentScope::Onboarding
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
|
|
|
@ -121,13 +121,13 @@ impl LanguageSelectorDelegate {
|
|||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
language_registry
|
||||
.available_language_for_name(&name)?
|
||||
.available_language_for_name(name.as_ref())?
|
||||
.hidden()
|
||||
.not()
|
||||
.then_some(name)
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name))
|
||||
.map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Self {
|
||||
|
|
|
@ -36,6 +36,7 @@ load-grammars = [
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-compression.workspace = true
|
||||
async-fs.workspace = true
|
||||
async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
chrono.workspace = true
|
||||
|
@ -62,6 +63,7 @@ regex.workspace = true
|
|||
rope.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
sha2.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
|
@ -69,6 +71,7 @@ settings.workspace = true
|
|||
smol.workspace = true
|
||||
snippet_provider.workspace = true
|
||||
task.workspace = true
|
||||
tempfile.workspace = true
|
||||
toml.workspace = true
|
||||
tree-sitter = { workspace = true, optional = true }
|
||||
tree-sitter-bash = { workspace = true, optional = true }
|
||||
|
|
|
@ -2,14 +2,16 @@ use anyhow::{Context as _, Result, bail};
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, AsyncApp};
|
||||
use http_client::github::{GitHubLspBinaryVersion, latest_github_release};
|
||||
use http_client::github::{AssetKind, GitHubLspBinaryVersion, latest_github_release};
|
||||
pub use language::*;
|
||||
use lsp::{InitializeParams, LanguageServerBinary, LanguageServerName};
|
||||
use project::lsp_store::clangd_ext;
|
||||
use serde_json::json;
|
||||
use smol::fs;
|
||||
use std::{any::Any, env::consts, path::PathBuf, sync::Arc};
|
||||
use util::{ResultExt, archive::extract_zip, fs::remove_matching, maybe, merge_json_value_into};
|
||||
use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into};
|
||||
|
||||
use crate::github_download::{GithubBinaryMetadata, download_server_binary};
|
||||
|
||||
pub struct CLspAdapter;
|
||||
|
||||
|
@ -58,6 +60,7 @@ impl super::LspAdapter for CLspAdapter {
|
|||
let version = GitHubLspBinaryVersion {
|
||||
name: release.tag_name,
|
||||
url: asset.browser_download_url.clone(),
|
||||
digest: asset.digest.clone(),
|
||||
};
|
||||
Ok(Box::new(version) as Box<_>)
|
||||
}
|
||||
|
@ -68,32 +71,67 @@ impl super::LspAdapter for CLspAdapter {
|
|||
container_dir: PathBuf,
|
||||
delegate: &dyn LspAdapterDelegate,
|
||||
) -> Result<LanguageServerBinary> {
|
||||
let version = version.downcast::<GitHubLspBinaryVersion>().unwrap();
|
||||
let version_dir = container_dir.join(format!("clangd_{}", version.name));
|
||||
let GitHubLspBinaryVersion { name, url, digest } =
|
||||
&*version.downcast::<GitHubLspBinaryVersion>().unwrap();
|
||||
let version_dir = container_dir.join(format!("clangd_{name}"));
|
||||
let binary_path = version_dir.join("bin/clangd");
|
||||
|
||||
if fs::metadata(&binary_path).await.is_err() {
|
||||
let mut response = delegate
|
||||
.http_client()
|
||||
.get(&version.url, Default::default(), true)
|
||||
.await
|
||||
.context("error downloading release")?;
|
||||
anyhow::ensure!(
|
||||
response.status().is_success(),
|
||||
"download failed with status {}",
|
||||
response.status().to_string()
|
||||
);
|
||||
extract_zip(&container_dir, response.body_mut())
|
||||
.await
|
||||
.with_context(|| format!("unzipping clangd archive to {container_dir:?}"))?;
|
||||
remove_matching(&container_dir, |entry| entry != version_dir).await;
|
||||
}
|
||||
|
||||
Ok(LanguageServerBinary {
|
||||
path: binary_path,
|
||||
let binary = LanguageServerBinary {
|
||||
path: binary_path.clone(),
|
||||
env: None,
|
||||
arguments: Vec::new(),
|
||||
})
|
||||
arguments: Default::default(),
|
||||
};
|
||||
|
||||
let metadata_path = version_dir.join("metadata");
|
||||
let metadata = GithubBinaryMetadata::read_from_file(&metadata_path)
|
||||
.await
|
||||
.ok();
|
||||
if let Some(metadata) = metadata {
|
||||
let validity_check = async || {
|
||||
delegate
|
||||
.try_exec(LanguageServerBinary {
|
||||
path: binary_path.clone(),
|
||||
arguments: vec!["--version".into()],
|
||||
env: None,
|
||||
})
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",)
|
||||
})
|
||||
};
|
||||
if let (Some(actual_digest), Some(expected_digest)) = (&metadata.digest, digest) {
|
||||
if actual_digest == expected_digest {
|
||||
if validity_check().await.is_ok() {
|
||||
return Ok(binary);
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}"
|
||||
);
|
||||
}
|
||||
} else if validity_check().await.is_ok() {
|
||||
return Ok(binary);
|
||||
}
|
||||
}
|
||||
download_server_binary(
|
||||
delegate,
|
||||
url,
|
||||
digest.as_deref(),
|
||||
&container_dir,
|
||||
AssetKind::Zip,
|
||||
)
|
||||
.await?;
|
||||
remove_matching(&container_dir, |entry| entry != version_dir).await;
|
||||
GithubBinaryMetadata::write_to_file(
|
||||
&GithubBinaryMetadata {
|
||||
metadata_version: 1,
|
||||
digest: digest.clone(),
|
||||
},
|
||||
&metadata_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(binary)
|
||||
}
|
||||
|
||||
async fn cached_server_binary(
|
||||
|
|
|
@ -5,7 +5,7 @@ use gpui::AsyncApp;
|
|||
use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate};
|
||||
use lsp::{LanguageServerBinary, LanguageServerName};
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::Fs;
|
||||
use project::{Fs, lsp_store::language_server_settings};
|
||||
use serde_json::json;
|
||||
use smol::fs;
|
||||
use std::{
|
||||
|
@ -14,7 +14,7 @@ use std::{
|
|||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::{ResultExt, maybe, merge_json_value_into};
|
||||
|
||||
const SERVER_PATH: &str =
|
||||
"node_modules/vscode-langservers-extracted/bin/vscode-css-language-server";
|
||||
|
@ -134,6 +134,37 @@ impl LspAdapter for CssLspAdapter {
|
|||
"provideFormatter": true
|
||||
})))
|
||||
}
|
||||
|
||||
async fn workspace_configuration(
|
||||
self: Arc<Self>,
|
||||
_: &dyn Fs,
|
||||
delegate: &Arc<dyn LspAdapterDelegate>,
|
||||
_: Arc<dyn LanguageToolchainStore>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut default_config = json!({
|
||||
"css": {
|
||||
"lint": {}
|
||||
},
|
||||
"less": {
|
||||
"lint": {}
|
||||
},
|
||||
"scss": {
|
||||
"lint": {}
|
||||
}
|
||||
});
|
||||
|
||||
let project_options = cx.update(|cx| {
|
||||
language_server_settings(delegate.as_ref(), &self.name(), cx)
|
||||
.and_then(|s| s.settings.clone())
|
||||
})?;
|
||||
|
||||
if let Some(override_options) = project_options {
|
||||
merge_json_value_into(override_options, &mut default_config);
|
||||
}
|
||||
|
||||
Ok(default_config)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_cached_server_binary(
|
||||
|
|
190
crates/languages/src/github_download.rs
Normal file
190
crates/languages/src/github_download.rs
Normal file
|
@ -0,0 +1,190 @@
|
|||
use std::{path::Path, pin::Pin, task::Poll};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
|
||||
use http_client::github::AssetKind;
|
||||
use language::LspAdapterDelegate;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug)]
|
||||
pub(crate) struct GithubBinaryMetadata {
|
||||
pub(crate) metadata_version: u64,
|
||||
pub(crate) digest: Option<String>,
|
||||
}
|
||||
|
||||
impl GithubBinaryMetadata {
|
||||
pub(crate) async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
|
||||
let metadata_content = async_fs::read_to_string(metadata_path)
|
||||
.await
|
||||
.with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
|
||||
let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content)
|
||||
.with_context(|| format!("parsing metadata file at {metadata_path:?}"))?;
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
|
||||
let metadata_content = serde_json::to_string(self)
|
||||
.with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
|
||||
async_fs::write(metadata_path, metadata_content.as_bytes())
|
||||
.await
|
||||
.with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn download_server_binary(
|
||||
delegate: &dyn LspAdapterDelegate,
|
||||
url: &str,
|
||||
digest: Option<&str>,
|
||||
destination_path: &Path,
|
||||
asset_kind: AssetKind,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
log::info!("downloading github artifact from {url}");
|
||||
let mut response = delegate
|
||||
.http_client()
|
||||
.get(url, Default::default(), true)
|
||||
.await
|
||||
.with_context(|| format!("downloading release from {url}"))?;
|
||||
let body = response.body_mut();
|
||||
match digest {
|
||||
Some(expected_sha_256) => {
|
||||
let temp_asset_file = tempfile::NamedTempFile::new()
|
||||
.with_context(|| format!("creating a temporary file for {url}"))?;
|
||||
let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
|
||||
let mut writer = HashingWriter {
|
||||
writer: async_fs::File::from(temp_asset_file),
|
||||
hasher: Sha256::new(),
|
||||
};
|
||||
futures::io::copy(&mut BufReader::new(body), &mut writer)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("saving archive contents into the temporary file for {url}",)
|
||||
})?;
|
||||
let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
|
||||
anyhow::ensure!(
|
||||
asset_sha_256 == expected_sha_256,
|
||||
"{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
|
||||
);
|
||||
writer
|
||||
.writer
|
||||
.seek(std::io::SeekFrom::Start(0))
|
||||
.await
|
||||
.with_context(|| format!("seeking temporary file {destination_path:?}",))?;
|
||||
stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("extracting downloaded asset for {url} into {destination_path:?}",)
|
||||
})?;
|
||||
}
|
||||
None => stream_response_archive(body, url, destination_path, asset_kind)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("extracting response for asset {url} into {destination_path:?}",)
|
||||
})?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_response_archive(
|
||||
response: impl AsyncRead + Unpin,
|
||||
url: &str,
|
||||
destination_path: &Path,
|
||||
asset_kind: AssetKind,
|
||||
) -> Result<()> {
|
||||
match asset_kind {
|
||||
AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
|
||||
AssetKind::Gz => extract_gz(destination_path, url, response).await?,
|
||||
AssetKind::Zip => {
|
||||
util::archive::extract_zip(&destination_path, response).await?;
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_file_archive(
|
||||
file_archive: impl AsyncRead + AsyncSeek + Unpin,
|
||||
url: &str,
|
||||
destination_path: &Path,
|
||||
asset_kind: AssetKind,
|
||||
) -> Result<()> {
|
||||
match asset_kind {
|
||||
AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
|
||||
AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
|
||||
#[cfg(not(windows))]
|
||||
AssetKind::Zip => {
|
||||
util::archive::extract_seekable_zip(&destination_path, file_archive).await?;
|
||||
}
|
||||
#[cfg(windows)]
|
||||
AssetKind::Zip => {
|
||||
util::archive::extract_zip(&destination_path, file_archive).await?;
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn extract_tar_gz(
|
||||
destination_path: &Path,
|
||||
url: &str,
|
||||
from: impl AsyncRead + Unpin,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
|
||||
let archive = async_tar::Archive::new(decompressed_bytes);
|
||||
archive
|
||||
.unpack(&destination_path)
|
||||
.await
|
||||
.with_context(|| format!("extracting {url} to {destination_path:?}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn extract_gz(
|
||||
destination_path: &Path,
|
||||
url: &str,
|
||||
from: impl AsyncRead + Unpin,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
|
||||
let mut file = smol::fs::File::create(&destination_path)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("creating a file {destination_path:?} for a download from {url}")
|
||||
})?;
|
||||
futures::io::copy(&mut decompressed_bytes, &mut file)
|
||||
.await
|
||||
.with_context(|| format!("extracting {url} to {destination_path:?}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct HashingWriter<W: AsyncWrite + Unpin> {
|
||||
writer: W,
|
||||
hasher: Sha256,
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::result::Result<usize, std::io::Error>> {
|
||||
match Pin::new(&mut self.writer).poll_write(cx, buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
self.hasher.update(&buf[..n]);
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.writer).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.writer).poll_close(cx)
|
||||
}
|
||||
}
|
|
@ -269,7 +269,15 @@ impl JsonLspAdapter {
|
|||
.await;
|
||||
|
||||
let config = cx.update(|cx| {
|
||||
Self::get_workspace_config(self.languages.language_names().clone(), adapter_schemas, cx)
|
||||
Self::get_workspace_config(
|
||||
self.languages
|
||||
.language_names()
|
||||
.into_iter()
|
||||
.map(|name| name.to_string())
|
||||
.collect(),
|
||||
adapter_schemas,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
writer.replace(config.clone());
|
||||
return Ok(config);
|
||||
|
@ -509,6 +517,7 @@ impl LspAdapter for NodeVersionAdapter {
|
|||
Ok(Box::new(GitHubLspBinaryVersion {
|
||||
name: release.tag_name,
|
||||
url: asset.browser_download_url.clone(),
|
||||
digest: asset.digest.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter};
|
|||
mod bash;
|
||||
mod c;
|
||||
mod css;
|
||||
mod github_download;
|
||||
mod go;
|
||||
mod json;
|
||||
mod package_json;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
use futures::{StreamExt, io::BufReader};
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, AppContext, AsyncApp, SharedString, Task};
|
||||
use http_client::github::AssetKind;
|
||||
use http_client::github::{GitHubLspBinaryVersion, latest_github_release};
|
||||
|
@ -23,14 +22,11 @@ use std::{
|
|||
sync::{Arc, LazyLock},
|
||||
};
|
||||
use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName};
|
||||
use util::archive::extract_zip;
|
||||
use util::fs::make_file_executable;
|
||||
use util::merge_json_value_into;
|
||||
use util::{
|
||||
ResultExt,
|
||||
fs::{make_file_executable, remove_matching},
|
||||
maybe,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
use crate::github_download::{GithubBinaryMetadata, download_server_binary};
|
||||
use crate::language_settings::language_settings;
|
||||
|
||||
pub struct RustLspAdapter;
|
||||
|
@ -163,7 +159,6 @@ impl LspAdapter for RustLspAdapter {
|
|||
)
|
||||
.await?;
|
||||
let asset_name = Self::build_asset_name();
|
||||
|
||||
let asset = release
|
||||
.assets
|
||||
.iter()
|
||||
|
@ -172,6 +167,7 @@ impl LspAdapter for RustLspAdapter {
|
|||
Ok(Box::new(GitHubLspBinaryVersion {
|
||||
name: release.tag_name,
|
||||
url: asset.browser_download_url.clone(),
|
||||
digest: asset.digest.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -181,58 +177,76 @@ impl LspAdapter for RustLspAdapter {
|
|||
container_dir: PathBuf,
|
||||
delegate: &dyn LspAdapterDelegate,
|
||||
) -> Result<LanguageServerBinary> {
|
||||
let version = version.downcast::<GitHubLspBinaryVersion>().unwrap();
|
||||
let destination_path = container_dir.join(format!("rust-analyzer-{}", version.name));
|
||||
let GitHubLspBinaryVersion { name, url, digest } =
|
||||
&*version.downcast::<GitHubLspBinaryVersion>().unwrap();
|
||||
let expected_digest = digest
|
||||
.as_ref()
|
||||
.and_then(|digest| digest.strip_prefix("sha256:"));
|
||||
let destination_path = container_dir.join(format!("rust-analyzer-{name}"));
|
||||
let server_path = match Self::GITHUB_ASSET_KIND {
|
||||
AssetKind::TarGz | AssetKind::Gz => destination_path.clone(), // Tar and gzip extract in place.
|
||||
AssetKind::Zip => destination_path.clone().join("rust-analyzer.exe"), // zip contains a .exe
|
||||
};
|
||||
|
||||
if fs::metadata(&server_path).await.is_err() {
|
||||
remove_matching(&container_dir, |entry| entry != destination_path).await;
|
||||
let binary = LanguageServerBinary {
|
||||
path: server_path.clone(),
|
||||
env: None,
|
||||
arguments: Default::default(),
|
||||
};
|
||||
|
||||
let mut response = delegate
|
||||
.http_client()
|
||||
.get(&version.url, Default::default(), true)
|
||||
.await
|
||||
.with_context(|| format!("downloading release from {}", version.url))?;
|
||||
match Self::GITHUB_ASSET_KIND {
|
||||
AssetKind::TarGz => {
|
||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let archive = async_tar::Archive::new(decompressed_bytes);
|
||||
archive.unpack(&destination_path).await.with_context(|| {
|
||||
format!("extracting {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
AssetKind::Gz => {
|
||||
let mut decompressed_bytes =
|
||||
GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let mut file =
|
||||
fs::File::create(&destination_path).await.with_context(|| {
|
||||
format!(
|
||||
"creating a file {:?} for a download from {}",
|
||||
destination_path, version.url,
|
||||
)
|
||||
})?;
|
||||
futures::io::copy(&mut decompressed_bytes, &mut file)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("extracting {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
AssetKind::Zip => {
|
||||
extract_zip(&destination_path, response.body_mut())
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("unzipping {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
let metadata_path = destination_path.with_extension("metadata");
|
||||
let metadata = GithubBinaryMetadata::read_from_file(&metadata_path)
|
||||
.await
|
||||
.ok();
|
||||
if let Some(metadata) = metadata {
|
||||
let validity_check = async || {
|
||||
delegate
|
||||
.try_exec(LanguageServerBinary {
|
||||
path: server_path.clone(),
|
||||
arguments: vec!["--version".into()],
|
||||
env: None,
|
||||
})
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
log::warn!("Unable to run {server_path:?} asset, redownloading: {err}",)
|
||||
})
|
||||
};
|
||||
|
||||
// todo("windows")
|
||||
make_file_executable(&server_path).await?;
|
||||
if let (Some(actual_digest), Some(expected_digest)) =
|
||||
(&metadata.digest, expected_digest)
|
||||
{
|
||||
if actual_digest == expected_digest {
|
||||
if validity_check().await.is_ok() {
|
||||
return Ok(binary);
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"SHA-256 mismatch for {destination_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}"
|
||||
);
|
||||
}
|
||||
} else if validity_check().await.is_ok() {
|
||||
return Ok(binary);
|
||||
}
|
||||
}
|
||||
|
||||
_ = fs::remove_dir_all(&destination_path).await;
|
||||
download_server_binary(
|
||||
delegate,
|
||||
url,
|
||||
expected_digest,
|
||||
&destination_path,
|
||||
Self::GITHUB_ASSET_KIND,
|
||||
)
|
||||
.await?;
|
||||
make_file_executable(&server_path).await?;
|
||||
GithubBinaryMetadata::write_to_file(
|
||||
&GithubBinaryMetadata {
|
||||
metadata_version: 1,
|
||||
digest: expected_digest.map(ToString::to_string),
|
||||
},
|
||||
&metadata_path,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(LanguageServerBinary {
|
||||
path: server_path,
|
||||
env: None,
|
||||
|
@ -1025,7 +1039,11 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ
|
|||
let mut last = None;
|
||||
let mut entries = fs::read_dir(&container_dir).await?;
|
||||
while let Some(entry) = entries.next().await {
|
||||
last = Some(entry?.path());
|
||||
let path = entry?.path();
|
||||
if path.extension().is_some_and(|ext| ext == "metadata") {
|
||||
continue;
|
||||
}
|
||||
last = Some(path);
|
||||
}
|
||||
|
||||
anyhow::Ok(LanguageServerBinary {
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Local};
|
||||
use collections::HashMap;
|
||||
|
@ -15,7 +13,7 @@ use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName};
|
|||
use node_runtime::NodeRuntime;
|
||||
use project::{Fs, lsp_store::language_server_settings};
|
||||
use serde_json::{Value, json};
|
||||
use smol::{fs, io::BufReader, lock::RwLock, stream::StreamExt};
|
||||
use smol::{fs, lock::RwLock, stream::StreamExt};
|
||||
use std::{
|
||||
any::Any,
|
||||
borrow::Cow,
|
||||
|
@ -24,11 +22,10 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
use task::{TaskTemplate, TaskTemplates, VariableName};
|
||||
use util::archive::extract_zip;
|
||||
use util::merge_json_value_into;
|
||||
use util::{ResultExt, fs::remove_matching, maybe};
|
||||
|
||||
use crate::{PackageJson, PackageJsonData};
|
||||
use crate::{PackageJson, PackageJsonData, github_download::download_server_binary};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TypeScriptContextProvider {
|
||||
|
@ -897,6 +894,7 @@ impl LspAdapter for EsLintLspAdapter {
|
|||
|
||||
Ok(Box::new(GitHubLspBinaryVersion {
|
||||
name: Self::CURRENT_VERSION.into(),
|
||||
digest: None,
|
||||
url,
|
||||
}))
|
||||
}
|
||||
|
@ -914,43 +912,14 @@ impl LspAdapter for EsLintLspAdapter {
|
|||
if fs::metadata(&server_path).await.is_err() {
|
||||
remove_matching(&container_dir, |entry| entry != destination_path).await;
|
||||
|
||||
let mut response = delegate
|
||||
.http_client()
|
||||
.get(&version.url, Default::default(), true)
|
||||
.await
|
||||
.context("downloading release")?;
|
||||
match Self::GITHUB_ASSET_KIND {
|
||||
AssetKind::TarGz => {
|
||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let archive = Archive::new(decompressed_bytes);
|
||||
archive.unpack(&destination_path).await.with_context(|| {
|
||||
format!("extracting {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
AssetKind::Gz => {
|
||||
let mut decompressed_bytes =
|
||||
GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let mut file =
|
||||
fs::File::create(&destination_path).await.with_context(|| {
|
||||
format!(
|
||||
"creating a file {:?} for a download from {}",
|
||||
destination_path, version.url,
|
||||
)
|
||||
})?;
|
||||
futures::io::copy(&mut decompressed_bytes, &mut file)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("extracting {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
AssetKind::Zip => {
|
||||
extract_zip(&destination_path, response.body_mut())
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("unzipping {} to {:?}", version.url, destination_path)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
download_server_binary(
|
||||
delegate,
|
||||
&version.url,
|
||||
None,
|
||||
&destination_path,
|
||||
Self::GITHUB_ASSET_KIND,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut dir = fs::read_dir(&destination_path).await?;
|
||||
let first = dir.next().await.context("missing first file")??;
|
||||
|
|
|
@ -18,6 +18,7 @@ use workspace::item::{Item, ItemHandle};
|
|||
use workspace::{Pane, Workspace};
|
||||
|
||||
use crate::markdown_elements::ParsedMarkdownElement;
|
||||
use crate::markdown_renderer::CheckboxClickedEvent;
|
||||
use crate::{
|
||||
MovePageDown, MovePageUp, OpenFollowingPreview, OpenPreview, OpenPreviewToTheSide,
|
||||
markdown_elements::ParsedMarkdown,
|
||||
|
@ -203,114 +204,7 @@ impl MarkdownPreviewView {
|
|||
cx: &mut Context<Workspace>,
|
||||
) -> Entity<Self> {
|
||||
cx.new(|cx| {
|
||||
let view = cx.entity().downgrade();
|
||||
|
||||
let list_state = ListState::new(
|
||||
0,
|
||||
gpui::ListAlignment::Top,
|
||||
px(1000.),
|
||||
move |ix, window, cx| {
|
||||
if let Some(view) = view.upgrade() {
|
||||
view.update(cx, |this: &mut Self, cx| {
|
||||
let Some(contents) = &this.contents else {
|
||||
return div().into_any();
|
||||
};
|
||||
|
||||
let mut render_cx =
|
||||
RenderContext::new(Some(this.workspace.clone()), window, cx)
|
||||
.with_checkbox_clicked_callback({
|
||||
let view = view.clone();
|
||||
move |checked, source_range, window, cx| {
|
||||
view.update(cx, |view, cx| {
|
||||
if let Some(editor) = view
|
||||
.active_editor
|
||||
.as_ref()
|
||||
.map(|s| s.editor.clone())
|
||||
{
|
||||
editor.update(cx, |editor, cx| {
|
||||
let task_marker =
|
||||
if checked { "[x]" } else { "[ ]" };
|
||||
|
||||
editor.edit(
|
||||
vec![(source_range, task_marker)],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
view.parse_markdown_from_active_editor(
|
||||
false, window, cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let block = contents.children.get(ix).unwrap();
|
||||
let rendered_block = render_markdown_block(block, &mut render_cx);
|
||||
|
||||
let should_apply_padding = Self::should_apply_padding_between(
|
||||
block,
|
||||
contents.children.get(ix + 1),
|
||||
);
|
||||
|
||||
div()
|
||||
.id(ix)
|
||||
.when(should_apply_padding, |this| {
|
||||
this.pb(render_cx.scaled_rems(0.75))
|
||||
})
|
||||
.group("markdown-block")
|
||||
.on_click(cx.listener(
|
||||
move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count == 2 {
|
||||
if let Some(source_range) = this
|
||||
.contents
|
||||
.as_ref()
|
||||
.and_then(|c| c.children.get(ix))
|
||||
.and_then(|block| block.source_range())
|
||||
{
|
||||
this.move_cursor_to_block(
|
||||
window,
|
||||
cx,
|
||||
source_range.start..source_range.start,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
.map(move |container| {
|
||||
let indicator = div()
|
||||
.h_full()
|
||||
.w(px(4.0))
|
||||
.when(ix == this.selected_block, |this| {
|
||||
this.bg(cx.theme().colors().border)
|
||||
})
|
||||
.group_hover("markdown-block", |s| {
|
||||
if ix == this.selected_block {
|
||||
s
|
||||
} else {
|
||||
s.bg(cx.theme().colors().border_variant)
|
||||
}
|
||||
})
|
||||
.rounded_xs();
|
||||
|
||||
container.child(
|
||||
div()
|
||||
.relative()
|
||||
.child(
|
||||
div()
|
||||
.pl(render_cx.scaled_rems(1.0))
|
||||
.child(rendered_block),
|
||||
)
|
||||
.child(indicator.absolute().left_0().top_0()),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
})
|
||||
} else {
|
||||
div().into_any()
|
||||
}
|
||||
},
|
||||
);
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
|
||||
|
||||
let mut this = Self {
|
||||
selected_block: 0,
|
||||
|
@ -607,10 +501,107 @@ impl Render for MarkdownPreviewView {
|
|||
.p_4()
|
||||
.text_size(buffer_size)
|
||||
.line_height(relative(buffer_line_height.value()))
|
||||
.child(
|
||||
div()
|
||||
.flex_grow()
|
||||
.map(|this| this.child(list(self.list_state.clone()).size_full())),
|
||||
)
|
||||
.child(div().flex_grow().map(|this| {
|
||||
this.child(
|
||||
list(
|
||||
self.list_state.clone(),
|
||||
cx.processor(|this, ix, window, cx| {
|
||||
let Some(contents) = &this.contents else {
|
||||
return div().into_any();
|
||||
};
|
||||
|
||||
let mut render_cx =
|
||||
RenderContext::new(Some(this.workspace.clone()), window, cx)
|
||||
.with_checkbox_clicked_callback(cx.listener(
|
||||
move |this, e: &CheckboxClickedEvent, window, cx| {
|
||||
if let Some(editor) = this
|
||||
.active_editor
|
||||
.as_ref()
|
||||
.map(|s| s.editor.clone())
|
||||
{
|
||||
editor.update(cx, |editor, cx| {
|
||||
let task_marker =
|
||||
if e.checked() { "[x]" } else { "[ ]" };
|
||||
|
||||
editor.edit(
|
||||
vec![(e.source_range(), task_marker)],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
this.parse_markdown_from_active_editor(
|
||||
false, window, cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
let block = contents.children.get(ix).unwrap();
|
||||
let rendered_block = render_markdown_block(block, &mut render_cx);
|
||||
|
||||
let should_apply_padding = Self::should_apply_padding_between(
|
||||
block,
|
||||
contents.children.get(ix + 1),
|
||||
);
|
||||
|
||||
div()
|
||||
.id(ix)
|
||||
.when(should_apply_padding, |this| {
|
||||
this.pb(render_cx.scaled_rems(0.75))
|
||||
})
|
||||
.group("markdown-block")
|
||||
.on_click(cx.listener(
|
||||
move |this, event: &ClickEvent, window, cx| {
|
||||
if event.click_count() == 2 {
|
||||
if let Some(source_range) = this
|
||||
.contents
|
||||
.as_ref()
|
||||
.and_then(|c| c.children.get(ix))
|
||||
.and_then(|block: &ParsedMarkdownElement| {
|
||||
block.source_range()
|
||||
})
|
||||
{
|
||||
this.move_cursor_to_block(
|
||||
window,
|
||||
cx,
|
||||
source_range.start..source_range.start,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
.map(move |container| {
|
||||
let indicator = div()
|
||||
.h_full()
|
||||
.w(px(4.0))
|
||||
.when(ix == this.selected_block, |this| {
|
||||
this.bg(cx.theme().colors().border)
|
||||
})
|
||||
.group_hover("markdown-block", |s| {
|
||||
if ix == this.selected_block {
|
||||
s
|
||||
} else {
|
||||
s.bg(cx.theme().colors().border_variant)
|
||||
}
|
||||
})
|
||||
.rounded_xs();
|
||||
|
||||
container.child(
|
||||
div()
|
||||
.relative()
|
||||
.child(
|
||||
div()
|
||||
.pl(render_cx.scaled_rems(1.0))
|
||||
.child(rendered_block),
|
||||
)
|
||||
.child(indicator.absolute().left_0().top_0()),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
}),
|
||||
)
|
||||
.size_full(),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue