assistant: Overhaul provider infrastructure (#14929)

<img width="624" alt="image"
src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815">

- [x] Correctly set custom model token count
- [x] How to count tokens for Gemini models?
- [x] Feature flag zed.dev provider
- [x] Figure out how to configure custom models
- [ ] Update docs

Release Notes:

- Added support for quickly switching between multiple language model
providers in the assistant panel

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-23 19:48:41 +02:00 committed by GitHub
parent 17ef9a367f
commit d0f52e90e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2757 additions and 2023 deletions

View file

@ -1,7 +1,4 @@
use crate::{
model::{CloudModel, LanguageModel},
role::Role,
};
use crate::{role::Role, LanguageModelId};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@ -23,16 +20,15 @@ impl LanguageModelRequestMessage {
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
model: model_id.0.to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
@ -40,70 +36,6 @@ impl LanguageModelRequest {
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => self.preprocess_anthropic(),
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
self.preprocess_anthropic();
}
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
self.preprocess_anthropic();
}
_ => {}
},
}
}
pub fn preprocess_anthropic(&mut self) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in self.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
self.messages = new_messages;
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]