changeset 1103:4b9c98988fa4

PgSqlFlowHandler - packets processing in progress
author Devel 2
date Wed, 13 May 2020 15:09:08 +0200
parents 955e69e85a1d
children 0cdfcf4df1c6
files stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowContext.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java stress-tester/src/test/java/com/passus/st/client/pgsql/PgSqlFlowHandlerTest.java
diffstat 4 files changed, 184 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowContext.java	Wed May 13 11:04:21 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowContext.java	Wed May 13 15:09:08 2020 +0200
@@ -12,10 +12,9 @@
     public static final int STAGE_NONE = 0;
     public static final int STAGE_SSL_REQUEST = 1;
     public static final int STAGE_STARTUP_MESSAGE = 2;
-    public static final int STAGE_PASSWORD = 3;
-    public static final int STAGE_AUTH_OK = 3;
-
-    PgSqlMessageType lastMsgType;
+    public static final int STAGE_PASSWORD_REQUIRED = 3;
+    public static final int STAGE_PASSWORD = 4;
+    public static final int STAGE_AUTH_OK = 5;
 
     String database;
 
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java	Wed May 13 11:04:21 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java	Wed May 13 15:09:08 2020 +0200
@@ -3,10 +3,7 @@
 import com.passus.commons.Assert;
 import com.passus.commons.time.TimeAware;
 import com.passus.commons.time.TimeGenerator;
-import com.passus.net.pgsql.PgSqlErrorResponseMessage;
-import com.passus.net.pgsql.PgSqlMessage;
-import com.passus.net.pgsql.PgSqlMessageType;
-import com.passus.net.pgsql.PgSqlSimpleQueryMessage;
+import com.passus.net.pgsql.*;
 import com.passus.st.PacketsBulk;
 import com.passus.st.client.AbstractFlowHandler;
 import com.passus.st.client.FlowContext;
@@ -15,6 +12,8 @@
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
+import java.util.List;
+
 import static com.passus.st.Protocols.NETFLOW;
 
 public class PgSqlFlowHandler extends AbstractFlowHandler<PgSqlMetric, PgSqlMessage, PacketsBulk<PgSqlMessage>> implements TimeAware {
@@ -53,13 +52,31 @@
 
     @Override
     protected void onRequestSent0(PgSqlMessage req, FlowContext flowContext) {
-        if (collectMetrics) {
-            if (req.getType() == PgSqlMessageType.SIMPLE_QUERY) {
-                PgSqlSimpleQueryMessage simpleQueryMsg = (PgSqlSimpleQueryMessage) req;
-                synchronized (metric) {
-                    metric.addQuery(simpleQueryMsg.getQuery());
+        if (context.stage == PgSqlFlowContext.STAGE_AUTH_OK) {
+            if (collectMetrics) {
+                if (req.getType() == PgSqlMessageType.SIMPLE_QUERY) {
+                    PgSqlSimpleQueryMessage simpleQueryMsg = (PgSqlSimpleQueryMessage) req;
+                    synchronized (metric) {
+                        metric.addQuery(simpleQueryMsg.getQuery());
+                    }
                 }
             }
+        } else if (context.stage == PgSqlFlowContext.STAGE_PASSWORD_REQUIRED) {
+            if (req.getType() == PgSqlMessageType.PASSWORD_MESSAGE) {
+                context.stage = PgSqlFlowContext.STAGE_PASSWORD;
+            } else {
+                if (LOGGER.isDebugEnabled()) {
+                    LOGGER.debug("Required PASSWORD_MESSAGE in stage STAGE_PASSWORD_REQUIRED.");
+                }
+
+                disconnectAndBlock(flowContext);
+            }
+        } else if (context.stage == PgSqlFlowContext.STAGE_NONE) {
+            if (req.getType() == PgSqlMessageType.SSL_REQUEST) {
+                context.stage = PgSqlFlowContext.STAGE_SSL_REQUEST;
+            } else if (req.getType() == PgSqlMessageType.STARTUP_MESSAGE) {
+                context.stage = PgSqlFlowContext.STAGE_STARTUP_MESSAGE;
+            }
         }
     }
 
@@ -72,9 +89,17 @@
         }
     }
 
-    @Override
-    protected void onResponseReceived0(PacketsBulk<PgSqlMessage> bulk, FlowContext flowContext) {
-        if (bulk.packets.size() == 1) {
+    private void populateParameters(List<PgSqlMessage> packets) {
+        for (PgSqlMessage packet : packets) {
+            if (packet.getType() == PgSqlMessageType.PARAMETER_STATUS) {
+                PgSqlParameterStatusMessage param = (PgSqlParameterStatusMessage) packet;
+                context.parameters.put(param.getName(), param.getValue());
+            }
+        }
+    }
+
+    private void processAuthOK(PacketsBulk<PgSqlMessage> bulk, FlowContext flowContext) {
+        if (bulk.packets.size() > 0) {
             PgSqlMessage resp = bulk.packets.get(0);
             if (resp.getType() == PgSqlMessageType.ERROR_RESPONSE) {
                 PgSqlErrorResponseMessage errorMsg = (PgSqlErrorResponseMessage) resp;
@@ -84,14 +109,7 @@
                     }
                 }
 
-                PgSqlMessage req = (PgSqlMessage) flowContext.sentRequest();
-                if (req.getType() == PgSqlMessageType.STARTUP_MESSAGE) {
-                    if (LOGGER.isDebugEnabled()) {
-                        LOGGER.debug("PgSql auth failed. Server message " + errorMsg);
-                    }
-
-                    disconnectAndBlock(flowContext);
-                } else if (errorMsg.isConnectionExceptionClass()) {
+                if (errorMsg.isConnectionExceptionClass()) {
                     if (LOGGER.isDebugEnabled()) {
                         LOGGER.debug("Fatal error. Server message " + errorMsg);
                     }
@@ -103,6 +121,50 @@
     }
 
     @Override
+    protected void onResponseReceived0(PacketsBulk<PgSqlMessage> bulk, FlowContext flowContext) {
+        switch (context.stage) {
+            case PgSqlFlowContext.STAGE_AUTH_OK:
+                processAuthOK(bulk, flowContext);
+                break;
+            case PgSqlFlowContext.STAGE_STARTUP_MESSAGE: {
+                PgSqlAuthRequestMessage auth = (PgSqlAuthRequestMessage) bulk.packets.get(0);
+                if (auth.isSuccess()) {
+                    if (LOGGER.isDebugEnabled()) {
+                        LOGGER.debug("PgSql auth success.");
+                    }
+
+                    context.stage = PgSqlFlowContext.STAGE_AUTH_OK;
+                    populateParameters(bulk.packets);
+                } else {
+                    context.stage = PgSqlFlowContext.STAGE_PASSWORD_REQUIRED;
+                }
+
+                break;
+            }
+            case PgSqlFlowContext.STAGE_PASSWORD: {
+                PgSqlAuthRequestMessage auth = (PgSqlAuthRequestMessage) bulk.packets.get(0);
+                if (auth.isSuccess()) {
+                    if (LOGGER.isDebugEnabled()) {
+                        LOGGER.debug("PgSql auth success.");
+                    }
+
+                    context.stage = PgSqlFlowContext.STAGE_AUTH_OK;
+                    populateParameters(bulk.packets);
+                } else {
+                    if (LOGGER.isDebugEnabled()) {
+                        LOGGER.debug("PgSql auth failed.");
+                    }
+
+                    context.stage = PgSqlFlowContext.STAGE_AUTH_FAILED;
+                    disconnectAndBlock(flowContext);
+                }
+
+                break;
+            }
+        }
+    }
+
+    @Override
     public TimeGenerator getTimeGenerator() {
         return timeGenerator;
     }
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java	Wed May 13 11:04:21 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java	Wed May 13 15:09:08 2020 +0200
@@ -19,7 +19,6 @@
 
     @Override
     public void encode(PgSqlMessage request, FlowContext flowContext, ByteBuff out) {
-        context.lastMsgType = request.getType();
         encoder.encode(request, out);
     }
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/test/java/com/passus/st/client/pgsql/PgSqlFlowHandlerTest.java	Wed May 13 15:09:08 2020 +0200
@@ -0,0 +1,99 @@
+package com.passus.st.client.pgsql;
+
+import com.passus.data.ByteBuff;
+import com.passus.net.packet.Tcp;
+import com.passus.net.pgsql.PgSqlDecoder;
+import com.passus.net.pgsql.PgSqlMessage;
+import com.passus.net.source.pcap.PcapUtils;
+import com.passus.st.PacketsBulk;
+import com.passus.st.client.FlowContext;
+import com.passus.st.client.FlowHandlerDataDecoder;
+import com.passus.st.emitter.SessionInfo;
+import org.testng.annotations.Test;
+
+import java.text.ParseException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.testng.AssertJUnit.*;
+
+public class PgSqlFlowHandlerTest {
+
+    private final SessionInfo info;
+
+    public PgSqlFlowHandlerTest() {
+        try {
+            info = SessionInfo.parse("1.1.1.1:100 <-> 2.2.2.2:5432");
+        } catch (ParseException e) {
+            throw new RuntimeException(e.getMessage(), e);
+        }
+    }
+
+    public Result processReqResp(String pcapFile) {
+        List<Tcp> tcps = PcapUtils.readTcpPackets(pcapFile);
+        tcps = tcps.stream().filter(t -> t.hasPayload()).collect(Collectors.toList());
+
+        Iterator<Tcp> it = tcps.iterator();
+        FlowContext flowContext = new FlowContext(info);
+        PgSqlDecoder reqDecoder = new PgSqlDecoder();
+        reqDecoder.decode(it.next().getPayload());
+        assertEquals(PgSqlDecoder.STATE_FINISHED, reqDecoder.state());
+        assertTrue(it.hasNext());
+
+        PgSqlFlowHandler handler = new PgSqlFlowHandler();
+        handler.init(flowContext);
+        handler.onRequestSent(reqDecoder.getResult(), flowContext);
+        int afterReqStage = handler.context.stage;
+
+        FlowHandlerDataDecoder<PacketsBulk<PgSqlMessage>> decoder = handler.getResponseDecoder(flowContext);
+        while (it.hasNext()) {
+            byte[] payload = it.next().getPayload();
+            ByteBuff buff = ByteBuff.wrap(payload, 0, payload.length);
+            decoder.decode(buff, flowContext);
+            if (it.hasNext()) {
+                assertEquals(PgSqlDecoder.STATE_DATA_NEEDED, reqDecoder.state());
+            } else {
+                assertEquals(PgSqlDecoder.STATE_FINISHED, reqDecoder.state());
+            }
+        }
+
+        handler.onResponseReceived(decoder.getResult(), flowContext);
+        return new Result(handler, afterReqStage, decoder.getResult());
+    }
+
+    @Test
+    public void testProcessSSLRequest() {
+        Result result = processReqResp("pcap/pgsql/pgsql_ssl_req_resp.pcap");
+
+        PgSqlFlowContext context = result.handler.context;
+        assertFalse(context.ssl);
+        assertEquals(0, result.bulk.packets.size());
+        assertEquals(PgSqlFlowContext.STAGE_SSL_REQUEST, result.afterRequestStage);
+    }
+
+    @Test
+    public void testProcessStartupMessage() {
+        Result result = processReqResp("pcap/pgsql/pgsql_startup_message.pcap");
+
+        PgSqlFlowContext context = result.handler.context;
+        assertEquals(PgSqlFlowContext.STAGE_AUTH_OK, context.stage);
+        assertEquals(11, context.parameters.size());
+        assertEquals(PgSqlFlowContext.STAGE_STARTUP_MESSAGE, result.afterRequestStage);
+    }
+
+    private class Result {
+
+        final PgSqlFlowHandler handler;
+
+        final int afterRequestStage;
+
+        final PacketsBulk<PgSqlMessage> bulk;
+
+        public Result(PgSqlFlowHandler handler, int afterRequestStage, PacketsBulk<PgSqlMessage> bulk) {
+            this.handler = handler;
+            this.afterRequestStage = afterRequestStage;
+            this.bulk = bulk;
+        }
+    }
+}
\ No newline at end of file