diff --git a/Cargo.lock b/Cargo.lock index 1c251aa0c5..d02e15a67e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6965,6 +6965,7 @@ dependencies = [ "image", "lmstudio", "log", + "mistral", "ollama", "open_ai", "parking_lot", @@ -7012,6 +7013,7 @@ dependencies = [ "language_model", "lmstudio", "menu", + "mistral", "ollama", "open_ai", "project", @@ -7973,6 +7975,19 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mistral" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", + "strum", +] + [[package]] name = "msvc_spectre_libs" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 573732a73e..050f9c57c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,6 +84,7 @@ members = [ "crates/media", "crates/menu", "crates/migrator", + "crates/mistral", "crates/multi_buffer", "crates/node_runtime", "crates/notifications", @@ -283,6 +284,7 @@ markdown_preview = { path = "crates/markdown_preview" } media = { path = "crates/media" } menu = { path = "crates/menu" } migrator = { path = "crates/migrator" } +mistral = { path = "crates/mistral" } multi_buffer = { path = "crates/multi_buffer" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg new file mode 100644 index 0000000000..23b8f2ef6c --- /dev/null +++ b/assets/icons/ai_mistral.svg @@ -0,0 +1 @@ +Mistral \ No newline at end of file diff --git a/assets/settings/default.json b/assets/settings/default.json index eaab8caaf8..4c4732ccb0 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1193,6 +1193,9 @@ }, "deepseek": { "api_url": "https://api.deepseek.com" + }, + "mistral": { + "api_url": "https://api.mistral.ai/v1" } }, // Zed's Prettier integration settings. diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 8b4bc518f8..7b570a54f7 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -28,6 +28,7 @@ http_client.workspace = true image.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true +mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } parking_lot.workspace = true diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 19ceea7a53..a5cf54297f 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -269,6 +269,47 @@ impl LanguageModelRequest { } } + pub fn into_mistral(self, model: String, max_output_tokens: Option) -> mistral::Request { + let len = self.messages.len(); + let merged_messages = + self.messages + .into_iter() + .fold(Vec::with_capacity(len), |mut acc, msg| { + let role = msg.role; + let content = msg.string_contents(); + + acc.push(match role { + Role::User => mistral::RequestMessage::User { content }, + Role::Assistant => mistral::RequestMessage::Assistant { + content: Some(content), + tool_calls: Vec::new(), + }, + Role::System => mistral::RequestMessage::System { content }, + }); + acc + }); + + mistral::Request { + model, + messages: merged_messages, + stream: true, + max_tokens: max_output_tokens, + temperature: self.temperature, + response_format: None, + tools: self + .tools + .into_iter() + .map(|tool| mistral::ToolDefinition::Function { + function: mistral::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + } + } + pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest { google_ai::GenerateContentRequest { model, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 700ed739ac..b3b9ddd7f1 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -28,6 +28,7 @@ http_client.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } menu.workspace = true +mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } project.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 99e5c36d61..11f9415b59 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -17,6 +17,7 @@ pub use crate::provider::cloud::RefreshLlmTokenListener; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider; +use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; pub use crate::settings::*; @@ -64,6 +65,10 @@ fn register_language_model_providers( GoogleLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + MistralLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); cx.observe_flag::(move |enabled, cx| { diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index a7738563e7..06c7355321 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -4,5 +4,6 @@ pub mod copilot_chat; pub mod deepseek; pub mod google; pub mod lmstudio; +pub mod mistral; pub mod ollama; pub mod open_ai; diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs new file mode 100644 index 0000000000..eb239858a7 --- /dev/null +++ b/crates/language_models/src/provider/mistral.rs @@ -0,0 +1,573 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; +use gpui::{ + AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, +}; + +use futures::stream::BoxStream; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use strum::IntoEnumIterator; +use theme::ThemeSettings; +use ui::{prelude::*, Icon, IconName, Tooltip}; +use util::ResultExt; + +use crate::AllLanguageModelSettings; + +const PROVIDER_ID: &str = "mistral"; +const PROVIDER_NAME: &str = "Mistral"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct MistralSettings { + pub api_url: String, + pub available_models: Vec, + pub needs_setting_migration: bool, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub max_output_tokens: Option, + pub max_completion_tokens: Option, +} + +pub struct MistralLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).mistral; + let delete_credentials = cx.delete_credentials(&settings.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.log_err(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).mistral; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + + cx.spawn(|this, mut cx| async move { + write_credentials.await?; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .mistral + .api_url + .clone(); + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + (String::from_utf8(api_key)?, false) + }; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + }) + }) + } + } +} + +impl MistralLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for MistralLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for MistralLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiMistral + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from mistral::Model::iter() + for model in mistral::Model::iter() { + if !matches!(model, mistral::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &AllLanguageModelSettings::get_global(cx) + .mistral + .available_models + { + models.insert( + model.name.clone(), + mistral::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct MistralLanguageModel { + id: LanguageModelId, + model: mistral::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl MistralLanguageModel { + fn stream_completion( + &self, + request: mistral::Request, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result>>, + > { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).mistral; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("Missing Mistral API Key"))?; + let request = + mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for MistralLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("mistral/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + cx.background_executor() + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens()); + let stream = self.stream_completion(request, cx); + + async move { + let stream = stream.await?; + Ok(stream + .map(|result| { + result.and_then(|response| { + response + .choices + .first() + .ok_or_else(|| anyhow!("Empty response")) + .map(|choice| { + choice + .delta + .content + .clone() + .unwrap_or_default() + .map(LanguageModelCompletionEvent::Text) + }) + }) + }) + .boxed()) + } + .boxed() + } + + fn use_any_tool( + &self, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + schema: serde_json::Value, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> { + let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens()); + request.tools = vec![mistral::ToolDefinition::Function { + function: mistral::FunctionDefinition { + name: tool_name.clone(), + description: Some(tool_description), + parameters: Some(schema), + }, + }]; + + let response = self.stream_completion(request, cx); + self.request_limiter + .run(async move { + let stream = response.await?; + + let tool_args_stream = stream + .filter_map(move |response| async move { + match response { + Ok(response) => { + for choice in response.choices { + if let Some(tool_calls) = choice.delta.tool_calls { + for tool_call in tool_calls { + if let Some(function) = tool_call.function { + if let Some(args) = function.arguments { + return Some(Ok(args)); + } + } + } + } + } + None + } + Err(e) => Some(Err(e)), + } + }) + .boxed(); + + Ok(tool_args_stream) + }) + .boxed() + } +} + +struct ConfigurationView { + api_key_editor: Entity, + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", cx); + editor + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + |this, mut cx| async move { + if let Some(task) = state + .update(&mut cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + + this.update(&mut cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, |_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, |_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.reset_api_key(cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + white_space: WhiteSpace::Normal, + ..Default::default() + }; + EditorElement::new( + &self.api_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + const MISTRAL_CONSOLE_URL: &str = "https://console.mistral.ai/api-keys"; + const INSTRUCTIONS: [&str; 4] = [ + "To use Zed's assistant with Mistral, you need to add an API key. Follow these steps:", + " - Create one by visiting:", + " - Ensure your Mistral account has credits", + " - Paste your API key below and hit enter to start using the assistant", + ]; + + let env_var_set = self.state.read(cx).api_key_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new(INSTRUCTIONS[0])) + .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child( + Button::new("mistral_console", MISTRAL_CONSOLE_URL) + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(MISTRAL_CONSOLE_URL)) + ) + ) + .children( + (2..INSTRUCTIONS.len()).map(|n| + Label::new(INSTRUCTIONS[n])).collect::>()) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(cx.theme().colors().border_variant) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."), + ) + .size(LabelSize::Small), + ) + .into_any() + } else { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index eb3afb8f5e..740bfecb5e 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -16,6 +16,7 @@ use crate::provider::{ deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, + mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, }; @@ -63,6 +64,7 @@ pub struct AllLanguageModelSettings { pub copilot_chat: CopilotChatSettings, pub lmstudio: LmStudioSettings, pub deepseek: DeepSeekSettings, + pub mistral: MistralSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -76,6 +78,7 @@ pub struct AllLanguageModelSettingsContent { pub google: Option, pub deepseek: Option, pub copilot_chat: Option, + pub mistral: Option, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -171,6 +174,12 @@ pub struct DeepseekSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct MistralSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum OpenAiSettingsContent { @@ -356,6 +365,17 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + + // Mistral + let mistral = value.mistral.clone(); + merge( + &mut settings.mistral.api_url, + mistral.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.mistral.available_models, + mistral.as_ref().and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/mistral/Cargo.toml b/crates/mistral/Cargo.toml new file mode 100644 index 0000000000..c4d475f014 --- /dev/null +++ b/crates/mistral/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "mistral" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/mistral.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +strum.workspace = true diff --git a/crates/mistral/LICENSE-GPL b/crates/mistral/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/mistral/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs new file mode 100644 index 0000000000..19ed244eb9 --- /dev/null +++ b/crates/mistral/src/mistral.rs @@ -0,0 +1,344 @@ +use anyhow::{anyhow, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::convert::TryFrom; +use strum::EnumIter; + +pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[serde(rename = "codestral-latest", alias = "codestral-latest")] + CodestralLatest, + #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")] + MistralLargeLatest, + #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")] + MistralSmallLatest, + #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")] + OpenMistralNemo, + #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")] + #[default] + OpenCodestralMamba, + + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: usize, + max_output_tokens: Option, + max_completion_tokens: Option, + }, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + match id { + "codestral-latest" => Ok(Self::CodestralLatest), + "mistral-large-latest" => Ok(Self::MistralLargeLatest), + "mistral-small-latest" => Ok(Self::MistralSmallLatest), + "open-mistral-nemo" => Ok(Self::OpenMistralNemo), + "open-codestral-mamba" => Ok(Self::OpenCodestralMamba), + _ => Err(anyhow!("invalid model id")), + } + } + + pub fn id(&self) -> &str { + match self { + Self::CodestralLatest => "codestral-latest", + Self::MistralLargeLatest => "mistral-large-latest", + Self::MistralSmallLatest => "mistral-small-latest", + Self::OpenMistralNemo => "open-mistral-nemo", + Self::OpenCodestralMamba => "open-codestral-mamba", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::CodestralLatest => "codestral-latest", + Self::MistralLargeLatest => "mistral-large-latest", + Self::MistralSmallLatest => "mistral-small-latest", + Self::OpenMistralNemo => "open-mistral-nemo", + Self::OpenCodestralMamba => "open-codestral-mamba", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::CodestralLatest => 256000, + Self::MistralLargeLatest => 131000, + Self::MistralSmallLatest => 32000, + Self::OpenMistralNemo => 131000, + Self::OpenCodestralMamba => 256000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + _ => None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + Text, + #[serde(rename = "json_object")] + JsonObject, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + Function { function: FunctionDefinition }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub prompt: String, + pub max_tokens: u32, + pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prediction: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub rewrite_speculation: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Prediction { + Content { content: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto, + Required, + None, + Other(ToolDefinition), +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct CompletionChoice { + pub text: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamChoice { + pub index: u32, + pub delta: StreamDelta, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Err(anyhow!( + "Failed to connect to Mistral API: {} {}", + response.status(), + body, + )) + } +} diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 3288c03fc3..ca90a16ea7 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -131,6 +131,7 @@ pub enum IconName { AiDeepSeek, AiGoogle, AiLmStudio, + AiMistral, AiOllama, AiOpenAi, AiZed,