Load language models in the background
This commit is contained in:
parent
92b0184036
commit
22046ef9a7
4 changed files with 83 additions and 69 deletions
|
@ -201,8 +201,10 @@ pub struct OpenAICompletionProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
impl OpenAICompletionProvider {
|
||||||
pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
|
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
|
||||||
let model = OpenAILanguageModel::load(model_name);
|
let model = executor
|
||||||
|
.spawn(async move { OpenAILanguageModel::load(&model_name) })
|
||||||
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -67,11 +67,14 @@ struct OpenAIEmbeddingUsage {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIEmbeddingProvider {
|
impl OpenAIEmbeddingProvider {
|
||||||
pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
// Loading the model is expensive, so ensure this runs off the main thread.
|
||||||
|
let model = executor
|
||||||
|
.spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") })
|
||||||
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
OpenAIEmbeddingProvider {
|
OpenAIEmbeddingProvider {
|
||||||
|
|
|
@ -31,9 +31,9 @@ use fs::Fs;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use gpui::{
|
use gpui::{
|
||||||
canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
|
canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
|
||||||
AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, FocusHandle,
|
AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
|
||||||
FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
|
FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
|
||||||
ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
|
IntoElement, Model, ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
|
||||||
StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
|
StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
|
||||||
View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
|
View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
|
||||||
};
|
};
|
||||||
|
@ -123,6 +123,10 @@ impl AssistantPanel {
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
|
let completion_provider =
|
||||||
|
OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
|
||||||
|
.await;
|
||||||
|
|
||||||
// TODO: deserialize state.
|
// TODO: deserialize state.
|
||||||
let workspace_handle = workspace.clone();
|
let workspace_handle = workspace.clone();
|
||||||
|
@ -156,11 +160,6 @@ impl AssistantPanel {
|
||||||
});
|
});
|
||||||
|
|
||||||
let semantic_index = SemanticIndex::global(cx);
|
let semantic_index = SemanticIndex::global(cx);
|
||||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
|
||||||
let completion_provider = Arc::new(OpenAICompletionProvider::new(
|
|
||||||
"gpt-4",
|
|
||||||
cx.background_executor().clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let focus_handle = cx.focus_handle();
|
let focus_handle = cx.focus_handle();
|
||||||
cx.on_focus_in(&focus_handle, Self::focus_in).detach();
|
cx.on_focus_in(&focus_handle, Self::focus_in).detach();
|
||||||
|
@ -176,7 +175,7 @@ impl AssistantPanel {
|
||||||
zoomed: false,
|
zoomed: false,
|
||||||
focus_handle,
|
focus_handle,
|
||||||
toolbar,
|
toolbar,
|
||||||
completion_provider,
|
completion_provider: Arc::new(completion_provider),
|
||||||
api_key_editor: None,
|
api_key_editor: None,
|
||||||
languages: workspace.app_state().languages.clone(),
|
languages: workspace.app_state().languages.clone(),
|
||||||
fs: workspace.app_state().fs.clone(),
|
fs: workspace.app_state().fs.clone(),
|
||||||
|
@ -1079,9 +1078,9 @@ impl AssistantPanel {
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let saved_conversation = fs.load(&path).await?;
|
let saved_conversation = fs.load(&path).await?;
|
||||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation =
|
||||||
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
|
Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
|
||||||
})?;
|
.await?;
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
// If, by the time we've loaded the conversation, the user has already opened
|
// If, by the time we've loaded the conversation, the user has already opened
|
||||||
// the same conversation, we don't want to open it again.
|
// the same conversation, we don't want to open it again.
|
||||||
|
@ -1462,21 +1461,25 @@ impl Conversation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn deserialize(
|
async fn deserialize(
|
||||||
saved_conversation: SavedConversation,
|
saved_conversation: SavedConversation,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut AsyncAppContext,
|
||||||
) -> Self {
|
) -> Result<Model<Self>> {
|
||||||
let id = match saved_conversation.id {
|
let id = match saved_conversation.id {
|
||||||
Some(id) => Some(id),
|
Some(id) => Some(id),
|
||||||
None => Some(Uuid::new_v4().to_string()),
|
None => Some(Uuid::new_v4().to_string()),
|
||||||
};
|
};
|
||||||
let model = saved_conversation.model;
|
let model = saved_conversation.model;
|
||||||
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
||||||
OpenAICompletionProvider::new(model.full_name(), cx.background_executor().clone()),
|
OpenAICompletionProvider::new(
|
||||||
|
model.full_name().into(),
|
||||||
|
cx.background_executor().clone(),
|
||||||
|
)
|
||||||
|
.await,
|
||||||
);
|
);
|
||||||
completion_provider.retrieve_credentials(cx);
|
cx.update(|cx| completion_provider.retrieve_credentials(cx))?;
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let mut message_anchors = Vec::new();
|
let mut message_anchors = Vec::new();
|
||||||
let mut next_message_id = MessageId(0);
|
let mut next_message_id = MessageId(0);
|
||||||
|
@ -1499,8 +1502,9 @@ impl Conversation {
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
buffer
|
buffer
|
||||||
});
|
})?;
|
||||||
|
|
||||||
|
cx.new_model(|cx| {
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
id,
|
id,
|
||||||
message_anchors,
|
message_anchors,
|
||||||
|
@ -1525,6 +1529,7 @@ impl Conversation {
|
||||||
};
|
};
|
||||||
this.count_remaining_tokens(cx);
|
this.count_remaining_tokens(cx);
|
||||||
this
|
this
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_buffer_event(
|
fn handle_buffer_event(
|
||||||
|
@ -3169,7 +3174,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::MessageId;
|
use crate::MessageId;
|
||||||
use ai::test::FakeCompletionProvider;
|
use ai::test::FakeCompletionProvider;
|
||||||
use gpui::AppContext;
|
use gpui::{AppContext, TestAppContext};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -3487,16 +3492,17 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_serialization(cx: &mut AppContext) {
|
async fn test_serialization(cx: &mut TestAppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
init(cx);
|
cx.update(init);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
let completion_provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let conversation =
|
let conversation =
|
||||||
cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
|
||||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
let message_0 =
|
||||||
|
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
|
||||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||||
conversation
|
conversation
|
||||||
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
|
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
|
||||||
|
@ -3517,9 +3523,9 @@ mod tests {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||||
assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
|
assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
messages(&conversation, cx),
|
cx.read(|cx| messages(&conversation, cx)),
|
||||||
[
|
[
|
||||||
(message_0, Role::User, 0..2),
|
(message_0, Role::User, 0..2),
|
||||||
(message_1.id, Role::Assistant, 2..6),
|
(message_1.id, Role::Assistant, 2..6),
|
||||||
|
@ -3527,18 +3533,22 @@ mod tests {
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let deserialized_conversation = cx.new_model(|cx| {
|
let deserialized_conversation = Conversation::deserialize(
|
||||||
Conversation::deserialize(
|
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
|
||||||
conversation.read(cx).serialize(cx),
|
|
||||||
Default::default(),
|
Default::default(),
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
cx,
|
&mut cx.to_async(),
|
||||||
)
|
)
|
||||||
});
|
.await
|
||||||
let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
|
.unwrap();
|
||||||
assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
|
let deserialized_buffer =
|
||||||
|
deserialized_conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
messages(&deserialized_conversation, cx),
|
deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||||
|
"a\nb\nc\n"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
cx.read(|cx| messages(&deserialized_conversation, cx)),
|
||||||
[
|
[
|
||||||
(message_0, Role::User, 0..2),
|
(message_0, Role::User, 0..2),
|
||||||
(message_1.id, Role::Assistant, 2..6),
|
(message_1.id, Role::Assistant, 2..6),
|
||||||
|
|
|
@ -90,13 +90,12 @@ pub fn init(
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
cx.spawn(move |cx| async move {
|
cx.spawn(move |cx| async move {
|
||||||
|
let embedding_provider =
|
||||||
|
OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddingProvider::new(
|
Arc::new(embedding_provider),
|
||||||
http_client,
|
|
||||||
cx.background_executor().clone(),
|
|
||||||
)),
|
|
||||||
language_registry,
|
language_registry,
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue