From dc45b56759e38da793f8976c94b6a798096c968e Mon Sep 17 00:00:00 2001 From: Oscar Najera Date: Wed, 29 Nov 2023 04:21:37 +0100 Subject: Semantic search server and client --- scratch/semgrep/server.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 scratch/semgrep/server.py (limited to 'scratch/semgrep/server.py') diff --git a/scratch/semgrep/server.py b/scratch/semgrep/server.py new file mode 100644 index 0000000..6a8648b --- /dev/null +++ b/scratch/semgrep/server.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +from http.server import BaseHTTPRequestHandler, HTTPServer +import chromadb +import collections +import hashlib +import json + + +def checksum(string): + sha256 = hashlib.sha256() + sha256.update(string.encode("utf-8")) + return sha256.hexdigest()[:32] + + +def ensure_list(data): + if isinstance(data, str): + return [data] + if isinstance(data, list): + if all(isinstance(l, str) for l in data): + return data + raise ValueError("Data must be a list of strings") + + +class MyRequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length).decode("utf-8") + + try: + data = json.loads(post_data) + # Process the JSON data + response_message = f"Received POST request with data: '{data}'\n" + except ValueError: + response_message = "Invalid JSON data" + self.send_response(400) + + if query := data.get("query"): + response = collection.query(query_texts=ensure_list(query)) + elif paragraph := data.get("store"): + data, metadata = drop_duplicates(paragraph) + collection.add( + documents=data, metadatas=metadata, ids=[checksum(l) for l in data] + ) + response = {"added": data} + else: + raise ValueError(f"Used wrong method. Sent: {data.keys()}") + + response_message = json.dumps(response) + + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.wfile.write(response_message.encode("utf-8")) + + +def run_server(port=8080): + server_address = ("", port) + httpd = HTTPServer(server_address, MyRequestHandler) + print(f"Server running on port {port}") + httpd.serve_forever() + + +def drop_duplicates(paragraph): + data = [data["document"] for data in paragraph] + metadata = [data["metadata"] for data in paragraph] + dups = (x for x, count in collections.Counter(data).items() if count > 1) + to_drop = [] + for no in dups: + to_drop.extend([i for i, d in enumerate(data) if d == no][1:]) + to_drop.sort(reverse=True) + for index in to_drop: + data.pop(index) + metadata.pop(index) + return data, metadata + + +def test(): + sample = [ + {"document": "Hello", "metadata": 5}, + {"document": "World", "metadata": 8}, + {"document": "Hello", "metadata": 6}, + {"document": "Good", "metadata": 3}, + {"document": "World", "metadata": 9}, + ] + + assert drop_duplicates(sample) == (["Hello", "World", "Good"], [5, 8, 3]) + + +if __name__ == "__main__": + client = chromadb.PersistentClient(path="./semgrep") + collection = client.get_or_create_collection("org-roam") + run_server() -- cgit v1.2.3