diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index e744894cf3..41cea4ab86 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -108,8 +108,12 @@ pub enum Model { } impl Model { - pub fn default_fast() -> Self { - Self::Claude3_5Haiku + pub fn default_fast(region: &str) -> Self { + if region.starts_with("us-") { + Self::Claude3_5Haiku + } else { + Self::Claude3Haiku + } } pub fn from_id(id: &str) -> anyhow::Result { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index c377e614c1..ed5e372616 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -229,6 +229,17 @@ impl State { Ok(()) }) } + + fn get_region(&self) -> String { + // Get region - from credentials or directly from settings + let credentials_region = self.credentials.as_ref().map(|s| s.region.clone()); + let settings_region = self.settings.as_ref().and_then(|s| s.region.clone()); + + // Use credentials region if available, otherwise use settings region, finally fall back to default + credentials_region + .or(settings_region) + .unwrap_or(String::from("us-east-1")) + } } pub struct BedrockLanguageModelProvider { @@ -289,8 +300,9 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { Some(self.create_language_model(bedrock::Model::default())) } - fn default_fast_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(bedrock::Model::default_fast())) + fn default_fast_model(&self, cx: &App) -> Option> { + let region = self.state.read(cx).get_region(); + Some(self.create_language_model(bedrock::Model::default_fast(region.as_str()))) } fn provided_models(&self, cx: &App) -> Vec> { @@ -377,11 +389,7 @@ impl BedrockModel { let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone()); - let region = state - .settings - .as_ref() - .and_then(|s| s.region.clone()) - .unwrap_or(String::from("us-east-1")); + let region = state.get_region(); ( auth_method, @@ -530,16 +538,7 @@ impl LanguageModel for BedrockModel { LanguageModelCompletionError, >, > { - let Ok(region) = cx.read_entity(&self.state, |state, _cx| { - // Get region - from credentials or directly from settings - let credentials_region = state.credentials.as_ref().map(|s| s.region.clone()); - let settings_region = state.settings.as_ref().and_then(|s| s.region.clone()); - - // Use credentials region if available, otherwise use settings region, finally fall back to default - credentials_region - .or(settings_region) - .unwrap_or(String::from("us-east-1")) - }) else { + let Ok(region) = cx.read_entity(&self.state, |state, _cx| state.get_region()) else { return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed(); };