From e8b1380c9019078554d401310b0793f0c0f8584e Mon Sep 17 00:00:00 2001 From: 0ceanSlim Date: Thu, 22 Aug 2024 16:12:12 -0400 Subject: [PATCH] req returns all lastest events when no kind is specefied. --- server/handlers/req.go | 74 +++++++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/server/handlers/req.go b/server/handlers/req.go index beb9d7b..9b3b231 100644 --- a/server/handlers/req.go +++ b/server/handlers/req.go @@ -50,6 +50,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri fmt.Println("Dropped oldest subscription:", oldestSubID) } + // Prepare filters based on the incoming message filters := make([]relay.Filter, len(message)-2) for i, filter := range message[2:] { filterData, ok := filter.(map[string]interface{}) @@ -73,7 +74,7 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri // Add the new subscription or update the existing one subscriptions[subID] = filters - fmt.Println("Subscription updated:", subID) + fmt.Printf("Subscription updated: %s with %d filters\n", subID, len(filters)) // Query the database with filters and send back the results queriedEvents, err := QueryEvents(filters, db.GetClient(), "grain") @@ -83,6 +84,13 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri return } + // Log the number of events retrieved + fmt.Printf("Retrieved %d events for subscription %s\n", len(queriedEvents), subID) + if len(queriedEvents) == 0 { + fmt.Printf("No matching events found for subscription %s\n", subID) + } + + // Send each event back to the client for _, evt := range queriedEvents { msg := []interface{}{"EVENT", subID, evt} msgBytes, _ := json.Marshal(msg) @@ -103,16 +111,19 @@ func HandleReq(ws *websocket.Conn, message []interface{}, subscriptions map[stri response.SendClosed(ws, subID, "error: could not send EOSE") return } + + fmt.Println("Subscription handling completed, keeping connection open.") }) } -// QueryEvents queries events from the MongoDB collection based on filters +// QueryEvents queries events from the MongoDB collection(s) based on filters func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName string) ([]relay.Event, error) { var results []relay.Event for _, filter := range filters { filterBson := bson.M{} + // Construct the BSON query based on the filters if len(filter.IDs) > 0 { filterBson["id"] = bson.M{"$in": filter.IDs} } @@ -145,24 +156,57 @@ func QueryEvents(filters []relay.Filter, client *mongo.Client, databaseName stri opts.SetLimit(int64(*filter.Limit)) } - for _, kind := range filter.Kinds { - collectionName := fmt.Sprintf("event-kind%d", kind) - collection := client.Database(databaseName).Collection(collectionName) - cursor, err := collection.Find(context.TODO(), filterBson, opts) + // If no specific kinds are specified, query all collections in the database + if len(filter.Kinds) == 0 { + collections, err := client.Database(databaseName).ListCollectionNames(context.TODO(), bson.D{}) if err != nil { - return nil, fmt.Errorf("error querying events: %v", err) + return nil, fmt.Errorf("error listing collections: %v", err) } - defer cursor.Close(context.TODO()) - for cursor.Next(context.TODO()) { - var event relay.Event - if err := cursor.Decode(&event); err != nil { - return nil, fmt.Errorf("error decoding event: %v", err) + for _, collectionName := range collections { + fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson) + + collection := client.Database(databaseName).Collection(collectionName) + cursor, err := collection.Find(context.TODO(), filterBson, opts) + if err != nil { + return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) + } + defer cursor.Close(context.TODO()) + + for cursor.Next(context.TODO()) { + var event relay.Event + if err := cursor.Decode(&event); err != nil { + return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) + } + results = append(results, event) + } + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) } - results = append(results, event) } - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %v", err) + } else { + // Query specific collections based on kinds + for _, kind := range filter.Kinds { + collectionName := fmt.Sprintf("event-kind%d", kind) + fmt.Printf("Querying collection: %s with query: %v\n", collectionName, filterBson) + + collection := client.Database(databaseName).Collection(collectionName) + cursor, err := collection.Find(context.TODO(), filterBson, opts) + if err != nil { + return nil, fmt.Errorf("error querying collection %s: %v", collectionName, err) + } + defer cursor.Close(context.TODO()) + + for cursor.Next(context.TODO()) { + var event relay.Event + if err := cursor.Decode(&event); err != nil { + return nil, fmt.Errorf("error decoding event from collection %s: %v", collectionName, err) + } + results = append(results, event) + } + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error in collection %s: %v", collectionName, err) + } } } }