diff options
Diffstat (limited to 'scratch/semgrep/server.py')
-rw-r--r-- | scratch/semgrep/server.py | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/scratch/semgrep/server.py b/scratch/semgrep/server.py new file mode 100644 index 0000000..becabbb --- /dev/null +++ b/scratch/semgrep/server.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +import argparse +import collections +import hashlib +import json +from http.server import BaseHTTPRequestHandler, HTTPServer + +import chromadb + + +def checksum(string: str): + sha256 = hashlib.sha256() + sha256.update(string.encode("utf-8")) + return sha256.hexdigest()[:32] + + +def ensure_list(data): + if isinstance(data, str): + return [data] + if isinstance(data, list): + if all(isinstance(l, str) for l in data): + return data + raise ValueError("Data must be a list of strings") + + +def delete_nodes(nodes): + for node in nodes: + collection.delete(where={"node-id": node}) + + +class MyRequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length).decode("utf-8") + + try: + data = json.loads(post_data) + response_message = f"Received POST request with data: '{data}'\n" + self.log_message(response_message) + except ValueError: + response_message = "Invalid JSON data" + self.send_response(400) + + if query := data.get("query"): + self.log_message("Processing query '%s'", query.replace("\n", " ").strip()) + response = collection.query(query_texts=ensure_list(query)) + elif delete_set := data.get("delete"): + delete_nodes(ensure_list(delete_set)) + response = f"Deleted nodes {delete_set}" + elif paragraphs := data.get("insert"): + data, metadata = drop_duplicates(paragraphs) + nodes = set(m.get("node-id") for m in metadata) + self.log_message("Processing metadata %s", nodes) + delete_nodes(nodes) + collection.add( + documents=data, metadatas=metadata, ids=list(map(checksum, data)) + ) + response = f"Successfully inserted {nodes}" + else: + raise ValueError(f"Used wrong method. Sent: {data.keys()}") + + response_message = json.dumps(response) + + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.wfile.write(response_message.encode("utf-8")) + + +def run_server(port=8080): + server_address = ("", port) + httpd = HTTPServer(server_address, MyRequestHandler) + print(f"Server running on port {port}") + httpd.serve_forever() + + +def drop_duplicates(paragraphs): + data = [data["document"].replace("\n", " ").strip() for data in paragraphs] + metadata = [data["metadata"] for data in paragraphs] + dups = (x for x, count in collections.Counter(data).items() if count > 1) + to_drop = [] + for no in dups: + to_drop.extend([i for i, d in enumerate(data) if d == no][1:]) + to_drop.sort(reverse=True) + for index in to_drop: + data.pop(index) + metadata.pop(index) + return data, metadata + + +def parse_arguments(args=None): + parser = argparse.ArgumentParser( + description="Run Semantic database server", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-db", "--database", default="org-roam", help="Name of the collection" + ) + parser.add_argument( + "-D", + "--database-dir", + default="semantic-roam", + help="Directory where to store database files", + ) + parser.add_argument( + "-p", "--port", default=8080, type=int, help="Port where server listens" + ) + + return parser.parse_args(args) + + +if __name__ == "__main__": + args = parse_arguments() + client = chromadb.PersistentClient(path=args.database_dir) + collection = client.get_or_create_collection(args.database) + run_server(args.port) |