websocket support, issue with connection closing

This commit is contained in:
Daniel
2024-10-24 19:36:31 +03:00
parent ce4001680e
commit f61eb541d7
6 changed files with 368 additions and 1 deletions

View File

@@ -0,0 +1,67 @@
package io.github.lumijiez.core.ws;
import java.io.IOException;
import java.net.Socket;
import java.util.UUID;
public class WebSocketConnection {
private final Socket socket;
private final String id;
private final String path;
private boolean isOpen;
public WebSocketConnection(Socket socket, String path) {
this.socket = socket;
this.id = UUID.randomUUID().toString();
this.path = path;
this.isOpen = true;
}
public void send(String message) throws IOException {
if (!isOpen) throw new IOException("Connection is closed");
byte[] payload = message.getBytes();
WebSocketFrame frame = new WebSocketFrame();
frame.setFin(true);
frame.setOpcode(0x1);
frame.setPayload(payload);
frame.write(socket.getOutputStream());
}
public void sendPong() throws IOException {
if (!isOpen) return;
WebSocketFrame frame = new WebSocketFrame();
frame.setFin(true);
frame.setOpcode(0xA);
frame.setPayload(new byte[0]);
frame.write(socket.getOutputStream());
}
public void close() throws IOException {
if (!isOpen) return;
WebSocketFrame frame = new WebSocketFrame();
frame.setFin(true);
frame.setOpcode(0x8);
frame.setPayload(new byte[0]);
frame.write(socket.getOutputStream());
isOpen = false;
socket.close();
}
public String getId() {
return id;
}
public String getPath() {
return path;
}
public boolean isOpen() {
return isOpen;
}
}

View File

@@ -0,0 +1,91 @@
package io.github.lumijiez.core.ws;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class WebSocketFrame {
private boolean fin;
private byte opcode;
private byte[] payload;
private boolean masked;
public static WebSocketFrame read(InputStream in) throws IOException {
WebSocketFrame frame = new WebSocketFrame();
int firstByte = in.read();
if (firstByte == -1) return null;
frame.fin = (firstByte & 0x80) != 0;
frame.opcode = (byte)(firstByte & 0x0F);
int secondByte = in.read();
if (secondByte == -1) return null;
frame.masked = (secondByte & 0x80) != 0;
int payloadLength = secondByte & 0x7F;
if (payloadLength == 126) {
payloadLength = (in.read() << 8) | in.read();
} else if (payloadLength == 127) {
throw new IOException("Payload length too large");
}
byte[] maskingKey = new byte[4];
if (frame.masked) {
int bytesRead = in.read(maskingKey);
if (bytesRead != 4) return null;
}
frame.payload = new byte[payloadLength];
int bytesRead = in.read(frame.payload);
if (bytesRead != payloadLength) return null;
if (frame.masked) {
for (int i = 0; i < frame.payload.length; i++) {
frame.payload[i] ^= maskingKey[i % 4];
}
}
return frame;
}
public void write(OutputStream out) throws IOException {
int firstByte = (fin ? 0x80 : 0x00) | (opcode & 0x0F);
out.write(firstByte);
if (payload.length < 126) {
out.write(payload.length);
} else if (payload.length < 65536) {
out.write(126);
out.write(payload.length >> 8);
out.write(payload.length & 0xFF);
} else {
throw new IOException("Payload too large");
}
// Write payload
out.write(payload);
out.flush();
}
public void setFin(boolean fin) {
this.fin = fin;
}
public void setOpcode(int opcode) {
this.opcode = (byte) opcode;
}
public void setPayload(byte[] payload) {
this.payload = payload;
}
public byte getOpcode() {
return opcode;
}
public byte[] getPayload() {
return payload;
}
}

View File

@@ -0,0 +1,7 @@
package io.github.lumijiez.core.ws;
public interface WebSocketHandler {
void onConnect(WebSocketConnection connection);
void onMessage(WebSocketConnection connection, String message);
void onDisconnect(WebSocketConnection connection);
}

View File

@@ -0,0 +1,177 @@
package io.github.lumijiez.core.ws;
import io.github.lumijiez.logging.Logger;
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class WebSocketServer {
private boolean running;
private final int port;
private ServerSocket serverSocket;
private final ExecutorService threadPool;
private final ConcurrentHashMap<String, WebSocketConnection> connections;
private final Map<String, WebSocketHandler> handlers;
public WebSocketServer(int port) {
this.port = port;
this.running = false;
this.threadPool = Executors.newCachedThreadPool();
this.connections = new ConcurrentHashMap<>();
this.handlers = new HashMap<>();
}
public void addHandler(String path, WebSocketHandler handler) {
handlers.put(path, handler);
}
public void broadcast(String path, String message) {
connections.values().stream()
.filter(conn -> conn.getPath().equals(path))
.forEach(conn -> {
try {
conn.send(message);
} catch (IOException e) {
Logger.error("WS", "Error broadcasting message: " + e.getMessage());
}
});
}
public void start() {
try {
serverSocket = new ServerSocket(port);
running = true;
Logger.info("WS", "WebSocket server started on port " + port);
while (running) {
try {
Socket clientSocket = serverSocket.accept();
threadPool.submit(() -> handleClient(clientSocket));
} catch (IOException e) {
if (running) {
Logger.error("WS", "Error accepting WebSocket connection: " + e.getMessage());
}
}
}
} catch (IOException e) {
Logger.error("WS", "Error starting WebSocket server: " + e.getMessage());
} finally {
stop();
}
}
private void handleClient(Socket clientSocket) {
try {
BufferedReader in = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));
BufferedWriter out = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream()));
String line = in.readLine();
if (line == null) return;
String[] requestLine = line.split(" ");
if (requestLine.length != 3) return;
String path = requestLine[1];
WebSocketHandler handler = handlers.get(path);
if (handler == null) {
clientSocket.close();
return;
}
Map<String, String> headers = new HashMap<>();
while (!(line = in.readLine()).isEmpty()) {
String[] parts = line.split(": ", 2);
if (parts.length == 2) {
headers.put(parts[0].toLowerCase(), parts[1]);
}
}
String key = headers.get("sec-websocket-key");
if (key == null) {
clientSocket.close();
return;
}
String acceptKey = generateAcceptKey(key);
out.write("HTTP/1.1 101 Switching Protocols\r\n");
out.write("Upgrade: websocket\r\n");
out.write("Connection: Upgrade\r\n");
out.write("Sec-WebSocket-Accept: " + acceptKey + "\r\n");
out.write("\r\n");
out.flush();
WebSocketConnection connection = new WebSocketConnection(clientSocket, path);
String connId = connection.getId();
connections.put(connId, connection);
handler.onConnect(connection);
while (running && connection.isOpen()) {
WebSocketFrame frame = WebSocketFrame.read(clientSocket.getInputStream());
if (frame == null) break;
switch (frame.getOpcode()) {
case 0x1:
handler.onMessage(connection, new String(frame.getPayload()));
break;
case 0x8:
connection.close();
break;
case 0x9:
connection.sendPong();
break;
}
}
handler.onDisconnect(connection);
connections.remove(connId);
} catch (IOException e) {
Logger.error("WS", "Error handling WebSocket client: " + e.getMessage());
}
}
private String generateAcceptKey(String key) {
String GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
try {
MessageDigest md = MessageDigest.getInstance("SHA-1");
return Base64.getEncoder().encodeToString(
md.digest((key + GUID).getBytes())
);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
public void stop() {
running = false;
connections.values().forEach(conn -> {
try {
conn.close();
} catch (IOException e) {
Logger.error("WS", "Error closing connection: " + e.getMessage());
}
});
connections.clear();
if (serverSocket != null) {
try {
serverSocket.close();
Logger.info("WS", "WebSocket server stopped");
} catch (IOException e) {
Logger.error("WS", "Error stopping WebSocket server: " + e.getMessage());
}
}
threadPool.shutdownNow();
}
}

View File

@@ -3,6 +3,9 @@ package io.github.lumijiez.example;
import io.github.lumijiez.core.config.ServerConfig; import io.github.lumijiez.core.config.ServerConfig;
import io.github.lumijiez.core.http.HttpServer; import io.github.lumijiez.core.http.HttpServer;
import io.github.lumijiez.core.http.HttpStatus; import io.github.lumijiez.core.http.HttpStatus;
import io.github.lumijiez.core.ws.WebSocketConnection;
import io.github.lumijiez.core.ws.WebSocketHandler;
import io.github.lumijiez.core.ws.WebSocketServer;
import io.github.lumijiez.example.daos.ProductDao; import io.github.lumijiez.example.daos.ProductDao;
import io.github.lumijiez.example.models.Product; import io.github.lumijiez.example.models.Product;
import io.github.lumijiez.logging.Logger; import io.github.lumijiez.logging.Logger;
@@ -34,6 +37,27 @@ public class Main {
res.sendResponse(HttpStatus.OK, product.toString()); res.sendResponse(HttpStatus.OK, product.toString());
}); });
server.start(); WebSocketServer wsServer = new WebSocketServer(8081);
wsServer.addHandler("/chat", new WebSocketHandler() {
@Override
public void onConnect(WebSocketConnection connection) {
Logger.info("WS", "Client connected to chat: " + connection.getId());
}
@Override
public void onMessage(WebSocketConnection connection, String message) {
Logger.info("WS", "Received message: " + message);
wsServer.broadcast("/chat", message);
}
@Override
public void onDisconnect(WebSocketConnection connection) {
Logger.info("WS", "Client disconnected from chat: " + connection.getId());
}
});
new Thread(server::start).start();
new Thread(wsServer::start).start();
} }
} }

View File

@@ -58,6 +58,7 @@ public class ProductDao {
} }
} }
public void deleteProduct(int id) { public void deleteProduct(int id) {
Transaction transaction = null; Transaction transaction = null;
try (Session session = new Configuration().configure().buildSessionFactory().openSession()) { try (Session session = new Configuration().configure().buildSessionFactory().openSession()) {