acp: Model-specific prompt capabilities for 1PA (#36879)
Adds support for per-session prompt capabilities and capability changes on the Zed side (ACP itself still only has per-connection static capabilities for now), and uses it to reflect image support accurately in 1PA threads based on the currently-selected model. Release Notes: - N/A
This commit is contained in:
parent
f1204dfc33
commit
5fd29d37a6
10 changed files with 98 additions and 53 deletions
|
@ -756,6 +756,8 @@ pub struct AcpThread {
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
token_usage: Option<TokenUsage>,
|
token_usage: Option<TokenUsage>,
|
||||||
|
prompt_capabilities: acp::PromptCapabilities,
|
||||||
|
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -770,6 +772,7 @@ pub enum AcpThreadEvent {
|
||||||
Stopped,
|
Stopped,
|
||||||
Error,
|
Error,
|
||||||
LoadError(LoadError),
|
LoadError(LoadError),
|
||||||
|
PromptCapabilitiesUpdated,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||||
|
@ -821,7 +824,20 @@ impl AcpThread {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
|
mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let prompt_capabilities = *prompt_capabilities_rx.borrow();
|
||||||
|
let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
|
||||||
|
loop {
|
||||||
|
let caps = prompt_capabilities_rx.recv().await?;
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.prompt_capabilities = caps;
|
||||||
|
cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
action_log,
|
action_log,
|
||||||
shared_buffers: Default::default(),
|
shared_buffers: Default::default(),
|
||||||
|
@ -833,9 +849,15 @@ impl AcpThread {
|
||||||
connection,
|
connection,
|
||||||
session_id,
|
session_id,
|
||||||
token_usage: None,
|
token_usage: None,
|
||||||
|
prompt_capabilities,
|
||||||
|
_observe_prompt_capabilities: task,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||||
|
self.prompt_capabilities
|
||||||
|
}
|
||||||
|
|
||||||
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
|
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
|
||||||
&self.connection
|
&self.connection
|
||||||
}
|
}
|
||||||
|
@ -2599,13 +2621,19 @@ mod tests {
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let thread = cx.new(|_cx| {
|
let thread = cx.new(|cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
"Test",
|
"Test",
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
watch::Receiver::constant(acp::PromptCapabilities {
|
||||||
|
image: true,
|
||||||
|
audio: true,
|
||||||
|
embedded_context: true,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||||
|
@ -2639,14 +2667,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
acp::PromptCapabilities {
|
|
||||||
image: true,
|
|
||||||
audio: true,
|
|
||||||
embedded_context: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
let sessions = self.sessions.lock();
|
let sessions = self.sessions.lock();
|
||||||
let thread = sessions.get(session_id).unwrap().clone();
|
let thread = sessions.get(session_id).unwrap().clone();
|
||||||
|
|
|
@ -38,8 +38,6 @@ pub trait AgentConnection {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>>;
|
) -> Task<Result<acp::PromptResponse>>;
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities;
|
|
||||||
|
|
||||||
fn resume(
|
fn resume(
|
||||||
&self,
|
&self,
|
||||||
_session_id: &acp::SessionId,
|
_session_id: &acp::SessionId,
|
||||||
|
@ -329,13 +327,19 @@ mod test_support {
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let thread = cx.new(|_cx| {
|
let thread = cx.new(|cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
"Test",
|
"Test",
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
watch::Receiver::constant(acp::PromptCapabilities {
|
||||||
|
image: true,
|
||||||
|
audio: true,
|
||||||
|
embedded_context: true,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
self.sessions.lock().insert(
|
self.sessions.lock().insert(
|
||||||
|
@ -348,14 +352,6 @@ mod test_support {
|
||||||
Task::ready(Ok(thread))
|
Task::ready(Ok(thread))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
acp::PromptCapabilities {
|
|
||||||
image: true,
|
|
||||||
audio: true,
|
|
||||||
embedded_context: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(
|
fn authenticate(
|
||||||
&self,
|
&self,
|
||||||
_method_id: acp::AuthMethodId,
|
_method_id: acp::AuthMethodId,
|
||||||
|
|
|
@ -240,13 +240,16 @@ impl NativeAgent {
|
||||||
let title = thread.title();
|
let title = thread.title();
|
||||||
let project = thread.project.clone();
|
let project = thread.project.clone();
|
||||||
let action_log = thread.action_log.clone();
|
let action_log = thread.action_log.clone();
|
||||||
let acp_thread = cx.new(|_cx| {
|
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
|
||||||
|
let acp_thread = cx.new(|cx| {
|
||||||
acp_thread::AcpThread::new(
|
acp_thread::AcpThread::new(
|
||||||
title,
|
title,
|
||||||
connection,
|
connection,
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
prompt_capabilities_rx,
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let subscriptions = vec![
|
let subscriptions = vec![
|
||||||
|
@ -925,14 +928,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
acp::PromptCapabilities {
|
|
||||||
image: true,
|
|
||||||
audio: false,
|
|
||||||
embedded_context: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resume(
|
fn resume(
|
||||||
&self,
|
&self,
|
||||||
session_id: &acp::SessionId,
|
session_id: &acp::SessionId,
|
||||||
|
|
|
@ -575,11 +575,22 @@ pub struct Thread {
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
model: Option<Arc<dyn LanguageModel>>,
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
|
||||||
|
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
|
||||||
pub(crate) project: Entity<Project>,
|
pub(crate) project: Entity<Project>,
|
||||||
pub(crate) action_log: Entity<ActionLog>,
|
pub(crate) action_log: Entity<ActionLog>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
|
fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
|
||||||
|
let image = model.map_or(true, |model| model.supports_images());
|
||||||
|
acp::PromptCapabilities {
|
||||||
|
image,
|
||||||
|
audio: false,
|
||||||
|
embedded_context: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
project_context: Entity<ProjectContext>,
|
project_context: Entity<ProjectContext>,
|
||||||
|
@ -590,6 +601,8 @@ impl Thread {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||||
|
let (prompt_capabilities_tx, prompt_capabilities_rx) =
|
||||||
|
watch::channel(Self::prompt_capabilities(model.as_deref()));
|
||||||
Self {
|
Self {
|
||||||
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
|
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
|
||||||
prompt_id: PromptId::new(),
|
prompt_id: PromptId::new(),
|
||||||
|
@ -617,6 +630,8 @@ impl Thread {
|
||||||
templates,
|
templates,
|
||||||
model,
|
model,
|
||||||
summarization_model: None,
|
summarization_model: None,
|
||||||
|
prompt_capabilities_tx,
|
||||||
|
prompt_capabilities_rx,
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
}
|
}
|
||||||
|
@ -750,6 +765,8 @@ impl Thread {
|
||||||
.or_else(|| registry.default_model())
|
.or_else(|| registry.default_model())
|
||||||
.map(|model| model.model)
|
.map(|model| model.model)
|
||||||
});
|
});
|
||||||
|
let (prompt_capabilities_tx, prompt_capabilities_rx) =
|
||||||
|
watch::channel(Self::prompt_capabilities(model.as_deref()));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
|
@ -779,6 +796,8 @@ impl Thread {
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
updated_at: db_thread.updated_at,
|
updated_at: db_thread.updated_at,
|
||||||
|
prompt_capabilities_tx,
|
||||||
|
prompt_capabilities_rx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -946,10 +965,12 @@ impl Thread {
|
||||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
|
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
|
||||||
let old_usage = self.latest_token_usage();
|
let old_usage = self.latest_token_usage();
|
||||||
self.model = Some(model);
|
self.model = Some(model);
|
||||||
|
let new_caps = Self::prompt_capabilities(self.model.as_deref());
|
||||||
let new_usage = self.latest_token_usage();
|
let new_usage = self.latest_token_usage();
|
||||||
if old_usage != new_usage {
|
if old_usage != new_usage {
|
||||||
cx.emit(TokenUsageUpdated(new_usage));
|
cx.emit(TokenUsageUpdated(new_usage));
|
||||||
}
|
}
|
||||||
|
self.prompt_capabilities_tx.send(new_caps).log_err();
|
||||||
cx.notify()
|
cx.notify()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -185,13 +185,16 @@ impl AgentConnection for AcpConnection {
|
||||||
|
|
||||||
let session_id = response.session_id;
|
let session_id = response.session_id;
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||||
let thread = cx.new(|_cx| {
|
let thread = cx.new(|cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
self.server_name.clone(),
|
self.server_name.clone(),
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
// ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
|
||||||
|
watch::Receiver::constant(self.prompt_capabilities),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -279,10 +282,6 @@ impl AgentConnection for AcpConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
self.prompt_capabilities
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
|
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
|
||||||
session.suppress_abort_err = true;
|
session.suppress_abort_err = true;
|
||||||
|
|
|
@ -249,13 +249,19 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
});
|
});
|
||||||
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||||
let thread = cx.new(|_cx| {
|
let thread = cx.new(|cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
"Claude Code",
|
"Claude Code",
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
watch::Receiver::constant(acp::PromptCapabilities {
|
||||||
|
image: true,
|
||||||
|
audio: false,
|
||||||
|
embedded_context: true,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -319,14 +325,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
cx.foreground_executor().spawn(async move { end_rx.await? })
|
cx.foreground_executor().spawn(async move { end_rx.await? })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
acp::PromptCapabilities {
|
|
||||||
image: true,
|
|
||||||
audio: false,
|
|
||||||
embedded_context: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||||
let sessions = self.sessions.borrow();
|
let sessions = self.sessions.borrow();
|
||||||
let Some(session) = sessions.get(session_id) else {
|
let Some(session) = sessions.get(session_id) else {
|
||||||
|
|
|
@ -373,7 +373,7 @@ impl MessageEditor {
|
||||||
|
|
||||||
if Img::extensions().contains(&extension) && !extension.contains("svg") {
|
if Img::extensions().contains(&extension) && !extension.contains("svg") {
|
||||||
if !self.prompt_capabilities.get().image {
|
if !self.prompt_capabilities.get().image {
|
||||||
return Task::ready(Err(anyhow!("This agent does not support images yet")));
|
return Task::ready(Err(anyhow!("This model does not support images yet")));
|
||||||
}
|
}
|
||||||
let task = self
|
let task = self
|
||||||
.project
|
.project
|
||||||
|
|
|
@ -474,7 +474,7 @@ impl AcpThreadView {
|
||||||
let action_log = thread.read(cx).action_log().clone();
|
let action_log = thread.read(cx).action_log().clone();
|
||||||
|
|
||||||
this.prompt_capabilities
|
this.prompt_capabilities
|
||||||
.set(connection.prompt_capabilities());
|
.set(thread.read(cx).prompt_capabilities());
|
||||||
|
|
||||||
let count = thread.read(cx).entries().len();
|
let count = thread.read(cx).entries().len();
|
||||||
this.list_state.splice(0..0, count);
|
this.list_state.splice(0..0, count);
|
||||||
|
@ -1163,6 +1163,10 @@ impl AcpThreadView {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
AcpThreadEvent::PromptCapabilitiesUpdated => {
|
||||||
|
self.prompt_capabilities
|
||||||
|
.set(thread.read(cx).prompt_capabilities());
|
||||||
|
}
|
||||||
AcpThreadEvent::TokenUsageUpdated => {}
|
AcpThreadEvent::TokenUsageUpdated => {}
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -5367,6 +5371,12 @@ pub(crate) mod tests {
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
SessionId("test".into()),
|
SessionId("test".into()),
|
||||||
|
watch::Receiver::constant(acp::PromptCapabilities {
|
||||||
|
image: true,
|
||||||
|
audio: true,
|
||||||
|
embedded_context: true,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
@ -5375,14 +5385,6 @@ pub(crate) mod tests {
|
||||||
&[]
|
&[]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
|
||||||
acp::PromptCapabilities {
|
|
||||||
image: true,
|
|
||||||
audio: true,
|
|
||||||
embedded_context: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(
|
fn authenticate(
|
||||||
&self,
|
&self,
|
||||||
_method_id: acp::AuthMethodId,
|
_method_id: acp::AuthMethodId,
|
||||||
|
|
|
@ -1529,6 +1529,7 @@ impl AgentDiff {
|
||||||
| AcpThreadEvent::TokenUsageUpdated
|
| AcpThreadEvent::TokenUsageUpdated
|
||||||
| AcpThreadEvent::EntriesRemoved(_)
|
| AcpThreadEvent::EntriesRemoved(_)
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
|
| AcpThreadEvent::PromptCapabilitiesUpdated
|
||||||
| AcpThreadEvent::Retry(_) => {}
|
| AcpThreadEvent::Retry(_) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,6 +162,19 @@ impl<T> Receiver<T> {
|
||||||
pending_waker_id: None,
|
pending_waker_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates a new [`Receiver`] holding an initial value that will never change.
|
||||||
|
pub fn constant(value: T) -> Self {
|
||||||
|
let state = Arc::new(RwLock::new(State {
|
||||||
|
value,
|
||||||
|
wakers: BTreeMap::new(),
|
||||||
|
next_waker_id: WakerId::default(),
|
||||||
|
version: 0,
|
||||||
|
closed: false,
|
||||||
|
}));
|
||||||
|
|
||||||
|
Self { state, version: 0 }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone> Receiver<T> {
|
impl<T: Clone> Receiver<T> {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue