Add assistant_context_editor
crate (#23429)
This PR adds a new `assistant_context_editor` crate. This will ultimately house the `ContextEditor` so that it can be consumed by both `assistant` and `assistant2`. For the purposes of this PR, we just introduce the crate and move some supporting constructs to it, such as the `ContextStore`. Release Notes: - N/A
This commit is contained in:
parent
c450cd51ea
commit
9a7f1d1de4
21 changed files with 304 additions and 211 deletions
|
@ -0,0 +1,16 @@
|
|||
mod context;
|
||||
mod context_store;
|
||||
mod patch;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::Client;
|
||||
use gpui::AppContext;
|
||||
|
||||
pub use crate::context::*;
|
||||
pub use crate::context_store::*;
|
||||
pub use crate::patch::*;
|
||||
|
||||
pub fn init(client: Arc<Client>, _cx: &mut AppContext) {
|
||||
context_store::init(&client.into());
|
||||
}
|
3544
crates/assistant_context_editor/src/context.rs
Normal file
3544
crates/assistant_context_editor/src/context.rs
Normal file
File diff suppressed because it is too large
Load diff
1662
crates/assistant_context_editor/src/context/context_tests.rs
Normal file
1662
crates/assistant_context_editor/src/context/context_tests.rs
Normal file
File diff suppressed because it is too large
Load diff
898
crates/assistant_context_editor/src/context_store.rs
Normal file
898
crates/assistant_context_editor/src/context_store.rs
Normal file
|
@ -0,0 +1,898 @@
|
|||
use crate::{
|
||||
Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
|
||||
SavedContextMetadata,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
|
||||
use assistant_tool::{ToolId, ToolWorkingSet};
|
||||
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
|
||||
use clock::ReplicaId;
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
use gpui::{
|
||||
AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel,
|
||||
};
|
||||
use language::LanguageRegistry;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
use prompt_library::PromptBuilder;
|
||||
use regex::Regex;
|
||||
use rpc::AnyProtoClient;
|
||||
use std::sync::LazyLock;
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
ffi::OsStr,
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
pub(crate) fn init(client: &AnyProtoClient) {
|
||||
client.add_model_message_handler(ContextStore::handle_advertise_contexts);
|
||||
client.add_model_request_handler(ContextStore::handle_open_context);
|
||||
client.add_model_request_handler(ContextStore::handle_create_context);
|
||||
client.add_model_message_handler(ContextStore::handle_update_context);
|
||||
client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RemoteContextMetadata {
|
||||
pub id: ContextId,
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
pub struct ContextStore {
|
||||
contexts: Vec<ContextHandle>,
|
||||
contexts_metadata: Vec<SavedContextMetadata>,
|
||||
context_server_manager: Model<ContextServerManager>,
|
||||
context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
host_contexts: Vec<RemoteContextMetadata>,
|
||||
fs: Arc<dyn Fs>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
_watch_updates: Task<Option<()>>,
|
||||
client: Arc<Client>,
|
||||
project: Model<Project>,
|
||||
project_is_shared: bool,
|
||||
client_subscription: Option<client::Subscription>,
|
||||
_project_subscriptions: Vec<gpui::Subscription>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub enum ContextStoreEvent {
|
||||
ContextCreated(ContextId),
|
||||
}
|
||||
|
||||
impl EventEmitter<ContextStoreEvent> for ContextStore {}
|
||||
|
||||
enum ContextHandle {
|
||||
Weak(WeakModel<Context>),
|
||||
Strong(Model<Context>),
|
||||
}
|
||||
|
||||
impl ContextHandle {
|
||||
fn upgrade(&self) -> Option<Model<Context>> {
|
||||
match self {
|
||||
ContextHandle::Weak(weak) => weak.upgrade(),
|
||||
ContextHandle::Strong(strong) => Some(strong.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn downgrade(&self) -> WeakModel<Context> {
|
||||
match self {
|
||||
ContextHandle::Weak(weak) => weak.clone(),
|
||||
ContextHandle::Strong(strong) => strong.downgrade(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextStore {
|
||||
pub fn new(
|
||||
project: Model<Project>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Model<Self>>> {
|
||||
let fs = project.read(cx).fs().clone();
|
||||
let languages = project.read(cx).languages().clone();
|
||||
let telemetry = project.read(cx).client().telemetry().clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
|
||||
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
|
||||
|
||||
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
|
||||
let context_server_factory_registry =
|
||||
ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new_model(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let mut this = Self {
|
||||
contexts: Vec::new(),
|
||||
contexts_metadata: Vec::new(),
|
||||
context_server_manager,
|
||||
context_server_slash_command_ids: HashMap::default(),
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
host_contexts: Vec::new(),
|
||||
fs,
|
||||
languages,
|
||||
slash_commands,
|
||||
tools,
|
||||
telemetry,
|
||||
_watch_updates: cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
while events.next().await.is_some() {
|
||||
this.update(&mut cx, |this, cx| this.reload(cx))?
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
}),
|
||||
client_subscription: None,
|
||||
_project_subscriptions: vec![
|
||||
cx.observe(&project, Self::handle_project_changed),
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
],
|
||||
project_is_shared: false,
|
||||
client: project.read(cx).client(),
|
||||
project: project.clone(),
|
||||
prompt_builder,
|
||||
};
|
||||
this.handle_project_changed(project.clone(), cx);
|
||||
this.synchronize_contexts(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this
|
||||
})?;
|
||||
this.update(&mut cx, |this, cx| this.reload(cx))?
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
Ok(this)
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_advertise_contexts(
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::AdvertiseContexts>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.host_contexts = envelope
|
||||
.payload
|
||||
.contexts
|
||||
.into_iter()
|
||||
.map(|context| RemoteContextMetadata {
|
||||
id: ContextId::from_proto(context.context_id),
|
||||
summary: context.summary,
|
||||
})
|
||||
.collect();
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_open_context(
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::OpenContext>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<proto::OpenContextResponse> {
|
||||
let context_id = ContextId::from_proto(envelope.payload.context_id);
|
||||
let operations = this.update(&mut cx, |this, cx| {
|
||||
if this.project.read(cx).is_via_collab() {
|
||||
return Err(anyhow!("only the host contexts can be opened"));
|
||||
}
|
||||
|
||||
let context = this
|
||||
.loaded_context_for_id(&context_id, cx)
|
||||
.context("context not found")?;
|
||||
if context.read(cx).replica_id() != ReplicaId::default() {
|
||||
return Err(anyhow!("context must be opened via the host"));
|
||||
}
|
||||
|
||||
anyhow::Ok(
|
||||
context
|
||||
.read(cx)
|
||||
.serialize_ops(&ContextVersion::default(), cx),
|
||||
)
|
||||
})??;
|
||||
let operations = operations.await;
|
||||
Ok(proto::OpenContextResponse {
|
||||
context: Some(proto::Context { operations }),
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_create_context(
|
||||
this: Model<Self>,
|
||||
_: TypedEnvelope<proto::CreateContext>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<proto::CreateContextResponse> {
|
||||
let (context_id, operations) = this.update(&mut cx, |this, cx| {
|
||||
if this.project.read(cx).is_via_collab() {
|
||||
return Err(anyhow!("can only create contexts as the host"));
|
||||
}
|
||||
|
||||
let context = this.create(cx);
|
||||
let context_id = context.read(cx).id().clone();
|
||||
cx.emit(ContextStoreEvent::ContextCreated(context_id.clone()));
|
||||
|
||||
anyhow::Ok((
|
||||
context_id,
|
||||
context
|
||||
.read(cx)
|
||||
.serialize_ops(&ContextVersion::default(), cx),
|
||||
))
|
||||
})??;
|
||||
let operations = operations.await;
|
||||
Ok(proto::CreateContextResponse {
|
||||
context_id: context_id.to_proto(),
|
||||
context: Some(proto::Context { operations }),
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_update_context(
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::UpdateContext>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let context_id = ContextId::from_proto(envelope.payload.context_id);
|
||||
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
|
||||
let operation_proto = envelope.payload.operation.context("invalid operation")?;
|
||||
let operation = ContextOperation::from_proto(operation_proto)?;
|
||||
context.update(cx, |context, cx| context.apply_ops([operation], cx));
|
||||
}
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
||||
async fn handle_synchronize_contexts(
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::SynchronizeContexts>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<proto::SynchronizeContextsResponse> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if this.project.read(cx).is_via_collab() {
|
||||
return Err(anyhow!("only the host can synchronize contexts"));
|
||||
}
|
||||
|
||||
let mut local_versions = Vec::new();
|
||||
for remote_version_proto in envelope.payload.contexts {
|
||||
let remote_version = ContextVersion::from_proto(&remote_version_proto);
|
||||
let context_id = ContextId::from_proto(remote_version_proto.context_id);
|
||||
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
|
||||
let context = context.read(cx);
|
||||
let operations = context.serialize_ops(&remote_version, cx);
|
||||
local_versions.push(context.version(cx).to_proto(context_id.clone()));
|
||||
let client = this.client.clone();
|
||||
let project_id = envelope.payload.project_id;
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let operations = operations.await;
|
||||
for operation in operations {
|
||||
client.send(proto::UpdateContext {
|
||||
project_id,
|
||||
context_id: context_id.to_proto(),
|
||||
operation: Some(operation),
|
||||
})?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
this.advertise_contexts(cx);
|
||||
|
||||
anyhow::Ok(proto::SynchronizeContextsResponse {
|
||||
contexts: local_versions,
|
||||
})
|
||||
})?
|
||||
}
|
||||
|
||||
fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
|
||||
let is_shared = self.project.read(cx).is_shared();
|
||||
let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
|
||||
if is_shared == was_shared {
|
||||
return;
|
||||
}
|
||||
|
||||
if is_shared {
|
||||
self.contexts.retain_mut(|context| {
|
||||
if let Some(strong_context) = context.upgrade() {
|
||||
*context = ContextHandle::Strong(strong_context);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
let remote_id = self.project.read(cx).remote_id().unwrap();
|
||||
self.client_subscription = self
|
||||
.client
|
||||
.subscribe_to_entity(remote_id)
|
||||
.log_err()
|
||||
.map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
|
||||
self.advertise_contexts(cx);
|
||||
} else {
|
||||
self.client_subscription = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_: Model<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::Reshared => {
|
||||
self.advertise_contexts(cx);
|
||||
}
|
||||
project::Event::HostReshared | project::Event::Rejoined => {
|
||||
self.synchronize_contexts(cx);
|
||||
}
|
||||
project::Event::DisconnectedFromHost => {
|
||||
self.contexts.retain_mut(|context| {
|
||||
if let Some(strong_context) = context.upgrade() {
|
||||
*context = ContextHandle::Weak(context.downgrade());
|
||||
strong_context.update(cx, |context, cx| {
|
||||
if context.replica_id() != ReplicaId::default() {
|
||||
context.set_capability(language::Capability::ReadOnly, cx);
|
||||
}
|
||||
});
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
self.host_contexts.clear();
|
||||
cx.notify();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
|
||||
let context = cx.new_model(|cx| {
|
||||
Context::local(
|
||||
self.languages.clone(),
|
||||
Some(self.project.clone()),
|
||||
Some(self.telemetry.clone()),
|
||||
self.prompt_builder.clone(),
|
||||
self.slash_commands.clone(),
|
||||
self.tools.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
self.register_context(&context, cx);
|
||||
context
|
||||
}
|
||||
|
||||
pub fn create_remote_context(
|
||||
&mut self,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Model<Context>>> {
|
||||
let project = self.project.read(cx);
|
||||
let Some(project_id) = project.remote_id() else {
|
||||
return Task::ready(Err(anyhow!("project was not remote")));
|
||||
};
|
||||
|
||||
let replica_id = project.replica_id();
|
||||
let capability = project.capability();
|
||||
let language_registry = self.languages.clone();
|
||||
let project = self.project.clone();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let prompt_builder = self.prompt_builder.clone();
|
||||
let slash_commands = self.slash_commands.clone();
|
||||
let tools = self.tools.clone();
|
||||
let request = self.client.request(proto::CreateContext { project_id });
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let response = request.await?;
|
||||
let context_id = ContextId::from_proto(response.context_id);
|
||||
let context_proto = response.context.context("invalid context")?;
|
||||
let context = cx.new_model(|cx| {
|
||||
Context::new(
|
||||
context_id.clone(),
|
||||
replica_id,
|
||||
capability,
|
||||
language_registry,
|
||||
prompt_builder,
|
||||
slash_commands,
|
||||
tools,
|
||||
Some(project),
|
||||
Some(telemetry),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
let operations = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
context_proto
|
||||
.operations
|
||||
.into_iter()
|
||||
.map(ContextOperation::from_proto)
|
||||
.collect::<Result<Vec<_>>>()
|
||||
})
|
||||
.await?;
|
||||
context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
|
||||
existing_context
|
||||
} else {
|
||||
this.register_context(&context, cx);
|
||||
this.synchronize_contexts(cx);
|
||||
context
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_local_context(
|
||||
&mut self,
|
||||
path: PathBuf,
|
||||
cx: &ModelContext<Self>,
|
||||
) -> Task<Result<Model<Context>>> {
|
||||
if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
|
||||
return Task::ready(Ok(existing_context));
|
||||
}
|
||||
|
||||
let fs = self.fs.clone();
|
||||
let languages = self.languages.clone();
|
||||
let project = self.project.clone();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let load = cx.background_executor().spawn({
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let saved_context = fs.load(&path).await?;
|
||||
SavedContext::from_json(&saved_context)
|
||||
}
|
||||
});
|
||||
let prompt_builder = self.prompt_builder.clone();
|
||||
let slash_commands = self.slash_commands.clone();
|
||||
let tools = self.tools.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let saved_context = load.await?;
|
||||
let context = cx.new_model(|cx| {
|
||||
Context::deserialize(
|
||||
saved_context,
|
||||
path.clone(),
|
||||
languages,
|
||||
prompt_builder,
|
||||
slash_commands,
|
||||
tools,
|
||||
Some(project),
|
||||
Some(telemetry),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
|
||||
existing_context
|
||||
} else {
|
||||
this.register_context(&context, cx);
|
||||
context
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
|
||||
self.contexts.iter().find_map(|context| {
|
||||
let context = context.upgrade()?;
|
||||
if context.read(cx).path() == Some(path) {
|
||||
Some(context)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
|
||||
self.contexts.iter().find_map(|context| {
|
||||
let context = context.upgrade()?;
|
||||
if context.read(cx).id() == id {
|
||||
Some(context)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_remote_context(
|
||||
&mut self,
|
||||
context_id: ContextId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Model<Context>>> {
|
||||
let project = self.project.read(cx);
|
||||
let Some(project_id) = project.remote_id() else {
|
||||
return Task::ready(Err(anyhow!("project was not remote")));
|
||||
};
|
||||
|
||||
if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
|
||||
return Task::ready(Ok(context));
|
||||
}
|
||||
|
||||
let replica_id = project.replica_id();
|
||||
let capability = project.capability();
|
||||
let language_registry = self.languages.clone();
|
||||
let project = self.project.clone();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let request = self.client.request(proto::OpenContext {
|
||||
project_id,
|
||||
context_id: context_id.to_proto(),
|
||||
});
|
||||
let prompt_builder = self.prompt_builder.clone();
|
||||
let slash_commands = self.slash_commands.clone();
|
||||
let tools = self.tools.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let response = request.await?;
|
||||
let context_proto = response.context.context("invalid context")?;
|
||||
let context = cx.new_model(|cx| {
|
||||
Context::new(
|
||||
context_id.clone(),
|
||||
replica_id,
|
||||
capability,
|
||||
language_registry,
|
||||
prompt_builder,
|
||||
slash_commands,
|
||||
tools,
|
||||
Some(project),
|
||||
Some(telemetry),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
let operations = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
context_proto
|
||||
.operations
|
||||
.into_iter()
|
||||
.map(ContextOperation::from_proto)
|
||||
.collect::<Result<Vec<_>>>()
|
||||
})
|
||||
.await?;
|
||||
context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
|
||||
existing_context
|
||||
} else {
|
||||
this.register_context(&context, cx);
|
||||
this.synchronize_contexts(cx);
|
||||
context
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
|
||||
let handle = if self.project_is_shared {
|
||||
ContextHandle::Strong(context.clone())
|
||||
} else {
|
||||
ContextHandle::Weak(context.downgrade())
|
||||
};
|
||||
self.contexts.push(handle);
|
||||
self.advertise_contexts(cx);
|
||||
cx.subscribe(context, Self::handle_context_event).detach();
|
||||
}
|
||||
|
||||
fn handle_context_event(
|
||||
&mut self,
|
||||
context: Model<Context>,
|
||||
event: &ContextEvent,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
match event {
|
||||
ContextEvent::SummaryChanged => {
|
||||
self.advertise_contexts(cx);
|
||||
}
|
||||
ContextEvent::Operation(operation) => {
|
||||
let context_id = context.read(cx).id().to_proto();
|
||||
let operation = operation.to_proto();
|
||||
self.client
|
||||
.send(proto::UpdateContext {
|
||||
project_id,
|
||||
context_id,
|
||||
operation: Some(operation),
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn advertise_contexts(&self, cx: &AppContext) {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// For now, only the host can advertise their open contexts.
|
||||
if self.project.read(cx).is_via_collab() {
|
||||
return;
|
||||
}
|
||||
|
||||
let contexts = self
|
||||
.contexts
|
||||
.iter()
|
||||
.rev()
|
||||
.filter_map(|context| {
|
||||
let context = context.upgrade()?.read(cx);
|
||||
if context.replica_id() == ReplicaId::default() {
|
||||
Some(proto::ContextMetadata {
|
||||
context_id: context.id().to_proto(),
|
||||
summary: context.summary().map(|summary| summary.text.clone()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.client
|
||||
.send(proto::AdvertiseContexts {
|
||||
project_id,
|
||||
contexts,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let contexts = self
|
||||
.contexts
|
||||
.iter()
|
||||
.filter_map(|context| {
|
||||
let context = context.upgrade()?.read(cx);
|
||||
if context.replica_id() != ReplicaId::default() {
|
||||
Some(context.version(cx).to_proto(context.id().clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let client = self.client.clone();
|
||||
let request = self.client.request(proto::SynchronizeContexts {
|
||||
project_id,
|
||||
contexts,
|
||||
});
|
||||
cx.spawn(|this, cx| async move {
|
||||
let response = request.await?;
|
||||
|
||||
let mut context_ids = Vec::new();
|
||||
let mut operations = Vec::new();
|
||||
this.read_with(&cx, |this, cx| {
|
||||
for context_version_proto in response.contexts {
|
||||
let context_version = ContextVersion::from_proto(&context_version_proto);
|
||||
let context_id = ContextId::from_proto(context_version_proto.context_id);
|
||||
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
|
||||
context_ids.push(context_id);
|
||||
operations.push(context.read(cx).serialize_ops(&context_version, cx));
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
let operations = futures::future::join_all(operations).await;
|
||||
for (context_id, operations) in context_ids.into_iter().zip(operations) {
|
||||
for operation in operations {
|
||||
client.send(proto::UpdateContext {
|
||||
project_id,
|
||||
context_id: context_id.to_proto(),
|
||||
operation: Some(operation),
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
|
||||
let metadata = self.contexts_metadata.clone();
|
||||
let executor = cx.background_executor().clone();
|
||||
cx.background_executor().spawn(async move {
|
||||
if query.is_empty() {
|
||||
metadata
|
||||
} else {
|
||||
let candidates = metadata
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
|
||||
.collect::<Vec<_>>();
|
||||
let matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
false,
|
||||
100,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|mat| metadata[mat.candidate_id].clone())
|
||||
.collect()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
|
||||
&self.host_contexts
|
||||
}
|
||||
|
||||
fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
fs.create_dir(contexts_dir()).await?;
|
||||
|
||||
let mut paths = fs.read_dir(contexts_dir()).await?;
|
||||
let mut contexts = Vec::<SavedContextMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
// This is used to filter out contexts saved by the new assistant.
|
||||
if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(title) = ASSISTANT_CONTEXT_REGEX
|
||||
.replace(file_name, "")
|
||||
.lines()
|
||||
.next()
|
||||
{
|
||||
contexts.push(SavedContextMetadata {
|
||||
title: title.to_string(),
|
||||
path,
|
||||
mtime: metadata.mtime.timestamp_for_user().into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.contexts_metadata = contexts;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
|
||||
cx.update_model(
|
||||
&self.context_server_manager,
|
||||
|context_server_manager, cx| {
|
||||
for server in context_server_manager.servers() {
|
||||
context_server_manager
|
||||
.restart_server(&server.id(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
|
||||
cx.subscribe(
|
||||
&self.context_server_manager.clone(),
|
||||
Self::handle_context_server_event,
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn handle_context_server_event(
|
||||
&mut self,
|
||||
context_server_manager: Model<ContextServerManager>,
|
||||
event: &context_server::manager::Event,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let slash_command_working_set = self.slash_commands.clone();
|
||||
let tool_working_set = self.tools.clone();
|
||||
match event {
|
||||
context_server::manager::Event::ServerStarted { server_id } => {
|
||||
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
|
||||
let context_server_manager = context_server_manager.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
let server_id = server_id.clone();
|
||||
|this, mut cx| async move {
|
||||
let Some(protocol) = server.client() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||
let slash_command_ids = prompts
|
||||
.into_iter()
|
||||
.filter(assistant_slash_commands::acceptable_prompt)
|
||||
.map(|prompt| {
|
||||
log::info!(
|
||||
"registering context server command: {:?}",
|
||||
prompt.name
|
||||
);
|
||||
slash_command_working_set.insert(Arc::new(
|
||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||
context_server_manager.clone(),
|
||||
&server,
|
||||
prompt,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update(&mut cx, |this, _cx| {
|
||||
this.context_server_slash_command_ids
|
||||
.insert(server_id.clone(), slash_command_ids);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tools.tools.into_iter().map(|tool| {
|
||||
log::info!("registering context server tool: {:?}", tool.name);
|
||||
tool_working_set.insert(
|
||||
Arc::new(ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
)),
|
||||
)
|
||||
|
||||
}).collect::<Vec<_>>();
|
||||
|
||||
|
||||
this.update(&mut cx, |this, _cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(slash_command_ids) =
|
||||
self.context_server_slash_command_ids.remove(server_id)
|
||||
{
|
||||
slash_command_working_set.remove(&slash_command_ids);
|
||||
}
|
||||
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
970
crates/assistant_context_editor/src/patch.rs
Normal file
970
crates/assistant_context_editor/src/patch.rs
Normal file
|
@ -0,0 +1,970 @@
|
|||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use editor::ProposedChangesEditor;
|
||||
use futures::{future, TryFutureExt as _};
|
||||
use gpui::{AppContext, AsyncAppContext, Model, SharedString};
|
||||
use language::{AutoindentMode, Buffer, BufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AssistantPatch {
|
||||
pub range: Range<language::Anchor>,
|
||||
pub title: SharedString,
|
||||
pub edits: Arc<[Result<AssistantEdit>]>,
|
||||
pub status: AssistantPatchStatus,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AssistantPatchStatus {
|
||||
Pending,
|
||||
Ready,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct AssistantEdit {
|
||||
pub path: String,
|
||||
pub kind: AssistantEditKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AssistantEditKind {
|
||||
Update {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
},
|
||||
Create {
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
},
|
||||
InsertBefore {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
},
|
||||
InsertAfter {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
},
|
||||
Delete {
|
||||
old_text: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ResolvedPatch {
|
||||
pub edit_groups: HashMap<Model<Buffer>, Vec<ResolvedEditGroup>>,
|
||||
pub errors: Vec<AssistantPatchResolutionError>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ResolvedEditGroup {
|
||||
pub context_range: Range<language::Anchor>,
|
||||
pub edits: Vec<ResolvedEdit>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ResolvedEdit {
|
||||
range: Range<language::Anchor>,
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct AssistantPatchResolutionError {
|
||||
pub edit_ix: usize,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(rows: usize, cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
|
||||
self.data[row * self.cols + col] = cost;
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvedPatch {
|
||||
pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) {
|
||||
for (buffer, groups) in &self.edit_groups {
|
||||
let branch = editor.branch_buffer_for_base(buffer).unwrap();
|
||||
Self::apply_edit_groups(groups, &branch, cx);
|
||||
}
|
||||
editor.recalculate_all_buffer_diffs();
|
||||
}
|
||||
|
||||
fn apply_edit_groups(
|
||||
groups: &Vec<ResolvedEditGroup>,
|
||||
buffer: &Model<Buffer>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let mut edits = Vec::new();
|
||||
for group in groups {
|
||||
for suggestion in &group.edits {
|
||||
edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
|
||||
}
|
||||
}
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(
|
||||
edits,
|
||||
Some(AutoindentMode::Block {
|
||||
original_indent_columns: Vec::new(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvedEdit {
|
||||
pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
|
||||
let range = &self.range;
|
||||
let other_range = &other.range;
|
||||
|
||||
// Don't merge if we don't contain the other suggestion.
|
||||
if range.start.cmp(&other_range.start, buffer).is_gt()
|
||||
|| range.end.cmp(&other_range.end, buffer).is_lt()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let other_offset_range = other_range.to_offset(buffer);
|
||||
let offset_range = range.to_offset(buffer);
|
||||
|
||||
// If the other range is empty at the start of this edit's range, combine the new text
|
||||
if other_offset_range.is_empty() && other_offset_range.start == offset_range.start {
|
||||
self.new_text = format!("{}\n{}", other.new_text, self.new_text);
|
||||
self.range.start = other_range.start;
|
||||
|
||||
if let Some((description, other_description)) =
|
||||
self.description.as_mut().zip(other.description.as_ref())
|
||||
{
|
||||
*description = format!("{}\n{}", other_description, description)
|
||||
}
|
||||
} else {
|
||||
if let Some((description, other_description)) =
|
||||
self.description.as_mut().zip(other.description.as_ref())
|
||||
{
|
||||
description.push('\n');
|
||||
description.push_str(other_description);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantEdit {
|
||||
pub fn new(
|
||||
path: Option<String>,
|
||||
operation: Option<String>,
|
||||
old_text: Option<String>,
|
||||
new_text: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let path = path.ok_or_else(|| anyhow!("missing path"))?;
|
||||
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
|
||||
|
||||
let kind = match operation.as_str() {
|
||||
"update" => AssistantEditKind::Update {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description,
|
||||
},
|
||||
"insert_before" => AssistantEditKind::InsertBefore {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description,
|
||||
},
|
||||
"insert_after" => AssistantEditKind::InsertAfter {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description,
|
||||
},
|
||||
"delete" => AssistantEditKind::Delete {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
},
|
||||
"create" => AssistantEditKind::Create {
|
||||
description,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
},
|
||||
_ => Err(anyhow!("unknown operation {operation:?}"))?,
|
||||
};
|
||||
|
||||
Ok(Self { path, kind })
|
||||
}
|
||||
|
||||
pub async fn resolve(
|
||||
&self,
|
||||
project: Model<Project>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<(Model<Buffer>, ResolvedEdit)> {
|
||||
let path = self.path.clone();
|
||||
let kind = self.kind.clone();
|
||||
let buffer = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(Path::new(&path), cx)
|
||||
.or_else(|| {
|
||||
// If we couldn't find a project path for it, put it in the active worktree
|
||||
// so that when we create the buffer, it can be saved.
|
||||
let worktree = project
|
||||
.active_entry()
|
||||
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
|
||||
.or_else(|| project.worktrees(cx).next())?;
|
||||
let worktree = worktree.read(cx);
|
||||
|
||||
Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: Arc::from(Path::new(&path)),
|
||||
})
|
||||
})
|
||||
.with_context(|| format!("worktree not found for {:?}", path))?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
|
||||
let suggestion = cx
|
||||
.background_executor()
|
||||
.spawn(async move { kind.resolve(&snapshot) })
|
||||
.await;
|
||||
|
||||
Ok((buffer, suggestion))
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantEditKind {
|
||||
fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
|
||||
match self {
|
||||
Self::Update {
|
||||
old_text,
|
||||
new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
ResolvedEdit {
|
||||
range,
|
||||
new_text,
|
||||
description,
|
||||
}
|
||||
}
|
||||
Self::Create {
|
||||
new_text,
|
||||
description,
|
||||
} => ResolvedEdit {
|
||||
range: text::Anchor::MIN..text::Anchor::MAX,
|
||||
description,
|
||||
new_text,
|
||||
},
|
||||
Self::InsertBefore {
|
||||
old_text,
|
||||
mut new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
new_text.push('\n');
|
||||
ResolvedEdit {
|
||||
range: range.start..range.start,
|
||||
new_text,
|
||||
description,
|
||||
}
|
||||
}
|
||||
Self::InsertAfter {
|
||||
old_text,
|
||||
mut new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
new_text.insert(0, '\n');
|
||||
ResolvedEdit {
|
||||
range: range.end..range.end,
|
||||
new_text,
|
||||
description,
|
||||
}
|
||||
}
|
||||
Self::Delete { old_text } => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
ResolvedEdit {
|
||||
range,
|
||||
new_text: String::new(),
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
const WHITESPACE_INSERTION_COST: u32 = 1;
|
||||
const WHITESPACE_DELETION_COST: u32 = 1;
|
||||
|
||||
let buffer_len = buffer.len();
|
||||
let query_len = search_query.len();
|
||||
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
|
||||
let mut leading_deletion_cost = 0_u32;
|
||||
for (row, query_byte) in search_query.bytes().enumerate() {
|
||||
let deletion_cost = if query_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_DELETION_COST
|
||||
} else {
|
||||
DELETION_COST
|
||||
};
|
||||
|
||||
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
|
||||
matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
|
||||
);
|
||||
|
||||
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
|
||||
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_INSERTION_COST
|
||||
} else {
|
||||
INSERTION_COST
|
||||
};
|
||||
|
||||
let up = SearchState::new(
|
||||
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_byte == *buffer_byte {
|
||||
matrix.get(row, col).cost
|
||||
} else {
|
||||
matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(deletion_cost + insertion_cost)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
}
|
||||
}
|
||||
|
||||
// Traceback to find the best match
|
||||
let mut best_buffer_end = buffer_len;
|
||||
let mut best_cost = u32::MAX;
|
||||
for col in 1..=buffer_len {
|
||||
let cost = matrix.get(query_len, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
best_buffer_end = col;
|
||||
}
|
||||
}
|
||||
|
||||
let mut query_ix = query_len;
|
||||
let mut buffer_ix = best_buffer_end;
|
||||
while query_ix > 0 && buffer_ix > 0 {
|
||||
let current = matrix.get(query_ix, buffer_ix);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_ix -= 1;
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_ix -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
|
||||
start.column = 0;
|
||||
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
|
||||
if end.column > 0 {
|
||||
end.column = buffer.line_len(end.row);
|
||||
}
|
||||
|
||||
buffer.anchor_after(start)..buffer.anchor_before(end)
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantPatch {
|
||||
pub async fn resolve(
|
||||
&self,
|
||||
project: Model<Project>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> ResolvedPatch {
|
||||
let mut resolve_tasks = Vec::new();
|
||||
for (ix, edit) in self.edits.iter().enumerate() {
|
||||
if let Ok(edit) = edit.as_ref() {
|
||||
resolve_tasks.push(
|
||||
edit.resolve(project.clone(), cx.clone())
|
||||
.map_err(move |error| (ix, error)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let edits = future::join_all(resolve_tasks).await;
|
||||
let mut errors = Vec::new();
|
||||
let mut edits_by_buffer = HashMap::default();
|
||||
for entry in edits {
|
||||
match entry {
|
||||
Ok((buffer, edit)) => {
|
||||
edits_by_buffer
|
||||
.entry(buffer)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(edit);
|
||||
}
|
||||
Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
|
||||
edit_ix,
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Expand the context ranges of each edit and group edits with overlapping context ranges.
|
||||
let mut edit_groups_by_buffer = HashMap::default();
|
||||
for (buffer, edits) in edits_by_buffer {
|
||||
if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
|
||||
edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
|
||||
}
|
||||
}
|
||||
|
||||
ResolvedPatch {
|
||||
edit_groups: edit_groups_by_buffer,
|
||||
errors,
|
||||
}
|
||||
}
|
||||
|
||||
fn group_edits(
|
||||
mut edits: Vec<ResolvedEdit>,
|
||||
snapshot: &text::BufferSnapshot,
|
||||
) -> Vec<ResolvedEditGroup> {
|
||||
let mut edit_groups = Vec::<ResolvedEditGroup>::new();
|
||||
// Sort edits by their range so that earlier, larger ranges come first
|
||||
edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
|
||||
|
||||
// Merge overlapping edits
|
||||
edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
|
||||
|
||||
// Create context ranges for each edit
|
||||
for edit in edits {
|
||||
let context_range = {
|
||||
let edit_point_range = edit.range.to_point(&snapshot);
|
||||
let start_row = edit_point_range.start.row.saturating_sub(5);
|
||||
let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
|
||||
let start = snapshot.anchor_before(Point::new(start_row, 0));
|
||||
let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
|
||||
start..end
|
||||
};
|
||||
|
||||
if let Some(last_group) = edit_groups.last_mut() {
|
||||
if last_group
|
||||
.context_range
|
||||
.end
|
||||
.cmp(&context_range.start, &snapshot)
|
||||
.is_ge()
|
||||
{
|
||||
// Merge with the previous group if context ranges overlap
|
||||
last_group.context_range.end = context_range.end;
|
||||
last_group.edits.push(edit);
|
||||
} else {
|
||||
// Create a new group
|
||||
edit_groups.push(ResolvedEditGroup {
|
||||
context_range,
|
||||
edits: vec![edit],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Create the first group
|
||||
edit_groups.push(ResolvedEditGroup {
|
||||
context_range,
|
||||
edits: vec![edit],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
edit_groups
|
||||
}
|
||||
|
||||
pub fn path_count(&self) -> usize {
|
||||
self.paths().count()
|
||||
}
|
||||
|
||||
pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
|
||||
let mut prev_path = None;
|
||||
self.edits.iter().filter_map(move |edit| {
|
||||
if let Ok(edit) = edit {
|
||||
let path = Some(edit.path.as_str());
|
||||
if path != prev_path {
|
||||
prev_path = path;
|
||||
return path;
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for AssistantPatch {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.range == other.range
|
||||
&& self.title == other.title
|
||||
&& Arc::ptr_eq(&self.edits, &other.edits)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for AssistantPatch {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context};
|
||||
use language::{
|
||||
language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use ui::BorrowAppContext;
|
||||
use unindent::Unindent as _;
|
||||
use util::test::{generate_marked_text, marked_text_ranges};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_location(cx: &mut AppContext) {
|
||||
assert_location_resolution(
|
||||
concat!(
|
||||
" Lorem\n",
|
||||
"« ipsum\n",
|
||||
" dolor sit amet»\n",
|
||||
" consecteur",
|
||||
),
|
||||
"ipsum\ndolor",
|
||||
cx,
|
||||
);
|
||||
|
||||
assert_location_resolution(
|
||||
&"
|
||||
«fn foo1(a: usize) -> usize {
|
||||
40
|
||||
}»
|
||||
|
||||
fn foo2(b: usize) -> usize {
|
||||
42
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
"fn foo1(b: usize) {\n40\n}",
|
||||
cx,
|
||||
);
|
||||
|
||||
assert_location_resolution(
|
||||
&"
|
||||
fn main() {
|
||||
« Foo
|
||||
.bar()
|
||||
.baz()
|
||||
.qux()»
|
||||
}
|
||||
|
||||
fn foo2(b: usize) -> usize {
|
||||
42
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
"Foo.bar.baz.qux()",
|
||||
cx,
|
||||
);
|
||||
|
||||
assert_location_resolution(
|
||||
&"
|
||||
class Something {
|
||||
one() { return 1; }
|
||||
« two() { return 2222; }
|
||||
three() { return 333; }
|
||||
four() { return 4444; }
|
||||
five() { return 5555; }
|
||||
six() { return 6666; }
|
||||
» seven() { return 7; }
|
||||
eight() { return 8; }
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
&"
|
||||
two() { return 2222; }
|
||||
four() { return 4444; }
|
||||
five() { return 5555; }
|
||||
six() { return 6666; }
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_edits(cx: &mut AppContext) {
|
||||
init_test(cx);
|
||||
|
||||
assert_edits(
|
||||
"
|
||||
/// A person
|
||||
struct Person {
|
||||
name: String,
|
||||
age: usize,
|
||||
}
|
||||
|
||||
/// A dog
|
||||
struct Dog {
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
vec![
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
name: String,
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
"
|
||||
.unindent(),
|
||||
description: None,
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn name(&self) -> String {
|
||||
format!(\"{} {}\", self.first_name, self.last_name)
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
"
|
||||
/// A person
|
||||
struct Person {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: usize,
|
||||
}
|
||||
|
||||
/// A dog
|
||||
struct Dog {
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
fn name(&self) -> String {
|
||||
format!(\"{} {}\", self.first_name, self.last_name)
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
|
||||
// Ensure InsertBefore merges correctly with Update of the same text
|
||||
assert_edits(
|
||||
"
|
||||
fn foo() {
|
||||
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
vec![
|
||||
AssistantEditKind::InsertBefore {
|
||||
old_text: "
|
||||
fn foo() {"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn bar() {
|
||||
qux();
|
||||
}"
|
||||
.unindent(),
|
||||
description: Some("implement bar".into()),
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn foo() {
|
||||
|
||||
}"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn foo() {
|
||||
bar();
|
||||
}"
|
||||
.unindent(),
|
||||
description: Some("call bar in foo".into()),
|
||||
},
|
||||
AssistantEditKind::InsertAfter {
|
||||
old_text: "
|
||||
fn foo() {
|
||||
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn qux() {
|
||||
// todo
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: Some("implement qux".into()),
|
||||
},
|
||||
],
|
||||
"
|
||||
fn bar() {
|
||||
qux();
|
||||
}
|
||||
|
||||
fn foo() {
|
||||
bar();
|
||||
}
|
||||
|
||||
fn qux() {
|
||||
// todo
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
|
||||
// Correctly indent new text when replacing multiple adjacent indented blocks.
|
||||
assert_edits(
|
||||
"
|
||||
impl Numbers {
|
||||
fn one() {
|
||||
1
|
||||
}
|
||||
|
||||
fn two() {
|
||||
2
|
||||
}
|
||||
|
||||
fn three() {
|
||||
3
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
vec![
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn one() {
|
||||
1
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn one() {
|
||||
101
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: None,
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn two() {
|
||||
2
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn two() {
|
||||
102
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: None,
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn three() {
|
||||
3
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn three() {
|
||||
103
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
"
|
||||
impl Numbers {
|
||||
fn one() {
|
||||
101
|
||||
}
|
||||
|
||||
fn two() {
|
||||
102
|
||||
}
|
||||
|
||||
fn three() {
|
||||
103
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
|
||||
assert_edits(
|
||||
"
|
||||
impl Person {
|
||||
fn set_name(&mut self, name: String) {
|
||||
self.name = name;
|
||||
}
|
||||
|
||||
fn name(&self) -> String {
|
||||
return self.name;
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
vec![
|
||||
AssistantEditKind::Update {
|
||||
old_text: "self.name = name;".unindent(),
|
||||
new_text: "self._name = name;".unindent(),
|
||||
description: None,
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "return self.name;\n".unindent(),
|
||||
new_text: "return self._name;\n".unindent(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
"
|
||||
impl Person {
|
||||
fn set_name(&mut self, name: String) {
|
||||
self._name = name;
|
||||
}
|
||||
|
||||
fn name(&self) -> String {
|
||||
return self._name;
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
cx.update_global::<SettingsStore, _>(|settings, cx| {
|
||||
settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
|
||||
});
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_location_resolution(
|
||||
text_with_expected_range: &str,
|
||||
query: &str,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let (text, _) = marked_text_ranges(text_with_expected_range, false);
|
||||
let buffer = cx.new_model(|cx| Buffer::local(text.clone(), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let range = AssistantEditKind::resolve_location(&snapshot, query).to_offset(&snapshot);
|
||||
let text_with_actual_range = generate_marked_text(&text, &[range], false);
|
||||
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_edits(
|
||||
old_text: String,
|
||||
edits: Vec<AssistantEditKind>,
|
||||
new_text: String,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let buffer =
|
||||
cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let resolved_edits = edits
|
||||
.into_iter()
|
||||
.map(|kind| kind.resolve(&snapshot))
|
||||
.collect();
|
||||
let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
|
||||
ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
|
||||
let actual_new_text = buffer.read(cx).text();
|
||||
pretty_assertions::assert_eq!(actual_new_text, new_text);
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(language::tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue