Simplify LLM protocol (#15366)
In this pull request, we change the zed.dev protocol so that we pass the raw JSON for the specified provider directly to our server. This avoids the need to define a protobuf message that's a superset of all these formats. @bennetbo: We also changed the settings for available_models under zed.dev to be a flat format, because the nesting seemed too confusing. Can you help us upgrade the local provider configuration to be consistent with this? We do whatever we need to do when parsing the settings to make this simple for users, even if it's a bit more complex on our end. We want to use versioning to avoid breaking existing users, but need to keep making progress. ```json "zed.dev": { "available_models": [ { "provider": "anthropic", "name": "some-newly-released-model-we-havent-added", "max_tokens": 200000 } ] } ``` Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
e0fe7f632c
commit
d6bdaa8a91
31 changed files with 896 additions and 2154 deletions
|
@ -28,6 +28,7 @@ collections.workspace = true
|
|||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
menu.workspace = true
|
||||
|
|
|
@ -1,108 +1,42 @@
|
|||
pub use anthropic::Model as AnthropicModel;
|
||||
use anyhow::{anyhow, Result};
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "provider", rename_all = "lowercase")]
|
||||
pub enum CloudModel {
|
||||
#[serde(rename = "gpt-3.5-turbo")]
|
||||
Gpt3Point5Turbo,
|
||||
#[serde(rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(rename = "gpt-4-turbo-preview")]
|
||||
Gpt4Turbo,
|
||||
#[serde(rename = "gpt-4o")]
|
||||
#[default]
|
||||
Gpt4Omni,
|
||||
#[serde(rename = "gpt-4o-mini")]
|
||||
Gpt4OmniMini,
|
||||
#[serde(rename = "claude-3-5-sonnet")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(rename = "claude-3-opus")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
Gemini15Flash,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
max_tokens: Option<usize>,
|
||||
},
|
||||
Anthropic(anthropic::Model),
|
||||
OpenAi(open_ai::Model),
|
||||
Google(google_ai::Model),
|
||||
}
|
||||
|
||||
impl Default for CloudModel {
|
||||
fn default() -> Self {
|
||||
Self::Anthropic(anthropic::Model::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudModel {
|
||||
pub fn from_id(value: &str) -> Result<Self> {
|
||||
match value {
|
||||
"gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo),
|
||||
"gpt-4" => Ok(Self::Gpt4),
|
||||
"gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo),
|
||||
"gpt-4o" => Ok(Self::Gpt4Omni),
|
||||
"gpt-4o-mini" => Ok(Self::Gpt4OmniMini),
|
||||
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||
"claude-3-opus" => Ok(Self::Claude3Opus),
|
||||
"claude-3-sonnet" => Ok(Self::Claude3Sonnet),
|
||||
"claude-3-haiku" => Ok(Self::Claude3Haiku),
|
||||
"gemini-1.5-pro" => Ok(Self::Gemini15Pro),
|
||||
"gemini-1.5-flash" => Ok(Self::Gemini15Flash),
|
||||
_ => Err(anyhow!("invalid model id")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||
Self::Gpt4Omni => "gpt-4o",
|
||||
Self::Gpt4OmniMini => "gpt-4o-mini",
|
||||
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||
Self::Custom { name, .. } => name,
|
||||
CloudModel::Anthropic(model) => model.id(),
|
||||
CloudModel::OpenAi(model) => model.id(),
|
||||
CloudModel::Google(model) => model.id(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||
Self::Gpt4 => "GPT 4",
|
||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||
Self::Gpt4Omni => "GPT 4 Omni",
|
||||
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Self::Custom { name, .. } => name,
|
||||
CloudModel::Anthropic(model) => model.display_name(),
|
||||
CloudModel::OpenAi(model) => model.display_name(),
|
||||
CloudModel::Google(model) => model.display_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => 2048,
|
||||
Self::Gpt4 => 4096,
|
||||
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
|
||||
Self::Gpt4OmniMini => 128000,
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku => 200000,
|
||||
Self::Gemini15Pro => 128000,
|
||||
Self::Gemini15Flash => 32000,
|
||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||
CloudModel::Anthropic(model) => model.max_token_count(),
|
||||
CloudModel::OpenAi(model) => model.max_token_count(),
|
||||
CloudModel::Google(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,5 +2,6 @@ pub mod anthropic;
|
|||
pub mod cloud;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod fake;
|
||||
pub mod google;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use anthropic::{stream_completion, Request, RequestMessage};
|
||||
use anthropic::stream_completion;
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
|
@ -18,7 +18,7 @@ use util::ResultExt;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_ID: &str = "anthropic";
|
||||
|
@ -160,40 +160,6 @@ pub struct AnthropicModel {
|
|||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl AnthropicModel {
|
||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||
preprocess_anthropic_request(&mut request);
|
||||
|
||||
let mut system_message = String::new();
|
||||
if request
|
||||
.messages
|
||||
.first()
|
||||
.map_or(false, |message| message.role == Role::System)
|
||||
{
|
||||
system_message = request.messages.remove(0).content;
|
||||
}
|
||||
|
||||
Request {
|
||||
model: self.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: match msg.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||
},
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_anthropic_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
|
@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
let request = request.into_anthropic(self.model.id().into());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
|
@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel {
|
|||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
Ok(anthropic::extract_text_from_events(response).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages.drain(..) {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_message.is_empty() {
|
||||
new_messages.insert(
|
||||
0,
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: system_message,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
request.messages = new_messages;
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
state: gpui::Model<State>,
|
||||
|
|
|
@ -7,8 +7,10 @@ use crate::{
|
|||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -16,14 +18,29 @@ use ui::prelude::*;
|
|||
|
||||
use crate::LanguageModelProvider;
|
||||
|
||||
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
|
||||
pub const PROVIDER_ID: &str = "zed.dev";
|
||||
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
pub available_models: Vec<CloudModel>,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AvailableProvider {
|
||||
Anthropic,
|
||||
OpenAi,
|
||||
Google,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
provider: AvailableProvider,
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
|
@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from CloudModel::iter()
|
||||
for model in CloudModel::iter() {
|
||||
if !matches!(model, CloudModel::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
for model in anthropic::Model::iter() {
|
||||
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
|
||||
}
|
||||
}
|
||||
for model in open_ai::Model::iter() {
|
||||
if !matches!(model, open_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::OpenAi(model));
|
||||
}
|
||||
}
|
||||
for model in google_ai::Model::iter() {
|
||||
if !matches!(model, google_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
.zed_dot_dev
|
||||
.available_models
|
||||
{
|
||||
let model = match model.provider {
|
||||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
};
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match &self.model {
|
||||
CloudModel::Gpt3Point5Turbo => {
|
||||
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||
}
|
||||
CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
|
||||
CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
|
||||
CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
|
||||
CloudModel::Gpt4OmniMini => {
|
||||
count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
|
||||
}
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
_ => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
match self.model.clone() {
|
||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
let request = google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
};
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::CountTokens as i32,
|
||||
request,
|
||||
});
|
||||
let response = response.await?;
|
||||
let response =
|
||||
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
|
||||
Ok(response.total_tokens)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel {
|
|||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku
|
||||
| CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
preprocess_anthropic_request(&mut request)
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: self.id.0.to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
351
crates/language_model/src/provider/google.rs
Normal file
351
crates/language_model/src/provider/google.rs
Normal file
|
@ -0,0 +1,351 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use google_ai::stream_generate_content;
|
||||
use gpui::{
|
||||
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
|
||||
WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
|
||||
const PROVIDER_ID: &str = "google";
|
||||
const PROVIDER_NAME: &str = "Google AI";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct GoogleSettings {
|
||||
pub api_url: String,
|
||||
pub low_speed_timeout: Option<Duration>,
|
||||
pub available_models: Vec<google_ai::Model>,
|
||||
}
|
||||
|
||||
pub struct GoogleLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
api_key: Option<String>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl GoogleLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||
let state = cx.new_model(|cx| State {
|
||||
api_key: None,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
||||
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||
Some(cx.observe(&self.state, |_, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from google_ai::Model::iter()
|
||||
for model in google_ai::Model::iter() {
|
||||
if !matches!(model, google_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| {
|
||||
Arc::new(GoogleLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let state = self.state.clone();
|
||||
let delete_credentials =
|
||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GoogleLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: google_ai::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl LanguageModel for GoogleLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("google/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.state.read(cx).api_key.clone();
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let response = google_ai::count_tokens(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(response.total_tokens)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let response =
|
||||
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let events = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text("AIzaSy...", cx);
|
||||
editor
|
||||
}),
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
let write_credentials =
|
||||
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
|
||||
let state = self.state.clone();
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the Google AI assistant, you need to add your Google AI API key.",
|
||||
"You can create an API key at: https://makersuite.google.com/app/apikey",
|
||||
"",
|
||||
"Paste your Google AI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ use gpui::{
|
|||
WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use open_ai::{stream_completion, Request, RequestMessage};
|
||||
use open_ai::stream_completion;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel {
|
|||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModel {
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
Request {
|
||||
model: self.model.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
|
@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
let request = request.into_open_ai(self.model.id().into());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
|
@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
Ok(open_ai::extract_text_from_events(response).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider,
|
||||
open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
||||
};
|
||||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use gpui::{AppContext, Global, Model, ModelContext};
|
||||
use std::sync::Arc;
|
||||
use ui::Context;
|
||||
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
||||
};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = LanguageModelRegistry::default();
|
||||
|
@ -40,6 +40,10 @@ fn register_language_model_providers(
|
|||
OllamaLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
GoogleLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
|
||||
let client = client.clone();
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{role::Role, LanguageModelId};
|
||||
use crate::role::Role;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage {
|
|||
pub content: String,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
||||
proto::LanguageModelRequestMessage {
|
||||
role: self.role.to_proto() as i32,
|
||||
content: self.content.clone(),
|
||||
tool_calls: Vec::new(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
|
@ -26,14 +15,110 @@ pub struct LanguageModelRequest {
|
|||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
model: model_id.0.to_string(),
|
||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||
stop: self.stop.clone(),
|
||||
pub fn into_open_ai(self, model: String) -> open_ai::Request {
|
||||
open_ai::Request {
|
||||
model,
|
||||
messages: self
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: self.stop,
|
||||
temperature: self.temperature,
|
||||
tool_choice: None,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
|
||||
google_ai::GenerateContentRequest {
|
||||
model,
|
||||
contents: self
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: msg.content,
|
||||
})],
|
||||
role: match msg.role {
|
||||
Role::User => google_ai::Role::User,
|
||||
Role::Assistant => google_ai::Role::Model,
|
||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
generation_config: Some(google_ai::GenerationConfig {
|
||||
candidate_count: Some(1),
|
||||
stop_sequences: Some(self.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: Some(self.temperature as f64),
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_anthropic(self, model: String) -> anthropic::Request {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in self.messages {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages: new_messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
Some(anthropic::RequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => return None,
|
||||
},
|
||||
content: message.content,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
max_tokens: 4092,
|
||||
system: system_message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@ impl Role {
|
|||
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
|
||||
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
|
||||
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
|
||||
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
|
||||
None => Role::User,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,12 +6,12 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
},
|
||||
CloudModel,
|
||||
use crate::provider::{
|
||||
anthropic::AnthropicSettings,
|
||||
cloud::{self, ZedDotDevSettings},
|
||||
google::GoogleSettings,
|
||||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
|
@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings {
|
|||
pub ollama: OllamaSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
pub google: GoogleSettings,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
|
@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent {
|
|||
pub openai: Option<OpenAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
|
@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent {
|
|||
pub available_models: Option<Vec<open_ai::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct GoogleSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||
pub available_models: Option<Vec<google_ai::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct ZedDotDevSettingsContent {
|
||||
available_models: Option<Vec<CloudModel>>,
|
||||
available_models: Option<Vec<cloud::AvailableModel>>,
|
||||
}
|
||||
|
||||
impl settings::Settings for AllLanguageModelSettings {
|
||||
|
@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
&mut settings.google.api_url,
|
||||
value.google.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
if let Some(low_speed_timeout_in_seconds) = value
|
||||
.google
|
||||
.as_ref()
|
||||
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||
{
|
||||
settings.google.low_speed_timeout =
|
||||
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||
}
|
||||
merge(
|
||||
&mut settings.google.available_models,
|
||||
value
|
||||
.google
|
||||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue