Load language models in the background

This commit is contained in:
Antonio Scandurra 2024-01-24 13:36:44 +01:00
parent 92b0184036
commit 22046ef9a7
4 changed files with 83 additions and 69 deletions

View file

@ -201,8 +201,10 @@ pub struct OpenAICompletionProvider {
} }
impl OpenAICompletionProvider { impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self { pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
let model = OpenAILanguageModel::load(model_name); let model = executor
.spawn(async move { OpenAILanguageModel::load(&model_name) })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self { Self {
model, model,

View file

@ -67,11 +67,14 @@ struct OpenAIEmbeddingUsage {
} }
impl OpenAIEmbeddingProvider { impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self { pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002"); // Loading the model is expensive, so ensure this runs off the main thread.
let model = executor
.spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider { OpenAIEmbeddingProvider {

View file

@ -31,9 +31,9 @@ use fs::Fs;
use futures::StreamExt; use futures::StreamExt;
use gpui::{ use gpui::{
canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext, canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, FocusHandle, AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model, FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString, IntoElement, Model, ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle, StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext, View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
}; };
@ -123,6 +123,10 @@ impl AssistantPanel {
.await .await
.log_err() .log_err()
.unwrap_or_default(); .unwrap_or_default();
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider =
OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
.await;
// TODO: deserialize state. // TODO: deserialize state.
let workspace_handle = workspace.clone(); let workspace_handle = workspace.clone();
@ -156,11 +160,6 @@ impl AssistantPanel {
}); });
let semantic_index = SemanticIndex::global(cx); let semantic_index = SemanticIndex::global(cx);
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4",
cx.background_executor().clone(),
));
let focus_handle = cx.focus_handle(); let focus_handle = cx.focus_handle();
cx.on_focus_in(&focus_handle, Self::focus_in).detach(); cx.on_focus_in(&focus_handle, Self::focus_in).detach();
@ -176,7 +175,7 @@ impl AssistantPanel {
zoomed: false, zoomed: false,
focus_handle, focus_handle,
toolbar, toolbar,
completion_provider, completion_provider: Arc::new(completion_provider),
api_key_editor: None, api_key_editor: None,
languages: workspace.app_state().languages.clone(), languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(), fs: workspace.app_state().fs.clone(),
@ -1079,9 +1078,9 @@ impl AssistantPanel {
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?; let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?; let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.new_model(|cx| { let conversation =
Conversation::deserialize(saved_conversation, path.clone(), languages, cx) Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
})?; .await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened // If, by the time we've loaded the conversation, the user has already opened
// the same conversation, we don't want to open it again. // the same conversation, we don't want to open it again.
@ -1462,21 +1461,25 @@ impl Conversation {
} }
} }
fn deserialize( async fn deserialize(
saved_conversation: SavedConversation, saved_conversation: SavedConversation,
path: PathBuf, path: PathBuf,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut AsyncAppContext,
) -> Self { ) -> Result<Model<Self>> {
let id = match saved_conversation.id { let id = match saved_conversation.id {
Some(id) => Some(id), Some(id) => Some(id),
None => Some(Uuid::new_v4().to_string()), None => Some(Uuid::new_v4().to_string()),
}; };
let model = saved_conversation.model; let model = saved_conversation.model;
let completion_provider: Arc<dyn CompletionProvider> = Arc::new( let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
OpenAICompletionProvider::new(model.full_name(), cx.background_executor().clone()), OpenAICompletionProvider::new(
model.full_name().into(),
cx.background_executor().clone(),
)
.await,
); );
completion_provider.retrieve_credentials(cx); cx.update(|cx| completion_provider.retrieve_credentials(cx))?;
let markdown = language_registry.language_for_name("Markdown"); let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new(); let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0); let mut next_message_id = MessageId(0);
@ -1499,32 +1502,34 @@ impl Conversation {
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
buffer buffer
}); })?;
let mut this = Self { cx.new_model(|cx| {
id, let mut this = Self {
message_anchors, id,
messages_metadata: saved_conversation.message_metadata, message_anchors,
next_message_id, messages_metadata: saved_conversation.message_metadata,
summary: Some(Summary { next_message_id,
text: saved_conversation.summary, summary: Some(Summary {
done: true, text: saved_conversation.summary,
}), done: true,
pending_summary: Task::ready(None), }),
completion_count: Default::default(), pending_summary: Task::ready(None),
pending_completions: Default::default(), completion_count: Default::default(),
token_count: None, pending_completions: Default::default(),
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), token_count: None,
pending_token_count: Task::ready(None), max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
model, pending_token_count: Task::ready(None),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], model,
pending_save: Task::ready(Ok(())), _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
path: Some(path), pending_save: Task::ready(Ok(())),
buffer, path: Some(path),
completion_provider, buffer,
}; completion_provider,
this.count_remaining_tokens(cx); };
this this.count_remaining_tokens(cx);
this
})
} }
fn handle_buffer_event( fn handle_buffer_event(
@ -3169,7 +3174,7 @@ mod tests {
use super::*; use super::*;
use crate::MessageId; use crate::MessageId;
use ai::test::FakeCompletionProvider; use ai::test::FakeCompletionProvider;
use gpui::AppContext; use gpui::{AppContext, TestAppContext};
use settings::SettingsStore; use settings::SettingsStore;
#[gpui::test] #[gpui::test]
@ -3487,16 +3492,17 @@ mod tests {
} }
#[gpui::test] #[gpui::test]
fn test_serialization(cx: &mut AppContext) { async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
init(cx); cx.update(init);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Arc::new(FakeCompletionProvider::new()); let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation = let conversation =
cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
let message_0 = conversation.read(cx).message_anchors[0].id; let message_0 =
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
let message_1 = conversation.update(cx, |conversation, cx| { let message_1 = conversation.update(cx, |conversation, cx| {
conversation conversation
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
@ -3517,9 +3523,9 @@ mod tests {
.unwrap() .unwrap()
}); });
buffer.update(cx, |buffer, cx| buffer.undo(cx)); buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(buffer.read(cx).text(), "a\nb\nc\n"); assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
assert_eq!( assert_eq!(
messages(&conversation, cx), cx.read(|cx| messages(&conversation, cx)),
[ [
(message_0, Role::User, 0..2), (message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6), (message_1.id, Role::Assistant, 2..6),
@ -3527,18 +3533,22 @@ mod tests {
] ]
); );
let deserialized_conversation = cx.new_model(|cx| { let deserialized_conversation = Conversation::deserialize(
Conversation::deserialize( conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
conversation.read(cx).serialize(cx), Default::default(),
Default::default(), registry.clone(),
registry.clone(), &mut cx.to_async(),
cx, )
) .await
}); .unwrap();
let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone(); let deserialized_buffer =
assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n"); deserialized_conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
assert_eq!( assert_eq!(
messages(&deserialized_conversation, cx), deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
"a\nb\nc\n"
);
assert_eq!(
cx.read(|cx| messages(&deserialized_conversation, cx)),
[ [
(message_0, Role::User, 0..2), (message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6), (message_1.id, Role::Assistant, 2..6),

View file

@ -90,13 +90,12 @@ pub fn init(
.detach(); .detach();
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
let embedding_provider =
OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
let semantic_index = SemanticIndex::new( let semantic_index = SemanticIndex::new(
fs, fs,
db_file_path, db_file_path,
Arc::new(OpenAIEmbeddingProvider::new( Arc::new(embedding_provider),
http_client,
cx.background_executor().clone(),
)),
language_registry, language_registry,
cx.clone(), cx.clone(),
) )