From 106948d996d74bf5ff7e3511f35eefea0a90561f Mon Sep 17 00:00:00 2001
From: rpigott <rpigott@berkeley.edu>
Date: Sun, 27 Jan 2019 00:59:26 -0800
Subject: [PATCH 1/4] update to wsproto 0.13

---
 mitmproxy/proxy/protocol/websocket.py | 87 +++++++++++++--------------
 setup.py                              |  2 +-
 2 files changed, 44 insertions(+), 45 deletions(-)

diff --git mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/protocol/websocket.py
index 0d1964a60..591bae7e0 100644
--- mitmproxy/proxy/protocol/websocket.py
+++ mitmproxy/proxy/protocol/websocket.py
@@ -4,8 +4,9 @@
 
 
 import wsproto
-from wsproto import events
-from wsproto.connection import ConnectionType, WSConnection
+from wsproto import events, WSConnection
+from wsproto.connection import ConnectionType
+from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request
 from wsproto.extensions import PerMessageDeflate
 
 from mitmproxy import exceptions
@@ -56,47 +57,44 @@ def __init__(self, ctx, handshake_flow):
         if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
             if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
                 extensions = [PerMessageDeflate()]
-        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
-                                                          extensions=extensions)
-        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
-                                                          host=handshake_flow.request.host,
-                                                          resource=handshake_flow.request.path,
-                                                          extensions=extensions)
+        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
+        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)
+
         if extensions:
-            for conn in self.connections.values():
-                conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
+            extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
 
-        data = self.connections[self.server_conn].bytes_to_send()
-        self.connections[self.client_conn].receive_bytes(data)
+        request = Request(extensions = extensions, host = handshake_flow.request.host, target = handshake_flow.request.path)
+        data = self.connections[self.server_conn].send(request)
+        self.connections[self.client_conn].receive_data(data)
 
         event = next(self.connections[self.client_conn].events())
-        assert isinstance(event, events.ConnectionRequested)
+        assert isinstance(event, events.Request)
 
-        self.connections[self.client_conn].accept(event)
-        self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
-        assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
+        data = self.connections[self.client_conn].send(AcceptConnection(extensions=extensions))
+        self.connections[self.server_conn].receive_data(data)
+        assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
 
     def _handle_event(self, event, source_conn, other_conn, is_server):
-        if isinstance(event, events.DataReceived):
-            return self._handle_data_received(event, source_conn, other_conn, is_server)
-        elif isinstance(event, events.PingReceived):
-            return self._handle_ping_received(event, source_conn, other_conn, is_server)
-        elif isinstance(event, events.PongReceived):
-            return self._handle_pong_received(event, source_conn, other_conn, is_server)
-        elif isinstance(event, events.ConnectionClosed):
-            return self._handle_connection_closed(event, source_conn, other_conn, is_server)
+        if isinstance(event, events.Message):
+            return self._handle_message(event, source_conn, other_conn, is_server)
+        elif isinstance(event, events.Ping):
+            return self._handle_ping(event, source_conn, other_conn, is_server)
+        elif isinstance(event, events.Pong):
+            return self._handle_pong(event, source_conn, other_conn, is_server)
+        elif isinstance(event, events.CloseConnection):
+            return self._handle_close_connection(event, source_conn, other_conn, is_server)
 
         # fail-safe for unhandled events
         return True  # pragma: no cover
 
-    def _handle_data_received(self, event, source_conn, other_conn, is_server):
+    def _handle_message(self, event, source_conn, other_conn, is_server):
         fb = self.server_frame_buffer if is_server else self.client_frame_buffer
         fb.append(event.data)
 
         if event.message_finished:
             original_chunk_sizes = [len(f) for f in fb]
 
-            if isinstance(event, events.TextReceived):
+            if isinstance(event, events.TextMessage):
                 message_type = wsproto.frame_protocol.Opcode.TEXT
                 payload = ''.join(fb)
             else:
@@ -127,19 +125,20 @@ def get_chunk(payload):
                             yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
 
                 for chunk, final in get_chunk(websocket_message.content):
-                    self.connections[other_conn].send_data(chunk, final)
-                    other_conn.send(self.connections[other_conn].bytes_to_send())
+                    data = self.connections[other_conn].send(Message(data = chunk, message_finished = final))
+                    other_conn.send(data)
 
         if self.flow.stream:
-            self.connections[other_conn].send_data(event.data, event.message_finished)
-            other_conn.send(self.connections[other_conn].bytes_to_send())
+            data = self.connections[other_conn].send(Message(data = event.data, message_finished = event.message_finished))
+            other_conn.send(data)
         return True
 
-    def _handle_ping_received(self, event, source_conn, other_conn, is_server):
-        # PING is automatically answered with a PONG by wsproto
-        self.connections[other_conn].ping()
-        other_conn.send(self.connections[other_conn].bytes_to_send())
-        source_conn.send(self.connections[source_conn].bytes_to_send())
+    def _handle_ping(self, event, source_conn, other_conn, is_server):
+        # Use event.response to create the approprate Pong response
+        data = self.connections[other_conn].send(Ping())
+        other_conn.send(data)
+        data = self.connections[source_conn].send(event.response())
+        source_conn.send(data)
         self.log(
             "Ping Received from {}".format("server" if is_server else "client"),
             "info",
@@ -147,7 +146,7 @@ def _handle_ping_received(self, event, source_conn, other_conn, is_server):
         )
         return True
 
-    def _handle_pong_received(self, event, source_conn, other_conn, is_server):
+    def _handle_pong(self, event, source_conn, other_conn, is_server):
         self.log(
             "Pong Received from {}".format("server" if is_server else "client"),
             "info",
@@ -155,14 +154,15 @@ def _handle_pong_received(self, event, source_conn, other_conn, is_server):
         )
         return True
 
-    def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
+    def _handle_close_connection(self, event, source_conn, other_conn, is_server):
         self.flow.close_sender = "server" if is_server else "client"
         self.flow.close_code = event.code
         self.flow.close_reason = event.reason
 
-        self.connections[other_conn].close(event.code, event.reason)
-        other_conn.send(self.connections[other_conn].bytes_to_send())
-        source_conn.send(self.connections[source_conn].bytes_to_send())
+        data = self.connections[other_conn].send(CloseConnection(code=event.code, reason=event.reason))
+        other_conn.send(data)
+        data = self.connections[source_conn].send(event.response())
+        source_conn.send(data)
 
         return False
 
@@ -170,8 +170,7 @@ def _inject_messages(self, endpoint, message_queue):
         while True:
             try:
                 payload = message_queue.get_nowait()
-                self.connections[endpoint].send_data(payload, final=True)
-                data = self.connections[endpoint].bytes_to_send()
+                data = self.connections[endpoint].send(Message(data = payload, message_finished = True))
                 endpoint.send(data)
             except queue.Empty:
                 break
@@ -197,8 +196,8 @@ def __call__(self):
                     is_server = (source_conn == self.server_conn)
 
                     frame = websockets.Frame.from_file(source_conn.rfile)
-                    self.connections[source_conn].receive_bytes(bytes(frame))
-                    source_conn.send(self.connections[source_conn].bytes_to_send())
+                    data = self.connections[source_conn].receive_data(bytes(frame))
+                    source_conn.send(data)
 
                     if close_received:
                         return

From 2af71a4486494b64431558da2e068dc1babe65fa Mon Sep 17 00:00:00 2001
From: rpigott <rpigott@berkeley.edu>
Date: Wed, 30 Jan 2019 18:12:05 -0800
Subject: [PATCH 2/4] Separate client and server PerMessageDefalte extension
 per #3460

---
 mitmproxy/proxy/protocol/websocket.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/protocol/websocket.py
index 591bae7e0..fd2aed157 100644
--- mitmproxy/proxy/protocol/websocket.py
+++ mitmproxy/proxy/protocol/websocket.py
@@ -53,24 +53,28 @@ def __init__(self, ctx, handshake_flow):
 
         self.connections: dict[object, WSConnection] = {}
 
-        extensions = []
+        client_extensions = []
+        server_extensions = []
         if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
             if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
-                extensions = [PerMessageDeflate()]
+                client_extensions = [PerMessageDeflate()]
+                server_extensions = [PerMessageDeflate()]
         self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
         self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)
 
-        if extensions:
-            extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
+        if client_extensions:
+            client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
+		if server_extensions:
+            server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
 
-        request = Request(extensions = extensions, host = handshake_flow.request.host, target = handshake_flow.request.path)
+        request = Request(extensions = client_extensions, host = handshake_flow.request.host, target = handshake_flow.request.path)
         data = self.connections[self.server_conn].send(request)
         self.connections[self.client_conn].receive_data(data)
 
         event = next(self.connections[self.client_conn].events())
         assert isinstance(event, events.Request)
 
-        data = self.connections[self.client_conn].send(AcceptConnection(extensions=extensions))
+        data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions))
         self.connections[self.server_conn].receive_data(data)
         assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
 

From 92fc87bc4c3922758fe9862fd309886f06fcb16b Mon Sep 17 00:00:00 2001
From: rpigott <rpigott@berkeley.edu>
Date: Wed, 30 Jan 2019 23:44:35 -0800
Subject: [PATCH 3/4] whitespace error

---
 mitmproxy/proxy/protocol/websocket.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/protocol/websocket.py
index fd2aed157..191da7a28 100644
--- mitmproxy/proxy/protocol/websocket.py
+++ mitmproxy/proxy/protocol/websocket.py
@@ -64,7 +64,7 @@ def __init__(self, ctx, handshake_flow):
 
         if client_extensions:
             client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
-		if server_extensions:
+        if server_extensions:
             server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
 
         request = Request(extensions = client_extensions, host = handshake_flow.request.host, target = handshake_flow.request.path)

From 7fa209790e0df869d0d3c96ba75cb1a85943fc7c Mon Sep 17 00:00:00 2001
From: rpigott <rpigott@berkeley.edu>
Date: Thu, 31 Jan 2019 12:34:00 -0800
Subject: [PATCH 4/4] removed whitespace in kwargs

---
 mitmproxy/proxy/protocol/websocket.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/protocol/websocket.py
index 191da7a28..f5ac6a29d 100644
--- mitmproxy/proxy/protocol/websocket.py
+++ mitmproxy/proxy/protocol/websocket.py
@@ -67,7 +67,7 @@ def __init__(self, ctx, handshake_flow):
         if server_extensions:
             server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
 
-        request = Request(extensions = client_extensions, host = handshake_flow.request.host, target = handshake_flow.request.path)
+        request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path)
         data = self.connections[self.server_conn].send(request)
         self.connections[self.client_conn].receive_data(data)
 
@@ -129,11 +129,11 @@ def get_chunk(payload):
                             yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
 
                 for chunk, final in get_chunk(websocket_message.content):
-                    data = self.connections[other_conn].send(Message(data = chunk, message_finished = final))
+                    data = self.connections[other_conn].send(Message(data=chunk, message_finished=final))
                     other_conn.send(data)
 
         if self.flow.stream:
-            data = self.connections[other_conn].send(Message(data = event.data, message_finished = event.message_finished))
+            data = self.connections[other_conn].send(Message(data=event.data, message_finished=event.message_finished))
             other_conn.send(data)
         return True
 
@@ -174,7 +174,7 @@ def _inject_messages(self, endpoint, message_queue):
         while True:
             try:
                 payload = message_queue.get_nowait()
-                data = self.connections[endpoint].send(Message(data = payload, message_finished = True))
+                data = self.connections[endpoint].send(Message(data=payload, message_finished=True))
                 endpoint.send(data)
             except queue.Empty:
                 break