/// Trait for LLM providers.
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Get the model name.
fn model_name(&self) -> &str;
/// Get cost per token (input, output).
fn cost_per_token(&self) -> (Decimal, Decimal);
/// Complete a chat conversation.
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
/// Complete with tool use support.
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError>;
/// List available models from the provider.
/// Default implementation returns empty list.
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
Ok(Vec::new())
}
/// Fetch metadata for the current model (context length, etc.).
/// Default returns the model name with no size info.
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
Ok(ModelMetadata {
id: self.model_name().to_string(),
context_length: None,
})
}
/// Resolve which model should be reported for a given request.
///
/// Providers that ignore per-request model overrides should override this
/// and return `active_model_name()`.
fn effective_model_name(&self, requested_model: Option<&str>) -> String {
normalized_model_override(requested_model)
.map(std::borrow::ToOwned::to_owned)
.unwrap_or_else(|| self.active_model_name())
}
/// Get the currently active model name.
///
/// May differ from `model_name()` if the model was switched at runtime
/// via `set_model()`. Default returns `model_name()`.
fn active_model_name(&self) -> String {
self.model_name().to_string()
}
/// Switch the active model at runtime. Not all providers support this.
fn set_model(&self, _model: &str) -> Result<(), LlmError> {
Err(LlmError::RequestFailed {
provider: "unknown".to_string(),
reason: "Runtime model switching not supported by this provider".to_string(),
})
}
/// Calculate cost for a completion.
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
let (input_cost, output_cost) = self.cost_per_token();
input_cost * Decimal::from(input_tokens) + output_cost * Decimal::from(output_tokens)
}
/// Cost multiplier for cache-creation tokens (Anthropic prompt caching).
///
/// Returns `1.0` by default (no surcharge). Anthropic providers return
/// `1.25` for 5-minute TTL or `2.0` for 1-hour TTL.
fn cache_write_multiplier(&self) -> Decimal {
Decimal::ONE
}
/// Discount divisor for cache-read tokens.
///
/// Cached-read cost = `input_rate / cache_read_discount()`.
/// Returns `1` by default (no discount). Anthropic returns `10` (90% off),
/// OpenAI would return `2` (50% off).
fn cache_read_discount(&self) -> Decimal {
Decimal::ONE
}
}
二. Rig装饰器,将第三方接入自己的trait
pub struct RigAdapter<M: CompletionModel> {
model: M,
model_name: String,
input_cost: Decimal,
output_cost: Decimal,
/// Prompt cache retention policy (Anthropic only).
/// When not `CacheRetention::None`, injects top-level `cache_control`
/// via `additional_params` for Anthropic automatic caching. Also controls
/// the cost multiplier for cache-creation tokens.
cache_retention: CacheRetention,
/// Parameter names that this provider does not support (e.g., `"temperature"`).
/// These are stripped from requests before sending to avoid 400 errors.
unsupported_params: HashSet<String>,
/// Default additional parameters merged into every request.
/// Used by providers that need extra top-level fields (e.g., Ollama `think: true`).
default_additional_params: Option<serde_json::Value>,
/// Optional model-discovery endpoint. When set, [`LlmProvider::list_models`]
/// issues a `GET` instead of returning the empty default. rig-core's
/// `CompletionModel` does not expose model discovery, so this is wired
/// explicitly per protocol (OpenAI-compatible, Anthropic, Ollama).
models_endpoint: Option<ModelsEndpoint>,
}
pub struct FailoverProvider {
providers: Vec<Arc<dyn LlmProvider>>,候选 provider 列表(primary 在前)
/// Index of the provider that last handled a request successfully.
/// Used by `model_name()` and `cost_per_token()` so downstream cost
/// tracking reflects the provider that actually served the request.
last_used: AtomicUsize,「最近一次成功」= 跨请求全局视角
/// Per-provider cooldown tracking (same length as `providers`).
cooldowns: Vec<ProviderCooldown>,每个 provider 各自的失败计数 + 冷却时间戳
/// Reference instant for computing elapsed nanos. Shared across all
/// cooldown timestamps so they are comparable.
epoch: Instant,构造时间锚,所有冷却时间戳共用
/// Cooldown configuration.
cooldown_config: CooldownConfig,冷却策略(threshold + duration)
/// Request-scoped provider index keyed by Tokio task ID.
///
/// This allows `effective_model_name()` to report the provider that handled
/// the *current* request, even when other concurrent requests update
/// `last_used`.
provider_for_task: Mutex<HashMap<tokio::task::Id, usize>>, 「当前请求」实际用的 provider
}
pub struct CachedProvider {
inner: Arc<dyn LlmProvider>,
/// `std::sync::Mutex` (not tokio) — never held across an `.await` point,
/// so blocking acquisition is safe and keeps `set_model()` synchronous.
cache: Mutex<HashMap<String, CacheEntry>>,
config: ResponseCacheConfig, 1h和1k条
/// Total `complete()` calls (hits + misses) for periodic stats logging.
request_count: AtomicU64,
/// Running total of cache hits, independent of entry lifecycle.
/// Never decremented on eviction, so `hit_rate_pct` in stats doesn't
/// drift down as entries expire or are LRU-evicted.
total_hit_count: AtomicU64,
}
/// Replace the inner provider chain with a freshly rebuilt provider.
/// Metadata is refreshed atomically in the same critical section.
pub fn swap(&self, inner: Arc<dyn LlmProvider>) {
let fresh = ProviderSnapshot::capture(inner);
*write(&self.state) = fresh;
}
fn current(&self) -> Arc<dyn LlmProvider> {
read(&self.state).inner.clone()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
self.current().complete(request).await 所有请求内部的都先调current()获取锁
}
fn set_model(&self, model: &str) -> Result<(), LlmError> {
// Hold the write lock across both the delegate call and the snapshot
// refresh so a concurrent `swap()` cannot overwrite the just-updated
// inner provider with a snapshot captured from an older one. Inner
// `set_model` impls are synchronous (no `.await`), so holding a
// std::sync lock across the call is safe.
let mut guard = write(&self.state);
guard.inner.set_model(model)?;
let refreshed = ProviderSnapshot::capture(Arc::clone(&guard.inner));
*guard = refreshed;
Ok(())
}
pub struct LlmReloadHandle {
primary: Arc<SwappableLlmProvider>,
cheap: Option<Arc<SwappableLlmProvider>>,
/// Serializes concurrent `reload()` calls so rapid setting toggles
/// don't fire overlapping chain rebuilds (each rebuild can touch OAuth
/// refresh and HTTP probes; letting them pile up wastes upstream quota
/// and leaves the wrapper briefly pointing at a half-built chain).
reload_lock: tokio::sync::Mutex<()>,
}
/// Rebuild the provider chain from `config` and atomically replace the
/// inner providers of the primary (and cheap, if present) wrappers.
///
/// Reloads are serialized so two concurrent callers cannot race.
pub async fn reload(
&self,
config: &crate::LlmConfig,
session: Arc<crate::SessionManager>,
) -> Result<(), LlmError> {
let _guard = self.reload_lock.lock().await;
let components = crate::build_provider_chain_components(config, session).await?;
self.primary.swap(components.primary);
if let Some(ref cheap_handle) = self.cheap {
let new_cheap = components
.cheap
.unwrap_or_else(|| self.primary.clone() as Arc<dyn LlmProvider>);
cheap_handle.swap(new_cheap);
} else if components.cheap.is_some() {
// Asymmetry: no cheap wrapper was allocated at startup, so a
// newly-configured cheap model cannot be activated via hot-reload.
// Surfacing this through tracing so ops don't think the swap
// silently took effect.
tracing::warn!(
"llm hot-reload: cheap provider is now configured but was not at startup; \
it will only take effect after a full restart",
);
}
Ok(())
}原子化重建主和cheap