From e2f0163bc2d528d3cec086f3496f78282e7002c9 Mon Sep 17 00:00:00 2001 From: Chris kerr Date: Sat, 20 Jul 2024 08:41:47 -0400 Subject: [PATCH] server is handling requests --- db/db.go | 2 + main.go | 6 +- requests/requests.go | 67 ---------- server/query.go | 70 +++++++++++ server/server.go | 286 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 362 insertions(+), 69 deletions(-) delete mode 100644 requests/requests.go create mode 100644 server/query.go create mode 100644 server/server.go diff --git a/db/db.go b/db/db.go index d5e2939..c7cf06d 100644 --- a/db/db.go +++ b/db/db.go @@ -8,6 +8,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) +// Initialize MongoDB client func InitDB(uri, database string) (*mongo.Client, error) { clientOptions := options.Client().ApplyURI(uri) client, err := mongo.Connect(context.TODO(), clientOptions) @@ -25,6 +26,7 @@ func InitDB(uri, database string) (*mongo.Client, error) { return client, nil } +// Disconnect from MongoDB func DisconnectDB(client *mongo.Client) { if err := client.Disconnect(context.TODO()); err != nil { fmt.Println("Error disconnecting from MongoDB:", err) diff --git a/main.go b/main.go index f83b7a1..e8f74fa 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,7 @@ import ( "grain/db" "grain/events" - "grain/requests" + "grain/server" "grain/utils" "golang.org/x/net/websocket" @@ -30,8 +30,10 @@ func main() { // Initialize collections events.InitCollections(client, config.Collections.EventKind0, config.Collections.EventKind1) + server.SetClient(client) + // Start WebSocket server - http.Handle("/", websocket.Handler(requests.Handler)) + http.Handle("/", websocket.Handler(server.Handler)) fmt.Println("WebSocket server started on", config.Server.Address) err = http.ListenAndServe(config.Server.Address, nil) if err != nil { diff --git a/requests/requests.go b/requests/requests.go deleted file mode 100644 index 068f064..0000000 --- a/requests/requests.go +++ /dev/null @@ -1,67 +0,0 @@ -package requests - -import ( - "context" - "encoding/json" - "fmt" - - "grain/events" - - "golang.org/x/net/websocket" -) - -func Handler(ws *websocket.Conn) { - var msg string - for { - err := websocket.Message.Receive(ws, &msg) - if err != nil { - fmt.Println("Error receiving message:", err) - return - } - fmt.Println("Received message:", msg) - - // Parse the received message - var event []interface{} - err = json.Unmarshal([]byte(msg), &event) - if err != nil { - fmt.Println("Error parsing message:", err) - return - } - - if len(event) < 2 || event[0] != "EVENT" { - fmt.Println("Invalid event format") - continue - } - - // Convert the event map to an Event struct - eventData, ok := event[1].(map[string]interface{}) - if !ok { - fmt.Println("Invalid event data format") - continue - } - eventBytes, err := json.Marshal(eventData) - if err != nil { - fmt.Println("Error marshaling event data:", err) - continue - } - - var evt events.Event - err = json.Unmarshal(eventBytes, &evt) - if err != nil { - fmt.Println("Error unmarshaling event data:", err) - continue - } - - err = events.HandleEvent(context.TODO(), evt) - if err != nil { - fmt.Println("Error handling event:", err) - continue - } - - err = websocket.Message.Send(ws, "Echo: "+msg) - if err != nil { - fmt.Println("Error sending message:", err) - return - } - } -} diff --git a/server/query.go b/server/query.go new file mode 100644 index 0000000..c5dfa02 --- /dev/null +++ b/server/query.go @@ -0,0 +1,70 @@ +package server + +import ( + "context" + "fmt" + + "grain/events" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// QueryEvents queries events from the MongoDB collection based on filters +func QueryEvents(filters []Filter, client *mongo.Client, databaseName, collectionName string) ([]events.Event, error) { + collection := client.Database(databaseName).Collection(collectionName) + + var results []events.Event + + for _, filter := range filters { + filterBson := bson.M{} + + if len(filter.IDs) > 0 { + filterBson["_id"] = bson.M{"$in": filter.IDs} + } + if len(filter.Authors) > 0 { + filterBson["author"] = bson.M{"$in": filter.Authors} + } + if len(filter.Kinds) > 0 { + filterBson["kind"] = bson.M{"$in": filter.Kinds} + } + if filter.Tags != nil { + for key, values := range filter.Tags { + if len(values) > 0 { + filterBson[key] = bson.M{"$in": values} + } + } + } + if filter.Since != nil { + filterBson["created_at"] = bson.M{"$gte": *filter.Since} + } + if filter.Until != nil { + filterBson["created_at"] = bson.M{"$lte": *filter.Until} + } + + opts := options.Find() + if filter.Limit != nil { + opts.SetLimit(int64(*filter.Limit)) + } + + cursor, err := collection.Find(context.TODO(), filterBson, opts) + if err != nil { + return nil, fmt.Errorf("error querying events: %v", err) + } + defer cursor.Close(context.TODO()) + + for cursor.Next(context.TODO()) { + var event events.Event + if err := cursor.Decode(&event); err != nil { + return nil, fmt.Errorf("error decoding event: %v", err) + } + results = append(results, event) + } + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %v", err) + } + } + + return results, nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..4813acd --- /dev/null +++ b/server/server.go @@ -0,0 +1,286 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "grain/events" + "time" + + "go.mongodb.org/mongo-driver/mongo" + "golang.org/x/net/websocket" +) + +type Subscription struct { + ID string + Filters []Filter +} + +// Filter represents the criteria used to query events +type Filter struct { + IDs []string `json:"ids"` + Authors []string `json:"authors"` + Kinds []int `json:"kinds"` + Tags map[string][]string `json:"tags"` + Since *time.Time `json:"since"` + Until *time.Time `json:"until"` + Limit *int `json:"limit"` +} + +var subscriptions = make(map[string]Subscription) +var client *mongo.Client + +func SetClient(mongoClient *mongo.Client) { + client = mongoClient +} + +func Handler(ws *websocket.Conn) { + defer ws.Close() + + var msg string + for { + err := websocket.Message.Receive(ws, &msg) + if err != nil { + fmt.Println("Error receiving message:", err) + return + } + fmt.Println("Received message:", msg) + + var message []interface{} + err = json.Unmarshal([]byte(msg), &message) + if err != nil { + fmt.Println("Error parsing message:", err) + continue + } + + if len(message) < 2 { + fmt.Println("Invalid message format") + continue + } + + messageType, ok := message[0].(string) + if !ok { + fmt.Println("Invalid message type") + continue + } + + switch messageType { + case "EVENT": + handleEvent(ws, message) + case "REQ": + handleReq(ws, message) + case "CLOSE": + handleClose(ws, message) + default: + fmt.Println("Unknown message type:", messageType) + } + } +} + +func handleEvent(ws *websocket.Conn, message []interface{}) { + if len(message) != 2 { + fmt.Println("Invalid EVENT message format") + return + } + + eventData, ok := message[1].(map[string]interface{}) + if !ok { + fmt.Println("Invalid event data format") + return + } + eventBytes, err := json.Marshal(eventData) + if err != nil { + fmt.Println("Error marshaling event data:", err) + return + } + + var evt events.Event + err = json.Unmarshal(eventBytes, &evt) + if err != nil { + fmt.Println("Error unmarshaling event data:", err) + return + } + + err = events.HandleEvent(context.TODO(), evt) + if err != nil { + fmt.Println("Error handling event:", err) + return + } + + fmt.Println("Event processed:", evt.ID) +} + +func handleReq(ws *websocket.Conn, message []interface{}) { + if len(message) < 3 { + fmt.Println("Invalid REQ message format") + return + } + + subID, ok := message[1].(string) + if !ok { + fmt.Println("Invalid subscription ID format") + return + } + + filters := make([]Filter, len(message)-2) + for i, filter := range message[2:] { + filterData, ok := filter.(map[string]interface{}) + if !ok { + fmt.Println("Invalid filter format") + return + } + + var f Filter + f.IDs = toStringArray(filterData["ids"]) + f.Authors = toStringArray(filterData["authors"]) + f.Kinds = toIntArray(filterData["kinds"]) + f.Tags = toTagsMap(filterData["tags"]) + f.Since = toTime(filterData["since"]) + f.Until = toTime(filterData["until"]) + f.Limit = toInt(filterData["limit"]) + + filters[i] = f + } + + subscriptions[subID] = Subscription{ID: subID, Filters: filters} + fmt.Println("Subscription added:", subID) + + // Query the database with filters and send back the results + events, err := QueryEvents(filters, client, "grain", "event-kind0") + if err != nil { + fmt.Println("Error querying events:", err) + return + } + + for _, evt := range events { + msg := []interface{}{"EVENT", subID, evt} + msgBytes, _ := json.Marshal(msg) + err = websocket.Message.Send(ws, string(msgBytes)) + if err != nil { + fmt.Println("Error sending event:", err) + return + } + } + + // Indicate end of stored events + eoseMsg := []interface{}{"EOSE", subID} + eoseBytes, _ := json.Marshal(eoseMsg) + err = websocket.Message.Send(ws, string(eoseBytes)) + if err != nil { + fmt.Println("Error sending EOSE:", err) + return + } +} + +func handleClose(ws *websocket.Conn, message []interface{}) { + if len(message) != 2 { + fmt.Println("Invalid CLOSE message format") + return + } + + subID, ok := message[1].(string) + if !ok { + fmt.Println("Invalid subscription ID format") + return + } + + delete(subscriptions, subID) + fmt.Println("Subscription closed:", subID) + + closeMsg := []interface{}{"CLOSED", subID, "Subscription closed"} + closeBytes, _ := json.Marshal(closeMsg) + err := websocket.Message.Send(ws, string(closeBytes)) + if err != nil { + fmt.Println("Error sending CLOSE message:", err) + return + } +} + +func toStringArray(i interface{}) []string { + if i == nil { + return nil + } + arr, ok := i.([]interface{}) + if !ok { + return nil + } + var result []string + for _, v := range arr { + str, ok := v.(string) + if ok { + result = append(result, str) + } + } + return result +} + +func toIntArray(i interface{}) []int { + if i == nil { + return nil + } + arr, ok := i.([]interface{}) + if !ok { + return nil + } + var result []int + for _, v := range arr { + num, ok := v.(float64) + if ok { + result = append(result, int(num)) + } + } + return result +} + +func toTagsMap(i interface{}) map[string][]string { + if i == nil { + return nil + } + tags, ok := i.(map[string]interface{}) + if !ok { + return nil + } + result := make(map[string][]string) + for k, v := range tags { + result[k] = toStringArray(v) + } + return result +} + +func toInt64(i interface{}) *int64 { + if i == nil { + return nil + } + num, ok := i.(float64) + if !ok { + return nil + } + val := int64(num) + return &val +} + +func toInt(i interface{}) *int { + if i == nil { + return nil + } + num, ok := i.(float64) + if !ok { + return nil + } + val := int(num) + return &val +} + +func toTime(data interface{}) *time.Time { + if data == nil { + return nil + } + // Ensure data is a float64 which MongoDB uses for numbers + timestamp, ok := data.(float64) + if !ok { + fmt.Println("Invalid timestamp format") + return nil + } + t := time.Unix(int64(timestamp), 0).UTC() + return &t +}