Extract completion provider crate (#14823)
We will soon need `semantic_index` to be able to use
`CompletionProvider`. This is currently impossible due to a cyclic crate
dependency, because `CompletionProvider` lives in the `assistant` crate,
which depends on `semantic_index`.
This PR breaks the dependency cycle by extracting two crates out of
`assistant`: `language_model` and `completion`.
Only one piece of logic changed: [this
code](922fcaf5a6 (diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69)
).
* As of https://github.com/zed-industries/zed/pull/13276, whenever we
ask a given completion provider for its available models, OpenAI
providers would go and ask the global assistant settings whether the
user had configured an `available_models` setting, and if so, return
that.
* This PR changes it so that instead of eagerly asking the assistant
settings for this info (the new crate must not depend on `assistant`, or
else the dependency cycle would be back), OpenAI completion providers
now store the user-configured settings as part of their struct, and
whenever the settings change, we update the provider.
In theory, this change should not change user-visible behavior...but
since it's the only change in this large PR that's more than just moving
code around, I'm mentioning it here in case there's an unexpected
regression in practice! (cc @amtoaer in case you'd like to try out this
branch and verify that the feature is still working the way you expect.)
Release Notes:
- N/A
---------
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
b9a53ffa0b
commit
ec487d8f64
30 changed files with 820 additions and 610 deletions
322
crates/completion/src/anthropic.rs
Normal file
322
crates/completion/src/anthropic.rs
Normal file
|
@ -0,0 +1,322 @@
|
|||
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
|
||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||
use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use language_model::Role;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AnthropicCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: AnthropicModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
AnthropicModel::iter()
|
||||
.map(LanguageModel::Anthropic)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Anthropic(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicCompletionProvider {
|
||||
pub fn new(
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||
request.preprocess_anthropic();
|
||||
|
||||
let model = match request.model {
|
||||
LanguageModel::Anthropic(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
let mut system_message = String::new();
|
||||
if request
|
||||
.messages
|
||||
.first()
|
||||
.map_or(false, |message| message.role == Role::System)
|
||||
{
|
||||
system_message = request.messages.remove(0).content;
|
||||
}
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: match msg.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||
},
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
209
crates/completion/src/cloud.rs
Normal file
209
crates/completion/src/cloud.rs
Normal file
|
@ -0,0 +1,209 @@
|
|||
use crate::{
|
||||
count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use language_model::CloudModel;
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct CloudCompletionProvider {
|
||||
client: Arc<Client>,
|
||||
model: CloudModel,
|
||||
settings_version: usize,
|
||||
status: client::Status,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
impl CloudCompletionProvider {
|
||||
pub fn new(
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.status = status;
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
model,
|
||||
settings_version,
|
||||
status,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
|
||||
self.model = model;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
||||
Some(custom_model)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
CloudModel::iter()
|
||||
.filter_map(move |model| {
|
||||
if let CloudModel::Custom(_) = model {
|
||||
Some(CloudModel::Custom(custom_model.take()?))
|
||||
} else {
|
||||
Some(model)
|
||||
}
|
||||
})
|
||||
.map(LanguageModel::Cloud)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Cloud(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
LanguageModel::Cloud(CloudModel::Gpt4)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku,
|
||||
) => {
|
||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(CloudModel::Custom(model)) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
request.preprocess();
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: request.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt;
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(|_, cx| {
|
||||
CompletionProvider::global(cx)
|
||||
.authenticate(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
271
crates/completion/src/completion.rs
Normal file
271
crates/completion/src/completion.rs
Normal file
|
@ -0,0 +1,271 @@
|
|||
mod anthropic;
|
||||
mod cloud;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
mod fake;
|
||||
mod ollama;
|
||||
mod open_ai;
|
||||
|
||||
pub use anthropic::*;
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
pub use cloud::*;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use fake::*;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task, WindowContext};
|
||||
use language_model::{LanguageModel, LanguageModelRequest};
|
||||
pub use ollama::*;
|
||||
pub use open_ai::*;
|
||||
use parking_lot::RwLock;
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
|
||||
|
||||
pub struct CompletionResponse {
|
||||
inner: BoxStream<'static, Result<String>>,
|
||||
_lock: SemaphoreGuardArc,
|
||||
}
|
||||
|
||||
impl futures::Stream for CompletionResponse {
|
||||
type Item = Result<String>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inner).poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelCompletionProvider: Send + Sync {
|
||||
fn available_models(&self) -> Vec<LanguageModel>;
|
||||
fn settings_version(&self) -> usize;
|
||||
fn is_authenticated(&self) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn model(&self) -> LanguageModel;
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>>;
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn Any;
|
||||
}
|
||||
|
||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||
|
||||
pub struct CompletionProvider {
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
request_limiter: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn new(
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
client,
|
||||
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn available_models(&self) -> Vec<LanguageModel> {
|
||||
self.provider.read().available_models()
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.provider.read().settings_version()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.provider.read().is_authenticated()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().authenticate(cx)
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
self.provider.read().authentication_prompt(cx)
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().reset_credentials(cx)
|
||||
}
|
||||
|
||||
pub fn model(&self) -> LanguageModel {
|
||||
self.provider.read().model()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
self.provider.read().count_tokens(request, cx)
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<CompletionResponse>> {
|
||||
let rate_limiter = self.request_limiter.clone();
|
||||
let provider = self.provider.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let lock = rate_limiter.acquire_arc().await;
|
||||
let response = provider.read().stream_completion(request);
|
||||
let response = response.await?;
|
||||
Ok(CompletionResponse {
|
||||
inner: response,
|
||||
_lock: lock,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
||||
let response = self.stream_completion(request, cx);
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let mut chunks = response.await?;
|
||||
let mut completion = String::new();
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
completion.push_str(&chunk);
|
||||
}
|
||||
Ok(completion)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update_provider(
|
||||
&mut self,
|
||||
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
) {
|
||||
if let Some(client) = &self.client {
|
||||
self.provider = get_provider(Arc::clone(client));
|
||||
} else {
|
||||
log::warn!("completion provider cannot be updated because its client was not set");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl gpui::Global for CompletionProvider {}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> &Self {
|
||||
cx.global::<Self>()
|
||||
}
|
||||
|
||||
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
|
||||
&mut self,
|
||||
update: impl FnOnce(&mut T) -> R,
|
||||
) -> Option<R> {
|
||||
let mut provider = self.provider.write();
|
||||
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
|
||||
Some(update(provider))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::AppContext;
|
||||
use parking_lot::RwLock;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
|
||||
use crate::{
|
||||
CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS,
|
||||
};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_rate_limiting(cx: &mut AppContext) {
|
||||
SettingsStore::test(cx);
|
||||
let fake_provider = FakeCompletionProvider::setup_test(cx);
|
||||
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
|
||||
|
||||
// Enqueue some requests
|
||||
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
||||
let response = provider.stream_completion(
|
||||
LanguageModelRequest {
|
||||
temperature: i as f32 / 10.0,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let mut stream = response.await.unwrap();
|
||||
while let Some(message) = stream.next().await {
|
||||
message.unwrap();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Get the first completion request that is in flight and mark it as completed.
|
||||
let completion = fake_provider
|
||||
.pending_completions()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
fake_provider.finish_completion(&completion);
|
||||
|
||||
// Ensure that the number of in-flight completion requests is reduced.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
// Ensure that another completion request was allowed to acquire the lock.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Mark all completion requests as finished that are in flight.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
}
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
|
||||
// Wait until the background tasks acquire the lock again.
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
// Finish all remaining completion requests.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
}
|
||||
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
}
|
||||
}
|
115
crates/completion/src/fake.rs
Normal file
115
crates/completion/src/fake.rs
Normal file
|
@ -0,0 +1,115 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
|
||||
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeCompletionProvider {
|
||||
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn setup_test(cx: &mut AppContext) -> Self {
|
||||
use crate::CompletionProvider;
|
||||
use parking_lot::RwLock;
|
||||
|
||||
let this = Self::default();
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
|
||||
cx.set_global(provider);
|
||||
this
|
||||
}
|
||||
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.keys()
|
||||
.map(|k| serde_json::from_str(k).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn completion_count(&self) -> usize {
|
||||
self.current_completion_txs.lock().len()
|
||||
}
|
||||
|
||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let json = serde_json::to_string(request).unwrap();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.get(&json)
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.remove(&serde_json::to_string(request).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_last_completion(&self) {
|
||||
self.finish_completion(self.pending_completions().last().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
vec![LanguageModel::default()]
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::default()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
futures::future::ready(Ok(0)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.insert(serde_json::to_string(&_request).unwrap(), tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
347
crates/completion/src/ollama.rs
Normal file
347
crates/completion/src/ollama.rs
Normal file
|
@ -0,0 +1,347 @@
|
|||
use crate::LanguageModelCompletionProvider;
|
||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt as _;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use http::HttpClient;
|
||||
use language_model::Role;
|
||||
use ollama::Model as OllamaModel;
|
||||
use ollama::{
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
|
||||
pub struct OllamaCompletionProvider {
|
||||
api_url: String,
|
||||
model: OllamaModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
self.available_models
|
||||
.iter()
|
||||
.map(|m| LanguageModel::Ollama(m.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
!self.available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider
|
||||
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.fetch_models(cx)
|
||||
})
|
||||
.unwrap_or_else(|| Task::ready(Ok(())))
|
||||
})
|
||||
});
|
||||
|
||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Ollama(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
// There is no endpoint for this _yet_ in Ollama
|
||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||
let token_count = request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| msg.content.chars().count())
|
||||
.sum::<usize>()
|
||||
/ 4;
|
||||
|
||||
async move { Ok(token_count) }.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(delta) => {
|
||||
let content = match delta.message {
|
||||
ChatMessage::User { content } => content,
|
||||
ChatMessage::Assistant { content } => content,
|
||||
ChatMessage::System { content } => content,
|
||||
};
|
||||
Some(Ok(content))
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaCompletionProvider {
|
||||
pub fn new(
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) -> Self {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move {
|
||||
if model.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
preload_model(client.as_ref(), &api_url, &model).await
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = self.http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
if model.name.is_empty() {
|
||||
self.select_first_available_model()
|
||||
} else {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn select_first_available_model(&mut self) {
|
||||
if let Some(model) = self.available_models.first() {
|
||||
self.model = model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<OllamaModel> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| OllamaModel::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.available_models = models;
|
||||
|
||||
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
||||
provider.select_first_available_model()
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||
let model = match request.model {
|
||||
LanguageModel::Ollama(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
ChatRequest {
|
||||
model: model.name,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => ChatMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::System => ChatMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
keep_alive: model.keep_alive.unwrap_or_default(),
|
||||
stream: true,
|
||||
options: Some(ChatOptions {
|
||||
num_ctx: Some(model.max_tokens),
|
||||
stop: Some(request.stop),
|
||||
temperature: Some(request.temperature),
|
||||
..Default::default()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DownloadOllamaMessage {
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl DownloadOllamaMessage {
|
||||
pub fn new(
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
Self { retry_connection }
|
||||
}
|
||||
|
||||
fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("download_ollama_button")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Get Ollama"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
|
||||
}
|
||||
|
||||
fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("retry_ollama_models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Retry"))
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
let connected = (this.retry_connection)(cx);
|
||||
|
||||
cx.spawn(|_this, _cx| async move {
|
||||
connected.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Once Ollama is on your machine, make sure to download a model or two.")
|
||||
.size(LabelSize::Large),
|
||||
)
|
||||
.child(
|
||||
h_flex().w_full().p_4().justify_center().gap_2().child(
|
||||
ButtonLike::new("view-models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("View Available Models"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for DownloadOllamaMessage {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.justify_center()
|
||||
.gap_2()
|
||||
.child(
|
||||
self.render_download_button(cx)
|
||||
)
|
||||
.child(
|
||||
self.render_retry_button(cx)
|
||||
)
|
||||
)
|
||||
.child(self.render_next_steps(cx))
|
||||
.into_any()
|
||||
}
|
||||
}
|
365
crates/completion/src/open_ai.rs
Normal file
365
crates/completion/src/open_ai.rs
Normal file
|
@ -0,0 +1,365 @@
|
|||
use crate::CompletionProvider;
|
||||
use crate::LanguageModelCompletionProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use open_ai::{stream_completion, Request, RequestMessage};
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models_from_settings: Vec<OpenAiModel>,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub fn new(
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models_from_settings: Vec<OpenAiModel>,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
available_models_from_settings,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
if self.available_models_from_settings.is_empty() {
|
||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
||||
vec![self.model.clone()]
|
||||
} else {
|
||||
OpenAiModel::iter()
|
||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
||||
.collect()
|
||||
};
|
||||
available_models
|
||||
.into_iter()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect()
|
||||
} else {
|
||||
self.available_models_from_settings
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::OpenAi(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
background_executor: &gpui::BackgroundExecutor,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
background_executor
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match request.model {
|
||||
LanguageModel::Anthropic(_)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
}
|
||||
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue