diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index da9ebd5a1d..5b9bad4870 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,7 +1,11 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use gpui::AppContext; -use crate::models::LanguageModel; +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + models::LanguageModel, +}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; @@ -9,6 +13,10 @@ pub trait CompletionRequest: Send + Sync { pub trait CompletionProvider { fn base_model(&self) -> Box; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 20f72c0ff7..9c9d205ff7 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,11 +13,12 @@ use std::{ }; use crate::{ + auth::CredentialProvider, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; -use super::OpenAILanguageModel; +use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -186,6 +187,7 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, api_key: String, executor: Arc, } @@ -193,8 +195,10 @@ pub struct OpenAICompletionProvider { impl OpenAICompletionProvider { pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); + let credential_provider = OpenAICredentialProvider {}; Self { model, + credential_provider, api_key, executor, } @@ -206,6 +210,10 @@ impl CompletionProvider for OpenAICompletionProvider { let model: Box = Box::new(self.model.clone()); model } + fn credential_provider(&self) -> Box { + let provider: Box = Box::new(self.credential_provider.clone()); + provider + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 64f568da1a..dafc94580d 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -11,7 +11,6 @@ use parking_lot::Mutex; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; -use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index 2c78027b62..b8f99af400 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -155,6 +155,9 @@ impl CompletionProvider for TestCompletionProvider { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } + fn credential_provider(&self) -> Box { + Box::new(NullCredentialProvider {}) + } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 3516fc3708..7f4c95f655 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -336,18 +336,13 @@ fn strip_markdown_codeblock( mod tests { use super::*; use ai::test::TestCompletionProvider; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; + use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; use rand::prelude::*; use serde::Serialize; use settings::SettingsStore; - use smol::future::FutureExt; #[derive(Serialize)] pub struct DummyCompletionRequest {