diff --git a/crates/copilot2/src/copilot2.rs b/crates/copilot2/src/copilot2.rs index dfe861f1c3..834750b25d 100644 --- a/crates/copilot2/src/copilot2.rs +++ b/crates/copilot2/src/copilot2.rs @@ -5,9 +5,10 @@ use anyhow::{anyhow, Context as _, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use collections::{HashMap, HashSet}; -use futures::{channel::oneshot, future::Shared, Future, FutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; use gpui2::{ - AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle, + AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task, + WeakHandle, }; use language2::{ language_settings::{all_language_settings, language_settings}, @@ -134,7 +135,7 @@ struct RunningCopilotServer { name: LanguageServerName, lsp: Arc, sign_in_status: SignInStatus, - registered_buffers: HashMap, + registered_buffers: HashMap, } #[derive(Clone, Debug)] @@ -190,23 +191,23 @@ impl RegisteredBuffer { let _ = done_tx.send((self.snapshot_version, self.snapshot.clone())); } else { let buffer = buffer.downgrade(); - let id = buffer.id(); + let id = buffer.entity_id(); let prev_pending_change = mem::replace(&mut self.pending_buffer_change, Task::ready(None)); - self.pending_buffer_change = cx.spawn_weak(|copilot, mut cx| async move { + self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move { prev_pending_change.await; - let old_version = copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { - let server = copilot.server.as_authenticated().log_err()?; - let buffer = server.registered_buffers.get_mut(&id)?; - Some(buffer.snapshot.version.clone()) - })?; - let new_snapshot = buffer - .upgrade()? - .read_with(&cx, |buffer, _| buffer.snapshot()); + let old_version = copilot + .update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + Some(buffer.snapshot.version.clone()) + }) + .ok()??; + let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?; let content_changes = cx - .background() + .executor() .spawn({ let new_snapshot = new_snapshot.clone(); async move { @@ -232,28 +233,30 @@ impl RegisteredBuffer { }) .await; - copilot.upgrade(&cx)?.update(&mut cx, |copilot, _| { - let server = copilot.server.as_authenticated().log_err()?; - let buffer = server.registered_buffers.get_mut(&id)?; - if !content_changes.is_empty() { - buffer.snapshot_version += 1; - buffer.snapshot = new_snapshot; - server - .lsp - .notify::( - lsp2::DidChangeTextDocumentParams { - text_document: lsp2::VersionedTextDocumentIdentifier::new( - buffer.uri.clone(), - buffer.snapshot_version, - ), - content_changes, - }, - ) - .log_err(); - } - let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone())); - Some(()) - })?; + copilot + .update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + if !content_changes.is_empty() { + buffer.snapshot_version += 1; + buffer.snapshot = new_snapshot; + server + .lsp + .notify::( + lsp2::DidChangeTextDocumentParams { + text_document: lsp2::VersionedTextDocumentIdentifier::new( + buffer.uri.clone(), + buffer.snapshot_version, + ), + content_changes, + }, + ) + .log_err(); + } + let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone())); + Some(()) + }) + .ok()?; Some(()) }); @@ -274,7 +277,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, - buffers: HashSet>, + buffers: HashSet>, server_id: LanguageServerId, _subscription: gpui2::Subscription, } @@ -311,14 +314,14 @@ impl Copilot { _subscription: cx.on_app_quit(Self::shutdown_language_server), }; this.enable_or_disable_copilot(cx); - cx.observe_global::(move |this, cx| this.enable_or_disable_copilot(cx)) + cx.observe_global::(move |this, cx| this.enable_or_disable_copilot(cx)) .detach(); this } fn shutdown_language_server( &mut self, - cx: &mut ModelContext, + _cx: &mut ModelContext, ) -> impl Future { let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) { CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })), @@ -339,10 +342,8 @@ impl Copilot { if all_language_settings(None, cx).copilot_enabled(None, None) { if matches!(self.server, CopilotServer::Disabled) { let start_task = cx - .spawn({ - move |this, cx| { - Self::start_language_server(server_id, http, node_runtime, this, cx) - } + .spawn(move |this, cx| { + Self::start_language_server(server_id, http, node_runtime, this, cx) }) .shared(); self.server = CopilotServer::Starting { task: start_task }; @@ -381,7 +382,7 @@ impl Copilot { new_server_id: LanguageServerId, http: Arc, node_runtime: Arc, - this: Handle, + this: WeakHandle, mut cx: AsyncAppContext, ) -> impl Future { async move { @@ -446,6 +447,7 @@ impl Copilot { } } }) + .ok(); } } @@ -487,7 +489,7 @@ impl Copilot { cx.notify(); } } - }); + })?; let response = lsp .request::( request::SignInConfirmParams { @@ -513,7 +515,7 @@ impl Copilot { ); Err(Arc::new(error)) } - }) + })? }) .shared(); server.sign_in_status = SignInStatus::SigningIn { @@ -525,7 +527,7 @@ impl Copilot { } }; - cx.foreground() + cx.executor() .spawn(task.map_err(|err| anyhow!("{:?}", err))) } else { // If we're downloading, wait until download is finished @@ -534,11 +536,12 @@ impl Copilot { } } + #[allow(dead_code)] // todo!() fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { let server = server.clone(); - cx.background().spawn(async move { + cx.executor().spawn(async move { server .request::(request::SignOutParams {}) .await?; @@ -568,7 +571,7 @@ impl Copilot { cx.notify(); - cx.foreground().spawn(start_task) + cx.executor().spawn(start_task) } pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc)> { @@ -594,40 +597,42 @@ impl Copilot { return; } - registered_buffers.entry(buffer.id()).or_insert_with(|| { - let uri: lsp2::Url = uri_for_buffer(buffer, cx); - let language_id = id_for_language(buffer.read(cx).language()); - let snapshot = buffer.read(cx).snapshot(); - server - .notify::( - lsp2::DidOpenTextDocumentParams { - text_document: lsp2::TextDocumentItem { - uri: uri.clone(), - language_id: language_id.clone(), - version: 0, - text: snapshot.text(), + registered_buffers + .entry(buffer.entity_id()) + .or_insert_with(|| { + let uri: lsp2::Url = uri_for_buffer(buffer, cx); + let language_id = id_for_language(buffer.read(cx).language()); + let snapshot = buffer.read(cx).snapshot(); + server + .notify::( + lsp2::DidOpenTextDocumentParams { + text_document: lsp2::TextDocumentItem { + uri: uri.clone(), + language_id: language_id.clone(), + version: 0, + text: snapshot.text(), + }, }, - }, - ) - .log_err(); + ) + .log_err(); - RegisteredBuffer { - uri, - language_id, - snapshot, - snapshot_version: 0, - pending_buffer_change: Task::ready(Some(())), - _subscriptions: [ - cx.subscribe(buffer, |this, buffer, event, cx| { - this.handle_buffer_event(buffer, event, cx).log_err(); - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - this.buffers.remove(&weak_buffer); - this.unregister_buffer(&weak_buffer); - }), - ], - } - }); + RegisteredBuffer { + uri, + language_id, + snapshot, + snapshot_version: 0, + pending_buffer_change: Task::ready(Some(())), + _subscriptions: [ + cx.subscribe(buffer, |this, buffer, event, cx| { + this.handle_buffer_event(buffer, event, cx).log_err(); + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&weak_buffer); + this.unregister_buffer(&weak_buffer); + }), + ], + } + }); } } @@ -638,7 +643,8 @@ impl Copilot { cx: &mut ModelContext, ) -> Result<()> { if let Ok(server) = self.server.as_running() { - if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.id()) { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id()) + { match event { language2::Event::Edited => { let _ = registered_buffer.report_changes(&buffer, cx); @@ -694,7 +700,7 @@ impl Copilot { fn unregister_buffer(&mut self, buffer: &WeakHandle) { if let Ok(server) = self.server.as_running() { - if let Some(buffer) = server.registered_buffers.remove(&buffer.id()) { + if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) { server .lsp .notify::( @@ -746,7 +752,7 @@ impl Copilot { .request::(request::NotifyAcceptedParams { uuid: completion.uuid.clone(), }); - cx.background().spawn(async move { + cx.executor().spawn(async move { request.await?; Ok(()) }) @@ -770,7 +776,7 @@ impl Copilot { .map(|completion| completion.uuid.clone()) .collect(), }); - cx.background().spawn(async move { + cx.executor().spawn(async move { request.await?; Ok(()) }) @@ -797,7 +803,10 @@ impl Copilot { Err(error) => return Task::ready(Err(error)), }; let lsp = server.lsp.clone(); - let registered_buffer = server.registered_buffers.get_mut(&buffer.id()).unwrap(); + let registered_buffer = server + .registered_buffers + .get_mut(&buffer.entity_id()) + .unwrap(); let snapshot = registered_buffer.report_changes(buffer, cx); let buffer = buffer.read(cx); let uri = registered_buffer.uri.clone(); @@ -810,7 +819,7 @@ impl Copilot { .map(|file| file.path().to_path_buf()) .unwrap_or_default(); - cx.foreground().spawn(async move { + cx.executor().spawn(async move { let (version, snapshot) = snapshot.await?; let result = lsp .request::(request::GetCompletionsParams { @@ -867,7 +876,7 @@ impl Copilot { lsp_status: request::SignInStatus, cx: &mut ModelContext, ) { - self.buffers.retain(|buffer| buffer.is_upgradable(cx)); + self.buffers.retain(|buffer| buffer.is_upgradable()); if let Ok(server) = self.server.as_running() { match lsp_status { @@ -876,20 +885,20 @@ impl Copilot { | request::SignInStatus::AlreadySignedIn { .. } => { server.sign_in_status = SignInStatus::Authorized; for buffer in self.buffers.iter().cloned().collect::>() { - if let Some(buffer) = buffer.upgrade(cx) { + if let Some(buffer) = buffer.upgrade() { self.register_buffer(&buffer, cx); } } } request::SignInStatus::NotAuthorized { .. } => { server.sign_in_status = SignInStatus::Unauthorized; - for buffer in self.buffers.iter().copied().collect::>() { + for buffer in self.buffers.iter().cloned().collect::>() { self.unregister_buffer(&buffer); } } request::SignInStatus::NotSignedIn => { server.sign_in_status = SignInStatus::SignedOut; - for buffer in self.buffers.iter().copied().collect::>() { + for buffer in self.buffers.iter().cloned().collect::>() { self.unregister_buffer(&buffer); } } @@ -913,7 +922,7 @@ fn uri_for_buffer(buffer: &Handle, cx: &AppContext) -> lsp2::Url { if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { lsp2::Url::from_file_path(file.abs_path(cx)).unwrap() } else { - format!("buffer://{}", buffer.id()).parse().unwrap() + format!("buffer://{}", buffer.entity_id()).parse().unwrap() } } diff --git a/crates/feature_flags2/src/feature_flags2.rs b/crates/feature_flags2/src/feature_flags2.rs index 48534051e7..cc672d85ca 100644 --- a/crates/feature_flags2/src/feature_flags2.rs +++ b/crates/feature_flags2/src/feature_flags2.rs @@ -36,7 +36,8 @@ where where F: Fn(bool, &mut V, &mut ViewContext) + Send + Sync + 'static, { - self.observe_global::(move |v, feature_flags, cx| { + self.observe_global::(move |v, cx| { + let feature_flags = cx.global::(); callback(feature_flags.has_flag(::NAME), v, cx); }) } diff --git a/crates/gpui2/src/app.rs b/crates/gpui2/src/app.rs index 1eb6ff37b1..d04893b092 100644 --- a/crates/gpui2/src/app.rs +++ b/crates/gpui2/src/app.rs @@ -56,7 +56,7 @@ impl App { ); let text_system = Arc::new(TextSystem::new(platform.text_system())); - let entities = EntityMap::new(); + let mut entities = EntityMap::new(); let unit_entity = entities.insert(entities.reserve(), ()); let app_metadata = AppMetadata { os_name: platform.os_name(), @@ -190,7 +190,7 @@ pub struct AppContext { pub(crate) observers: SubscriberSet, pub(crate) event_listeners: SubscriberSet, pub(crate) release_listeners: SubscriberSet, - pub(crate) global_observers: SubscriberSet, + pub(crate) global_observers: SubscriberSet, pub(crate) quit_observers: SubscriberSet<(), QuitHandler>, pub(crate) layout_id_buffer: Vec, // We recycle this memory across layout requests. pub(crate) propagate_event: bool, @@ -427,12 +427,10 @@ impl AppContext { } fn apply_notify_global_observers_effect(&mut self, type_id: TypeId) { - self.pending_global_notifications.insert(type_id); - let global = self.globals_by_type.remove(&type_id).unwrap(); + self.pending_global_notifications.remove(&type_id); self.global_observers .clone() - .retain(&type_id, |observer| observer(global.as_ref(), self)); - self.globals_by_type.insert(type_id, global); + .retain(&type_id, |observer| observer(self)); } pub fn to_async(&self) -> AsyncAppContext { @@ -563,12 +561,12 @@ impl AppContext { pub fn observe_global( &mut self, - f: impl Fn(&G, &mut Self) + Send + Sync + 'static, + f: impl Fn(&mut Self) + Send + Sync + 'static, ) -> Subscription { self.global_observers.insert( TypeId::of::(), - Box::new(move |global, cx| { - f(global.downcast_ref::().unwrap(), cx); + Box::new(move |cx| { + f(cx); true }), ) @@ -648,7 +646,7 @@ impl Context for AppContext { ) -> Handle { self.update(|cx| { let slot = cx.entities.reserve(); - let entity = build_entity(&mut ModelContext::mutable(cx, slot.id)); + let entity = build_entity(&mut ModelContext::mutable(cx, slot.entity_id)); cx.entities.insert(slot, entity) }) } @@ -660,7 +658,10 @@ impl Context for AppContext { ) -> R { self.update(|cx| { let mut entity = cx.entities.lease(handle); - let result = update(&mut entity, &mut ModelContext::mutable(cx, handle.id)); + let result = update( + &mut entity, + &mut ModelContext::mutable(cx, handle.entity_id), + ); cx.entities.end_lease(entity); result }) diff --git a/crates/gpui2/src/app/entity_map.rs b/crates/gpui2/src/app/entity_map.rs index b62b62705d..c8886ec195 100644 --- a/crates/gpui2/src/app/entity_map.rs +++ b/crates/gpui2/src/app/entity_map.rs @@ -1,10 +1,12 @@ -use crate::Context; +use crate::{AppContext, Context}; use anyhow::{anyhow, Result}; use derive_more::{Deref, DerefMut}; use parking_lot::{RwLock, RwLockUpgradableReadGuard}; use slotmap::{SecondaryMap, SlotMap}; use std::{ any::{Any, TypeId}, + fmt::{self, Display}, + hash::{Hash, Hasher}, marker::PhantomData, mem, sync::{ @@ -15,81 +17,101 @@ use std::{ slotmap::new_key_type! { pub struct EntityId; } -pub(crate) struct EntityMap(Arc>); +impl Display for EntityId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self) + } +} -struct EntityMapState { - ref_counts: SlotMap, +pub(crate) struct EntityMap { entities: SecondaryMap>, - dropped_entities: Vec<(EntityId, Box)>, + ref_counts: Arc>, +} + +struct EntityRefCounts { + counts: SlotMap, + dropped_entity_ids: Vec, } impl EntityMap { pub fn new() -> Self { - Self(Arc::new(RwLock::new(EntityMapState { - ref_counts: SlotMap::with_key(), + Self { entities: SecondaryMap::new(), - dropped_entities: Vec::new(), - }))) + ref_counts: Arc::new(RwLock::new(EntityRefCounts { + counts: SlotMap::with_key(), + dropped_entity_ids: Vec::new(), + })), + } } /// Reserve a slot for an entity, which you can subsequently use with `insert`. pub fn reserve(&self) -> Slot { - let id = self.0.write().ref_counts.insert(1.into()); - Slot(Handle::new(id, Arc::downgrade(&self.0))) + let id = self.ref_counts.write().counts.insert(1.into()); + Slot(Handle::new(id, Arc::downgrade(&self.ref_counts))) } /// Insert an entity into a slot obtained by calling `reserve`. - pub fn insert(&self, slot: Slot, entity: T) -> Handle { + pub fn insert( + &mut self, + slot: Slot, + entity: T, + ) -> Handle { let handle = slot.0; - self.0.write().entities.insert(handle.id, Box::new(entity)); + self.entities.insert(handle.entity_id, Box::new(entity)); handle } /// Move an entity to the stack. - pub fn lease(&self, handle: &Handle) -> Lease { - let id = handle.id; + pub fn lease<'a, T: 'static + Send + Sync>(&mut self, handle: &'a Handle) -> Lease<'a, T> { let entity = Some( - self.0 - .write() - .entities - .remove(id) + self.entities + .remove(handle.entity_id) .expect("Circular entity lease. Is the entity already being updated?") .downcast::() .unwrap(), ); - Lease { id, entity } + Lease { handle, entity } } /// Return an entity after moving it to the stack. pub fn end_lease(&mut self, mut lease: Lease) { - self.0 - .write() - .entities - .insert(lease.id, lease.entity.take().unwrap()); + self.entities + .insert(lease.handle.entity_id, lease.entity.take().unwrap()); + } + + pub fn read(&self, handle: &Handle) -> &T { + self.entities[handle.entity_id].downcast_ref().unwrap() } pub fn weak_handle(&self, id: EntityId) -> WeakHandle { WeakHandle { any_handle: AnyWeakHandle { - id, + entity_id: id, entity_type: TypeId::of::(), - entity_map: Arc::downgrade(&self.0), + entity_ref_counts: Arc::downgrade(&self.ref_counts), }, entity_type: PhantomData, } } - pub fn take_dropped(&self) -> Vec<(EntityId, Box)> { - mem::take(&mut self.0.write().dropped_entities) + pub fn take_dropped(&mut self) -> Vec<(EntityId, Box)> { + let dropped_entity_ids = mem::take(&mut self.ref_counts.write().dropped_entity_ids); + dropped_entity_ids + .into_iter() + .map(|entity_id| (entity_id, self.entities.remove(entity_id).unwrap())) + .collect() } } -pub struct Lease { +pub struct Lease<'a, T: Send + Sync> { entity: Option>, - pub id: EntityId, + pub handle: &'a Handle, } -impl core::ops::Deref for Lease { +impl<'a, T> core::ops::Deref for Lease<'a, T> +where + T: Send + Sync, +{ type Target = T; fn deref(&self) -> &Self::Target { @@ -97,13 +119,19 @@ impl core::ops::Deref for Lease { } } -impl core::ops::DerefMut for Lease { +impl<'a, T> core::ops::DerefMut for Lease<'a, T> +where + T: Send + Sync, +{ fn deref_mut(&mut self) -> &mut Self::Target { self.entity.as_mut().unwrap() } } -impl Drop for Lease { +impl<'a, T> Drop for Lease<'a, T> +where + T: Send + Sync, +{ fn drop(&mut self) { if self.entity.is_some() { // We don't panic here, because other panics can cause us to drop the lease without ending it cleanly. @@ -116,25 +144,29 @@ impl Drop for Lease { pub struct Slot(Handle); pub struct AnyHandle { - pub(crate) id: EntityId, + pub(crate) entity_id: EntityId, entity_type: TypeId, - entity_map: Weak>, + entity_map: Weak>, } impl AnyHandle { - fn new(id: EntityId, entity_type: TypeId, entity_map: Weak>) -> Self { + fn new(id: EntityId, entity_type: TypeId, entity_map: Weak>) -> Self { Self { - id, + entity_id: id, entity_type, entity_map, } } + pub fn entity_id(&self) -> EntityId { + self.entity_id + } + pub fn downgrade(&self) -> AnyWeakHandle { AnyWeakHandle { - id: self.id, + entity_id: self.entity_id, entity_type: self.entity_type, - entity_map: self.entity_map.clone(), + entity_ref_counts: self.entity_map.clone(), } } @@ -158,15 +190,15 @@ impl Clone for AnyHandle { if let Some(entity_map) = self.entity_map.upgrade() { let entity_map = entity_map.read(); let count = entity_map - .ref_counts - .get(self.id) + .counts + .get(self.entity_id) .expect("detected over-release of a handle"); let prev_count = count.fetch_add(1, SeqCst); assert_ne!(prev_count, 0, "Detected over-release of a handle."); } Self { - id: self.id, + entity_id: self.entity_id, entity_type: self.entity_type, entity_map: self.entity_map.clone(), } @@ -178,20 +210,16 @@ impl Drop for AnyHandle { if let Some(entity_map) = self.entity_map.upgrade() { let entity_map = entity_map.upgradable_read(); let count = entity_map - .ref_counts - .get(self.id) + .counts + .get(self.entity_id) .expect("Detected over-release of a handle."); let prev_count = count.fetch_sub(1, SeqCst); assert_ne!(prev_count, 0, "Detected over-release of a handle."); if prev_count == 1 { // We were the last reference to this entity, so we can remove it. let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map); - let entity = entity_map - .entities - .remove(self.id) - .expect("entity was removed twice"); - entity_map.ref_counts.remove(self.id); - entity_map.dropped_entities.push((self.id, entity)); + entity_map.counts.remove(self.entity_id); + entity_map.dropped_entity_ids.push(self.entity_id); } } } @@ -215,7 +243,7 @@ pub struct Handle { } impl Handle { - fn new(id: EntityId, entity_map: Weak>) -> Self { + fn new(id: EntityId, entity_map: Weak>) -> Self { Self { any_handle: AnyHandle::new(id, TypeId::of::(), entity_map), entity_type: PhantomData, @@ -229,6 +257,10 @@ impl Handle { } } + pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T { + cx.entities.read(self) + } + /// Update the entity referenced by this handle with the given function. /// /// The update function receives a context appropriate for its environment. @@ -254,23 +286,36 @@ impl Clone for Handle { #[derive(Clone)] pub struct AnyWeakHandle { - pub(crate) id: EntityId, + pub(crate) entity_id: EntityId, entity_type: TypeId, - entity_map: Weak>, + entity_ref_counts: Weak>, } impl AnyWeakHandle { + pub fn entity_id(&self) -> EntityId { + self.entity_id + } + + pub fn is_upgradable(&self) -> bool { + let ref_count = self + .entity_ref_counts + .upgrade() + .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst))) + .unwrap_or(0); + ref_count > 0 + } + pub fn upgrade(&self) -> Option { - let entity_map = &self.entity_map.upgrade()?; + let entity_map = self.entity_ref_counts.upgrade()?; entity_map .read() - .ref_counts - .get(self.id)? + .counts + .get(self.entity_id)? .fetch_add(1, SeqCst); Some(AnyHandle { - id: self.id, + entity_id: self.entity_id, entity_type: self.entity_type, - entity_map: self.entity_map.clone(), + entity_map: self.entity_ref_counts.clone(), }) } } @@ -284,6 +329,20 @@ where } } +impl Hash for AnyWeakHandle { + fn hash(&self, state: &mut H) { + self.entity_id.hash(state); + } +} + +impl PartialEq for AnyWeakHandle { + fn eq(&self, other: &Self) -> bool { + self.entity_id == other.entity_id + } +} + +impl Eq for AnyWeakHandle {} + #[derive(Deref, DerefMut)] pub struct WeakHandle { #[deref] @@ -331,3 +390,17 @@ impl WeakHandle { ) } } + +impl Hash for WeakHandle { + fn hash(&self, state: &mut H) { + self.any_handle.hash(state); + } +} + +impl PartialEq for WeakHandle { + fn eq(&self, other: &Self) -> bool { + self.any_handle == other.any_handle + } +} + +impl Eq for WeakHandle {} diff --git a/crates/gpui2/src/app/model_context.rs b/crates/gpui2/src/app/model_context.rs index 39d459acef..35d2cc7c86 100644 --- a/crates/gpui2/src/app/model_context.rs +++ b/crates/gpui2/src/app/model_context.rs @@ -4,7 +4,7 @@ use crate::{ }; use derive_more::{Deref, DerefMut}; use futures::FutureExt; -use std::{future::Future, marker::PhantomData}; +use std::{any::TypeId, future::Future, marker::PhantomData}; #[derive(Deref, DerefMut)] pub struct ModelContext<'a, T> { @@ -36,7 +36,7 @@ impl<'a, T: Send + Sync + 'static> ModelContext<'a, T> { let this = self.handle(); let handle = handle.downgrade(); self.app.observers.insert( - handle.id, + handle.entity_id, Box::new(move |cx| { if let Some((this, handle)) = this.upgrade().zip(handle.upgrade()) { this.update(cx, |this, cx| on_notify(this, handle, cx)); @@ -59,7 +59,7 @@ impl<'a, T: Send + Sync + 'static> ModelContext<'a, T> { let this = self.handle(); let handle = handle.downgrade(); self.app.event_listeners.insert( - handle.id, + handle.entity_id, Box::new(move |event, cx| { let event = event.downcast_ref().expect("invalid event type"); if let Some((this, handle)) = this.upgrade().zip(handle.upgrade()) { @@ -85,6 +85,34 @@ impl<'a, T: Send + Sync + 'static> ModelContext<'a, T> { ) } + pub fn observe_release( + &mut self, + handle: &Handle, + on_release: impl Fn(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + Sync + 'static, + ) -> Subscription { + let this = self.handle(); + self.app.release_listeners.insert( + handle.entity_id, + Box::new(move |entity, cx| { + let entity = entity.downcast_mut().expect("invalid entity type"); + if let Some(this) = this.upgrade() { + this.update(cx, |this, cx| on_release(this, entity, cx)); + } + }), + ) + } + + pub fn observe_global( + &mut self, + f: impl Fn(&mut T, &mut ModelContext<'_, T>) + Send + Sync + 'static, + ) -> Subscription { + let handle = self.handle(); + self.global_observers.insert( + TypeId::of::(), + Box::new(move |cx| handle.update(cx, |view, cx| f(view, cx)).is_ok()), + ) + } + pub fn on_app_quit( &mut self, on_quit: impl Fn(&mut T, &mut ModelContext) -> Fut + Send + Sync + 'static, @@ -107,23 +135,6 @@ impl<'a, T: Send + Sync + 'static> ModelContext<'a, T> { ) } - pub fn observe_release( - &mut self, - handle: &Handle, - on_release: impl Fn(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + Sync + 'static, - ) -> Subscription { - let this = self.handle(); - self.app.release_listeners.insert( - handle.id, - Box::new(move |entity, cx| { - let entity = entity.downcast_mut().expect("invalid entity type"); - if let Some(this) = this.upgrade() { - this.update(cx, |this, cx| on_release(this, entity, cx)); - } - }), - ) - } - pub fn notify(&mut self) { if self.app.pending_notifications.insert(self.entity_id) { self.app.pending_effects.push_back(Effect::Notify { diff --git a/crates/gpui2/src/view.rs b/crates/gpui2/src/view.rs index 92914b7279..a8a4d650d6 100644 --- a/crates/gpui2/src/view.rs +++ b/crates/gpui2/src/view.rs @@ -58,7 +58,7 @@ impl Element for View { type ElementState = AnyElement; fn id(&self) -> Option { - Some(ElementId::View(self.state.id)) + Some(ElementId::View(self.state.entity_id)) } fn initialize( @@ -159,7 +159,7 @@ trait ViewObject: 'static + Send + Sync { impl ViewObject for View { fn entity_id(&self) -> EntityId { - self.state.id + self.state.entity_id } fn initialize(&mut self, cx: &mut WindowContext) -> AnyBox { diff --git a/crates/gpui2/src/window.rs b/crates/gpui2/src/window.rs index cfd800fbe1..6712e4cec6 100644 --- a/crates/gpui2/src/window.rs +++ b/crates/gpui2/src/window.rs @@ -1019,15 +1019,12 @@ impl<'a, 'w> WindowContext<'a, 'w> { pub fn observe_global( &mut self, - f: impl Fn(&G, &mut WindowContext<'_, '_>) + Send + Sync + 'static, + f: impl Fn(&mut WindowContext<'_, '_>) + Send + Sync + 'static, ) -> Subscription { let window_id = self.window.handle.id; self.global_observers.insert( TypeId::of::(), - Box::new(move |global, cx| { - let global = global.downcast_ref::().unwrap(); - cx.update_window(window_id, |cx| f(global, cx)).is_ok() - }), + Box::new(move |cx| cx.update_window(window_id, |cx| f(cx)).is_ok()), ) } @@ -1128,7 +1125,7 @@ impl Context for WindowContext<'_, '_> { let entity = build_entity(&mut ViewContext::mutable( &mut *self.app, &mut self.window, - slot.id, + slot.entity_id, )); self.entities.insert(slot, entity) } @@ -1141,7 +1138,7 @@ impl Context for WindowContext<'_, '_> { let mut entity = self.entities.lease(handle); let result = update( &mut *entity, - &mut ViewContext::mutable(&mut *self.app, &mut *self.window, handle.id), + &mut ViewContext::mutable(&mut *self.app, &mut *self.window, handle.entity_id), ); self.entities.end_lease(entity); result @@ -1352,7 +1349,7 @@ impl<'a, 'w, V: Send + Sync + 'static> ViewContext<'a, 'w, V> { let handle = handle.downgrade(); let window_handle = self.window.handle; self.app.observers.insert( - handle.id, + handle.entity_id, Box::new(move |cx| { cx.update_window(window_handle.id, |cx| { if let Some(handle) = handle.upgrade() { @@ -1379,7 +1376,7 @@ impl<'a, 'w, V: Send + Sync + 'static> ViewContext<'a, 'w, V> { let handle = handle.downgrade(); let window_handle = self.window.handle; self.app.event_listeners.insert( - handle.id, + handle.entity_id, Box::new(move |event, cx| { cx.update_window(window_handle.id, |cx| { if let Some(handle) = handle.upgrade() { @@ -1418,7 +1415,7 @@ impl<'a, 'w, V: Send + Sync + 'static> ViewContext<'a, 'w, V> { let this = self.handle(); let window_handle = self.window.handle; self.app.release_listeners.insert( - handle.id, + handle.entity_id, Box::new(move |entity, cx| { let entity = entity.downcast_mut().expect("invalid entity type"); // todo!("are we okay with silently swallowing the error?") @@ -1578,16 +1575,15 @@ impl<'a, 'w, V: Send + Sync + 'static> ViewContext<'a, 'w, V> { pub fn observe_global( &mut self, - f: impl Fn(&mut V, &G, &mut ViewContext<'_, '_, V>) + Send + Sync + 'static, + f: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) + Send + Sync + 'static, ) -> Subscription { let window_id = self.window.handle.id; let handle = self.handle(); self.global_observers.insert( TypeId::of::(), - Box::new(move |global, cx| { - let global = global.downcast_ref::().unwrap(); + Box::new(move |cx| { cx.update_window(window_id, |cx| { - handle.update(cx, |view, cx| f(view, global, cx)).is_ok() + handle.update(cx, |view, cx| f(view, cx)).is_ok() }) .unwrap_or(false) }),