From 459be664067041620ab781d3b5c8023c2081b842 Mon Sep 17 00:00:00 2001 From: Sandipsinh Rathod Date: Sun, 13 Oct 2024 00:58:13 -0400 Subject: [PATCH] chore: switch to gemani --- margdarshak-http-cache/src/cache.rs | 9 ++++++-- src/app_ctx.rs | 2 +- src/http/handle_req.rs | 15 +++++++------ src/http/mod.rs | 2 +- src/http/request.rs | 18 ++++++++++++---- src/io/env.rs | 20 +++++++++++++++++ src/io/fs.rs | 4 ++-- src/io/http.rs | 12 +++++++---- src/io/mod.rs | 3 ++- src/lib.rs | 10 ++++----- src/margdarshak/helper_md/query.md | 6 ++---- src/margdarshak/http1.rs | 33 +++++++++++------------------ src/margdarshak/mod.rs | 27 ++--------------------- src/margdarshak/model.rs | 22 +++++++++---------- src/margdarshak/runner.rs | 29 ++++++++++++++++--------- src/scrapper/mod.rs | 2 +- src/scrapper/scrape.rs | 27 +++++++++++++++-------- src/target_rt.rs | 15 +++++++------ 18 files changed, 142 insertions(+), 114 deletions(-) create mode 100644 src/io/env.rs diff --git a/margdarshak-http-cache/src/cache.rs b/margdarshak-http-cache/src/cache.rs index ee86b72..9a8baab 100644 --- a/margdarshak-http-cache/src/cache.rs +++ b/margdarshak-http-cache/src/cache.rs @@ -30,7 +30,9 @@ impl HttpCacheManager { .eviction_policy(EvictionPolicy::lru()) .max_capacity(cache_size) .build(); - Self { cache: Arc::new(cache) } + Self { + cache: Arc::new(cache), + } } pub async fn clear(&self) -> Result<()> { @@ -56,7 +58,10 @@ impl CacheManager for HttpCacheManager { response: HttpResponse, policy: CachePolicy, ) -> Result { - let data = Store { response: response.clone(), policy }; + let data = Store { + response: response.clone(), + policy, + }; self.cache.insert(cache_key, data).await; self.cache.run_pending_tasks().await; Ok(response) diff --git a/src/app_ctx.rs b/src/app_ctx.rs index 7d3e8b6..61a4eeb 100644 --- a/src/app_ctx.rs +++ b/src/app_ctx.rs @@ -3,4 +3,4 @@ use crate::margdarshak::model::Wizard; pub struct AppCtx { pub wizard: Wizard, pub md: String, -} \ No newline at end of file +} diff --git a/src/http/handle_req.rs b/src/http/handle_req.rs index fa0a1fb..7e513ac 100644 --- a/src/http/handle_req.rs +++ b/src/http/handle_req.rs @@ -1,12 +1,15 @@ -use std::sync::Arc; -use bytes::Bytes; -use http_body_util::Full; -use hyper::Response; use crate::app_ctx::AppCtx; use crate::http::request::Request; use crate::margdarshak::model::Question; +use bytes::Bytes; +use http_body_util::Full; +use hyper::Response; +use std::sync::Arc; -pub async fn handle_request(req: Request, app_ctx: Arc) -> anyhow::Result>> { +pub async fn handle_request( + req: Request, + app_ctx: Arc, +) -> anyhow::Result>> { let body = String::from_utf8(req.body.to_vec())?; tracing::info!("{}", body); let body = Question::process(&app_ctx.md, &body); @@ -19,4 +22,4 @@ pub async fn handle_request(req: Request, app_ctx: Arc) -> anyhow::Resul .body(Full::new(Bytes::from(response.ans)))?; Ok(hyper_response) -} \ No newline at end of file +} diff --git a/src/http/mod.rs b/src/http/mod.rs index d9fc20e..8dcfb93 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,2 +1,2 @@ +pub mod handle_req; pub mod request; -pub mod handle_req; \ No newline at end of file diff --git a/src/http/request.rs b/src/http/request.rs index 9b4661d..f4e9277 100644 --- a/src/http/request.rs +++ b/src/http/request.rs @@ -1,6 +1,6 @@ use bytes::Bytes; -use hyper::body::Incoming; use http_body_util::BodyExt; +use hyper::body::Incoming; pub struct Request { pub method: String, @@ -17,7 +17,17 @@ impl Request { let method = parts.method.to_string(); let path = parts.uri.path().to_string(); let query = parts.uri.query().map(|q| q.to_string()); - let headers = parts.headers.iter().map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string())).collect(); - Ok(Self { method, path, query, headers, body }) + let headers = parts + .headers + .iter() + .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string())) + .collect(); + Ok(Self { + method, + path, + query, + headers, + body, + }) } -} \ No newline at end of file +} diff --git a/src/io/env.rs b/src/io/env.rs new file mode 100644 index 0000000..1ab6564 --- /dev/null +++ b/src/io/env.rs @@ -0,0 +1,20 @@ +use crate::target_rt::EnvIO; +use std::collections::HashMap; + +pub struct NativeEnvIO { + map: HashMap, +} + +impl NativeEnvIO { + pub fn new() -> Self { + Self { + map: std::env::vars().collect(), + } + } +} + +impl EnvIO for NativeEnvIO { + fn get_env(&self, key: &str) -> Option { + self.map.get(key).cloned() + } +} diff --git a/src/io/fs.rs b/src/io/fs.rs index a02c2e3..d131c75 100644 --- a/src/io/fs.rs +++ b/src/io/fs.rs @@ -1,5 +1,5 @@ -use bytes::Bytes; use crate::target_rt::FileIO; +use bytes::Bytes; pub struct NativeFileIO {} @@ -14,4 +14,4 @@ impl FileIO for NativeFileIO { let vec = tokio::fs::read(path).await?; Ok(Bytes::from(vec)) } -} \ No newline at end of file +} diff --git a/src/io/http.rs b/src/io/http.rs index 959a92b..5f064aa 100644 --- a/src/io/http.rs +++ b/src/io/http.rs @@ -1,9 +1,9 @@ +use crate::target_rt::HttpIO; use anyhow::anyhow; use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions}; +use margdarshak_http_cache::HttpCacheManager; use reqwest::{Client, Request}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; -use margdarshak_http_cache::HttpCacheManager; -use crate::target_rt::HttpIO; pub struct NativeHttpIO { client: ClientWithMiddleware, @@ -30,7 +30,11 @@ impl NativeHttpIO { #[async_trait::async_trait] impl HttpIO for NativeHttpIO { async fn execute(&self, request: Request) -> anyhow::Result { - let resp = self.client.execute(request).await.map_err(|e| anyhow!("{}", e))?; + let resp = self + .client + .execute(request) + .await + .map_err(|e| anyhow!("{}", e))?; Ok(resp) } -} \ No newline at end of file +} diff --git a/src/io/mod.rs b/src/io/mod.rs index b27588a..ded7353 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -1,2 +1,3 @@ +pub mod env; pub mod fs; -pub mod http; \ No newline at end of file +pub mod http; diff --git a/src/lib.rs b/src/lib.rs index e21a80f..99c579e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ -pub mod target_rt; -pub mod io; -pub mod scrapper; -pub mod margdarshak; +pub mod app_ctx; pub mod http; -pub mod app_ctx; \ No newline at end of file +pub mod io; +pub mod margdarshak; +pub mod scrapper; +pub mod target_rt; diff --git a/src/margdarshak/helper_md/query.md b/src/margdarshak/helper_md/query.md index 358bc16..e63f03a 100644 --- a/src/margdarshak/helper_md/query.md +++ b/src/margdarshak/helper_md/query.md @@ -1,10 +1,8 @@ Based on the attached text/link below and answer questions accordingly after I write "question:". -Please do not copy paste the text in the answer. -Write the answer in your own words. And you should focus on mentioning link for the source if possible. -Lastly, do not provide any images, just respond in text. +Lastly, do not provide any images, just respond in text in your own words. text: {{input}} @@ -12,4 +10,4 @@ text: question: {{question}} -Make sure to respond in MD format only. \ No newline at end of file +Make sure to respond in markdown format. \ No newline at end of file diff --git a/src/margdarshak/http1.rs b/src/margdarshak/http1.rs index f3a5b9b..dfe8fcf 100644 --- a/src/margdarshak/http1.rs +++ b/src/margdarshak/http1.rs @@ -1,17 +1,14 @@ -use std::net::SocketAddr; -use std::sync::Arc; -use hyper::body::Incoming; -use hyper::service::service_fn; -use hyper_util::rt::TokioIo; -use tokio::net::TcpListener; use crate::app_ctx::AppCtx; use crate::http::handle_req::handle_request; use crate::http::request::Request; +use hyper::body::Incoming; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpListener; -pub async fn start_http_1( - port: u16, - app_ctx: AppCtx, -) -> anyhow::Result<()> { +pub async fn start_http_1(port: u16, app_ctx: AppCtx) -> anyhow::Result<()> { let addr = SocketAddr::new([0, 0, 0, 0].into(), port); let listener = TcpListener::bind(addr).await?; let app_ctx = Arc::new(app_ctx); @@ -29,17 +26,11 @@ pub async fn start_http_1( .serve_connection( io, service_fn(move |req: hyper::Request| { - { - let app_ctx = app_ctx.clone(); + let app_ctx = app_ctx.clone(); - async move { - let req = Request::from_hyper(req).await?; - handle_request( - req, - app_ctx.clone(), - ) - .await - } + async move { + let req = Request::from_hyper(req).await?; + handle_request(req, app_ctx.clone()).await } }), ) @@ -52,4 +43,4 @@ pub async fn start_http_1( Err(e) => tracing::error!("An error occurred while handling request: {e}"), } } -} \ No newline at end of file +} diff --git a/src/margdarshak/mod.rs b/src/margdarshak/mod.rs index 7ad0a7c..b60e489 100644 --- a/src/margdarshak/mod.rs +++ b/src/margdarshak/mod.rs @@ -1,26 +1,3 @@ -// use qdrant_client::prelude::PointStruct; -// use qdrant_client::qdrant::{Vector, Vectors}; -// use qdrant_client::qdrant::vectors::VectorsOptions; - -/*fn chunk_md_content(content: &str) -> Vec { - let chunks: Vec = content.split("\n\n").map(|s| s.trim().to_string()).collect(); - chunks -} -*/ - -/*fn store_embeddings(text_chunks: Vec, client: &qdrant_client::Qdrant) { - for (i, chunk) in text_chunks.iter().enumerate() { - let vector = Vector { data: embedding.to_vec(), indices: None, vectors_count: None }; - let options = VectorsOptions::Vector(vector); - - let point = PointStruct { - id: i as u64, - vectors: Some(Vectors { vectors_options: Some(options) }), - ..Default::default() - }; - client.upload_point("your_collection", point).unwrap(); - } -}*/ -pub mod model; pub mod http1; -pub mod runner; \ No newline at end of file +pub mod model; +pub mod runner; diff --git a/src/margdarshak/model.rs b/src/margdarshak/model.rs index c1ca009..0b6fb19 100644 --- a/src/margdarshak/model.rs +++ b/src/margdarshak/model.rs @@ -1,8 +1,8 @@ +use anyhow::{anyhow, Result}; use genai::adapter::AdapterKind; use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponse}; -use genai::Client; use genai::resolver::AuthResolver; -use anyhow::{anyhow, Result}; +use genai::Client; const BASE: &str = include_str!("helper_md/query.md"); @@ -34,7 +34,6 @@ impl TryInto for Question<'_> { type Error = anyhow::Error; fn try_into(self) -> Result { - Ok(ChatRequest::new(vec![ ChatMessage::system(self.md), ChatMessage::user(self.qry), @@ -51,13 +50,16 @@ impl TryFrom for Answer { fn try_from(response: ChatResponse) -> Result { let message_content = response.content.ok_or(anyhow!("No response found"))?; - let text_content = message_content.text_as_str().ok_or(anyhow!("Unable to deserialize response"))?; + let text_content = message_content + .text_as_str() + .ok_or(anyhow!("Unable to deserialize response"))?; // Ok(serde_json::from_str(text_content)?) - Ok(Self { ans: text_content.to_string() }) + Ok(Self { + ans: text_content.to_string(), + }) } } - impl Wizard { pub fn new(model: String, secret: Option) -> Self { let mut config = genai::adapter::AdapterConfig::default(); @@ -65,9 +67,7 @@ impl Wizard { config = config.with_auth_resolver(AuthResolver::from_key_value(key)); } - - let adapter_kind = AdapterKind::from_model(model.as_str()) - .unwrap_or(AdapterKind::Ollama); + let adapter_kind = AdapterKind::from_model(model.as_str()).unwrap_or(AdapterKind::Ollama); let chat_options = ChatOptions::default() .with_json_mode(true) @@ -87,6 +87,6 @@ impl Wizard { .exec_chat(self.model.as_str(), q.try_into()?, None) .await?; - Answer::try_from(response).map_err(|e| anyhow!("{}",e.to_string())) + Answer::try_from(response).map_err(|e| anyhow!("{}", e.to_string())) } -} \ No newline at end of file +} diff --git a/src/margdarshak/runner.rs b/src/margdarshak/runner.rs index 94d2195..5986d09 100644 --- a/src/margdarshak/runner.rs +++ b/src/margdarshak/runner.rs @@ -1,35 +1,44 @@ -use std::sync::Arc; use crate::app_ctx::AppCtx; +use crate::io::env::NativeEnvIO; use crate::io::fs::NativeFileIO; use crate::io::http::NativeHttpIO; use crate::margdarshak::http1; use crate::margdarshak::model::{Question, Wizard}; use crate::target_rt::TargetRuntime; +use std::sync::Arc; pub async fn run() -> anyhow::Result<()> { let rt = TargetRuntime { http: Arc::new(NativeHttpIO::new()), fs: Arc::new(NativeFileIO {}), + env: Arc::new(NativeEnvIO::new()), }; - let md = crate::scrapper::scrape::handle("https://icds-docs.readthedocs.io/en/latest/".to_string(), &rt).await?; + let md = crate::scrapper::scrape::handle( + "https://icds-docs.readthedocs.io/en/latest/".to_string(), + &rt, + ) + .await?; + // let md = "https://icds-docs.readthedocs.io/en/latest/".to_string(); // TODO: add wrapper structs to avoid accidentally // query instead of MD and vis versa. - let query = "".to_string(); - let query = Question::process(&md, &query); + let query = "This is the input, now you should refer this input and answer following questions" + .to_string(); + // let query = "".to_string(); + // let query = Question::process(&md, &query); let query = Question::process(&md, &query); let question = Question::new(&md, &query); - let wizard = Wizard::new("llama3.2".to_string(), None); + let wizard = Wizard::new( + "gemini-1.5-flash-latest".to_string(), + rt.env.get_env("API_KEY"), + ); tracing::info!("Warming up!"); let _ans = wizard.ask(question).await?; tracing::info!("Warmup complete."); - let app_ctx = AppCtx { - wizard, - md, - }; + let app_ctx = AppCtx { wizard, md }; http1::start_http_1(19194, app_ctx).await?; Ok(()) -} \ No newline at end of file +} diff --git a/src/scrapper/mod.rs b/src/scrapper/mod.rs index 1ddaa48..d64fc19 100644 --- a/src/scrapper/mod.rs +++ b/src/scrapper/mod.rs @@ -1 +1 @@ -pub mod scrape; \ No newline at end of file +pub mod scrape; diff --git a/src/scrapper/scrape.rs b/src/scrapper/scrape.rs index 4880529..4e5dfda 100644 --- a/src/scrapper/scrape.rs +++ b/src/scrapper/scrape.rs @@ -1,6 +1,6 @@ +use crate::target_rt::TargetRuntime; use reqwest::Url; use scraper::{Html, Selector}; -use crate::target_rt::TargetRuntime; async fn handle_inner(url: &str, runtime: &TargetRuntime) -> anyhow::Result { let url = Url::parse(url)?; @@ -20,10 +20,16 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result< loop { let parsed = Html::parse_document(&root); - let val = - Selector::parse("a") - .ok() - .and_then(|link_selector| parsed.select(&link_selector).find(|val| val.text().collect::>().join(" ").trim().to_lowercase().eq("next"))); + let val = Selector::parse("a").ok().and_then(|link_selector| { + parsed.select(&link_selector).find(|val| { + val.text() + .collect::>() + .join(" ") + .trim() + .to_lowercase() + .eq("next") + }) + }); match val { Some(val) => { @@ -38,9 +44,12 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result< break; } } - }; - let mds = htmls.into_iter().map(|v| htmd::convert(v.as_str())).flatten().collect::>(); + } + let mds = htmls + .into_iter() + .map(|v| htmd::convert(v.as_str())) + .flatten() + .collect::>(); let text = mds.join("\n"); - // let text: String = md_to_text::convert(&text); Ok(text) -} \ No newline at end of file +} diff --git a/src/target_rt.rs b/src/target_rt.rs index fd323a1..4089957 100644 --- a/src/target_rt.rs +++ b/src/target_rt.rs @@ -1,15 +1,15 @@ -use std::sync::Arc; use bytes::Bytes; +use std::sync::Arc; + +pub trait EnvIO { + fn get_env(&self, key: &str) -> Option; +} #[async_trait::async_trait] pub trait HttpIO: Sync + Send + 'static { - async fn execute( - &self, - request: reqwest::Request, - ) -> anyhow::Result; + async fn execute(&self, request: reqwest::Request) -> anyhow::Result; } - #[async_trait::async_trait] pub trait FileIO: Send + Sync { async fn write<'a>(&'a self, path: &'a str, content: &'a Bytes) -> anyhow::Result<()>; @@ -19,4 +19,5 @@ pub trait FileIO: Send + Sync { pub struct TargetRuntime { pub http: Arc, pub fs: Arc, -} \ No newline at end of file + pub env: Arc, +}