Wire up Azure OpenAI completion provider (#8646)

This PR wires up support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
as an alternative AI provider in the assistant panel.

This can be configured using the following in the settings file:

```json
{
  "assistant": {
    "provider": {
      "type": "azure_openai",
      "api_url": "https://{your-resource-name}.openai.azure.com",
      "deployment_id": "gpt-4",
      "api_version": "2023-05-15"
    }
  },
}
```

You will need to deploy a model within Azure and update the settings
accordingly.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-02-29 22:01:33 -05:00 committed by GitHub
parent 7c1ef966f3
commit eb1ab69606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 290 additions and 65 deletions

View file

@ -124,16 +124,18 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
let (api_url, model_name) = cx.update(|cx| {
let (provider_kind, api_url, model_name) = cx.update(|cx| {
let settings = AssistantSettings::get_global(cx);
(
settings.openai_api_url.clone(),
settings.default_open_ai_model.full_name().to_string(),
)
})?;
anyhow::Ok((
settings.provider_kind()?,
settings.provider_api_url()?,
settings.provider_model_name()?,
))
})??;
let completion_provider = OpenAiCompletionProvider::new(
api_url,
OpenAiCompletionProviderKind::OpenAi,
provider_kind,
model_name,
cx.background_executor().clone(),
)
@ -693,24 +695,29 @@ impl AssistantPanel {
Task::ready(Ok(Vec::new()))
};
let mut model = AssistantSettings::get_global(cx)
.default_open_ai_model
.clone();
let model_name = model.full_name();
let Some(mut model_name) = AssistantSettings::get_global(cx)
.provider_model_name()
.log_err()
else {
return;
};
let prompt = cx.background_executor().spawn(async move {
let snippets = snippets.await?;
let prompt = cx.background_executor().spawn({
let model_name = model_name.clone();
async move {
let snippets = snippets.await?;
let language_name = language_name.as_deref();
generate_content_prompt(
user_prompt,
language_name,
buffer,
range,
snippets,
model_name,
project_name,
)
let language_name = language_name.as_deref();
generate_content_prompt(
user_prompt,
language_name,
buffer,
range,
snippets,
&model_name,
project_name,
)
}
});
let mut messages = Vec::new();
@ -722,7 +729,7 @@ impl AssistantPanel {
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
model = conversation.model.clone();
model_name = conversation.model.full_name().to_string();
}
cx.spawn(|_, mut cx| async move {
@ -735,7 +742,7 @@ impl AssistantPanel {
});
let request = Box::new(OpenAiRequest {
model: model.full_name().into(),
model: model_name,
messages,
stream: true,
stop: vec!["|END|>".to_string()],
@ -1454,8 +1461,14 @@ impl Conversation {
});
let settings = AssistantSettings::get_global(cx);
let model = settings.default_open_ai_model.clone();
let api_url = settings.openai_api_url.clone();
let model = settings
.provider_model()
.log_err()
.unwrap_or(OpenAiModel::FourTurbo);
let api_url = settings
.provider_api_url()
.log_err()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string());
let mut this = Self {
id: Some(Uuid::new_v4().to_string()),
@ -3655,9 +3668,9 @@ fn report_assistant_event(
let client = workspace.read(cx).project().read(cx).client();
let telemetry = client.telemetry();
let model = AssistantSettings::get_global(cx)
.default_open_ai_model
.clone();
let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else {
return;
};
telemetry.report_assistant_event(conversation_id, assistant_kind, model.full_name())
telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name)
}

View file

@ -1,10 +1,14 @@
use anyhow;
use ai::providers::open_ai::{
AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
};
use anyhow::anyhow;
use gpui::Pixels;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OpenAiModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
@ -17,25 +21,25 @@ pub enum OpenAiModel {
impl OpenAiModel {
pub fn full_name(&self) -> &'static str {
match self {
OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
OpenAiModel::Four => "gpt-4-0613",
OpenAiModel::FourTurbo => "gpt-4-1106-preview",
Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
Self::Four => "gpt-4-0613",
Self::FourTurbo => "gpt-4-1106-preview",
}
}
pub fn short_name(&self) -> &'static str {
match self {
OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
OpenAiModel::Four => "gpt-4",
OpenAiModel::FourTurbo => "gpt-4-turbo",
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo",
}
}
pub fn cycle(&self) -> Self {
match self {
OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four,
OpenAiModel::Four => OpenAiModel::FourTurbo,
OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo,
Self::ThreePointFiveTurbo => Self::Four,
Self::Four => Self::FourTurbo,
Self::FourTurbo => Self::ThreePointFiveTurbo,
}
}
}
@ -48,14 +52,99 @@ pub enum AssistantDockPosition {
Bottom,
}
#[derive(Deserialize, Debug)]
#[derive(Debug, Deserialize)]
pub struct AssistantSettings {
/// Whether to show the assistant panel button in the status bar.
pub button: bool,
/// Where to dock the assistant.
pub dock: AssistantDockPosition,
/// Default width in pixels when the assistant is docked to the left or right.
pub default_width: Pixels,
/// Default height in pixels when the assistant is docked to the bottom.
pub default_height: Pixels,
/// The default OpenAI model to use when starting new conversations.
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: OpenAiModel,
/// OpenAI API base URL to use when starting new conversations.
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: String,
/// The settings for the AI provider.
pub provider: AiProviderSettings,
}
impl AssistantSettings {
pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
match &self.provider {
AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
AiProviderSettings::AzureOpenAi(settings) => {
let deployment_id = settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
let api_version = settings
.api_version
.ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
Ok(OpenAiCompletionProviderKind::AzureOpenAi {
deployment_id,
api_version,
})
}
}
}
pub fn provider_api_url(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.api_url
.clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
AiProviderSettings::AzureOpenAi(settings) => settings
.api_url
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
}
}
pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => {
Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
}
AiProviderSettings::AzureOpenAi(_settings) => {
// TODO: We need to use an Azure OpenAI model here.
Ok(OpenAiModel::FourTurbo)
}
}
}
pub fn provider_model_name(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.default_model
.unwrap_or(OpenAiModel::FourTurbo)
.full_name()
.to_string()),
AiProviderSettings::AzureOpenAi(settings) => settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
}
}
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
type FileContent = AssistantSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}
/// Assistant panel settings
@ -77,26 +166,88 @@ pub struct AssistantSettingsContent {
///
/// Default: 320
pub default_height: Option<f32>,
/// Deprecated: Please use `provider.default_model` instead.
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: Option<OpenAiModel>,
/// Deprecated: Please use `provider.api_url` instead.
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: Option<String>,
/// The settings for the AI provider.
#[serde(default)]
pub provider: AiProviderSettingsContent,
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettings {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettings),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettings),
}
type FileContent = AssistantSettingsContent;
/// The settings for the AI provider used by the Zed Assistant.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettingsContent {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettingsContent),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettingsContent),
}
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
impl Default for AiProviderSettingsContent {
fn default() -> Self {
Self::OpenAi(OpenAiProviderSettingsContent::default())
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAiProviderSettings {
/// The OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
pub default_model: Option<OpenAiModel>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OpenAiProviderSettingsContent {
/// The OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
pub default_model: Option<OpenAiModel>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAiProviderSettings {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI API deployment ID.
pub deployment_id: Option<String>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AzureOpenAiProviderSettingsContent {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI deployment ID.
pub deployment_id: Option<String>,
}