assistant: Overhaul provider infrastructure (#14929)
<img width="624" alt="image" src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815"> - [x] Correctly set custom model token count - [x] How to count tokens for Gemini models? - [x] Feature flag zed.dev provider - [x] Figure out how to configure custom models - [ ] Update docs Release Notes: - Added support for quickly switching between multiple language model providers in the assistant panel --------- Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
parent
17ef9a367f
commit
d0f52e90e6
55 changed files with 2757 additions and 2023 deletions
|
@ -22,12 +22,27 @@ test-support = [
|
|||
|
||||
[dependencies]
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http.workspace = true
|
||||
menu.workspace = true
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
proto = { workspace = true, features = ["test-support"] }
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
strum.workspace = true
|
||||
proto = { workspace = true, features = ["test-support"] }
|
||||
theme.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
|
|
|
@ -1,7 +1,84 @@
|
|||
mod model;
|
||||
pub mod provider;
|
||||
mod registry;
|
||||
mod request;
|
||||
mod role;
|
||||
pub mod settings;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
||||
|
||||
pub use model::*;
|
||||
pub use registry::*;
|
||||
pub use request::*;
|
||||
pub use role::*;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
settings::init(cx);
|
||||
registry::init(client, cx);
|
||||
}
|
||||
|
||||
pub trait LanguageModel: Send + Sync {
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
fn provider_name(&self) -> LanguageModelProviderName;
|
||||
fn telemetry_id(&self) -> String;
|
||||
|
||||
fn max_token_count(&self) -> usize;
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>>;
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
}
|
||||
|
||||
pub trait LanguageModelProvider: 'static {
|
||||
fn name(&self) -> LanguageModelProviderName;
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
pub trait LanguageModelProviderState: 'static {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
impl From<String> for LanguageModelId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
pub use anthropic::Model as AnthropicModel;
|
||||
use anyhow::{anyhow, Result};
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::JsonSchema;
|
||||
|
@ -38,6 +39,23 @@ pub enum CloudModel {
|
|||
}
|
||||
|
||||
impl CloudModel {
|
||||
pub fn from_id(value: &str) -> Result<Self> {
|
||||
match value {
|
||||
"gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo),
|
||||
"gpt-4" => Ok(Self::Gpt4),
|
||||
"gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo),
|
||||
"gpt-4o" => Ok(Self::Gpt4Omni),
|
||||
"gpt-4o-mini" => Ok(Self::Gpt4OmniMini),
|
||||
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||
"claude-3-opus" => Ok(Self::Claude3Opus),
|
||||
"claude-3-sonnet" => Ok(Self::Claude3Sonnet),
|
||||
"claude-3-haiku" => Ok(Self::Claude3Haiku),
|
||||
"gemini-1.5-pro" => Ok(Self::Gemini15Pro),
|
||||
"gemini-1.5-flash" => Ok(Self::Gemini15Flash),
|
||||
_ => Err(anyhow!("invalid model id")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
|
|
|
@ -4,57 +4,3 @@ pub use anthropic::Model as AnthropicModel;
|
|||
pub use cloud_model::*;
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum LanguageModel {
|
||||
Cloud(CloudModel),
|
||||
OpenAi(OpenAiModel),
|
||||
Anthropic(AnthropicModel),
|
||||
Ollama(OllamaModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
fn default() -> Self {
|
||||
LanguageModel::Cloud(CloudModel::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel {
|
||||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
|
||||
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||
LanguageModel::Cloud(model) => model.display_name().into(),
|
||||
LanguageModel::Ollama(model) => model.display_name().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||
LanguageModel::Cloud(model) => model.max_token_count(),
|
||||
LanguageModel::Ollama(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::Anthropic(model) => model.id(),
|
||||
LanguageModel::Cloud(model) => model.id(),
|
||||
LanguageModel::Ollama(model) => model.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
6
crates/language_model/src/provider.rs
Normal file
6
crates/language_model/src/provider.rs
Normal file
|
@ -0,0 +1,6 @@
|
|||
pub mod anthropic;
|
||||
pub mod cloud;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod fake;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
454
crates/language_model/src/provider/anthropic.rs
Normal file
454
crates/language_model/src/provider/anthropic.rs
Normal file
|
@ -0,0 +1,454 @@
|
|||
use anthropic::{stream_completion, Request, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
|
||||
WhiteSpace,
|
||||
};
|
||||
use http::HttpClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_NAME: &str = "anthropic";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct AnthropicSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<anthropic::Model>,
|
||||
}
|
||||
|
||||
pub struct AnthropicLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
api_key: Option<String>,
|
||||
settings: AnthropicSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
let state = cx.new_model(|cx| State {
|
||||
api_key: None,
|
||||
settings: AnthropicSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
}
|
||||
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
|
||||
// Add base models from anthropic::Model::iter()
|
||||
for model in anthropic::Model::iter() {
|
||||
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| {
|
||||
Arc::new(AnthropicModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let state = self.state.clone();
|
||||
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicModel {
|
||||
id: LanguageModelId,
|
||||
model: anthropic::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl AnthropicModel {
|
||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||
preprocess_anthropic_request(&mut request);
|
||||
|
||||
let mut system_message = String::new();
|
||||
if request
|
||||
.messages
|
||||
.first()
|
||||
.map_or(false, |message| message.role == Role::System)
|
||||
{
|
||||
system_message = request.messages.remove(0).content;
|
||||
}
|
||||
|
||||
Request {
|
||||
model: self.model.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: match msg.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||
},
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_anthropic_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
impl LanguageModel for AnthropicModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("anthropic/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
(
|
||||
state.api_key.clone(),
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages.drain(..) {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_message.is_empty() {
|
||||
new_messages.insert(
|
||||
0,
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: system_message,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
request.messages = new_messages;
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(
|
||||
&self.state.read(cx).settings.api_url,
|
||||
"Bearer",
|
||||
api_key.as_bytes(),
|
||||
);
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
287
crates/language_model/src/provider/cloud.rs
Normal file
287
crates/language_model/src/provider/cloud.rs
Normal file
|
@ -0,0 +1,287 @@
|
|||
use super::open_ai::count_open_ai_tokens;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
use crate::LanguageModelProvider;
|
||||
|
||||
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||
|
||||
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
pub available_models: Vec<CloudModel>,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
client: Arc<Client>,
|
||||
state: gpui::Model<State>,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
client: Arc<Client>,
|
||||
status: client::Status,
|
||||
settings: ZedDotDevSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudLanguageModelProvider {
|
||||
pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
|
||||
let state = cx.new_model(|cx| State {
|
||||
client: client.clone(),
|
||||
status,
|
||||
settings: ZedDotDevSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
let state_ref = state.downgrade();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
if let Some(this) = state_ref.upgrade() {
|
||||
_ = this.update(&mut cx, |this, cx| {
|
||||
this.status = status;
|
||||
cx.notify();
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
client,
|
||||
state,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for CloudLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
|
||||
// Add base models from CloudModel::iter()
|
||||
for model in CloudModel::iter() {
|
||||
if !matches!(model, CloudModel::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| {
|
||||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
client: self.client.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
self.state.read(cx).status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.state.read(cx).authenticate(cx)
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("zed.dev/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match &self.model {
|
||||
CloudModel::Gpt3Point5Turbo => {
|
||||
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||
}
|
||||
CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
|
||||
CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
|
||||
CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
|
||||
CloudModel::Gpt4OmniMini => {
|
||||
count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
|
||||
}
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||
_ => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku
|
||||
| CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
preprocess_anthropic_request(&mut request)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: self.id.0.to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
this.state.update(cx, |provider, cx| {
|
||||
provider.authenticate(cx).detach_and_log_err(cx);
|
||||
cx.notify();
|
||||
});
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
160
crates/language_model/src/provider/fake.rs
Normal file
160
crates/language_model/src/provider/fake.rs
Normal file
|
@ -0,0 +1,160 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
use http::Result;
|
||||
use ui::WindowContext;
|
||||
|
||||
pub fn language_model_id() -> LanguageModelId {
|
||||
LanguageModelId::from("fake".to_string())
|
||||
}
|
||||
|
||||
pub fn language_model_name() -> LanguageModelName {
|
||||
LanguageModelName::from("Fake".to_string())
|
||||
}
|
||||
|
||||
pub fn provider_name() -> LanguageModelProviderName {
|
||||
LanguageModelProviderName::from("fake".to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeLanguageModelProvider {
|
||||
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for FakeLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for FakeLanguageModelProvider {
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
}
|
||||
|
||||
fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
vec![Arc::new(FakeLanguageModel {
|
||||
current_completion_txs: self.current_completion_txs.clone(),
|
||||
})]
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, _: &AppContext) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeLanguageModelProvider {
|
||||
pub fn test_model(&self) -> FakeLanguageModel {
|
||||
FakeLanguageModel {
|
||||
current_completion_txs: self.current_completion_txs.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeLanguageModel {
|
||||
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeLanguageModel {
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.keys()
|
||||
.map(|k| serde_json::from_str(k).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn completion_count(&self) -> usize {
|
||||
self.current_completion_txs.lock().unwrap().len()
|
||||
}
|
||||
|
||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let json = serde_json::to_string(request).unwrap();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&json)
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&serde_json::to_string(request).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_last_completion(&self) {
|
||||
self.finish_completion(self.pending_completions().last().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
language_model_id()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
language_model_name()
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
"fake".to_string()
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
1000000
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_: LanguageModelRequest,
|
||||
_: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
futures::future::ready(Ok(0)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(serde_json::to_string(&request).unwrap(), tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
}
|
368
crates/language_model/src/provider/ollama.rs
Normal file
368
crates/language_model/src/provider/ollama.rs
Normal file
|
@ -0,0 +1,368 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
||||
use http::HttpClient;
|
||||
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
|
||||
const PROVIDER_NAME: &str = "ollama";
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct OllamaSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
pub struct OllamaLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<ollama::Model>,
|
||||
settings: OllamaSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.settings.api_url.clone();
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<ollama::Model> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| ollama::Model::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.available_models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
Self {
|
||||
http_client: http_client.clone(),
|
||||
state: cx.new_model(|cx| State {
|
||||
http_client,
|
||||
available_models: Default::default(),
|
||||
settings: OllamaSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
|
||||
cx.notify();
|
||||
}),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
|
||||
let state = self.state.clone();
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<ollama::Model> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| ollama::Model::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.available_models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
self.state
|
||||
.read(cx)
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model| {
|
||||
Arc::new(OllamaLanguageModel {
|
||||
id: LanguageModelId::from(model.name.clone()),
|
||||
model: model.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
state: self.state.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
!self.state.read(cx).available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
let state = self.state.clone();
|
||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||
state.update(cx, |this, cx| this.fetch_models(cx))
|
||||
});
|
||||
|
||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OllamaLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: ollama::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl OllamaLanguageModel {
|
||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||
ChatRequest {
|
||||
model: self.model.name.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => ChatMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::System => ChatMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
|
||||
stream: true,
|
||||
options: Some(ChatOptions {
|
||||
num_ctx: Some(self.model.max_tokens),
|
||||
stop: Some(request.stop),
|
||||
temperature: Some(request.temperature),
|
||||
..Default::default()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OllamaLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("ollama/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
// There is no endpoint for this _yet_ in Ollama
|
||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||
let token_count = request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| msg.content.chars().count())
|
||||
.sum::<usize>()
|
||||
/ 4;
|
||||
|
||||
async move { Ok(token_count) }.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
(
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(delta) => {
|
||||
let content = match delta.message {
|
||||
ChatMessage::User { content } => content,
|
||||
ChatMessage::Assistant { content } => content,
|
||||
ChatMessage::System { content } => content,
|
||||
};
|
||||
Some(Ok(content))
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct DownloadOllamaMessage {
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl DownloadOllamaMessage {
|
||||
pub fn new(
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
Self { retry_connection }
|
||||
}
|
||||
|
||||
fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("download_ollama_button")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Get Ollama"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
|
||||
}
|
||||
|
||||
fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("retry_ollama_models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Retry"))
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
let connected = (this.retry_connection)(cx);
|
||||
|
||||
cx.spawn(|_this, _cx| async move {
|
||||
connected.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Once Ollama is on your machine, make sure to download a model or two.")
|
||||
.size(LabelSize::Large),
|
||||
)
|
||||
.child(
|
||||
h_flex().w_full().p_4().justify_center().gap_2().child(
|
||||
ButtonLike::new("view-models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("View Available Models"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for DownloadOllamaMessage {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.justify_center()
|
||||
.gap_2()
|
||||
.child(
|
||||
self.render_download_button(cx)
|
||||
)
|
||||
.child(
|
||||
self.render_retry_button(cx)
|
||||
)
|
||||
)
|
||||
.child(self.render_next_steps(cx))
|
||||
.into_any()
|
||||
}
|
||||
}
|
398
crates/language_model/src/provider/open_ai.rs
Normal file
398
crates/language_model/src/provider/open_ai.rs
Normal file
|
@ -0,0 +1,398 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
|
||||
WhiteSpace,
|
||||
};
|
||||
use http::HttpClient;
|
||||
use open_ai::{stream_completion, Request, RequestMessage};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_NAME: &str = "openai";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenAiSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<open_ai::Model>,
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
api_key: Option<String>,
|
||||
settings: OpenAiSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
let state = cx.new_model(|cx| State {
|
||||
api_key: None,
|
||||
settings: OpenAiSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
|
||||
// Add base models from open_ai::Model::iter()
|
||||
for model in open_ai::Model::iter() {
|
||||
if !matches!(model, open_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| {
|
||||
Arc::new(OpenAiLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: open_ai::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModel {
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
Request {
|
||||
model: self.model.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("openai/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||
(
|
||||
state.api_key.clone(),
|
||||
state.settings.api_url.clone(),
|
||||
state.settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: open_ai::Model,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if let open_ai::Model::Custom { .. } = model {
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
} else {
|
||||
tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(
|
||||
&self.state.read(cx).settings.api_url,
|
||||
"Bearer",
|
||||
api_key.as_bytes(),
|
||||
);
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
172
crates/language_model/src/registry.rs
Normal file
172
crates/language_model/src/registry.rs
Normal file
|
@ -0,0 +1,172 @@
|
|||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use gpui::{AppContext, Global, Model, ModelContext};
|
||||
use std::sync::Arc;
|
||||
use ui::Context;
|
||||
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = LanguageModelRegistry::default();
|
||||
register_language_model_providers(&mut registry, client, cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
}
|
||||
|
||||
fn register_language_model_providers(
|
||||
registry: &mut LanguageModelRegistry,
|
||||
client: Arc<Client>,
|
||||
cx: &mut ModelContext<LanguageModelRegistry>,
|
||||
) {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
|
||||
registry.register_provider(
|
||||
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
OpenAiLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
OllamaLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
|
||||
let client = client.clone();
|
||||
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
|
||||
if enabled {
|
||||
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
|
||||
} else {
|
||||
registry.unregister_provider(
|
||||
&LanguageModelProviderName::from(
|
||||
crate::provider::cloud::PROVIDER_NAME.to_string(),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
|
||||
|
||||
impl Global for GlobalLanguageModelRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
|
||||
}
|
||||
|
||||
impl LanguageModelRegistry {
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
cx.global::<GlobalLanguageModelRegistry>().0.clone()
|
||||
}
|
||||
|
||||
pub fn read_global(cx: &AppContext) -> &Self {
|
||||
cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
|
||||
let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
fake_provider
|
||||
}
|
||||
|
||||
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
|
||||
&mut self,
|
||||
provider: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let name = provider.name();
|
||||
|
||||
if let Some(subscription) = provider.subscribe(cx) {
|
||||
subscription.detach();
|
||||
}
|
||||
|
||||
self.providers.insert(name, Arc::new(provider));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn unregister_provider(
|
||||
&mut self,
|
||||
name: &LanguageModelProviderName,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if self.providers.remove(name).is_some() {
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn providers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
self.providers
|
||||
.values()
|
||||
.flat_map(|provider| provider.provided_models(cx))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn available_models_grouped_by_provider(
|
||||
&self,
|
||||
cx: &AppContext,
|
||||
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
|
||||
self.providers
|
||||
.iter()
|
||||
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn provider(
|
||||
&self,
|
||||
name: &LanguageModelProviderName,
|
||||
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::fake::FakeLanguageModelProvider;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_register_providers(cx: &mut AppContext) {
|
||||
let registry = cx.new_model(|_| LanguageModelRegistry::default());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(FakeLanguageModelProvider::default(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
assert!(providers.is_empty());
|
||||
}
|
||||
}
|
|
@ -1,7 +1,4 @@
|
|||
use crate::{
|
||||
model::{CloudModel, LanguageModel},
|
||||
role::Role,
|
||||
};
|
||||
use crate::{role::Role, LanguageModelId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
@ -23,16 +20,15 @@ impl LanguageModelRequestMessage {
|
|||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub model: LanguageModel,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||
pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
model: model_id.0.to_string(),
|
||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||
stop: self.stop.clone(),
|
||||
temperature: self.temperature,
|
||||
|
@ -40,70 +36,6 @@ impl LanguageModelRequest {
|
|||
tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
|
||||
pub fn preprocess(&mut self) {
|
||||
match &self.model {
|
||||
LanguageModel::OpenAi(_) => {}
|
||||
LanguageModel::Anthropic(_) => self.preprocess_anthropic(),
|
||||
LanguageModel::Ollama(_) => {}
|
||||
LanguageModel::Cloud(model) => match model {
|
||||
CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku
|
||||
| CloudModel::Claude3_5Sonnet => {
|
||||
self.preprocess_anthropic();
|
||||
}
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
self.preprocess_anthropic();
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_anthropic(&mut self) {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in self.messages.drain(..) {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_message.is_empty() {
|
||||
new_messages.insert(
|
||||
0,
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: system_message,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
self.messages = new_messages;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
|
143
crates/language_model/src/settings.rs
Normal file
143
crates/language_model/src/settings.rs
Normal file
|
@ -0,0 +1,143 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AppContext;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
},
|
||||
CloudModel,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
AllLanguageModelSettings::register(cx);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AllLanguageModelSettings {
|
||||
pub open_ai: OpenAiSettings,
|
||||
pub anthropic: AnthropicSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AllLanguageModelSettingsContent {
|
||||
pub anthropic: Option<AnthropicSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub open_ai: Option<OpenAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AnthropicSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<anthropic::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OllamaSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<open_ai::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct ZedDotDevSettingsContent {
|
||||
available_models: Option<Vec<CloudModel>>,
|
||||
}
|
||||
|
||||
impl settings::Settings for AllLanguageModelSettings {
|
||||
const KEY: Option<&'static str> = Some("language_models");
|
||||
|
||||
type FileContent = AllLanguageModelSettingsContent;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
fn merge<T>(target: &mut T, value: Option<T>) {
|
||||
if let Some(value) = value {
|
||||
*target = value;
|
||||
}
|
||||
}
|
||||
|
||||
let mut settings = AllLanguageModelSettings::default();
|
||||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
merge(
|
||||
&mut settings.anthropic.api_url,
|
||||
value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.anthropic
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.anthropic.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
merge(
|
||||
&mut settings.anthropic.available_models,
|
||||
value
|
||||
.anthropic
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
&mut settings.ollama.api_url,
|
||||
value.ollama.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.ollama
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.ollama.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
|
||||
merge(
|
||||
&mut settings.open_ai.api_url,
|
||||
value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.open_ai
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.open_ai.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
merge(
|
||||
&mut settings.open_ai.available_models,
|
||||
value
|
||||
.open_ai
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
&mut settings.zed_dot_dev.available_models,
|
||||
value
|
||||
.zed_dot_dev
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue