diff --git a/crates/completion/src/cloud.rs b/crates/completion/src/cloud.rs index ba1a7dd233..959394715b 100644 --- a/crates/completion/src/cloud.rs +++ b/crates/completion/src/cloud.rs @@ -101,7 +101,7 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - match request.model { + match &request.model { LanguageModel::Cloud(CloudModel::Gpt4) | LanguageModel::Cloud(CloudModel::Gpt4Turbo) | LanguageModel::Cloud(CloudModel::Gpt4Omni) @@ -118,19 +118,24 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider { count_open_ai_tokens(request, cx.background_executor()) } LanguageModel::Cloud(CloudModel::Custom { name, .. }) => { - let request = self.client.request(proto::CountTokensWithLanguageModel { - model: name, - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - }); - async move { - let response = request.await?; - Ok(response.token_count as usize) + if name.starts_with("anthropic/") { + // Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation. + count_open_ai_tokens(request, cx.background_executor()) + } else { + let request = self.client.request(proto::CountTokensWithLanguageModel { + model: name.clone(), + messages: request + .messages + .iter() + .map(|message| message.to_proto()) + .collect(), + }); + async move { + let response = request.await?; + Ok(response.token_count as usize) + } + .boxed() } - .boxed() } _ => future::ready(Err(anyhow!("invalid model"))).boxed(), } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 43cb393a04..0460b5dcf1 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,4 +1,3 @@ -use crate::LanguageModelRequest; pub use anthropic::Model as AnthropicModel; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; @@ -88,19 +87,4 @@ impl CloudModel { Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), } } - - pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { - match self { - Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku - | Self::Claude3_5Sonnet => { - request.preprocess_anthropic(); - } - Self::Custom { name, .. } if name.starts_with("anthropic/") => { - request.preprocess_anthropic(); - } - _ => {} - } - } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index f9c4322cdf..50a46c55a5 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -45,7 +45,7 @@ impl LanguageModelRequest { pub fn preprocess(&mut self) { match &self.model { LanguageModel::OpenAi(_) => {} - LanguageModel::Anthropic(_) => {} + LanguageModel::Anthropic(_) => self.preprocess_anthropic(), LanguageModel::Ollama(_) => {} LanguageModel::Cloud(model) => match model { CloudModel::Claude3Opus @@ -54,6 +54,9 @@ impl LanguageModelRequest { | CloudModel::Claude3_5Sonnet => { self.preprocess_anthropic(); } + CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { + self.preprocess_anthropic(); + } _ => {} }, }