mod auth_modal; mod request; use anyhow::{anyhow, Result}; use async_compression::futures::bufread::GzipDecoder; use auth_modal::AuthModal; use client::Client; use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16}; use lsp::LanguageServer; use settings::Settings; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ env::consts, path::{Path, PathBuf}, sync::Arc, }; use util::{ fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt, }; use workspace::Workspace; actions!(copilot, [SignIn, SignOut, ToggleAuthStatus]); pub fn init(client: Arc, cx: &mut MutableAppContext) { let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx)); cx.set_global(copilot.clone()); cx.add_action(|workspace: &mut Workspace, _: &SignIn, cx| { let copilot = Copilot::global(cx); if copilot.read(cx).status() == Status::Authorized { return; } copilot .update(cx, |copilot, cx| copilot.sign_in(cx)) .detach_and_log_err(cx); workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx)); }); cx.add_action(|workspace: &mut Workspace, _: &SignOut, cx| { let copilot = Copilot::global(cx); copilot .update(cx, |copilot, cx| copilot.sign_out(cx)) .detach_and_log_err(cx); if workspace.modal::().is_some() { workspace.dismiss_modal(cx) } }); cx.add_action(|workspace: &mut Workspace, _: &ToggleAuthStatus, cx| { workspace.toggle_modal(cx, |_workspace, cx| build_auth_modal(cx)) }) } fn build_auth_modal(cx: &mut gpui::ViewContext) -> gpui::ViewHandle { let modal = cx.add_view(|cx| AuthModal::new(cx)); cx.subscribe(&modal, |workspace, _, e: &auth_modal::Event, cx| match e { auth_modal::Event::Dismiss => workspace.dismiss_modal(cx), }) .detach(); modal } enum CopilotServer { Downloading, Error(Arc), Started { server: Arc, status: SignInStatus, }, } #[derive(Clone, Debug, PartialEq, Eq)] struct PromptingUser { user_code: String, verification_uri: String, } #[derive(Clone, Debug, PartialEq, Eq)] enum SignInStatus { Authorized { user: String }, Unauthorized { user: String }, PromptingUser(PromptingUser), SignedOut, } #[derive(Debug)] pub enum Event { PromptUserDeviceFlow { user_code: String, verification_uri: String, }, } #[derive(Debug, PartialEq, Eq)] pub enum Status { Downloading, Error(Arc), SignedOut, Unauthorized, Authorized, } impl Status { fn is_authorized(&self) -> bool { matches!(self, Status::Authorized) } } #[derive(Debug)] pub struct Completion { pub position: Anchor, pub text: String, } struct Copilot { server: CopilotServer, } impl Entity for Copilot { type Event = Event; } impl Copilot { fn global(cx: &AppContext) -> ModelHandle { cx.global::>().clone() } fn start(http: Arc, cx: &mut ModelContext) -> Self { // TODO: Don't eagerly download the LSP cx.spawn(|this, mut cx| async move { let start_language_server = async { let server_path = get_lsp_binary(http).await?; let server = LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?; let server = server.initialize(Default::default()).await?; let status = server .request::(request::CheckStatusParams { local_checks_only: false, }) .await?; anyhow::Ok((server, status)) }; let server = start_language_server.await; this.update(&mut cx, |this, cx| { cx.notify(); match server { Ok((server, status)) => { this.server = CopilotServer::Started { server, status: SignInStatus::SignedOut, }; this.update_sign_in_status(status, cx); } Err(error) => { this.server = CopilotServer::Error(error.to_string().into()); } } }) }) .detach(); Self { server: CopilotServer::Downloading, } } fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { if let CopilotServer::Started { server, .. } = &self.server { let server = server.clone(); cx.spawn(|this, mut cx| async move { let sign_in = server .request::(request::SignInInitiateParams {}) .await?; if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in { this.update(&mut cx, |this, cx| { this.update_prompting_user( flow.user_code.clone(), flow.verification_uri, cx, ); }); // TODO: catch an error here and clear the corresponding user code let response = server .request::(request::SignInConfirmParams { user_code: flow.user_code, }) .await?; this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx)); } anyhow::Ok(()) }) } else { Task::ready(Err(anyhow!("copilot hasn't started yet"))) } } fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { if let CopilotServer::Started { server, .. } = &self.server { let server = server.clone(); cx.spawn(|this, mut cx| async move { server .request::(request::SignOutParams {}) .await?; this.update(&mut cx, |this, cx| { if let CopilotServer::Started { status, .. } = &mut this.server { *status = SignInStatus::SignedOut; cx.notify(); } }); anyhow::Ok(()) }) } else { Task::ready(Err(anyhow!("copilot hasn't started yet"))) } } pub fn completion( &self, buffer: &ModelHandle, position: T, cx: &mut ModelContext, ) -> Task>> where T: ToPointUtf16, { let server = match self.authenticated_server() { Ok(server) => server, Err(error) => return Task::ready(Err(error)), }; let buffer = buffer.read(cx).snapshot(); let request = server .request::(build_completion_params(&buffer, position, cx)); cx.background().spawn(async move { let result = request.await?; let completion = result .completions .into_iter() .next() .map(|completion| completion_from_lsp(completion, &buffer)); anyhow::Ok(completion) }) } pub fn completions_cycling( &self, buffer: &ModelHandle, position: T, cx: &mut ModelContext, ) -> Task>> where T: ToPointUtf16, { let server = match self.authenticated_server() { Ok(server) => server, Err(error) => return Task::ready(Err(error)), }; let buffer = buffer.read(cx).snapshot(); let request = server.request::(build_completion_params( &buffer, position, cx, )); cx.background().spawn(async move { let result = request.await?; let completions = result .completions .into_iter() .map(|completion| completion_from_lsp(completion, &buffer)) .collect(); anyhow::Ok(completions) }) } pub fn status(&self) -> Status { match &self.server { CopilotServer::Downloading => Status::Downloading, CopilotServer::Error(error) => Status::Error(error.clone()), CopilotServer::Started { status, .. } => match status { SignInStatus::Authorized { .. } => Status::Authorized, SignInStatus::Unauthorized { .. } | SignInStatus::PromptingUser { .. } => { Status::Unauthorized } SignInStatus::SignedOut => Status::SignedOut, }, } } pub fn prompting_user(&self) -> Option<&PromptingUser> { if let CopilotServer::Started { status, .. } = &self.server { if let SignInStatus::PromptingUser(prompt) = status { return Some(prompt); } } None } fn update_prompting_user( &mut self, user_code: String, verification_uri: String, cx: &mut ModelContext, ) { if let CopilotServer::Started { status, .. } = &mut self.server { *status = SignInStatus::PromptingUser(PromptingUser { user_code, verification_uri, }); cx.notify(); } } fn update_sign_in_status( &mut self, lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { if let CopilotServer::Started { status, .. } = &mut self.server { *status = match lsp_status { request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => { SignInStatus::Authorized { user } } request::SignInStatus::NotAuthorized { user } => { SignInStatus::Unauthorized { user } } _ => SignInStatus::SignedOut, }; cx.notify(); } } fn authenticated_server(&self) -> Result> { match &self.server { CopilotServer::Downloading => Err(anyhow!("copilot is still downloading")), CopilotServer::Error(error) => Err(anyhow!( "copilot was not started because of an error: {}", error )), CopilotServer::Started { server, status } => { if matches!(status, SignInStatus::Authorized { .. }) { Ok(server.clone()) } else { Err(anyhow!("must sign in before using copilot")) } } } } } fn build_completion_params( buffer: &BufferSnapshot, position: T, cx: &AppContext, ) -> request::GetCompletionsParams where T: ToPointUtf16, { let position = position.to_point_utf16(&buffer); let language_name = buffer.language_at(position).map(|language| language.name()); let language_name = language_name.as_deref(); let path; let relative_path; if let Some(file) = buffer.file() { if let Some(file) = file.as_local() { path = file.abs_path(cx); } else { path = file.full_path(cx); } relative_path = file.path().to_path_buf(); } else { path = PathBuf::from("/untitled"); relative_path = PathBuf::from("untitled"); } let settings = cx.global::(); let language_id = match language_name { Some("Plain Text") => "plaintext".to_string(), Some(language_name) => language_name.to_lowercase(), None => "plaintext".to_string(), }; request::GetCompletionsParams { doc: request::GetCompletionsDocument { source: buffer.text(), tab_size: settings.tab_size(language_name).into(), indent_size: 1, insert_spaces: !settings.hard_tabs(language_name), uri: lsp::Url::from_file_path(&path).unwrap(), path: path.to_string_lossy().into(), relative_path: relative_path.to_string_lossy().into(), language_id, position: point_to_lsp(position), version: 0, }, } } fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion { let position = buffer.clip_point_utf16(point_from_lsp(completion.position), Bias::Left); Completion { position: buffer.anchor_before(position), text: completion.display_text, } } async fn get_lsp_binary(http: Arc) -> anyhow::Result { ///Check for the latest copilot language server and download it if we haven't already async fn fetch_latest(http: Arc) -> anyhow::Result { let release = latest_github_release("zed-industries/copilot", http.clone()).await?; let asset_name = format!("copilot-darwin-{}.gz", consts::ARCH); let asset = release .assets .iter() .find(|asset| asset.name == asset_name) .ok_or_else(|| anyhow!("no asset found matching {:?}", asset_name))?; fs::create_dir_all(&*paths::COPILOT_DIR).await?; let destination_path = paths::COPILOT_DIR.join(format!("copilot-{}-{}", release.name, consts::ARCH)); if fs::metadata(&destination_path).await.is_err() { let mut response = http .get(&asset.browser_download_url, Default::default(), true) .await .map_err(|err| anyhow!("error downloading release: {}", err))?; let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); let mut file = fs::File::create(&destination_path).await?; futures::io::copy(decompressed_bytes, &mut file).await?; fs::set_permissions( &destination_path, ::from_mode(0o755), ) .await?; remove_matching(&paths::COPILOT_DIR, |entry| entry != destination_path).await; } Ok(destination_path) } match fetch_latest(http).await { ok @ Result::Ok(..) => ok, e @ Err(..) => { e.log_err(); // Fetch a cached binary, if it exists (|| async move { let mut last = None; let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?; while let Some(entry) = entries.next().await { last = Some(entry?.path()); } last.ok_or_else(|| anyhow!("no cached binary")) })() .await } } } #[cfg(test)] mod tests { use super::*; use gpui::TestAppContext; use util::http; #[gpui::test] async fn test_smoke(cx: &mut TestAppContext) { Settings::test_async(cx); let http = http::client(); let copilot = cx.add_model(|cx| Copilot::start(http, cx)); smol::Timer::after(std::time::Duration::from_secs(2)).await; copilot .update(cx, |copilot, cx| copilot.sign_in(cx)) .await .unwrap(); dbg!(copilot.read_with(cx, |copilot, _| copilot.status())); let buffer = cx.add_model(|cx| language::Buffer::new(0, "fn foo() -> ", cx)); dbg!(copilot .update(cx, |copilot, cx| copilot.completion(&buffer, 12, cx)) .await .unwrap()); dbg!(copilot .update(cx, |copilot, cx| copilot .completions_cycling(&buffer, 12, cx)) .await .unwrap()); } }