changeset 1101:f7d9542beb3d

PgSqlFlowContext, PgSqlFlowHandlerDataDecoder uses PgSqlResponsePacketsDecoder
author Devel 2
date Wed, 13 May 2020 11:04:03 +0200
parents 150d42e14d54
children 955e69e85a1d
files stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowContext.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataDecoder.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoder.java stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlFilter.java stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlLoginFilter.java stress-tester/src/test/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoderTest.java
diffstat 8 files changed, 177 insertions(+), 71 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowContext.java	Wed May 13 11:04:03 2020 +0200
@@ -0,0 +1,31 @@
+package com.passus.st.client.pgsql;
+
+import com.passus.net.pgsql.PgSqlMessageType;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public final class PgSqlFlowContext {
+
+    public static final int STAGE_ERROR = -2;
+    public static final int STAGE_AUTH_FAILED = -1;
+    public static final int STAGE_NONE = 0;
+    public static final int STAGE_SSL_REQUEST = 1;
+    public static final int STAGE_STARTUP_MESSAGE = 2;
+    public static final int STAGE_PASSWORD = 3;
+    public static final int STAGE_AUTH_OK = 3;
+
+    PgSqlMessageType lastMsgType;
+
+    String database;
+
+    String user;
+
+    int stage = STAGE_NONE;
+
+    boolean ssl = false;
+
+    Map<String, String> parameters = new HashMap<>();
+
+
+}
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandler.java	Wed May 13 11:04:03 2020 +0200
@@ -7,6 +7,7 @@
 import com.passus.net.pgsql.PgSqlMessage;
 import com.passus.net.pgsql.PgSqlMessageType;
 import com.passus.net.pgsql.PgSqlSimpleQueryMessage;
+import com.passus.st.PacketsBulk;
 import com.passus.st.client.AbstractFlowHandler;
 import com.passus.st.client.FlowContext;
 import com.passus.st.client.FlowHandlerDataDecoder;
@@ -16,20 +17,22 @@
 
 import static com.passus.st.Protocols.NETFLOW;
 
-public class PgSqlFlowHandler extends AbstractFlowHandler<PgSqlMetric, PgSqlMessage, PgSqlMessage> implements TimeAware {
+public class PgSqlFlowHandler extends AbstractFlowHandler<PgSqlMetric, PgSqlMessage, PacketsBulk<PgSqlMessage>> implements TimeAware {
 
     private final Logger LOGGER = LogManager.getLogger(PgSqlFlowHandler.class);
 
     TimeGenerator timeGenerator = TimeGenerator.getDefaultGenerator();
 
+    PgSqlFlowContext context = new PgSqlFlowContext();
+
     @Override
     protected FlowHandlerDataEncoder<PgSqlMessage> createEncoder() {
-        return new PgSqlFlowHandlerDataEncoder();
+        return new PgSqlFlowHandlerDataEncoder(context);
     }
 
     @Override
-    protected FlowHandlerDataDecoder<PgSqlMessage> createDecoder() {
-        return new PgSqlFlowHandlerDataDecoder();
+    protected FlowHandlerDataDecoder<PacketsBulk<PgSqlMessage>> createDecoder() {
+        return new PgSqlFlowHandlerDataDecoder(context);
     }
 
     @Override
@@ -70,28 +73,31 @@
     }
 
     @Override
-    protected void onResponseReceived0(PgSqlMessage resp, FlowContext flowContext) {
-        if (resp.getType() == PgSqlMessageType.ERROR_RESPONSE) {
-            PgSqlErrorResponseMessage errorMsg = (PgSqlErrorResponseMessage) resp;
-            if (collectMetrics) {
-                synchronized (metric) {
-                    metric.addErrorCode(errorMsg.getCode());
-                }
-            }
-
-            PgSqlMessage req = (PgSqlMessage) flowContext.sentRequest();
-            if (req.getType() == PgSqlMessageType.STARTUP_MESSAGE) {
-                if (LOGGER.isDebugEnabled()) {
-                    LOGGER.debug("PgSql auth failed. Server message " + errorMsg);
+    protected void onResponseReceived0(PacketsBulk<PgSqlMessage> bulk, FlowContext flowContext) {
+        if (bulk.packets.size() == 1) {
+            PgSqlMessage resp = bulk.packets.get(0);
+            if (resp.getType() == PgSqlMessageType.ERROR_RESPONSE) {
+                PgSqlErrorResponseMessage errorMsg = (PgSqlErrorResponseMessage) resp;
+                if (collectMetrics) {
+                    synchronized (metric) {
+                        metric.addErrorCode(errorMsg.getCode());
+                    }
                 }
 
-                disconnectAndBlock(flowContext);
-            } else if (errorMsg.isConnectionExceptionClass()) {
-                if (LOGGER.isDebugEnabled()) {
-                    LOGGER.debug("Fatal error. Server message " + errorMsg);
+                PgSqlMessage req = (PgSqlMessage) flowContext.sentRequest();
+                if (req.getType() == PgSqlMessageType.STARTUP_MESSAGE) {
+                    if (LOGGER.isDebugEnabled()) {
+                        LOGGER.debug("PgSql auth failed. Server message " + errorMsg);
+                    }
+
+                    disconnectAndBlock(flowContext);
+                } else if (errorMsg.isConnectionExceptionClass()) {
+                    if (LOGGER.isDebugEnabled()) {
+                        LOGGER.debug("Fatal error. Server message " + errorMsg);
+                    }
+
+                    disconnectAndBlock(flowContext);
                 }
-
-                disconnectAndBlock(flowContext);
             }
         }
     }
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataDecoder.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataDecoder.java	Wed May 13 11:04:03 2020 +0200
@@ -2,8 +2,8 @@
 
 import com.passus.data.ByteBuff;
 import com.passus.data.DataDecoder;
-import com.passus.net.pgsql.PgSqlDecoder;
 import com.passus.net.pgsql.PgSqlMessage;
+import com.passus.st.PacketsBulk;
 import com.passus.st.client.FlowContext;
 import com.passus.st.client.FlowHandlerDataDecoder;
 import org.apache.logging.log4j.LogManager;
@@ -11,18 +11,21 @@
 
 import static com.passus.st.client.FlowUtils.debug;
 
-public class PgSqlFlowHandlerDataDecoder implements FlowHandlerDataDecoder<PgSqlMessage> {
+public class PgSqlFlowHandlerDataDecoder implements FlowHandlerDataDecoder<PacketsBulk<PgSqlMessage>> {
 
     private static final Logger LOGGER = LogManager.getLogger(PgSqlFlowHandlerDataDecoder.class);
 
-    private PgSqlDecoder decoder = new PgSqlDecoder(false);
+    private PgSqlResponsePacketsDecoder decoder;
 
-    public PgSqlFlowHandlerDataDecoder() {
+    private final PgSqlFlowContext context;
 
+    public PgSqlFlowHandlerDataDecoder(PgSqlFlowContext context) {
+        this.context = context;
+        decoder = new PgSqlResponsePacketsDecoder(context);
     }
 
     @Override
-    public PgSqlMessage getResult() {
+    public PacketsBulk<PgSqlMessage> getResult() {
         return decoder.getResult();
     }
 
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlFlowHandlerDataEncoder.java	Wed May 13 11:04:03 2020 +0200
@@ -3,6 +3,7 @@
 import com.passus.data.ByteBuff;
 import com.passus.net.pgsql.PgSqlEncoder;
 import com.passus.net.pgsql.PgSqlMessage;
+import com.passus.net.pgsql.PgSqlMessageType;
 import com.passus.st.client.FlowContext;
 import com.passus.st.client.FlowHandlerDataEncoder;
 
@@ -10,8 +11,16 @@
 
     private final PgSqlEncoder encoder = new PgSqlEncoder();
 
+    private final PgSqlFlowContext context;
+
+    public PgSqlFlowHandlerDataEncoder(PgSqlFlowContext context) {
+        this.context = context;
+    }
+
     @Override
     public void encode(PgSqlMessage request, FlowContext flowContext, ByteBuff out) {
+        context.lastMsgType = request.getType();
         encoder.encode(request, out);
     }
+
 }
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoder.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoder.java	Wed May 13 11:04:03 2020 +0200
@@ -6,10 +6,7 @@
 import com.passus.net.FixedSizeLengthPduXHandler;
 import com.passus.net.PduX;
 import com.passus.net.mysql.MySqlDecoder;
-import com.passus.net.mysql.MySqlDecoderContext;
-import com.passus.net.pgsql.PgSqlDecoder;
-import com.passus.net.pgsql.PgSqlMessage;
-import com.passus.net.pgsql.PgSqlMessageType;
+import com.passus.net.pgsql.*;
 import com.passus.st.PacketsBulk;
 import com.passus.st.client.mysql.MySqlResponsePacketsDecoder;
 import org.apache.commons.lang3.mutable.MutableInt;
@@ -24,11 +21,17 @@
 
     private final PacketsBulk<PgSqlMessage> bulk;
 
+    private final PgSqlFlowContext context;
+
     private final PgSqlPdu pdu;
 
-    public PgSqlResponsePacketsDecoder() {
+    private final PgSqlDecoder decoder;
+
+    public PgSqlResponsePacketsDecoder(PgSqlFlowContext context) {
         this.bulk = new PacketsBulk<>(PGSQL);
         this.pdu = new PgSqlPdu();
+        this.decoder = new PgSqlDecoder(false);
+        this.context = context;
     }
 
     @Override
@@ -36,17 +39,50 @@
         return bulk;
     }
 
+    private void doDecode(byte[] data, int offset, int length) {
+        try {
+            decoder.decode(data, offset, length);
+            if (decoder.state() == MySqlDecoder.STATE_FINISHED) {
+                PgSqlMessage packet = decoder.getResult();
+                if (packet != null) {
+                    bulk.packets.add(packet);
+                    if (packet.getType() == PgSqlMessageType.READY_FOR_QUERY) {
+                        state(STATE_FINISHED);
+                    }
+                }
+
+                decoder.clear();
+            } else if (decoder.state() == MySqlDecoder.STATE_ERROR) {
+                error("Decoder error. " + decoder.getLastError());
+            }
+        } catch (Exception e) {
+            if (LOGGER.isDebugEnabled()) {
+                LOGGER.debug(e.getMessage(), e);
+            }
+
+            decoder.clear();
+        }
+    }
 
     @Override
     public int decode(byte[] data, int offset, int length) {
+        if (context.stage == PgSqlFlowContext.STAGE_SSL_REQUEST) {
+            if (data[offset] == 'S') {
+                context.ssl = true;
+            } else if (data[offset] == 'N') {
+                context.ssl = false;
+            }
+
+            state(STATE_FINISHED);
+            return length;
+        }
+
         pdu.handle(data, offset, length);
         return length;
     }
 
     private class PgSqlPdu implements FixedSizeLengthPduXHandler {
 
-        private final PgSqlDecoder decoder;
-
         private final PduX pdu;
 
         private boolean sslRequested = false;
@@ -54,7 +90,6 @@
         private boolean sslSession;
 
         public PgSqlPdu() {
-            decoder = new PgSqlDecoder(false);
             pdu = new FixedSizeLengthPduX(this, 5, false);
         }
 
@@ -78,28 +113,7 @@
 
         @Override
         public void onNewPdu(byte[] data, int offset, int length) {
-            try {
-                decoder.decode(data, offset, length);
-                if (decoder.state() == MySqlDecoder.STATE_FINISHED) {
-                    PgSqlMessage packet = decoder.getResult();
-                    if (packet != null) {
-                        bulk.packets.add(packet);
-                        if (packet.getType() == PgSqlMessageType.READY_FOR_QUERY) {
-                            state(STATE_FINISHED);
-                        }
-                    }
-
-                    decoder.clear();
-                } else if (decoder.state() == MySqlDecoder.STATE_ERROR) {
-                    error("Decoder error. " + decoder.getLastError());
-                }
-            } catch (Exception e) {
-                if (LOGGER.isDebugEnabled()) {
-                    LOGGER.debug(e.getMessage(), e);
-                }
-
-                decoder.clear();
-            }
+            doDecode(data, offset, length);
         }
     }
 }
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlFilter.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlFilter.java	Wed May 13 11:04:03 2020 +0200
@@ -1,9 +1,12 @@
 package com.passus.st.client.pgsql.filter;
 
 import com.passus.net.pgsql.PgSqlMessage;
+import com.passus.st.PacketsBulk;
 import com.passus.st.client.FlowContext;
 import com.passus.st.filter.FlowFilter;
 
+import static com.passus.st.Protocols.PGSQL;
+
 /**
  * @author mikolaj.podbielski
  */
@@ -12,24 +15,24 @@
     @Override
     public int filterInbound(Object req, Object resp, FlowContext context) {
         if (resp instanceof PgSqlMessage) {
-            return filterInbound(req, (PgSqlMessage) resp, context);
+            return filterInbound(req, (PacketsBulk<PgSqlMessage>) resp, context);
         }
         return DUNNO;
     }
 
-    public int filterInbound(Object req, PgSqlMessage resp, FlowContext context) {
+    public int filterInbound(Object req, PacketsBulk<PgSqlMessage> bulk, FlowContext context) {
         return DUNNO;
     }
 
     @Override
     public int filterOutbound(Object req, Object resp, FlowContext context) {
-        if (req instanceof PgSqlMessage || resp instanceof PgSqlMessage) {
-            return filterOutbound((PgSqlMessage) req, (PgSqlMessage) resp, context);
+        if (req instanceof PgSqlMessage || (resp instanceof PacketsBulk && ((PacketsBulk) resp).protocol == PGSQL)) {
+            return filterOutbound((PgSqlMessage) req, (PacketsBulk<PgSqlMessage>) resp, context);
         }
         return DUNNO;
     }
 
-    public int filterOutbound(PgSqlMessage req, PgSqlMessage resp, FlowContext context) {
+    public int filterOutbound(PgSqlMessage req, PacketsBulk<PgSqlMessage> bulk, FlowContext context) {
         return DUNNO;
     }
 }
--- a/stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlLoginFilter.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/pgsql/filter/PgSqlLoginFilter.java	Wed May 13 11:04:03 2020 +0200
@@ -7,6 +7,7 @@
 import com.passus.config.schema.NodeDefinition;
 import com.passus.config.schema.NodeDefinitionCreator;
 import com.passus.net.pgsql.*;
+import com.passus.st.PacketsBulk;
 import com.passus.st.client.FlowContext;
 import com.passus.st.client.credentials.Credentials;
 import com.passus.st.client.credentials.CredentialsProvider;
@@ -52,16 +53,19 @@
     }
 
     @Override
-    public int filterInbound(Object req, PgSqlMessage resp, FlowContext context) {
-        if (resp.getType() == PgSqlMessageType.AUTH_REQUEST) {
-            PgSqlAuthRequestMessage authReq = (PgSqlAuthRequestMessage) resp;
-            context.setParam("authReq", authReq);
+    public int filterInbound(Object req, PacketsBulk<PgSqlMessage> bulk, FlowContext context) {
+        if (bulk.packets.size() == 1) {
+            PgSqlMessage resp = bulk.packets.get(0);
+            if (resp.getType() == PgSqlMessageType.AUTH_REQUEST) {
+                PgSqlAuthRequestMessage authReq = (PgSqlAuthRequestMessage) resp;
+                context.setParam("authReq", authReq);
+            }
         }
         return DUNNO;
     }
 
     @Override
-    public int filterOutbound(PgSqlMessage req, PgSqlMessage resp, FlowContext context) {
+    public int filterOutbound(PgSqlMessage req, PacketsBulk<PgSqlMessage> bulk, FlowContext context) {
         if (req.getType() == PgSqlMessageType.STARTUP_MESSAGE) {
             PgSqlStartupMessage startupMsg = (PgSqlStartupMessage) req;
             Credentials credentials = credentialsProvider.getCredentials(context);
--- a/stress-tester/src/test/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoderTest.java	Tue May 12 13:36:50 2020 +0200
+++ b/stress-tester/src/test/java/com/passus/st/client/pgsql/PgSqlResponsePacketsDecoderTest.java	Wed May 13 11:04:03 2020 +0200
@@ -17,15 +17,51 @@
 public class PgSqlResponsePacketsDecoderTest {
 
     @Test
-    public void testProcessSimpleQueryAndData() {
-        List<Tcp> tcps = PcapUtils.readTcpPackets("pcap/pgsql/pgsql_simple_query_and_data.pcap");
+    public void testProcessSSLRequest() {
+        List<Tcp> tcps = PcapUtils.readTcpPackets("pcap/pgsql/pgsql_ssl_req_resp.pcap");
 
+        PgSqlFlowContext context = new PgSqlFlowContext();
         PgSqlDecoder reqDecoder = new PgSqlDecoder(true);
-        PgSqlResponsePacketsDecoder respDecoder = new PgSqlResponsePacketsDecoder();
+        PgSqlResponsePacketsDecoder respDecoder = new PgSqlResponsePacketsDecoder(context);
 
         Iterator<Tcp> it = tcps.iterator();
         reqDecoder.decode(it.next().getPayload());
         assertEquals(STATE_FINISHED, reqDecoder.state());
+        context.stage = PgSqlFlowContext.STAGE_SSL_REQUEST;
+
+        respDecoder.decode(it.next().getPayload());
+        assertEquals(STATE_FINISHED, respDecoder.state());
+    }
+
+    @Test
+    public void testProcessStartupMessage() {
+        List<Tcp> tcps = PcapUtils.readTcpPackets("pcap/pgsql/pgsql_startup_message.pcap");
+
+        PgSqlFlowContext context = new PgSqlFlowContext();
+        PgSqlDecoder reqDecoder = new PgSqlDecoder(true);
+        PgSqlResponsePacketsDecoder respDecoder = new PgSqlResponsePacketsDecoder(context);
+
+        Iterator<Tcp> it = tcps.iterator();
+        reqDecoder.decode(it.next().getPayload());
+        assertEquals(STATE_FINISHED, reqDecoder.state());
+        context.stage = PgSqlFlowContext.STAGE_STARTUP_MESSAGE;
+
+        respDecoder.decode(it.next().getPayload());
+        assertEquals(STATE_FINISHED, respDecoder.state());
+    }
+
+    @Test
+    public void testProcessSimpleQueryAndData() {
+        List<Tcp> tcps = PcapUtils.readTcpPackets("pcap/pgsql/pgsql_simple_query_and_data.pcap");
+
+        PgSqlFlowContext context = new PgSqlFlowContext();
+        PgSqlDecoder reqDecoder = new PgSqlDecoder(true);
+        PgSqlResponsePacketsDecoder respDecoder = new PgSqlResponsePacketsDecoder(context);
+
+        Iterator<Tcp> it = tcps.iterator();
+        reqDecoder.decode(it.next().getPayload());
+        assertEquals(STATE_FINISHED, reqDecoder.state());
+        context.stage = PgSqlFlowContext.STAGE_AUTH_OK;
 
         respDecoder.decode(it.next().getPayload());
         assertEquals(STATE_DATA_NEEDED, respDecoder.state());