#!/usr/bin/env python3 import argparse import collections import hashlib import json from http.server import BaseHTTPRequestHandler, HTTPServer import chromadb def checksum(string: str): sha256 = hashlib.sha256() sha256.update(string.encode("utf-8")) return sha256.hexdigest()[:32] def ensure_list(data): if isinstance(data, str): return [data] if isinstance(data, list): if all(isinstance(l, str) for l in data): return data raise ValueError("Data must be a list of strings") def delete_nodes(nodes): for node in nodes: collection.delete(where={"node-id": node}) class MyRequestHandler(BaseHTTPRequestHandler): def do_POST(self): content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length).decode("utf-8") try: data = json.loads(post_data) response_message = f"Received POST request with data: '{data}'\n" self.log_message(response_message) except ValueError: response_message = "Invalid JSON data" self.send_response(400) if query := data.get("query"): self.log_message("Processing query '%s'", query.replace("\n", " ").strip()) response = collection.query(query_texts=ensure_list(query)) elif delete_set := data.get("delete"): delete_nodes(ensure_list(delete_set)) response = f"Deleted nodes {delete_set}" elif paragraphs := data.get("insert"): data, metadata = drop_duplicates(paragraphs) nodes = set(m.get("node-id") for m in metadata) self.log_message("Processing metadata %s", nodes) delete_nodes(nodes) collection.add( documents=data, metadatas=metadata, ids=list(map(checksum, data)) ) response = f"Successfully inserted {nodes}" else: raise ValueError(f"Used wrong method. Sent: {data.keys()}") response_message = json.dumps(response) self.send_response(200) self.send_header("Content-type", "text/plain") self.end_headers() self.wfile.write(response_message.encode("utf-8")) def run_server(port=8080): server_address = ("", port) httpd = HTTPServer(server_address, MyRequestHandler) print(f"Server running on port {port}") httpd.serve_forever() def drop_duplicates(paragraphs): data = [data["document"].replace("\n", " ").strip() for data in paragraphs] metadata = [data["metadata"] for data in paragraphs] dups = (x for x, count in collections.Counter(data).items() if count > 1) to_drop = [] for no in dups: to_drop.extend([i for i, d in enumerate(data) if d == no][1:]) to_drop.sort(reverse=True) for index in to_drop: data.pop(index) metadata.pop(index) return data, metadata def parse_arguments(args=None): parser = argparse.ArgumentParser( description="Run Semantic database server", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-db", "--database", default="org-roam", help="Name of the collection" ) parser.add_argument( "-D", "--database-dir", default="semantic-roam", help="Directory where to store database files", ) parser.add_argument( "-p", "--port", default=8080, type=int, help="Port where server listens" ) return parser.parse_args(args) if __name__ == "__main__": args = parse_arguments() client = chromadb.PersistentClient(path=args.database_dir) collection = client.get_or_create_collection(args.database) run_server(args.port)