acp: Model-specific prompt capabilities for 1PA (#36879)

Adds support for per-session prompt capabilities and capability changes
on the Zed side (ACP itself still only has per-connection static
capabilities for now), and uses it to reflect image support accurately
in 1PA threads based on the currently-selected model.

Release Notes:

- N/A
This commit is contained in:
Cole Miller 2025-08-25 14:28:11 -04:00 committed by GitHub
parent f1204dfc33
commit 5fd29d37a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 98 additions and 53 deletions

View file

@ -756,6 +756,8 @@ pub struct AcpThread {
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId, session_id: acp::SessionId,
token_usage: Option<TokenUsage>, token_usage: Option<TokenUsage>,
prompt_capabilities: acp::PromptCapabilities,
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -770,6 +772,7 @@ pub enum AcpThreadEvent {
Stopped, Stopped,
Error, Error,
LoadError(LoadError), LoadError(LoadError),
PromptCapabilitiesUpdated,
} }
impl EventEmitter<AcpThreadEvent> for AcpThread {} impl EventEmitter<AcpThreadEvent> for AcpThread {}
@ -821,7 +824,20 @@ impl AcpThread {
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
session_id: acp::SessionId, session_id: acp::SessionId,
mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
cx: &mut Context<Self>,
) -> Self { ) -> Self {
let prompt_capabilities = *prompt_capabilities_rx.borrow();
let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
loop {
let caps = prompt_capabilities_rx.recv().await?;
this.update(cx, |this, cx| {
this.prompt_capabilities = caps;
cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
})?;
}
});
Self { Self {
action_log, action_log,
shared_buffers: Default::default(), shared_buffers: Default::default(),
@ -833,9 +849,15 @@ impl AcpThread {
connection, connection,
session_id, session_id,
token_usage: None, token_usage: None,
prompt_capabilities,
_observe_prompt_capabilities: task,
} }
} }
pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
pub fn connection(&self) -> &Rc<dyn AgentConnection> { pub fn connection(&self) -> &Rc<dyn AgentConnection> {
&self.connection &self.connection
} }
@ -2599,13 +2621,19 @@ mod tests {
.into(), .into(),
); );
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| { let thread = cx.new(|cx| {
AcpThread::new( AcpThread::new(
"Test", "Test",
self.clone(), self.clone(),
project, project,
action_log, action_log,
session_id.clone(), session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
) )
}); });
self.sessions.lock().insert(session_id, thread.downgrade()); self.sessions.lock().insert(session_id, thread.downgrade());
@ -2639,14 +2667,6 @@ mod tests {
} }
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock(); let sessions = self.sessions.lock();
let thread = sessions.get(session_id).unwrap().clone(); let thread = sessions.get(session_id).unwrap().clone();

View file

@ -38,8 +38,6 @@ pub trait AgentConnection {
cx: &mut App, cx: &mut App,
) -> Task<Result<acp::PromptResponse>>; ) -> Task<Result<acp::PromptResponse>>;
fn prompt_capabilities(&self) -> acp::PromptCapabilities;
fn resume( fn resume(
&self, &self,
_session_id: &acp::SessionId, _session_id: &acp::SessionId,
@ -329,13 +327,19 @@ mod test_support {
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| { let thread = cx.new(|cx| {
AcpThread::new( AcpThread::new(
"Test", "Test",
self.clone(), self.clone(),
project, project,
action_log, action_log,
session_id.clone(), session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
) )
}); });
self.sessions.lock().insert( self.sessions.lock().insert(
@ -348,14 +352,6 @@ mod test_support {
Task::ready(Ok(thread)) Task::ready(Ok(thread))
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn authenticate( fn authenticate(
&self, &self,
_method_id: acp::AuthMethodId, _method_id: acp::AuthMethodId,

View file

@ -240,13 +240,16 @@ impl NativeAgent {
let title = thread.title(); let title = thread.title();
let project = thread.project.clone(); let project = thread.project.clone();
let action_log = thread.action_log.clone(); let action_log = thread.action_log.clone();
let acp_thread = cx.new(|_cx| { let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
let acp_thread = cx.new(|cx| {
acp_thread::AcpThread::new( acp_thread::AcpThread::new(
title, title,
connection, connection,
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
session_id.clone(), session_id.clone(),
prompt_capabilities_rx,
cx,
) )
}); });
let subscriptions = vec![ let subscriptions = vec![
@ -925,14 +928,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}) })
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn resume( fn resume(
&self, &self,
session_id: &acp::SessionId, session_id: &acp::SessionId,

View file

@ -575,11 +575,22 @@ pub struct Thread {
templates: Arc<Templates>, templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>, model: Option<Arc<dyn LanguageModel>>,
summarization_model: Option<Arc<dyn LanguageModel>>, summarization_model: Option<Arc<dyn LanguageModel>>,
prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
pub(crate) project: Entity<Project>, pub(crate) project: Entity<Project>,
pub(crate) action_log: Entity<ActionLog>, pub(crate) action_log: Entity<ActionLog>,
} }
impl Thread { impl Thread {
fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
let image = model.map_or(true, |model| model.supports_images());
acp::PromptCapabilities {
image,
audio: false,
embedded_context: true,
}
}
pub fn new( pub fn new(
project: Entity<Project>, project: Entity<Project>,
project_context: Entity<ProjectContext>, project_context: Entity<ProjectContext>,
@ -590,6 +601,8 @@ impl Thread {
) -> Self { ) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone(); let profile_id = AgentSettings::get_global(cx).default_profile.clone();
let action_log = cx.new(|_cx| ActionLog::new(project.clone())); let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self { Self {
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()), id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
prompt_id: PromptId::new(), prompt_id: PromptId::new(),
@ -617,6 +630,8 @@ impl Thread {
templates, templates,
model, model,
summarization_model: None, summarization_model: None,
prompt_capabilities_tx,
prompt_capabilities_rx,
project, project,
action_log, action_log,
} }
@ -750,6 +765,8 @@ impl Thread {
.or_else(|| registry.default_model()) .or_else(|| registry.default_model())
.map(|model| model.model) .map(|model| model.model)
}); });
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
Self { Self {
id, id,
@ -779,6 +796,8 @@ impl Thread {
project, project,
action_log, action_log,
updated_at: db_thread.updated_at, updated_at: db_thread.updated_at,
prompt_capabilities_tx,
prompt_capabilities_rx,
} }
} }
@ -946,10 +965,12 @@ impl Thread {
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) { pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
let old_usage = self.latest_token_usage(); let old_usage = self.latest_token_usage();
self.model = Some(model); self.model = Some(model);
let new_caps = Self::prompt_capabilities(self.model.as_deref());
let new_usage = self.latest_token_usage(); let new_usage = self.latest_token_usage();
if old_usage != new_usage { if old_usage != new_usage {
cx.emit(TokenUsageUpdated(new_usage)); cx.emit(TokenUsageUpdated(new_usage));
} }
self.prompt_capabilities_tx.send(new_caps).log_err();
cx.notify() cx.notify()
} }

View file

@ -185,13 +185,16 @@ impl AgentConnection for AcpConnection {
let session_id = response.session_id; let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?; let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| { let thread = cx.new(|cx| {
AcpThread::new( AcpThread::new(
self.server_name.clone(), self.server_name.clone(),
self.clone(), self.clone(),
project, project,
action_log, action_log,
session_id.clone(), session_id.clone(),
// ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
watch::Receiver::constant(self.prompt_capabilities),
cx,
) )
})?; })?;
@ -279,10 +282,6 @@ impl AgentConnection for AcpConnection {
}) })
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
self.prompt_capabilities
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) { if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
session.suppress_abort_err = true; session.suppress_abort_err = true;

View file

@ -249,13 +249,19 @@ impl AgentConnection for ClaudeAgentConnection {
}); });
let action_log = cx.new(|_| ActionLog::new(project.clone()))?; let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| { let thread = cx.new(|cx| {
AcpThread::new( AcpThread::new(
"Claude Code", "Claude Code",
self.clone(), self.clone(),
project, project,
action_log, action_log,
session_id.clone(), session_id.clone(),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}),
cx,
) )
})?; })?;
@ -319,14 +325,6 @@ impl AgentConnection for ClaudeAgentConnection {
cx.foreground_executor().spawn(async move { end_rx.await? }) cx.foreground_executor().spawn(async move { end_rx.await? })
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: false,
embedded_context: true,
}
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
let sessions = self.sessions.borrow(); let sessions = self.sessions.borrow();
let Some(session) = sessions.get(session_id) else { let Some(session) = sessions.get(session_id) else {

View file

@ -373,7 +373,7 @@ impl MessageEditor {
if Img::extensions().contains(&extension) && !extension.contains("svg") { if Img::extensions().contains(&extension) && !extension.contains("svg") {
if !self.prompt_capabilities.get().image { if !self.prompt_capabilities.get().image {
return Task::ready(Err(anyhow!("This agent does not support images yet"))); return Task::ready(Err(anyhow!("This model does not support images yet")));
} }
let task = self let task = self
.project .project

View file

@ -474,7 +474,7 @@ impl AcpThreadView {
let action_log = thread.read(cx).action_log().clone(); let action_log = thread.read(cx).action_log().clone();
this.prompt_capabilities this.prompt_capabilities
.set(connection.prompt_capabilities()); .set(thread.read(cx).prompt_capabilities());
let count = thread.read(cx).entries().len(); let count = thread.read(cx).entries().len();
this.list_state.splice(0..0, count); this.list_state.splice(0..0, count);
@ -1163,6 +1163,10 @@ impl AcpThreadView {
}); });
} }
} }
AcpThreadEvent::PromptCapabilitiesUpdated => {
self.prompt_capabilities
.set(thread.read(cx).prompt_capabilities());
}
AcpThreadEvent::TokenUsageUpdated => {} AcpThreadEvent::TokenUsageUpdated => {}
} }
cx.notify(); cx.notify();
@ -5367,6 +5371,12 @@ pub(crate) mod tests {
project, project,
action_log, action_log,
SessionId("test".into()), SessionId("test".into()),
watch::Receiver::constant(acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}),
cx,
) )
}))) })))
} }
@ -5375,14 +5385,6 @@ pub(crate) mod tests {
&[] &[]
} }
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
acp::PromptCapabilities {
image: true,
audio: true,
embedded_context: true,
}
}
fn authenticate( fn authenticate(
&self, &self,
_method_id: acp::AuthMethodId, _method_id: acp::AuthMethodId,

View file

@ -1529,6 +1529,7 @@ impl AgentDiff {
| AcpThreadEvent::TokenUsageUpdated | AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::PromptCapabilitiesUpdated
| AcpThreadEvent::Retry(_) => {} | AcpThreadEvent::Retry(_) => {}
} }
} }

View file

@ -162,6 +162,19 @@ impl<T> Receiver<T> {
pending_waker_id: None, pending_waker_id: None,
} }
} }
/// Creates a new [`Receiver`] holding an initial value that will never change.
pub fn constant(value: T) -> Self {
let state = Arc::new(RwLock::new(State {
value,
wakers: BTreeMap::new(),
next_waker_id: WakerId::default(),
version: 0,
closed: false,
}));
Self { state, version: 0 }
}
} }
impl<T: Clone> Receiver<T> { impl<T: Clone> Receiver<T> {