diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 489837f159..fe11458282 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -219,32 +219,24 @@ impl LanguageModel for CopilotChatLanguageModel { cx: &App, ) -> BoxFuture<'static, Result> { match self.model { - CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx), - CopilotChatModel::Claude3_7Sonnet => count_anthropic_tokens(request, cx), - CopilotChatModel::Claude3_7SonnetThinking => count_anthropic_tokens(request, cx), + CopilotChatModel::Claude3_5Sonnet + | CopilotChatModel::Claude3_7Sonnet + | CopilotChatModel::Claude3_7SonnetThinking => count_anthropic_tokens(request, cx), CopilotChatModel::Gemini20Flash | CopilotChatModel::Gemini25Pro => { count_google_tokens(request, cx) } - _ => { - let model = match self.model { - CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, - CopilotChatModel::Gpt4 => open_ai::Model::Four, - CopilotChatModel::Gpt4_1 => open_ai::Model::FourPointOne, - CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, - CopilotChatModel::O1 => open_ai::Model::O1, - CopilotChatModel::O3Mini => open_ai::Model::O3Mini, - CopilotChatModel::O3 => open_ai::Model::O3, - CopilotChatModel::O4Mini => open_ai::Model::O4Mini, - CopilotChatModel::Claude3_5Sonnet - | CopilotChatModel::Claude3_7Sonnet - | CopilotChatModel::Claude3_7SonnetThinking - | CopilotChatModel::Gemini20Flash - | CopilotChatModel::Gemini25Pro => { - unreachable!() - } - }; - count_open_ai_tokens(request, model, cx) + CopilotChatModel::Gpt4o => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx), + CopilotChatModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx), + CopilotChatModel::Gpt4_1 => { + count_open_ai_tokens(request, open_ai::Model::FourPointOne, cx) } + CopilotChatModel::Gpt3_5Turbo => { + count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) + } + CopilotChatModel::O1 => count_open_ai_tokens(request, open_ai::Model::O1, cx), + CopilotChatModel::O3Mini => count_open_ai_tokens(request, open_ai::Model::O3Mini, cx), + CopilotChatModel::O3 => count_open_ai_tokens(request, open_ai::Model::O3, cx), + CopilotChatModel::O4Mini => count_open_ai_tokens(request, open_ai::Model::O4Mini, cx), } } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index e186a5151e..4f2a750c5a 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -539,7 +539,7 @@ struct RawToolCall { pub fn count_open_ai_tokens( request: LanguageModelRequest, - model: open_ai::Model, + model: Model, cx: &App, ) -> BoxFuture<'static, Result> { cx.background_spawn(async move { @@ -559,11 +559,33 @@ pub fn count_open_ai_tokens( .collect::>(); match model { - open_ai::Model::Custom { .. } - | open_ai::Model::O1Mini - | open_ai::Model::O1 - | open_ai::Model::O3Mini => tiktoken_rs::num_tokens_from_messages("gpt-4", &messages), - _ => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + Model::Custom { max_tokens, .. } => { + let model = if max_tokens >= 100_000 { + // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o + "gpt-4o" + } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages) + } + // Not currently supported by tiktoken_rs. All use the same tokenizer as gpt-4o (o200k_base) + Model::O1 + | Model::FourPointOne + | Model::FourPointOneMini + | Model::FourPointOneNano + | Model::O3Mini + | Model::O3 + | Model::O4Mini => tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages), + // Currently supported by tiktoken_rs + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::FourOmni + | Model::FourOmniMini + | Model::O1Preview + | Model::O1Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), } }) .boxed()