diff --git a/Cargo.lock b/Cargo.lock
index 1c251aa0c5..d02e15a67e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -6965,6 +6965,7 @@ dependencies = [
"image",
"lmstudio",
"log",
+ "mistral",
"ollama",
"open_ai",
"parking_lot",
@@ -7012,6 +7013,7 @@ dependencies = [
"language_model",
"lmstudio",
"menu",
+ "mistral",
"ollama",
"open_ai",
"project",
@@ -7973,6 +7975,19 @@ dependencies = [
"windows-sys 0.48.0",
]
+[[package]]
+name = "mistral"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "http_client",
+ "schemars",
+ "serde",
+ "serde_json",
+ "strum",
+]
+
[[package]]
name = "msvc_spectre_libs"
version = "0.1.2"
diff --git a/Cargo.toml b/Cargo.toml
index 573732a73e..050f9c57c8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -84,6 +84,7 @@ members = [
"crates/media",
"crates/menu",
"crates/migrator",
+ "crates/mistral",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
@@ -283,6 +284,7 @@ markdown_preview = { path = "crates/markdown_preview" }
media = { path = "crates/media" }
menu = { path = "crates/menu" }
migrator = { path = "crates/migrator" }
+mistral = { path = "crates/mistral" }
multi_buffer = { path = "crates/multi_buffer" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg
new file mode 100644
index 0000000000..23b8f2ef6c
--- /dev/null
+++ b/assets/icons/ai_mistral.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/assets/settings/default.json b/assets/settings/default.json
index eaab8caaf8..4c4732ccb0 100644
--- a/assets/settings/default.json
+++ b/assets/settings/default.json
@@ -1193,6 +1193,9 @@
},
"deepseek": {
"api_url": "https://api.deepseek.com"
+ },
+ "mistral": {
+ "api_url": "https://api.mistral.ai/v1"
}
},
// Zed's Prettier integration settings.
diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml
index 8b4bc518f8..7b570a54f7 100644
--- a/crates/language_model/Cargo.toml
+++ b/crates/language_model/Cargo.toml
@@ -28,6 +28,7 @@ http_client.workspace = true
image.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
+mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs
index 19ceea7a53..a5cf54297f 100644
--- a/crates/language_model/src/request.rs
+++ b/crates/language_model/src/request.rs
@@ -269,6 +269,47 @@ impl LanguageModelRequest {
}
}
+ pub fn into_mistral(self, model: String, max_output_tokens: Option) -> mistral::Request {
+ let len = self.messages.len();
+ let merged_messages =
+ self.messages
+ .into_iter()
+ .fold(Vec::with_capacity(len), |mut acc, msg| {
+ let role = msg.role;
+ let content = msg.string_contents();
+
+ acc.push(match role {
+ Role::User => mistral::RequestMessage::User { content },
+ Role::Assistant => mistral::RequestMessage::Assistant {
+ content: Some(content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => mistral::RequestMessage::System { content },
+ });
+ acc
+ });
+
+ mistral::Request {
+ model,
+ messages: merged_messages,
+ stream: true,
+ max_tokens: max_output_tokens,
+ temperature: self.temperature,
+ response_format: None,
+ tools: self
+ .tools
+ .into_iter()
+ .map(|tool| mistral::ToolDefinition::Function {
+ function: mistral::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ }
+ }
+
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
google_ai::GenerateContentRequest {
model,
diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml
index 700ed739ac..b3b9ddd7f1 100644
--- a/crates/language_models/Cargo.toml
+++ b/crates/language_models/Cargo.toml
@@ -28,6 +28,7 @@ http_client.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
menu.workspace = true
+mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
project.workspace = true
diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs
index 99e5c36d61..11f9415b59 100644
--- a/crates/language_models/src/language_models.rs
+++ b/crates/language_models/src/language_models.rs
@@ -17,6 +17,7 @@ pub use crate::provider::cloud::RefreshLlmTokenListener;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
+use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
pub use crate::settings::*;
@@ -64,6 +65,10 @@ fn register_language_model_providers(
GoogleLanguageModelProvider::new(client.http_client(), cx),
cx,
);
+ registry.register_provider(
+ MistralLanguageModelProvider::new(client.http_client(), cx),
+ cx,
+ );
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
cx.observe_flag::(move |enabled, cx| {
diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs
index a7738563e7..06c7355321 100644
--- a/crates/language_models/src/provider.rs
+++ b/crates/language_models/src/provider.rs
@@ -4,5 +4,6 @@ pub mod copilot_chat;
pub mod deepseek;
pub mod google;
pub mod lmstudio;
+pub mod mistral;
pub mod ollama;
pub mod open_ai;
diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs
new file mode 100644
index 0000000000..eb239858a7
--- /dev/null
+++ b/crates/language_models/src/provider/mistral.rs
@@ -0,0 +1,573 @@
+use anyhow::{anyhow, Result};
+use collections::BTreeMap;
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{future::BoxFuture, FutureExt, StreamExt};
+use gpui::{
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+};
+use http_client::HttpClient;
+use language_model::{
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+};
+
+use futures::stream::BoxStream;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsStore};
+use std::sync::Arc;
+use strum::IntoEnumIterator;
+use theme::ThemeSettings;
+use ui::{prelude::*, Icon, IconName, Tooltip};
+use util::ResultExt;
+
+use crate::AllLanguageModelSettings;
+
+const PROVIDER_ID: &str = "mistral";
+const PROVIDER_NAME: &str = "Mistral";
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct MistralSettings {
+ pub api_url: String,
+ pub available_models: Vec,
+ pub needs_setting_migration: bool,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+ pub name: String,
+ pub display_name: Option,
+ pub max_tokens: usize,
+ pub max_output_tokens: Option,
+ pub max_completion_tokens: Option,
+}
+
+pub struct MistralLanguageModelProvider {
+ http_client: Arc,
+ state: gpui::Entity,
+}
+
+pub struct State {
+ api_key: Option,
+ api_key_from_env: bool,
+ _subscription: Subscription,
+}
+
+const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
+ fn reset_api_key(&self, cx: &mut Context) -> Task> {
+ let settings = &AllLanguageModelSettings::get_global(cx).mistral;
+ let delete_credentials = cx.delete_credentials(&settings.api_url);
+ cx.spawn(|this, mut cx| async move {
+ delete_credentials.await.log_err();
+ this.update(&mut cx, |this, cx| {
+ this.api_key = None;
+ this.api_key_from_env = false;
+ cx.notify();
+ })
+ })
+ }
+
+ fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> {
+ let settings = &AllLanguageModelSettings::get_global(cx).mistral;
+ let write_credentials =
+ cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
+
+ cx.spawn(|this, mut cx| async move {
+ write_credentials.await?;
+ this.update(&mut cx, |this, cx| {
+ this.api_key = Some(api_key);
+ cx.notify();
+ })
+ })
+ }
+
+ fn authenticate(&self, cx: &mut Context) -> Task> {
+ if self.is_authenticated() {
+ Task::ready(Ok(()))
+ } else {
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .mistral
+ .api_url
+ .clone();
+ cx.spawn(|this, mut cx| async move {
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
+ (api_key, true)
+ } else {
+ let (_, api_key) = cx
+ .update(|cx| cx.read_credentials(&api_url))?
+ .await?
+ .ok_or_else(|| anyhow!("credentials not found"))?;
+ (String::from_utf8(api_key)?, false)
+ };
+ this.update(&mut cx, |this, cx| {
+ this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
+ cx.notify();
+ })
+ })
+ }
+ }
+}
+
+impl MistralLanguageModelProvider {
+ pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State {
+ api_key: None,
+ api_key_from_env: false,
+ _subscription: cx.observe_global::(|_this: &mut State, cx| {
+ cx.notify();
+ }),
+ });
+
+ Self { http_client, state }
+ }
+}
+
+impl LanguageModelProviderState for MistralLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for MistralLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiMistral
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec> {
+ let mut models = BTreeMap::default();
+
+ // Add base models from mistral::Model::iter()
+ for model in mistral::Model::iter() {
+ if !matches!(model, mistral::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ // Override with available models from settings
+ for model in &AllLanguageModelSettings::get_global(cx)
+ .mistral
+ .available_models
+ {
+ models.insert(
+ model.name.clone(),
+ mistral::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
+ max_completion_tokens: model.max_completion_tokens,
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| {
+ Arc::new(MistralLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ }) as Arc
+ })
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ }
+}
+
+pub struct MistralLanguageModel {
+ id: LanguageModelId,
+ model: mistral::Model,
+ state: gpui::Entity,
+ http_client: Arc,
+ request_limiter: RateLimiter,
+}
+
+impl MistralLanguageModel {
+ fn stream_completion(
+ &self,
+ request: mistral::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result>>,
+ > {
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).mistral;
+ (state.api_key.clone(), settings.api_url.clone())
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ let future = self.request_limiter.stream(async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing Mistral API Key"))?;
+ let request =
+ mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for MistralLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("mistral/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> usize {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option {
+ self.model.max_output_tokens()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result> {
+ cx.background_executor()
+ .spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::>();
+
+ tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
+ })
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result>>> {
+ let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
+ let stream = self.stream_completion(request, cx);
+
+ async move {
+ let stream = stream.await?;
+ Ok(stream
+ .map(|result| {
+ result.and_then(|response| {
+ response
+ .choices
+ .first()
+ .ok_or_else(|| anyhow!("Empty response"))
+ .map(|choice| {
+ choice
+ .delta
+ .content
+ .clone()
+ .unwrap_or_default()
+ .map(LanguageModelCompletionEvent::Text)
+ })
+ })
+ })
+ .boxed())
+ }
+ .boxed()
+ }
+
+ fn use_any_tool(
+ &self,
+ request: LanguageModelRequest,
+ tool_name: String,
+ tool_description: String,
+ schema: serde_json::Value,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result>>> {
+ let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
+ request.tools = vec![mistral::ToolDefinition::Function {
+ function: mistral::FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(schema),
+ },
+ }];
+
+ let response = self.stream_completion(request, cx);
+ self.request_limiter
+ .run(async move {
+ let stream = response.await?;
+
+ let tool_args_stream = stream
+ .filter_map(move |response| async move {
+ match response {
+ Ok(response) => {
+ for choice in response.choices {
+ if let Some(tool_calls) = choice.delta.tool_calls {
+ for tool_call in tool_calls {
+ if let Some(function) = tool_call.function {
+ if let Some(args) = function.arguments {
+ return Some(Ok(args));
+ }
+ }
+ }
+ }
+ }
+ None
+ }
+ Err(e) => Some(Err(e)),
+ }
+ })
+ .boxed();
+
+ Ok(tool_args_stream)
+ })
+ .boxed()
+ }
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity,
+ state: gpui::Entity,
+ load_credentials_task: Option>,
+}
+
+impl ConfigurationView {
+ fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", cx);
+ editor
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ |this, mut cx| async move {
+ if let Some(task) = state
+ .update(&mut cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ // We don't log an error, because "not signed in" is also an error.
+ let _ = task.await;
+ }
+
+ this.update(&mut cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) {
+ let api_key = self.api_key_editor.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ let state = self.state.clone();
+ cx.spawn_in(window, |_, mut cx| async move {
+ state
+ .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, |_, mut cx| async move {
+ state
+ .update(&mut cx, |state, cx| state.reset_api_key(cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement {
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_fallbacks: settings.ui_font.fallbacks.clone(),
+ font_size: rems(0.875).into(),
+ font_weight: settings.ui_font.weight,
+ font_style: FontStyle::Normal,
+ line_height: relative(1.3),
+ white_space: WhiteSpace::Normal,
+ ..Default::default()
+ };
+ EditorElement::new(
+ &self.api_key_editor,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+
+ fn should_render_editor(&self, cx: &mut Context) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement {
+ const MISTRAL_CONSOLE_URL: &str = "https://console.mistral.ai/api-keys";
+ const INSTRUCTIONS: [&str; 4] = [
+ "To use Zed's assistant with Mistral, you need to add an API key. Follow these steps:",
+ " - Create one by visiting:",
+ " - Ensure your Mistral account has credits",
+ " - Paste your API key below and hit enter to start using the assistant",
+ ];
+
+ let env_var_set = self.state.read(cx).api_key_from_env;
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials...")).into_any()
+ } else if self.should_render_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new(INSTRUCTIONS[0]))
+ .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
+ Button::new("mistral_console", MISTRAL_CONSOLE_URL)
+ .style(ButtonStyle::Subtle)
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .on_click(move |_, _, cx| cx.open_url(MISTRAL_CONSOLE_URL))
+ )
+ )
+ .children(
+ (2..INSTRUCTIONS.len()).map(|n|
+ Label::new(INSTRUCTIONS[n])).collect::>())
+ .child(
+ h_flex()
+ .w_full()
+ .my_2()
+ .px_2()
+ .py_1()
+ .bg(cx.theme().colors().editor_background)
+ .border_1()
+ .border_color(cx.theme().colors().border_variant)
+ .rounded_md()
+ .child(self.render_api_key_editor(cx)),
+ )
+ .child(
+ Label::new(
+ format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small),
+ )
+ .into_any()
+ } else {
+ h_flex()
+ .size_full()
+ .justify_between()
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new(if env_var_set {
+ format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
+ )
+ .child(
+ Button::new("reset-key", "Reset key")
+ .icon(Some(IconName::Trash))
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
+ )
+ .into_any()
+ }
+ }
+}
diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs
index eb3afb8f5e..740bfecb5e 100644
--- a/crates/language_models/src/settings.rs
+++ b/crates/language_models/src/settings.rs
@@ -16,6 +16,7 @@ use crate::provider::{
deepseek::DeepSeekSettings,
google::GoogleSettings,
lmstudio::LmStudioSettings,
+ mistral::MistralSettings,
ollama::OllamaSettings,
open_ai::OpenAiSettings,
};
@@ -63,6 +64,7 @@ pub struct AllLanguageModelSettings {
pub copilot_chat: CopilotChatSettings,
pub lmstudio: LmStudioSettings,
pub deepseek: DeepSeekSettings,
+ pub mistral: MistralSettings,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -76,6 +78,7 @@ pub struct AllLanguageModelSettingsContent {
pub google: Option,
pub deepseek: Option,
pub copilot_chat: Option,
+ pub mistral: Option,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -171,6 +174,12 @@ pub struct DeepseekSettingsContent {
pub available_models: Option>,
}
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct MistralSettingsContent {
+ pub api_url: Option,
+ pub available_models: Option>,
+}
+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum OpenAiSettingsContent {
@@ -356,6 +365,17 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
+
+ // Mistral
+ let mistral = value.mistral.clone();
+ merge(
+ &mut settings.mistral.api_url,
+ mistral.as_ref().and_then(|s| s.api_url.clone()),
+ );
+ merge(
+ &mut settings.mistral.available_models,
+ mistral.as_ref().and_then(|s| s.available_models.clone()),
+ );
}
Ok(settings)
diff --git a/crates/mistral/Cargo.toml b/crates/mistral/Cargo.toml
new file mode 100644
index 0000000000..c4d475f014
--- /dev/null
+++ b/crates/mistral/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "mistral"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/mistral.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+http_client.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+strum.workspace = true
diff --git a/crates/mistral/LICENSE-GPL b/crates/mistral/LICENSE-GPL
new file mode 120000
index 0000000000..89e542f750
--- /dev/null
+++ b/crates/mistral/LICENSE-GPL
@@ -0,0 +1 @@
+../../LICENSE-GPL
\ No newline at end of file
diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs
new file mode 100644
index 0000000000..19ed244eb9
--- /dev/null
+++ b/crates/mistral/src/mistral.rs
@@ -0,0 +1,344 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+use std::convert::TryFrom;
+use strum::EnumIter;
+
+pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+ Tool,
+}
+
+impl TryFrom for Role {
+ type Error = anyhow::Error;
+
+ fn try_from(value: String) -> Result {
+ match value.as_str() {
+ "user" => Ok(Self::User),
+ "assistant" => Ok(Self::Assistant),
+ "system" => Ok(Self::System),
+ "tool" => Ok(Self::Tool),
+ _ => Err(anyhow!("invalid role '{value}'")),
+ }
+ }
+}
+
+impl From for String {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => "user".to_owned(),
+ Role::Assistant => "assistant".to_owned(),
+ Role::System => "system".to_owned(),
+ Role::Tool => "tool".to_owned(),
+ }
+ }
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
+pub enum Model {
+ #[serde(rename = "codestral-latest", alias = "codestral-latest")]
+ CodestralLatest,
+ #[serde(rename = "mistral-large-latest", alias = "mistral-large-latest")]
+ MistralLargeLatest,
+ #[serde(rename = "mistral-small-latest", alias = "mistral-small-latest")]
+ MistralSmallLatest,
+ #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo")]
+ OpenMistralNemo,
+ #[serde(rename = "open-codestral-mamba", alias = "open-codestral-mamba")]
+ #[default]
+ OpenCodestralMamba,
+
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ display_name: Option,
+ max_tokens: usize,
+ max_output_tokens: Option,
+ max_completion_tokens: Option,
+ },
+}
+
+impl Model {
+ pub fn from_id(id: &str) -> Result {
+ match id {
+ "codestral-latest" => Ok(Self::CodestralLatest),
+ "mistral-large-latest" => Ok(Self::MistralLargeLatest),
+ "mistral-small-latest" => Ok(Self::MistralSmallLatest),
+ "open-mistral-nemo" => Ok(Self::OpenMistralNemo),
+ "open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
+ _ => Err(anyhow!("invalid model id")),
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ Self::CodestralLatest => "codestral-latest",
+ Self::MistralLargeLatest => "mistral-large-latest",
+ Self::MistralSmallLatest => "mistral-small-latest",
+ Self::OpenMistralNemo => "open-mistral-nemo",
+ Self::OpenCodestralMamba => "open-codestral-mamba",
+ Self::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::CodestralLatest => "codestral-latest",
+ Self::MistralLargeLatest => "mistral-large-latest",
+ Self::MistralSmallLatest => "mistral-small-latest",
+ Self::OpenMistralNemo => "open-mistral-nemo",
+ Self::OpenCodestralMamba => "open-codestral-mamba",
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_ref().unwrap_or(name),
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ Self::CodestralLatest => 256000,
+ Self::MistralLargeLatest => 131000,
+ Self::MistralSmallLatest => 32000,
+ Self::OpenMistralNemo => 131000,
+ Self::OpenCodestralMamba => 256000,
+ Self::Custom { max_tokens, .. } => *max_tokens,
+ }
+ }
+
+ pub fn max_output_tokens(&self) -> Option {
+ match self {
+ Self::Custom {
+ max_output_tokens, ..
+ } => *max_output_tokens,
+ _ => None,
+ }
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Request {
+ pub model: String,
+ pub messages: Vec,
+ pub stream: bool,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub max_tokens: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub temperature: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub response_format: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum ResponseFormat {
+ Text,
+ #[serde(rename = "json_object")]
+ JsonObject,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ToolDefinition {
+ Function { function: FunctionDefinition },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FunctionDefinition {
+ pub name: String,
+ pub description: Option,
+ pub parameters: Option,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CompletionRequest {
+ pub model: String,
+ pub prompt: String,
+ pub max_tokens: u32,
+ pub temperature: f32,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub prediction: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub rewrite_speculation: Option,
+}
+
+#[derive(Clone, Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Prediction {
+ Content { content: String },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+ Auto,
+ Required,
+ None,
+ Other(ToolDefinition),
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum RequestMessage {
+ Assistant {
+ content: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tool_calls: Vec,
+ },
+ User {
+ content: String,
+ },
+ System {
+ content: String,
+ },
+ Tool {
+ content: String,
+ tool_call_id: String,
+ },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCall {
+ pub id: String,
+ #[serde(flatten)]
+ pub content: ToolCallContent,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+ Function { function: FunctionContent },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionContent {
+ pub name: String,
+ pub arguments: String,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct CompletionChoice {
+ pub text: String,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Response {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec,
+ pub usage: Usage,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Usage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Choice {
+ pub index: u32,
+ pub message: RequestMessage,
+ pub finish_reason: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct StreamResponse {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct StreamChoice {
+ pub index: u32,
+ pub delta: StreamDelta,
+ pub finish_reason: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct StreamDelta {
+ pub role: Option,
+ pub content: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_calls: Option>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub reasoning_content: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCallChunk {
+ pub index: usize,
+ pub id: Option,
+ pub function: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionChunk {
+ pub name: Option,
+ pub arguments: Option,
+}
+
+pub async fn stream_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result>> {
+ let uri = format!("{api_url}/chat/completions");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key));
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ let line = line.strip_prefix("data: ")?;
+ if line == "[DONE]" {
+ None
+ } else {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ Err(anyhow!(
+ "Failed to connect to Mistral API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}
diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs
index 3288c03fc3..ca90a16ea7 100644
--- a/crates/ui/src/components/icon.rs
+++ b/crates/ui/src/components/icon.rs
@@ -131,6 +131,7 @@ pub enum IconName {
AiDeepSeek,
AiGoogle,
AiLmStudio,
+ AiMistral,
AiOllama,
AiOpenAi,
AiZed,