assistant: Overhaul provider infrastructure (#14929)
<img width="624" alt="image" src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815"> - [x] Correctly set custom model token count - [x] How to count tokens for Gemini models? - [x] Feature flag zed.dev provider - [x] Figure out how to configure custom models - [ ] Update docs Release Notes: - Added support for quickly switching between multiple language model providers in the assistant panel --------- Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
parent
17ef9a367f
commit
d0f52e90e6
55 changed files with 2757 additions and 2023 deletions
|
@ -1,318 +0,0 @@
|
|||
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
|
||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||
use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task, TextStyle, View};
|
||||
use http::HttpClient;
|
||||
use language_model::Role;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AnthropicCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: AnthropicModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
AnthropicModel::iter()
|
||||
.map(LanguageModel::Anthropic)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Anthropic(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicCompletionProvider {
|
||||
pub fn new(
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||
request.preprocess_anthropic();
|
||||
|
||||
let model = match request.model {
|
||||
LanguageModel::Anthropic(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
let mut system_message = String::new();
|
||||
if request
|
||||
.messages
|
||||
.first()
|
||||
.map_or(false, |message| message.role == Role::System)
|
||||
{
|
||||
system_message = request.messages.remove(0).content;
|
||||
}
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: match msg.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||
},
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(1.3),
|
||||
..Default::default()
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -1,214 +0,0 @@
|
|||
use crate::{
|
||||
count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use language_model::CloudModel;
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct CloudCompletionProvider {
|
||||
client: Arc<Client>,
|
||||
model: CloudModel,
|
||||
settings_version: usize,
|
||||
status: client::Status,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
impl CloudCompletionProvider {
|
||||
pub fn new(
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.status = status;
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
model,
|
||||
settings_version,
|
||||
status,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
|
||||
self.model = model;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
|
||||
Some(self.model.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
CloudModel::iter()
|
||||
.filter_map(move |model| {
|
||||
if let CloudModel::Custom { .. } = model {
|
||||
custom_model.take()
|
||||
} else {
|
||||
Some(model)
|
||||
}
|
||||
})
|
||||
.map(LanguageModel::Cloud)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Cloud(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match &request.model {
|
||||
LanguageModel::Cloud(CloudModel::Gpt4)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku,
|
||||
) => {
|
||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
|
||||
if name.starts_with("anthropic/") {
|
||||
// Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
} else {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: name.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
request.preprocess();
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: request.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt;
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(|_, cx| {
|
||||
CompletionProvider::global(cx)
|
||||
.authenticate(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,31 +1,37 @@
|
|||
mod anthropic;
|
||||
mod cloud;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
mod fake;
|
||||
mod ollama;
|
||||
mod open_ai;
|
||||
|
||||
pub use anthropic::*;
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
pub use cloud::*;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use fake::*;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task, WindowContext};
|
||||
use language_model::{LanguageModel, LanguageModelRequest};
|
||||
pub use ollama::*;
|
||||
pub use open_ai::*;
|
||||
use parking_lot::RwLock;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
|
||||
use std::{pin::Pin, sync::Arc, task::Poll};
|
||||
use ui::Context;
|
||||
|
||||
pub struct CompletionResponse {
|
||||
inner: BoxStream<'static, Result<String>>,
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
|
||||
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
|
||||
}
|
||||
|
||||
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
|
||||
|
||||
impl Global for GlobalLanguageModelCompletionProvider {}
|
||||
|
||||
pub struct LanguageModelCompletionProvider {
|
||||
active_provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||
active_model: Option<Arc<dyn LanguageModel>>,
|
||||
request_limiter: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||
|
||||
pub struct LanguageModelCompletionResponse {
|
||||
pub inner: BoxStream<'static, Result<String>>,
|
||||
_lock: SemaphoreGuardArc,
|
||||
}
|
||||
|
||||
impl futures::Stream for CompletionResponse {
|
||||
impl futures::Stream for LanguageModelCompletionResponse {
|
||||
type Item = Result<String>;
|
||||
|
||||
fn poll_next(
|
||||
|
@ -36,73 +42,96 @@ impl futures::Stream for CompletionResponse {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelCompletionProvider: Send + Sync {
|
||||
fn available_models(&self) -> Vec<LanguageModel>;
|
||||
fn settings_version(&self) -> usize;
|
||||
fn is_authenticated(&self) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn model(&self) -> LanguageModel;
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>>;
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
impl LanguageModelCompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
cx.global::<GlobalLanguageModelCompletionProvider>()
|
||||
.0
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn Any;
|
||||
}
|
||||
pub fn read_global(cx: &AppContext) -> &Self {
|
||||
cx.global::<GlobalLanguageModelCompletionProvider>()
|
||||
.0
|
||||
.read(cx)
|
||||
}
|
||||
|
||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(cx: &mut AppContext) {
|
||||
let provider = cx.new_model(|cx| {
|
||||
let mut this = Self::new(cx);
|
||||
let available_model = LanguageModelRegistry::read_global(cx)
|
||||
.available_models(cx)
|
||||
.first()
|
||||
.unwrap()
|
||||
.clone();
|
||||
this.set_active_model(available_model, cx);
|
||||
this
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
|
||||
}
|
||||
|
||||
pub struct CompletionProvider {
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
request_limiter: Arc<Semaphore>,
|
||||
}
|
||||
pub fn new(cx: &mut ModelContext<Self>) -> Self {
|
||||
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn new(
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
client,
|
||||
active_provider: None,
|
||||
active_model: None,
|
||||
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn available_models(&self) -> Vec<LanguageModel> {
|
||||
self.provider.read().available_models()
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
self.active_provider.clone()
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.provider.read().settings_version()
|
||||
pub fn set_active_provider(
|
||||
&mut self,
|
||||
provider_name: LanguageModelProviderName,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
|
||||
self.active_model = None;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.provider.read().is_authenticated()
|
||||
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.active_model.clone()
|
||||
}
|
||||
|
||||
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
||||
if self.active_model.as_ref().map_or(false, |m| {
|
||||
m.id() == model.id() && m.provider_name() == model.provider_name()
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.active_provider =
|
||||
LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
|
||||
self.active_model = Some(model);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
self.active_provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().authenticate(cx)
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
self.provider.read().authentication_prompt(cx)
|
||||
self.active_provider
|
||||
.as_ref()
|
||||
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().reset_credentials(cx)
|
||||
}
|
||||
|
||||
pub fn model(&self) -> LanguageModel {
|
||||
self.provider.read().model()
|
||||
self.active_provider
|
||||
.as_ref()
|
||||
.map_or(Task::ready(Ok(())), |provider| {
|
||||
provider.reset_credentials(cx)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
|
@ -110,25 +139,31 @@ impl CompletionProvider {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
self.provider.read().count_tokens(request, cx)
|
||||
if let Some(model) = self.active_model() {
|
||||
model.count_tokens(request, cx)
|
||||
} else {
|
||||
std::future::ready(Err(anyhow!("No active model set"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<CompletionResponse>> {
|
||||
let rate_limiter = self.request_limiter.clone();
|
||||
let provider = self.provider.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let lock = rate_limiter.acquire_arc().await;
|
||||
let response = provider.read().stream_completion(request);
|
||||
let response = response.await?;
|
||||
Ok(CompletionResponse {
|
||||
inner: response,
|
||||
_lock: lock,
|
||||
) -> Task<Result<LanguageModelCompletionResponse>> {
|
||||
if let Some(language_model) = self.active_model() {
|
||||
let rate_limiter = self.request_limiter.clone();
|
||||
cx.spawn(|cx| async move {
|
||||
let lock = rate_limiter.acquire_arc().await;
|
||||
let response = language_model.stream_completion(request, &cx).await?;
|
||||
Ok(LanguageModelCompletionResponse {
|
||||
inner: response,
|
||||
_lock: lock,
|
||||
})
|
||||
})
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("No active model set")))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
||||
|
@ -143,63 +178,43 @@ impl CompletionProvider {
|
|||
Ok(completion)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update_provider(
|
||||
&mut self,
|
||||
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
) {
|
||||
if let Some(client) = &self.client {
|
||||
self.provider = get_provider(Arc::clone(client));
|
||||
} else {
|
||||
log::warn!("completion provider cannot be updated because its client was not set");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl gpui::Global for CompletionProvider {}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> &Self {
|
||||
cx.global::<Self>()
|
||||
}
|
||||
|
||||
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
|
||||
&mut self,
|
||||
update: impl FnOnce(&mut T) -> R,
|
||||
) -> Option<R> {
|
||||
let mut provider = self.provider.write();
|
||||
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
|
||||
Some(update(provider))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use gpui::AppContext;
|
||||
use parking_lot::RwLock;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
use ui::Context;
|
||||
|
||||
use crate::{
|
||||
CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS,
|
||||
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
|
||||
};
|
||||
|
||||
use language_model::LanguageModelRegistry;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_rate_limiting(cx: &mut AppContext) {
|
||||
SettingsStore::test(cx);
|
||||
let fake_provider = FakeCompletionProvider::setup_test(cx);
|
||||
let fake_provider = LanguageModelRegistry::test(cx);
|
||||
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.available_models(cx)
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap();
|
||||
|
||||
let provider = cx.new_model(|cx| {
|
||||
let mut provider = LanguageModelCompletionProvider::new(cx);
|
||||
provider.set_active_model(model.clone(), cx);
|
||||
provider
|
||||
});
|
||||
|
||||
let fake_model = fake_provider.test_model();
|
||||
|
||||
// Enqueue some requests
|
||||
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
||||
let response = provider.stream_completion(
|
||||
let response = provider.read(cx).stream_completion(
|
||||
LanguageModelRequest {
|
||||
temperature: i as f32 / 10.0,
|
||||
..Default::default()
|
||||
|
@ -216,23 +231,18 @@ mod tests {
|
|||
.detach();
|
||||
}
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
fake_model.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Get the first completion request that is in flight and mark it as completed.
|
||||
let completion = fake_provider
|
||||
.pending_completions()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
fake_provider.finish_completion(&completion);
|
||||
let completion = fake_model.pending_completions().into_iter().next().unwrap();
|
||||
fake_model.finish_completion(&completion);
|
||||
|
||||
// Ensure that the number of in-flight completion requests is reduced.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
fake_model.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
|
@ -240,32 +250,32 @@ mod tests {
|
|||
|
||||
// Ensure that another completion request was allowed to acquire the lock.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
fake_model.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Mark all completion requests as finished that are in flight.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
for request in fake_model.pending_completions() {
|
||||
fake_model.finish_completion(&request);
|
||||
}
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
assert_eq!(fake_model.completion_count(), 0);
|
||||
|
||||
// Wait until the background tasks acquire the lock again.
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
fake_model.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
// Finish all remaining completion requests.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
for request in fake_model.pending_completions() {
|
||||
fake_model.finish_completion(&request);
|
||||
}
|
||||
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
assert_eq!(fake_model.completion_count(), 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
|
||||
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeCompletionProvider {
|
||||
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn setup_test(cx: &mut AppContext) -> Self {
|
||||
use crate::CompletionProvider;
|
||||
use parking_lot::RwLock;
|
||||
|
||||
let this = Self::default();
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
|
||||
cx.set_global(provider);
|
||||
this
|
||||
}
|
||||
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.keys()
|
||||
.map(|k| serde_json::from_str(k).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn completion_count(&self) -> usize {
|
||||
self.current_completion_txs.lock().len()
|
||||
}
|
||||
|
||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let json = serde_json::to_string(request).unwrap();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.get(&json)
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.remove(&serde_json::to_string(request).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_last_completion(&self) {
|
||||
self.finish_completion(self.pending_completions().last().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
vec![LanguageModel::default()]
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::default()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
futures::future::ready(Ok(0)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.insert(serde_json::to_string(&_request).unwrap(), tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
|
@ -1,347 +0,0 @@
|
|||
use crate::LanguageModelCompletionProvider;
|
||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt as _;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use http::HttpClient;
|
||||
use language_model::Role;
|
||||
use ollama::Model as OllamaModel;
|
||||
use ollama::{
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
|
||||
pub struct OllamaCompletionProvider {
|
||||
api_url: String,
|
||||
model: OllamaModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
self.available_models
|
||||
.iter()
|
||||
.map(|m| LanguageModel::Ollama(m.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
!self.available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider
|
||||
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.fetch_models(cx)
|
||||
})
|
||||
.unwrap_or_else(|| Task::ready(Ok(())))
|
||||
})
|
||||
});
|
||||
|
||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Ollama(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
// There is no endpoint for this _yet_ in Ollama
|
||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||
let token_count = request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| msg.content.chars().count())
|
||||
.sum::<usize>()
|
||||
/ 4;
|
||||
|
||||
async move { Ok(token_count) }.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(delta) => {
|
||||
let content = match delta.message {
|
||||
ChatMessage::User { content } => content,
|
||||
ChatMessage::Assistant { content } => content,
|
||||
ChatMessage::System { content } => content,
|
||||
};
|
||||
Some(Ok(content))
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaCompletionProvider {
|
||||
pub fn new(
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) -> Self {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move {
|
||||
if model.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
preload_model(client.as_ref(), &api_url, &model).await
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = self.http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
if model.name.is_empty() {
|
||||
self.select_first_available_model()
|
||||
} else {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn select_first_available_model(&mut self) {
|
||||
if let Some(model) = self.available_models.first() {
|
||||
self.model = model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<OllamaModel> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| OllamaModel::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.available_models = models;
|
||||
|
||||
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
||||
provider.select_first_available_model()
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||
let model = match request.model {
|
||||
LanguageModel::Ollama(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
ChatRequest {
|
||||
model: model.name,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => ChatMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::System => ChatMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
keep_alive: model.keep_alive.unwrap_or_default(),
|
||||
stream: true,
|
||||
options: Some(ChatOptions {
|
||||
num_ctx: Some(model.max_tokens),
|
||||
stop: Some(request.stop),
|
||||
temperature: Some(request.temperature),
|
||||
..Default::default()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DownloadOllamaMessage {
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl DownloadOllamaMessage {
|
||||
pub fn new(
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
Self { retry_connection }
|
||||
}
|
||||
|
||||
fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("download_ollama_button")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Get Ollama"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
|
||||
}
|
||||
|
||||
fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("retry_ollama_models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Retry"))
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
let connected = (this.retry_connection)(cx);
|
||||
|
||||
cx.spawn(|_this, _cx| async move {
|
||||
connected.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Once Ollama is on your machine, make sure to download a model or two.")
|
||||
.size(LabelSize::Large),
|
||||
)
|
||||
.child(
|
||||
h_flex().w_full().p_4().justify_center().gap_2().child(
|
||||
ButtonLike::new("view-models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("View Available Models"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for DownloadOllamaMessage {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.justify_center()
|
||||
.gap_2()
|
||||
.child(
|
||||
self.render_download_button(cx)
|
||||
)
|
||||
.child(
|
||||
self.render_retry_button(cx)
|
||||
)
|
||||
)
|
||||
.child(self.render_next_steps(cx))
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -1,362 +0,0 @@
|
|||
use crate::CompletionProvider;
|
||||
use crate::LanguageModelCompletionProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task, TextStyle, View};
|
||||
use http::HttpClient;
|
||||
use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use open_ai::{stream_completion, Request, RequestMessage};
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models_from_settings: Vec<OpenAiModel>,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub fn new(
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models_from_settings: Vec<OpenAiModel>,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
available_models_from_settings,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||
fn available_models(&self) -> Vec<LanguageModel> {
|
||||
if self.available_models_from_settings.is_empty() {
|
||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
||||
vec![self.model.clone()]
|
||||
} else {
|
||||
OpenAiModel::iter()
|
||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
||||
.collect()
|
||||
};
|
||||
available_models
|
||||
.into_iter()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect()
|
||||
} else {
|
||||
self.available_models_from_settings
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::OpenAi(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
background_executor: &gpui::BackgroundExecutor,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
background_executor
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match request.model {
|
||||
LanguageModel::Anthropic(_)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
||||
| LanguageModel::Cloud(CloudModel::Custom { .. })
|
||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
}
|
||||
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(1.3),
|
||||
..Default::default()
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue