
Sets the account status state to allow the Supermaven button to move out of the "Supermaven is initializing..." state. We also need to add the ability to sign out and change tiers but I will do that in a separate PR. Release Notes: - Improved Supermaven status messages ([#12715](https://github.com/zed-industries/zed/issues/12715)).
400 lines
13 KiB
Rust
400 lines
13 KiB
Rust
mod messages;
|
|
mod supermaven_completion_provider;
|
|
|
|
pub use supermaven_completion_provider::*;
|
|
|
|
use anyhow::{Context as _, Result};
|
|
#[allow(unused_imports)]
|
|
use client::{proto, Client};
|
|
use collections::BTreeMap;
|
|
|
|
use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
|
|
use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
|
|
use language::{
|
|
language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
|
|
};
|
|
use messages::*;
|
|
use postage::watch;
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::SettingsStore;
|
|
use smol::{
|
|
io::AsyncWriteExt,
|
|
process::{Child, ChildStdin, ChildStdout, Command},
|
|
};
|
|
use std::{path::PathBuf, process::Stdio, sync::Arc};
|
|
use ui::prelude::*;
|
|
use util::ResultExt;
|
|
|
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|
let supermaven = cx.new_model(|_| Supermaven::Starting);
|
|
Supermaven::set_global(supermaven.clone(), cx);
|
|
|
|
let mut provider = all_language_settings(None, cx).inline_completions.provider;
|
|
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
|
|
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
|
|
}
|
|
|
|
cx.observe_global::<SettingsStore>(move |cx| {
|
|
let new_provider = all_language_settings(None, cx).inline_completions.provider;
|
|
if new_provider != provider {
|
|
provider = new_provider;
|
|
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
|
|
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
|
|
} else {
|
|
supermaven.update(cx, |supermaven, _cx| supermaven.stop());
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
}
|
|
|
|
pub enum Supermaven {
|
|
Starting,
|
|
FailedDownload { error: anyhow::Error },
|
|
Spawned(SupermavenAgent),
|
|
Error { error: anyhow::Error },
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub enum AccountStatus {
|
|
Unknown,
|
|
NeedsActivation { activate_url: String },
|
|
Ready,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct SupermavenGlobal(Model<Supermaven>);
|
|
|
|
impl Global for SupermavenGlobal {}
|
|
|
|
impl Supermaven {
|
|
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
|
|
cx.try_global::<SupermavenGlobal>()
|
|
.map(|model| model.0.clone())
|
|
}
|
|
|
|
pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
|
|
cx.set_global(SupermavenGlobal(supermaven));
|
|
}
|
|
|
|
pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
|
|
if let Self::Starting = self {
|
|
cx.spawn(|this, mut cx| async move {
|
|
let binary_path =
|
|
supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
|
|
|
|
this.update(&mut cx, |this, cx| {
|
|
if let Self::Starting = this {
|
|
*this =
|
|
Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
|
|
}
|
|
anyhow::Ok(())
|
|
})
|
|
})
|
|
.detach_and_log_err(cx)
|
|
}
|
|
}
|
|
|
|
pub fn stop(&mut self) {
|
|
*self = Self::Starting;
|
|
}
|
|
|
|
pub fn is_enabled(&self) -> bool {
|
|
matches!(self, Self::Spawned { .. })
|
|
}
|
|
|
|
pub fn complete(
|
|
&mut self,
|
|
buffer: &Model<Buffer>,
|
|
cursor_position: Anchor,
|
|
cx: &AppContext,
|
|
) -> Option<SupermavenCompletion> {
|
|
if let Self::Spawned(agent) = self {
|
|
let buffer_id = buffer.entity_id();
|
|
let buffer = buffer.read(cx);
|
|
let path = buffer
|
|
.file()
|
|
.and_then(|file| Some(file.as_local()?.abs_path(cx)))
|
|
.unwrap_or_else(|| PathBuf::from("untitled"))
|
|
.to_string_lossy()
|
|
.to_string();
|
|
let content = buffer.text();
|
|
let offset = cursor_position.to_offset(buffer);
|
|
let state_id = agent.next_state_id;
|
|
agent.next_state_id.0 += 1;
|
|
|
|
let (updates_tx, mut updates_rx) = watch::channel();
|
|
postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
|
|
|
|
agent.states.insert(
|
|
state_id,
|
|
SupermavenCompletionState {
|
|
buffer_id,
|
|
prefix_anchor: cursor_position,
|
|
text: String::new(),
|
|
dedent: String::new(),
|
|
updates_tx,
|
|
},
|
|
);
|
|
let _ = agent
|
|
.outgoing_tx
|
|
.unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
|
|
new_id: state_id.0.to_string(),
|
|
updates: vec![
|
|
StateUpdate::FileUpdate(FileUpdateMessage {
|
|
path: path.clone(),
|
|
content,
|
|
}),
|
|
StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
|
|
],
|
|
}));
|
|
|
|
Some(SupermavenCompletion {
|
|
id: state_id,
|
|
updates: updates_rx,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn completion(
|
|
&self,
|
|
buffer: &Model<Buffer>,
|
|
cursor_position: Anchor,
|
|
cx: &AppContext,
|
|
) -> Option<&str> {
|
|
if let Self::Spawned(agent) = self {
|
|
find_relevant_completion(
|
|
&agent.states,
|
|
buffer.entity_id(),
|
|
&buffer.read(cx).snapshot(),
|
|
cursor_position,
|
|
)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
fn find_relevant_completion<'a>(
|
|
states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
|
|
buffer_id: EntityId,
|
|
buffer: &BufferSnapshot,
|
|
cursor_position: Anchor,
|
|
) -> Option<&'a str> {
|
|
let mut best_completion: Option<&str> = None;
|
|
'completions: for state in states.values() {
|
|
if state.buffer_id != buffer_id {
|
|
continue;
|
|
}
|
|
let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
|
|
continue;
|
|
};
|
|
|
|
let current_cursor_offset = cursor_position.to_offset(buffer);
|
|
let original_cursor_offset = state.prefix_anchor.to_offset(buffer);
|
|
if current_cursor_offset < original_cursor_offset {
|
|
continue;
|
|
}
|
|
|
|
let text_inserted_since_completion_request =
|
|
buffer.text_for_range(original_cursor_offset..current_cursor_offset);
|
|
let mut trimmed_completion = state_completion;
|
|
for chunk in text_inserted_since_completion_request {
|
|
if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
|
|
trimmed_completion = suffix;
|
|
} else {
|
|
continue 'completions;
|
|
}
|
|
}
|
|
|
|
if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
|
|
continue;
|
|
}
|
|
|
|
best_completion = Some(trimmed_completion);
|
|
}
|
|
best_completion
|
|
}
|
|
|
|
pub struct SupermavenAgent {
|
|
_process: Child,
|
|
next_state_id: SupermavenCompletionStateId,
|
|
states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
|
|
outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
|
|
_handle_outgoing_messages: Task<Result<()>>,
|
|
_handle_incoming_messages: Task<Result<()>>,
|
|
pub account_status: AccountStatus,
|
|
service_tier: Option<ServiceTier>,
|
|
#[allow(dead_code)]
|
|
client: Arc<Client>,
|
|
}
|
|
|
|
impl SupermavenAgent {
|
|
fn new(
|
|
binary_path: PathBuf,
|
|
client: Arc<Client>,
|
|
cx: &mut ModelContext<Supermaven>,
|
|
) -> Result<Self> {
|
|
let mut process = Command::new(&binary_path)
|
|
.arg("stdio")
|
|
.stdin(Stdio::piped())
|
|
.stdout(Stdio::piped())
|
|
.stderr(Stdio::piped())
|
|
.kill_on_drop(true)
|
|
.spawn()
|
|
.context("failed to start the binary")?;
|
|
|
|
let stdin = process
|
|
.stdin
|
|
.take()
|
|
.context("failed to get stdin for process")?;
|
|
let stdout = process
|
|
.stdout
|
|
.take()
|
|
.context("failed to get stdout for process")?;
|
|
|
|
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
|
|
|
cx.spawn({
|
|
let client = client.clone();
|
|
let outgoing_tx = outgoing_tx.clone();
|
|
move |this, mut cx| async move {
|
|
let mut status = client.status();
|
|
while let Some(status) = status.next().await {
|
|
if status.is_connected() {
|
|
let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
|
|
outgoing_tx
|
|
.unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
|
|
.ok();
|
|
this.update(&mut cx, |this, cx| {
|
|
if let Supermaven::Spawned(this) = this {
|
|
this.account_status = AccountStatus::Ready;
|
|
cx.notify();
|
|
}
|
|
})?;
|
|
break;
|
|
}
|
|
}
|
|
return anyhow::Ok(());
|
|
}
|
|
})
|
|
.detach();
|
|
|
|
Ok(Self {
|
|
_process: process,
|
|
next_state_id: SupermavenCompletionStateId::default(),
|
|
states: BTreeMap::default(),
|
|
outgoing_tx,
|
|
_handle_outgoing_messages: cx
|
|
.spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
|
|
_handle_incoming_messages: cx
|
|
.spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
|
|
account_status: AccountStatus::Unknown,
|
|
service_tier: None,
|
|
client,
|
|
})
|
|
}
|
|
|
|
async fn handle_outgoing_messages(
|
|
mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
|
|
mut stdin: ChildStdin,
|
|
) -> Result<()> {
|
|
while let Some(message) = outgoing.next().await {
|
|
let bytes = serde_json::to_vec(&message)?;
|
|
stdin.write_all(&bytes).await?;
|
|
stdin.write_all(&[b'\n']).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_incoming_messages(
|
|
this: WeakModel<Supermaven>,
|
|
stdout: ChildStdout,
|
|
mut cx: AsyncAppContext,
|
|
) -> Result<()> {
|
|
const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
|
|
|
|
let stdout = BufReader::new(stdout);
|
|
let mut lines = stdout.lines();
|
|
while let Some(line) = lines.next().await {
|
|
let Some(line) = line.context("failed to read line from stdout").log_err() else {
|
|
continue;
|
|
};
|
|
let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
|
|
continue;
|
|
};
|
|
let Some(message) = serde_json::from_str::<SupermavenMessage>(&line)
|
|
.with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
|
|
.log_err()
|
|
else {
|
|
continue;
|
|
};
|
|
|
|
this.update(&mut cx, |this, _cx| {
|
|
if let Supermaven::Spawned(this) = this {
|
|
this.handle_message(message);
|
|
}
|
|
Task::ready(anyhow::Ok(()))
|
|
})?
|
|
.await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn handle_message(&mut self, message: SupermavenMessage) {
|
|
match message {
|
|
SupermavenMessage::ActivationRequest(request) => {
|
|
self.account_status = match request.activate_url {
|
|
Some(activate_url) => AccountStatus::NeedsActivation {
|
|
activate_url: activate_url.clone(),
|
|
},
|
|
None => AccountStatus::Ready,
|
|
};
|
|
}
|
|
SupermavenMessage::ActivationSuccess => {
|
|
self.account_status = AccountStatus::Ready;
|
|
}
|
|
SupermavenMessage::ServiceTier { service_tier } => {
|
|
self.account_status = AccountStatus::Ready;
|
|
self.service_tier = Some(service_tier);
|
|
}
|
|
SupermavenMessage::Response(response) => {
|
|
let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
|
|
if let Some(state) = self.states.get_mut(&state_id) {
|
|
for item in &response.items {
|
|
match item {
|
|
ResponseItem::Text { text } => state.text.push_str(text),
|
|
ResponseItem::Dedent { text } => state.dedent.push_str(text),
|
|
_ => {}
|
|
}
|
|
}
|
|
*state.updates_tx.borrow_mut() = ();
|
|
}
|
|
}
|
|
SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
|
|
_ => {
|
|
log::warn!("unhandled message: {:?}", message);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
|
|
pub struct SupermavenCompletionStateId(usize);
|
|
|
|
#[allow(dead_code)]
|
|
pub struct SupermavenCompletionState {
|
|
buffer_id: EntityId,
|
|
prefix_anchor: Anchor,
|
|
text: String,
|
|
dedent: String,
|
|
updates_tx: watch::Sender<()>,
|
|
}
|
|
|
|
pub struct SupermavenCompletion {
|
|
pub id: SupermavenCompletionStateId,
|
|
pub updates: watch::Receiver<()>,
|
|
}
|