Supermaven (#10788)

Adds a supermaven provider for completions. There are various other
refactors amidst this branch, primarily to make copilot no longer a
dependency of project as well as show LSP Logs for global LSPs like
copilot properly.

This feature is not enabled by default. We're going to seek to refine it
in the coming weeks.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Kyle Kelley 2024-05-03 12:50:42 -07:00 committed by GitHub
parent 610968815c
commit 6563330239
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 2242 additions and 827 deletions

View file

@ -0,0 +1,41 @@
[package]
name = "supermaven"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/supermaven.rs"
doctest = false
[dependencies]
anyhow.workspace = true
client.workspace = true
collections.workspace = true
editor.workspace = true
gpui.workspace = true
futures.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
supermaven_api.workspace = true
smol.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }

View file

@ -0,0 +1,152 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SetApiKey {
pub api_key: String,
}
// Outbound messages
#[derive(Debug, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum OutboundMessage {
SetApiKey(SetApiKey),
StateUpdate(StateUpdateMessage),
#[allow(dead_code)]
UseFreeVersion,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StateUpdateMessage {
pub new_id: String,
pub updates: Vec<StateUpdate>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum StateUpdate {
FileUpdate(FileUpdateMessage),
CursorUpdate(CursorPositionUpdateMessage),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FileUpdateMessage {
pub path: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct CursorPositionUpdateMessage {
pub path: String,
pub offset: usize,
}
// Inbound messages coming in on stdout
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ResponseItem {
// A completion
Text { text: String },
// Vestigial message type from old versions -- safe to ignore
Del { text: String },
// Be able to delete whitespace prior to the cursor, likely for the rest of the completion
Dedent { text: String },
// When the completion is over
End,
// Got the closing parentheses and shouldn't show any more after
Barrier,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenResponse {
pub state_id: String,
pub items: Vec<ResponseItem>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenMetadataMessage {
pub dust_strings: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenTaskUpdateMessage {
pub task: String,
pub status: TaskStatus,
pub percent_complete: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
InProgress,
Complete,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenActiveRepoMessage {
pub repo_simple_name: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SupermavenPopupAction {
OpenUrl { label: String, url: String },
NoOp { label: String },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SupermavenPopupMessage {
pub message: String,
pub actions: Vec<SupermavenPopupAction>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "camelCase")]
pub struct ActivationRequest {
pub activate_url: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenSetMessage {
pub key: String,
pub value: serde_json::Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ServiceTier {
FreeNoLicense,
#[serde(other)]
Unknown,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SupermavenMessage {
Response(SupermavenResponse),
Metadata(SupermavenMetadataMessage),
Apology {
message: Option<String>,
},
ActivationRequest(ActivationRequest),
ActivationSuccess,
Passthrough {
passthrough: Box<SupermavenMessage>,
},
Popup(SupermavenPopupMessage),
TaskStatus(SupermavenTaskUpdateMessage),
ActiveRepo(SupermavenActiveRepoMessage),
ServiceTier {
service_tier: ServiceTier,
},
Set(SupermavenSetMessage),
#[serde(other)]
Unknown,
}

View file

@ -0,0 +1,345 @@
mod messages;
mod supermaven_completion_provider;
pub use supermaven_completion_provider::*;
use anyhow::{Context as _, Result};
#[allow(unused_imports)]
use client::{proto, Client};
use collections::BTreeMap;
use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset};
use messages::*;
use postage::watch;
use serde::{Deserialize, Serialize};
use settings::SettingsStore;
use smol::{
io::AsyncWriteExt,
process::{Child, ChildStdin, ChildStdout, Command},
};
use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let supermaven = cx.new_model(|_| Supermaven::Starting);
Supermaven::set_global(supermaven.clone(), cx);
let mut provider = all_language_settings(None, cx).inline_completions.provider;
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
}
cx.observe_global::<SettingsStore>(move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
provider = new_provider;
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
} else {
supermaven.update(cx, |supermaven, _cx| supermaven.stop());
}
}
})
.detach();
}
pub enum Supermaven {
Starting,
FailedDownload { error: anyhow::Error },
Spawned(SupermavenAgent),
Error { error: anyhow::Error },
}
#[derive(Clone)]
pub enum AccountStatus {
Unknown,
NeedsActivation { activate_url: String },
Ready,
}
#[derive(Clone)]
struct SupermavenGlobal(Model<Supermaven>);
impl Global for SupermavenGlobal {}
impl Supermaven {
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
cx.try_global::<SupermavenGlobal>()
.map(|model| model.0.clone())
}
pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
cx.set_global(SupermavenGlobal(supermaven));
}
pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
if let Self::Starting = self {
cx.spawn(|this, mut cx| async move {
let binary_path =
supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
this.update(&mut cx, |this, cx| {
if let Self::Starting = this {
*this =
Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
}
anyhow::Ok(())
})
})
.detach_and_log_err(cx)
}
}
pub fn stop(&mut self) {
*self = Self::Starting;
}
pub fn is_enabled(&self) -> bool {
matches!(self, Self::Spawned { .. })
}
pub fn complete(
&mut self,
buffer: &Model<Buffer>,
cursor_position: Anchor,
cx: &AppContext,
) -> Option<SupermavenCompletion> {
if let Self::Spawned(agent) = self {
let buffer_id = buffer.entity_id();
let buffer = buffer.read(cx);
let path = buffer
.file()
.and_then(|file| Some(file.as_local()?.abs_path(cx)))
.unwrap_or_else(|| PathBuf::from("untitled"))
.to_string_lossy()
.to_string();
let content = buffer.text();
let offset = cursor_position.to_offset(buffer);
let state_id = agent.next_state_id;
agent.next_state_id.0 += 1;
let (updates_tx, mut updates_rx) = watch::channel();
postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
agent.states.insert(
state_id,
SupermavenCompletionState {
buffer_id,
range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer),
completion: Vec::new(),
text: String::new(),
updates_tx,
},
);
let _ = agent
.outgoing_tx
.unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
new_id: state_id.0.to_string(),
updates: vec![
StateUpdate::FileUpdate(FileUpdateMessage {
path: path.clone(),
content,
}),
StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
],
}));
Some(SupermavenCompletion {
id: state_id,
updates: updates_rx,
})
} else {
None
}
}
pub fn completion(
&self,
id: SupermavenCompletionStateId,
) -> Option<&SupermavenCompletionState> {
if let Self::Spawned(agent) = self {
agent.states.get(&id)
} else {
None
}
}
}
pub struct SupermavenAgent {
_process: Child,
next_state_id: SupermavenCompletionStateId,
states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
_handle_outgoing_messages: Task<Result<()>>,
_handle_incoming_messages: Task<Result<()>>,
pub account_status: AccountStatus,
service_tier: Option<ServiceTier>,
#[allow(dead_code)]
client: Arc<Client>,
}
impl SupermavenAgent {
fn new(
binary_path: PathBuf,
client: Arc<Client>,
cx: &mut ModelContext<Supermaven>,
) -> Result<Self> {
let mut process = Command::new(&binary_path)
.arg("stdio")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.context("failed to start the binary")?;
let stdin = process
.stdin
.take()
.context("failed to get stdin for process")?;
let stdout = process
.stdout
.take()
.context("failed to get stdout for process")?;
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
cx.spawn({
let client = client.clone();
let outgoing_tx = outgoing_tx.clone();
move |this, mut cx| async move {
let mut status = client.status();
while let Some(status) = status.next().await {
if status.is_connected() {
let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
outgoing_tx
.unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
.ok();
this.update(&mut cx, |this, cx| {
if let Supermaven::Spawned(this) = this {
this.account_status = AccountStatus::Ready;
cx.notify();
}
})?;
break;
}
}
return anyhow::Ok(());
}
})
.detach();
Ok(Self {
_process: process,
next_state_id: SupermavenCompletionStateId::default(),
states: BTreeMap::default(),
outgoing_tx,
_handle_outgoing_messages: cx
.spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
_handle_incoming_messages: cx
.spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
account_status: AccountStatus::Unknown,
service_tier: None,
client,
})
}
async fn handle_outgoing_messages(
mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
mut stdin: ChildStdin,
) -> Result<()> {
while let Some(message) = outgoing.next().await {
let bytes = serde_json::to_vec(&message)?;
stdin.write_all(&bytes).await?;
stdin.write_all(&[b'\n']).await?;
}
Ok(())
}
async fn handle_incoming_messages(
this: WeakModel<Supermaven>,
stdout: ChildStdout,
mut cx: AsyncAppContext,
) -> Result<()> {
const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
let stdout = BufReader::new(stdout);
let mut lines = stdout.lines();
while let Some(line) = lines.next().await {
let Some(line) = line.context("failed to read line from stdout").log_err() else {
continue;
};
let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
continue;
};
let Some(message) = serde_json::from_str::<SupermavenMessage>(&line)
.with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
.log_err()
else {
continue;
};
this.update(&mut cx, |this, _cx| {
if let Supermaven::Spawned(this) = this {
this.handle_message(message);
}
Task::ready(anyhow::Ok(()))
})?
.await?;
}
Ok(())
}
fn handle_message(&mut self, message: SupermavenMessage) {
match message {
SupermavenMessage::ActivationRequest(request) => {
self.account_status = match request.activate_url {
Some(activate_url) => AccountStatus::NeedsActivation {
activate_url: activate_url.clone(),
},
None => AccountStatus::Ready,
};
}
SupermavenMessage::ServiceTier { service_tier } => {
self.service_tier = Some(service_tier);
}
SupermavenMessage::Response(response) => {
let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
if let Some(state) = self.states.get_mut(&state_id) {
for item in &response.items {
if let ResponseItem::Text { text } = item {
state.text.push_str(text);
}
}
state.completion.extend(response.items);
*state.updates_tx.borrow_mut() = ();
}
}
SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
_ => {
log::warn!("unhandled message: {:?}", message);
}
}
}
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct SupermavenCompletionStateId(usize);
#[allow(dead_code)]
pub struct SupermavenCompletionState {
buffer_id: EntityId,
range: Range<Anchor>,
completion: Vec<ResponseItem>,
text: String,
updates_tx: watch::Sender<()>,
}
pub struct SupermavenCompletion {
pub id: SupermavenCompletionStateId,
pub updates: watch::Receiver<()>,
}

View file

@ -0,0 +1,131 @@
use crate::{Supermaven, SupermavenCompletionStateId};
use anyhow::Result;
use editor::{Direction, InlineCompletionProvider};
use futures::StreamExt as _;
use gpui::{AppContext, Model, ModelContext, Task};
use language::{
language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset,
};
use std::time::Duration;
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct SupermavenCompletionProvider {
supermaven: Model<Supermaven>,
completion_id: Option<SupermavenCompletionStateId>,
pending_refresh: Task<Result<()>>,
}
impl SupermavenCompletionProvider {
pub fn new(supermaven: Model<Supermaven>) -> Self {
Self {
supermaven,
completion_id: None,
pending_refresh: Task::ready(Ok(())),
}
}
}
impl InlineCompletionProvider for SupermavenCompletionProvider {
fn is_enabled(&self, buffer: &Model<Buffer>, cursor_position: Anchor, cx: &AppContext) -> bool {
if !self.supermaven.read(cx).is_enabled() {
return false;
}
let buffer = buffer.read(cx);
let file = buffer.file();
let language = buffer.language_at(cursor_position);
let settings = all_language_settings(file, cx);
settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
}
fn refresh(
&mut self,
buffer_handle: Model<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut ModelContext<Self>,
) {
let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
supermaven.complete(&buffer_handle, cursor_position, cx)
}) else {
return;
};
self.pending_refresh = cx.spawn(|this, mut cx| async move {
if debounce {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
}
while let Some(()) = completion.updates.next().await {
this.update(&mut cx, |this, cx| {
this.completion_id = Some(completion.id);
cx.notify();
})?;
}
Ok(())
});
}
fn cycle(
&mut self,
_buffer: Model<Buffer>,
_cursor_position: Anchor,
_direction: Direction,
_cx: &mut ModelContext<Self>,
) {
// todo!("cycling")
}
fn accept(&mut self, _cx: &mut ModelContext<Self>) {
self.pending_refresh = Task::ready(Ok(()));
self.completion_id = None;
}
fn discard(&mut self, _cx: &mut ModelContext<Self>) {
self.pending_refresh = Task::ready(Ok(()));
self.completion_id = None;
}
fn active_completion_text<'a>(
&'a self,
buffer: &Model<Buffer>,
cursor_position: Anchor,
cx: &'a AppContext,
) -> Option<&'a str> {
let completion_id = self.completion_id?;
let buffer = buffer.read(cx);
let cursor_offset = cursor_position.to_offset(buffer);
let completion = self.supermaven.read(cx).completion(completion_id)?;
let mut completion_range = completion.range.to_offset(buffer);
let prefix_len = common_prefix(
buffer.chars_for_range(completion_range.clone()),
completion.text.chars(),
);
completion_range.start += prefix_len;
let suffix_len = common_prefix(
buffer.reversed_chars_for_range(completion_range.clone()),
completion.text[prefix_len..].chars().rev(),
);
completion_range.end = completion_range.end.saturating_sub(suffix_len);
let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len];
if completion_range.is_empty()
&& completion_range.start == cursor_offset
&& !completion_text.trim().is_empty()
{
Some(completion_text)
} else {
None
}
}
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
a.zip(b)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.len_utf8())
.sum()
}