aboutsummaryrefslogtreecommitdiffstats
path: root/scratch/semgrep/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'scratch/semgrep/server.py')
-rw-r--r--scratch/semgrep/server.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/scratch/semgrep/server.py b/scratch/semgrep/server.py
new file mode 100644
index 0000000..becabbb
--- /dev/null
+++ b/scratch/semgrep/server.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+import argparse
+import collections
+import hashlib
+import json
+from http.server import BaseHTTPRequestHandler, HTTPServer
+
+import chromadb
+
+
+def checksum(string: str):
+ sha256 = hashlib.sha256()
+ sha256.update(string.encode("utf-8"))
+ return sha256.hexdigest()[:32]
+
+
+def ensure_list(data):
+ if isinstance(data, str):
+ return [data]
+ if isinstance(data, list):
+ if all(isinstance(l, str) for l in data):
+ return data
+ raise ValueError("Data must be a list of strings")
+
+
+def delete_nodes(nodes):
+ for node in nodes:
+ collection.delete(where={"node-id": node})
+
+
+class MyRequestHandler(BaseHTTPRequestHandler):
+ def do_POST(self):
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length).decode("utf-8")
+
+ try:
+ data = json.loads(post_data)
+ response_message = f"Received POST request with data: '{data}'\n"
+ self.log_message(response_message)
+ except ValueError:
+ response_message = "Invalid JSON data"
+ self.send_response(400)
+
+ if query := data.get("query"):
+ self.log_message("Processing query '%s'", query.replace("\n", " ").strip())
+ response = collection.query(query_texts=ensure_list(query))
+ elif delete_set := data.get("delete"):
+ delete_nodes(ensure_list(delete_set))
+ response = f"Deleted nodes {delete_set}"
+ elif paragraphs := data.get("insert"):
+ data, metadata = drop_duplicates(paragraphs)
+ nodes = set(m.get("node-id") for m in metadata)
+ self.log_message("Processing metadata %s", nodes)
+ delete_nodes(nodes)
+ collection.add(
+ documents=data, metadatas=metadata, ids=list(map(checksum, data))
+ )
+ response = f"Successfully inserted {nodes}"
+ else:
+ raise ValueError(f"Used wrong method. Sent: {data.keys()}")
+
+ response_message = json.dumps(response)
+
+ self.send_response(200)
+ self.send_header("Content-type", "text/plain")
+ self.end_headers()
+ self.wfile.write(response_message.encode("utf-8"))
+
+
+def run_server(port=8080):
+ server_address = ("", port)
+ httpd = HTTPServer(server_address, MyRequestHandler)
+ print(f"Server running on port {port}")
+ httpd.serve_forever()
+
+
+def drop_duplicates(paragraphs):
+ data = [data["document"].replace("\n", " ").strip() for data in paragraphs]
+ metadata = [data["metadata"] for data in paragraphs]
+ dups = (x for x, count in collections.Counter(data).items() if count > 1)
+ to_drop = []
+ for no in dups:
+ to_drop.extend([i for i, d in enumerate(data) if d == no][1:])
+ to_drop.sort(reverse=True)
+ for index in to_drop:
+ data.pop(index)
+ metadata.pop(index)
+ return data, metadata
+
+
+def parse_arguments(args=None):
+ parser = argparse.ArgumentParser(
+ description="Run Semantic database server",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "-db", "--database", default="org-roam", help="Name of the collection"
+ )
+ parser.add_argument(
+ "-D",
+ "--database-dir",
+ default="semantic-roam",
+ help="Directory where to store database files",
+ )
+ parser.add_argument(
+ "-p", "--port", default=8080, type=int, help="Port where server listens"
+ )
+
+ return parser.parse_args(args)
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+ client = chromadb.PersistentClient(path=args.database_dir)
+ collection = client.get_or_create_collection(args.database)
+ run_server(args.port)