acp: Model-specific prompt capabilities for 1PA (#36879)
Adds support for per-session prompt capabilities and capability changes on the Zed side (ACP itself still only has per-connection static capabilities for now), and uses it to reflect image support accurately in 1PA threads based on the currently-selected model. Release Notes: - N/A
This commit is contained in:
parent
66d9fb09cc
commit
5d8e0f6ad1
10 changed files with 98 additions and 53 deletions
|
@ -756,6 +756,8 @@ pub struct AcpThread {
|
|||
connection: Rc<dyn AgentConnection>,
|
||||
session_id: acp::SessionId,
|
||||
token_usage: Option<TokenUsage>,
|
||||
prompt_capabilities: acp::PromptCapabilities,
|
||||
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -770,6 +772,7 @@ pub enum AcpThreadEvent {
|
|||
Stopped,
|
||||
Error,
|
||||
LoadError(LoadError),
|
||||
PromptCapabilitiesUpdated,
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
|
@ -821,7 +824,20 @@ impl AcpThread {
|
|||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
session_id: acp::SessionId,
|
||||
mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let prompt_capabilities = *prompt_capabilities_rx.borrow();
|
||||
let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
|
||||
loop {
|
||||
let caps = prompt_capabilities_rx.recv().await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.prompt_capabilities = caps;
|
||||
cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
|
||||
})?;
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
|
@ -833,9 +849,15 @@ impl AcpThread {
|
|||
connection,
|
||||
session_id,
|
||||
token_usage: None,
|
||||
prompt_capabilities,
|
||||
_observe_prompt_capabilities: task,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
self.prompt_capabilities
|
||||
}
|
||||
|
||||
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
|
||||
&self.connection
|
||||
}
|
||||
|
@ -2599,13 +2621,19 @@ mod tests {
|
|||
.into(),
|
||||
);
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
"Test",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
watch::Receiver::constant(acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
|
@ -2639,14 +2667,6 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(session_id).unwrap().clone();
|
||||
|
|
|
@ -38,8 +38,6 @@ pub trait AgentConnection {
|
|||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities;
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
|
@ -329,13 +327,19 @@ mod test_support {
|
|||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
"Test",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
watch::Receiver::constant(acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
self.sessions.lock().insert(
|
||||
|
@ -348,14 +352,6 @@ mod test_support {
|
|||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method_id: acp::AuthMethodId,
|
||||
|
|
|
@ -240,13 +240,16 @@ impl NativeAgent {
|
|||
let title = thread.title();
|
||||
let project = thread.project.clone();
|
||||
let action_log = thread.action_log.clone();
|
||||
let acp_thread = cx.new(|_cx| {
|
||||
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
|
||||
let acp_thread = cx.new(|cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
title,
|
||||
connection,
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
prompt_capabilities_rx,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let subscriptions = vec![
|
||||
|
@ -925,14 +928,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: false,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
|
|
|
@ -575,11 +575,22 @@ pub struct Thread {
|
|||
templates: Arc<Templates>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
|
||||
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
|
||||
pub(crate) project: Entity<Project>,
|
||||
pub(crate) action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
|
||||
let image = model.map_or(true, |model| model.supports_images());
|
||||
acp::PromptCapabilities {
|
||||
image,
|
||||
audio: false,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
|
@ -590,6 +601,8 @@ impl Thread {
|
|||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||
let (prompt_capabilities_tx, prompt_capabilities_rx) =
|
||||
watch::channel(Self::prompt_capabilities(model.as_deref()));
|
||||
Self {
|
||||
id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
|
||||
prompt_id: PromptId::new(),
|
||||
|
@ -617,6 +630,8 @@ impl Thread {
|
|||
templates,
|
||||
model,
|
||||
summarization_model: None,
|
||||
prompt_capabilities_tx,
|
||||
prompt_capabilities_rx,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
|
@ -750,6 +765,8 @@ impl Thread {
|
|||
.or_else(|| registry.default_model())
|
||||
.map(|model| model.model)
|
||||
});
|
||||
let (prompt_capabilities_tx, prompt_capabilities_rx) =
|
||||
watch::channel(Self::prompt_capabilities(model.as_deref()));
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
@ -779,6 +796,8 @@ impl Thread {
|
|||
project,
|
||||
action_log,
|
||||
updated_at: db_thread.updated_at,
|
||||
prompt_capabilities_tx,
|
||||
prompt_capabilities_rx,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -946,10 +965,12 @@ impl Thread {
|
|||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
|
||||
let old_usage = self.latest_token_usage();
|
||||
self.model = Some(model);
|
||||
let new_caps = Self::prompt_capabilities(self.model.as_deref());
|
||||
let new_usage = self.latest_token_usage();
|
||||
if old_usage != new_usage {
|
||||
cx.emit(TokenUsageUpdated(new_usage));
|
||||
}
|
||||
self.prompt_capabilities_tx.send(new_caps).log_err();
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
|
|
|
@ -185,13 +185,16 @@ impl AgentConnection for AcpConnection {
|
|||
|
||||
let session_id = response.session_id;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||
let thread = cx.new(|_cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
self.server_name.clone(),
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
// ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
|
||||
watch::Receiver::constant(self.prompt_capabilities),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
|
@ -279,10 +282,6 @@ impl AgentConnection for AcpConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
self.prompt_capabilities
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
|
||||
session.suppress_abort_err = true;
|
||||
|
|
|
@ -249,13 +249,19 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||
let thread = cx.new(|_cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
"Claude Code",
|
||||
self.clone(),
|
||||
project,
|
||||
action_log,
|
||||
session_id.clone(),
|
||||
watch::Receiver::constant(acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: false,
|
||||
embedded_context: true,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
|
@ -319,14 +325,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
cx.foreground_executor().spawn(async move { end_rx.await? })
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: false,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(session_id) else {
|
||||
|
|
|
@ -373,7 +373,7 @@ impl MessageEditor {
|
|||
|
||||
if Img::extensions().contains(&extension) && !extension.contains("svg") {
|
||||
if !self.prompt_capabilities.get().image {
|
||||
return Task::ready(Err(anyhow!("This agent does not support images yet")));
|
||||
return Task::ready(Err(anyhow!("This model does not support images yet")));
|
||||
}
|
||||
let task = self
|
||||
.project
|
||||
|
|
|
@ -474,7 +474,7 @@ impl AcpThreadView {
|
|||
let action_log = thread.read(cx).action_log().clone();
|
||||
|
||||
this.prompt_capabilities
|
||||
.set(connection.prompt_capabilities());
|
||||
.set(thread.read(cx).prompt_capabilities());
|
||||
|
||||
let count = thread.read(cx).entries().len();
|
||||
this.list_state.splice(0..0, count);
|
||||
|
@ -1163,6 +1163,10 @@ impl AcpThreadView {
|
|||
});
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::PromptCapabilitiesUpdated => {
|
||||
self.prompt_capabilities
|
||||
.set(thread.read(cx).prompt_capabilities());
|
||||
}
|
||||
AcpThreadEvent::TokenUsageUpdated => {}
|
||||
}
|
||||
cx.notify();
|
||||
|
@ -5367,6 +5371,12 @@ pub(crate) mod tests {
|
|||
project,
|
||||
action_log,
|
||||
SessionId("test".into()),
|
||||
watch::Receiver::constant(acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})))
|
||||
}
|
||||
|
@ -5375,14 +5385,6 @@ pub(crate) mod tests {
|
|||
&[]
|
||||
}
|
||||
|
||||
fn prompt_capabilities(&self) -> acp::PromptCapabilities {
|
||||
acp::PromptCapabilities {
|
||||
image: true,
|
||||
audio: true,
|
||||
embedded_context: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method_id: acp::AuthMethodId,
|
||||
|
|
|
@ -1529,6 +1529,7 @@ impl AgentDiff {
|
|||
| AcpThreadEvent::TokenUsageUpdated
|
||||
| AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::PromptCapabilitiesUpdated
|
||||
| AcpThreadEvent::Retry(_) => {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -162,6 +162,19 @@ impl<T> Receiver<T> {
|
|||
pending_waker_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new [`Receiver`] holding an initial value that will never change.
|
||||
pub fn constant(value: T) -> Self {
|
||||
let state = Arc::new(RwLock::new(State {
|
||||
value,
|
||||
wakers: BTreeMap::new(),
|
||||
next_waker_id: WakerId::default(),
|
||||
version: 0,
|
||||
closed: false,
|
||||
}));
|
||||
|
||||
Self { state, version: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Receiver<T> {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue