Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)

Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Nathan Sobo 2024-03-19 12:22:26 -06:00 committed by GitHub
parent 905a24079a
commit 8ae5a3b61a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 3647 additions and 8937 deletions

View file

@ -1,22 +1,24 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod codegen;
mod completion_provider;
mod prompts;
mod saved_conversation;
mod streaming_diff;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAiModel;
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use client::{proto, Client};
pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, SharedString};
use regex::Regex;
pub(crate) use saved_conversation::*;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
use settings::Settings;
use std::{
fmt::{self, Display},
sync::Arc,
};
actions!(
assistant,
@ -30,7 +32,6 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
@ -39,6 +40,139 @@ actions!(
)]
struct MessageId(usize);
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
OpenAi(OpenAiModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::ZedDotDev(ZedDotDevModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
LanguageModel::ZedDotDev(model) => match model {
ZedDotDevModel::GptThreePointFiveTurbo
| ZedDotDevModel::GptFour
| ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
},
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: match self.role {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(),
}
}
}
#[derive(Debug, Default, Serialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelChoiceDelta {
pub index: u32,
pub delta: LanguageModelResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
@ -53,71 +187,9 @@ enum MessageStatus {
Error(SharedString),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
pub fn init(cx: &mut AppContext) {
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantSettings::register(cx);
completion_provider::init(client, cx);
assistant_panel::init(cx);
}