replace api_key with ProviderCredential throughout the AssistantPanel
This commit is contained in:
parent
558f54c424
commit
1e8b23d8fb
5 changed files with 208 additions and 121 deletions
|
@ -9,6 +9,8 @@ pub enum ProviderCredential {
|
||||||
|
|
||||||
pub trait CredentialProvider: Send + Sync {
|
pub trait CredentialProvider: Send + Sync {
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||||
|
fn delete_credentials(&self, cx: &AppContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider {
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
ProviderCredential::NotNeeded
|
ProviderCredential::NotNeeded
|
||||||
}
|
}
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,12 @@ pub trait CompletionProvider {
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
self.credential_provider().retrieve_credentials(cx)
|
self.credential_provider().retrieve_credentials(cx)
|
||||||
}
|
}
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
self.credential_provider().save_credentials(cx, credential);
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
self.credential_provider().delete_credentials(cx);
|
||||||
|
}
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: Box<dyn CompletionRequest>,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
|
|
@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider {
|
||||||
ProviderCredential::NoCredentials
|
ProviderCredential::NoCredentials
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::CredentialProvider,
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
models::LanguageModel,
|
models::LanguageModel,
|
||||||
};
|
};
|
||||||
|
@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
api_key: String,
|
credential: ProviderCredential,
|
||||||
executor: Arc<Background>,
|
executor: Arc<Background>,
|
||||||
request: Box<dyn CompletionRequest>,
|
request: Box<dyn CompletionRequest>,
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
let api_key = match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => api_key,
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("no credentials provider for completion"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
let json_data = request.data()?;
|
let json_data = request.data()?;
|
||||||
|
@ -188,18 +195,22 @@ pub async fn stream_completion(
|
||||||
pub struct OpenAICompletionProvider {
|
pub struct OpenAICompletionProvider {
|
||||||
model: OpenAILanguageModel,
|
model: OpenAILanguageModel,
|
||||||
credential_provider: OpenAICredentialProvider,
|
credential_provider: OpenAICredentialProvider,
|
||||||
api_key: String,
|
credential: ProviderCredential,
|
||||||
executor: Arc<Background>,
|
executor: Arc<Background>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
impl OpenAICompletionProvider {
|
||||||
pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
|
pub fn new(
|
||||||
|
model_name: &str,
|
||||||
|
credential: ProviderCredential,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
) -> Self {
|
||||||
let model = OpenAILanguageModel::load(model_name);
|
let model = OpenAILanguageModel::load(model_name);
|
||||||
let credential_provider = OpenAICredentialProvider {};
|
let credential_provider = OpenAICredentialProvider {};
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
credential_provider,
|
credential_provider,
|
||||||
api_key,
|
credential,
|
||||||
executor,
|
executor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider {
|
||||||
&self,
|
&self,
|
||||||
prompt: Box<dyn CompletionRequest>,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
let credential = self.credential.clone();
|
||||||
|
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||||
async move {
|
async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
|
|
|
@ -7,7 +7,8 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use ai::{
|
use ai::{
|
||||||
completion::CompletionRequest,
|
auth::ProviderCredential,
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
providers::open_ai::{
|
providers::open_ai::{
|
||||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||||
},
|
},
|
||||||
|
@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) {
|
||||||
cx.capture_action(ConversationEditor::copy);
|
cx.capture_action(ConversationEditor::copy);
|
||||||
cx.add_action(ConversationEditor::split);
|
cx.add_action(ConversationEditor::split);
|
||||||
cx.capture_action(ConversationEditor::cycle_message_role);
|
cx.capture_action(ConversationEditor::cycle_message_role);
|
||||||
cx.add_action(AssistantPanel::save_api_key);
|
cx.add_action(AssistantPanel::save_credentials);
|
||||||
cx.add_action(AssistantPanel::reset_api_key);
|
cx.add_action(AssistantPanel::reset_credentials);
|
||||||
cx.add_action(AssistantPanel::toggle_zoom);
|
cx.add_action(AssistantPanel::toggle_zoom);
|
||||||
cx.add_action(AssistantPanel::deploy);
|
cx.add_action(AssistantPanel::deploy);
|
||||||
cx.add_action(AssistantPanel::select_next_match);
|
cx.add_action(AssistantPanel::select_next_match);
|
||||||
|
@ -143,7 +144,8 @@ pub struct AssistantPanel {
|
||||||
zoomed: bool,
|
zoomed: bool,
|
||||||
has_focus: bool,
|
has_focus: bool,
|
||||||
toolbar: ViewHandle<Toolbar>,
|
toolbar: ViewHandle<Toolbar>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
credential: Rc<RefCell<ProviderCredential>>,
|
||||||
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
api_key_editor: Option<ViewHandle<Editor>>,
|
api_key_editor: Option<ViewHandle<Editor>>,
|
||||||
has_read_credentials: bool,
|
has_read_credentials: bool,
|
||||||
languages: Arc<LanguageRegistry>,
|
languages: Arc<LanguageRegistry>,
|
||||||
|
@ -205,6 +207,12 @@ impl AssistantPanel {
|
||||||
});
|
});
|
||||||
|
|
||||||
let semantic_index = SemanticIndex::global(cx);
|
let semantic_index = SemanticIndex::global(cx);
|
||||||
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
|
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||||
|
"gpt-4",
|
||||||
|
ProviderCredential::NoCredentials,
|
||||||
|
cx.background().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
workspace: workspace_handle,
|
workspace: workspace_handle,
|
||||||
|
@ -216,7 +224,8 @@ impl AssistantPanel {
|
||||||
zoomed: false,
|
zoomed: false,
|
||||||
has_focus: false,
|
has_focus: false,
|
||||||
toolbar,
|
toolbar,
|
||||||
api_key: Rc::new(RefCell::new(None)),
|
credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)),
|
||||||
|
completion_provider,
|
||||||
api_key_editor: None,
|
api_key_editor: None,
|
||||||
has_read_credentials: false,
|
has_read_credentials: false,
|
||||||
languages: workspace.app_state().languages.clone(),
|
languages: workspace.app_state().languages.clone(),
|
||||||
|
@ -257,10 +266,7 @@ impl AssistantPanel {
|
||||||
cx: &mut ViewContext<Workspace>,
|
cx: &mut ViewContext<Workspace>,
|
||||||
) {
|
) {
|
||||||
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
||||||
if this
|
if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) {
|
||||||
.update(cx, |assistant, cx| assistant.load_api_key(cx))
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
workspace.focus_panel::<AssistantPanel>(cx);
|
workspace.focus_panel::<AssistantPanel>(cx);
|
||||||
|
@ -292,12 +298,7 @@ impl AssistantPanel {
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
project: &ModelHandle<Project>,
|
project: &ModelHandle<Project>,
|
||||||
) {
|
) {
|
||||||
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
|
let credential = self.credential.borrow().clone();
|
||||||
api_key
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||||
if selection.start.excerpt_id() != selection.end.excerpt_id() {
|
if selection.start.excerpt_id() != selection.end.excerpt_id() {
|
||||||
return;
|
return;
|
||||||
|
@ -329,7 +330,7 @@ impl AssistantPanel {
|
||||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
api_key,
|
credential,
|
||||||
cx.background().clone(),
|
cx.background().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -816,7 +817,7 @@ impl AssistantPanel {
|
||||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
||||||
let editor = cx.add_view(|cx| {
|
let editor = cx.add_view(|cx| {
|
||||||
ConversationEditor::new(
|
ConversationEditor::new(
|
||||||
self.api_key.clone(),
|
self.credential.clone(),
|
||||||
self.languages.clone(),
|
self.languages.clone(),
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
self.workspace.clone(),
|
self.workspace.clone(),
|
||||||
|
@ -875,17 +876,20 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||||
if let Some(api_key) = self
|
if let Some(api_key) = self
|
||||||
.api_key_editor
|
.api_key_editor
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|editor| editor.read(cx).text(cx))
|
.map(|editor| editor.read(cx).text(cx))
|
||||||
{
|
{
|
||||||
if !api_key.is_empty() {
|
if !api_key.is_empty() {
|
||||||
cx.platform()
|
let credential = ProviderCredential::Credentials {
|
||||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
api_key: api_key.clone(),
|
||||||
.log_err();
|
};
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
self.completion_provider
|
||||||
|
.save_credentials(cx, credential.clone());
|
||||||
|
*self.credential.borrow_mut() = credential;
|
||||||
|
|
||||||
self.api_key_editor.take();
|
self.api_key_editor.take();
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -895,9 +899,9 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
self.completion_provider.delete_credentials(cx);
|
||||||
self.api_key.take();
|
*self.credential.borrow_mut() = ProviderCredential::NoCredentials;
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -1156,13 +1160,19 @@ impl AssistantPanel {
|
||||||
|
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let workspace = self.workspace.clone();
|
let workspace = self.workspace.clone();
|
||||||
let api_key = self.api_key.clone();
|
let credential = self.credential.clone();
|
||||||
let languages = self.languages.clone();
|
let languages = self.languages.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let saved_conversation = fs.load(&path).await?;
|
let saved_conversation = fs.load(&path).await?;
|
||||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||||
let conversation = cx.add_model(|cx| {
|
let conversation = cx.add_model(|cx| {
|
||||||
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
|
Conversation::deserialize(
|
||||||
|
saved_conversation,
|
||||||
|
path.clone(),
|
||||||
|
credential,
|
||||||
|
languages,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
// If, by the time we've loaded the conversation, the user has already opened
|
// If, by the time we've loaded the conversation, the user has already opened
|
||||||
|
@ -1186,30 +1196,39 @@ impl AssistantPanel {
|
||||||
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
|
fn has_credentials(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||||
if self.api_key.borrow().is_none() && !self.has_read_credentials {
|
let credential = self.load_credentials(cx);
|
||||||
self.has_read_credentials = true;
|
match credential {
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
ProviderCredential::Credentials { .. } => true,
|
||||||
Some(api_key)
|
ProviderCredential::NotNeeded => true,
|
||||||
} else if let Some((_, api_key)) = cx
|
ProviderCredential::NoCredentials => false,
|
||||||
.platform()
|
}
|
||||||
.read_credentials(OPENAI_API_URL)
|
}
|
||||||
.log_err()
|
|
||||||
.flatten()
|
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> ProviderCredential {
|
||||||
{
|
let existing_credential = self.credential.clone();
|
||||||
String::from_utf8(api_key).log_err()
|
let existing_credential = existing_credential.borrow().clone();
|
||||||
} else {
|
match existing_credential {
|
||||||
None
|
ProviderCredential::NoCredentials => {
|
||||||
};
|
if !self.has_read_credentials {
|
||||||
if let Some(api_key) = api_key {
|
self.has_read_credentials = true;
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
let retrieved_credentials = self.completion_provider.retrieve_credentials(cx);
|
||||||
} else if self.api_key_editor.is_none() {
|
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
match retrieved_credentials {
|
||||||
cx.notify();
|
ProviderCredential::NoCredentials {} => {
|
||||||
|
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
*self.credential.borrow_mut() = retrieved_credentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.api_key.borrow().clone()
|
self.credential.borrow().clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel {
|
||||||
|
|
||||||
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
||||||
if active {
|
if active {
|
||||||
self.load_api_key(cx);
|
self.load_credentials(cx);
|
||||||
|
|
||||||
if self.editors.is_empty() {
|
if self.editors.is_empty() {
|
||||||
self.new_conversation(cx);
|
self.new_conversation(cx);
|
||||||
|
@ -1459,7 +1478,7 @@ struct Conversation {
|
||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
max_token_count: usize,
|
max_token_count: usize,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
credential: Rc<RefCell<ProviderCredential>>,
|
||||||
pending_save: Task<Result<()>>,
|
pending_save: Task<Result<()>>,
|
||||||
path: Option<PathBuf>,
|
path: Option<PathBuf>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
|
@ -1471,7 +1490,8 @@ impl Entity for Conversation {
|
||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
credential: Rc<RefCell<ProviderCredential>>,
|
||||||
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -1512,7 +1532,7 @@ impl Conversation {
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: None,
|
path: None,
|
||||||
api_key,
|
credential,
|
||||||
buffer,
|
buffer,
|
||||||
};
|
};
|
||||||
let message = MessageAnchor {
|
let message = MessageAnchor {
|
||||||
|
@ -1559,7 +1579,7 @@ impl Conversation {
|
||||||
fn deserialize(
|
fn deserialize(
|
||||||
saved_conversation: SavedConversation,
|
saved_conversation: SavedConversation,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
credential: Rc<RefCell<ProviderCredential>>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -1614,7 +1634,7 @@ impl Conversation {
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: Some(path),
|
path: Some(path),
|
||||||
api_key,
|
credential,
|
||||||
buffer,
|
buffer,
|
||||||
};
|
};
|
||||||
this.count_remaining_tokens(cx);
|
this.count_remaining_tokens(cx);
|
||||||
|
@ -1736,9 +1756,13 @@ impl Conversation {
|
||||||
}
|
}
|
||||||
|
|
||||||
if should_assist {
|
if should_assist {
|
||||||
let Some(api_key) = self.api_key.borrow().clone() else {
|
let credential = self.credential.borrow().clone();
|
||||||
return Default::default();
|
match credential {
|
||||||
};
|
ProviderCredential::NoCredentials => {
|
||||||
|
return Default::default();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
|
@ -1752,7 +1776,7 @@ impl Conversation {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
});
|
});
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = stream_completion(credential, cx.background().clone(), request);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -2018,57 +2042,62 @@ impl Conversation {
|
||||||
|
|
||||||
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
||||||
let api_key = self.api_key.borrow().clone();
|
let credential = self.credential.borrow().clone();
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
let messages = self
|
|
||||||
.messages(cx)
|
|
||||||
.take(2)
|
|
||||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
|
||||||
.chain(Some(RequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content:
|
|
||||||
"Summarize the conversation into a short title without punctuation"
|
|
||||||
.into(),
|
|
||||||
}));
|
|
||||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
|
||||||
model: self.model.full_name().to_string(),
|
|
||||||
messages: messages.collect(),
|
|
||||||
stream: true,
|
|
||||||
stop: vec![],
|
|
||||||
temperature: 1.0,
|
|
||||||
});
|
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
match credential {
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
ProviderCredential::NoCredentials => {
|
||||||
async move {
|
return;
|
||||||
let mut messages = stream.await?;
|
}
|
||||||
|
_ => {}
|
||||||
while let Some(message) = messages.next().await {
|
|
||||||
let mut message = message?;
|
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
let text = choice.delta.content.unwrap_or_default();
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
this.summary
|
|
||||||
.get_or_insert(Default::default())
|
|
||||||
.text
|
|
||||||
.push_str(&text);
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
if let Some(summary) = this.summary.as_mut() {
|
|
||||||
summary.done = true;
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
|
||||||
.log_err()
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let messages = self
|
||||||
|
.messages(cx)
|
||||||
|
.take(2)
|
||||||
|
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||||
|
.chain(Some(RequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: "Summarize the conversation into a short title without punctuation"
|
||||||
|
.into(),
|
||||||
|
}));
|
||||||
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
|
model: self.model.full_name().to_string(),
|
||||||
|
messages: messages.collect(),
|
||||||
|
stream: true,
|
||||||
|
stop: vec![],
|
||||||
|
temperature: 1.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = stream_completion(credential, cx.background().clone(), request);
|
||||||
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
|
async move {
|
||||||
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
|
while let Some(message) = messages.next().await {
|
||||||
|
let mut message = message?;
|
||||||
|
if let Some(choice) = message.choices.pop() {
|
||||||
|
let text = choice.delta.content.unwrap_or_default();
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
this.summary
|
||||||
|
.get_or_insert(Default::default())
|
||||||
|
.text
|
||||||
|
.push_str(&text);
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
if let Some(summary) = this.summary.as_mut() {
|
||||||
|
summary.done = true;
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.log_err()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2229,13 +2258,13 @@ struct ConversationEditor {
|
||||||
|
|
||||||
impl ConversationEditor {
|
impl ConversationEditor {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
credential: Rc<RefCell<ProviderCredential>>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
workspace: WeakViewHandle<Workspace>,
|
workspace: WeakViewHandle<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
|
let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx));
|
||||||
Self::for_conversation(conversation, fs, workspace, cx)
|
Self::for_conversation(conversation, fs, workspace, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3431,7 +3460,13 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let conversation = cx.add_model(|cx| {
|
||||||
|
Conversation::new(
|
||||||
|
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
|
||||||
|
registry,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3559,7 +3594,13 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let conversation = cx.add_model(|cx| {
|
||||||
|
Conversation::new(
|
||||||
|
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
|
||||||
|
registry,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3655,7 +3696,13 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let conversation = cx.add_model(|cx| {
|
||||||
|
Conversation::new(
|
||||||
|
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
|
||||||
|
registry,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3737,8 +3784,13 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation =
|
let conversation = cx.add_model(|cx| {
|
||||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
Conversation::new(
|
||||||
|
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
|
||||||
|
registry.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
let message_0 = conversation.read(cx).message_anchors[0].id;
|
||||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||||
|
@ -3775,7 +3827,7 @@ mod tests {
|
||||||
Conversation::deserialize(
|
Conversation::deserialize(
|
||||||
conversation.read(cx).serialize(cx),
|
conversation.read(cx).serialize(cx),
|
||||||
Default::default(),
|
Default::default(),
|
||||||
Default::default(),
|
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue