v0.1 #1

Merged
ssdd merged 1 commits from v0.1 into main 2024-10-13 02:00:47 +00:00
21 changed files with 2483 additions and 432 deletions
Showing only changes of commit 066fcf799e - Show all commits

2302
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -3,14 +3,41 @@ name = "hackpsuchatbot"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
#[lib]
#name = "margdarshak"
#path = "src/lib.rs"
#
#[[bin]]
#name = "hackpsuchatbot"
#path = "src/main.rs"
[dependencies] [dependencies]
genai = { git = "https://github.com/laststylebender14/rust-genai.git", rev = "63a542ce20132503c520f4e07108e0d768f243c3", optional = true }
hyper = { version = "1.4.1", features = ["full"] } hyper = { version = "1.4.1", features = ["full"] }
pdf = "0.9.0"
async-trait = "0.1.83" async-trait = "0.1.83"
reqwest = "0.12.8" reqwest = { version = "0.11", features = [
"json",
"rustls-tls",
], default-features = false }
anyhow = "1.0.89" anyhow = "1.0.89"
hyper-util = { version = "0.1", features = ["tokio"] } hyper-util = { version = "0.1", features = ["tokio"] }
http-body-util = "0.1.0" http-body-util = "0.1.0"
bytes = "1.7.2" bytes = "1.7.2"
tokio = "1.40.0" tokio = { version = "1.40.0", features = ["full"] }
scraper = { version = "0.20.0", features = [] }
num_cpus = "1.16.0"
serde_json = "1.0.128"
async-recursion = "1.1.1"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
htmd = "0.1.6"
qdrant-client = "1.12.1"
reqwest-middleware = "0.2.5"
http-cache-reqwest = { version = "0.13.0", features = [
"manager-moka",
], default-features = false }
margdarshak-http-cache = { path = "margdarshak-http-cache" }
genai = { git = "https://github.com/laststylebender14/rust-genai.git", rev = "63a542ce20132503c520f4e07108e0d768f243c3" }
md_to_text = "0.0.0"
[workspace]
members = [".", "margdarshak-http-cache"]

@ -0,0 +1,24 @@
[package]
name = "margdarshak-http-cache"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
http-cache-reqwest = { version = "0.13.0", default-features = false, features = ["manager-moka"] }
moka = { version = "0.12.7", default-features = false, features = [
"future",
] }
http-cache-semantics = { version = "1.0.1", default-features = false, features = ["with_serde", "reqwest"] }
serde = "1.0.202"
async-trait = "0.1.80"
[dev-dependencies]
tokio = { version = "1.37.0", features = ["full"] }
url = { version = "2.5.0", features = ["serde"] }
reqwest = { version = "0.11", features = [
"json",
"rustls-tls",
], default-features = false }
http = "0.2.12"
http-cache = "0.18.0"

@ -0,0 +1,176 @@
use http_cache_reqwest::{CacheManager, HttpResponse};
use http_cache_semantics::CachePolicy;
use serde::{Deserialize, Serialize};
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, BoxError>;
use std::sync::Arc;
use moka::future::Cache;
use moka::policy::EvictionPolicy;
pub struct HttpCacheManager {
pub cache: Arc<Cache<String, Store>>,
}
impl Default for HttpCacheManager {
fn default() -> Self {
Self::new(42)
}
}
#[derive(Clone, Deserialize, Serialize)]
pub struct Store {
response: HttpResponse,
policy: CachePolicy,
}
impl HttpCacheManager {
pub fn new(cache_size: u64) -> Self {
let cache = Cache::builder()
.eviction_policy(EvictionPolicy::lru())
.max_capacity(cache_size)
.build();
Self { cache: Arc::new(cache) }
}
pub async fn clear(&self) -> Result<()> {
self.cache.invalidate_all();
self.cache.run_pending_tasks().await;
Ok(())
}
}
#[async_trait::async_trait]
impl CacheManager for HttpCacheManager {
async fn get(&self, cache_key: &str) -> Result<Option<(HttpResponse, CachePolicy)>> {
let store: Store = match self.cache.get(cache_key).await {
Some(d) => d,
None => return Ok(None),
};
Ok(Some((store.response, store.policy)))
}
async fn put(
&self,
cache_key: String,
response: HttpResponse,
policy: CachePolicy,
) -> Result<HttpResponse> {
let data = Store { response: response.clone(), policy };
self.cache.insert(cache_key, data).await;
self.cache.run_pending_tasks().await;
Ok(response)
}
async fn delete(&self, cache_key: &str) -> Result<()> {
self.cache.invalidate(cache_key).await;
self.cache.run_pending_tasks().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use http_cache::HttpVersion;
use reqwest::{Method, Response, ResponseBuilderExt};
use url::Url;
use super::*;
fn convert_response(response: HttpResponse) -> Result<Response> {
let ret_res = http::Response::builder()
.status(response.status)
.url(response.url)
.version(response.version.into())
.body(response.body)?;
Ok(Response::from(ret_res))
}
async fn insert_key_into_cache(manager: &HttpCacheManager, key: &str) {
let request_url = "http://localhost:8080/test";
let url = Url::parse(request_url).unwrap();
let http_resp = HttpResponse {
headers: HashMap::default(),
body: vec![1, 2, 3],
status: 200,
url: url.clone(),
version: HttpVersion::Http11,
};
let resp = convert_response(http_resp.clone()).unwrap();
let request: reqwest::Request =
reqwest::Request::new(Method::GET, request_url.parse().unwrap());
let _ = manager
.put(
key.to_string(),
http_resp,
CachePolicy::new(&request, &resp),
)
.await
.unwrap();
}
#[tokio::test]
async fn test_put() {
let manager = HttpCacheManager::default();
insert_key_into_cache(&manager, "test").await;
assert!(manager.cache.contains_key("test"));
}
#[tokio::test]
async fn test_get_when_key_present() {
let manager = HttpCacheManager::default();
insert_key_into_cache(&manager, "test").await;
let value = manager.get("test").await.unwrap();
assert!(value.is_some());
}
#[tokio::test]
async fn test_get_when_key_not_present() {
let manager = HttpCacheManager::default();
let result = manager.get("test").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_delete_when_key_present() {
let manager = HttpCacheManager::default();
insert_key_into_cache(&manager, "test").await;
assert!(manager.cache.iter().count() as i32 == 1);
let _ = manager.delete("test").await;
assert!(manager.cache.iter().count() as i32 == 0);
}
#[tokio::test]
async fn test_clear() {
let manager = HttpCacheManager::default();
insert_key_into_cache(&manager, "test").await;
assert!(manager.cache.iter().count() as i32 == 1);
let _ = manager.clear().await;
assert!(manager.cache.iter().count() as i32 == 0);
}
#[tokio::test]
async fn test_lru_eviction_policy() {
let manager = HttpCacheManager::new(2);
insert_key_into_cache(&manager, "test-1").await;
insert_key_into_cache(&manager, "test-2").await;
insert_key_into_cache(&manager, "test-10").await;
let res = manager.get("test-1").await.unwrap();
assert!(res.is_none());
let res = manager.get("test-2").await.unwrap();
assert!(res.is_some());
let res = manager.get("test-10").await.unwrap();
assert!(res.is_some());
assert_eq!(manager.cache.entry_count(), 2);
}
}

@ -0,0 +1,3 @@
mod cache;
pub use cache::HttpCacheManager;

6
src/app_ctx.rs Normal file

@ -0,0 +1,6 @@
use crate::margdarshak::model::Wizard;
pub struct AppCtx {
pub wizard: Wizard,
pub md: String,
}

22
src/http/handle_req.rs Normal file

@ -0,0 +1,22 @@
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Response;
use crate::app_ctx::AppCtx;
use crate::http::request::Request;
use crate::margdarshak::model::Question;
pub async fn handle_request(req: Request, app_ctx: Arc<AppCtx>) -> anyhow::Result<Response<Full<Bytes>>> {
let body = String::from_utf8(req.body.to_vec())?;
tracing::info!("{}", body);
let body = Question::process(&app_ctx.md, &body);
let question = Question::new(&app_ctx.md, &body);
let response = app_ctx.wizard.ask(question).await?;
let hyper_response = hyper::Response::builder()
.header("content-type", "application/json")
.body(Full::new(Bytes::from(response.ans)))?;
Ok(hyper_response)
}

2
src/http/mod.rs Normal file

@ -0,0 +1,2 @@
pub mod request;
pub mod handle_req;

23
src/http/request.rs Normal file

@ -0,0 +1,23 @@
use bytes::Bytes;
use hyper::body::Incoming;
use http_body_util::BodyExt;
pub struct Request {
pub method: String,
pub path: String,
pub query: Option<String>,
pub headers: Vec<(String, String)>,
pub body: Bytes,
}
impl Request {
pub async fn from_hyper(req: hyper::Request<Incoming>) -> anyhow::Result<Self> {
let (parts, body) = req.into_parts();
let body = body.collect().await?.to_bytes();
let method = parts.method.to_string();
let path = parts.uri.path().to_string();
let query = parts.uri.query().map(|q| q.to_string());
let headers = parts.headers.iter().map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string())).collect();
Ok(Self { method, path, query, headers, body })
}
}

@ -1,10 +1,30 @@
use anyhow::anyhow; use anyhow::anyhow;
use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions};
use reqwest::{Client, Request}; use reqwest::{Client, Request};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use margdarshak_http_cache::HttpCacheManager;
use crate::target_rt::HttpIO; use crate::target_rt::HttpIO;
pub struct NativeHttpIO { pub struct NativeHttpIO {
client: Client, client: ClientWithMiddleware,
}
impl NativeHttpIO {
pub fn new() -> Self {
let builder = Client::builder();
let mut client = ClientBuilder::new(builder.build().expect("Failed to build client"));
client = client.with(Cache(HttpCache {
mode: CacheMode::Default,
manager: HttpCacheManager::new(1024),
options: HttpCacheOptions::default(),
}));
Self {
client: client.build(),
}
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]

@ -1,2 +1,2 @@
pub mod fs; pub mod fs;
mod http; pub mod http;

@ -1,2 +1,6 @@
pub mod target_rt; pub mod target_rt;
pub mod io; pub mod io;
pub mod scrapper;
pub mod margdarshak;
pub mod http;
pub mod app_ctx;

@ -1,3 +1,16 @@
fn main() { use hackpsuchatbot::margdarshak::runner;
println!("Hello, world!");
fn run() -> anyhow::Result<()> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(num_cpus::get())
.build()?;
rt.block_on(runner::run())?;
Ok(())
}
fn main() {
tracing_subscriber::fmt::fmt().init();
run().unwrap();
} }

@ -0,0 +1,15 @@
Based on the attached text/link below and answer questions accordingly after I write "question:".
Please do not copy paste the text in the answer.
Write the answer in your own words.
And you should focus on mentioning link for the source if possible.
Lastly, do not provide any images, just respond in text.
text:
{{input}}
question:
{{question}}
Make sure to respond in MD format only.

55
src/margdarshak/http1.rs Normal file

@ -0,0 +1,55 @@
use std::net::SocketAddr;
use std::sync::Arc;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use crate::app_ctx::AppCtx;
use crate::http::handle_req::handle_request;
use crate::http::request::Request;
pub async fn start_http_1(
port: u16,
app_ctx: AppCtx,
) -> anyhow::Result<()> {
let addr = SocketAddr::new([0, 0, 0, 0].into(), port);
let listener = TcpListener::bind(addr).await?;
let app_ctx = Arc::new(app_ctx);
tracing::info!("Starting HTTP/1 server on port {port}");
loop {
let stream_result = listener.accept().await;
match stream_result {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let app_ctx = app_ctx.clone();
tokio::spawn(async {
let server = hyper::server::conn::http1::Builder::new()
.serve_connection(
io,
service_fn(move |req: hyper::Request<Incoming>| {
{
let app_ctx = app_ctx.clone();
async move {
let req = Request::from_hyper(req).await?;
handle_request(
req,
app_ctx.clone(),
)
.await
}
}
}),
)
.await;
if let Err(e) = server {
tracing::error!("An error occurred while handling a request: {e}");
}
});
}
Err(e) => tracing::error!("An error occurred while handling request: {e}"),
}
}
}

26
src/margdarshak/mod.rs Normal file

@ -0,0 +1,26 @@
// use qdrant_client::prelude::PointStruct;
// use qdrant_client::qdrant::{Vector, Vectors};
// use qdrant_client::qdrant::vectors::VectorsOptions;
/*fn chunk_md_content(content: &str) -> Vec<String> {
let chunks: Vec<String> = content.split("\n\n").map(|s| s.trim().to_string()).collect();
chunks
}
*/
/*fn store_embeddings(text_chunks: Vec<String>, client: &qdrant_client::Qdrant) {
for (i, chunk) in text_chunks.iter().enumerate() {
let vector = Vector { data: embedding.to_vec(), indices: None, vectors_count: None };
let options = VectorsOptions::Vector(vector);
let point = PointStruct {
id: i as u64,
vectors: Some(Vectors { vectors_options: Some(options) }),
..Default::default()
};
client.upload_point("your_collection", point).unwrap();
}
}*/
pub mod model;
pub mod http1;
pub mod runner;

92
src/margdarshak/model.rs Normal file

@ -0,0 +1,92 @@
use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponse};
use genai::Client;
use genai::resolver::AuthResolver;
use anyhow::{anyhow, Result};
const BASE: &str = include_str!("helper_md/query.md");
pub struct Wizard {
client: Client,
model: String,
}
pub struct Question<'a> {
md: &'a str,
qry: &'a str,
}
impl<'a> Question<'a> {
pub fn process(qry: &str, md: &str) -> String {
let qry = BASE.replace("{{question}}", qry);
let qry = qry.replace("{{input}}", md);
qry
}
pub fn new(md: &'a str, qry: &'a str) -> Self {
// TODO: decide if we need to pass md or link
// let qry = qry.replace("{{input}}", "https://icds-docs.readthedocs.io/en/latest/");
Self { md, qry }
}
}
impl TryInto<ChatRequest> for Question<'_> {
type Error = anyhow::Error;
fn try_into(self) -> Result<ChatRequest> {
Ok(ChatRequest::new(vec![
ChatMessage::system(self.md),
ChatMessage::user(self.qry),
]))
}
}
pub struct Answer {
pub ans: String,
}
impl TryFrom<ChatResponse> for Answer {
type Error = anyhow::Error;
fn try_from(response: ChatResponse) -> Result<Self> {
let message_content = response.content.ok_or(anyhow!("No response found"))?;
let text_content = message_content.text_as_str().ok_or(anyhow!("Unable to deserialize response"))?;
// Ok(serde_json::from_str(text_content)?)
Ok(Self { ans: text_content.to_string() })
}
}
impl Wizard {
pub fn new(model: String, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
if let Some(key) = secret {
config = config.with_auth_resolver(AuthResolver::from_key_value(key));
}
let adapter_kind = AdapterKind::from_model(model.as_str())
.unwrap_or(AdapterKind::Ollama);
let chat_options = ChatOptions::default()
.with_json_mode(true)
.with_temperature(0.0);
Self {
client: Client::builder()
.with_chat_options(chat_options)
.insert_adapter_config(adapter_kind, config)
.build(),
model,
}
}
pub async fn ask(&self, q: Question<'_>) -> Result<Answer> {
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.await?;
Answer::try_from(response).map_err(|e| anyhow!("{}",e.to_string()))
}
}

35
src/margdarshak/runner.rs Normal file

@ -0,0 +1,35 @@
use std::sync::Arc;
use crate::app_ctx::AppCtx;
use crate::io::fs::NativeFileIO;
use crate::io::http::NativeHttpIO;
use crate::margdarshak::http1;
use crate::margdarshak::model::{Question, Wizard};
use crate::target_rt::TargetRuntime;
pub async fn run() -> anyhow::Result<()> {
let rt = TargetRuntime {
http: Arc::new(NativeHttpIO::new()),
fs: Arc::new(NativeFileIO {}),
};
let md = crate::scrapper::scrape::handle("https://icds-docs.readthedocs.io/en/latest/".to_string(), &rt).await?;
// TODO: add wrapper structs to avoid accidentally
// query instead of MD and vis versa.
let query = "".to_string();
let query = Question::process(&md, &query);
let query = Question::process(&md, &query);
let question = Question::new(&md, &query);
let wizard = Wizard::new("llama3.2".to_string(), None);
tracing::info!("Warming up!");
let _ans = wizard.ask(question).await?;
tracing::info!("Warmup complete.");
let app_ctx = AppCtx {
wizard,
md,
};
http1::start_http_1(19194, app_ctx).await?;
Ok(())
}

1
src/scrapper/mod.rs Normal file

@ -0,0 +1 @@
pub mod scrape;

46
src/scrapper/scrape.rs Normal file

@ -0,0 +1,46 @@
use reqwest::Url;
use scraper::{Html, Selector};
use crate::target_rt::TargetRuntime;
async fn handle_inner(url: &str, runtime: &TargetRuntime) -> anyhow::Result<String> {
let url = Url::parse(url)?;
tracing::info!("Scraping: {}", url);
let req = reqwest::Request::new(reqwest::Method::GET, url);
let response = runtime.http.execute(req).await?;
let response = response.text().await?;
// let document = Html::parse_document(&response);
// Ok(document)
Ok(response)
}
pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result<String> {
let mut htmls = vec![];
let mut root = handle_inner(&url, runtime).await?;
htmls.push(root.clone());
loop {
let parsed = Html::parse_document(&root);
let val =
Selector::parse("a")
.ok()
.and_then(|link_selector| parsed.select(&link_selector).find(|val| val.text().collect::<Vec<_>>().join(" ").trim().to_lowercase().eq("next")));
match val {
Some(val) => {
let next = val.value().attr("href").unwrap_or("#");
url = format!("{}{}", url, next);
let next = handle_inner(&url, runtime).await?;
htmls.push(next.clone());
root = next;
}
None => {
break;
}
}
};
let mds = htmls.into_iter().map(|v| htmd::convert(v.as_str())).flatten().collect::<Vec<_>>();
let text = mds.join("\n");
// let text: String = md_to_text::convert(&text);
Ok(text)
}

@ -1,6 +1,5 @@
use std::sync::Arc; use std::sync::Arc;
use bytes::Bytes; use bytes::Bytes;
use http_body_util::Full;
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait HttpIO: Sync + Send + 'static { pub trait HttpIO: Sync + Send + 'static {