diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 8e4f90e11d..c28c6e4587 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -215,12 +215,12 @@ impl InlineCompletionProvider for CopilotCompletionProvider { } } - fn active_completion_text( - &self, + fn active_completion_text<'a>( + &'a self, buffer: &Model, cursor_position: language::Anchor, - cx: &AppContext, - ) -> Option<&str> { + cx: &'a AppContext, + ) -> Option<&'a str> { let buffer_id = buffer.entity_id(); let buffer = buffer.read(cx); let completion = self.active_completion()?; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index b5c6e3fbc7..6b19104868 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -4356,6 +4356,7 @@ impl Editor { text: completion.text.to_string().into(), }); self.insert_with_autoindent_mode(&completion.text.to_string(), None, cx); + self.refresh_inline_completion(true, cx); cx.notify(); true } else { diff --git a/crates/editor/src/inline_completion_provider.rs b/crates/editor/src/inline_completion_provider.rs index 2fb2cb608f..11658a2d75 100644 --- a/crates/editor/src/inline_completion_provider.rs +++ b/crates/editor/src/inline_completion_provider.rs @@ -30,7 +30,7 @@ pub trait InlineCompletionProvider: 'static + Sized { buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, - ) -> Option<&str>; + ) -> Option<&'a str>; } pub trait InlineCompletionProviderHandle { diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs index c432116357..e4ab556490 100644 --- a/crates/supermaven/src/supermaven.rs +++ b/crates/supermaven/src/supermaven.rs @@ -10,7 +10,9 @@ use collections::BTreeMap; use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel}; -use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset}; +use language::{ + language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset, +}; use messages::*; use postage::watch; use serde::{Deserialize, Serialize}; @@ -19,7 +21,7 @@ use smol::{ io::AsyncWriteExt, process::{Child, ChildStdin, ChildStdout, Command}, }; -use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc}; +use std::{path::PathBuf, process::Stdio, sync::Arc}; use ui::prelude::*; use util::ResultExt; @@ -128,9 +130,9 @@ impl Supermaven { state_id, SupermavenCompletionState { buffer_id, - range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer), - completion: Vec::new(), + prefix_anchor: cursor_position, text: String::new(), + dedent: String::new(), updates_tx, }, ); @@ -158,16 +160,64 @@ impl Supermaven { pub fn completion( &self, - id: SupermavenCompletionStateId, - ) -> Option<&SupermavenCompletionState> { + buffer: &Model, + cursor_position: Anchor, + cx: &AppContext, + ) -> Option<&str> { if let Self::Spawned(agent) = self { - agent.states.get(&id) + find_relevant_completion( + &agent.states, + buffer.entity_id(), + &buffer.read(cx).snapshot(), + cursor_position, + ) } else { None } } } +fn find_relevant_completion<'a>( + states: &'a BTreeMap, + buffer_id: EntityId, + buffer: &BufferSnapshot, + cursor_position: Anchor, +) -> Option<&'a str> { + let mut best_completion: Option<&str> = None; + 'completions: for state in states.values() { + if state.buffer_id != buffer_id { + continue; + } + let Some(state_completion) = state.text.strip_prefix(&state.dedent) else { + continue; + }; + + let current_cursor_offset = cursor_position.to_offset(buffer); + let original_cursor_offset = state.prefix_anchor.to_offset(buffer); + if current_cursor_offset < original_cursor_offset { + continue; + } + + let text_inserted_since_completion_request = + buffer.text_for_range(original_cursor_offset..current_cursor_offset); + let mut trimmed_completion = state_completion; + for chunk in text_inserted_since_completion_request { + if let Some(suffix) = trimmed_completion.strip_prefix(chunk) { + trimmed_completion = suffix; + } else { + continue 'completions; + } + } + + if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) { + continue; + } + + best_completion = Some(trimmed_completion); + } + best_completion +} + pub struct SupermavenAgent { _process: Child, next_state_id: SupermavenCompletionStateId, @@ -311,11 +361,12 @@ impl SupermavenAgent { let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap()); if let Some(state) = self.states.get_mut(&state_id) { for item in &response.items { - if let ResponseItem::Text { text } = item { - state.text.push_str(text); + match item { + ResponseItem::Text { text } => state.text.push_str(text), + ResponseItem::Dedent { text } => state.dedent.push_str(text), + _ => {} } } - state.completion.extend(response.items); *state.updates_tx.borrow_mut() = (); } } @@ -333,9 +384,9 @@ pub struct SupermavenCompletionStateId(usize); #[allow(dead_code)] pub struct SupermavenCompletionState { buffer_id: EntityId, - range: Range, - completion: Vec, + prefix_anchor: Anchor, text: String, + dedent: String, updates_tx: watch::Sender<()>, } diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 978f4baf00..e939a7ef9c 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -3,9 +3,7 @@ use anyhow::Result; use editor::{Direction, InlineCompletionProvider}; use futures::StreamExt as _; use gpui::{AppContext, Model, ModelContext, Task}; -use language::{ - language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset, -}; +use language::{language_settings::all_language_settings, Anchor, Buffer}; use std::time::Duration; pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); @@ -92,29 +90,16 @@ impl InlineCompletionProvider for SupermavenCompletionProvider { cursor_position: Anchor, cx: &'a AppContext, ) -> Option<&'a str> { - let completion_id = self.completion_id?; - let buffer = buffer.read(cx); - let cursor_offset = cursor_position.to_offset(buffer); - let completion = self.supermaven.read(cx).completion(completion_id)?; + let completion_text = self + .supermaven + .read(cx) + .completion(buffer, cursor_position, cx)?; - let mut completion_range = completion.range.to_offset(buffer); + let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text); - let prefix_len = common_prefix( - buffer.chars_for_range(completion_range.clone()), - completion.text.chars(), - ); - completion_range.start += prefix_len; - let suffix_len = common_prefix( - buffer.reversed_chars_for_range(completion_range.clone()), - completion.text[prefix_len..].chars().rev(), - ); - completion_range.end = completion_range.end.saturating_sub(suffix_len); + let completion_text = completion_text.trim_end(); - let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len]; - if completion_range.is_empty() - && completion_range.start == cursor_offset - && !completion_text.trim().is_empty() - { + if !completion_text.trim().is_empty() { Some(completion_text) } else { None @@ -122,9 +107,24 @@ impl InlineCompletionProvider for SupermavenCompletionProvider { } } -fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { - a.zip(b) - .take_while(|(a, b)| a == b) - .map(|(a, _)| a.len_utf8()) - .sum() +fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str { + if has_leading_newline(&text) { + text + } else if let Some(i) = text.find('\n') { + &text[..i] + } else { + text + } +} + +fn has_leading_newline(text: &str) -> bool { + for c in text.chars() { + if c == '\n' { + return true; + } + if !c.is_whitespace() { + return false; + } + } + false }